Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add task manager #16570

Open
wants to merge 1 commit into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions modules/call_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import html
import time

from modules import shared, progress, errors, devices, fifo_lock, profiling
from modules import shared, progress, errors, devices, fifo_lock, profiling, manager

queue_lock = fifo_lock.FIFOLock()

Expand Down Expand Up @@ -34,7 +34,7 @@ def f(*args, **kwargs):
progress.start_task(id_task)

try:
res = func(*args, **kwargs)
res = manager.task.run_and_wait_result(func, *args, **kwargs)
progress.record_results(id_task, res)
finally:
progress.finish_task(id_task)
Expand Down
6 changes: 6 additions & 0 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,11 +463,17 @@ def configure_for_tests():
def start():
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
import webui

from modules import manager

if '--nowebui' in sys.argv:
webui.api_only()
else:
webui.webui()

manager.task.main_loop()
return


def dump_sysinfo():
from modules import sysinfo
Expand Down
83 changes: 83 additions & 0 deletions modules/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#
# based on forge's work from https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py
#
# Original author comment:
# This file is the main thread that handles all gradio calls for major t2i or i2i processing.
# Other gradio calls (like those from extensions) are not influenced.
# By using one single thread to process all major calls, model moving is significantly faster.
#
# 2024/09/28 classified,

import random
import string
import threading
import time

from collections import OrderedDict


class Task:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)


class TaskManager:
last_exception = None
pending_tasks = []
finished_tasks = OrderedDict()
lock = None
running = False

def __init__(self):
self.lock = threading.Lock()

def work(self, task):
try:
task.result = task.func(*task.args, **task.kwargs)
except Exception as e:
task.exception = e
self.last_exception = e


def stop(self):
self.running = False


def main_loop(self):
self.running = True
while self.running:
time.sleep(0.01)
if len(self.pending_tasks) > 0:
with self.lock:
task = self.pending_tasks.pop(0)

self.work(task)

self.finished_tasks[task.task_id] = task


def push_task(self, func, *args, **kwargs):
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
task_id = args[0]
else:
task_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=7))
task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs, result=None, exception=None)
self.pending_tasks.append(task)

return task.task_id


def run_and_wait_result(self, func, *args, **kwargs):
current_id = self.push_task(func, *args, **kwargs)

while True:
time.sleep(0.01)
if current_id in self.finished_tasks:
finished = self.finished_tasks.pop(current_id)
if finished.exception is not None:
raise finished.exception

return finished.result


task = TaskManager()
21 changes: 16 additions & 5 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from modules import timer
from modules import initialize_util
from modules import initialize
from modules import manager
from threading import Thread

startup_timer = timer.startup_timer
startup_timer.record("launcher")
Expand All @@ -14,6 +16,8 @@

initialize.check_versions()

initialize.initialize()


def create_api(app):
from modules.api.api import Api
Expand All @@ -23,12 +27,10 @@ def create_api(app):
return api


def api_only():
def _api_only():
from fastapi import FastAPI
from modules.shared_cmd_options import cmd_opts

initialize.initialize()

app = FastAPI()
initialize_util.setup_middleware(app)
api = create_api(app)
Expand Down Expand Up @@ -83,11 +85,10 @@ def abspath(path):
{"!"*25} Warning {"!"*25}''')


def webui():
def _webui():
from modules.shared_cmd_options import cmd_opts

launch_api = cmd_opts.api
initialize.initialize()

from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks

Expand Down Expand Up @@ -177,6 +178,7 @@ def webui():
print("Stopping server...")
# If we catch a keyboard interrupt, we want to stop the server and exit.
shared.demo.close()
manager.task.stop()
break

# disable auto launch webui in browser for subsequent UI Reload
Expand All @@ -193,10 +195,19 @@ def webui():
initialize.initialize_rest(reload_script_modules=True)


def api_only():
Thread(target=_api_only, daemon=True).start()


def webui():
Thread(target=_webui, daemon=True).start()

if __name__ == "__main__":
from modules.shared_cmd_options import cmd_opts

if cmd_opts.nowebui:
api_only()
else:
webui()

manager.task.main_loop()
Loading