Skip to content

Commit

Permalink
Add a grace period between detecting a change and triggering generati…
Browse files Browse the repository at this point in the history
…on in live preview (#1412)

* Add a grace period between detecting a change and triggering generation in live preview
* This will prevent some of the "useless" generations, e.g. from the very start of
  the brush stroke
* Period is configurable in settings; setting the default to 0 to
  preserve the existing behaviour
This at least partially addresses/follows the discussions in #628 and #1248
* Don't exit live generation loop when switching documents
  • Loading branch information
modelflat authored Nov 29, 2024
1 parent 9902368 commit 81f0195
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 17 deletions.
41 changes: 24 additions & 17 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from dataclasses import replace
from pathlib import Path
from enum import Enum
from typing import Any, NamedTuple
import time
from typing import NamedTuple
from PyQt5.QtCore import QObject, QUuid, pyqtSignal, Qt
from PyQt5.QtGui import QImage, QPainter, QColor, QBrush
from PyQt5.QtGui import QPainter, QColor, QBrush
import uuid

from . import eventloop, workflow, util
Expand All @@ -28,7 +29,7 @@
from .connection import Connection
from .properties import Property, ObservableProperties
from .jobs import Job, JobKind, JobParams, JobQueue, JobState, JobRegion
from .control import ControlLayer, ControlLayerList
from .control import ControlLayer
from .region import Region, RegionLink, RootRegion, process_regions, get_region_inpaint_mask
from .resources import ControlMode
from .resolution import compute_bounds, compute_relative_bounds
Expand Down Expand Up @@ -311,9 +312,10 @@ def estimate_cost(self, kind=JobKind.diffusion):
return 0

def generate_live(self):
eventloop.run(_report_errors(self, self._generate_live()))
input, job_params = self._prepare_live_workflow()
eventloop.run(_report_errors(self, self._generate_live(input, job_params)))

async def _generate_live(self, last_input: WorkflowInput | None = None):
def _prepare_live_workflow(self):
strength = self.live.strength
workflow_kind = WorkflowKind.generate if strength == 1.0 else WorkflowKind.refine
client = self._connection.client
Expand Down Expand Up @@ -361,13 +363,12 @@ async def _generate_live(self, last_input: WorkflowInput | None = None):
inpaint=inpaint if mask else None,
is_live=True,
)
if input != last_input:
self.clear_error()
params = JobParams(bounds, conditioning.positive, regions=job_regions)
await self.enqueue_jobs(input, JobKind.live_preview, params)
return input
params = JobParams(bounds, conditioning.positive, regions=job_regions)
return input, params

return None
async def _generate_live(self, input: WorkflowInput, job_params: JobParams):
self.clear_error()
await self.enqueue_jobs(input, JobKind.live_preview, job_params)

async def _generate_custom(self, previous_input: WorkflowInput | None):
if self.workspace is not Workspace.custom or not self.document.is_active:
Expand Down Expand Up @@ -856,6 +857,7 @@ class LiveWorkspace(QObject, ObservableProperties):

_model: Model
_last_input: WorkflowInput | None = None
_last_change: float = 0
_result: Image | None = None
_result_composition: Image | None = None
_result_params: JobParams | None = None
Expand Down Expand Up @@ -902,12 +904,17 @@ def handle_job_finished(self, job: Job):
eventloop.run(_report_errors(self._model, self._continue_generating()))

async def _continue_generating(self):
while self.is_active and self._model.document.is_active:
new_input = await self._model._generate_live(self._last_input)
if new_input is not None: # frame was scheduled
self._last_input = new_input
return
# no changes in input data
while self.is_active:
if self._model.document.is_active:
new_input, job_params = self._model._prepare_live_workflow()
if self._last_input != new_input:
now = time.monotonic()
if self._last_change + settings.live_redraw_grace_period <= now:
await self._model._generate_live(new_input, job_params)
self._last_input = new_input
return
else:
self._last_change = time.monotonic()
await asyncio.sleep(self._poll_rate)

def apply_result(self, layer_only=False):
Expand Down
7 changes: 7 additions & 0 deletions ai_diffusion/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ class Settings(QObject):
_("Pick a new seed after copying the result to the canvas in Live mode"),
)

live_redraw_grace_period: float
_live_redraw_grace_period = Setting(
_("Live: Redraw grace period"),
0.0,
_("How long to delay scheduling the live preview job for after a change is made"),
)

prompt_translation: str
_prompt_translation = Setting(
_("Prompt Translation"),
Expand Down
4 changes: 4 additions & 0 deletions ai_diffusion/ui/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,10 @@ def __init__(self):
self.add("auto_preview", SwitchSetting(S._auto_preview, parent=self))
self.add("show_steps", SwitchSetting(S._show_steps, parent=self))
self.add("new_seed_after_apply", SwitchSetting(S._new_seed_after_apply, parent=self))
self.add(
"live_redraw_grace_period",
SliderSetting(S._live_redraw_grace_period, self, 0.0, 3.0, "{} s"),
)
self.add("debug_dump_workflow", SwitchSetting(S._debug_dump_workflow, parent=self))

languages = [(lang.name, lang.id) for lang in Localization.available]
Expand Down

0 comments on commit 81f0195

Please sign in to comment.