Skip to content

Commit

Permalink
Update upscale button text/enabled based on context #1269
Browse files Browse the repository at this point in the history
  • Loading branch information
Acly committed Nov 17, 2024
1 parent 135c6a8 commit 4e9bab7
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
25 changes: 19 additions & 6 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def upscale_image(self):
eventloop.run(_report_errors(self, self._enqueue_job(job, inputs)))

self._doc.resize(job.params.bounds.extent)
self.upscale.can_generate = False
self.upscale.set_in_progress(True)
self.upscale.target_extent_changed.emit(self.upscale.target_extent)

def estimate_cost(self, kind=JobKind.diffusion):
Expand Down Expand Up @@ -482,7 +482,7 @@ def handle_message(self, message: ClientMessage):

def _finish_job(self, job: Job, event: ClientEvent):
if job.kind is JobKind.upscaling:
self.upscale.can_generate = True
self.upscale.set_in_progress(False)

if event is ClientEvent.finished:
self.jobs.notify_finished(job)
Expand Down Expand Up @@ -757,7 +757,7 @@ class UpscaleParams(NamedTuple):

class UpscaleWorkspace(QObject, ObservableProperties):
upscaler = Property("", persist=True)
factor = Property(2.0, persist=True)
factor = Property(2.0, persist=True, setter="_set_factor")
use_diffusion = Property(True, persist=True)
strength = Property(0.3, persist=True)
unblur_strength = Property(1, persist=True)
Expand All @@ -774,12 +774,11 @@ class UpscaleWorkspace(QObject, ObservableProperties):
can_generate_changed = pyqtSignal(bool)
modified = pyqtSignal(QObject, str)

_model: Model

def __init__(self, model: Model):
super().__init__()
self._model = model
self.factor_changed.connect(lambda _: self.target_extent_changed.emit(self.target_extent))
self._in_progress = False
self.use_diffusion_changed.connect(self._update_can_generate)
self._init_model()
model._connection.models_changed.connect(self._init_model)

Expand All @@ -788,6 +787,20 @@ def _init_model(self):
if self.upscaler not in client.models.upscalers:
self.upscaler = client.models.default_upscaler

def set_in_progress(self, in_progress: bool):
self._in_progress = in_progress
self._update_can_generate()

def _set_factor(self, value: float):
if self._factor != value:
self._factor = value
self.factor_changed.emit(value)
self.target_extent_changed.emit(self.target_extent)
self._update_can_generate()

def _update_can_generate(self):
self.can_generate = not self._in_progress and (self.factor > 1.0 or self.use_diffusion)

@property
def target_extent(self):
return self._model.document.extent * self.factor
Expand Down
7 changes: 7 additions & 0 deletions ai_diffusion/ui/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(self):
layout.addLayout(model_layout)

self.factor_widget = FactorWidget(self)
self.factor_widget.value_changed.connect(self._update_factor)
layout.addWidget(self.factor_widget)

self.refinement_checkbox = QGroupBox(_("Refine upscaled image"), self)
Expand Down Expand Up @@ -282,6 +283,12 @@ def _update_prompt(self):
self.prompt_warning.hide()
set_text_clipped(self.prompt_label, text, padding=padding)

def _update_factor(self):
if self.factor_widget.value == 1.0 and self.model.upscale.use_diffusion:
self.upscale_button.operation = _("Refine")
else:
self.upscale_button.operation = _("Upscale")


def _upscaler_order(filename: str):
return {
Expand Down
19 changes: 14 additions & 5 deletions ai_diffusion/ui/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,22 +739,31 @@ def _create_action(self, name: str, workspace: Workspace):

class GenerateButton(QPushButton):
model: Model
operation: str
_operation: str
_kind: JobKind
_cost: int = 0
_cost_icon: QIcon

def __init__(self, kind: JobKind, parent: QWidget):
super().__init__(parent)
self.model = root.active_model
self.operation = _("Generate")
self._operation = _("Generate")
self._kind = kind
self._cost_icon = theme.icon("interstice")
self.setAttribute(Qt.WidgetAttribute.WA_Hover)

@property
def operation(self):
return self._operation

@operation.setter
def operation(self, value: str):
self._operation = value
self.update()

def minimumSizeHint(self):
fm = self.fontMetrics()
return QSize(fm.width(self.operation) + 40, 12 + int(1.3 * fm.height()))
return QSize(fm.width(self._operation) + 40, 12 + int(1.3 * fm.height()))

def enterEvent(self, a0: QEvent | None):
if client := root.connection.client_if_connected:
Expand All @@ -776,12 +785,12 @@ def paintEvent(self, a0: QPaintEvent | None) -> None:
is_hover = int(opt.state) & QStyle.StateFlag.State_MouseOver
element = QStyle.PrimitiveElement.PE_PanelButtonCommand
vcenter = Qt.AlignmentFlag.AlignVCenter
content_width = fm.width(self.operation) + 5 + pixmap.width()
content_width = fm.width(self._operation) + 5 + pixmap.width()
content_rect = rect.adjusted(int(0.5 * (rect.width() - content_width)), 0, 0, 0)
style.drawPrimitive(element, opt, painter, self)
style.drawItemPixmap(painter, content_rect, vcenter, pixmap)
content_rect = content_rect.adjusted(pixmap.width() + 5, 0, 0, 0)
style.drawItemText(painter, content_rect, vcenter, self.palette(), True, self.operation)
style.drawItemText(painter, content_rect, vcenter, self.palette(), True, self._operation)

if is_hover and self._cost > 0:
cost_width = fm.width(str(self._cost))
Expand Down

0 comments on commit 4e9bab7

Please sign in to comment.