feat: add Chroma support (#696)

---------

Co-authored-by: Green Sky <Green-Sky@users.noreply.github.com>
Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
stduhpf 2025-06-29 17:36:42 +02:00 committed by GitHub
parent 884e23eeeb
commit b1cc40c35c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 563 additions and 112 deletions

View File

@ -747,7 +747,7 @@ struct SD3CLIPEmbedder : public Conditioner {
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding);
clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding);
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
// for (int i = 0; i < clip_l_tokens.size(); i++) {
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@ -902,6 +902,7 @@ struct SD3CLIPEmbedder : public Conditioner {
t5->compute(n_threads,
input_ids,
NULL,
&chunk_hidden_states_t5,
work_ctx);
{
@ -1004,6 +1005,7 @@ struct FluxCLIPEmbedder : public Conditioner {
T5UniGramTokenizer t5_tokenizer;
std::shared_ptr<CLIPTextModelRunner> clip_l;
std::shared_ptr<T5Runner> t5;
size_t chunk_len = 256;
FluxCLIPEmbedder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
@ -1077,7 +1079,7 @@ struct FluxCLIPEmbedder : public Conditioner {
}
clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding);
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding);
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding);
// for (int i = 0; i < clip_l_tokens.size(); i++) {
// std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
@ -1109,7 +1111,6 @@ struct FluxCLIPEmbedder : public Conditioner {
struct ggml_tensor* pooled = NULL; // [768,]
std::vector<float> hidden_states_vec;
size_t chunk_len = 256;
size_t chunk_count = t5_tokens.size() / chunk_len;
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
// clip_l
@ -1147,6 +1148,7 @@ struct FluxCLIPEmbedder : public Conditioner {
t5->compute(n_threads,
input_ids,
NULL,
&chunk_hidden_states,
work_ctx);
{
@ -1196,7 +1198,208 @@ struct FluxCLIPEmbedder : public Conditioner {
int height,
int adm_in_channels = -1,
bool force_zero_embeddings = false) {
auto tokens_and_weights = tokenize(text, 256, true);
auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
}
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool force_zero_embeddings = false) {
GGML_ASSERT(0 && "Not implemented yet!");
}
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) {
GGML_ASSERT(0 && "Not implemented yet!");
}
};
struct PixArtCLIPEmbedder : public Conditioner {
T5UniGramTokenizer t5_tokenizer;
std::shared_ptr<T5Runner> t5;
size_t chunk_len = 512;
bool use_mask = false;
int mask_pad = 1;
PixArtCLIPEmbedder(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
int clip_skip = -1,
bool use_mask = false,
int mask_pad = 1) : use_mask(use_mask), mask_pad(mask_pad) {
t5 = std::make_shared<T5Runner>(backend, tensor_types, "text_encoders.t5xxl.transformer");
}
void set_clip_skip(int clip_skip) {
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer");
}
void alloc_params_buffer() {
t5->alloc_params_buffer();
}
void free_params_buffer() {
t5->free_params_buffer();
}
size_t get_params_buffer_size() {
size_t buffer_size = 0;
buffer_size += t5->get_params_buffer_size();
return buffer_size;
}
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
auto parsed_attention = parse_prompt_attention(text);
{
std::stringstream ss;
ss << "[";
for (const auto& item : parsed_attention) {
ss << "['" << item.first << "', " << item.second << "], ";
}
ss << "]";
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
}
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
return false;
};
std::vector<int> t5_tokens;
std::vector<float> t5_weights;
std::vector<float> t5_mask;
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = t5_tokenizer.Encode(curr_text, true);
t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end());
t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight);
}
t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding);
return {t5_tokens, t5_weights, t5_mask};
}
void modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) {
float* mask_data = (float*)mask->data;
int num_pad = 0;
for (int64_t i = 0; i < max_seq_length; i++) {
if (num_pad >= num_extra_padding) {
break;
}
if (std::isinf(mask_data[i])) {
mask_data[i] = 0;
++num_pad;
}
}
// LOG_DEBUG("PAD: %d", num_pad);
}
SDCondition get_learned_condition_common(ggml_context* work_ctx,
int n_threads,
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> token_and_weights,
int clip_skip,
bool force_zero_embeddings = false) {
auto& t5_tokens = std::get<0>(token_and_weights);
auto& t5_weights = std::get<1>(token_and_weights);
auto& t5_attn_mask_vec = std::get<2>(token_and_weights);
int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096]
struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096]
struct ggml_tensor* pooled = NULL; // [768,]
struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [768,]
std::vector<float> hidden_states_vec;
size_t chunk_count = t5_tokens.size() / chunk_len;
for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) {
// t5
std::vector<int> chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len,
t5_tokens.begin() + (chunk_idx + 1) * chunk_len);
std::vector<float> chunk_weights(t5_weights.begin() + chunk_idx * chunk_len,
t5_weights.begin() + (chunk_idx + 1) * chunk_len);
std::vector<float> chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len,
t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len);
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens);
auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL;
t5->compute(n_threads,
input_ids,
t5_attn_mask_chunk,
&chunk_hidden_states,
work_ctx);
{
auto tensor = chunk_hidden_states;
float original_mean = ggml_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
value *= chunk_weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean));
}
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
if (force_zero_embeddings) {
float* vec = (float*)chunk_hidden_states->data;
for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) {
vec[i] = 0;
}
}
hidden_states_vec.insert(hidden_states_vec.end(),
(float*)chunk_hidden_states->data,
((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states));
}
if (hidden_states_vec.size() > 0) {
hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec);
hidden_states = ggml_reshape_2d(work_ctx,
hidden_states,
chunk_hidden_states->ne[0],
ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]);
} else {
hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256);
ggml_set_f32(hidden_states, 0.f);
}
modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad);
return SDCondition(hidden_states, t5_attn_mask, NULL);
}
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int adm_in_channels = -1,
bool force_zero_embeddings = false) {
auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
}

View File

@ -137,8 +137,9 @@ struct FluxModel : public DiffusionModel {
FluxModel(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types,
SDVersion version = VERSION_FLUX,
bool flash_attn = false)
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
bool flash_attn = false,
bool use_mask = false)
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) {
}
void alloc_params_buffer() {

View File

@ -132,6 +132,10 @@ struct SDParams {
float slg_scale = 0.f;
float skip_layer_start = 0.01f;
float skip_layer_end = 0.2f;
bool chroma_use_dit_mask = true;
bool chroma_use_t5_mask = false;
int chroma_t5_mask_pad = 1;
};
void print_params(SDParams params) {
@ -185,6 +189,9 @@ void print_params(SDParams params) {
printf(" batch_count: %d\n", params.batch_count);
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
printf(" upscale_repeats: %d\n", params.upscale_repeats);
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
}
void print_usage(int argc, const char* argv[]) {
@ -252,6 +259,9 @@ void print_usage(int argc, const char* argv[]) {
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color colors the logging tags according to level\n");
printf(" --chroma-disable-dit-mask disable dit mask for chroma\n");
printf(" --chroma-enable-t5-mask enable t5 mask for chroma\n");
printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n");
printf(" -v, --verbose print extra info\n");
}
@ -643,6 +653,16 @@ void parse_args(int argc, const char** argv, SDParams& params) {
break;
}
params.ref_image_paths.push_back(argv[i]);
} else if (arg == "chroma-disable-dit-mask") {
params.chroma_use_dit_mask = false;
} else if (arg == "--chroma-use-t5-mask") {
params.chroma_use_t5_mask = true;
} else if (arg == "--chroma-t5-mask-pad") {
if (++i >= argc) {
invalid_arg = true;
break;
}
params.chroma_t5_mask_pad = std::stoi(argv[i]);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv);
@ -952,7 +972,10 @@ int main(int argc, const char* argv[]) {
params.clip_on_cpu,
params.control_net_cpu,
params.vae_on_cpu,
params.diffusion_flash_attn);
params.diffusion_flash_attn,
params.chroma_use_dit_mask,
params.chroma_use_t5_mask,
params.chroma_t5_mask_pad);
if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");

335
flux.hpp
View File

@ -117,6 +117,7 @@ namespace Flux {
struct ggml_tensor* k,
struct ggml_tensor* v,
struct ggml_tensor* pe,
struct ggml_tensor* mask,
bool flash_attn) {
// q,k,v: [N, L, n_head, d_head]
// pe: [L, d_head/2, 2, 2]
@ -124,7 +125,7 @@ namespace Flux {
q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head]
k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head]
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true, flash_attn); // [N, L, n_head*d_head]
auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head]
return x;
}
@ -167,13 +168,13 @@ namespace Flux {
return x;
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe) {
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask) {
// x: [N, n_token, dim]
// pe: [n_token, d_head/2, 2, 2]
// return [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head]
x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
return x;
}
};
@ -185,6 +186,13 @@ namespace Flux {
ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL)
: shift(shift), scale(scale), gate(gate) {}
ModulationOut(struct ggml_context* ctx, ggml_tensor* vec, int64_t offset) {
int64_t stride = vec->nb[1] * vec->ne[1];
shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
gate = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim]
}
};
struct Modulation : public GGMLBlock {
@ -210,19 +218,12 @@ namespace Flux {
auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim]
int64_t offset = m->nb[1] * m->ne[1];
auto shift_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, dim]
auto scale_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, dim]
auto gate_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, dim]
ModulationOut m_0 = ModulationOut(ctx, m, 0);
if (is_double) {
auto shift_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, dim]
auto scale_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, dim]
auto gate_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, dim]
return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut(shift_1, scale_1, gate_1)};
return {m_0, ModulationOut(ctx, m, 3)};
}
return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut()};
return {m_0, ModulationOut()};
}
};
@ -242,25 +243,33 @@ namespace Flux {
struct DoubleStreamBlock : public GGMLBlock {
bool flash_attn;
bool prune_mod;
int idx = 0;
public:
DoubleStreamBlock(int64_t hidden_size,
int64_t num_heads,
float mlp_ratio,
int idx = 0,
bool qkv_bias = false,
bool flash_attn = false)
: flash_attn(flash_attn) {
bool flash_attn = false,
bool prune_mod = false)
: idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) {
int64_t mlp_hidden_dim = hidden_size * mlp_ratio;
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
if (!prune_mod) {
blocks["img_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["img_mlp.0"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, mlp_hidden_dim));
// img_mlp.1 is nn.GELU(approximate="tanh")
blocks["img_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size));
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
if (!prune_mod) {
blocks["txt_mod"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, true));
}
blocks["txt_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
blocks["txt_attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn));
@ -270,17 +279,34 @@ namespace Flux {
blocks["txt_mlp.2"] = std::shared_ptr<GGMLBlock>(new Linear(mlp_hidden_dim, hidden_size));
}
std::vector<ModulationOut> get_distil_img_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
// TODO: not hardcoded?
const int single_blocks_count = 38;
const int double_blocks_count = 19;
int64_t offset = 6 * idx + 3 * single_blocks_count;
return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)};
}
std::vector<ModulationOut> get_distil_txt_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
// TODO: not hardcoded?
const int single_blocks_count = 38;
const int double_blocks_count = 19;
int64_t offset = 6 * idx + 6 * double_blocks_count + 3 * single_blocks_count;
return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)};
}
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
struct ggml_tensor* img,
struct ggml_tensor* txt,
struct ggml_tensor* vec,
struct ggml_tensor* pe) {
struct ggml_tensor* pe,
struct ggml_tensor* mask = NULL) {
// img: [N, n_img_token, hidden_size]
// txt: [N, n_txt_token, hidden_size]
// pe: [n_img_token + n_txt_token, d_head/2, 2, 2]
// return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size])
auto img_mod = std::dynamic_pointer_cast<Modulation>(blocks["img_mod"]);
auto img_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["img_norm1"]);
auto img_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["img_attn"]);
@ -288,7 +314,6 @@ namespace Flux {
auto img_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.0"]);
auto img_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["img_mlp.2"]);
auto txt_mod = std::dynamic_pointer_cast<Modulation>(blocks["txt_mod"]);
auto txt_norm1 = std::dynamic_pointer_cast<LayerNorm>(blocks["txt_norm1"]);
auto txt_attn = std::dynamic_pointer_cast<SelfAttention>(blocks["txt_attn"]);
@ -296,10 +321,22 @@ namespace Flux {
auto txt_mlp_0 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.0"]);
auto txt_mlp_2 = std::dynamic_pointer_cast<Linear>(blocks["txt_mlp.2"]);
auto img_mods = img_mod->forward(ctx, vec);
std::vector<ModulationOut> img_mods;
if (prune_mod) {
img_mods = get_distil_img_mod(ctx, vec);
} else {
auto img_mod = std::dynamic_pointer_cast<Modulation>(blocks["img_mod"]);
img_mods = img_mod->forward(ctx, vec);
}
ModulationOut img_mod1 = img_mods[0];
ModulationOut img_mod2 = img_mods[1];
auto txt_mods = txt_mod->forward(ctx, vec);
std::vector<ModulationOut> txt_mods;
if (prune_mod) {
txt_mods = get_distil_txt_mod(ctx, vec);
} else {
auto txt_mod = std::dynamic_pointer_cast<Modulation>(blocks["txt_mod"]);
txt_mods = txt_mod->forward(ctx, vec);
}
ModulationOut txt_mod1 = txt_mods[0];
ModulationOut txt_mod2 = txt_mods[1];
@ -324,7 +361,7 @@ namespace Flux {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx,
attn,
@ -373,14 +410,18 @@ namespace Flux {
int64_t hidden_size;
int64_t mlp_hidden_dim;
bool flash_attn;
bool prune_mod;
int idx = 0;
public:
SingleStreamBlock(int64_t hidden_size,
int64_t num_heads,
float mlp_ratio = 4.0f,
int idx = 0,
float qk_scale = 0.f,
bool flash_attn = false)
: hidden_size(hidden_size), num_heads(num_heads), flash_attn(flash_attn) {
bool flash_attn = false,
bool prune_mod = false)
: hidden_size(hidden_size), num_heads(num_heads), idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) {
int64_t head_dim = hidden_size / num_heads;
float scale = qk_scale;
if (scale <= 0.f) {
@ -393,26 +434,37 @@ namespace Flux {
blocks["norm"] = std::shared_ptr<GGMLBlock>(new QKNorm(head_dim));
blocks["pre_norm"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-6f, false));
// mlp_act is nn.GELU(approximate="tanh")
blocks["modulation"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, false));
if (!prune_mod) {
blocks["modulation"] = std::shared_ptr<GGMLBlock>(new Modulation(hidden_size, false));
}
}
ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
int64_t offset = 3 * idx;
return ModulationOut(ctx, vec, offset);
}
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* vec,
struct ggml_tensor* pe) {
struct ggml_tensor* pe,
struct ggml_tensor* mask = NULL) {
// x: [N, n_token, hidden_size]
// pe: [n_token, d_head/2, 2, 2]
// return: [N, n_token, hidden_size]
auto linear1 = std::dynamic_pointer_cast<Linear>(blocks["linear1"]);
auto linear2 = std::dynamic_pointer_cast<Linear>(blocks["linear2"]);
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
auto pre_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_norm"]);
auto modulation = std::dynamic_pointer_cast<Modulation>(blocks["modulation"]);
auto mods = modulation->forward(ctx, vec);
ModulationOut mod = mods[0];
auto linear1 = std::dynamic_pointer_cast<Linear>(blocks["linear1"]);
auto linear2 = std::dynamic_pointer_cast<Linear>(blocks["linear2"]);
auto norm = std::dynamic_pointer_cast<QKNorm>(blocks["norm"]);
auto pre_norm = std::dynamic_pointer_cast<LayerNorm>(blocks["pre_norm"]);
ModulationOut mod;
if (prune_mod) {
mod = get_distil_mod(ctx, vec);
} else {
auto modulation = std::dynamic_pointer_cast<Modulation>(blocks["modulation"]);
mod = modulation->forward(ctx, vec)[0];
}
auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale);
auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim]
qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token]
@ -443,7 +495,7 @@ namespace Flux {
auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head]
q = norm->query_norm(ctx, q);
k = norm->key_norm(ctx, k);
auto attn = attention(ctx, q, k, v, pe, flash_attn); // [N, n_token, hidden_size]
auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size]
auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim]
auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size]
@ -454,13 +506,27 @@ namespace Flux {
};
struct LastLayer : public GGMLBlock {
bool prune_mod;
public:
LastLayer(int64_t hidden_size,
int64_t patch_size,
int64_t out_channels) {
blocks["norm_final"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels));
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size));
int64_t out_channels,
bool prune_mod = false) : prune_mod(prune_mod) {
blocks["norm_final"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["linear"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, patch_size * patch_size * out_channels));
if (!prune_mod) {
blocks["adaLN_modulation.1"] = std::shared_ptr<GGMLBlock>(new Linear(hidden_size, 2 * hidden_size));
}
}
ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
int64_t offset = vec->ne[2] - 2;
int64_t stride = vec->nb[1] * vec->ne[1];
auto shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim]
auto scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim]
// No gate
return ModulationOut(shift, scale, NULL);
}
struct ggml_tensor* forward(struct ggml_context* ctx,
@ -469,17 +535,24 @@ namespace Flux {
// x: [N, n_token, hidden_size]
// c: [N, hidden_size]
// return: [N, n_token, patch_size * patch_size * out_channels]
auto norm_final = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_final"]);
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto norm_final = std::dynamic_pointer_cast<LayerNorm>(blocks["norm_final"]);
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
struct ggml_tensor *shift, *scale;
if (prune_mod) {
auto mod = get_distil_mod(ctx, c);
shift = mod.shift;
scale = mod.scale;
} else {
auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks["adaLN_modulation.1"]);
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size]
m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size]
m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
int64_t offset = m->nb[1] * m->ne[1];
shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size]
scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size]
}
x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale);
x = linear->forward(ctx, x);
@ -488,6 +561,34 @@ namespace Flux {
}
};
struct ChromaApproximator : public GGMLBlock {
int64_t inner_size = 5120;
int64_t n_layers = 5;
ChromaApproximator(int64_t in_channels = 64, int64_t hidden_size = 3072) {
blocks["in_proj"] = std::shared_ptr<GGMLBlock>(new Linear(in_channels, inner_size, true));
for (int i = 0; i < n_layers; i++) {
blocks["norms." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new RMSNorm(inner_size));
blocks["layers." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(inner_size, inner_size));
}
blocks["out_proj"] = std::shared_ptr<GGMLBlock>(new Linear(inner_size, hidden_size, true));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
auto in_proj = std::dynamic_pointer_cast<Linear>(blocks["in_proj"]);
auto out_proj = std::dynamic_pointer_cast<Linear>(blocks["out_proj"]);
x = in_proj->forward(ctx, x);
for (int i = 0; i < n_layers; i++) {
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norms." + std::to_string(i)]);
auto embed = std::dynamic_pointer_cast<MLPEmbedder>(blocks["layers." + std::to_string(i)]);
x = ggml_add_inplace(ctx, x, embed->forward(ctx, norm->forward(ctx, x)));
}
x = out_proj->forward(ctx, x);
return x;
}
};
struct FluxParams {
int64_t in_channels = 64;
int64_t out_channels = 64;
@ -504,6 +605,7 @@ namespace Flux {
bool qkv_bias = true;
bool guidance_embed = true;
bool flash_attn = true;
bool is_chroma = false;
};
struct Flux : public GGMLBlock {
@ -642,6 +744,7 @@ namespace Flux {
return ids;
}
// Generate positional embeddings
std::vector<float> gen_pe(int h, int w, int patch_size, int bs, int context_len, std::vector<ggml_tensor*> ref_latents, int theta, const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_ids(h, w, patch_size, bs, context_len, ref_latents);
@ -680,11 +783,15 @@ namespace Flux {
: params(params) {
int64_t pe_dim = params.hidden_size / params.num_heads;
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
blocks["vector_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(params.vec_in_dim, params.hidden_size));
if (params.guidance_embed) {
blocks["guidance_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
if (params.is_chroma) {
blocks["distilled_guidance_layer"] = std::shared_ptr<GGMLBlock>(new ChromaApproximator(params.in_channels, params.hidden_size));
} else {
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
blocks["vector_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(params.vec_in_dim, params.hidden_size));
if (params.guidance_embed) {
blocks["guidance_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
}
}
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.context_in_dim, params.hidden_size, true));
@ -692,19 +799,23 @@ namespace Flux {
blocks["double_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new DoubleStreamBlock(params.hidden_size,
params.num_heads,
params.mlp_ratio,
i,
params.qkv_bias,
params.flash_attn));
params.flash_attn,
params.is_chroma));
}
for (int i = 0; i < params.depth_single_blocks; i++) {
blocks["single_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new SingleStreamBlock(params.hidden_size,
params.num_heads,
params.mlp_ratio,
i,
0.f,
params.flash_attn));
params.flash_attn,
params.is_chroma));
}
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels));
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels, params.is_chroma));
}
struct ggml_tensor* patchify(struct ggml_context* ctx,
@ -761,25 +872,55 @@ namespace Flux {
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe,
struct ggml_tensor* mod_index_arange = NULL,
std::vector<int> skip_layers = {}) {
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
auto txt_in = std::dynamic_pointer_cast<Linear>(blocks["txt_in"]);
auto final_layer = std::dynamic_pointer_cast<LastLayer>(blocks["final_layer"]);
img = img_in->forward(ctx, img);
auto vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f));
img = img_in->forward(ctx, img);
struct ggml_tensor* vec;
struct ggml_tensor* txt_img_mask = NULL;
if (params.is_chroma) {
int64_t mod_index_length = 344;
auto approx = std::dynamic_pointer_cast<ChromaApproximator>(blocks["distilled_guidance_layer"]);
auto distill_timestep = ggml_nn_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f);
auto distill_guidance = ggml_nn_timestep_embedding(ctx, guidance, 16, 10000, 1000.f);
if (params.guidance_embed) {
GGML_ASSERT(guidance != NULL);
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
// bf16 and fp16 result is different
auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f);
vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in));
// auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1);
// ggml_arange tot working on a lot of backends, precomputing it on CPU instead
GGML_ASSERT(arange != NULL);
auto modulation_index = ggml_nn_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32]
// Batch broadcast (will it ever be useful)
modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32]
auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32]
timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32]
vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64]
// Permute for consistency with non-distilled modulation implementation
vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64]
vec = approx->forward(ctx, vec); // [344, N, hidden_size]
if (y != NULL) {
txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0);
}
} else {
auto time_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["time_in"]);
auto vector_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["vector_in"]);
vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f));
if (params.guidance_embed) {
GGML_ASSERT(guidance != NULL);
auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks["guidance_in"]);
// bf16 and fp16 result is different
auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f);
vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in));
}
vec = ggml_add(ctx, vec, vector_in->forward(ctx, y));
}
vec = ggml_add(ctx, vec, vector_in->forward(ctx, y));
txt = txt_in->forward(ctx, txt);
for (int i = 0; i < params.depth; i++) {
@ -789,7 +930,7 @@ namespace Flux {
auto block = std::dynamic_pointer_cast<DoubleStreamBlock>(blocks["double_blocks." + std::to_string(i)]);
auto img_txt = block->forward(ctx, img, txt, vec, pe);
auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask);
img = img_txt.first; // [N, n_img_token, hidden_size]
txt = img_txt.second; // [N, n_txt_token, hidden_size]
}
@ -801,7 +942,7 @@ namespace Flux {
}
auto block = std::dynamic_pointer_cast<SingleStreamBlock>(blocks["single_blocks." + std::to_string(i)]);
txt_img = block->forward(ctx, txt_img, vec, pe);
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask);
}
txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
@ -816,7 +957,6 @@ namespace Flux {
img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
return img;
}
@ -843,6 +983,7 @@ namespace Flux {
struct ggml_tensor* y,
struct ggml_tensor* guidance,
struct ggml_tensor* pe,
struct ggml_tensor* mod_index_arange = NULL,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<int> skip_layers = {}) {
// Forward pass of DiT.
@ -884,7 +1025,7 @@ namespace Flux {
}
}
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
@ -904,14 +1045,18 @@ namespace Flux {
public:
FluxParams flux_params;
Flux flux;
std::vector<float> pe_vec; // for cache
std::vector<float> pe_vec;
std::vector<float> mod_index_arange_vec; // for cache
SDVersion version;
bool use_mask = false;
FluxRunner(ggml_backend_t backend,
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
bool flash_attn = false)
: GGMLRunner(backend) {
bool flash_attn = false,
bool use_mask = false)
: GGMLRunner(backend), use_mask(use_mask) {
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;
flux_params.depth = 0;
@ -927,6 +1072,10 @@ namespace Flux {
// not schnell
flux_params.guidance_embed = true;
}
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
// Chroma
flux_params.is_chroma = true;
}
size_t db = tensor_name.find("double_blocks.");
if (db != std::string::npos) {
tensor_name = tensor_name.substr(db); // remove prefix
@ -946,7 +1095,9 @@ namespace Flux {
}
LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks);
if (!flux_params.guidance_embed) {
if (flux_params.is_chroma) {
LOG_INFO("Using pruned modulation (Chroma)");
} else if (!flux_params.guidance_embed) {
LOG_INFO("Flux guidance is disabled (Schnell mode)");
}
@ -969,18 +1120,33 @@ namespace Flux {
struct ggml_tensor* y,
struct ggml_tensor* guidance,
std::vector<ggml_tensor*> ref_latents = {},
std::vector<int> skip_layers = std::vector<int>()) {
std::vector<int> skip_layers = {}) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
struct ggml_tensor* mod_index_arange = NULL;
x = to_backend(x);
context = to_backend(context);
if (c_concat != NULL) {
c_concat = to_backend(c_concat);
}
y = to_backend(y);
if (flux_params.is_chroma) {
guidance = ggml_set_f32(guidance, 0);
if (!use_mask) {
y = NULL;
}
// ggml_arange is not working on some backends, precompute it
mod_index_arange_vec = arange(0, 344);
mod_index_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, mod_index_arange_vec.size());
set_backend_tensor_data(mod_index_arange, mod_index_arange_vec.data());
}
y = to_backend(y);
timesteps = to_backend(timesteps);
if (flux_params.guidance_embed) {
if (flux_params.guidance_embed || flux_params.is_chroma) {
guidance = to_backend(guidance);
}
for (int i = 0; i < ref_latents.size(); i++) {
@ -1004,6 +1170,7 @@ namespace Flux {
y,
guidance,
pe,
mod_index_arange,
ref_latents,
skip_layers);

View File

@ -864,6 +864,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
v = ggml_cast(ctx, v, GGML_TYPE_F16);
if (mask != nullptr) {
mask = ggml_transpose(ctx, mask);
if (mask->ne[1] < GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)) {
LOG_DEBUG("mask dims %ld, %ld, %ld, %ld\n", mask->ne[0], mask->ne[1], mask->ne[2], mask->ne[3]);
LOG_DEBUG("needs padding, padding from %ld to %ld\n", mask->ne[1], GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD));
mask = ggml_pad(ctx, mask, 0, GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) - mask->ne[1], 0, 0);
}
mask = ggml_cast(ctx, mask, GGML_TYPE_F16);
}
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
@ -876,7 +888,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
kq = ggml_scale_inplace(ctx, kq, scale);
if (mask) {
kq = ggml_add(ctx, kq, mask);
kq = ggml_add_inplace(ctx, kq, mask);
}
if (diag_mask_inf) {
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);

View File

@ -159,7 +159,10 @@ public:
bool clip_on_cpu,
bool control_net_cpu,
bool vae_on_cpu,
bool diffusion_flash_attn) {
bool diffusion_flash_attn,
bool chroma_use_dit_mask,
bool chroma_use_t5_mask,
int chroma_t5_mask_pad) {
use_tiny_autoencoder = taesd_path.size() > 0;
#ifdef SD_USE_CUDA
LOG_DEBUG("Using CUDA backend");
@ -334,8 +337,19 @@ public:
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types);
} else if (sd_version_is_flux(version)) {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);
bool is_chroma = false;
for (auto pair : model_loader.tensor_storages_types) {
if (pair.first.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
is_chroma = true;
break;
}
}
if (is_chroma) {
cond_stage_model = std::make_shared<PixArtCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types, -1, chroma_use_t5_mask, chroma_t5_mask_pad);
} else {
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
}
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn, chroma_use_dit_mask);
} else {
if (id_embeddings_path.find("v2") != std::string::npos) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);
@ -1135,7 +1149,10 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu,
bool diffusion_flash_attn) {
bool diffusion_flash_attn,
bool chroma_use_dit_mask,
bool chroma_use_t5_mask,
int chroma_t5_mask_pad) {
sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t));
if (sd_ctx == NULL) {
return NULL;
@ -1177,7 +1194,10 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str,
keep_clip_on_cpu,
keep_control_net_cpu,
keep_vae_on_cpu,
diffusion_flash_attn)) {
diffusion_flash_attn,
chroma_use_dit_mask,
chroma_use_t5_mask,
chroma_t5_mask_pad)) {
delete sd_ctx->sd;
sd_ctx->sd = NULL;
free(sd_ctx);

View File

@ -150,7 +150,10 @@ SD_API sd_ctx_t* new_sd_ctx(const char* model_path,
bool keep_clip_on_cpu,
bool keep_control_net_cpu,
bool keep_vae_on_cpu,
bool diffusion_flash_attn);
bool diffusion_flash_attn,
bool chroma_use_dit_mask,
bool chroma_use_t5_mask,
int chroma_t5_mask_pad);
SD_API void free_sd_ctx(sd_ctx_t* sd_ctx);

50
t5.hpp
View File

@ -385,6 +385,7 @@ public:
void pad_tokens(std::vector<int>& tokens,
std::vector<float>& weights,
std::vector<float>* attention_mask,
size_t max_length = 0,
bool padding = false) {
if (max_length > 0 && padding) {
@ -397,11 +398,15 @@ public:
LOG_DEBUG("token length: %llu", length);
std::vector<int> new_tokens;
std::vector<float> new_weights;
std::vector<float> new_attention_mask;
int token_idx = 0;
for (int i = 0; i < length; i++) {
if (token_idx >= orig_token_num) {
break;
}
if (attention_mask != nullptr) {
new_attention_mask.push_back(0.0);
}
if (i % max_length == max_length - 1) {
new_tokens.push_back(eos_id_);
new_weights.push_back(1.0);
@ -414,13 +419,24 @@ public:
new_tokens.push_back(eos_id_);
new_weights.push_back(1.0);
if (attention_mask != nullptr) {
new_attention_mask.push_back(0.0);
}
tokens = new_tokens;
weights = new_weights;
if (attention_mask != nullptr) {
*attention_mask = new_attention_mask;
}
if (padding) {
int pad_token_id = pad_id_;
tokens.insert(tokens.end(), length - tokens.size(), pad_token_id);
weights.insert(weights.end(), length - weights.size(), 1.0);
if (attention_mask != nullptr) {
// maybe keep some padding tokens unmasked?
attention_mask->insert(attention_mask->end(), length - attention_mask->size(), -HUGE_VALF);
}
}
}
}
@ -579,6 +595,7 @@ public:
}
if (past_bias != NULL) {
if (mask != NULL) {
mask = ggml_repeat(ctx, mask, past_bias);
mask = ggml_add(ctx, mask, past_bias);
} else {
mask = past_bias;
@ -739,15 +756,17 @@ struct T5Runner : public GGMLRunner {
struct ggml_tensor* forward(struct ggml_context* ctx,
struct ggml_tensor* input_ids,
struct ggml_tensor* relative_position_bucket) {
struct ggml_tensor* relative_position_bucket,
struct ggml_tensor* attention_mask = NULL) {
size_t N = input_ids->ne[1];
size_t n_token = input_ids->ne[0];
auto hidden_states = model.forward(ctx, input_ids, NULL, NULL, relative_position_bucket); // [N, n_token, model_dim]
auto hidden_states = model.forward(ctx, input_ids, NULL, attention_mask, relative_position_bucket); // [N, n_token, model_dim]
return hidden_states;
}
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) {
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
struct ggml_tensor* attention_mask = NULL) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
input_ids = to_backend(input_ids);
@ -767,7 +786,7 @@ struct T5Runner : public GGMLRunner {
input_ids->ne[0]);
set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data());
struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket);
struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket, attention_mask);
ggml_build_forward_expand(gf, hidden_states);
@ -776,10 +795,11 @@ struct T5Runner : public GGMLRunner {
void compute(const int n_threads,
struct ggml_tensor* input_ids,
struct ggml_tensor* attention_mask,
ggml_tensor** output,
ggml_context* output_ctx = NULL) {
ggml_context* output_ctx = NULL) {
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(input_ids);
return build_graph(input_ids, attention_mask);
};
GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
}
@ -877,9 +897,9 @@ struct T5Embedder {
model.alloc_params_buffer();
}
std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
std::tuple<std::vector<int>, std::vector<float>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
auto parsed_attention = parse_prompt_attention(text);
{
@ -906,14 +926,16 @@ struct T5Embedder {
tokens.push_back(EOS_TOKEN_ID);
weights.push_back(1.0);
tokenizer.pad_tokens(tokens, weights, max_length, padding);
std::vector<float> attention_mask;
tokenizer.pad_tokens(tokens, weights, &attention_mask, max_length, padding);
// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", ";
// }
// std::cout << std::endl;
return {tokens, weights};
return {tokens, weights, attention_mask};
}
void test() {
@ -934,8 +956,8 @@ struct T5Embedder {
// TODO: fix cuda nan
std::string text("a lovely cat");
auto tokens_and_weights = tokenize(text, 77, true);
std::vector<int>& tokens = tokens_and_weights.first;
std::vector<float>& weights = tokens_and_weights.second;
std::vector<int>& tokens = std::get<0>(tokens_and_weights);
std::vector<float>& weights = std::get<1>(tokens_and_weights);
for (auto token : tokens) {
printf("%d ", token);
}
@ -944,7 +966,7 @@ struct T5Embedder {
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
model.compute(8, input_ids, &out, work_ctx);
model.compute(8, input_ids, NULL, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);