add support for extra contexts

This commit is contained in:
leejet 2025-12-25 22:35:28 +08:00
parent 8004d32de2
commit 190c523cec
4 changed files with 229 additions and 163 deletions

View File

@ -10,9 +10,14 @@ struct SDCondition {
struct ggml_tensor* c_vector = nullptr; // aka y struct ggml_tensor* c_vector = nullptr; // aka y
struct ggml_tensor* c_concat = nullptr; struct ggml_tensor* c_concat = nullptr;
std::vector<struct ggml_tensor*> extra_c_crossattns;
SDCondition() = default; SDCondition() = default;
SDCondition(struct ggml_tensor* c_crossattn, struct ggml_tensor* c_vector, struct ggml_tensor* c_concat) SDCondition(struct ggml_tensor* c_crossattn,
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {} struct ggml_tensor* c_vector,
struct ggml_tensor* c_concat,
const std::vector<struct ggml_tensor*>& extra_c_crossattns = {})
: c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat), extra_c_crossattns(extra_c_crossattns) {}
}; };
struct ConditionerParams { struct ConditionerParams {
@ -1657,18 +1662,23 @@ struct LLMEmbedder : public Conditioner {
} }
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text, std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
std::pair<int, int> attn_range, const std::pair<int, int>& attn_range,
size_t max_length = 0, size_t max_length = 0,
bool padding = false) { bool padding = false) {
std::vector<std::pair<std::string, float>> parsed_attention; std::vector<std::pair<std::string, float>> parsed_attention;
parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); if (attn_range.first >= 0 && attn_range.second > 0) {
if (attn_range.second - attn_range.first > 0) { parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f);
auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); if (attn_range.second - attn_range.first > 0) {
parsed_attention.insert(parsed_attention.end(), auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first));
new_parsed_attention.begin(), parsed_attention.insert(parsed_attention.end(),
new_parsed_attention.end()); new_parsed_attention.begin(),
new_parsed_attention.end());
}
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
} else {
parsed_attention.emplace_back(text, 1.f);
} }
parsed_attention.emplace_back(text.substr(attn_range.second), 1.f);
{ {
std::stringstream ss; std::stringstream ss;
ss << "["; ss << "[";
@ -1699,140 +1709,20 @@ struct LLMEmbedder : public Conditioner {
return {tokens, weights}; return {tokens, weights};
} }
SDCondition get_learned_condition(ggml_context* work_ctx, ggml_tensor* encode_prompt(ggml_context* work_ctx,
int n_threads, int n_threads,
const ConditionerParams& conditioner_params) override { const std::string prompt,
std::string prompt; const std::pair<int, int>& prompt_attn_range,
std::vector<std::pair<int, ggml_tensor*>> image_embeds; int max_length,
std::pair<int, int> prompt_attn_range; int min_length,
int prompt_template_encode_start_idx = 34; std::vector<std::pair<int, ggml_tensor*>> image_embeds,
int max_length = 0; const std::set<int>& out_layers,
std::set<int> out_layers; int prompt_template_encode_start_idx) {
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64;
int image_embed_idx = 64 + 6;
int min_pixels = 384 * 384;
int max_pixels = 560 * 560;
std::string placeholder = "<|image_pad|>";
std::string img_prompt;
for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
int height = image.height;
int width = image.width;
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
int w_bar = static_cast<int>(std::round(width / factor)) * factor;
if (static_cast<double>(h_bar) * w_bar > max_pixels) {
double beta = std::sqrt((height * width) / static_cast<double>(max_pixels));
h_bar = std::max(static_cast<int>(factor),
static_cast<int>(std::floor(height / beta / factor)) * static_cast<int>(factor));
w_bar = std::max(static_cast<int>(factor),
static_cast<int>(std::floor(width / beta / factor)) * static_cast<int>(factor));
} else if (static_cast<double>(h_bar) * w_bar < min_pixels) {
double beta = std::sqrt(static_cast<double>(min_pixels) / (height * width));
h_bar = static_cast<int>(std::ceil(height * beta / factor)) * static_cast<int>(factor);
w_bar = static_cast<int>(std::ceil(width * beta / factor)) * static_cast<int>(factor);
}
LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar);
sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar);
free(image.data);
image.data = nullptr;
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false);
free(resized_image.data);
resized_image.data = nullptr;
ggml_tensor* image_embed = nullptr;
llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
image_embeds.emplace_back(image_embed_idx, image_embed);
image_embed_idx += 1 + image_embed->ne[1] + 6;
img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652]
int64_t num_image_tokens = image_embed->ne[1];
img_prompt.reserve(num_image_tokens * placeholder.size());
for (int j = 0; j < num_image_tokens; j++) {
img_prompt += placeholder;
}
img_prompt += "<|vision_end|>";
}
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
prompt += img_prompt;
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "[/INST]";
} else if (sd_version_is_z_image(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {35}; // -2
prompt = "<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else if (sd_version_is_flux2(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
prompt_attn_range.first = prompt.size();
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();
prompt += "[/INST]";
} else if (version == VERSION_OVIS_IMAGE) {
prompt_template_encode_start_idx = 28;
max_length = prompt_template_encode_start_idx + 256;
prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += " " + conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
} else {
prompt_template_encode_start_idx = 34;
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
}
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
auto& tokens = std::get<0>(tokens_and_weights); auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights); auto& weights = std::get<1>(tokens_and_weights);
int64_t t0 = ggml_time_ms(); struct ggml_tensor* hidden_states = nullptr; // [N, n_token, hidden_size]
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
@ -1860,11 +1750,6 @@ struct LLMEmbedder : public Conditioner {
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx); GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
int64_t min_length = 0;
if (sd_version_is_flux2(version)) {
min_length = 512;
}
int64_t zero_pad_len = 0; int64_t zero_pad_len = 0;
if (min_length > 0) { if (min_length > 0) {
if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) { if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) {
@ -1886,11 +1771,186 @@ struct LLMEmbedder : public Conditioner {
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3); ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
}); });
// print_ggml_tensor(new_hidden_states); return new_hidden_states;
}
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const ConditionerParams& conditioner_params) override {
std::string prompt;
std::pair<int, int> prompt_attn_range;
std::vector<std::string> extra_prompts;
std::vector<std::pair<int, int>> extra_prompts_attn_range;
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
int prompt_template_encode_start_idx = 34;
int max_length = 0;
int min_length = 0;
std::set<int> out_layers;
int64_t t0 = ggml_time_ms();
if (sd_version_is_qwen_image(version)) {
if (llm->enable_vision && !conditioner_params.ref_images.empty() > 0) {
LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64;
int image_embed_idx = 64 + 6;
int min_pixels = 384 * 384;
int max_pixels = 560 * 560;
std::string placeholder = "<|image_pad|>";
std::string img_prompt;
for (int i = 0; i < conditioner_params.ref_images.size(); i++) {
sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]);
double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size;
int height = image.height;
int width = image.width;
int h_bar = static_cast<int>(std::round(height / factor)) * factor;
int w_bar = static_cast<int>(std::round(width / factor)) * factor;
if (static_cast<double>(h_bar) * w_bar > max_pixels) {
double beta = std::sqrt((height * width) / static_cast<double>(max_pixels));
h_bar = std::max(static_cast<int>(factor),
static_cast<int>(std::floor(height / beta / factor)) * static_cast<int>(factor));
w_bar = std::max(static_cast<int>(factor),
static_cast<int>(std::floor(width / beta / factor)) * static_cast<int>(factor));
} else if (static_cast<double>(h_bar) * w_bar < min_pixels) {
double beta = std::sqrt(static_cast<double>(min_pixels) / (height * width));
h_bar = static_cast<int>(std::ceil(height * beta / factor)) * static_cast<int>(factor);
w_bar = static_cast<int>(std::ceil(width * beta / factor)) * static_cast<int>(factor);
}
LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar);
sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar);
free(image.data);
image.data = nullptr;
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false);
free(resized_image.data);
resized_image.data = nullptr;
ggml_tensor* image_embed = nullptr;
llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx);
image_embeds.emplace_back(image_embed_idx, image_embed);
image_embed_idx += 1 + image_embed->ne[1] + 6;
img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652]
int64_t num_image_tokens = image_embed->ne[1];
img_prompt.reserve(num_image_tokens * placeholder.size());
for (int j = 0; j < num_image_tokens; j++) {
img_prompt += placeholder;
}
img_prompt += "<|vision_end|>";
}
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
prompt += img_prompt;
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
} else {
prompt_template_encode_start_idx = 34;
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
}
} else if (sd_version_is_flux2(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "[/INST]";
} else if (sd_version_is_z_image(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {35}; // -2
if (!conditioner_params.ref_images.empty()) {
LOG_INFO("ZImageOmniPipeline");
prompt = "<|im_start|>user\n<|vision_start|>";
for (int i = 0; i < conditioner_params.ref_images.size() - 1; i++) {
extra_prompts.push_back("<|vision_end|><|vision_start|>");
}
extra_prompts.push_back("<|vision_end|>" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>");
extra_prompts.push_back("<|vision_end|><|im_end|>");
} else {
prompt = "<|im_start|>user\n";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n";
}
} else if (sd_version_is_flux2(version)) {
prompt_template_encode_start_idx = 0;
out_layers = {10, 20, 30};
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
prompt_attn_range.first = prompt.size();
prompt += conditioner_params.text;
prompt_attn_range.second = prompt.size();
prompt += "[/INST]";
min_length = 512;
} else if (version == VERSION_OVIS_IMAGE) {
prompt_template_encode_start_idx = 28;
max_length = prompt_template_encode_start_idx + 256;
prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:";
prompt_attn_range.first = static_cast<int>(prompt.size());
prompt += " " + conditioner_params.text;
prompt_attn_range.second = static_cast<int>(prompt.size());
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
} else {
GGML_ABORT("unknown version %d", version);
}
auto hidden_states = encode_prompt(work_ctx,
n_threads,
prompt,
prompt_attn_range,
max_length,
min_length,
image_embeds,
out_layers,
prompt_template_encode_start_idx);
std::vector<ggml_tensor*> extra_hidden_states_vec;
for (int i = 0; i < extra_prompts.size(); i++) {
auto extra_hidden_states = encode_prompt(work_ctx,
n_threads,
extra_prompts[i],
extra_prompts_attn_range[i],
max_length,
min_length,
image_embeds,
out_layers,
prompt_template_encode_start_idx);
extra_hidden_states_vec.push_back(extra_hidden_states);
}
int64_t t1 = ggml_time_ms(); int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
return {new_hidden_states, nullptr, nullptr}; return {hidden_states, nullptr, nullptr, extra_hidden_states_vec};
} }
}; };

View File

@ -23,6 +23,7 @@ struct DiffusionParams {
struct ggml_tensor* vace_context = nullptr; struct ggml_tensor* vace_context = nullptr;
float vace_strength = 1.f; float vace_strength = 1.f;
std::vector<int> skip_layers = {}; std::vector<int> skip_layers = {};
std::vector<struct ggml_tensor*> extra_contexts; // for z-image-omni
}; };
struct DiffusionModel { struct DiffusionModel {
@ -436,10 +437,12 @@ struct ZImageModel : public DiffusionModel {
DiffusionParams diffusion_params, DiffusionParams diffusion_params,
struct ggml_tensor** output = nullptr, struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) override { struct ggml_context* output_ctx = nullptr) override {
std::vector<ggml_tensor*> contexts = {diffusion_params.context};
contexts.insert(contexts.end(), diffusion_params.extra_contexts.begin(), diffusion_params.extra_contexts.end());
return z_image.compute(n_threads, return z_image.compute(n_threads,
diffusion_params.x, diffusion_params.x,
diffusion_params.timesteps, diffusion_params.timesteps,
{diffusion_params.context}, contexts,
diffusion_params.ref_latents, diffusion_params.ref_latents,
{}, {},
output, output,

View File

@ -1931,10 +1931,11 @@ public:
struct ggml_tensor** active_output = &out_cond; struct ggml_tensor** active_output = &out_cond;
if (start_merge_step == -1 || step <= start_merge_step) { if (start_merge_step == -1 || step <= start_merge_step) {
// cond // cond
diffusion_params.context = cond.c_crossattn; diffusion_params.context = cond.c_crossattn;
diffusion_params.c_concat = cond.c_concat; diffusion_params.extra_contexts = cond.extra_c_crossattns;
diffusion_params.y = cond.c_vector; diffusion_params.c_concat = cond.c_concat;
active_condition = &cond; diffusion_params.y = cond.c_vector;
active_condition = &cond;
} else { } else {
diffusion_params.context = id_cond.c_crossattn; diffusion_params.context = id_cond.c_crossattn;
diffusion_params.c_concat = cond.c_concat; diffusion_params.c_concat = cond.c_concat;
@ -1965,12 +1966,13 @@ public:
LOG_ERROR("controlnet compute failed"); LOG_ERROR("controlnet compute failed");
} }
} }
current_step_skipped = cache_step_is_skipped(); current_step_skipped = cache_step_is_skipped();
diffusion_params.controls = controls; diffusion_params.controls = controls;
diffusion_params.context = uncond.c_crossattn; diffusion_params.context = uncond.c_crossattn;
diffusion_params.c_concat = uncond.c_concat; diffusion_params.extra_contexts = uncond.extra_c_crossattns;
diffusion_params.y = uncond.c_vector; diffusion_params.c_concat = uncond.c_concat;
bool skip_uncond = cache_before_condition(&uncond, out_uncond); diffusion_params.y = uncond.c_vector;
bool skip_uncond = cache_before_condition(&uncond, out_uncond);
if (!skip_uncond) { if (!skip_uncond) {
if (!work_diffusion_model->compute(n_threads, if (!work_diffusion_model->compute(n_threads,
diffusion_params, diffusion_params,
@ -1985,10 +1987,11 @@ public:
float* img_cond_data = nullptr; float* img_cond_data = nullptr;
if (has_img_cond) { if (has_img_cond) {
diffusion_params.context = img_cond.c_crossattn; diffusion_params.context = img_cond.c_crossattn;
diffusion_params.c_concat = img_cond.c_concat; diffusion_params.extra_contexts = img_cond.extra_c_crossattns;
diffusion_params.y = img_cond.c_vector; diffusion_params.c_concat = img_cond.c_concat;
bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond); diffusion_params.y = img_cond.c_vector;
bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond);
if (!skip_img_cond) { if (!skip_img_cond) {
if (!work_diffusion_model->compute(n_threads, if (!work_diffusion_model->compute(n_threads,
diffusion_params, diffusion_params,

View File

@ -644,7 +644,7 @@ namespace ZImage {
t_clean = t_embedder->forward(ctx, t_clean = t_embedder->forward(ctx,
ggml_scale(ctx->ggml_ctx, ggml_scale(ctx->ggml_ctx,
ggml_ext_ones(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]), ggml_ext_ones(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]),
0.f)); 1000.f));
} else { } else {
t_emb = t_embedder->forward(ctx, timestep); t_emb = t_embedder->forward(ctx, timestep);
} }