Skip to content

Commit

Permalink
task manager added
Browse files Browse the repository at this point in the history
based on https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py

 * classified
 * this way, gc.collect() will work as intended.
  • Loading branch information
wkpark committed Nov 1, 2024
1 parent 1b16c62 commit d4234f6
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 7 deletions.
4 changes: 2 additions & 2 deletions modules/call_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import html
import time

from modules import shared, progress, errors, devices, fifo_lock, profiling
from modules import shared, progress, errors, devices, fifo_lock, profiling, manager

queue_lock = fifo_lock.FIFOLock()

Expand Down Expand Up @@ -34,7 +34,7 @@ def f(*args, **kwargs):
progress.start_task(id_task)

try:
res = func(*args, **kwargs)
res = manager.task.run_and_wait_result(func, *args, **kwargs)
progress.record_results(id_task, res)
finally:
progress.finish_task(id_task)
Expand Down
6 changes: 6 additions & 0 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,11 +463,17 @@ def configure_for_tests():
def start():
print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {shlex.join(sys.argv[1:])}")
import webui

from modules import manager

if '--nowebui' in sys.argv:
webui.api_only()
else:
webui.webui()

manager.task.main_loop()
return


def dump_sysinfo():
from modules import sysinfo
Expand Down
83 changes: 83 additions & 0 deletions modules/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#
# based on forge's work from https://github.com/lllyasviel/stable-diffusion-webui-forge/blob/main/modules_forge/main_thread.py
#
# Original author comment:
# This file is the main thread that handles all gradio calls for major t2i or i2i processing.
# Other gradio calls (like those from extensions) are not influenced.
# By using one single thread to process all major calls, model moving is significantly faster.
#
# 2024/09/28 classified,

import random
import string
import threading
import time

from collections import OrderedDict


class Task:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)


class TaskManager:
last_exception = None
pending_tasks = []
finished_tasks = OrderedDict()
lock = None
running = False

def __init__(self):
self.lock = threading.Lock()

def work(self, task):
try:
task.result = task.func(*task.args, **task.kwargs)
except Exception as e:
task.exception = e
self.last_exception = e


def stop(self):
self.running = False


def main_loop(self):
self.running = True
while self.running:
time.sleep(0.01)
if len(self.pending_tasks) > 0:
with self.lock:
task = self.pending_tasks.pop(0)

self.work(task)

self.finished_tasks[task.task_id] = task


def push_task(self, func, *args, **kwargs):
if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"):
task_id = args[0]
else:
task_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=7))
task = Task(task_id=task_id, func=func, args=args, kwargs=kwargs, result=None, exception=None)
self.pending_tasks.append(task)

return task.task_id


def run_and_wait_result(self, func, *args, **kwargs):
current_id = self.push_task(func, *args, **kwargs)

while True:
time.sleep(0.01)
if current_id in self.finished_tasks:
finished = self.finished_tasks.pop(current_id)
if finished.exception is not None:
raise finished.exception

return finished.result


task = TaskManager()
21 changes: 16 additions & 5 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from modules import timer
from modules import initialize_util
from modules import initialize
from modules import manager
from threading import Thread

startup_timer = timer.startup_timer
startup_timer.record("launcher")
Expand All @@ -14,6 +16,8 @@

initialize.check_versions()

initialize.initialize()


def create_api(app):
from modules.api.api import Api
Expand All @@ -23,12 +27,10 @@ def create_api(app):
return api


def api_only():
def _api_only():
from fastapi import FastAPI
from modules.shared_cmd_options import cmd_opts

initialize.initialize()

app = FastAPI()
initialize_util.setup_middleware(app)
api = create_api(app)
Expand Down Expand Up @@ -83,11 +85,10 @@ def abspath(path):
{"!"*25} Warning {"!"*25}''')


def webui():
def _webui():
from modules.shared_cmd_options import cmd_opts

launch_api = cmd_opts.api
initialize.initialize()

from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks

Expand Down Expand Up @@ -177,6 +178,7 @@ def webui():
print("Stopping server...")
# If we catch a keyboard interrupt, we want to stop the server and exit.
shared.demo.close()
manager.task.stop()
break

# disable auto launch webui in browser for subsequent UI Reload
Expand All @@ -193,10 +195,19 @@ def webui():
initialize.initialize_rest(reload_script_modules=True)


def api_only():
Thread(target=_api_only, daemon=True).start()


def webui():
Thread(target=_webui, daemon=True).start()

if __name__ == "__main__":
from modules.shared_cmd_options import cmd_opts

if cmd_opts.nowebui:
api_only()
else:
webui()

manager.task.main_loop()

0 comments on commit d4234f6

Please sign in to comment.