From 089dc66d256a7e31f2764487b9d828120248bbe1 Mon Sep 17 00:00:00 2001 From: Acly Date: Sat, 12 Oct 2024 20:10:07 +0200 Subject: [PATCH] Load previously used custom graph from document (somehow) --- ai_diffusion/custom_workflow.py | 53 ++++++++++++++++++++---------- ai_diffusion/persistence.py | 22 +++++++++++-- ai_diffusion/ui/custom_workflow.py | 4 +-- tests/test_custom_workflow.py | 26 ++++++++++----- 4 files changed, 74 insertions(+), 31 deletions(-) diff --git a/ai_diffusion/custom_workflow.py b/ai_diffusion/custom_workflow.py index 1f37d32230..804f15a9c1 100644 --- a/ai_diffusion/custom_workflow.py +++ b/ai_diffusion/custom_workflow.py @@ -48,11 +48,14 @@ class WorkflowCollection(QAbstractListModel): _icon_remote = theme.icon("web-connection") _icon_document = theme.icon("file-kra") + loaded = pyqtSignal() + def __init__(self, connection: Connection, folder: Path | None = None): super().__init__() self._connection = connection self._folder = folder or user_data_dir / "workflows" self._workflows: list[CustomWorkflow] = [] + self._pending_workflows: list[tuple[str, WorkflowSource, dict]] = [] self._connection.state_changed.connect(self._handle_connection) self._connection.workflow_published.connect(self._process_remote_workflow) @@ -63,6 +66,10 @@ def _handle_connection(self, state: ConnectionState): self.clear() if state is ConnectionState.connected: + for id, source, graph in self._pending_workflows: + self._process_workflow(id, source, graph) + self._pending_workflows.clear() + for file in self._folder.glob("*.json"): try: self._process_file(file) @@ -72,31 +79,36 @@ def _handle_connection(self, state: ConnectionState): for wf in self._connection.workflows.keys(): self._process_remote_workflow(wf) + self.loaded.emit() + def _node_inputs(self): return self._connection.client.models.node_inputs - def _create_workflow( + def _process_workflow( self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None ): - wf = ComfyWorkflow.import_graph(graph, self._node_inputs()) - return CustomWorkflow(id, source, wf, path) - - def _process_remote_workflow(self, id: str): - graph = self._connection.workflows[id] - self._process(self._create_workflow(id, WorkflowSource.remote, graph)) - - def _process_file(self, file: Path): - with file.open("r") as f: - graph = json.load(f) - self._process(self._create_workflow(file.stem, WorkflowSource.local, graph, file)) + if self._connection.state is not ConnectionState.connected: + self._pending_workflows.append((id, source, graph)) + return - def _process(self, workflow: CustomWorkflow): + comfy_flow = ComfyWorkflow.import_graph(graph, self._node_inputs()) + workflow = CustomWorkflow(id, source, comfy_flow, path) idx = self.find_index(workflow.id) if idx.isValid(): self._workflows[idx.row()] = workflow self.dataChanged.emit(idx, idx) else: self.append(workflow) + return idx + + def _process_remote_workflow(self, id: str): + graph = self._connection.workflows[id] + self._process_workflow(id, WorkflowSource.remote, graph) + + def _process_file(self, file: Path): + with file.open("r") as f: + graph = json.load(f) + self._process_workflow(file.stem, WorkflowSource.local, graph, file) def rowCount(self, parent=QModelIndex()): return len(self._workflows) @@ -121,7 +133,7 @@ def append(self, item: CustomWorkflow): self.endInsertRows() def add_from_document(self, id: str, graph: dict): - self.append(self._create_workflow(id, WorkflowSource.document, graph)) + self._process_workflow(id, WorkflowSource.document, graph) def remove(self, id: str): idx = self.find_index(id) @@ -154,7 +166,7 @@ def save_as(self, id: str, graph: dict): self._folder.mkdir(exist_ok=True) path = self._folder / f"{id}.json" path.write_text(json.dumps(graph, indent=2)) - self.append(self._create_workflow(id, WorkflowSource.local, graph, path)) + self._process_workflow(id, WorkflowSource.local, graph, path) return id def import_file(self, filepath: Path): @@ -336,12 +348,16 @@ def __init__(self, workflows: WorkflowCollection, generator: ImageGenerator, job jobs.job_finished.connect(self._handle_job_finished) workflows.dataChanged.connect(self._update_workflow) - workflows.rowsInserted.connect(self._set_default_workflow) + workflows.loaded.connect(self._set_default_workflow) self._set_default_workflow() def _set_default_workflow(self): if not self.workflow_id and len(self._workflows) > 0: self.workflow_id = self._workflows[0].id + else: + current_index = self._workflows.find_index(self.workflow_id) + if current_index.isValid(): + self._update_workflow(current_index, QModelIndex()) def _update_workflow(self, idx: QModelIndex, _: QModelIndex): wf = self._workflows[idx.row()] @@ -358,10 +374,13 @@ def _set_workflow_id(self, id: str): self._workflow_id = id self.workflow_id_changed.emit(id) self.modified.emit(self, "workflow_id") - self._update_workflow(self._workflows.find_index(id), QModelIndex()) + index = self._workflows.find_index(id) + if index.isValid(): # might be invalid when loading document before connecting + self._update_workflow(index, QModelIndex()) def set_graph(self, id: str, graph: dict): if self._workflows.find(id) is None: + id = "Document Workflow (embedded)" self._workflows.add_from_document(id, graph) self.workflow_id = id diff --git a/ai_diffusion/persistence.py b/ai_diffusion/persistence.py index b01dec15e8..348e62cfbc 100644 --- a/ai_diffusion/persistence.py +++ b/ai_diffusion/persistence.py @@ -8,8 +8,9 @@ from PyQt5.QtWidgets import QMessageBox from .api import InpaintMode, FillMode -from .image import Bounds, Image, ImageCollection, ImageFileFormat +from .image import ImageCollection from .model import Model, InpaintContext +from .custom_workflow import CustomWorkspace from .control import ControlLayer, ControlLayerList from .region import RootRegion, Region from .jobs import Job, JobKind, JobParams, JobQueue @@ -132,7 +133,7 @@ def _save(self): state["upscale"] = _serialize(model.upscale) state["live"] = _serialize(model.live) state["animation"] = _serialize(model.animation) - state["custom"] = _serialize(model.custom) + state["custom"] = _serialize_custom(model.custom) state["history"] = [asdict(h) for h in self._history] state["root"] = _serialize(model.regions) state["control"] = [_serialize(c) for c in model.regions.control] @@ -151,7 +152,7 @@ def _load(self, model: Model, state_bytes: bytes): _deserialize(model.upscale, state.get("upscale", {})) _deserialize(model.live, state.get("live", {})) _deserialize(model.animation, state.get("animation", {})) - _deserialize(model.custom, state.get("custom", {})) + _deserialize_custom(model.custom, state.get("custom", {})) _deserialize(model.regions, state.get("root", {})) for control_state in state.get("control", []): _deserialize(model.regions.control.emplace(), control_state) @@ -264,6 +265,21 @@ def converter(type, value): return deserialize(obj, data, converter) +def _serialize_custom(custom: CustomWorkspace): + result = _serialize(custom) + result["workflow_id"] = custom.workflow_id + result["graph"] = custom.graph.root if custom.graph else None + return result + + +def _deserialize_custom(custom: CustomWorkspace, data: dict[str, Any]): + _deserialize(custom, data) + workflow_id = data.get("workflow_id", "") + graph = data.get("graph", None) + if workflow_id and graph: + custom.set_graph(workflow_id, graph) + + def _find_annotation(document, name: str): if result := document.find_annotation(name): return result diff --git a/ai_diffusion/ui/custom_workflow.py b/ai_diffusion/ui/custom_workflow.py index 5624dbca6b..85def2632b 100644 --- a/ai_diffusion/ui/custom_workflow.py +++ b/ai_diffusion/ui/custom_workflow.py @@ -87,7 +87,7 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None): self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) self._widget.valueChanged.connect(self._notify) self._label = QLabel(self) - self._label.setFixedWidth(40) + self._label.setFixedWidth(32) self._label.setAlignment(Qt.AlignmentFlag.AlignRight) layout.addWidget(self._widget) layout.addWidget(self._label) @@ -135,7 +135,7 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None): self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4) self._widget.valueChanged.connect(self._notify) self._label = QLabel(self) - self._label.setFixedWidth(40) + self._label.setFixedWidth(32) self._label.setAlignment(Qt.AlignmentFlag.AlignRight) layout.addWidget(self._widget) layout.addWidget(self._label) diff --git a/tests/test_custom_workflow.py b/tests/test_custom_workflow.py index 8679928d79..ebf1693d81 100644 --- a/tests/test_custom_workflow.py +++ b/tests/test_custom_workflow.py @@ -83,14 +83,24 @@ def test_collection(tmp_path: Path): connection = create_mock_connection(connection_workflows, state=ConnectionState.disconnected) collection = WorkflowCollection(connection, tmp_path) + events = [] + assert len(collection) == 0 + + def on_loaded(): + events.append("loaded") + + collection.loaded.connect(on_loaded) + doc_graph = {"0": {"class_type": "D1", "inputs": {}}} + collection.add_from_document("doc1", doc_graph) + connection.state = ConnectionState.connected - assert len(collection) == 3 + assert len(collection) == 4 + assert events == ["loaded"] _assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph, file1) _assert_has_workflow(collection, "file2", WorkflowSource.local, file2_graph, file2) _assert_has_workflow(collection, "connection1", WorkflowSource.remote, connection_graph) - - events = [] + _assert_has_workflow(collection, "doc1", WorkflowSource.document, doc_graph) def on_begin_insert(index, first, last): events.append(("begin_insert", first)) @@ -109,15 +119,13 @@ def on_data_changed(start, end): connection_workflows["connection2"] = connection2_graph connection.workflow_published.emit("connection2") - assert len(collection) == 4 + assert len(collection) == 5 _assert_has_workflow(collection, "connection2", WorkflowSource.remote, connection2_graph) file1_graph_changed = {"0": {"class_type": "F3", "inputs": {}}} - collection.set_graph(collection.index(0), file1_graph_changed) + collection.set_graph(collection.find_index("file1"), file1_graph_changed) _assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph_changed, file1) - assert events == [("begin_insert", 3), "end_insert", ("data_changed", 0)] - - collection.add_from_document("doc1", {"0": {"class_type": "D1", "inputs": {}}}) + assert events == ["loaded", ("begin_insert", 4), "end_insert", ("data_changed", 1)] sorted = SortedWorkflows(collection) assert sorted[0].source is WorkflowSource.document @@ -207,7 +215,7 @@ def test_workspace(): } } workspace.set_graph("doc1", doc_graph) - assert workspace.workflow_id == "doc1" + assert workspace.workflow_id == "Document Workflow (embedded)" assert workspace.workflow and workspace.workflow.source is WorkflowSource.document assert workspace.graph and workspace.graph.node(0).type == "ETN_Parameter" assert workspace.metadata[0].name == "param2"