Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Perturbed-Attention Guidance toggle. #622

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ai_diffusion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CheckpointInput:
clip_skip: int = 0
v_prediction_zsnr: bool = False
self_attention_guidance: bool = False
perturbed_attention_guidance: bool = False


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,9 @@ def apply_ip_adapter_face(
def apply_self_attention_guidance(self, model: Output):
return self.add("SelfAttentionGuidance", 1, model=model, scale=0.5, blur_sigma=2.0)

def apply_perturbed_attention_guidance(self, model: Output):
return self.add("PerturbedAttentionGuidance", 1, model=model)

def inpaint_preprocessor(self, image: Output, mask: Output):
return self.add("InpaintPreprocessor", 1, image=image, mask=mask)

Expand Down
8 changes: 8 additions & 0 deletions ai_diffusion/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ class StyleSettings:
_("Pay more attention to difficult parts of the image. Can improve fine details."),
)

perturbed_attention_guidance = Setting(
"Enable PAG / Perturbed-Attention Guidance",
False,
'Deliberately introduce errors in "difficult" parts to steer away from. Can improve coherence.',
)

preferred_resolution = Setting(
_("Preferred Resolution"), 0, _("Image resolution the checkpoint was trained on")
)
Expand Down Expand Up @@ -117,6 +123,7 @@ class Style:
clip_skip: int = StyleSettings.clip_skip.default
v_prediction_zsnr: bool = StyleSettings.v_prediction_zsnr.default
self_attention_guidance: bool = StyleSettings.self_attention_guidance.default
perturbed_attention_guidance: bool = StyleSettings.perturbed_attention_guidance.default
preferred_resolution: int = StyleSettings.preferred_resolution.default
sampler: str = StyleSettings.sampler.default
sampler_steps: int = StyleSettings.sampler_steps.default
Expand Down Expand Up @@ -188,6 +195,7 @@ def get_models(self):
v_prediction_zsnr=self.v_prediction_zsnr,
loras=[LoraInput.from_dict(l) for l in self.loras],
self_attention_guidance=self.self_attention_guidance,
perturbed_attention_guidance=self.perturbed_attention_guidance,
)
return result

Expand Down
7 changes: 7 additions & 0 deletions ai_diffusion/ui/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,12 @@ def add(name: str, widget: SettingWidget):
)
self._checkpoint_advanced_widgets.append(self._sag)

self._pag = add(
"perturbed_attention_guidance",
SwitchSetting(StyleSettings.perturbed_attention_guidance, parent=self),
)
self._checkpoint_advanced_widgets.append(self._pag)

for widget in self._checkpoint_advanced_widgets:
widget.indent = 1
self._toggle_checkpoint_advanced(False)
Expand Down Expand Up @@ -794,6 +800,7 @@ def _enable_checkpoint_advanced(self):
self._clip_skip.enabled = arch.supports_clip_skip and self.current_style.clip_skip > 0
self._zsnr.enabled = arch.supports_attention_guidance
self._sag.enabled = arch.supports_attention_guidance
self._pag.enabled = arch.supports_attention_guidance

def _read_style(self, style: Style):
with self._write_guard:
Expand Down
3 changes: 3 additions & 0 deletions ai_diffusion/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def load_checkpoint_with_lora(w: ComfyWorkflow, checkpoint: CheckpointInput, mod
if arch.supports_attention_guidance and checkpoint.self_attention_guidance:
model = w.apply_self_attention_guidance(model)

if checkpoint.perturbed_attention_guidance:
model = w.apply_perturbed_attention_guidance(model)

return model, clip, vae


Expand Down
Loading