Skip to content

Commit

Permalink
remove duplicated code
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Oct 14, 2023
1 parent 891ccb7 commit a8cbe50
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 63 deletions.
31 changes: 2 additions & 29 deletions extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Union

from modules import shared, devices, sd_models, errors, scripts, sd_hijack
from modules.textual_inversion.textual_inversion import Embedding
import modules.textual_inversion.textual_inversion as textual_inversion

from lora_logger import logger

Expand Down Expand Up @@ -210,34 +210,7 @@ def load_network(name, network_on_disk):

embeddings = {}
for emb_name, data in bundle_embeddings.items():
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.")

embedding = Embedding(vec, emb_name)
embedding.vectors = vectors
embedding.shape = shape
embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
embedding.loaded = None
embeddings[emb_name] = embedding

Expand Down
74 changes: 40 additions & 34 deletions modules/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,40 +181,7 @@ def load_from_file(self, path, filename):
else:
return


# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")

embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vectors
embedding.shape = shape
embedding.filename = path
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
embedding = create_embedding_from_data(data, name, filename=filename, filepath=path)

if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
Expand Down Expand Up @@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
return fn


def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None):
if 'string_to_param' in data: # textual inversion embeddings
param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")

embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vectors
embedding.shape = shape

if filepath:
embedding.filename = filepath
embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '')

return embedding


def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
Expand Down

0 comments on commit a8cbe50

Please sign in to comment.