Skip to content

Commit

Permalink
LTX: first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Stéphane du Hamel committed Dec 1, 2024
1 parent 9578fdc commit c5e01af
Show file tree
Hide file tree
Showing 8 changed files with 834 additions and 14 deletions.
167 changes: 167 additions & 0 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1190,4 +1190,171 @@ struct FluxCLIPEmbedder : public Conditioner {
}
};

struct SimpleT5Embedder : public Conditioner {
T5UniGramTokenizer t5_tokenizer;
std::shared_ptr<T5Runner> t5;

SimpleT5Embedder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
int clip_skip = -1) {
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
}

void alloc_params_buffer() {
t5->alloc_params_buffer();
}

void free_params_buffer() {
t5->free_params_buffer();
}

size_t get_params_buffer_size() {
size_t buffer_size = t5->get_params_buffer_size();
return buffer_size;
}

std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
auto parsed_attention = parse_prompt_attention(text);

{
std::stringstream ss;
ss << "[";
for (const auto& item : parsed_attention) {
ss << "['" << item.first << "', " << item.second << "], ";
}
ss << "]";
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
}

auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
return false;
};

std::vector<int> t5_tokens;
std::vector<float> t5_weights;
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;

std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
}

t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);

// for (int i = 0; i < clip_l_tokens.size(); i++) {
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
// }
// std::cout << std::endl;

// for (int i = 0; i < t5_tokens.size(); i++) {
// std::cout << t5_tokens[i] << ":" << t5_weights[i] << ", ";
// }
// std::cout << std::endl;

return {t5_tokens, t5_weights};
}

SDCondition get_learned_condition_common(ggml_context* work_ctx,
int n_threads,
std::pair<std::vector<int>, std::vector<float>> token_and_weights,
int clip_skip,
bool force_zero_embeddings = false) {
auto& t5_tokens = token_and_weights.first;
auto& t5_weights = token_and_weights.second;

int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096]
std::vector<float> hidden_states_vec;

size_t chunk_len = 256;
size_t chunk_count = t5_tokens.size() / chunk_len;
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
t5_weights.begin() + (chunk_idx + 1) * chunk_len);

auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);

t5->compute(n_threads,
input_ids,
&chunk_hidden_states,
work_ctx);
{
auto tensor = chunk_hidden_states;
float original_mean = ggml_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean));
}

int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
if (force_zero_embeddings) {
float* vec = (float*)chunk_hidden_states->data;
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
vec[i] = 0;
}
}

hidden_states_vec.insert(hidden_states_vec.end(),
(float*)chunk_hidden_states->data,
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
}

hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
hidden_states = ggml_reshape_2d(work_ctx,
hidden_states,
chunk_hidden_states->ne[0],
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
return SDCondition(hidden_states, NULL, NULL);
}

SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int adm_in_channels = -1,
bool force_zero_embeddings = false) {
auto tokens_and_weights = tokenize(text, 256, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
}

std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool force_zero_embeddings = false) {
GGML_ASSERT(0 && "Not implemented yet!");
}

std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) {
GGML_ASSERT(0 && "Not implemented yet!");
}
};

#endif
51 changes: 51 additions & 0 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define __DIFFUSION_MODEL_H__

#include "flux.hpp"
#include "ltx.hpp"
#include "mmdit.hpp"
#include "unet.hpp"

Expand Down Expand Up @@ -178,4 +179,54 @@ struct FluxModel : public DiffusionModel {
}
};

struct LTXModel : public DiffusionModel {
Ltx::LTXRunner ltx;

LTXModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
bool flash_attn = false)
: ltx(backend, tensor_types, "model.diffusion_model") {
}

void alloc_params_buffer() {
ltx.alloc_params_buffer();
}

void free_params_buffer() {
ltx.free_params_buffer();
}

void free_compute_buffer() {
ltx.free_compute_buffer();
}

void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
ltx.get_param_tensors(tensors, "model.diffusion_model");
}

size_t get_params_buffer_size() {
return ltx.get_params_buffer_size();
}

int64_t get_adm_in_channels() {
return 768;
}

void compute(int n_threads,
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor* c_concat,
struct ggml_tensor* y,
struct ggml_tensor* guidance,
int num_video_frames = -1,
std::vector<struct ggml_tensor*> controls = {},
float control_strength = 0.f,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL,
std::vector<int> skip_layers = std::vector<int>()) {
return ltx.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers);
}
};

#endif
21 changes: 20 additions & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,15 @@ const char* schedule_str[] = {
const char* modes_str[] = {
"txt2img",
"img2img",
"txt2vid",
"img2vid",
"convert",
};

enum SDMode {
TXT2IMG,
IMG2IMG,
TXT2VID,
IMG2VID,
CONVERT,
MODE_COUNT
Expand Down Expand Up @@ -264,7 +266,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
}
if (mode_found == -1) {
fprintf(stderr,
"error: invalid mode %s, must be one of [txt2img, img2img, img2vid, convert]\n",
"error: invalid mode %s, must be one of [txt2img, img2img, txt2vid, img2vid, convert]\n",
mode_selected);
exit(1);
}
Expand Down Expand Up @@ -931,6 +933,23 @@ int main(int argc, const char* argv[]) {
params.slg_scale,
params.skip_layer_start,
params.skip_layer_end);
} else if (params.mode == TXT2VID) {
results = txt2vid(sd_ctx,
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
params.cfg_scale,
params.width,
params.height,
params.sample_method,
params.sample_steps,
params.seed,
params.batch_count,
params.skip_layers.data(),
params.skip_layers.size(),
params.slg_scale,
params.skip_layer_start,
params.skip_layer_end);
} else {
sd_image_t input_image = {(uint32_t)params.width,
(uint32_t)params.height,
Expand Down
Loading

0 comments on commit c5e01af

Please sign in to comment.