From 190c523cecd8b6e9042e19cb8514339fa8871ddf Mon Sep 17 00:00:00 2001 From: leejet Date: Thu, 25 Dec 2025 22:35:28 +0800 Subject: [PATCH] add support for extra contexts --- conditioner.hpp | 354 +++++++++++++++++++++++++------------------ diffusion_model.hpp | 5 +- stable-diffusion.cpp | 31 ++-- z_image.hpp | 2 +- 4 files changed, 229 insertions(+), 163 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 45db314..7fb1f0c 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -10,9 +10,14 @@ struct SDCondition { struct ggml_tensor* c_vector = nullptr; // aka y struct ggml_tensor* c_concat = nullptr; + std::vector extra_c_crossattns; + SDCondition() = default; - SDCondition(struct ggml_tensor* c_crossattn, struct ggml_tensor* c_vector, struct ggml_tensor* c_concat) - : c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {} + SDCondition(struct ggml_tensor* c_crossattn, + struct ggml_tensor* c_vector, + struct ggml_tensor* c_concat, + const std::vector& extra_c_crossattns = {}) + : c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat), extra_c_crossattns(extra_c_crossattns) {} }; struct ConditionerParams { @@ -1657,18 +1662,23 @@ struct LLMEmbedder : public Conditioner { } std::tuple, std::vector> tokenize(std::string text, - std::pair attn_range, + const std::pair& attn_range, size_t max_length = 0, bool padding = false) { std::vector> parsed_attention; - parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); - if (attn_range.second - attn_range.first > 0) { - auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); - parsed_attention.insert(parsed_attention.end(), - new_parsed_attention.begin(), - new_parsed_attention.end()); + if (attn_range.first >= 0 && attn_range.second > 0) { + parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + if (attn_range.second - attn_range.first > 0) { + auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); + parsed_attention.insert(parsed_attention.end(), + new_parsed_attention.begin(), + new_parsed_attention.end()); + } + parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); + } else { + parsed_attention.emplace_back(text, 1.f); } - parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); + { std::stringstream ss; ss << "["; @@ -1699,140 +1709,20 @@ struct LLMEmbedder : public Conditioner { return {tokens, weights}; } - SDCondition get_learned_condition(ggml_context* work_ctx, - int n_threads, - const ConditionerParams& conditioner_params) override { - std::string prompt; - std::vector> image_embeds; - std::pair prompt_attn_range; - int prompt_template_encode_start_idx = 34; - int max_length = 0; - std::set out_layers; - if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { - LOG_INFO("QwenImageEditPlusPipeline"); - prompt_template_encode_start_idx = 64; - int image_embed_idx = 64 + 6; - - int min_pixels = 384 * 384; - int max_pixels = 560 * 560; - std::string placeholder = "<|image_pad|>"; - std::string img_prompt; - - for (int i = 0; i < conditioner_params.ref_images.size(); i++) { - sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); - double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; - int height = image.height; - int width = image.width; - int h_bar = static_cast(std::round(height / factor)) * factor; - int w_bar = static_cast(std::round(width / factor)) * factor; - - if (static_cast(h_bar) * w_bar > max_pixels) { - double beta = std::sqrt((height * width) / static_cast(max_pixels)); - h_bar = std::max(static_cast(factor), - static_cast(std::floor(height / beta / factor)) * static_cast(factor)); - w_bar = std::max(static_cast(factor), - static_cast(std::floor(width / beta / factor)) * static_cast(factor)); - } else if (static_cast(h_bar) * w_bar < min_pixels) { - double beta = std::sqrt(static_cast(min_pixels) / (height * width)); - h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); - w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); - } - - LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); - - sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); - free(image.data); - image.data = nullptr; - - ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); - sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); - free(resized_image.data); - resized_image.data = nullptr; - - ggml_tensor* image_embed = nullptr; - llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); - image_embeds.emplace_back(image_embed_idx, image_embed); - image_embed_idx += 1 + image_embed->ne[1] + 6; - - img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] - int64_t num_image_tokens = image_embed->ne[1]; - img_prompt.reserve(num_image_tokens * placeholder.size()); - for (int j = 0; j < num_image_tokens; j++) { - img_prompt += placeholder; - } - img_prompt += "<|vision_end|>"; - } - - prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; - prompt += img_prompt; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "<|im_end|>\n<|im_start|>assistant\n"; - } else if (sd_version_is_flux2(version)) { - prompt_template_encode_start_idx = 0; - out_layers = {10, 20, 30}; - - prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "[/INST]"; - } else if (sd_version_is_z_image(version)) { - prompt_template_encode_start_idx = 0; - out_layers = {35}; // -2 - - prompt = "<|im_start|>user\n"; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "<|im_end|>\n<|im_start|>assistant\n"; - } else if (sd_version_is_flux2(version)) { - prompt_template_encode_start_idx = 0; - out_layers = {10, 20, 30}; - - prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; - - prompt_attn_range.first = prompt.size(); - prompt += conditioner_params.text; - prompt_attn_range.second = prompt.size(); - - prompt += "[/INST]"; - } else if (version == VERSION_OVIS_IMAGE) { - prompt_template_encode_start_idx = 28; - max_length = prompt_template_encode_start_idx + 256; - - prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:"; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += " " + conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; - } else { - prompt_template_encode_start_idx = 34; - - prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n"; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "<|im_end|>\n<|im_start|>assistant\n"; - } - + ggml_tensor* encode_prompt(ggml_context* work_ctx, + int n_threads, + const std::string prompt, + const std::pair& prompt_attn_range, + int max_length, + int min_length, + std::vector> image_embeds, + const std::set& out_layers, + int prompt_template_encode_start_idx) { auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); auto& tokens = std::get<0>(tokens_and_weights); auto& weights = std::get<1>(tokens_and_weights); - int64_t t0 = ggml_time_ms(); - struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584] + struct ggml_tensor* hidden_states = nullptr; // [N, n_token, hidden_size] auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); @@ -1860,11 +1750,6 @@ struct LLMEmbedder : public Conditioner { GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx); - int64_t min_length = 0; - if (sd_version_is_flux2(version)) { - min_length = 512; - } - int64_t zero_pad_len = 0; if (min_length > 0) { if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) { @@ -1886,11 +1771,186 @@ struct LLMEmbedder : public Conditioner { ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3); }); - // print_ggml_tensor(new_hidden_states); + return new_hidden_states; + } + + SDCondition get_learned_condition(ggml_context* work_ctx, + int n_threads, + const ConditionerParams& conditioner_params) override { + std::string prompt; + std::pair prompt_attn_range; + std::vector extra_prompts; + std::vector> extra_prompts_attn_range; + std::vector> image_embeds; + int prompt_template_encode_start_idx = 34; + int max_length = 0; + int min_length = 0; + std::set out_layers; + + int64_t t0 = ggml_time_ms(); + + if (sd_version_is_qwen_image(version)) { + if (llm->enable_vision && !conditioner_params.ref_images.empty() > 0) { + LOG_INFO("QwenImageEditPlusPipeline"); + prompt_template_encode_start_idx = 64; + int image_embed_idx = 64 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor)) * factor; + int w_bar = static_cast(std::round(width / factor)) * factor; + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; + + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; + + ggml_tensor* image_embed = nullptr; + llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + image_embed->ne[1] + 6; + + img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; + } + + prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } else { + prompt_template_encode_start_idx = 34; + + prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } + } else if (sd_version_is_flux2(version)) { + prompt_template_encode_start_idx = 0; + out_layers = {10, 20, 30}; + + prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "[/INST]"; + } else if (sd_version_is_z_image(version)) { + prompt_template_encode_start_idx = 0; + out_layers = {35}; // -2 + + if (!conditioner_params.ref_images.empty()) { + LOG_INFO("ZImageOmniPipeline"); + prompt = "<|im_start|>user\n<|vision_start|>"; + for (int i = 0; i < conditioner_params.ref_images.size() - 1; i++) { + extra_prompts.push_back("<|vision_end|><|vision_start|>"); + } + extra_prompts.push_back("<|vision_end|>" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"); + extra_prompts.push_back("<|vision_end|><|im_end|>"); + } else { + prompt = "<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } + } else if (sd_version_is_flux2(version)) { + prompt_template_encode_start_idx = 0; + out_layers = {10, 20, 30}; + + prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; + + prompt_attn_range.first = prompt.size(); + prompt += conditioner_params.text; + prompt_attn_range.second = prompt.size(); + + prompt += "[/INST]"; + + min_length = 512; + } else if (version == VERSION_OVIS_IMAGE) { + prompt_template_encode_start_idx = 28; + max_length = prompt_template_encode_start_idx + 256; + + prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += " " + conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + } else { + GGML_ABORT("unknown version %d", version); + } + + auto hidden_states = encode_prompt(work_ctx, + n_threads, + prompt, + prompt_attn_range, + max_length, + min_length, + image_embeds, + out_layers, + prompt_template_encode_start_idx); + + std::vector extra_hidden_states_vec; + for (int i = 0; i < extra_prompts.size(); i++) { + auto extra_hidden_states = encode_prompt(work_ctx, + n_threads, + extra_prompts[i], + extra_prompts_attn_range[i], + max_length, + min_length, + image_embeds, + out_layers, + prompt_template_encode_start_idx); + extra_hidden_states_vec.push_back(extra_hidden_states); + } int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); - return {new_hidden_states, nullptr, nullptr}; + return {hidden_states, nullptr, nullptr, extra_hidden_states_vec}; } }; diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 3f735ac..5cf7717 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -23,6 +23,7 @@ struct DiffusionParams { struct ggml_tensor* vace_context = nullptr; float vace_strength = 1.f; std::vector skip_layers = {}; + std::vector extra_contexts; // for z-image-omni }; struct DiffusionModel { @@ -436,10 +437,12 @@ struct ZImageModel : public DiffusionModel { DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, struct ggml_context* output_ctx = nullptr) override { + std::vector contexts = {diffusion_params.context}; + contexts.insert(contexts.end(), diffusion_params.extra_contexts.begin(), diffusion_params.extra_contexts.end()); return z_image.compute(n_threads, diffusion_params.x, diffusion_params.timesteps, - {diffusion_params.context}, + contexts, diffusion_params.ref_latents, {}, output, diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4b1c004..84c9ffe 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1931,10 +1931,11 @@ public: struct ggml_tensor** active_output = &out_cond; if (start_merge_step == -1 || step <= start_merge_step) { // cond - diffusion_params.context = cond.c_crossattn; - diffusion_params.c_concat = cond.c_concat; - diffusion_params.y = cond.c_vector; - active_condition = &cond; + diffusion_params.context = cond.c_crossattn; + diffusion_params.extra_contexts = cond.extra_c_crossattns; + diffusion_params.c_concat = cond.c_concat; + diffusion_params.y = cond.c_vector; + active_condition = &cond; } else { diffusion_params.context = id_cond.c_crossattn; diffusion_params.c_concat = cond.c_concat; @@ -1965,12 +1966,13 @@ public: LOG_ERROR("controlnet compute failed"); } } - current_step_skipped = cache_step_is_skipped(); - diffusion_params.controls = controls; - diffusion_params.context = uncond.c_crossattn; - diffusion_params.c_concat = uncond.c_concat; - diffusion_params.y = uncond.c_vector; - bool skip_uncond = cache_before_condition(&uncond, out_uncond); + current_step_skipped = cache_step_is_skipped(); + diffusion_params.controls = controls; + diffusion_params.context = uncond.c_crossattn; + diffusion_params.extra_contexts = uncond.extra_c_crossattns; + diffusion_params.c_concat = uncond.c_concat; + diffusion_params.y = uncond.c_vector; + bool skip_uncond = cache_before_condition(&uncond, out_uncond); if (!skip_uncond) { if (!work_diffusion_model->compute(n_threads, diffusion_params, @@ -1985,10 +1987,11 @@ public: float* img_cond_data = nullptr; if (has_img_cond) { - diffusion_params.context = img_cond.c_crossattn; - diffusion_params.c_concat = img_cond.c_concat; - diffusion_params.y = img_cond.c_vector; - bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond); + diffusion_params.context = img_cond.c_crossattn; + diffusion_params.extra_contexts = img_cond.extra_c_crossattns; + diffusion_params.c_concat = img_cond.c_concat; + diffusion_params.y = img_cond.c_vector; + bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond); if (!skip_img_cond) { if (!work_diffusion_model->compute(n_threads, diffusion_params, diff --git a/z_image.hpp b/z_image.hpp index 715b2d2..1f34c9f 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -644,7 +644,7 @@ namespace ZImage { t_clean = t_embedder->forward(ctx, ggml_scale(ctx->ggml_ctx, ggml_ext_ones(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]), - 0.f)); + 1000.f)); } else { t_emb = t_embedder->forward(ctx, timestep); }