Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix missing infotext cased by conda cache #16677

Draft
wants to merge 2 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any

import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling, util
from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
Expand Down Expand Up @@ -457,6 +457,20 @@ def cached_params(self, required_prompts, steps, extra_network_data, hires_steps
opts.emphasis,
)

def apply_generation_params_list(self, generation_params_states):
"""add and apply generation_params_states to self.extra_generation_params"""
for key, value in generation_params_states.items():
if key in self.extra_generation_params and isinstance(current_value := self.extra_generation_params[key], util.GenerationParametersList):
self.extra_generation_params[key] = current_value + value
else:
self.extra_generation_params[key] = value

def clear_marked_generation_params(self):
"""clears any generation parameters that are with the attribute to_be_clear_before_batch = True"""
for key, value in list(self.extra_generation_params.items()):
if getattr(value, 'to_be_clear_before_batch', False):
self.extra_generation_params.pop(key)

def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
Expand All @@ -480,13 +494,24 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr

for cache in caches:
if cache[0] is not None and cached_params == cache[0]:
if len(cache) == 3:
generation_params_states, cached_cached_params = cache[2]
if cached_params == cached_cached_params:
self.apply_generation_params_list(generation_params_states)
return cache[1]

cache = caches[0]

with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)

generation_params_states = model_hijack.extract_generation_params_states()
self.apply_generation_params_list(generation_params_states)
if len(cache) == 2:
cache.append((generation_params_states, cached_params))
else:
cache[2] = (generation_params_states, cached_params)

cache[0] = cached_params
return cache[1]

Expand All @@ -502,6 +527,8 @@ def setup_conds(self):
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)

self.extra_generation_params.update(model_hijack.extra_generation_params)

def get_conds(self):
return self.c, self.uc

Expand Down Expand Up @@ -801,10 +828,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter

for key, value in generation_params.items():
try:
if isinstance(value, list):
generation_params[key] = value[index]
elif callable(value):
if callable(value):
generation_params[key] = value(**locals())
elif isinstance(value, list):
generation_params[key] = value[index]
except Exception:
errors.report(f'Error creating infotext for key "{key}"', exc_info=True)
generation_params[key] = None
Expand Down Expand Up @@ -938,6 +965,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted or state.stopping_generation:
break

p.clear_marked_generation_params() # clean up some generation params are tagged to be cleared before batch
sd_models.reload_model_weights() # model can be changed for example by refiner

p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
Expand Down Expand Up @@ -965,8 +993,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:

p.setup_conds()

p.extra_generation_params.update(model_hijack.extra_generation_params)

# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
# Example: a wildcard processed by process_batch sets an extra model
Expand Down Expand Up @@ -1513,6 +1539,8 @@ def calculate_hr_conds(self):
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)

self.extra_generation_params.update(model_hijack.extra_generation_params)

def setup_conds(self):
if self.is_hr_pass:
# if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model
Expand Down
10 changes: 9 additions & 1 deletion modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from torch.nn.functional import silu
from types import MethodType

from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches, util
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, xlmr_m18
Expand Down Expand Up @@ -321,6 +321,14 @@ def clear_comments(self):
self.comments = []
self.extra_generation_params = {}

def extract_generation_params_states(self):
"""Extracts GenerationParametersList so that they can be cached and restored later"""
states = {}
for key in list(self.extra_generation_params):
if isinstance(self.extra_generation_params[key], util.GenerationParametersList):
states[key] = self.extra_generation_params.pop(key)
return states

def get_prompt_lengths(self, text):
if self.clip is None:
return "-", "-"
Expand Down
34 changes: 28 additions & 6 deletions modules/sd_hijack_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from modules import prompt_parser, devices, sd_hijack, sd_emphasis
from modules import prompt_parser, devices, sd_hijack, sd_emphasis, util
from modules.shared import opts


Expand All @@ -27,6 +27,30 @@ def __init__(self):
are applied by sd_hijack.EmbeddingsWithFixes's forward function."""


class EmphasisMode(util.GenerationParametersList):
def __init__(self, emphasis_mode:str = None):
super().__init__()
self.emphasis_mode = emphasis_mode

def __call__(self, *args, **kwargs):
return self.emphasis_mode

def __add__(self, other):
if isinstance(other, EmphasisMode):
return self if self.emphasis_mode else other
elif isinstance(other, str):
return self.__str__() + other
return NotImplemented

def __radd__(self, other):
if isinstance(other, str):
return other + self.__str__()
return NotImplemented

def __str__(self):
return self.emphasis_mode if self.emphasis_mode else ''


class TextConditionalModel(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -238,12 +262,10 @@ def forward(self, texts):
hashes.append(f"{name}: {shorthash}")

if hashes:
if self.hijack.extra_generation_params.get("TI hashes"):
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
self.hijack.extra_generation_params["TI hashes"] = util.GenerationParametersList(hashes)

if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original":
self.hijack.extra_generation_params["Emphasis"] = opts.emphasis
if opts.emphasis != 'Original' and any(x for x in texts if '(' in x or '[' in x):
self.hijack.extra_generation_params["Emphasis"] = EmphasisMode(opts.emphasis)

if self.return_pooled:
return torch.hstack(zs), zs[0].pooled
Expand Down
46 changes: 46 additions & 0 deletions modules/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,49 @@ def compare_sha256(file_path: str, hash_prefix: str) -> bool:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())


class GenerationParametersList(list):
"""A special object used in sd_hijack.StableDiffusionModelHijack for setting extra_generation_params
due to StableDiffusionProcessing.get_conds_with_caching
extra_generation_params set in StableDiffusionModelHijack will be lost when cached is used

When an extra_generation_params is set in StableDiffusionModelHijack using this object,
the params will be extracted by StableDiffusionModelHijack.extract_generation_params_states
the extracted params will be cached in StableDiffusionProcessing.get_conds_with_caching
and applyed to StableDiffusionProcessing.extra_generation_params by StableDiffusionProcessing.apply_generation_params_states

Example see modules.sd_hijack_clip.TextConditionalModel.hijack.extra_generation_params 'TI hashes' 'Emphasis'

Depending on the use case the methods can be overwritten.
In general __call__ method should return str or None, as normally it's called in modules.processing.create_infotext.
When called by create_infotext it will access to the locals() of the caller,
if return str, the value will be written to infotext, if return None will be ignored.
"""

def __init__(self, *args, to_be_clear_before_batch=True, **kwargs):
super().__init__(*args, **kwargs)
self._to_be_clear_before_batch = to_be_clear_before_batch

def __call__(self, *args, **kwargs):
return ', '.join(sorted(set(self), key=natural_sort_key))

@property
def to_be_clear_before_batch(self):
return self._to_be_clear_before_batch

def __add__(self, other):
if isinstance(other, GenerationParametersList):
return self.__class__([*self, *other])
elif isinstance(other, str):
return self.__str__() + other
return NotImplemented

def __radd__(self, other):
if isinstance(other, str):
return other + self.__str__()
return NotImplemented

def __str__(self):
return self.__call__()