From 015e20a3dbcf4ede6178533ad5b4a1d5064b5252 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Tue, 17 Sep 2024 11:05:48 +0900 Subject: [PATCH] fix load_vae() to check size mismatch --- modules/sd_vae.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 43687e48dcf..357f6febbe4 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -202,42 +202,52 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): # use vae checkpoint cache print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") store_base_vae(model) - _load_vae_dict(model, checkpoints_loaded[vae_file]) + loaded = _load_vae_dict(model, checkpoints_loaded[vae_file]) else: assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location) - _load_vae_dict(model, vae_dict_1) + loaded = _load_vae_dict(model, vae_dict_1) - if cache_enabled: + if loaded and cache_enabled: # cache newly loaded vae checkpoints_loaded[vae_file] = vae_dict_1.copy() # clean up cache if limit is reached - if cache_enabled: + if loaded and cache_enabled: while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model checkpoints_loaded.popitem(last=False) # LRU # If vae used is not in dict, update it # It will be removed on refresh though vae_opt = get_filename(vae_file) - if vae_opt not in vae_dict: + if loaded and vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file elif loaded_vae_file: restore_base_vae(model) + loaded = True - loaded_vae_file = vae_file + if loaded: + loaded_vae_file = vae_file model.base_vae = base_vae model.loaded_vae_file = loaded_vae_file + return loaded # don't call this from outside def _load_vae_dict(model, vae_dict_1): + conv_out = model.first_stage_model.state_dict().get("encoder.conv_out.weight") + # check shape of "encoder.conv_out.weight". SD1.5/SDXL: [8, 512, 3, 3], FLUX/SD3: [32, 512, 3, 3] + if conv_out.shape != vae_dict_1["encoder.conv_out.weight"].shape: + print("Failed to load VAE. Size mismatched!") + return False + model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) + return True def clear_loaded_vae(): @@ -270,7 +280,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): sd_hijack.model_hijack.undo_hijack(sd_model) - load_vae(sd_model, vae_file, vae_source) + loaded = load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) @@ -279,5 +289,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): script_callbacks.model_loaded_callback(sd_model) - print("VAE weights loaded.") + if loaded: + print("VAE weights loaded.") return sd_model