Skip to content

Commit

Permalink
image embedding data cache (#16556)
Browse files Browse the repository at this point in the history
  • Loading branch information
w-e-w authored Oct 29, 2024
1 parent d88a3c1 commit deb3803
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
1 change: 1 addition & 0 deletions modules/shared_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@
"textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"),
"textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"),
"sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks),
"textual_inversion_image_embedding_data_cache": OptionInfo(False, 'Cache the data of image embeddings').info('potentially increase TI load time at the cost some disk space'),

This comment has been minimized.

Copy link
@strawberrymelonpanda

strawberrymelonpanda Oct 29, 2024

I believe this should say:
"potentially decrease TI load time at the cost of some disk space"

}))

options_templates.update(options_section(('ui_prompt_editing', "Prompt editing", "ui"), {
Expand Down
44 changes: 31 additions & 13 deletions modules/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
from PIL import Image, PngImagePlugin

from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes
from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes, cache
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler

Expand Down Expand Up @@ -116,6 +116,7 @@ def __init__(self):
self.expected_shape = -1
self.embedding_dirs = {}
self.previously_displayed_embeddings = ()
self.image_embedding_cache = cache.cache('image-embedding')

def add_embedding_dir(self, path):
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
Expand Down Expand Up @@ -154,6 +155,31 @@ def get_expected_shape(self):
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]

def read_embedding_from_image(self, path, name):
try:
ondisk_mtime = os.path.getmtime(path)

if (cache_embedding := self.image_embedding_cache.get(path)) and ondisk_mtime == cache_embedding.get('mtime', 0):
# cache will only be used if the file has not been modified time matches
return cache_embedding.get('data', None), cache_embedding.get('name', None)

embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name', name)
elif data := extract_image_data_embed(embed_image):
name = data.get('name', name)

if data is None or shared.opts.textual_inversion_image_embedding_data_cache:
# data of image embeddings only will be cached if the option textual_inversion_image_embedding_data_cache is enabled
# results of images that are not embeddings will allways be cached to reduce unnecessary future disk reads
self.image_embedding_cache[path] = {'data': data, 'name': None if data is None else name, 'mtime': ondisk_mtime}

return data, name
except Exception:
errors.report(f"Error loading embedding {path}", exc_info=True)
return None, None

def load_from_file(self, path, filename):
name, ext = os.path.splitext(filename)
ext = ext.upper()
Expand All @@ -163,17 +189,10 @@ def load_from_file(self, path, filename):
if second_ext.upper() == '.PREVIEW':
return

embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name', name)
else:
data = extract_image_data_embed(embed_image)
if data:
name = data.get('name', name)
else:
# if data is None, means this is not an embedding, just a preview image
return
data, name = self.read_embedding_from_image(path, name)
if data is None:
return

elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
elif ext in ['.SAFETENSORS']:
Expand All @@ -191,7 +210,6 @@ def load_from_file(self, path, filename):
else:
print(f"Unable to load Textual inversion embedding due to data issue: '{name}'.")


def load_from_dir(self, embdir):
if not os.path.isdir(embdir.path):
return
Expand Down

0 comments on commit deb3803

Please sign in to comment.