Skip to content

Commit

Permalink
clip_g support for SD3
Browse files Browse the repository at this point in the history
  • Loading branch information
stduhpf committed Oct 23, 2024
1 parent 14206fd commit f74fdab
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 12 deletions.
10 changes: 10 additions & 0 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ struct SDParams {
SDMode mode = TXT2IMG;

std::string model_path;
std::string clip_g_path;
std::string clip_l_path;
std::string t5xxl_path;
std::string diffusion_model_path;
Expand Down Expand Up @@ -127,6 +128,7 @@ void print_params(SDParams params) {
printf(" mode: %s\n", modes_str[params.mode]);
printf(" model_path: %s\n", params.model_path.c_str());
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
Expand Down Expand Up @@ -175,6 +177,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
printf(" -m, --model [MODEL] path to full model\n");
printf(" --diffusion-model path to the standalone diffusion model\n");
printf(" --clip_g path to the clip-g text encoder\n");
printf(" --clip_l path to the clip-l text encoder\n");
printf(" --t5xxl path to the the t5xxl text encoder.\n");
printf(" --vae [VAE] path to vae\n");
Expand Down Expand Up @@ -256,6 +259,12 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.model_path = argv[i];
} else if (arg == "--clip_g") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.clip_g_path = argv[i];
} else if (arg == "--clip_l") {
if (++i >= argc) {
invalid_arg = true;
Expand Down Expand Up @@ -764,6 +773,7 @@ int main(int argc, const char* argv[]) {
}

sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(),
params.clip_g_path.c_str(),
params.clip_l_path.c_str(),
params.t5xxl_path.c_str(),
params.diffusion_model_path.c_str(),
Expand Down
47 changes: 35 additions & 12 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class StableDiffusionGGML {
}

bool load_from_file(const std::string& model_path,
const std::string& clip_g_path,
const std::string& clip_l_path,
const std::string& t5xxl_path,
const std::string& diffusion_model_path,
Expand Down Expand Up @@ -167,7 +168,7 @@ class StableDiffusionGGML {
for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
backend = ggml_backend_vk_init(device);
}
if(!backend) {
if (!backend) {
LOG_WARN("Failed to initialize Vulkan backend");
}
#endif
Expand All @@ -181,7 +182,7 @@ class StableDiffusionGGML {
backend = ggml_backend_cpu_init();
}
#ifdef SD_USE_FLASH_ATTENTION
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined (SD_USE_SYCL) || defined(SD_USE_VULKAN)
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN)
LOG_WARN("Flash Attention not supported with GPU Backend");
#else
LOG_INFO("Flash Attention enabled");
Expand All @@ -198,24 +199,44 @@ class StableDiffusionGGML {
}
}

if (diffusion_model_path.size() > 0) {
LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str());
if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) {
LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str());
}
}
version = model_loader.get_sd_version();

if (clip_g_path.size() > 0) {
LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str());
std::string prefix = "text_encoders.clip_g.";
if (version == VERSION_SD3_2B ) {
prefix = "text_encoders.clip_g.transformer.";
}
if (!model_loader.init_from_file(clip_g_path, prefix)) {
LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str());
}
}

if (clip_l_path.size() > 0) {
LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str());
if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.")) {
std::string prefix = "text_encoders.clip_l.";
if (version == VERSION_SD3_2B ) {
prefix = "text_encoders.clip_l.transformer.";
}
if (!model_loader.init_from_file(clip_l_path, prefix)) {
LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str());
}
}

if (t5xxl_path.size() > 0) {
LOG_INFO("loading t5xxl from '%s'", t5xxl_path.c_str());
if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.")) {
LOG_WARN("loading t5xxl from '%s' failed", t5xxl_path.c_str());
std::string prefix = "text_encoders.t5xxl.";
if (version == VERSION_SD3_2B ) {
prefix = "text_encoders.t5xxl.transformer.";
}
}

if (diffusion_model_path.size() > 0) {
LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str());
if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) {
LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str());
if (!model_loader.init_from_file(t5xxl_path, prefix)) {
LOG_WARN("loading t5xxl from '%s' failed", t5xxl_path.c_str());
}
}

Expand All @@ -226,7 +247,6 @@ class StableDiffusionGGML {
}
}

version = model_loader.get_sd_version();
if (version == VERSION_COUNT) {
LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str());
return false;
Expand Down Expand Up @@ -1007,6 +1027,7 @@ struct sd_ctx_t {
};

sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
const char* clip_g_path_c_str,
const char* clip_l_path_c_str,
const char* t5xxl_path_c_str,
const char* diffusion_model_path_c_str,
Expand All @@ -1031,6 +1052,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
return NULL;
}
std::string model_path(model_path_c_str);
std::string clip_g_path(clip_g_path_c_str);
std::string clip_l_path(clip_l_path_c_str);
std::string t5xxl_path(t5xxl_path_c_str);
std::string diffusion_model_path(diffusion_model_path_c_str);
Expand All @@ -1051,6 +1073,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
}

if (!sd_ctx->sd->load_from_file(model_path,
clip_g_path,
clip_l_path,
t5xxl_path_c_str,
diffusion_model_path,
Expand Down
1 change: 1 addition & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ typedef struct {
typedef struct sd_ctx_t sd_ctx_t;

SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
const char* clip_g_path,
const char* clip_l_path,
const char* t5xxl_path,
const char* diffusion_model_path,
Expand Down

0 comments on commit f74fdab

Please sign in to comment.