diff --git a/control.hpp b/control.hpp index 3b2c1e1b..0cf081ce 100644 --- a/control.hpp +++ b/control.hpp @@ -317,12 +317,10 @@ struct ControlNet : public GGMLRunner { bool guided_hint_cached = false; ControlNet(ggml_backend_t backend, - SDVersion version = VERSION_SD1) + std::map& tensor_types, + SDVersion version = VERSION_SD1) : GGMLRunner(backend), control_net(version) { - } - - void init_params(std::map& tensor_types, const std::string prefix) { - control_net.init(params_ctx, tensor_types, prefix); + control_net.init(params_ctx, tensor_types, ""); } ~ControlNet() { diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 49d60810..a12cfed6 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -358,7 +358,7 @@ class StableDiffusionGGML { first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { - tae_first_stage = std::make_shared(backend, vae_decode_only); + tae_first_stage = std::make_shared(backend, model_loader.tensor_storages_types, "decoder.layers", vae_decode_only); } // first_stage_model->get_param_tensors(tensors, "first_stage_model."); @@ -370,7 +370,7 @@ class StableDiffusionGGML { } else { controlnet_backend = backend; } - control_net = std::make_shared(controlnet_backend, version); + control_net = std::make_shared(controlnet_backend, model_loader.tensor_storages_types, version); } if (id_embeddings_path.find("v2") != std::string::npos) { diff --git a/tae.hpp b/tae.hpp index b9cc3312..ac061115 100644 --- a/tae.hpp +++ b/tae.hpp @@ -188,14 +188,15 @@ struct TinyAutoEncoder : public GGMLRunner { bool decode_only = false; TinyAutoEncoder(ggml_backend_t backend, + std::map& tensor_types, + const std::string prefix, bool decoder_only = true) : decode_only(decoder_only), taesd(decode_only), GGMLRunner(backend) { - } - void init_params(std::map& tensor_types, const std::string prefix) { taesd.init(params_ctx, tensor_types, prefix); } + std::string get_desc() { return "taesd"; }