feat: add extra_c_crossattns support for llm embedder (#1265)

This commit is contained in:
leejet 2026-02-10 00:00:17 +08:00 committed by GitHub
parent d60fb27560
commit 3296545090
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -10,9 +10,14 @@ struct SDCondition {
struct ggml_tensor* c_vector = nullptr; // aka y
struct ggml_tensor* c_concat = nullptr;
std::vector<struct ggml_tensor*> extra_c_crossattns;
SDCondition() = default;
SDCondition(struct ggml_tensor* c_crossattn, struct ggml_tensor* c_vector, struct ggml_tensor* c_concat)
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {}
SDCondition(struct ggml_tensor* c_crossattn,
struct ggml_tensor* c_vector,
struct ggml_tensor* c_concat,
const std::vector<struct ggml_tensor*>& extra_c_crossattns = {})
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat), extra_c_crossattns(extra_c_crossattns) {}
};
struct ConditionerParams {
@ -1696,10 +1701,11 @@ struct LLMEmbedder : public Conditioner {
}
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
std::pair<int, int> attn_range,
const std::pair<int, int>& attn_range,
size_t max_length = 0,
bool padding = false) {
std::vector<std::pair<std::string, float>> parsed_attention;
if (attn_range.first >= 0 && attn_range.second > 0) {
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
if (attn_range.second - attn_range.first > 0) {
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
@ -1708,6 +1714,10 @@ struct LLMEmbedder : public Conditioner {
new_parsed_attention.end());
}
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
} else {
parsed_attention.emplace_back(text, 1.f);
}
{
std::stringstream ss;
ss << "[";
@ -1738,19 +1748,110 @@ struct LLMEmbedder : public Conditioner {
return {tokens, weights};
}
ggml_tensor* encode_prompt(ggml_context* work_ctx,
int n_threads,
const std::string prompt,
const std::pair<int, int>& prompt_attn_range,
int max_length,
int min_length,
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
const std::set<int>& out_layers,
int prompt_template_encode_start_idx) {
auto tokens_and_weights = tokenize(prompt, prompt_attn_range);
auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights);
std::vector<float> mask;
if (max_length > 0 && tokens.size() < max_length) {
mask.insert(mask.end(), tokens.size(), 1.f);
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
tokenizer->pad_tokens(tokens, weights, max_length, true);
}
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, hidden_size]
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
ggml_tensor* attention_mask = nullptr;
if (!mask.empty()) {
attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size());
ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.f;
if (mask[i0] == 0.f) {
value = -INFINITY;
} else if (i0 > i1) {
value = -INFINITY;
}
ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3);
});
}
llm->compute(n_threads,
input_ids,
attention_mask,
image_embeds,
out_layers,
&hidden_states,
work_ctx);
{
auto tensor = hidden_states;
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= weights[i1];
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_ext_tensor_mean(tensor);
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
int64_t zero_pad_len = 0;
if (min_length > 0) {
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
}
}
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
GGML_TYPE_F32,
hidden_states->ne[0],
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
hidden_states->ne[2]);
ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.f;
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
}
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});
return new_hidden_states;
}
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const ConditionerParams& conditioner_params) override {
std::string prompt;
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
std::pair<int, int> prompt_attn_range;
std::vector<std::string> extra_prompts;
std::vector<std::pair<int, int>> extra_prompts_attn_range;
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
int prompt_template_encode_start_idx = 34;
int max_length = 0;
int max_length = 0; // pad tokens
int min_length = 0; // zero pad hidden_states
std::set<int> out_layers;
std::vector<int> tokens;
std::vector<float> weights;
std::vector<float> mask;
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
int64_t t0 = ggml_time_ms();
if (sd_version_is_qwen_image(version)) {
if (llm->enable_vision && !conditioner_params.ref_images.empty()) {
LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64;
int image_embed_idx = 64 + 6;
@ -1813,8 +1914,20 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else {
prompt_template_encode_start_idx = 34;
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
}
} else if (version == VERSION_FLUX2) {
prompt_template_encode_start_idx = 0;
min_length = 512;
out_layers = {10, 20, 30};
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
@ -1828,6 +1941,15 @@ struct LLMEmbedder : public Conditioner {
prompt_template_encode_start_idx = 0;
out_layers = {35}; // -2
if (!conditioner_params.ref_images.empty()) {
LOG_INFO("ZImageOmniPipeline");
prompt = "<|im_start|>user\n<|vision_start|>";
for (int i = 0; i < conditioner_params.ref_images.size() - 1; i++) {
extra_prompts.push_back("<|vision_end|><|vision_start|>");
}
extra_prompts.push_back("<|vision_end|>" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>");
extra_prompts.push_back("<|vision_end|><|im_end|>");
} else {
prompt = "<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
@ -1835,6 +1957,7 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
}
} else if (version == VERSION_FLUX2_KLEIN) {
prompt_template_encode_start_idx = 0;
max_length = 512;
@ -1847,16 +1970,6 @@ struct LLMEmbedder : public Conditioner {
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
tokens = std::get<0>(tokens_and_weights);
weights = std::get<1>(tokens_and_weights);
mask.insert(mask.end(), tokens.size(), 1.f);
if (tokens.size() < max_length) {
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
tokenizer->pad_tokens(tokens, weights, max_length, true);
}
} else if (version == VERSION_OVIS_IMAGE) {
prompt_template_encode_start_idx = 28;
max_length = prompt_template_encode_start_idx + 256;
@ -1869,98 +1982,36 @@ struct LLMEmbedder : public Conditioner {
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
} else {
prompt_template_encode_start_idx = 34;
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
GGML_ABORT("unknown version %d", version);
}
if (tokens.empty()) {
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
tokens = std::get<0>(tokens_and_weights);
weights = std::get<1>(tokens_and_weights);
}
int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
ggml_tensor* attention_mask = nullptr;
if (!mask.empty()) {
attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size());
ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.f;
if (mask[i0] == 0.f) {
value = -INFINITY;
} else if (i0 > i1) {
value = -INFINITY;
}
ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3);
});
}
llm->compute(n_threads,
input_ids,
attention_mask,
auto hidden_states = encode_prompt(work_ctx,
n_threads,
prompt,
prompt_attn_range,
max_length,
min_length,
image_embeds,
out_layers,
&hidden_states,
work_ctx);
{
auto tensor = hidden_states;
float original_mean = ggml_ext_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2);
value *= weights[i1];
ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_ext_tensor_mean(tensor);
ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean));
}
prompt_template_encode_start_idx);
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
int64_t min_length = 0;
if (version == VERSION_FLUX2) {
min_length = 512;
std::vector<ggml_tensor*> extra_hidden_states_vec;
for (int i = 0; i < extra_prompts.size(); i++) {
auto extra_hidden_states = encode_prompt(work_ctx,
n_threads,
extra_prompts[i],
extra_prompts_attn_range[i],
max_length,
min_length,
image_embeds,
out_layers,
prompt_template_encode_start_idx);
extra_hidden_states_vec.push_back(extra_hidden_states);
}
int64_t zero_pad_len = 0;
if (min_length > 0) {
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx;
}
}
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
GGML_TYPE_F32,
hidden_states->ne[0],
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
hidden_states->ne[2]);
ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = 0.f;
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
}
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});
// print_ggml_tensor(new_hidden_states);
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
return {new_hidden_states, nullptr, nullptr};
return {hidden_states, nullptr, nullptr, extra_hidden_states_vec};
}
};