diff --git a/conditioner.hpp b/conditioner.hpp index 5b3f20dd..eea068ae 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1190,4 +1190,171 @@ struct FluxCLIPEmbedder : public Conditioner { } }; +struct SimpleT5Embedder : public Conditioner { + T5UniGramTokenizer t5_tokenizer; + std::shared_ptr t5; + + SimpleT5Embedder(ggml_backend_t backend, + std::map& tensor_types, + int clip_skip = -1) { + t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); + } + + void get_param_tensors(std::map& tensors) { + t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer"); + } + + void alloc_params_buffer() { + t5->alloc_params_buffer(); + } + + void free_params_buffer() { + t5->free_params_buffer(); + } + + size_t get_params_buffer_size() { + size_t buffer_size = t5->get_params_buffer_size(); + return buffer_size; + } + + std::pair, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { + auto parsed_attention = parse_prompt_attention(text); + + { + std::stringstream ss; + ss << "["; + for (const auto& item : parsed_attention) { + ss << "['" << item.first << "', " << item.second << "], "; + } + ss << "]"; + LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + } + + auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { + return false; + }; + + std::vector t5_tokens; + std::vector t5_weights; + for (const auto& item : parsed_attention) { + const std::string& curr_text = item.first; + float curr_weight = item.second; + + std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); + t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); + } + + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding); + + // for (int i = 0; i < clip_l_tokens.size(); i++) { + // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", "; + // } + // std::cout << std::endl; + + // for (int i = 0; i < t5_tokens.size(); i++) { + // std::cout << t5_tokens[i] << ":" << t5_weights[i] << ", "; + // } + // std::cout << std::endl; + + return {t5_tokens, t5_weights}; + } + + SDCondition get_learned_condition_common(ggml_context* work_ctx, + int n_threads, + std::pair, std::vector> token_and_weights, + int clip_skip, + bool force_zero_embeddings = false) { + auto& t5_tokens = token_and_weights.first; + auto& t5_weights = token_and_weights.second; + + int64_t t0 = ggml_time_ms(); + struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096] + struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096] + std::vector hidden_states_vec; + + size_t chunk_len = 256; + size_t chunk_count = t5_tokens.size() / chunk_len; + for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { + std::vector chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len, + t5_tokens.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_weights(t5_weights.begin() + chunk_idx * chunk_len, + t5_weights.begin() + (chunk_idx + 1) * chunk_len); + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); + + t5->compute(n_threads, + input_ids, + &chunk_hidden_states, + work_ctx); + { + auto tensor = chunk_hidden_states; + float original_mean = ggml_tensor_mean(tensor); + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float value = ggml_tensor_get_f32(tensor, i0, i1, i2); + value *= chunk_weights[i1]; + ggml_tensor_set_f32(tensor, value, i0, i1, i2); + } + } + } + float new_mean = ggml_tensor_mean(tensor); + ggml_tensor_scale(tensor, (original_mean / new_mean)); + } + + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); + if (force_zero_embeddings) { + float* vec = (float*)chunk_hidden_states->data; + for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) { + vec[i] = 0; + } + } + + hidden_states_vec.insert(hidden_states_vec.end(), + (float*)chunk_hidden_states->data, + ((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states)); + } + + hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec); + hidden_states = ggml_reshape_2d(work_ctx, + hidden_states, + chunk_hidden_states->ne[0], + ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); + return SDCondition(hidden_states, NULL, NULL); + } + + SDCondition get_learned_condition(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + auto tokens_and_weights = tokenize(text, 256, true); + return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); + } + + std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int num_input_imgs, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + GGML_ASSERT(0 && "Not implemented yet!"); + } + + std::string remove_trigger_from_prompt(ggml_context* work_ctx, + const std::string& prompt) { + GGML_ASSERT(0 && "Not implemented yet!"); + } +}; + #endif \ No newline at end of file diff --git a/diffusion_model.hpp b/diffusion_model.hpp index cbc0cd4c..f76a10db 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -2,6 +2,7 @@ #define __DIFFUSION_MODEL_H__ #include "flux.hpp" +#include "ltx.hpp" #include "mmdit.hpp" #include "unet.hpp" @@ -178,4 +179,54 @@ struct FluxModel : public DiffusionModel { } }; +struct LTXModel : public DiffusionModel { + Ltx::LTXRunner ltx; + + LTXModel(ggml_backend_t backend, + std::map& tensor_types, + bool flash_attn = false) + : ltx(backend, tensor_types, "model.diffusion_model") { + } + + void alloc_params_buffer() { + ltx.alloc_params_buffer(); + } + + void free_params_buffer() { + ltx.free_params_buffer(); + } + + void free_compute_buffer() { + ltx.free_compute_buffer(); + } + + void get_param_tensors(std::map& tensors) { + ltx.get_param_tensors(tensors, "model.diffusion_model"); + } + + size_t get_params_buffer_size() { + return ltx.get_params_buffer_size(); + } + + int64_t get_adm_in_channels() { + return 768; + } + + void compute(int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* c_concat, + struct ggml_tensor* y, + struct ggml_tensor* guidance, + int num_video_frames = -1, + std::vector controls = {}, + float control_strength = 0.f, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + return ltx.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers); + } +}; + #endif diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 4b47286f..814ddf7d 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -54,6 +54,7 @@ const char* schedule_str[] = { const char* modes_str[] = { "txt2img", "img2img", + "txt2vid", "img2vid", "convert", }; @@ -61,6 +62,7 @@ const char* modes_str[] = { enum SDMode { TXT2IMG, IMG2IMG, + TXT2VID, IMG2VID, CONVERT, MODE_COUNT @@ -264,7 +266,7 @@ void parse_args(int argc, const char** argv, SDParams& params) { } if (mode_found == -1) { fprintf(stderr, - "error: invalid mode %s, must be one of [txt2img, img2img, img2vid, convert]\n", + "error: invalid mode %s, must be one of [txt2img, img2img, txt2vid, img2vid, convert]\n", mode_selected); exit(1); } @@ -931,6 +933,23 @@ int main(int argc, const char* argv[]) { params.slg_scale, params.skip_layer_start, params.skip_layer_end); + } else if (params.mode == TXT2VID) { + results = txt2vid(sd_ctx, + params.prompt.c_str(), + params.negative_prompt.c_str(), + params.clip_skip, + params.cfg_scale, + params.width, + params.height, + params.sample_method, + params.sample_steps, + params.seed, + params.batch_count, + params.skip_layers.data(), + params.skip_layers.size(), + params.slg_scale, + params.skip_layer_start, + params.skip_layer_end); } else { sd_image_t input_image = {(uint32_t)params.width, (uint32_t)params.height, diff --git a/ltx.hpp b/ltx.hpp new file mode 100644 index 00000000..9838e4a6 --- /dev/null +++ b/ltx.hpp @@ -0,0 +1,532 @@ +#ifndef __LTX_HPP__ +#define __LTX_HPP__ + +#include "ggml_extend.hpp" +#include "model.h" + +#define LTX_GRAPH_SIZE 10240 +namespace Ltx { + + struct Mlp : public GGMLBlock { + public: + Mlp(int64_t in_features, + int64_t hidden_features = -1, + int64_t out_features = -1, + bool bias = true) { + // act_layer is always lambda: nn.GELU(approximate="tanh") + // norm_layer is always None + // use_conv is always False + if (hidden_features == -1) { + hidden_features = in_features; + } + if (out_features == -1) { + out_features = in_features; + } + blocks["net.0.proj"] = std::shared_ptr(new Linear(in_features, hidden_features, bias)); + blocks["net.2"] = std::shared_ptr(new Linear(hidden_features, out_features, bias)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [N, n_token, in_features] + auto fc1 = std::dynamic_pointer_cast(blocks["net.0.proj"]); + auto fc2 = std::dynamic_pointer_cast(blocks["net.2"]); + + x = fc1->forward(ctx, x); + x = ggml_gelu_inplace(ctx, x); + x = fc2->forward(ctx, x); + return x; + } + }; + + struct EmbedProjection : public GGMLBlock { + // Embeds scalar timesteps into vector representations. + public: + EmbedProjection(int64_t hidden_size, + int64_t embedding_size = 256) { + blocks["linear_1"] = std::shared_ptr(new Linear(embedding_size, hidden_size, true, true)); + blocks["linear_2"] = std::shared_ptr(new Linear(hidden_size, hidden_size, true, true)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) { + // t: [N, ] + // return: [N, hidden_size] + auto mlp_0 = std::dynamic_pointer_cast(blocks["linear_1"]); + auto mlp_2 = std::dynamic_pointer_cast(blocks["linear_2"]); + + auto t_emb = mlp_0->forward(ctx, t); + t_emb = ggml_silu_inplace(ctx, t_emb); + t_emb = mlp_2->forward(ctx, t_emb); + return t_emb; + } + }; + + struct AdaLnSingleEmbedder : public GGMLBlock { + protected: + int64_t frequency_embedding_size; + + public: + AdaLnSingleEmbedder(int64_t hidden_size, int64_t frequency_embedding_size = 256) + : frequency_embedding_size(frequency_embedding_size) { + blocks["timestep_embedder"] = std::shared_ptr(new EmbedProjection(hidden_size, frequency_embedding_size)); + } + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) { + auto t_embedder = std::dynamic_pointer_cast(blocks["timestep_embedder"]); + auto t_freq = ggml_nn_timestep_embedding(ctx, t, frequency_embedding_size); // [N, frequency_embedding_size] + + return t_embedder->forward(ctx, t_freq); + } + }; + + struct AdaLnSingle : public GGMLBlock { + // Embeds scalar timesteps into vector representations. + public: + AdaLnSingle(int64_t hidden_size, int64_t frequency_embedding_size = 256, int64_t num_scales_shifts = 6) { + blocks["emb"] = std::shared_ptr(new AdaLnSingleEmbedder(hidden_size, frequency_embedding_size)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, hidden_size * num_scales_shifts, true, true)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* t) { + auto embedder = std::dynamic_pointer_cast(blocks["emb"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + + auto embeds = embedder->forward(ctx, t); + embeds = ggml_silu_inplace(ctx, embeds); + return linear->forward(ctx, embeds); + } + }; + + class RMSNorm : public UnaryBlock { + protected: + int64_t hidden_size; + float eps; + + void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); + } + + public: + RMSNorm(int64_t hidden_size, + float eps = 1e-06f) + : hidden_size(hidden_size), + eps(eps) {} + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["weight"]; + x = ggml_rms_norm(ctx, x, eps); + x = ggml_mul(ctx, x, w); + return x; + } + }; + + class Attention : public GGMLBlock { + public: + int64_t num_heads; + std::string qk_norm; + + public: + Attention(int64_t dim, + int64_t num_heads = 8, + std::string qk_norm = "", + bool qkv_bias = false) + : num_heads(num_heads), qk_norm(qk_norm) { + + blocks["to_q"] = std::shared_ptr(new Linear(dim, dim, qkv_bias)); + blocks["to_k"] = std::shared_ptr(new Linear(dim, dim, qkv_bias)); + blocks["to_v"] = std::shared_ptr(new Linear(dim, dim, qkv_bias)); + blocks["to_out.0"] = std::shared_ptr(new Linear(dim, dim, qkv_bias)); + + blocks["k_norm"] = std::shared_ptr(new RMSNorm(dim, 1.0e-6)); + blocks["q_norm"] = std::shared_ptr(new RMSNorm(dim, 1.0e-6)); + } + + std::vector pre_attention(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* y) { + auto q_proj = std::dynamic_pointer_cast(blocks["to_q"]); + auto k_proj = std::dynamic_pointer_cast(blocks["to_k"]); + auto v_proj = std::dynamic_pointer_cast(blocks["to_v"]); + + auto q = q_proj->forward(ctx, x); + auto k = k_proj->forward(ctx, y); + auto v = v_proj->forward(ctx, y); + + { + auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]); + auto k_norm = std::dynamic_pointer_cast(blocks["k_norm"]); + q = q_norm->forward(ctx, q); + k = k_norm->forward(ctx, k); + } + + q = ggml_reshape_3d(ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head] + k = ggml_reshape_3d(ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head] + + return {q, k, v}; + } + + struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) { + auto out_proj = std::dynamic_pointer_cast(blocks["to_out.0"]); + + x = out_proj->forward(ctx, x); // [N, n_token, dim] + return x; + } + + // x: [N, n_token, dim] + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* y) { + auto qkv = pre_attention(ctx, x, y); + x = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] + return x; + } + }; + + __STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* shift, + struct ggml_tensor* scale) { + // x: [N, L, C] + // scale: [N, C] + // shift: [N, C] + scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] + shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C] + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); + x = ggml_add(ctx, x, shift); + return x; + } + + struct TransformerBlock : public GGMLBlock { + public: + int64_t hidden_size; + + public: + void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + enum ggml_type wtype = (tensor_types.find(prefix + "scale_shift_table") != tensor_types.end()) ? tensor_types[prefix + "scale_shift_table"] : GGML_TYPE_F32; + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, wtype, hidden_size, 6); + ; + } + TransformerBlock(int64_t hidden_size, + int64_t num_heads, + float mlp_ratio = 4.0, + std::string qk_norm = "", + bool qkv_bias = false) + : hidden_size(hidden_size) { + blocks["attn1"] = std::shared_ptr(new Attention(hidden_size, num_heads, qk_norm, qkv_bias)); + blocks["attn2"] = std::shared_ptr(new Attention(hidden_size, num_heads, qk_norm, qkv_bias)); + + blocks["ff"] = std::shared_ptr(new Mlp(hidden_size, hidden_size * mlp_ratio, hidden_size, true)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* context, + struct ggml_tensor* x, + struct ggml_tensor* c, + struct ggml_tensor* shift_scale) { + struct ggml_tensor* ss_table = params["scale_shift_table"]; // [hidden_size, 6] + + auto ss = ggml_add(ctx, shift_scale, ss_table); + + int64_t offset = ss->nb[0] * ss->ne[0]; + // TODO: Is that the right order? + // assuming [scale0, scale2, shift0, shift2, scale1, scale3] from Pixart alpha paper + + auto scale_0 = ggml_view_1d(ctx, ss, ss->ne[0], offset * 0); + auto scale_2 = ggml_view_1d(ctx, ss, ss->ne[0], offset * 1); + + auto shift_0 = ggml_view_1d(ctx, ss, ss->ne[0], offset * 2); + auto shift_2 = ggml_view_1d(ctx, ss, ss->ne[0], offset * 3); + + auto scale_1 = ggml_view_1d(ctx, ss, ss->ne[0], offset * 4); + auto scale_3 = ggml_view_1d(ctx, ss, ss->ne[0], offset * 5); + + auto attn1 = std::dynamic_pointer_cast(blocks["attn1"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + + auto ff = std::dynamic_pointer_cast(blocks["ff"]); + + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale_0)); + x = ggml_add(ctx, x, shift_0); + + x = attn1->forward(ctx, x, x); + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale_1)); + + x = attn2->forward(ctx, x, c); + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale_2)); + x = ggml_add(ctx, x, shift_2); + + x = ff->forward(ctx, x); + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale_3)); + + return x; + } + }; + + struct LTXv : public GGMLBlock { + // TODO: This seems to be closely related to Pixart Alpha models + // Support both here? + protected: + int64_t input_size = -1; + int64_t patch_size = 2; + int64_t in_channels = 128; + int64_t depth = 24; + float mlp_ratio = 4.0f; + int64_t adm_in_channels = 2048; + int64_t out_channels = 128; + int64_t pos_embed_max_size = 192; + int64_t num_patchs = 36864; // 192 * 192 + int64_t context_size = 4096; + int64_t hidden_size; + + void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "pos_embed") != tensor_types.end()) ? tensor_types[prefix + "pos_embed"] : GGML_TYPE_F32; + params["scale_shift_table"] = ggml_new_tensor_2d(ctx, wtype, hidden_size, 2); // scales and shifts for last layer + } + + public: + LTXv(std::map& tensor_types) { + // read tensors from tensor_types + for (auto pair : tensor_types) { + std::string tensor_name = pair.first; + if (tensor_name.find("model.diffusion_model.") == std::string::npos) + continue; + size_t jb = tensor_name.find("transformer_blocks."); + if (jb != std::string::npos) { + tensor_name = tensor_name.substr(jb); // remove prefix + int block_depth = atoi(tensor_name.substr(19, tensor_name.find(".", 19)).c_str()); + if (block_depth + 1 > depth) { + depth = block_depth + 1; + } + } + } + + LOG_INFO("Transformer layers: %d", depth); + + int64_t default_out_channels = in_channels; + hidden_size = 2048; + int64_t num_heads = depth; + + blocks["patchify_proj"] = std::shared_ptr(new Linear(in_channels, hidden_size)); + + blocks["adaln_single"] = std::shared_ptr(new AdaLnSingle(hidden_size)); + + blocks["caption_projection"] = std::shared_ptr(new EmbedProjection(hidden_size, context_size)); + + for (int i = 0; i < depth; i++) { + blocks["transformer_blocks." + std::to_string(i)] = std::shared_ptr(new TransformerBlock(hidden_size, + num_heads, + mlp_ratio, + "rms", + true)); + } + + // params["scale_shift_table"] (in init_params()) + blocks["proj_out"] = std::shared_ptr(new Linear(hidden_size, out_channels)); + LOG_INFO("Loaded"); + } + + struct ggml_tensor* forward_core(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* c_mod, + struct ggml_tensor* shift_scale, + struct ggml_tensor* context, + std::vector skip_layers = std::vector()) { + auto final_layer_proj = std::dynamic_pointer_cast(blocks["proj_out"]); + auto final_layer_modulation = params["scale_shift_table"]; // [hidden_size, 2] + + // TODO: figure out last layer modulation + // we need to slice shift_scale : [hidden_size, 6] => [hidden_size, 2] (which columns to keep?) + // then add to final_layer_modulation + + // auto scale = ggml_view_1d(ctx, final_layer_modulation, final_layer_modulation->ne[0], offset * 0); + // auto shift = ggml_view_1d(ctx, final_layer_modulation, final_layer_modulation->ne[0], offset * 1); + + for (int i = 0; i < depth; i++) { + // skip iteration if i is in skip_layers + if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { + continue; + } + + auto block = std::dynamic_pointer_cast(blocks["transformer_blocks." + std::to_string(i)]); + + x = block->forward(ctx, context, x, c_mod, shift_scale); + } + + x = final_layer_proj->forward(ctx, x); // (N, T, patch_size ** 2 * out_channels) + + // TODO: before or after final proj? (probably after) + // x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); + // x = ggml_add(ctx, x, shift); + + return x; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* t, + struct ggml_tensor* y = NULL, + struct ggml_tensor* context = NULL, + std::vector skip_layers = std::vector()) { + auto x_embedder = std::dynamic_pointer_cast(blocks["patchify_proj"]); + auto t_embedder = std::dynamic_pointer_cast(blocks["adaln_single"]); + + int64_t w = x->ne[0]; + int64_t h = x->ne[1]; + + auto hidden_states = x_embedder->forward(ctx, x); // [N, H*W, hidden_size] + + auto shift_scales = t_embedder->forward(ctx, t); // [hidden_size * 6] + shift_scales = ggml_reshape_2d(ctx, shift_scales, hidden_size, 6); // [hidden_size, 6] + + auto c = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + if (y != NULL && adm_in_channels != -1) { + auto y_embedder = std::dynamic_pointer_cast(blocks["caption_projection"]); + + y = y_embedder->forward(ctx, y); // [N, hidden_size] + c = ggml_add(ctx, c, y); + } + + // if (context != NULL) { + // auto context_embedder = std::dynamic_pointer_cast(blocks["context_embedder"]); + + // context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536] + // } + + x = forward_core(ctx, x, c, shift_scales, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) + + // x = unpatchify(ctx, x, h, w); // [N, C, H, W] + + return x; + } + }; + struct LTXRunner : public GGMLRunner { + LTXv ltx; + + static std::map empty_tensor_types; + + LTXRunner(ggml_backend_t backend, + std::map& tensor_types = empty_tensor_types, + const std::string prefix = "") + : GGMLRunner(backend), ltx(tensor_types) { + ltx.init(params_ctx, tensor_types, prefix); + } + + std::string get_desc() { + return "ltx"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + ltx.get_param_tensors(tensors, prefix); + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* y, + std::vector skip_layers = std::vector()) { + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LTX_GRAPH_SIZE, false); + + x = to_backend(x); + context = to_backend(context); + y = to_backend(y); + timesteps = to_backend(timesteps); + + struct ggml_tensor* out = ltx.forward(compute_ctx, + x, + timesteps, + y, + context, + skip_layers); + + ggml_build_forward_expand(gf, out); + + return gf; + } + + void compute(int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* y, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + // x: [N, in_channels, h, w] + // timesteps: [N, ] + // context: [N, max_position, hidden_size]([N, 154, 4096]) or [1, max_position, hidden_size] + // y: [N, adm_in_channels] or [1, adm_in_channels] + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(x, timesteps, context, y, skip_layers); + }; + + GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); + } + + void test() { + struct ggml_init_params params; + params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB + params.mem_buffer = NULL; + params.no_alloc = false; + + struct ggml_context* work_ctx = ggml_init(params); + GGML_ASSERT(work_ctx != NULL); + + { + // cpu f16: pass + // cpu f32: pass + // cuda f16: pass + // cuda f32: pass + auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 128, 128, 16, 1); + std::vector timesteps_vec(1, 999.f); + auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); + ggml_set_f32(x, 0.01f); + // print_ggml_tensor(x); + + auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 154, 1); + ggml_set_f32(context, 0.01f); + // print_ggml_tensor(context); + + auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 2048, 1); + ggml_set_f32(y, 0.01f); + // print_ggml_tensor(y); + + struct ggml_tensor* out = NULL; + + int t0 = ggml_time_ms(); + compute(8, x, timesteps, context, y, &out, work_ctx); + int t1 = ggml_time_ms(); + + print_ggml_tensor(out); + LOG_DEBUG("ltx test done in %dms", t1 - t0); + } + } + + static void load_from_file_and_test(const std::string& file_path) { + // ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_type model_data_type = GGML_TYPE_F16; + std::shared_ptr ltx = std::shared_ptr(new LTXRunner(backend)); + { + LOG_INFO("loading from '%s'", file_path.c_str()); + + ltx->alloc_params_buffer(); + std::map tensors; + ltx->get_param_tensors(tensors, "model.diffusion_model"); + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path)) { + LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); + return; + } + + bool success = model_loader.load_tensors(tensors, backend); + + if (!success) { + LOG_ERROR("load tensors from model loader failed"); + return; + } + + LOG_INFO("ltx model loaded"); + } + ltx->test(); + } + }; +} // namespace Flux + +#endif // __LTX_HPP__ \ No newline at end of file diff --git a/model.cpp b/model.cpp index c90918ad..9dd5860b 100644 --- a/model.cpp +++ b/model.cpp @@ -1032,19 +1032,19 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const } int n_dims = (int)shape.size(); - int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1, 1}; for (int i = 0; i < n_dims; i++) { ne[i] = shape[i].get(); } - if (n_dims == 5) { - if (ne[3] == 1 && ne[4] == 1) { - n_dims = 4; - } else { - LOG_ERROR("invalid tensor '%s'", name.c_str()); - return false; - } - } + // if (n_dims == 5) { + // if (ne[3] == 1 && ne[4] == 1) { + // n_dims = 4; + // } else { + // LOG_ERROR("invalid tensor '%s'", name.c_str()); + // return false; + // } + // } // ggml_n_dims returns 1 for scalars if (n_dims == 0) { @@ -1475,6 +1475,9 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { return VERSION_SVD; } + if (tensor_storage.name.find("model.diffusion_model.transformer_blocks") != std::string::npos){ + return VERSION_LTXV; + } if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || tensor_storage.name == "cond_stage_model.model.token_embedding.weight" || diff --git a/model.h b/model.h index 29d46c19..5602bdb8 100644 --- a/model.h +++ b/model.h @@ -15,7 +15,7 @@ #include "json.hpp" #include "zip.h" -#define SD_MAX_DIMS 5 +#define SD_MAX_DIMS 6 enum SDVersion { VERSION_SD1, @@ -24,6 +24,7 @@ enum SDVersion { VERSION_SVD, VERSION_SD3, VERSION_FLUX, + VERSION_LTXV, VERSION_COUNT, }; @@ -42,7 +43,7 @@ static inline bool sd_version_is_sd3(SDVersion version) { } static inline bool sd_version_is_dit(SDVersion version) { - if (sd_version_is_flux(version) || sd_version_is_sd3(version)) { + if (sd_version_is_flux(version) || sd_version_is_sd3(version) || version == VERSION_LTXV) { return true; } return false; @@ -59,7 +60,7 @@ struct TensorStorage { bool is_bf16 = false; bool is_f8_e4m3 = false; bool is_f8_e5m2 = false; - int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; + int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1, 1}; int n_dims = 0; size_t file_index = 0; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 5abc2950..12814a2c 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -30,7 +30,8 @@ const char* model_version_to_str[] = { "SDXL", "SVD", "SD3.x", - "Flux"}; + "Flux", + "LTX-Video"}; const char* sampling_methods_str[] = { "Euler A", @@ -330,6 +331,10 @@ class StableDiffusionGGML { } else if (sd_version_is_flux(version)) { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, diffusion_flash_attn); + } else if (version == VERSION_LTXV) { + // TODO: cond for T5 only + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, diffusion_flash_attn); } else { if (id_embeddings_path.find("v2") != std::string::npos) { cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2); @@ -1541,6 +1546,31 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, return result_images; } +sd_image_t* txt2vid(sd_ctx_t* sd_ctx, + const char* prompt_c_str, + const char* negative_prompt_c_str, + int clip_skip, + float cfg_scale, + int width, + int height, + enum sample_method_t sample_method, + int sample_steps, + int64_t seed, + int batch_count, + int* skip_layers = NULL, + size_t skip_layers_count = 0, + float slg_scale = 0, + float skip_layer_start = 0.01, + float skip_layer_end = 0.2) { + std::vector skip_layers_vec(skip_layers, skip_layers + skip_layers_count); + LOG_DEBUG("txt2vid %dx%d", width, height); + if (sd_ctx == NULL) { + return NULL; + } + LOG_ERROR("Unimplemented"); + return NULL; +} + sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_t init_image, const char* prompt_c_str, diff --git a/stable-diffusion.h b/stable-diffusion.h index c67bc8a3..0650f9a5 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -197,6 +197,23 @@ SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, float skip_layer_start, float skip_layer_end); +SD_API sd_image_t* txt2vid(sd_ctx_t* sd_ctx, + const char* prompt, + const char* negative_prompt, + int clip_skip, + float cfg_scale, + int width, + int height, + enum sample_method_t sample_method, + int sample_steps, + int64_t seed, + int batch_count, + int* skip_layers, + size_t skip_layers_count, + float slg_scale, + float skip_layer_start, + float skip_layer_end); + SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, sd_image_t init_image, int width,