From 887055edce3e4bb964c32326fad822ebb5488b91 Mon Sep 17 00:00:00 2001 From: leejet Date: Thu, 9 Oct 2025 01:37:52 +0800 Subject: [PATCH] fix qwen image edit pipeline --- conditioner.hpp | 27 +++++++++++++++++------- diffusion_model.hpp | 2 +- examples/cli/main.cpp | 2 -- ggml_extend.hpp | 33 +++++++++++++---------------- qwen_image.hpp | 2 +- qwenvl.hpp | 9 ++++---- rope.hpp | 4 ++-- stable-diffusion.cpp | 49 +++++++++++++++++++++++++++++++++---------- util.cpp | 22 ++++++++++--------- 9 files changed, 93 insertions(+), 57 deletions(-) diff --git a/conditioner.hpp b/conditioner.hpp index 959b41e..01d0759 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1409,9 +1409,19 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { } std::tuple, std::vector> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { - auto parsed_attention = parse_prompt_attention(text); + size_t max_length = 0, + size_t system_prompt_length = 0, + bool padding = false) { + std::vector> parsed_attention; + if (system_prompt_length > 0) { + parsed_attention.emplace_back(text.substr(0, system_prompt_length), 1.f); + auto new_parsed_attention = parse_prompt_attention(text.substr(system_prompt_length, text.size() - system_prompt_length)); + parsed_attention.insert(parsed_attention.end(), + new_parsed_attention.begin(), + new_parsed_attention.end()); + } else { + parsed_attention = parse_prompt_attention(text); + } { std::stringstream ss; @@ -1436,7 +1446,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { tokenizer.pad_tokens(tokens, weights, max_length, padding); // for (int i = 0; i < tokens.size(); i++) { - // std::cout << tokens[i] << ":" << weights[i] << ", "; + // std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl; // } // std::cout << std::endl; @@ -1448,12 +1458,13 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { const ConditionerParams& conditioner_params) { std::string prompt; std::vector> image_embeds; + size_t system_prompt_length = 0; if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) { LOG_INFO("QwenImageEditPlusPipeline"); prompt_template_encode_start_idx = 64; int image_embed_idx = 64 + 6; - int min_pixels = 56 * 56; + int min_pixels = 384 * 384; int max_pixels = 560 * 560; std::string placeholder = "<|image_pad|>"; std::string img_prompt; @@ -1485,7 +1496,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { image.data = nullptr; ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); - sd_image_f32_to_tensor(resized_image.data, image_tensor, false); + sd_image_f32_to_tensor(resized_image, image_tensor, false); free(resized_image.data); resized_image.data = nullptr; @@ -1505,6 +1516,8 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; + system_prompt_length = prompt.size(); + prompt += img_prompt; prompt += conditioner_params.text; prompt += "<|im_end|>\n<|im_start|>assistant\n"; @@ -1512,7 +1525,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner { prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n"; } - auto tokens_and_weights = tokenize(prompt, 0, false); + auto tokens_and_weights = tokenize(prompt, 0, system_prompt_length, false); auto& tokens = std::get<0>(tokens_and_weights); auto& weights = std::get<1>(tokens_and_weights); diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 6411857..6c38b58 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -314,7 +314,7 @@ struct QwenImageModel : public DiffusionModel { diffusion_params.timesteps, diffusion_params.context, diffusion_params.ref_latents, - diffusion_params.increase_ref_index, + true, // increase_ref_index output, output_ctx); } diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 79ae2eb..c0bd55b 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -27,8 +27,6 @@ #include "avi_writer.h" -#include "qwenvl.hpp" - #if defined(_WIN32) #define NOMINMAX #include diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 65cbff7..e94950a 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -193,8 +193,11 @@ __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int i return value; } -__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic) { +__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic, bool scale = true) { float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic); + if (scale) { + value /= 255.f; + } return value; } @@ -446,24 +449,18 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, } } -__STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data, - struct ggml_tensor* output, +__STATIC_INLINE__ void sd_image_f32_to_tensor(sd_image_f32_t image, + ggml_tensor* tensor, bool scale = true) { - int64_t width = output->ne[0]; - int64_t height = output->ne[1]; - int64_t channels = output->ne[2]; - GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32); - for (int iy = 0; iy < height; iy++) { - for (int ix = 0; ix < width; ix++) { - for (int k = 0; k < channels; k++) { - int value = *(image_data + iy * width * channels + ix * channels + k); - if (scale) { - value /= 255.f; - } - ggml_tensor_set_f32(output, value, ix, iy, k); - } - } - } + GGML_ASSERT(image.width == tensor->ne[0]); + GGML_ASSERT(image.height == tensor->ne[1]); + GGML_ASSERT(image.channel == tensor->ne[2]); + GGML_ASSERT(1 == tensor->ne[3]); + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + ggml_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = sd_image_get_f32(image, i0, i1, i2, scale); + ggml_tensor_set_f32(tensor, value, i0, i1, i2, i3); + }); } __STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input, diff --git a/qwen_image.hpp b/qwen_image.hpp index a7bdc3b..6b88af8 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -156,7 +156,7 @@ namespace Qwen { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 256.f)); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, diff --git a/qwenvl.hpp b/qwenvl.hpp index 584afbd..f11ebca 100644 --- a/qwenvl.hpp +++ b/qwenvl.hpp @@ -730,11 +730,10 @@ namespace Qwen { input_embed = ggml_concat(ctx, input_embed, image_embed, 1); } - auto final_txt_embed = ggml_slice(ctx, - raw_x, - 1, - image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1], - raw_x->ne[1]); + txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1]; + txt_token_end = raw_x->ne[1]; + + auto final_txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end); input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1); GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]); diff --git a/rope.hpp b/rope.hpp index 295c9a2..551c8ab 100644 --- a/rope.hpp +++ b/rope.hpp @@ -222,8 +222,8 @@ namespace Rope { int context_len, const std::vector& ref_latents, bool increase_ref_index) { - int h_len = (h + (patch_size / 2)) / patch_size; - int w_len = (w + (patch_size / 2)) / patch_size; + int h_len = (h + (patch_size / 2)) / patch_size / 2; + int w_len = (w + (patch_size / 2)) / patch_size / 2; int txt_id_start = std::max(h_len, w_len); auto txt_ids = linspace(txt_id_start, context_len + txt_id_start, context_len); std::vector> txt_ids_repeated(bs * context_len, std::vector(3)); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4eab7d5..8148b89 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -973,7 +973,7 @@ public: image.data = NULL; ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); - sd_image_f32_to_tensor(resized_image.data, pixel_values, false); + sd_image_f32_to_tensor(resized_image, pixel_values, false); free(resized_image.data); resized_image.data = NULL; @@ -1010,7 +1010,7 @@ public: sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height); free(image.data); image.data = NULL; - sd_image_f32_to_tensor(resized_image.data, init_img, false); + sd_image_f32_to_tensor(resized_image, init_img, false); free(resized_image.data); resized_image.data = NULL; } else { @@ -2006,8 +2006,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, seed = rand(); } - print_ggml_tensor(init_latent, true, "init"); - // for (auto v : sigmas) { // std::cout << v << " "; // } @@ -2485,13 +2483,42 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g std::vector ref_latents; for (int i = 0; i < ref_images.size(); i++) { - ggml_tensor* img = ggml_new_tensor_4d(work_ctx, - GGML_TYPE_F32, - ref_images[i]->width, - ref_images[i]->height, - 3, - 1); - sd_image_to_tensor(*ref_images[i], img); + ggml_tensor* img; + if (sd_version_is_qwen_image(sd_ctx->sd->version)) { + sd_image_f32_t ref_image = sd_image_t_to_sd_image_f32_t(*ref_images[i]); + int VAE_IMAGE_SIZE = std::min(1024 * 1024, width * height); + double vae_width = sqrt(VAE_IMAGE_SIZE * ref_image.width / ref_image.height); + double vae_height = vae_width * ref_image.height / ref_image.width; + + vae_height = round(vae_height / 32) * 32; + vae_width = round(vae_width / 32) * 32; + + sd_image_f32_t resized_image = resize_sd_image_f32_t(ref_image, static_cast(vae_width), static_cast(vae_height)); + free(ref_image.data); + ref_image.data = nullptr; + + LOG_DEBUG("resize vae ref image %d from %dx%d to %dx%d", i, ref_image.height, ref_image.width, resized_image.height, resized_image.width); + + img = ggml_new_tensor_4d(work_ctx, + GGML_TYPE_F32, + resized_image.width, + resized_image.height, + 3, + 1); + sd_image_f32_to_tensor(resized_image, img); + free(resized_image.data); + resized_image.data = nullptr; + } else { + img = ggml_new_tensor_4d(work_ctx, + GGML_TYPE_F32, + ref_images[i]->width, + ref_images[i]->height, + 3, + 1); + sd_image_to_tensor(*ref_images[i], img); + } + + // print_ggml_tensor(img, false, "img"); ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img); ref_latents.push_back(latent); diff --git a/util.cpp b/util.cpp index 7b20950..1d0bbd2 100644 --- a/util.cpp +++ b/util.cpp @@ -299,7 +299,7 @@ std::string trim(const std::string& s) { static sd_log_cb_t sd_log_cb = NULL; void* sd_log_cb_data = NULL; -#define LOG_BUFFER_SIZE 1024 +#define LOG_BUFFER_SIZE 4096 void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) { va_list args; @@ -388,10 +388,10 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int float original_x = (float)x * image.width / target_width; float original_y = (float)y * image.height / target_height; - int x1 = (int)original_x; - int y1 = (int)original_y; - int x2 = x1 + 1; - int y2 = y1 + 1; + uint32_t x1 = (uint32_t)original_x; + uint32_t y1 = (uint32_t)original_y; + uint32_t x2 = std::min(x1 + 1, image.width - 1); + uint32_t y2 = std::min(y1 + 1, image.height - 1); for (int k = 0; k < image.channel; k++) { float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k); @@ -444,10 +444,10 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe float original_x = (float)x * image.width / resized_width; float original_y = (float)y * image.height / resized_height; - int x1 = (int)original_x; - int y1 = (int)original_y; - int x2 = x1 + 1; - int y2 = y1 + 1; + uint32_t x1 = (uint32_t)original_x; + uint32_t y1 = (uint32_t)original_y; + uint32_t x2 = std::min(x1 + 1, image.width - 1); + uint32_t y2 = std::min(y1 + 1, image.height - 1); for (int k = 0; k < image.channel; k++) { float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k); @@ -478,8 +478,10 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe for (int k = 0; k < image.channel; k++) { for (int i = 0; i < result.height; i++) { for (int j = 0; j < result.width; j++) { + int src_y = std::min(i + h_offset, resized_height - 1); + int src_x = std::min(j + w_offset, resized_width - 1); *(result.data + i * result.width * image.channel + j * image.channel + k) = - fmin(fmax(*(resized_data + (i + h_offset) * resized_width * image.channel + (j + w_offset) * image.channel + k), 0.0f), 255.0f) / 255.0f; + fmin(fmax(*(resized_data + src_y * resized_width * image.channel + src_x * image.channel + k), 0.0f), 255.0f) / 255.0f; } } }