Skip to content

Commit

Permalink
Load previously used custom graph from document (somehow)
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Oct 12, 2024
1 parent aed73dc commit 089dc66
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 31 deletions.
53 changes: 36 additions & 17 deletions ai_diffusion/custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,14 @@ class WorkflowCollection(QAbstractListModel):
_icon_remote = theme.icon("web-connection")
_icon_document = theme.icon("file-kra")

loaded = pyqtSignal()

def __init__(self, connection: Connection, folder: Path | None = None):
super().__init__()
self._connection = connection
self._folder = folder or user_data_dir / "workflows"
self._workflows: list[CustomWorkflow] = []
self._pending_workflows: list[tuple[str, WorkflowSource, dict]] = []

self._connection.state_changed.connect(self._handle_connection)
self._connection.workflow_published.connect(self._process_remote_workflow)
Expand All @@ -63,6 +66,10 @@ def _handle_connection(self, state: ConnectionState):
self.clear()

if state is ConnectionState.connected:
for id, source, graph in self._pending_workflows:
self._process_workflow(id, source, graph)
self._pending_workflows.clear()

for file in self._folder.glob("*.json"):
try:
self._process_file(file)
Expand All @@ -72,31 +79,36 @@ def _handle_connection(self, state: ConnectionState):
for wf in self._connection.workflows.keys():
self._process_remote_workflow(wf)

self.loaded.emit()

def _node_inputs(self):
return self._connection.client.models.node_inputs

def _create_workflow(
def _process_workflow(
self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None
):
wf = ComfyWorkflow.import_graph(graph, self._node_inputs())
return CustomWorkflow(id, source, wf, path)

def _process_remote_workflow(self, id: str):
graph = self._connection.workflows[id]
self._process(self._create_workflow(id, WorkflowSource.remote, graph))

def _process_file(self, file: Path):
with file.open("r") as f:
graph = json.load(f)
self._process(self._create_workflow(file.stem, WorkflowSource.local, graph, file))
if self._connection.state is not ConnectionState.connected:
self._pending_workflows.append((id, source, graph))
return

def _process(self, workflow: CustomWorkflow):
comfy_flow = ComfyWorkflow.import_graph(graph, self._node_inputs())
workflow = CustomWorkflow(id, source, comfy_flow, path)
idx = self.find_index(workflow.id)
if idx.isValid():
self._workflows[idx.row()] = workflow
self.dataChanged.emit(idx, idx)
else:
self.append(workflow)
return idx

def _process_remote_workflow(self, id: str):
graph = self._connection.workflows[id]
self._process_workflow(id, WorkflowSource.remote, graph)

def _process_file(self, file: Path):
with file.open("r") as f:
graph = json.load(f)
self._process_workflow(file.stem, WorkflowSource.local, graph, file)

def rowCount(self, parent=QModelIndex()):
return len(self._workflows)
Expand All @@ -121,7 +133,7 @@ def append(self, item: CustomWorkflow):
self.endInsertRows()

def add_from_document(self, id: str, graph: dict):
self.append(self._create_workflow(id, WorkflowSource.document, graph))
self._process_workflow(id, WorkflowSource.document, graph)

def remove(self, id: str):
idx = self.find_index(id)
Expand Down Expand Up @@ -154,7 +166,7 @@ def save_as(self, id: str, graph: dict):
self._folder.mkdir(exist_ok=True)
path = self._folder / f"{id}.json"
path.write_text(json.dumps(graph, indent=2))
self.append(self._create_workflow(id, WorkflowSource.local, graph, path))
self._process_workflow(id, WorkflowSource.local, graph, path)
return id

def import_file(self, filepath: Path):
Expand Down Expand Up @@ -336,12 +348,16 @@ def __init__(self, workflows: WorkflowCollection, generator: ImageGenerator, job

jobs.job_finished.connect(self._handle_job_finished)
workflows.dataChanged.connect(self._update_workflow)
workflows.rowsInserted.connect(self._set_default_workflow)
workflows.loaded.connect(self._set_default_workflow)
self._set_default_workflow()

def _set_default_workflow(self):
if not self.workflow_id and len(self._workflows) > 0:
self.workflow_id = self._workflows[0].id
else:
current_index = self._workflows.find_index(self.workflow_id)
if current_index.isValid():
self._update_workflow(current_index, QModelIndex())

def _update_workflow(self, idx: QModelIndex, _: QModelIndex):
wf = self._workflows[idx.row()]
Expand All @@ -358,10 +374,13 @@ def _set_workflow_id(self, id: str):
self._workflow_id = id
self.workflow_id_changed.emit(id)
self.modified.emit(self, "workflow_id")
self._update_workflow(self._workflows.find_index(id), QModelIndex())
index = self._workflows.find_index(id)
if index.isValid(): # might be invalid when loading document before connecting
self._update_workflow(index, QModelIndex())

def set_graph(self, id: str, graph: dict):
if self._workflows.find(id) is None:
id = "Document Workflow (embedded)"
self._workflows.add_from_document(id, graph)
self.workflow_id = id

Expand Down
22 changes: 19 additions & 3 deletions ai_diffusion/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from PyQt5.QtWidgets import QMessageBox

from .api import InpaintMode, FillMode
from .image import Bounds, Image, ImageCollection, ImageFileFormat
from .image import ImageCollection
from .model import Model, InpaintContext
from .custom_workflow import CustomWorkspace
from .control import ControlLayer, ControlLayerList
from .region import RootRegion, Region
from .jobs import Job, JobKind, JobParams, JobQueue
Expand Down Expand Up @@ -132,7 +133,7 @@ def _save(self):
state["upscale"] = _serialize(model.upscale)
state["live"] = _serialize(model.live)
state["animation"] = _serialize(model.animation)
state["custom"] = _serialize(model.custom)
state["custom"] = _serialize_custom(model.custom)
state["history"] = [asdict(h) for h in self._history]
state["root"] = _serialize(model.regions)
state["control"] = [_serialize(c) for c in model.regions.control]
Expand All @@ -151,7 +152,7 @@ def _load(self, model: Model, state_bytes: bytes):
_deserialize(model.upscale, state.get("upscale", {}))
_deserialize(model.live, state.get("live", {}))
_deserialize(model.animation, state.get("animation", {}))
_deserialize(model.custom, state.get("custom", {}))
_deserialize_custom(model.custom, state.get("custom", {}))
_deserialize(model.regions, state.get("root", {}))
for control_state in state.get("control", []):
_deserialize(model.regions.control.emplace(), control_state)
Expand Down Expand Up @@ -264,6 +265,21 @@ def converter(type, value):
return deserialize(obj, data, converter)


def _serialize_custom(custom: CustomWorkspace):
result = _serialize(custom)
result["workflow_id"] = custom.workflow_id
result["graph"] = custom.graph.root if custom.graph else None
return result


def _deserialize_custom(custom: CustomWorkspace, data: dict[str, Any]):
_deserialize(custom, data)
workflow_id = data.get("workflow_id", "")
graph = data.get("graph", None)
if workflow_id and graph:
custom.set_graph(workflow_id, graph)


def _find_annotation(document, name: str):
if result := document.find_annotation(name):
return result
Expand Down
4 changes: 2 additions & 2 deletions ai_diffusion/ui/custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None):
self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4)
self._widget.valueChanged.connect(self._notify)
self._label = QLabel(self)
self._label.setFixedWidth(40)
self._label.setFixedWidth(32)
self._label.setAlignment(Qt.AlignmentFlag.AlignRight)
layout.addWidget(self._widget)
layout.addWidget(self._label)
Expand Down Expand Up @@ -135,7 +135,7 @@ def __init__(self, param: CustomParam, parent: QWidget | None = None):
self._widget.setMinimumHeight(self._widget.minimumSizeHint().height() + 4)
self._widget.valueChanged.connect(self._notify)
self._label = QLabel(self)
self._label.setFixedWidth(40)
self._label.setFixedWidth(32)
self._label.setAlignment(Qt.AlignmentFlag.AlignRight)
layout.addWidget(self._widget)
layout.addWidget(self._label)
Expand Down
26 changes: 17 additions & 9 deletions tests/test_custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,24 @@ def test_collection(tmp_path: Path):
connection = create_mock_connection(connection_workflows, state=ConnectionState.disconnected)

collection = WorkflowCollection(connection, tmp_path)
events = []

assert len(collection) == 0

def on_loaded():
events.append("loaded")

collection.loaded.connect(on_loaded)
doc_graph = {"0": {"class_type": "D1", "inputs": {}}}
collection.add_from_document("doc1", doc_graph)

connection.state = ConnectionState.connected
assert len(collection) == 3
assert len(collection) == 4
assert events == ["loaded"]
_assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph, file1)
_assert_has_workflow(collection, "file2", WorkflowSource.local, file2_graph, file2)
_assert_has_workflow(collection, "connection1", WorkflowSource.remote, connection_graph)

events = []
_assert_has_workflow(collection, "doc1", WorkflowSource.document, doc_graph)

def on_begin_insert(index, first, last):
events.append(("begin_insert", first))
Expand All @@ -109,15 +119,13 @@ def on_data_changed(start, end):
connection_workflows["connection2"] = connection2_graph
connection.workflow_published.emit("connection2")

assert len(collection) == 4
assert len(collection) == 5
_assert_has_workflow(collection, "connection2", WorkflowSource.remote, connection2_graph)

file1_graph_changed = {"0": {"class_type": "F3", "inputs": {}}}
collection.set_graph(collection.index(0), file1_graph_changed)
collection.set_graph(collection.find_index("file1"), file1_graph_changed)
_assert_has_workflow(collection, "file1", WorkflowSource.local, file1_graph_changed, file1)
assert events == [("begin_insert", 3), "end_insert", ("data_changed", 0)]

collection.add_from_document("doc1", {"0": {"class_type": "D1", "inputs": {}}})
assert events == ["loaded", ("begin_insert", 4), "end_insert", ("data_changed", 1)]

sorted = SortedWorkflows(collection)
assert sorted[0].source is WorkflowSource.document
Expand Down Expand Up @@ -207,7 +215,7 @@ def test_workspace():
}
}
workspace.set_graph("doc1", doc_graph)
assert workspace.workflow_id == "doc1"
assert workspace.workflow_id == "Document Workflow (embedded)"
assert workspace.workflow and workspace.workflow.source is WorkflowSource.document
assert workspace.graph and workspace.graph.node(0).type == "ETN_Parameter"
assert workspace.metadata[0].name == "param2"
Expand Down

0 comments on commit 089dc66

Please sign in to comment.