Skip to content

Commit

Permalink
fix: correct the handling of weight loading
Browse files Browse the repository at this point in the history
  • Loading branch information
leejet committed Aug 30, 2023
1 parent 1b5a868 commit c542a77
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
26 changes: 20 additions & 6 deletions models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

QK4_0 = 32
def quantize_q4_0(x):
assert x.shape[-1] % QK4_0 == 0
assert x.shape[-1] % QK4_0 == 0 and x.shape[-1] > QK4_0
x = x.reshape(-1, QK4_0)
max = np.take_along_axis(x, np.argmax(np.abs(x), axis=-1)[:, np.newaxis], axis=-1)
d = max / -8
Expand All @@ -44,7 +44,7 @@ def quantize_q4_0(x):

QK4_1 = 32
def quantize_q4_1(x):
assert x.shape[-1] % QK4_1 == 0
assert x.shape[-1] % QK4_1 == 0 and x.shape[-1] > QK4_1
x = x.reshape(-1, QK4_1)
min = np.min(x, axis=-1, keepdims=True)
max = np.max(x, axis=-1, keepdims=True)
Expand All @@ -59,7 +59,7 @@ def quantize_q4_1(x):

QK5_0 = 32
def quantize_q5_0(x):
assert x.shape[1] % QK5_0 == 0
assert x.shape[-1] % QK5_0 == 0 and x.shape[-1] > QK5_0
x = x.reshape(-1, QK5_0)
max = np.take_along_axis(x, np.argmax(np.abs(x), axis=-1)[:, np.newaxis], axis=-1)
d = max / -16
Expand All @@ -76,7 +76,7 @@ def quantize_q5_0(x):

QK5_1 = 32
def quantize_q5_1(x):
assert x.shape[-1] % QK5_1 == 0
assert x.shape[-1] % QK5_1 == 0 and x.shape[-1] > QK5_1
x = x.reshape(-1, QK5_1)
min = np.min(x, axis=-1, keepdims=True)
max = np.max(x, axis=-1, keepdims=True)
Expand All @@ -95,7 +95,7 @@ def quantize_q5_1(x):

QK8_0 = 32
def quantize_q8_0(x):
assert x.shape[-1] % QK8_0 == 0
assert x.shape[-1] % QK8_0 == 0 and x.shape[-1] > QK8_0
x = x.reshape(-1, QK8_0)
amax = np.max(np.abs(x), axis=-1, keepdims=True)
d = amax / ((1 << 7) - 1)
Expand Down Expand Up @@ -156,7 +156,10 @@ def get_alpha_comprod(linear_start=0.00085, linear_end=0.0120, timesteps=1000):
"posterior_mean_coef2",
"cond_stage_model.transformer.text_model.embeddings.position_ids",
"model_ema.decay",
"model_ema.num_updates"
"model_ema.num_updates",
"control_model",
"lora_te_text_model",
"embedding_manager"
]

def convert(model_path, out_type = None, out_file=None):
Expand All @@ -182,6 +185,10 @@ def convert(model_path, out_type = None, out_file=None):
out_type = "f32"
elif weight.dtype == np.float16:
out_type = "f16"
elif weight.dtype == np.float64:
out_type = "f32"
else:
raise Exception("unsupported weight type %s" % weight.dtype)
if out_file == None:
out_file = os.path.splitext(os.path.basename(model_path))[0] + f"-ggml-model-{out_type}.bin"
out_file = os.path.join(os.getcwd(), out_file)
Expand All @@ -207,6 +214,13 @@ def convert(model_path, out_type = None, out_file=None):
for name in state_dict.keys():
if not isinstance(state_dict[name], torch.Tensor):
continue
skip = False
for unused_tensor in unused_tensors:
if name.startswith(unused_tensor):
skip = True
break
if skip:
continue
if name in unused_tensors:
continue
data = state_dict[name].numpy()
Expand Down
6 changes: 3 additions & 3 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2864,6 +2864,8 @@ class StableDiffusionGGML {
nelements *= ne[i];
}

const size_t num_bytes = nelements / ggml_blck_size(ggml_type(ttype)) * ggml_type_size(ggml_type(ttype));

std::string name(length, 0);
file.read(&name[0], length);

Expand Down Expand Up @@ -2891,7 +2893,7 @@ class StableDiffusionGGML {
return false;
}
}
file.ignore(nelements * ggml_type_size((ggml_type)ttype));
file.ignore(num_bytes);
continue;
}

Expand Down Expand Up @@ -2919,8 +2921,6 @@ class StableDiffusionGGML {
return false;
}

const size_t num_bytes = nelements / ggml_blck_size(ggml_type(ttype)) * ggml_type_size(ggml_type(ttype));

file.read(reinterpret_cast<char*>(tensor->data), num_bytes);

total_size += ggml_nbytes(tensor);
Expand Down

0 comments on commit c542a77

Please sign in to comment.