From 091f56e169dbfd4bb2a16bcab06a43d7f8706b91 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 24 Aug 2024 19:57:35 +0900 Subject: [PATCH 01/10] initial SDv3 / FLUX support added. --- javascript/update.js | 50 ++- scripts/model_mixer.py | 874 ++++++++++++++++++++++++++++------------- 2 files changed, 641 insertions(+), 283 deletions(-) diff --git a/javascript/update.js b/javascript/update.js index 98b2775..5a17f7b 100644 --- a/javascript/update.js +++ b/javascript/update.js @@ -1,25 +1,43 @@ function mm_slider_to_text() { let res = Array.from(arguments); - const ISXLBLOCK = [ - /* base, in01, in02,... */ - true, true, true, true, true, true, true, true, true, true, false, false, false, - /* mid, out01, out02,... */ - true, true, true, true, true, true, true, true, true, true, false, false, false, - ]; - let isxl = res[0]; + + function is_block_sliders(sdversion) { + let sliders = Array(1 + 38 + 1 + 38).fill(false); // base + in 38 blocks + middle + out 38 blocks + if (sdversion == "v1" || sdversion == "v2") { + sliders[0] = true; // base + sliders.splice(1, 12, ...Array(12).fill(true)); // input blocks 00-11 + sliders[39] = true; // middle + sliders.splice(40, 12, ...Array(12).fill(true)); // output blocks 00-11 + } else if (sdversion == "XL") { + sliders[0] = true; // base + sliders.splice(1, 9, ...Array(9).fill(true)); // input blocks 00-08 + sliders[39] = true; // middle + sliders.splice(40, 9, ...Array(9).fill(true)); // output blocks 00-08 + } else if (sdversion == "v3") { + sliders[0] = true; // base + sliders.splice(1, 12, ...Array(12).fill(true)); // joint blocks 00-11 + sliders[39] = false; // no middle + sliders.splice(40, 12, ...Array(12).fill(true)); // joint blocks 12-23 + } else if (sdversion == "FLUX") { + sliders[0] = true; // base + sliders.splice(1, 19, ...Array(19).fill(true)); // double blocks 00-11 + sliders[39] = false; // no middle + sliders.splice(40, 38, ...Array(38).fill(true)); // single blocks 00-37 + } + return sliders; + } + + let sdv = res[0]; let selected = []; - let slider = res.slice(1); - if (isxl) { - selected = [] - for (let i = 0; i < slider.length; i++) { - if (ISXLBLOCK[i]) { - selected.push(slider[i]); - } + const slider = res.slice(1); + const block_sliders = is_block_sliders(sdv); + for (let i = 0; i < slider.length; i++) { + if (block_sliders[i]) { + selected.push(slider[i]); } - } else { - selected = slider; } + let mbw = null; let mbws = gradioApp().querySelectorAll(".mm_mbw textarea"); for (let i = 0; i < mbws.length; i++) { diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index 0234ee0..e02bb4d 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -107,16 +107,43 @@ def gr_enable(interactive=True): def gr_open(open=True): return {"open": open, "__type__": "update"} -def slider2text(isxl, *slider): - if isxl: - selected = [] - for i,v in enumerate(slider): - if ISXLBLOCK[i]: - selected.append(slider[i]) - else: - selected = slider + +def is_block_sliders(sdversion): + sliders = [False] * (1 + 38 + 1 + 38) # base + in 38 blocks + middle + out 38 blocks + if sdversion in ["v1", "v2"]: + sliders[0] = True # base + sliders[1:13] = [True]*12 # input blocks 00-11 + sliders[39] = True # middle + sliders[40:52] = [True]*12 # output blocks 00-11 + elif sdversion == "XL": + sliders[0] = True # base + sliders[1:10] = [True]*9 # input blocks 00-08 + sliders[39] = True # middle + sliders[40:49] = [True]*9 # output blocks 00-08 + elif sdversion == "v3": + sliders[0] = True # base + sliders[1:13] = [True]*12 # input blocks 00-11 + sliders[39] = False # no middle + sliders[40:52] = [True]*12 # input blocks 12-23 + elif sdversion == "FLUX": + sliders[0] = True # base + sliders[1:20] = [True]*19 # double blocks 00-11 + sliders[39] = False # no middle + sliders[40:78] = [True]*38 # single blocks 00-37 + + return sliders + + +def slider2text(version, *slider): + is_selected_blocks = is_block_sliders(version) + selected = [] + for i,v in enumerate(slider): + if is_selected_blocks[i]: + selected.append(slider[i]) + return gr.update(value = ",".join([str(x) for x in selected])) + parsed_mbwpresets = {} def _load_mbwpresets(): raw = None @@ -241,145 +268,202 @@ def get_selected_blocks(mbw_blocks, isxl=False): return selected_blocks -def calc_mbws(mbw, mbw_blocks, isxl=False): +def normalize_blocks(blocks, sdv): + # no mbws blocks selected or have 'ALL' alias + if len(blocks) == 0 or 'ALL' in blocks: + # select all blocks + blocks = [ 'BASE', 'INP*', 'MID', 'OUT*' ] + + # fix alias + if 'MID' in blocks: + i = blocks.index('MID') + blocks[i] = 'M00' + + if sdv in ["v1", "v2", "XL"]: + isxl = sdv == "XL" + BLOCKLEN = 12 - (0 if not isxl else 3) + + # expand some aliases + if 'INP*' in blocks: + for i in range(0, BLOCKLEN): + name = f"IN{i:02d}" + if name not in blocks: + blocks.append(name) + if 'OUT*' in blocks: + for i in range(0, BLOCKLEN): + name = f"OUT{i:02d}" + if name not in blocks: + blocks.append(name) + + elif sdv == "FLUX": + # expand some aliases + if 'INP*' in blocks or 'DOUBLE*' in blocks: + for i in range(0, 19): + name = f"DOUBLE{i:02d}" + if name not in blocks: + blocks.append(name) + if 'OUT*' in blocks or 'SINGLE*' in blocks: + for i in range(0, 38): + name = f"SINGLE{i:02d}" + if name not in blocks: + blocks.append(name) + + elif sdv == "v3": + # expand some aliases + if 'INP*' in blocks: + for i in range(0, 24): + name = f"IN{i:02d}" + if name not in blocks: + blocks.append(name) + + blocks = list(set(blocks)) + + # filter valid blocks + BLOCKIDS = all_blocks(sdv) + MAXLEN = len(BLOCKIDS) + selected = [False]*MAXLEN + + normalized = [] + for i, name in enumerate(BLOCKIDS): + if name in blocks: + selected[i] = True + normalized.append(name) + + return normalized, selected + + +def calc_mbws(mbw, mbw_blocks, sdver): if "," in mbw: weights = [t.strip() for t in mbw.strip().split(",")] elif " " in mbw.strip(): weights = [t.strip() for t in mbw.strip().split(" ")] else: weights = [mbw.strip()] - expect = 0 - MAXLEN = 26 - (0 if not isxl else 6) - BLOCKLEN = 12 - (0 if not isxl else 3) - BLOCKOFFSET = 13 if not isxl else 10 - selected = [False]*MAXLEN - compact_blocks = [] - BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL - # no mbws blocks selected or have 'ALL' alias - if len(mbw_blocks) == 0 or 'ALL' in mbw_blocks: - # select all blocks - mbw_blocks = [ 'BASE', 'INP*', 'MID', 'OUT*' ] + normalized, selected = normalize_blocks(mbw_blocks, sdver) - # fix alias - if 'MID' in mbw_blocks: - i = mbw_blocks.index('MID') - mbw_blocks[i] = 'M00' + MAXLEN = len(selected) - # expand some aliases - if 'INP*' in mbw_blocks: - for i in range(0, BLOCKLEN): - name = f"IN{i:02d}" - if name not in mbw_blocks: - mbw_blocks.append(name) - if 'OUT*' in mbw_blocks: - for i in range(0, BLOCKLEN): - name = f"OUT{i:02d}" - if name not in mbw_blocks: - mbw_blocks.append(name) + # full weights given + mbws = [0.0] * MAXLEN + expect = len([s for s in selected if s is True]) + wlen = len(weights) - for i, name in enumerate(BLOCKIDS): - if name in mbw_blocks: - if name[0:2] == 'IN': - expect += 1 - num = int(name[2:]) - selected[num + 1] = True - compact_blocks.append(f'inp.{num}.') - elif name[0:3] == 'OUT': - expect += 1 - num = int(name[3:]) - selected[num + BLOCKOFFSET + 1] = True - compact_blocks.append(f'out.{num}.') - elif name == 'M00': - expect += 1 - selected[BLOCKOFFSET] = True - compact_blocks.append('mid.1.') - elif name == 'BASE': - expect +=1 - selected[0] = True - compact_blocks.append('base') + if wlen > MAXLEN: # too many weights given + weights = weights[:MAXLEN] # trim weights - if len(weights) > MAXLEN: - weights = weights[:MAXLEN] - elif len(weights) > expect: - for i in range(len(weights), MAXLEN): - weights.append(weights[len(weights)-1]) # fill up last weight - elif len(weights) < expect: - for i in range(len(weights), expect): - weights.append(weights[len(weights)-1]) # fill up last weight + # parse weights + for i, f in enumerate(weights): + try: + f = float(f) + weights[i] = f + except: + pass # ignore invalid entries + if wlen < expect: + # in this case, fill up to expected number of weights. 0.5 will be expanded to 0.5,0.5,...,0.5 + weights.extend([weights[-1]] * (expect - wlen)) + elif wlen > expect: + # in this case, given weights are suppose to be the trimed result of empty values (trimed values are ,0.0, 0.0,...,0.0) + weights.extend([0.0] * (MAXLEN - wlen)) # fill up 0.0 to MAXLEN + + mbws = [0.0] * MAXLEN if len(weights) == MAXLEN: # full weights given - mbws = [0.0]*len(weights) - compact_mbws = [] - for i,f in enumerate(weights): - try: - f = float(f) - weights[i] = f - except: - pass # ignore invalid entries + for i, f in enumerate(weights): if selected[i]: mbws[i] = weights[i] - compact_mbws.append(mbws[i]) else: - # short weights given - compact_mbws = [0.0]*len(weights) - mbws = [0.0]*MAXLEN - for i,f in enumerate(weights): - try: - f = float(f) - compact_mbws[i] = f - except: - pass # ignore invalid entries + # only selected weights given + k = 0 + for i in range(MAXLEN): + if selected[i]: + mbws[i] = weights[k] + k += 1 - block = compact_blocks[i] - if 'base' == block: - off = 0 - num = 0 - else: - block, num, _ = compact_blocks[i].split(".") - num = int(num) - if 'inp' == block: - off = 1 - elif 'mid' == block: - off = BLOCKOFFSET - num = 0 - elif 'out' == block: - off = BLOCKOFFSET + 1 - - mbws[off + num] = compact_mbws[i] - - return mbws, compact_mbws, selected - -def get_mbws(mbw, use_advanced, mbw_blocks, simple_blocks, isxl=False): - mbws, compact_mbws, selected = calc_mbws(mbw, mbw_blocks if use_advanced else simple_blocks, isxl=isxl) - if isxl: - j = 0 - ret = [] - for i, v in enumerate(ISXLBLOCK): - if v: - ret.append(gr.update(value = mbws[j])) - j += 1 - else: - ret.append(gr.update()) - return ret + [gr.update(open=True)] + return mbws, None, selected - return [gr.update(value = v) for v in mbws] + [gr.update(open=True)] -def _all_blocks(isxl=False): - BLOCKLEN = 12 - (0 if not isxl else 3) - # return all blocks - base_prefix = "cond_stage_model." if not isxl else "conditioner." - blocks = [ base_prefix ] - for i in range(0, BLOCKLEN): - blocks.append(f"input_blocks.{i}.") - blocks.append("middle_block.") - for i in range(0, BLOCKLEN): - blocks.append(f"output_blocks.{i}.") - - blocks += [ "time_embed.", "out." ] +def get_mbws(mbw, use_advanced, mbw_blocks, simple_blocks, sdver): + mbws, dummy, selected = calc_mbws(mbw, mbw_blocks if use_advanced else simple_blocks, sdver) + + is_sliders = is_block_sliders(sdver) + + j = 0 + ret = [] + for i, v in enumerate(is_sliders): + if v: + ret.append(gr.update(value=mbws[j])) + j += 1 + else: + ret.append(gr.update()) + return ret + [gr.update(open=True)] + + +def _all_blocks(sdversion): + if sdversion is True: + # for old behavior called by _all_blocks(isxl) + sdversion = "XL" + if sdversion in ["v1", "v2", "XL"]: + BLOCKLEN = 12 - (0 if sdversion != "XL" else 3) + # return all blocks + base_prefix = "cond_stage_model." if sdversion != "XL" else "conditioner." + blocks = [ base_prefix ] + for i in range(0, BLOCKLEN): + blocks.append(f"input_blocks.{i}.") + blocks.append("middle_block.") + for i in range(0, BLOCKLEN): + blocks.append(f"output_blocks.{i}.") + + blocks += [ "time_embed.", "out." ] + if sdversion == "XL": + blocks += [ "label_emb." ] + + elif sdversion == "v3": + #blocks = [ "text_encoders.clip_l.", "text_encoders.clip_g.", "text_encoders.t5xxl." ] + blocks = [ "text_encoders." ] + for i in range(0, 24): + blocks.append(f"joint_blocks.{i}.") + + blocks += [ "x_embedder.", "t_embedder.", "y_embedder.", "context_embedder.", "pos_embed", "final_layer." ] + + elif sdversion == "FLUX": + #blocks = [ "text_encoders.clip_l.", "text_encoders.t5xxl." ] + blocks = [ "text_encoders." ] + for i in range(0, 19): + blocks.append(f"double_blocks.{i}.") + for i in range(0, 38): + blocks.append(f"single_blocks.{i}.") + + blocks += [ "img_in.", "time_in.", "vector_in.", "guidance_in.", "txt_in.", "final_layer." ] + return blocks + +def all_blocks(sdversion): + blocks = [ 'BASE' ] + if sdversion in ["v1", "v2", "XL"]: + BLOCKLEN = 12 - (0 if sdversion != "XL" else 3) + # return all blocks + for i in range(0, BLOCKLEN): + blocks.append(f"IN{i:02d}") + blocks.append("M00") + for i in range(0, BLOCKLEN): + blocks.append(f"OUT{i:02d}") + + elif sdversion == "v3": + for i in range(0, 24): + blocks.append(f"IN{i:02d}") + + elif sdversion == "FLUX": + for i in range(0, 19): + blocks.append(f"DOUBLE{i:02d}") + for i in range(0, 38): + blocks.append(f"SINGLE{i:02d}") + + return blocks + + def print_blocks(blocks): str = [] for i,x in enumerate(blocks): @@ -394,6 +478,18 @@ def print_blocks(blocks): n = int(x[14:len(x)-1]) block = f"OUT{n:02d}" str.append(block) + elif "single_blocks." in x: + n = int(x[14:len(x)-1]) + block = f"SINGLE{n:02d}" + str.append(block) + elif "double_blocks." in x: + n = int(x[14:len(x)-1]) + block = f"DOUBLE{n:02d}" + str.append(block) + elif "joint_blocks." in x: + n = int(x[13:len(x)-1]) + block = f"IN{n:02d}" + str.append(block) elif "cond_stage_model" in x or "conditioner." in x: block = f"BASE" str.append(block) @@ -422,27 +518,57 @@ def _selected_blocks_and_weights(mbw, isxl=False): sel_mbws.append(v) return sel_blocks, sel_mbws -def _weight_index(key, isxl=False): - num = -1 - offset = [ 0, 1, 13, 14 ] if not isxl else [ 0, 1, 10, 11 ] - base_prefix = "cond_stage_model." if not isxl else "conditioner." - for k, s in enumerate([ base_prefix, "input_blocks.", "middle_block.", "output_blocks." ]): - if s in key: + +def _weight_index(key, sdversion): + if sdversion in ["v1", "v2", "XL"]: + isxl = sdversion == "XL" + base_prefix = "cond_stage_model." if not isxl else "conditioner." + candidates = [ base_prefix, "input_blocks.", "middle_block.", "output_blocks." ] + elif sdversion == "v3": + candidates = [ "text_encoders.", "joint_blocks." ] + elif sdversion == "FLUX": + candidates = [ "text_encoders.", "double_blocks.", "single_blocks." ] + + all_blocks = _all_blocks(sdversion) + + for k, candidate in enumerate(candidates): + if candidate in key: if k == 0: return 0 # base - if k == 2: return offset[2] # middle_block + if candidate == "middle_block.": + if candidate in all_blocks: + return all_blocks.index(candidate) + else: + return -1 + + i = key.find(candidate) + if i > 0: + j = key.find(".", i + len(candidate)) + name = "{}{}.".format(candidate, key[i+len(candidate):j]) + + return all_blocks.index(name) + else: + return -1 - i = key.find(s) - j = key.find(".", i+len(s)) - num = int(key[i+len(s):j]) + offset[k] - return num + return -1 + + +def prune_model(state_dict, sdver): + keys = list(state_dict.keys()) + if sdver in ["v1", "v2", "XL"]: + acceptables = [ "diffusion_model.", "first_stage_model.", ] + base_prefix = "conditioner." if sdver == "XL" else "cond_stage_model." + acceptables += [ base_prefix] + elif sdver == "FLUX": + acceptables = [ "diffusion_model.", "vae.decoder.", "vae.encoder.", "vae.decoder." ] + acceptables += [ "text_encoders.clip_l.", "text_encoders.t5xxl." ] + elif sdver == "v3": + acceptables = [ "diffusion_model.", "vae.decoder.", "vae.encoder.", "vae.decoder." ] + acceptables += [ "text_encoders.clip_l.", "text_encoders.clip_g.", "text_encoders.t5xxl." ] -def prune_model(model, isxl=False): - keys = list(model.keys()) - base_prefix = "conditioner." if isxl else "cond_stage_model." for k in keys: - if "diffusion_model." not in k and "first_stage_model." not in k and base_prefix not in k: - model.pop(k, None) - return model + if all(acceptable not in k for acceptable in acceptables): + state_dict.pop(k, None) + return state_dict def to_half(tensor, enable): if enable and type(tensor) in [dict, collections.OrderedDict]: @@ -657,20 +783,21 @@ def is_xl(modelname): return None -def sdversion(modelname): - checkpointinfo = sd_models.get_closet_checkpoint_match(modelname) - if checkpointinfo is None: - return None +def sdversion(modelname_or_header): + if type(modelname_or_header) == str: + checkpointinfo = sd_models.get_closet_checkpoint_match(modelname_or_header) + if checkpointinfo is None: + return None - is_safetensors = getattr(checkpointinfo, "is_safetensors", None) - if is_safetensors is None: - checkpointinfo.is_safetensors = checkpointinfo.filename.endswith(".safetensors") - if checkpointinfo.is_safetensors: - header = get_safetensors_header(checkpointinfo.filename) - elif checkpointinfo.filename.endswith(".ckpt"): - header = get_ckpt_header(checkpointinfo.filename) - else: - return None + is_safetensors = getattr(checkpointinfo, "is_safetensors", None) + if is_safetensors is None: + checkpointinfo.is_safetensors = checkpointinfo.filename.endswith(".safetensors") + if checkpointinfo.is_safetensors: + header = get_safetensors_header(checkpointinfo.filename) + elif checkpointinfo.filename.endswith(".ckpt"): + header = get_ckpt_header(checkpointinfo.filename) + else: + return None if header is not None: if "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in header: @@ -838,22 +965,51 @@ def get_device(): return cuda_device -def unet_blocks_map(diffusion_model, isxl=False): +def unet_blocks_map(diffusion_model, sdver): block_map = {} - block_map['time_embed.'] = diffusion_model.time_embed - - BLOCKLEN = 12 - (0 if not isxl else 3) - for j in range(BLOCKLEN): - block_name = f"input_blocks.{j}." - block_map[block_name] = diffusion_model.input_blocks[j] - - block_map["middle_block."] = diffusion_model.middle_block - - for j in range(BLOCKLEN): - block_name = f"output_blocks.{j}." - block_map[block_name] = diffusion_model.output_blocks[j] - - block_map["out."] = diffusion_model.out + if sdver in ["v1", "v2", "XL"]: + isxl = sdver == "XL" + block_map['time_embed.'] = diffusion_model.time_embed + if isxl: + block_map['label_emb.'] = diffusion_model.label_emb + + BLOCKLEN = 12 - (0 if not isxl else 3) + for j in range(BLOCKLEN): + block_name = f"input_blocks.{j}." + block_map[block_name] = diffusion_model.input_blocks[j] + + block_map["middle_block."] = diffusion_model.middle_block + + for j in range(BLOCKLEN): + block_name = f"output_blocks.{j}." + block_map[block_name] = diffusion_model.output_blocks[j] + + block_map["out."] = diffusion_model.out + elif sdver == "v3": + # depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64 -> 24 + block_map['x_embedder.'] = diffusion_model.x_embedder + block_map['t_embedder.'] = diffusion_model.t_embedder + block_map['y_embedder.'] = diffusion_model.y_embedder + block_map['context_embedder.'] = diffusion_model.context_embedder + block_map['pos_embed'] = diffusion_model.pos_embed + for j in range(24): + block_name = f"joint_blocks.{j}." + block_map[block_name] = diffusion_model.joint_blocks[j] + block_map['final_layer.'] = diffusion_model.final_layer + elif sdver == "FLUX": + # based on https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/ldm/flux/model.py + block_map['img_in.'] = diffusion_model.img_in + block_map['time_in.'] = diffusion_model.time_in + block_map['vector_in.'] = diffusion_model.vector_in + block_map['guidance_in.'] = diffusion_model.guidance_in + block_map['txt_in.'] = diffusion_model.txt_in + for j in range(19): + block_name = f"double_blocks.{j}." + block_map[block_name] = diffusion_model.double_blocks[j] + for j in range(37): + block_name = f"single_blocks.{j}." + block_map[block_name] = diffusion_model.single_blocks[j] + block_map['final_layer.'] = diffusion_model.final_layer return block_map @@ -1072,7 +1228,7 @@ def initial_checkpoint(): create_refresh_button(base_model, mm_list_models,lambda: {"choices": ["None"]+sd_models.checkpoint_tiles()},"mm_refresh_base_model") with gr.Row(): enable_sync = gr.Checkbox(label="Sync with Default SD checkpoint", value=False, visible=True) - is_sdxl = gr.Checkbox(label="is SDXL", value=False, visible=True) + is_sdxl = gr.Radio(label="", choices=[("SDv1", "v1"), ("SDv2", "v2"), ("SDXL", "XL"), ("SDv3", "v3"), ("FLUX", "FLUX"),], value="v1", visible=True) with gr.Row(): calc_settings = gr.CheckboxGroup(label=f"Calculation options", info="Optional paramters for calculation if needed. e.g.) Rebasin", choices=[("Use GPU", "GPU"), ("Use CPU", "CPU"), ("Fast Rebasin", "fastrebasin"), ("use FP16 to reduce RAM usage", "usefp16"), ("Full merge", "full")], value=["GPU", "fastrebasin"]) @@ -1289,7 +1445,59 @@ def calc_weights_for_sum(*weights): in09 = gr.Slider(label="IN09", minimum=0.0, maximum=1, step=0.0001, value=0.5) in10 = gr.Slider(label="IN10", minimum=0.0, maximum=1, step=0.0001, value=0.5) in11 = gr.Slider(label="IN11", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in12 = gr.Slider(label="IN12", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in13 = gr.Slider(label="IN13", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in14 = gr.Slider(label="IN14", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in15 = gr.Slider(label="IN15", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in16 = gr.Slider(label="IN16", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in17 = gr.Slider(label="IN17", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in18 = gr.Slider(label="IN18", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in19 = gr.Slider(label="IN19", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in20 = gr.Slider(label="IN20", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in21 = gr.Slider(label="IN21", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in22 = gr.Slider(label="IN22", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in23 = gr.Slider(label="IN23", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in24 = gr.Slider(label="IN24", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in25 = gr.Slider(label="IN25", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in26 = gr.Slider(label="IN26", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in27 = gr.Slider(label="IN27", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in28 = gr.Slider(label="IN28", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in29 = gr.Slider(label="IN29", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in30 = gr.Slider(label="IN30", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in31 = gr.Slider(label="IN31", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in32 = gr.Slider(label="IN32", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in33 = gr.Slider(label="IN33", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in34 = gr.Slider(label="IN34", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in35 = gr.Slider(label="IN35", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in36 = gr.Slider(label="IN36", minimum=0.0, maximum=1, step=0.0001, value=0.5) + in37 = gr.Slider(label="IN37", minimum=0.0, maximum=1, step=0.0001, value=0.5) with gr.Column(scale=2, min_width=200): + ou37 = gr.Slider(label="OUT37", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou36 = gr.Slider(label="OUT36", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou35 = gr.Slider(label="OUT35", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou34 = gr.Slider(label="OUT34", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou33 = gr.Slider(label="OUT33", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou32 = gr.Slider(label="OUT32", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou31 = gr.Slider(label="OUT31", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou30 = gr.Slider(label="OUT30", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou29 = gr.Slider(label="OUT29", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou28 = gr.Slider(label="OUT28", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou27 = gr.Slider(label="OUT27", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou26 = gr.Slider(label="OUT26", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou25 = gr.Slider(label="OUT25", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou24 = gr.Slider(label="OUT24", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou23 = gr.Slider(label="OUT23", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou22 = gr.Slider(label="OUT22", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou21 = gr.Slider(label="OUT21", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou20 = gr.Slider(label="OUT20", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou19 = gr.Slider(label="OUT19", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou18 = gr.Slider(label="OUT18", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou17 = gr.Slider(label="OUT17", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou16 = gr.Slider(label="OUT16", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou15 = gr.Slider(label="OUT15", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou14 = gr.Slider(label="OUT14", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou13 = gr.Slider(label="OUT13", minimum=0.0, maximum=1, step=0.0001, value=0.5) + ou12 = gr.Slider(label="OUT12", minimum=0.0, maximum=1, step=0.0001, value=0.5) ou11 = gr.Slider(label="OUT11", minimum=0.0, maximum=1, step=0.0001, value=0.5) ou10 = gr.Slider(label="OUT10", minimum=0.0, maximum=1, step=0.0001, value=0.5) ou09 = gr.Slider(label="OUT09", minimum=0.0, maximum=1, step=0.0001, value=0.5) @@ -2268,12 +2476,13 @@ def resetvalopt(opt): return gr.update(value = value) def sync_main_checkpoint(enable_sync, model): - isxl = is_xl(model) - ret = gr.update(value=True if isxl else False, interactive=True if isxl is None else False) + sdv = sdversion(model) + if sdv == "v2": + sdv = "v1" # same way prepare_model(model) - return ret + return gr.update(value=sdv, interactive=True if sdv == "v1" else False) def import_image_from_gallery(gallery): prompt = "" @@ -2390,7 +2599,17 @@ def model_metadata(model): if self.init_on_after_callback is False: script_callbacks.on_after_component(on_after_components) - members = [base,in00,in01,in02,in03,in04,in05,in06,in07,in08,in09,in10,in11,mi00,ou00,ou01,ou02,ou03,ou04,ou05,ou06,ou07,ou08,ou09,ou10,ou11] + members = [base, + in00, in01, in02, in03, in04, in05, in06, in07, in08, in09, + in10, in11, in12, in13, in14, in15, in16, in17, in18, in19, + in20, in21, in22, in23, in24, in25, in26, in27, in28, in29, + in30, in31, in32, in33, in34, in35, in36, in37, + mi00, + ou00, ou01, ou02, ou03, ou04, ou05, ou06, ou07, ou08, ou09, + ou10, ou11, ou12, ou13, ou14, ou15, ou16, ou17, ou18, ou19, + ou20, ou21, ou22, ou23, ou24, ou25, ou26, ou27, ou28, ou29, + ou30, ou31, ou32, ou33, ou34, ou35, ou36, ou37, + ] def update_slider_range(advanced_mode): if advanced_mode: @@ -2577,20 +2796,38 @@ def finetune_reader(finetune): col2.release(fn=finetune_update, inputs=[mm_finetune, *finetunes], outputs=mm_finetune, show_progress=False) col3.release(fn=finetune_update, inputs=[mm_finetune, *finetunes], outputs=mm_finetune, show_progress=False) - def config_sdxl(isxl, num_models): - if isxl: - BLOCKS = BLOCKIDXL - else: - BLOCKS = BLOCKID + def config_sliders(sdver, num_models): + BLOCKS = all_blocks(sdver) ret = [gr.update(choices=BLOCKS)] - ret += [gr.update(visible=True) for _ in range(26)] if not isxl else [gr.update(visible=ISXLBLOCK[i]) for i in range(26)] - choices = ["ALL","BASE","INP*","MID","OUT*"]+BLOCKID[1:] if not isxl else ["ALL","BASE","INP*","MID","OUT*"]+BLOCKIDXL[1:] + + is_selected_blocks = is_block_sliders(sdver) + labs = [] + k = 0 + + for j, i in enumerate(is_selected_blocks): + if i: + labs.append(BLOCKS[k]) + k += 1 + else: + labs.append(f"NONAME{j:02d}") + + ret += [gr.update(visible=is_selected_blocks[i], label=labs[i]) for i in range(len(is_selected_blocks))] + if sdver in ["v1", "v2", "XL"]: + choices = ["ALL","BASE","INP*","MID","OUT*"] + BLOCKS[1:] + else: + choices = ["ALL","BASE"] + BLOCKS[1:] ret += [gr.update(choices=choices) for _ in range(num_models)] - last = 11 if not isxl else 8 - info = f"Merge Block Weights: BASE,IN00,IN02,...IN{last:02d},M00,OUT00,...,OUT{last:02d}" + if sdver in ["v1", "v2", "XL"]: + last = 11 if sdver != "XL" else 8 + info = f"Merge Block Weights: BASE,IN00,IN02,...,IN{last:02d},M00,OUT00,...,OUT{last:02d}" + elif sdver == "v3": + info = f"Merge Block Weights: BASE,IN00,IN02,...,IN23" + elif sdver == "FLUX": + info = f"Merge Block Weights: BASE,DOUBLE00,DOUBLE01,...,SINGLE00,SINGLE02,...,SINGLE37" ret += [gr.update(label=info) for _ in range(num_models)] return ret + def select_block_elements(blocks): # change choices for selected blocks elements = [] @@ -2659,7 +2896,7 @@ def read_elemental(elemental): elemental_write.click(fn=write_elemental, inputs=[not_elemblks, not_elements, elemblks, elements, elemental_ratio, mm_elemental_main], outputs=mm_elemental_main) elemental_read.click(fn=read_elemental, inputs=mm_elemental_main, outputs=[not_elemblks, not_elements, elemblks, elements, elemental_ratio]) - is_sdxl.change(fn=config_sdxl, inputs=[is_sdxl, mm_max_models], outputs=[elemblks, *members, *mm_usembws, *mm_weights]) + is_sdxl.change(fn=config_sliders, inputs=[is_sdxl, mm_max_models], outputs=[elemblks, *members, *mm_usembws, *mm_weights]) resetopt.change(fn=resetvalopt, inputs=[resetopt], outputs=[resetval]) resetweight.click(fn=resetblockweights, inputs=[resetval,resetblockopt], outputs=members) @@ -3068,6 +3305,15 @@ def _update_model_list(max_models): show_progress=False, ) + # init mbw sliders + demo.load( + fn=config_sliders, + inputs=[is_sdxl, mm_max_models], + outputs=[elemblks, *members, *mm_usembws, *mm_weights], + show_progress=False, + queue=False, + ) + self.init_on_app_started = True @@ -3398,9 +3644,9 @@ def before_process(self, p, enabled, model_a, base_model, mm_max_models, mm_fine all_elemental_blocks = set(all_elemental_blocks) if "elemental merge" in debugs: print(" Elemental: all elemental blocks = ", all_elemental_blocks) - def selected_elemental_blocks(blocks, isxl): - max_blocks = 26 - (0 if not isxl else 6) - BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL + def selected_elemental_blocks(blocks, sdver): + BLOCKIDS = all_blocks(sdver) + max_blocks = len(BLOCKIDS) elemental_selected = [False] * max_blocks for j, b in enumerate(BLOCKIDS): if b in blocks: @@ -3461,11 +3707,8 @@ def load_state_dict(checkpoint_info): # check SDXL, FLUX etc. sdv = sdversion(model_a) isxl = sdv == 'XL' - isflux = sdv == 'FLUX' - isv3 = sdv == 'v3' - isv20 = sdv == 'v2' - print("isxl =", isxl, ", sd2 =", isv20, ", sd3 =", isv3, ", flux =", isflux) + print("sdversion =", sdv) # check base_model use_safe_open = shared.opts.data.get("mm_use_safe_open", False) @@ -3479,7 +3722,7 @@ def load_state_dict(checkpoint_info): # get all selected elemental blocks elemental_selected = [] if len(all_elemental_blocks) > 0: - elemental_selected = selected_elemental_blocks(all_elemental_blocks, isxl) + elemental_selected = selected_elemental_blocks(all_elemental_blocks, sdv) # prepare for merges compact_mode = None @@ -3487,11 +3730,12 @@ def load_state_dict(checkpoint_info): for j, model in enumerate(mm_models): if len(mm_usembws[j]) > 0: # normalize Merge block weights - mm_weights[j], compact_mbws, mm_selected[j] = calc_mbws(mm_weights[j], mm_usembws[j], isxl=isxl) + mm_weights[j], dummy, mm_selected[j] = calc_mbws(mm_weights[j], mm_usembws[j], sdv) compact_mode = True if compact_mode is None else compact_mode else: compact_mode = False - max_blocks = 26 - (0 if not isxl else 6) + + max_blocks = len(all_blocks(sdv)) mm_weights[j] = [mm_alpha[j]] * max_blocks # fix mm_weights to use poinpoint blocks xyz @@ -3508,7 +3752,9 @@ def load_state_dict(checkpoint_info): # get overall selected blocks if compact_mode: - max_blocks = 26 - (0 if not isxl else 6) + BLOCKIDS = all_blocks(sdv) + max_blocks = len(BLOCKIDS) + selected_blocks = [] mm_selected_all = [False] * max_blocks for j in range(len(mm_models)): @@ -3520,11 +3766,10 @@ def load_state_dict(checkpoint_info): for k in range(max_blocks): mm_selected_all[k] = mm_selected_all[k] or elemental_selected[k] - all_blocks = _all_blocks(isxl) - BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL + allblocks = _all_blocks(sdv) # get all blocks affected by same perm groups by rebasin merge - if not isxl and not isv20 and "Rebasin" in mm_calcmodes: + if sdv == "v1" and "Rebasin" in mm_calcmodes: print("check affected permutation blocks by rebasin merge...") jj = 0 while True: @@ -3539,10 +3784,10 @@ def load_state_dict(checkpoint_info): for k in range(max_blocks): if mm_selected_all[k]: - selected_blocks.append(all_blocks[k]) + selected_blocks.append(allblocks[k]) else: # no compact mode, get all blocks - selected_blocks = _all_blocks(isxl) + selected_blocks = _all_blocks(sdv) print("compact_mode = ", compact_mode) @@ -3594,15 +3839,27 @@ def load_state_dict(checkpoint_info): keyremains = [] if compact_mode: # get keylist of all selected blocks - base_prefix = "cond_stage_model." if not isxl else "conditioner." + candidates = ["diffusion_model."] + if sdv in ["v1", "v2", "XL"]: + if sdv in ["v1", "v2"]: + candidates += ["cond_stage_model."] + else: + candidates += ["conditioner."] + elif sdv == "FLUX": + #candidates += ["text_encoders.clip_l.", "text_encoders.t5xxl."]: + candidates += ["text_encoders."] + elif sdv == "v3": + #candidates += ["text_encoders.clip_l.", "text_encoders.clip_g.", "text_encoders.t5xxl."]: + candidates += ["text_encoders."] + for k in models['model_a'].keys(): keyadded = False for s in selected_blocks: - if s not in ["cond_stage_model.", "conditioner."]: + if s not in ["cond_stage_model.", "conditioner.", "text_encoders."]: s = f"model.diffusion_model.{s}" if s in k: # ignore all non block releated keys - if "diffusion_model." not in k and base_prefix not in k: + if all(candidate not in k for candidate in candidates): continue keys.append(k) theta_0[k] = models['model_a'][k] @@ -3611,10 +3868,48 @@ def load_state_dict(checkpoint_info): keyremains.append(k) # add some missing extra_elements - last_block = "output_blocks.11." if not isxl else "output_blocks.8." + BLOCKIDS = all_blocks(sdv) + allblocks = _all_blocks(sdv) + + last_block = allblocks[len(BLOCKIDS)-1] if use_extra_elements and (last_block in selected_blocks) or ("" in all_elemental_blocks): - selected_blocks += [ "time_embed.", "out." ] - for el in [ "time_embed.0.bias", "time_embed.0.weight", "time_embed.2.bias", "time_embed.2.weight", "out.0.bias", "out.0.weight", "out.2.bias", "out.2.weight" ]: + if sdv in ["v1", "v2", "XL"]: + selected_blocks += [ "time_embed.", "out." ] + add_extra_elements = [ + "time_embed.0.bias", "time_embed.0.weight", + "time_embed.2.bias", "time_embed.2.weight", + "out.0.bias", "out.0.weight", + "out.2.bias", "out.2.weight", + ] + elif sdv == "FLUX": + selected_blocks += [ "img_in.", "time_in.", "vector_in.", "guidance_in.", "txt_in.", "final_layer." ] + add_extra_elements = [ + "img_in.bias", "img_in.weight", + "time_in.in_layer.bias", "time_in.in_layer.weight", + "time_in.out_layer.bias", "time_in.out_layer.weight", + "vector_in.in_layer.bias", "vector_in.in_layer.weight", + "vector_in.out_layer.bias", "vector_in.out_layer.weight", + "guidance_in.in_layer.bias", "guidance_in.in_layer.weight", + "guidance_in.out_layer.bias", "guidance_in.out_layer.weight", + "txt_in.bias", "txt_in.weight", + "final_layer.adaLN_modulation.1.bias", "final_layer.adaLN_modulation.1.weight", + "final_layer.linear.bias", "final_layer.linear.weight", + ] + elif sdv == "v3": + selected_blocks += [ "x_embedder.", "t_embedder.", "y_embedder.", "context_embedder.", "pos_embed", "final_layer." ] + add_extra_elements = [ + "x_embedder.proj.bias", "x_embedder.proj.weight", + "t_embedder.mlp.0.bias", "t_embedder.mlp.0.weight", + "t_embedder.mlp.2.bias", "t_embedder.mlp.2.weight", + "y_embedder.mlp.0.bias", "y_embedder.mlp.0.weight", + "y_embedder.mlp.2.bias", "y_embedder.mlp.2.weight", + "context_embedder.bias", "context_embedder.weight", + "pos_embed", + "final_layer.adaLN_modulation.1.bias", "final_layer.adaLN_modulation.1.weight", + "final_layer.linear.bias", "final_layer.linear.weight", + ] + + for el in add_extra_elements: k = f"model.diffusion_model.{el}" keys.append(k) theta_0[k] = models['model_a'][k] @@ -3628,7 +3923,7 @@ def load_state_dict(checkpoint_info): theta_0 = {k: v for k, v in models['model_a'].items()} # check finetune - if mm_finetune.rstrip(",0") != "": + if sdv in ["v1", "v2", "XL"] and mm_finetune.rstrip(",0") != "": fines = fineman(mm_finetune, isxl) if fines is not None: for tune_block in [ "input_blocks.0.", "out."]: @@ -3801,7 +4096,8 @@ def dare_merge(theta0, theta1, alpha, density, rescale=True, mode='random'): if first_model_is_the_same: print(" - check possible UNet partial update...") - max_blocks = 26 - (0 if not isxl else 6) + BLOCKIDS = all_blocks(sdv) + max_blocks = len(BLOCKIDS) # check changed weights weights = current["weights"] @@ -3842,27 +4138,33 @@ def dare_merge(theta0, theta1, alpha, density, rescale=True, mode='random'): if len(mm_weights) > j: changed |= np.array(mm_weights[j][:max_blocks]) != np.array([0.0]*max_blocks) - BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL print(" - partial changed blocks = ", [BLOCKIDS[k] for k, b in enumerate(changed) if b]) - all_blocks = _all_blocks(isxl) + allblocks = _all_blocks(sdv) weight_changed_blocks = [] for j, b in enumerate(changed): # recalculate all elemental blocks if len(elemental_selected) > 0: b |= elemental_selected[j] if b: - weight_changed_blocks.append(all_blocks[j]) + weight_changed_blocks.append(allblocks[j]) # check ".out.", ".time_embed." elemental blocks if "" in all_elemental_blocks: # always update elemental blocks - weight_changed_blocks += ["time_embed.", "out."] + weight_changed_blocks += allblocks[max_blocks:] elif changed[max_blocks - 1]: # last block changed. add time_embed. and out. - weight_changed_blocks += ["time_embed.", "out."] + weight_changed_blocks += allblocks[max_blocks:] + + # fix for compatiblity + if sdv == "XL" and weight_changed_blocks[-1] == "label_emb.": + # FIXME + weight_changed_blocks = weight_changed_blocks[:-1] # check finetune - finetune_changed = current["adjust"] != mm_finetune + finetune_changed = False + if sdv in ["v1", "v2", "XL"]: + finetune_changed = current["adjust"] != mm_finetune if finetune_changed: # add "input_blocks.0.", "out." for finetune @@ -3875,7 +4177,7 @@ def dare_merge(theta0, theta1, alpha, density, rescale=True, mode='random'): # get changed keys for k in keys: for s in weight_changed_blocks: - if s not in ["cond_stage_model.", "conditioner."]: + if s not in ["cond_stage_model.", "conditioner.", "text_encoders."]: ss = f"model.diffusion_model.{s}" else: ss = s @@ -3895,7 +4197,7 @@ def dare_merge(theta0, theta1, alpha, density, rescale=True, mode='random'): print(" - No change blocks detected.") # check Rebasin mode - if not isxl and not isv20 and "Rebasin" in calcmodes: + if sdv == "v1" and "Rebasin" in calcmodes: fullmatching = "fastrebasin" not in calc_settings print(" - Dynamic loading rebasin module...") load_module(os.path.join(scriptdir, "scripts", "rebasin", "weight_matching.py")) @@ -3934,7 +4236,7 @@ def dare_merge(theta0, theta1, alpha, density, rescale=True, mode='random'): shared.state.job_count = 0 if len(mm_models) > 0 and len(keys) > 0: shared.state.job_count += len(mm_models) - if not isxl and "Rebasin" in calcmodes: + if sdv == "v1" and "Rebasin" in calcmodes: shared.state.job_count += 1 sel_keys = changed_keys if partial_update else keys @@ -3953,6 +4255,8 @@ def dare_merge(theta0, theta1, alpha, density, rescale=True, mode='random'): item = theta_0.pop(k) keyremains.append(k) + BLOCKIDS = all_blocks(sdv) + timer.record("prepare") stage = 1 theta_1 = None @@ -4054,7 +4358,7 @@ def dare_merge(theta0, theta1, alpha, density, rescale=True, mode='random'): if "model" in key and key in theta_1: if usembw: - i = _weight_index(key, isxl=isxl) + i = _weight_index(key, sdv) if i == -1: if use_extra_elements and any(s in key for s in extra_elements.keys()): # FIXME @@ -4068,7 +4372,7 @@ def dare_merge(theta0, theta1, alpha, density, rescale=True, mode='random'): if i < 0: name = "" # empty block name -> extra elements else: - name = BLOCKID[i] if not isxl else BLOCKIDXL[i] + name = BLOCKIDS[i] ws = mm_elementals[n].get(name, None) new_alpha = None if ws is not None: @@ -4178,7 +4482,7 @@ def cosim(theta0, theta1, calcmode): print(f" +{k}") theta_0[key] = theta_1[key] - if not isxl and not isv20 and "Rebasin" in calcmodes[n]: + if sdv == "v1" and "Rebasin" in calcmodes[n]: print("Rebasin calc...") # rebasin mode # Replace theta_0 with a permutated version using model A and B @@ -4259,7 +4563,7 @@ def make_recipe(modes, model_a, models): if len(remains_blocks) > 0: for k in keyremains: for remain in remains_blocks: - if remain not in ["cond_stage_model.", "conditioner."]: + if remain not in ["cond_stage_model.", "conditioner.", "text_encoders."]: r = f"model.diffusion_model.{remain}" else: r = remain @@ -4269,7 +4573,7 @@ def make_recipe(modes, model_a, models): break # apply finetune - if mm_finetune.rstrip(",0") != "": + if sdv in ["v1", "v2", "XL"] and mm_finetune.rstrip(",0") != "": fines = fineman(mm_finetune, isxl) if fines is not None: old_finetune = shared.opts.data.get("mm_use_old_finetune", False) @@ -4339,7 +4643,7 @@ def fake_checkpoint(checkpoint_info, metadata, model_name, sha256, fake=True): return checkpoint_info # fix/check bad CLIP ids - fixclip(theta_0, mm_states["save_settings"], isxl) + fixclip(theta_0, mm_states["save_settings"], sdv) timer.record("merging") shared.state.textinfo = "Merge completed..." @@ -4348,7 +4652,7 @@ def fake_checkpoint(checkpoint_info, metadata, model_name, sha256, fake=True): t = Timer() if not partial_update and shared.opts.data.get("mm_use_precalculate_hash", False): shared.state.textinfo = "Precalculate model hash..." - sha256 = precalculate_safetensors_hashes(theta_0, metadata.copy(), mm_states["save_settings"], isxl) + sha256 = precalculate_safetensors_hashes(theta_0, metadata.copy(), mm_states["save_settings"], sdv) print(" - precalculated hash = ", sha256) t.record("precalculate hash") @@ -4396,27 +4700,27 @@ def fake_checkpoint(checkpoint_info, metadata, model_name, sha256, fake=True): send_model_to_cpu(shared.sd_model) sd_hijack.model_hijack.undo_hijack(shared.sd_model) - if "cond_stage_model." in weight_changed_blocks or "conditioner." in weight_changed_blocks: + allblocks = _all_blocks(sdv) + if allblocks[0] in weight_changed_blocks: # Textencoder(BASE) - if isxl: - prefix = "conditioner." - else: - prefix = "cond_stage_model." + prefix = allblocks[0] base_dict = {} for k in weight_changed[prefix]: # remove prefix, 'cond_stage_model.' or 'conditioner.' will be removed key = k[len(prefix):] if k in theta_0: base_dict[key] = theta_0[k] - if isxl: + if prefix == "conditioner.": shared.sd_model.conditioner.load_state_dict(base_dict, strict=False) - else: + elif prefix == "cond_stage_model.": shared.sd_model.cond_stage_model.load_state_dict(base_dict, strict=False) + elif prefix == "text_encoders.": + shared.sd_model.text_encoders.load_state_dict(base_dict, strict=False) print(" - \033[92mTextencoder(BASE) has been successfully updated\033[0m") shared.state.textinfo = "Update Textencoder..." # get unet_blocks_map - unet_map = unet_blocks_map(shared.sd_model.model.diffusion_model, isxl) + unet_map = unet_blocks_map(shared.sd_model.model.diffusion_model, sdv) # partial update unet blocks state_dict unet_updated = 0 @@ -4594,7 +4898,7 @@ def list_dirs(parent="None"): return None -def precalculate_safetensors_hashes(state_dict, metadata, save_settings, isxl, fixheader=True): +def precalculate_safetensors_hashes(state_dict, metadata, save_settings, sdver, fixheader=True): import safetensors if "sd_merge_models" in metadata: @@ -4605,7 +4909,7 @@ def precalculate_safetensors_hashes(state_dict, metadata, save_settings, isxl, f if "fp16" in save_settings: state_dict = to_half(state_dict, True) if "prune" in save_settings: - state_dict = prune_model(state_dict, isxl) + state_dict = prune_model(state_dict, sdver) hash_sha256 = hashlib.sha256() bytes = safetensors.torch.save(state_dict, metadata) @@ -4666,10 +4970,15 @@ def precalculate_safetensors_hashes(state_dict, metadata, save_settings, isxl, f return hash_sha256.hexdigest() -def fixclip(theta_0, settings, isxl): +def fixclip(theta_0, settings, sdver): """fix/check bad CLIP ids""" - base_prefix = "cond_stage_model." if not isxl else "conditioner." - position_id_key = f"{base_prefix}transformer.text_model.embeddings.position_ids" + if sdver in ["v1", "v2", "XL"]: + base_prefix = "cond_stage_model." if sdver != "XL" else "conditioner." + position_id_key = f"{base_prefix}transformer.text_model.embeddings.position_ids" + elif sdver in ["v3", "FLUX"]: + # ignore + return + if position_id_key in theta_0: correct = torch.tensor([list(range(77))], dtype=torch.int64, device="cpu") current = theta_0[position_id_key].to(torch.int64) @@ -4843,7 +5152,12 @@ def extract_lora_from_current_model(save_lora_mode, model_orig, model_tuned, dif if is_equal: return gr.update(value="No difference found") - isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in state_dict_base + sdv = sdversion(state_dict_base) + isxl = sdv == "XL" + if sdv not in ["v1", "v2", "XL"]: + err_msg = f"only SDv1, SDv2, SDXL are supported." + print(err_msg) + return gr.update(value=err_msg) gc.collect() devices.torch_gc() @@ -4947,9 +5261,7 @@ def extract_lora_from_current_model(save_lora_mode, model_orig, model_tuned, dif "ss_network_alpha": str(float(lora_dim)), "ss_output_name": custom_name, } - v2 = False - if 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight' in state_dict_base: - v2 = state_dict_base['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] == 1024 + v2 = sdv == "v2" metadata["ss_v2"] = str(v2) v_parameterization = v2 try: @@ -5027,13 +5339,10 @@ def save_as_diffusers(custom_name, save_settings, metadata_settings, state_dict= print(" - \033[92mget the merged model\033[0m...") state_dict = get_current_state_dict(lora=False, base=True)[0] - v2 = False - if 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight' in state_dict: - v2 = state_dict['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] == 1024 - print(" v2 = ", v2) - - isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in state_dict - print(" isxl = ", isxl) + sdv = sdversion(state_dict) + isxl = sdv == "XL" + v2 = sdv == "v2" + print(" isxl = ", isxl, ", v2 = ", v2) pipeline_type = None scheduler_type = "pndm" @@ -5068,7 +5377,7 @@ def save_as_diffusers(custom_name, save_settings, metadata_settings, state_dict= # fix/check bad CLIP ids - fixclip(state_dict, save_settings, isxl) + fixclip(state_dict, save_settings, sdv) # for safetensors contiguous error print(" - check contiguous...") @@ -5193,9 +5502,11 @@ def save_current_model(custom_name, bake_in_vae, save_settings, metadata_setting del vae_dict print("Saving...") - isxl = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in state_dict - print("isxl = ", isxl) + sdver = sdversion(state_dict) + isxl = sdver == "XL" + print("sdver = ", sdver) if isxl: + # FIXME XXX # prune share memory tensors, "cond_stage_model." prefixed base tensors are share memory with "conditioner." prefixed tensors for i, key in enumerate(state_dict.keys()): if "cond_stage_model." in key: @@ -5204,10 +5515,10 @@ def save_current_model(custom_name, bake_in_vae, save_settings, metadata_setting if "fp16" in save_settings: state_dict = to_half(state_dict, True) if "prune" in save_settings: - state_dict = prune_model(state_dict, isxl) + state_dict = prune_model(state_dict, sdver) # fix/check bad CLIP ids - fixclip(state_dict, save_settings, isxl) + fixclip(state_dict, save_settings, sdver) # for safetensors contiguous error print(" - check contiguous...") @@ -5690,22 +6001,31 @@ def prepare_model(model): elemental_blocks = prepare_elemental_blocks(model) else: # check settings again - isxl = is_xl(model) - if isxl: + sdv = sdversion(model).lower() + reload = False + if sdv == "v3": + if elemental_blocks.get("IN23", None) is None: + reload = True + elif sdv == "xl": if elemental_blocks.get("IN09", None) is not None: - # read elements-xl.json - elemental_blocks = prepare_elemental_blocks(model) - else: - if elemental_blocks.get("IN09", None) is None: - # read elements.json - elemental_blocks = prepare_elemental_blocks(model) + reload = True + elif elemental_blocks.get("IN08", None) is None: + reload = True + elif sdv == "flux": + if elemental_blocks.get("SINGLE00", None) is None: + reload = True + + if reload: + # reload elemental info. + elemental_blocks = prepare_elemental_blocks(model) + def prepare_elemental_blocks(model=None, force=False): if model is not None: - isxl = is_xl(model) + sdv = sdversion(model).lower() else: - isxl = False - elemdata = "elements.json" if not isxl else "elements-xl.json" + sdv = None + elemdata = "elements.json" if sdv is None or sdv == "v1" else f"elements-{sdv}.json" elempath = os.path.join(scriptdir, "data", elemdata) if not os.path.exists(os.path.join(scriptdir, "data")): os.makedirs(os.path.join(scriptdir, "data")) @@ -5760,10 +6080,29 @@ def prepare_elemental_blocks(model=None, force=False): return elements + def get_blocks_elements(res): import collections - blockmap = { "input_blocks": "IN", "output_blocks": "OUT", "middle_block": "M" } + ver = 'v1' + if "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in res: + ver = 'XL' + elif "model.diffusion_model.context_embedder.weight" in res: + ver = 'v3' + elif "model.diffusion_model.double_blocks.0.img_attn.proj.weight" in res: + ver = 'FLUX' + + if ver == "v1": + blockmap = {"input_blocks": "IN", "output_blocks": "OUT", "middle_block": "M",} + else: + blockmaps = { + "v3": {"joint_blocks": "IN",}, + "XL": {"input_blocks": "IN", "output_blocks": "OUT", "middle_block": "M",}, + "FLUX": {"single_blocks": "SINGLE", "double_blocks": "DOUBLE",}, + } + blockmap = blockmaps[ver] + + blocks = list(blockmap.keys()) key_re = re.compile(r"^(?:\d+\.)?(.*?)(?:\.\d+)?$") key_split_re = re.compile(r"\.\d+\.") @@ -5788,7 +6127,7 @@ def get_blocks_elements(res): name = None # only for block level keys - if any(item in k for item in ["input_blocks.", "output_blocks.", "middle_block."]): + if any(item in k for item in blocks): tmp = k.split(".",2) num = int(tmp[1]) name = f"{blockmap[tmp[0]]}{num:02d}" @@ -5834,6 +6173,7 @@ def get_blocks_elements(res): return sorted_elements + def prepblocks(blocks, blockids, select=True): #blocks = sorted(set(blocks)) # one liner block sorter if len(blocks) == 0 or (len(blocks) == 1 and blocks[0] == '*'): From da41441346270e93f408ce87677c219e16edd30e Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sat, 24 Aug 2024 20:44:01 +0900 Subject: [PATCH 02/10] fix automerger --- sd_modelmixer/hyper.py | 51 ++----------- sd_modelmixer/utils.py | 163 +++++++++++++++++++++++++++++++++++------ 2 files changed, 148 insertions(+), 66 deletions(-) diff --git a/sd_modelmixer/hyper.py b/sd_modelmixer/hyper.py index ce97ff4..4bae351 100644 --- a/sd_modelmixer/hyper.py +++ b/sd_modelmixer/hyper.py @@ -17,14 +17,14 @@ from .classifier import get_classifiers, classifier_score from .optimizers import optimizer_types -from .utils import all_blocks, _all_blocks, load_module +from .utils import all_blocks, _all_blocks, normalize_blocks, load_module classifiers = get_classifiers() def para_to_weights(para, weights=None, alpha=None, isxl=False): - BLOCKS = all_blocks(isxl) - BLOCKLEN = (12 if not isxl else 9)*2 + 2 + BLOCKIDS = all_blocks(isxl) + BLOCKLEN = len(BLOCKIDS) weights = {} if weights is None else dict(zip(range(len(weights)), weights)) alpha = {} @@ -36,7 +36,7 @@ def para_to_weights(para, weights=None, alpha=None, isxl=False): continue weight = weights.get(modelidx, [0.0]*BLOCKLEN) - j = BLOCKS.index(name[1]) + j = BLOCKIDS.index(name[1]) weight[j] = para[k] weights[modelidx] = weight @@ -55,43 +55,6 @@ def para_to_weights(para, weights=None, alpha=None, isxl=False): return nweights, nalpha -def normalize_mbw(mbw, isxl): - """Normalize Merge Block Weights""" - MAXLEN = 26 - (0 if not isxl else 6) - BLOCKLEN = 12 - (0 if not isxl else 3) - - # no mbws blocks selected or have 'ALL' alias - if len(mbw) == 0 or 'ALL' in mbw: - # select all blocks - mbw = [ 'BASE', 'INP*', 'MID', 'OUT*' ] - - # fix alias - if 'MID' in mbw: - i = mbw.index('MID') - mbw[i] = 'M00' - - # expand some aliases - if 'INP*' in mbw: - for i in range(BLOCKLEN): - name = f"IN{i:02d}" - if name not in mbw: - mbw.append(name) - if 'OUT*' in mbw: - for i in range(BLOCKLEN): - name = f"OUT{i:02d}" - if name not in mbw: - mbw.append(name) - - BLOCKS = all_blocks(isxl) - - sort = [] - for b in BLOCKS[:MAXLEN]: - if b in mbw: - sort.append(b) - - return sort - - def unquote(text): if len(text) == 0 or text[0] != '"' or text[-1] != '"': return text @@ -370,7 +333,7 @@ def hyper_score(localargs): # setup search space search_space = {} if variable_blocks is not None and len(variable_blocks) > 0: - variable_blocks = normalize_mbw(variable_blocks, isxl) + variable_blocks, _ = normalize_blocks(variable_blocks, isxl) else: variable_blocks = None @@ -408,7 +371,7 @@ def hyper_score(localargs): continue weight = weights[k] - mbw = normalize_mbw(usembws[k], isxl) + mbw, _ = normalize_blocks(usembws[k], isxl) for b in selected_blocks: j = blocks.index(b) if j < len(weight) and _BLOCKS[j] in mbw: @@ -456,7 +419,7 @@ def hyper_score(localargs): continue weight = _weights[k] - mbw = normalize_mbw(_usembws[k], isxl) + mbw, _ = normalize_blocks(_usembws[k], isxl) for b in _selected_blocks: j = blocks.index(b) if j < len(weight) and _BLOCKS[j] in mbw: diff --git a/sd_modelmixer/utils.py b/sd_modelmixer/utils.py index 80239e0..317dcf7 100644 --- a/sd_modelmixer/utils.py +++ b/sd_modelmixer/utils.py @@ -14,35 +14,154 @@ scriptdir = basedir() -def all_blocks(isxl=False): - BLOCKLEN = 12 if not isxl else 9 - # return all blocks - blocks = [ "BASE" ] - for i in range(0, BLOCKLEN): - blocks.append(f"IN{i:02d}") - blocks.append("M00") - for i in range(0, BLOCKLEN): - blocks.append(f"OUT{i:02d}") - - blocks += [ "TIME_EMBED", "OUT" ] + +def all_blocks(sdversion): + """simple BLOCKIDS""" + + if type(sdversion) is bool: + # for old behavior called by all_blocks(isxl) + sdversion = "XL" if sdversion else "v1" + + blocks = [ 'BASE' ] + if sdversion in ["v1", "v2", "XL"]: + BLOCKLEN = 12 - (0 if sdversion != "XL" else 3) + # return all blocks + for i in range(0, BLOCKLEN): + blocks.append(f"IN{i:02d}") + blocks.append("M00") + for i in range(0, BLOCKLEN): + blocks.append(f"OUT{i:02d}") + + elif sdversion == "v3": + for i in range(0, 24): + blocks.append(f"IN{i:02d}") + + elif sdversion == "FLUX": + for i in range(0, 19): + blocks.append(f"DOUBLE{i:02d}") + for i in range(0, 38): + blocks.append(f"SINGLE{i:02d}") + return blocks -def _all_blocks(isxl=False): - BLOCKLEN = 12 - (0 if not isxl else 3) - # return all blocks - base_prefix = "cond_stage_model." if not isxl else "conditioner." - blocks = [ base_prefix ] - for i in range(0, BLOCKLEN): - blocks.append(f"input_blocks.{i}.") - blocks.append("middle_block.") - for i in range(0, BLOCKLEN): - blocks.append(f"output_blocks.{i}.") +def _all_blocks(sdversion): + """1:1 mapping BLOCKIDS to tensor keys""" + + if type(sdversion) is bool: + # for old behavior called by all_blocks(isxl) + sdversion = "XL" if sdversion else "v1" + + if sdversion is True: + # for old behavior called by _all_blocks(isxl) + sdversion = "XL" + + if sdversion in ["v1", "v2", "XL"]: + BLOCKLEN = 12 - (0 if sdversion != "XL" else 3) + # return all blocks + base_prefix = "cond_stage_model." if sdversion != "XL" else "conditioner." + blocks = [ base_prefix ] + for i in range(0, BLOCKLEN): + blocks.append(f"input_blocks.{i}.") + blocks.append("middle_block.") + for i in range(0, BLOCKLEN): + blocks.append(f"output_blocks.{i}.") + + blocks += [ "time_embed.", "out." ] + if sdversion == "XL": + blocks += [ "label_emb." ] + + elif sdversion == "v3": + #blocks = [ "text_encoders.clip_l.", "text_encoders.clip_g.", "text_encoders.t5xxl." ] + blocks = [ "text_encoders." ] + for i in range(0, 24): + blocks.append(f"joint_blocks.{i}.") + + blocks += [ "x_embedder.", "t_embedder.", "y_embedder.", "context_embedder.", "pos_embed", "final_layer." ] + + elif sdversion == "FLUX": + #blocks = [ "text_encoders.clip_l.", "text_encoders.t5xxl." ] + blocks = [ "text_encoders." ] + for i in range(0, 19): + blocks.append(f"double_blocks.{i}.") + for i in range(0, 38): + blocks.append(f"single_blocks.{i}.") + + blocks += [ "img_in.", "time_in.", "vector_in.", "guidance_in.", "txt_in.", "final_layer." ] - blocks += [ "time_embed.", "out." ] return blocks +def normalize_blocks(blocks, sdv): + """Normalize Merge Block Weights""" + + if type(sdv) is bool: + # for old behavior + sdv = "XL" if sdv else "v1" + + # no mbws blocks selected or have 'ALL' alias + if len(blocks) == 0 or 'ALL' in blocks: + # select all blocks + blocks = [ 'BASE', 'INP*', 'MID', 'OUT*' ] + + # fix alias + if 'MID' in blocks: + i = blocks.index('MID') + blocks[i] = 'M00' + + if sdv in ["v1", "v2", "XL"]: + isxl = sdv == "XL" + BLOCKLEN = 12 - (0 if not isxl else 3) + + # expand some aliases + if 'INP*' in blocks: + for i in range(0, BLOCKLEN): + name = f"IN{i:02d}" + if name not in blocks: + blocks.append(name) + if 'OUT*' in blocks: + for i in range(0, BLOCKLEN): + name = f"OUT{i:02d}" + if name not in blocks: + blocks.append(name) + + elif sdv == "FLUX": + # expand some aliases + if 'INP*' in blocks or 'DOUBLE*' in blocks: + for i in range(0, 19): + name = f"DOUBLE{i:02d}" + if name not in blocks: + blocks.append(name) + if 'OUT*' in blocks or 'SINGLE*' in blocks: + for i in range(0, 38): + name = f"SINGLE{i:02d}" + if name not in blocks: + blocks.append(name) + + elif sdv == "v3": + # expand some aliases + if 'INP*' in blocks: + for i in range(0, 24): + name = f"IN{i:02d}" + if name not in blocks: + blocks.append(name) + + blocks = list(set(blocks)) + + # filter valid blocks + BLOCKIDS = all_blocks(sdv) + MAXLEN = len(BLOCKIDS) + selected = [False]*MAXLEN + + normalized = [] + for i, name in enumerate(BLOCKIDS): + if name in blocks: + selected[i] = True + normalized.append(name) + + return normalized, selected + + def module_name(path): p = Path(path) if "sd_modelmixer" in p.parts: From 52bc56c838eb4a39887728b3ecbc34bf48f3c1c9 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sun, 25 Aug 2024 12:19:41 +0900 Subject: [PATCH 03/10] fix elemental blocks --- scripts/model_mixer.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index e02bb4d..7e79ae6 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -464,6 +464,18 @@ def all_blocks(sdversion): return blocks +def alias_blocks(sdversion): + aliases = [] + if sdversion == "XL": + aliases = ["CLIP_L", "CLIP_G"] + elif sdversion == "v3": + aliases = ["CLIP_L", "CLIP_G", "T5XXL"] + elif sdversion == "FLUX": + aliases = ["CLIP_L", "T5XXL"] + + return aliases + + def print_blocks(blocks): str = [] for i,x in enumerate(blocks): @@ -2798,7 +2810,9 @@ def finetune_reader(finetune): def config_sliders(sdver, num_models): BLOCKS = all_blocks(sdver) - ret = [gr.update(choices=BLOCKS)] + aliases = alias_blocks(sdver) + + ret = [gr.update(choices=BLOCKS+aliases)] is_selected_blocks = is_block_sliders(sdver) labs = [] @@ -2828,15 +2842,23 @@ def config_sliders(sdver, num_models): return ret - def select_block_elements(blocks): + def select_block_elements(blocks, model): # change choices for selected blocks elements = [] - if elemental_blocks is None or len(elemental_blocks) == 0: - return gr.update(choices=["time_embed", "time_embed.0", "time_embed.2", "out", "out.0", "out.2"]) + if elemental_blocks is None: + prepare_model(model) + + if len(blocks) == 0: + sdv = sdversion(model) + max_blocks = all_blocks(sdv) + allblocks = _all_blocks(sdv) for b in blocks: elements += elemental_blocks.get(b, []) + if len(blocks) == 0: + elements += elemental_blocks.get("", []) + elements = list(set(elements)) elements = sorted(elements) return gr.update(choices=elements) @@ -2891,7 +2913,7 @@ def read_elemental(elemental): return [gr.update(value=not_blks), gr.update(value=not_elem), gr.update(value=blks), gr.update(value=elem), gr.update(value=ratio)] - elemblks.change(fn=select_block_elements, inputs=[elemblks], outputs=[elements]) + elemblks.change(fn=select_block_elements, inputs=[elemblks, model_a], outputs=[elements]) elemental_reset.click(fn=lambda: [gr.update(value=False)]*2 + [gr.update(value=[])]*2+[gr.update(value=0.5)], inputs=[], outputs=[not_elemblks, not_elements, elemblks, elements, elemental_ratio]) elemental_write.click(fn=write_elemental, inputs=[not_elemblks, not_elements, elemblks, elements, elemental_ratio, mm_elemental_main], outputs=mm_elemental_main) elemental_read.click(fn=read_elemental, inputs=mm_elemental_main, outputs=[not_elemblks, not_elements, elemblks, elements, elemental_ratio]) From 0e48dda60fff3c22cb6c49d39af2b127d2a618d2 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sun, 25 Aug 2024 12:21:22 +0900 Subject: [PATCH 04/10] fix get_blocks_elements() parser to get all extra unet elements --- scripts/model_mixer.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index 7e79ae6..639f0b3 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -6126,7 +6126,8 @@ def get_blocks_elements(res): blocks = list(blockmap.keys()) - key_re = re.compile(r"^(?:\d+\.)?(.*?)(?:\.\d+)?$") + key_re = re.compile(r"^(?:\d+\.)?(.*?)$") + #key_re = re.compile(r"^(?:\d+\.)?(.*?)(?:\.\d+)?$") key_split_re = re.compile(r"\.\d+\.") elements = {} @@ -6137,7 +6138,7 @@ def get_blocks_elements(res): if res.get(k, None) is not None: continue tmp = key.split(".") - if tmp[0] not in ["cond_stage_model", "conditioner"]: + if tmp[0] not in ["cond_stage_model", "conditioner", "text_encoders"]: if tmp[0] == "model" and tmp[1] == "diffusion_model": pass else: @@ -6145,7 +6146,12 @@ def get_blocks_elements(res): k = key.replace(".weight", "") # strip .weight k = k.replace("model.diffusion_model.", "") - k = k.replace("cond_stage_model.transformer.text_model.", "BASE.") + k = k.replace("cond_stage_model.transformer.", "BASE.") + k = k.replace("conditioner.embedders.0.transformer.", "CLIP_L.") # XL + k = k.replace("conditioner.embedders.1.model.", "CLIP_G.") # XL + k = k.replace("text_encoders.t5xxl.", "T5XXL.") + k = k.replace("text_encoders.clip_l.", "CLIP_L.") + k = k.replace("text_encoders.clip_g.", "CLIP_G.") name = None # only for block level keys @@ -6156,14 +6162,17 @@ def get_blocks_elements(res): if name in [ "M00", "M01", "M02" ]: name = "M00" # supermerger does not distinguish M01 and M02 last = tmp[2] - elif "BASE" in k: + elif any(prefix in k for prefix in ["BASE", "CLIP_L", "CLIP_G", "T5XXL"]): if "position_ids" in k: continue tmp = k.split(".",1) - name = "BASE" + name = tmp[0] #"BASE" last = tmp[1] - last = last.replace("encoder.layers.", "") + last = last.replace("text_model.", "").replace("transformer.", "") + else: + name = "" + last = k - if name and last != "": + if name is not None and last != "": m = key_re.match(last) # trim out some numbering: 0.foobar.1 => foobar if m: elem = elements.get(name, None) @@ -6176,14 +6185,16 @@ def get_blocks_elements(res): for e in tmp: if e == "0": # for IN00 case, only has 0.bias and 0.weight -> "0" remain continue - elem[e] = 1 + if e.rstrip("012345") != "": + elem[e] = 1 if e.find(".") != -1: # split attn1.to_q -> attn1, to_q tmp1 = e.split(".") if len(tmp1) > 0: for e1 in tmp1: - elem[e1] = 1 - e2 = e1.rstrip("12345") # attn1 -> attn - if e1 != e2: + if e1.rstrip("012345") != "": + elem[e1] = 1 + e2 = e1.rstrip("012345") # attn1 -> attn + if e2 != "" and e1 != e2: elem[e2] = 1 # sort elements From 85d85f7c27223c31b27029c028bb87a57a7b0924 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sun, 25 Aug 2024 18:48:56 +0900 Subject: [PATCH 05/10] fix remains * use `BLOCKIDS = all_blocks(sdver)` * check sd version of model to get correct block ids * remove duplicated function get_selected_blocks() --- scripts/model_mixer.py | 163 ++++++++++++++++------------------------- 1 file changed, 65 insertions(+), 98 deletions(-) diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index 639f0b3..7184236 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -218,56 +218,6 @@ def find_preset_by_name(preset, presets=None, reload=False): return None -def get_selected_blocks(mbw_blocks, isxl=False): - MAXLEN = 26 - (0 if not isxl else 6) - BLOCKLEN = 12 - (0 if not isxl else 3) - BLOCKOFFSET = 13 if not isxl else 10 - selected = [False]*MAXLEN - BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL - - # no mbws blocks selected or have 'ALL' alias - if 'ALL' in mbw_blocks: - # select all blocks - mbw_blocks += [ 'BASE', 'INP*', 'MID', 'OUT*' ] - - # fix alias - if 'MID' in mbw_blocks: - i = mbw_blocks.index('MID') - mbw_blocks[i] = 'M00' - - # expand some aliases - if 'INP*' in mbw_blocks: - for i in range(0, BLOCKLEN): - name = f"IN{i:02d}" - if name not in mbw_blocks: - mbw_blocks.append(name) - if 'OUT*' in mbw_blocks: - for i in range(0, BLOCKLEN): - name = f"OUT{i:02d}" - if name not in mbw_blocks: - mbw_blocks.append(name) - - for i, name in enumerate(BLOCKIDS): - if name in mbw_blocks: - if name[0:2] == 'IN': - num = int(name[2:]) - selected[num + 1] = True - elif name[0:3] == 'OUT': - num = int(name[3:]) - selected[num + BLOCKOFFSET + 1] = True - elif name == 'M00': - selected[BLOCKOFFSET] = True - elif name == 'BASE': - selected[0] = True - - all_blocks = _all_blocks(isxl) - selected_blocks = [] - for i, v in enumerate(selected): - if v: - selected_blocks.append(all_blocks[i]) - return selected_blocks - - def normalize_blocks(blocks, sdv): # no mbws blocks selected or have 'ALL' alias if len(blocks) == 0 or 'ALL' in blocks: @@ -513,23 +463,6 @@ def print_blocks(blocks): str.append(block) return ','.join(str) -def _selected_blocks_and_weights(mbw, isxl=False): - if type(mbw) is str: - weights = [t.strip() for t in mbw.split(",")] - else: - weights = mbw - # get all blocks - all_blocks = _all_blocks(isxl) - - sel_blocks = [] - sel_mbws = [] - for i, w in enumerate(weights): - v = float(w) - if v != 0.0: - sel_blocks.append(all_blocks[i]) - sel_mbws.append(v) - return sel_blocks, sel_mbws - def _weight_index(key, sdversion): if sdversion in ["v1", "v2", "XL"]: @@ -880,7 +813,7 @@ def mm_list_models(): permutation_spec = None -def get_rebasin_perms(mbws, isxl): +def get_rebasin_perms(mbws, sdver): """all blocks permutations of selected blocks""" global permutation_spec @@ -895,13 +828,13 @@ def get_rebasin_perms(mbws, isxl): if True in mbws or False in mbws: # already have selected _selected = mbws - all_blocks = _all_blocks(isxl) + all_blocks = _all_blocks(sdver) selected = [] for i, v in enumerate(_selected): if v: selected.append(all_blocks[i]) else: - selected = get_selected_blocks(mbws, isxl) + normalized, selected = normalize_blocks(mbws, sdv) if len(selected) > 0: axes = [] @@ -920,10 +853,10 @@ def get_rebasin_perms(mbws, isxl): return None -def get_rebasin_axes(mbws, isxl): +def get_rebasin_axes(mbws, sdver): """select all blocks correspond their permutation groups""" - perms = get_rebasin_perms(mbws, isxl) + perms = get_rebasin_perms(mbws, sdver) if perms is None: return None @@ -937,12 +870,14 @@ def get_rebasin_axes(mbws, isxl): return axes -def _get_rebasin_blocks(mbws, isxl): +def _get_rebasin_blocks(mbws, sdver): """select all blocks correspond their permutation groups""" - perms = get_rebasin_perms(mbws, isxl) + perms = get_rebasin_perms(mbws, sdver) if perms is None: return None + if sdver != "v1": # only v1 supported. FIXME + return None # get all axes and corresponde blocks blocks = [] @@ -952,13 +887,11 @@ def _get_rebasin_blocks(mbws, isxl): axes = list(set(axes)) # get all block representations to show gr.Dropdown - MAXLEN = 26 - (0 if not isxl else 6) - BLOCKLEN = 12 - (0 if not isxl else 3) - BLOCKOFFSET = 13 if not isxl else 10 + BLOCKIDS = all_blocks(sdver) + MAXLEN = len(BLOCKIDS) selected = [False]*MAXLEN - BLOCKIDS = BLOCKID if not isxl else BLOCKIDXL - all_blocks = _all_blocks(isxl) + all_blocks = _all_blocks(sdver) for j, block in enumerate(all_blocks[:MAXLEN]): if block not in ["cond_stage_model.", "conditioner."]: block = f"model.diffusion_model.{block}" @@ -2214,6 +2147,7 @@ def setup_download_ui(fileinfo): def load_mm_settings(text_or_image=None, reset=True): """load weight settings from text or image""" + sdver = None if text_or_image is None: # load from the current selected checkpoint current_model = shared.opts.data.get("sd_model_checkpoint", None) @@ -2221,6 +2155,7 @@ def load_mm_settings(text_or_image=None, reset=True): if checkpoint_info is None: raise gr.Error("Not a valid image or text") text_or_image = checkpoint_info.title + sdver = sdversion(checkpoint_info.model_name) if type(text_or_image) is str: geninfo = text_or_image.replace("\n", "").strip() @@ -2240,6 +2175,7 @@ def load_mm_settings(text_or_image=None, reset=True): checkpoint = sd_models.get_closet_checkpoint_match(geninfo) if checkpoint is not None: parsed = read_metadata_from_safetensors(checkpoint.filename) + sdver = sdversion(checkpoint.model_name) if parsed is not None: recipe = parsed.get("sd_merge_recipe", None) @@ -2261,7 +2197,31 @@ def load_mm_settings(text_or_image=None, reset=True): params["ModelMixer adjust"] = parsed.get("adjust", "") params["ModelMixer model a"] = parsed.get("model_a", "None") - BLOCKIDS = BLOCKID if len(weights[0]) > 20 else BLOCKIDXL + if sdver is None: + # no version info. check model + model = parsed.get(f"model_a", "None") + if model is not None and model != "None": + checkpointinfo = sd_models.get_closet_checkpoint_match(model) + if checkpointinfo is not None: + sdver = sdversion(checkpointinfo.model_name) + + if sdver is None: + print("Can't detect SD version. try to parse weight info...") + if len(weights[0]) == 20: # BASE + MID + 9 * 2 + sdver = "XL" + elif len(weights[0]) == 26: # BASE + MID + 12 * 2 + sdver = "v1" + elif len(weights[0]) == 25: # BASE + 24 + sdver = "v3" + elif len(weights[0]) == 58: # BASE + 19 + 38 = 58 + sdver = "FLUX" + else: + raise gr.Error(f"Can't detect SD version!") + print(f"Detected SD version is '{sdver}'") + else: + print(f"SD version is '{sdver}'") + + BLOCKIDS = all_blocks(sdver) if weights is not None: if type(weights) is list: for n, mbw in enumerate(weights): @@ -2863,13 +2823,14 @@ def select_block_elements(blocks, model): elements = sorted(elements) return gr.update(choices=elements) - def write_elemental(not_blocks, not_elements, blocks, elements, ratio, elemental): + def write_elemental(not_blocks, not_elements, blocks, elements, ratio, elemental, sdver): # update elemental information if len(blocks) == 0 and len(elements) == 0: return gr.update() # newly added - info = ("NOT " if not_blocks else "") + " ".join(zipblocks(blocks, BLOCKID)) + BLOCKIDS = all_blocks(sdver) + info = ("NOT " if not_blocks else "") + " ".join(zipblocks(blocks, BLOCKIDS)) info += ":" + ("NOT " if not_elements else "") + " ".join(elements) + ":" + str(ratio) # old @@ -2884,7 +2845,7 @@ def write_elemental(not_blocks, not_elements, blocks, elements, ratio, elemental info = "\n".join(newtmp) + "\n" return gr.update(value=info) - def read_elemental(elemental): + def read_elemental(elemental, sdver): tmp = elemental.strip() if len(tmp) == 0: return [gr.update()]*5 @@ -2902,7 +2863,8 @@ def read_elemental(elemental): not_blks = True blks = blks[1:] # expand any block ranges - blks = prepblocks(blks, BLOCKID) + BLOCKIDS = all_blocks(sdver) + blks = prepblocks(blks, BLOCKIDS) elem = tmp[1].strip().split(" ") elem = list(filter(None, elem)) @@ -2915,8 +2877,8 @@ def read_elemental(elemental): elemblks.change(fn=select_block_elements, inputs=[elemblks, model_a], outputs=[elements]) elemental_reset.click(fn=lambda: [gr.update(value=False)]*2 + [gr.update(value=[])]*2+[gr.update(value=0.5)], inputs=[], outputs=[not_elemblks, not_elements, elemblks, elements, elemental_ratio]) - elemental_write.click(fn=write_elemental, inputs=[not_elemblks, not_elements, elemblks, elements, elemental_ratio, mm_elemental_main], outputs=mm_elemental_main) - elemental_read.click(fn=read_elemental, inputs=mm_elemental_main, outputs=[not_elemblks, not_elements, elemblks, elements, elemental_ratio]) + elemental_write.click(fn=write_elemental, inputs=[not_elemblks, not_elements, elemblks, elements, elemental_ratio, mm_elemental_main, is_sdxl], outputs=mm_elemental_main) + elemental_read.click(fn=read_elemental, inputs=[mm_elemental_main,is_sdxl], outputs=[not_elemblks, not_elements, elemblks, elements, elemental_ratio]) is_sdxl.change(fn=config_sliders, inputs=[is_sdxl, mm_max_models], outputs=[elemblks, *members, *mm_usembws, *mm_weights]) @@ -3078,7 +3040,7 @@ def set_elemental(elemental, elemental_edit): preset_weight.change(fn=on_change_preset_weight, inputs=[preset_weight], outputs=members) preset_save.click( - fn=lambda isxl, *mem: [slider2text(isxl, *mem), gr.update(visible=True)], + fn=lambda sdver, *mem: [slider2text(sdver, *mem), gr.update(visible=True)], inputs=[is_sdxl, *members], outputs=[preset_edit_weight, preset_edit_dialog], show_progress=False, @@ -3482,7 +3444,12 @@ def before_process(self, p, enabled, model_a, base_model, mm_max_models, mm_fine j = k.rfind(" ") name = k[j+1:] # model name: model b -> get "b" idx = ord(name) - 98 # model index: model b -> get 0 - if pinpoint in BLOCKID: + + # check sdversion and get blockids + sdv = sdversion(model_a) + blockids = all_blocks(sdv) + + if pinpoint in blockids: if f"pinpoint alpha {name}" in p.modelmixer_xyz: alpha = p.modelmixer_xyz[f"pinpoint alpha {name}"] else: @@ -3653,7 +3620,7 @@ def before_process(self, p, enabled, model_a, base_model, mm_max_models, mm_fine for j in range(len(mm_models)): elemental_ws = None if mm_use_elemental[j]: - elemental_ws = parse_elemental(mm_elementals[j]) + elemental_ws = parse_elemental(mm_elementals[j], sdv) if "elemental merge" in debugs: print(" Elemental merge wegiths = ", elemental_ws) if elemental_ws is not None: mm_elementals[j] = elemental_ws @@ -3747,6 +3714,7 @@ def load_state_dict(checkpoint_info): elemental_selected = selected_elemental_blocks(all_elemental_blocks, sdv) # prepare for merges + BLOCKIDS = all_blocks(sdv) compact_mode = None mm_selected = [[]] * num_models for j, model in enumerate(mm_models): @@ -3764,17 +3732,16 @@ def load_state_dict(checkpoint_info): for j in range(len(mm_models)): # get original model index n = modelindex[j] - BLOCKS = BLOCKID if not isxl else BLOCKIDXL + if len(xyz_pinpoint_blocks[n]) > 0: for pin in xyz_pinpoint_blocks[n].keys(): - if pin in BLOCKS: - mm_weights[j][BLOCKS.index(pin)] = xyz_pinpoint_blocks[n][pin] + if pin in BLOCKIDS: + mm_weights[j][BLOCKIDS.index(pin)] = xyz_pinpoint_blocks[n][pin] else: print("WARN: No pinpoint block found. ignore...") # get overall selected blocks if compact_mode: - BLOCKIDS = all_blocks(sdv) max_blocks = len(BLOCKIDS) selected_blocks = [] @@ -3795,7 +3762,7 @@ def load_state_dict(checkpoint_info): print("check affected permutation blocks by rebasin merge...") jj = 0 while True: - xx_selected_all = _get_rebasin_blocks(mm_selected_all, isxl) + xx_selected_all = _get_rebasin_blocks(mm_selected_all, sdv) changed = [BLOCKIDS[i] for i, v in enumerate(mm_selected_all) if v != xx_selected_all[i]] if len(changed) > 0: print(f" - [{jj+1}] {changed} block{'s' if len(changed) > 1 else ''} added") @@ -3890,7 +3857,6 @@ def load_state_dict(checkpoint_info): keyremains.append(k) # add some missing extra_elements - BLOCKIDS = all_blocks(sdv) allblocks = _all_blocks(sdv) last_block = allblocks[len(BLOCKIDS)-1] @@ -4277,7 +4243,6 @@ def dare_merge(theta0, theta1, alpha, density, rescale=True, mode='random'): item = theta_0.pop(k) keyremains.append(k) - BLOCKIDS = all_blocks(sdv) timer.record("prepare") stage = 1 @@ -6279,12 +6244,13 @@ def zipblocks(blocks, blockids): i = start + 1 return out -def parse_elemental(elemental): +def parse_elemental(elemental, sdver): if len(elemental) > 0: elemental = elemental.replace(",","\n").strip().split("\n") elemental = [f.strip() for f in elemental] elemental_weights = {} + BLOCKIDS = all_blocks(sdver) if len(elemental) > 0: for d in elemental: if d.count(":") != 2: @@ -6301,7 +6267,7 @@ def parse_elemental(elemental): dbs = list(filter(None, dbs)) if len(dbs) > 0: dbn, dbs = (False, dbs[1:]) if dbs[0].upper() == "NOT" else (True, dbs) - dbs = prepblocks(dbs, BLOCKID, select=dbn) + dbs = prepblocks(dbs, BLOCKIDS, select=dbn) else: dbn = True @@ -6857,6 +6823,7 @@ def format_elemental_add_label(p, opt, x): f"[Model Mixer] Pinpoint block {Name}", str, partial(set_value, field=f"pinpoint block {name}"), + # XXX FIXME for FLUX choices=lambda: BLOCKID, ), xyz_grid.AxisOption( From 1977498300b5a43425062ba1a6f1bdebec0dfc20 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Sun, 25 Aug 2024 21:22:26 +0900 Subject: [PATCH 06/10] dynamic change xyz_grid options --- scripts/model_mixer.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index 7184236..13a53c4 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -3301,8 +3301,25 @@ def _update_model_list(max_models): self.init_on_app_started = True + # dynamic change xyz-grid + def fix_xyz_grid_options(demo, app): + xyz_grid = [x for x in scripts.scripts_data if x.script_class.__module__ == "xyz_grid.py"][0].module + + def update_blockids(sdv): + axis_options = [x for x in xyz_grid.axis_options + if type(x) == xyz_grid.AxisOption or x.is_img2img == is_img2img and x.label.startswith("[Model Mixer] Pinpoint block ")] + + blockids = all_blocks(sdv) + for axis_option in axis_options: + axis_option.choices = lambda: blockids + + with demo: + is_sdxl.change(fn=update_blockids, inputs=[is_sdxl], outputs=[], queue=False) + + if self.init_on_app_started is False: script_callbacks.on_app_started(on_app_started) + script_callbacks.on_app_started(fix_xyz_grid_options) generate_button = MM.components["img2img_generate" if is_img2img else "txt2img_generate"] @@ -6823,7 +6840,7 @@ def format_elemental_add_label(p, opt, x): f"[Model Mixer] Pinpoint block {Name}", str, partial(set_value, field=f"pinpoint block {name}"), - # XXX FIXME for FLUX + # fix blockids in the fix_xyz_grid_options(() choices=lambda: BLOCKID, ), xyz_grid.AxisOption( From 4acfdf80978fed5daf7f106bc7e6460565c6ffff Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Mon, 26 Aug 2024 00:01:47 +0900 Subject: [PATCH 07/10] check model_a sd version first. cleanup --- scripts/model_mixer.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index 13a53c4..1b88c84 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -3631,6 +3631,22 @@ def before_process(self, p, enabled, model_a, base_model, mm_max_models, mm_fine # save original mm_elementals orig_elementals = mm_elementals.copy() + mm_weights_orig = mm_weights + + # check model_a + checkpoint_info = sd_models.get_closet_checkpoint_match(model_a) + if checkpoint_info is None: + print(f"ERROR: Fail to get {model_a}") + return + model_a = checkpoint_info.model_name + print(f"model_a = {model_a}") + + # check SDXL, FLUX etc. + sdv = sdversion(model_a) + isxl = sdv == 'XL' + + print("sdversion =", sdv) + # parse elemental weights if "elemental merge" in debugs: print(" - Parse elemental merge...") all_elemental_blocks = [] @@ -3659,16 +3675,6 @@ def selected_elemental_blocks(blocks, sdver): elemental_selected[j] = True return elemental_selected - mm_weights_orig = mm_weights - - # load model_a - # check model_a - checkpoint_info = sd_models.get_closet_checkpoint_match(model_a) - if checkpoint_info is None: - print(f"ERROR: Fail to get {model_a}") - return - model_a = checkpoint_info.model_name - print(f"model_a = {model_a}") # load models models = {} @@ -3710,12 +3716,6 @@ def load_state_dict(checkpoint_info): return sd_models.read_state_dict(checkpoint_info.filename, map_location = "cpu").copy() - # check SDXL, FLUX etc. - sdv = sdversion(model_a) - isxl = sdv == 'XL' - - print("sdversion =", sdv) - # check base_model use_safe_open = shared.opts.data.get("mm_use_safe_open", False) From 8bcc6191f2940ed34013d295545617dd421a5664 Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Mon, 26 Aug 2024 00:05:17 +0900 Subject: [PATCH 08/10] copy() original weights --- scripts/model_mixer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index 1b88c84..249618f 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -3631,7 +3631,7 @@ def before_process(self, p, enabled, model_a, base_model, mm_max_models, mm_fine # save original mm_elementals orig_elementals = mm_elementals.copy() - mm_weights_orig = mm_weights + mm_weights_orig = mm_weights.copy() # check model_a checkpoint_info = sd_models.get_closet_checkpoint_match(model_a) From 6e1c335813f498810f8dec9082f8bbe19ec4ccaf Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Mon, 26 Aug 2024 23:00:52 +0900 Subject: [PATCH 09/10] fix for text_encoders. --- scripts/model_mixer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index 249618f..13da4ea 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -452,7 +452,7 @@ def print_blocks(blocks): n = int(x[13:len(x)-1]) block = f"IN{n:02d}" str.append(block) - elif "cond_stage_model" in x or "conditioner." in x: + elif "cond_stage_model" in x or "conditioner." in x or "text_encoders." in x: block = f"BASE" str.append(block) elif "time_embed." in x: @@ -4480,7 +4480,7 @@ def cosim(theta0, theta1, calcmode): for key in (tqdm(sel_keys, desc=f"Check uninitialized #{n+2-weight_start}/{stages}")): if "model" in key: for s in selected_blocks: - if s not in ["cond_stage_model.", "conditioner."]: + if s not in ["cond_stage_model.", "conditioner.", "text_encoders."]: s = f"model.diffusion_model.{s}" if s in key and key not in theta_0 and key not in checkpoint_dict_skip_on_merge: print(f" +{k}") @@ -4730,7 +4730,7 @@ def fake_checkpoint(checkpoint_info, metadata, model_name, sha256, fake=True): unet_updated = 0 for s in weight_changed_blocks: shared.state.textinfo = "Update UNet Blocks..." - if s in ["cond_stage_model.", "conditioner."]: + if s in ["cond_stage_model.", "conditioner.", "text_encoders."]: # Textencoder(BASE) continue print(" - update UNet block", s) @@ -4746,7 +4746,7 @@ def fake_checkpoint(checkpoint_info, metadata, model_name, sha256, fake=True): print(" - \033[92mUNet partial blocks have been successfully updated\033[0m") # textencoder partial update does not work as expected. read state_dict() and set state_dict. - if "cond_stage_model." in weight_changed_blocks or "conditioner." in weight_changed_blocks: + if any(prefix in weight_changed_blocks for prefix in ["cond_stage_model.", "conditioner.", "text_encoders."]): print(" - \033[93mReload full state_dict...\033[0m") shared.state.textinfo = "Reload full state_dict..." state_dict = shared.sd_model.state_dict().copy() From ffa38d088bc6db34540e842905d4700949d2bf6f Mon Sep 17 00:00:00 2001 From: Won-Kyu Park Date: Fri, 6 Sep 2024 18:36:15 +0900 Subject: [PATCH 10/10] fix sdversion() --- scripts/model_mixer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/model_mixer.py b/scripts/model_mixer.py index 13da4ea..ff19f2a 100644 --- a/scripts/model_mixer.py +++ b/scripts/model_mixer.py @@ -743,6 +743,8 @@ def sdversion(modelname_or_header): header = get_ckpt_header(checkpointinfo.filename) else: return None + else: + header = modelname_or_header if header is not None: if "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight" in header: @@ -756,7 +758,11 @@ def sdversion(modelname_or_header): v2 = False if 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight' in header: - v2 = header['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight']["shape"][1] == 1024 + w = header['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'] + if type(w) == dict: + v2 = w["shape"][1] == 1024 + else: + v2 = w.shape[1] == 1024 return 'v1' if not v2 else 'v2' return None