From d4234f6bce17c260720964565f3cc9d5d07ac231 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 28 Sep 2024 23:19:08 +0900 Subject: [PATCH] task manager added based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py * classified * this way, gc.collect() will work as intended. --- modules/call_queue.py | 4 +- modules/launch_utils.py | 6 +++ modules/manager.py | 83 +++++++++++++++++++++++++++++++++++++++++ webui.py | 21 ++++++++--- 4 files changed, 107 insertions(+), 7 deletions(-) create mode 100644 modules/manager.py diff --git a/modules/call_queue.py b/modules/call_queue.py index 555c35312dd..b20badcaf25 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -3,7 +3,7 @@ import html import time -from modules import shared, progress, errors, devices, fifo_lock, profiling +from modules import shared, progress, errors, devices, fifo_lock, profiling, manager queue_lock = fifo_lock.FIFOLock() @@ -34,7 +34,7 @@ def f(*args, **kwargs): progress.start_task(id_task) try: - res = func(*args, **kwargs) + res = manager.task.run_and_wait_result(func, *args, **kwargs) progress.record_results(id_task, res) finally: progress.finish_task(id_task) diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 20c7dc127a7..5c868747e7a 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -463,11 +463,17 @@ def configure_for_tests(): def start(): print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}") import webui + + from modules import manager + if '--nowebui' in sys.argv: webui.api_only() else: webui.webui() + manager.task.main_loop() + return + def dump_sysinfo(): from modules import sysinfo diff --git a/modules/manager.py b/modules/manager.py new file mode 100644 index 00000000000..34c67c6b3cc --- /dev/null +++ b/modules/manager.py @@ -0,0 +1,83 @@ +# +# based on forge's work from https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py +# +# Original author comment: +# This file is the main thread that handles all gradio calls for major t2i or i2i processing. +# Other gradio calls (like those from extensions) are not influenced. +# By using one single thread to process all major calls, model moving is significantly faster. +# +# 2024/09/28 classified, + +import random +import string +import threading +import time + +from collections import OrderedDict + + +class Task: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +class TaskManager: + last_exception = None + pending_tasks = [] + finished_tasks = OrderedDict() + lock = None + running = False + + def __init__(self): + self.lock = threading.Lock() + + def work(self, task): + try: + task.result = task.func(*task.args, **task.kwargs) + except Exception as e: + task.exception = e + self.last_exception = e + + + def stop(self): + self.running = False + + + def main_loop(self): + self.running = True + while self.running: + time.sleep(0.01) + if len(self.pending_tasks) > 0: + with self.lock: + task = self.pending_tasks.pop(0) + + self.work(task) + + self.finished_tasks[task.task_id] = task + + + def push_task(self, func, *args, **kwargs): + if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"): + task_id = args[0] + else: + task_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=7)) + task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs, result=None, exception=None) + self.pending_tasks.append(task) + + return task.task_id + + + def run_and_wait_result(self, func, *args, **kwargs): + current_id = self.push_task(func, *args, **kwargs) + + while True: + time.sleep(0.01) + if current_id in self.finished_tasks: + finished = self.finished_tasks.pop(current_id) + if finished.exception is not None: + raise finished.exception + + return finished.result + + +task = TaskManager() diff --git a/webui.py b/webui.py index 421e3b8334f..398d8355023 100644 --- a/webui.py +++ b/webui.py @@ -6,6 +6,8 @@ from modules import timer from modules import initialize_util from modules import initialize +from modules import manager +from threading import Thread startup_timer = timer.startup_timer startup_timer.record("launcher") @@ -14,6 +16,8 @@ initialize.check_versions() +initialize.initialize() + def create_api(app): from modules.api.api import Api @@ -23,12 +27,10 @@ def create_api(app): return api -def api_only(): +def _api_only(): from fastapi import FastAPI from modules.shared_cmd_options import cmd_opts - initialize.initialize() - app = FastAPI() initialize_util.setup_middleware(app) api = create_api(app) @@ -83,11 +85,10 @@ def abspath(path): {"!"*25} Warning {"!"*25}''') -def webui(): +def _webui(): from modules.shared_cmd_options import cmd_opts launch_api = cmd_opts.api - initialize.initialize() from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks @@ -177,6 +178,7 @@ def webui(): print("Stopping server...") # If we catch a keyboard interrupt, we want to stop the server and exit. shared.demo.close() + manager.task.stop() break # disable auto launch webui in browser for subsequent UI Reload @@ -193,6 +195,13 @@ def webui(): initialize.initialize_rest(reload_script_modules=True) +def api_only(): + Thread(target=_api_only, daemon=True).start() + + +def webui(): + Thread(target=_webui, daemon=True).start() + if __name__ == "__main__": from modules.shared_cmd_options import cmd_opts @@ -200,3 +209,5 @@ def webui(): api_only() else: webui() + + manager.task.main_loop()