Skip to content

Commit

Permalink
Fix erroneous "unused lora tensors"
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf committed Dec 11, 2024
1 parent d22f183 commit f17c9f5
Showing 1 changed file with 32 additions and 41 deletions.
73 changes: 32 additions & 41 deletions lora.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,31 +302,22 @@ struct LoraModel : public GGMLRunner {
lora_down = ggml_cont(compute_ctx, lora_down_concat);
lora_up = ggml_cont(compute_ctx, lora_up_concat);

lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
// lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
// lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";

lora_tensors[lora_down_name] = lora_down;
lora_tensors[lora_up_name] = lora_up;
// lora_tensors[lora_down_name] = lora_down;
// lora_tensors[lora_up_name] = lora_up;

// Would be nice to be able to clean up lora_tensors, but it breaks because this is called twice :/
// lora_tensors.erase(split_q_u_name);
// lora_tensors.erase(split_k_u_name);
// lora_tensors.erase(split_v_u_name);
// lora_tensors.erase(split_m_u_name);

// lora_tensors.erase(split_q_d_name);
// lora_tensors.erase(split_k_d_name);
// lora_tensors.erase(split_v_d_name);
// lora_tensors.erase(split_m_d_name);

} else {
// lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
// lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
// if (lora_tensors.find(lora_up_name) != lora_tensors.end()) {
// // print_ggml_tensor(lora_tensors[lora_down_name], true); // [3072, R, 1, 1]
// // print_ggml_tensor(lora_tensors[lora_up_name], true); // [R, 21504, 1, 1]
// // print_ggml_tensor(it.second, true); // [3072, 21504, 1, 1]
// }
applied_lora_tensors.insert(split_q_u_name);
applied_lora_tensors.insert(split_k_u_name);
applied_lora_tensors.insert(split_v_u_name);
applied_lora_tensors.insert(split_m_u_name);

applied_lora_tensors.insert(split_q_d_name);
applied_lora_tensors.insert(split_k_d_name);
applied_lora_tensors.insert(split_v_d_name);
applied_lora_tensors.insert(split_m_d_name);
}
} else if (linear2 != std::string::npos) {
linear2--;
Expand All @@ -341,8 +332,8 @@ struct LoraModel : public GGMLRunner {
lora_down = lora_tensors[lora_down_name];
}

applied_lora_tensors.insert(lora_up_name);
applied_lora_tensors.insert(lora_down_name);
applied_lora_tensors.insert(lora_up_name);
}
} else if (modulation != std::string::npos) {
modulation--;
Expand All @@ -357,8 +348,8 @@ struct LoraModel : public GGMLRunner {
lora_down = lora_tensors[lora_down_name];
}

applied_lora_tensors.insert(lora_up_name);
applied_lora_tensors.insert(lora_down_name);
applied_lora_tensors.insert(lora_up_name);
}
}
// Double blocks
Expand Down Expand Up @@ -446,20 +437,20 @@ struct LoraModel : public GGMLRunner {
lora_down = ggml_cont(compute_ctx, lora_down_concat);
lora_up = ggml_cont(compute_ctx, lora_up_concat);

lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";
// lora_down_name = lora_pre[type] + key + lora_downs[type] + ".weight";
// lora_up_name = lora_pre[type] + key + lora_ups[type] + ".weight";

lora_tensors[lora_down_name] = lora_down;
lora_tensors[lora_up_name] = lora_up;
// lora_tensors[lora_down_name] = lora_down;
// lora_tensors[lora_up_name] = lora_up;

// Would be nice to be able to clean up lora_tensors, but it breaks because this is called twice :/
// lora_tensors.erase(split_q_u_name);
// lora_tensors.erase(split_k_u_name);
// lora_tensors.erase(split_v_u_name);
applied_lora_tensors.insert(split_q_u_name);
applied_lora_tensors.insert(split_k_u_name);
applied_lora_tensors.insert(split_v_u_name);

// lora_tensors.erase(split_q_d_name);
// lora_tensors.erase(split_k_d_name);
// lora_tensors.erase(split_v_d_name);
applied_lora_tensors.insert(split_q_d_name);
applied_lora_tensors.insert(split_k_d_name);
applied_lora_tensors.insert(split_v_d_name);
}
} else if (txt_attn_proj != std::string::npos || img_attn_proj != std::string::npos) {
size_t match = txt_attn_proj;
Expand All @@ -481,8 +472,8 @@ struct LoraModel : public GGMLRunner {
lora_down = lora_tensors[lora_down_name];
}

applied_lora_tensors.insert(lora_up_name);
applied_lora_tensors.insert(lora_down_name);
applied_lora_tensors.insert(lora_up_name);
}
} else if (txt_mlp_0 != std::string::npos || txt_mlp_2 != std::string::npos || img_mlp_0 != std::string::npos || img_mlp_2 != std::string::npos) {
bool has_two = txt_mlp_2 != std::string::npos || img_mlp_2 != std::string::npos;
Expand Down Expand Up @@ -514,8 +505,8 @@ struct LoraModel : public GGMLRunner {
lora_down = lora_tensors[lora_down_name];
}

applied_lora_tensors.insert(lora_up_name);
applied_lora_tensors.insert(lora_down_name);
applied_lora_tensors.insert(lora_up_name);
}
} else if (txt_mod_lin != std::string::npos || img_mod_lin != std::string::npos) {
size_t match = txt_mod_lin;
Expand All @@ -537,8 +528,8 @@ struct LoraModel : public GGMLRunner {
lora_down = lora_tensors[lora_down_name];
}

applied_lora_tensors.insert(lora_up_name);
applied_lora_tensors.insert(lora_down_name);
applied_lora_tensors.insert(lora_up_name);
}
}
}
Expand All @@ -564,11 +555,11 @@ struct LoraModel : public GGMLRunner {
if (lora_tensors.find(lora_down_name) != lora_tensors.end()) {
lora_down = lora_tensors[lora_down_name];
}
applied_lora_tensors.insert(lora_up_name);
applied_lora_tensors.insert(lora_down_name);
applied_lora_tensors.insert(alpha_name);
applied_lora_tensors.insert(scale_name);
}
applied_lora_tensors.insert(lora_up_name);
applied_lora_tensors.insert(lora_down_name);
applied_lora_tensors.insert(alpha_name);
applied_lora_tensors.insert(scale_name);

if (lora_up == NULL || lora_down == NULL) {
continue;
Expand Down

0 comments on commit f17c9f5

Please sign in to comment.