Skip to content

Commit

Permalink
fix load_vae() to check size mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
wkpark committed Sep 17, 2024
1 parent e063757 commit de9497a
Showing 1 changed file with 20 additions and 8 deletions.
28 changes: 20 additions & 8 deletions modules/sd_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,47 +197,58 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):

cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0

loaded = False
if vae_file:
if cache_enabled and vae_file in checkpoints_loaded:
# use vae checkpoint cache
print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
store_base_vae(model)
_load_vae_dict(model, checkpoints_loaded[vae_file])
loaded = _load_vae_dict(model, checkpoints_loaded[vae_file])
else:
assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
print(f"Loading VAE weights {vae_source}: {vae_file}")
store_base_vae(model)

vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
_load_vae_dict(model, vae_dict_1)
loaded = _load_vae_dict(model, vae_dict_1)

if cache_enabled:
if loaded and cache_enabled:
# cache newly loaded vae
checkpoints_loaded[vae_file] = vae_dict_1.copy()

# clean up cache if limit is reached
if cache_enabled:
if loaded and cache_enabled:
while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model
checkpoints_loaded.popitem(last=False) # LRU

# If vae used is not in dict, update it
# It will be removed on refresh though
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
if loaded and vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file

elif loaded_vae_file:
restore_base_vae(model)
loaded = True

loaded_vae_file = vae_file
if loaded:
loaded_vae_file = vae_file
model.base_vae = base_vae
model.loaded_vae_file = loaded_vae_file
return loaded


# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
conv_out = model.first_stage_model.state_dict().get("encoder.conv_out.weight")
# check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3]
if conv_out.shape != vae_dict_1["encoder.conv_out.weight"].shape:
print("Failed to load VAE. Size mismatched!")
return False

model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
return True


def clear_loaded_vae():
Expand Down Expand Up @@ -270,7 +281,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):

sd_hijack.model_hijack.undo_hijack(sd_model)

load_vae(sd_model, vae_file, vae_source)
loaded = load_vae(sd_model, vae_file, vae_source)

sd_hijack.model_hijack.hijack(sd_model)

Expand All @@ -279,5 +290,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):

script_callbacks.model_loaded_callback(sd_model)

print("VAE weights loaded.")
if loaded:
print("VAE weights loaded.")
return sd_model

0 comments on commit de9497a

Please sign in to comment.