From 1bb97d6c5da9d20e050d661f117a42f82f2a2a8e Mon Sep 17 00:00:00 2001 From: FeepingCreature Date: Tue, 16 Apr 2024 21:47:30 +0200 Subject: [PATCH] Add Perturbed-Attention Guidance toggle. See https://ku-cvlab.github.io/Perturbed-Attention-Guidance/ --- ai_diffusion/api.py | 1 + ai_diffusion/comfy_workflow.py | 3 +++ ai_diffusion/style.py | 8 ++++++++ ai_diffusion/ui/style.py | 7 +++++++ ai_diffusion/workflow.py | 3 +++ 5 files changed, 22 insertions(+) diff --git a/ai_diffusion/api.py b/ai_diffusion/api.py index 657217ec9c..67b25c2a8b 100644 --- a/ai_diffusion/api.py +++ b/ai_diffusion/api.py @@ -61,6 +61,7 @@ class CheckpointInput: clip_skip: int = 0 v_prediction_zsnr: bool = False self_attention_guidance: bool = False + perturbed_attention_guidance: bool = False @dataclass diff --git a/ai_diffusion/comfy_workflow.py b/ai_diffusion/comfy_workflow.py index b985a1457b..95a67c8d8c 100644 --- a/ai_diffusion/comfy_workflow.py +++ b/ai_diffusion/comfy_workflow.py @@ -510,6 +510,9 @@ def apply_ip_adapter_face( def apply_self_attention_guidance(self, model: Output): return self.add("SelfAttentionGuidance", 1, model=model, scale=0.5, blur_sigma=2.0) + def apply_perturbed_attention_guidance(self, model: Output): + return self.add("PerturbedAttentionGuidance", 1, model=model) + def inpaint_preprocessor(self, image: Output, mask: Output): return self.add("InpaintPreprocessor", 1, image=image, mask=mask) diff --git a/ai_diffusion/style.py b/ai_diffusion/style.py index 9f6bcc0b6e..43ed227c94 100644 --- a/ai_diffusion/style.py +++ b/ai_diffusion/style.py @@ -83,6 +83,12 @@ class StyleSettings: _("Pay more attention to difficult parts of the image. Can improve fine details."), ) + perturbed_attention_guidance = Setting( + "Enable PAG / Perturbed-Attention Guidance", + False, + 'Deliberately introduce errors in "difficult" parts to steer away from. Can improve coherence.', + ) + preferred_resolution = Setting( _("Preferred Resolution"), 0, _("Image resolution the checkpoint was trained on") ) @@ -119,6 +125,7 @@ class Style: clip_skip: int = StyleSettings.clip_skip.default v_prediction_zsnr: bool = StyleSettings.v_prediction_zsnr.default self_attention_guidance: bool = StyleSettings.self_attention_guidance.default + perturbed_attention_guidance: bool = StyleSettings.perturbed_attention_guidance.default preferred_resolution: int = StyleSettings.preferred_resolution.default sampler: str = StyleSettings.sampler.default sampler_steps: int = StyleSettings.sampler_steps.default @@ -188,6 +195,7 @@ def get_models(self): v_prediction_zsnr=self.v_prediction_zsnr, loras=[LoraInput.from_dict(l) for l in self.loras], self_attention_guidance=self.self_attention_guidance, + perturbed_attention_guidance=self.perturbed_attention_guidance, ) return result diff --git a/ai_diffusion/ui/style.py b/ai_diffusion/ui/style.py index 345b343bae..9ea663ccd1 100644 --- a/ai_diffusion/ui/style.py +++ b/ai_diffusion/ui/style.py @@ -611,6 +611,12 @@ def add(name: str, widget: SettingWidget): ) self._checkpoint_advanced_widgets.append(self._sag) + self._pag = add( + "perturbed_attention_guidance", + SwitchSetting(StyleSettings.perturbed_attention_guidance, parent=self), + ) + self._checkpoint_advanced_widgets.append(self._pag) + for widget in self._checkpoint_advanced_widgets: widget.indent = 1 self._toggle_checkpoint_advanced(False) @@ -751,6 +757,7 @@ def _enable_checkpoint_advanced(self): self._clip_skip.enabled = arch.supports_clip_skip and self.current_style.clip_skip > 0 self._zsnr.enabled = arch.supports_attention_guidance self._sag.enabled = arch.supports_attention_guidance + self._pag.enabled = arch.supports_attention_guidance def _read_style(self, style: Style): with self._write_guard: diff --git a/ai_diffusion/workflow.py b/ai_diffusion/workflow.py index ca38af1e8e..3ace9b6cfc 100644 --- a/ai_diffusion/workflow.py +++ b/ai_diffusion/workflow.py @@ -126,6 +126,9 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod if arch.supports_attention_guidance and checkpoint.self_attention_guidance: model = w.apply_self_attention_guidance(model) + if checkpoint.perturbed_attention_guidance: + model = w.apply_perturbed_attention_guidance(model) + return model, clip, vae