Skip to content

Commit

Permalink
Refactor: fix controlnet and tae
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf committed Nov 26, 2024
1 parent 371d81f commit 04ca926
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
8 changes: 3 additions & 5 deletions control.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,10 @@ struct ControlNet : public GGMLRunner {
bool guided_hint_cached = false;

ControlNet(ggml_backend_t backend,
SDVersion version = VERSION_SD1)
std::map<std::string, enum ggml_type>& tensor_types,
SDVersion version = VERSION_SD1)
: GGMLRunner(backend), control_net(version) {
}

void init_params(std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix) {
control_net.init(params_ctx, tensor_types, prefix);
control_net.init(params_ctx, tensor_types, "");
}

~ControlNet() {
Expand Down
4 changes: 2 additions & 2 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ class StableDiffusionGGML {
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else {
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, vae_decode_only);
tae_first_stage = std::make_shared<TinyAutoEncoder>(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only);
}
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");

Expand All @@ -370,7 +370,7 @@ class StableDiffusionGGML {
} else {
controlnet_backend = backend;
}
control_net = std::make_shared<ControlNet>(controlnet_backend, version);
control_net = std::make_shared<ControlNet>(controlnet_backend, model_loader.tensor_storages_types, version);
}

if (id_embeddings_path.find("v2") != std::string::npos) {
Expand Down
5 changes: 3 additions & 2 deletions tae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,15 @@ struct TinyAutoEncoder : public GGMLRunner {
bool decode_only = false;

TinyAutoEncoder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
const std::string prefix,
bool decoder_only = true)
: decode_only(decoder_only),
taesd(decode_only),
GGMLRunner(backend) {
}
void init_params(std::map<std::string, enum ggml_type>& tensor_types, const std::string prefix) {
taesd.init(params_ctx, tensor_types, prefix);
}

std::string get_desc() {
return "taesd";
}
Expand Down

0 comments on commit 04ca926

Please sign in to comment.