fix qwen image edit pipeline

This commit is contained in:
leejet 2025-10-09 01:37:52 +08:00
parent 40752b629f
commit 887055edce
9 changed files with 93 additions and 57 deletions

View File

@ -1410,8 +1410,18 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text, std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0, size_t max_length = 0,
size_t system_prompt_length = 0,
bool padding = false) { bool padding = false) {
auto parsed_attention = parse_prompt_attention(text); std::vector<std::pair<std::string, float>> parsed_attention;
if (system_prompt_length > 0) {
parsed_attention.emplace_back(text.substr(0, system_prompt_length), 1.f);
auto new_parsed_attention = parse_prompt_attention(text.substr(system_prompt_length, text.size() - system_prompt_length));
parsed_attention.insert(parsed_attention.end(),
new_parsed_attention.begin(),
new_parsed_attention.end());
} else {
parsed_attention = parse_prompt_attention(text);
}
{ {
std::stringstream ss; std::stringstream ss;
@ -1436,7 +1446,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
tokenizer.pad_tokens(tokens, weights, max_length, padding); tokenizer.pad_tokens(tokens, weights, max_length, padding);
// for (int i = 0; i < tokens.size(); i++) { // for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", "; // std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
// } // }
// std::cout << std::endl; // std::cout << std::endl;
@ -1448,12 +1458,13 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
const ConditionerParams& conditioner_params) { const ConditionerParams& conditioner_params) {
std::string prompt; std::string prompt;
std::vector<std::pair<int, ggml_tensor*>> image_embeds; std::vector<std::pair<int, ggml_tensor*>> image_embeds;
size_t system_prompt_length = 0;
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) { if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) {
LOG_INFO("QwenImageEditPlusPipeline"); LOG_INFO("QwenImageEditPlusPipeline");
prompt_template_encode_start_idx = 64; prompt_template_encode_start_idx = 64;
int image_embed_idx = 64 + 6; int image_embed_idx = 64 + 6;
int min_pixels = 56 * 56; int min_pixels = 384 * 384;
int max_pixels = 560 * 560; int max_pixels = 560 * 560;
std::string placeholder = "<|image_pad|>"; std::string placeholder = "<|image_pad|>";
std::string img_prompt; std::string img_prompt;
@ -1485,7 +1496,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
image.data = nullptr; image.data = nullptr;
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_tensor(resized_image.data, image_tensor, false); sd_image_f32_to_tensor(resized_image, image_tensor, false);
free(resized_image.data); free(resized_image.data);
resized_image.data = nullptr; resized_image.data = nullptr;
@ -1505,6 +1516,8 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
system_prompt_length = prompt.size();
prompt += img_prompt; prompt += img_prompt;
prompt += conditioner_params.text; prompt += conditioner_params.text;
prompt += "<|im_end|>\n<|im_start|>assistant\n"; prompt += "<|im_end|>\n<|im_start|>assistant\n";
@ -1512,7 +1525,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n"; prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n";
} }
auto tokens_and_weights = tokenize(prompt, 0, false); auto tokens_and_weights = tokenize(prompt, 0, system_prompt_length, false);
auto& tokens = std::get<0>(tokens_and_weights); auto& tokens = std::get<0>(tokens_and_weights);
auto& weights = std::get<1>(tokens_and_weights); auto& weights = std::get<1>(tokens_and_weights);

View File

@ -314,7 +314,7 @@ struct QwenImageModel : public DiffusionModel {
diffusion_params.timesteps, diffusion_params.timesteps,
diffusion_params.context, diffusion_params.context,
diffusion_params.ref_latents, diffusion_params.ref_latents,
diffusion_params.increase_ref_index, true, // increase_ref_index
output, output,
output_ctx); output_ctx);
} }

View File

@ -27,8 +27,6 @@
#include "avi_writer.h" #include "avi_writer.h"
#include "qwenvl.hpp"
#if defined(_WIN32) #if defined(_WIN32)
#define NOMINMAX #define NOMINMAX
#include <windows.h> #include <windows.h>

View File

@ -193,8 +193,11 @@ __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int i
return value; return value;
} }
__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic) { __STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic, bool scale = true) {
float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic); float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic);
if (scale) {
value /= 255.f;
}
return value; return value;
} }
@ -446,24 +449,18 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
} }
} }
__STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data, __STATIC_INLINE__ void sd_image_f32_to_tensor(sd_image_f32_t image,
struct ggml_tensor* output, ggml_tensor* tensor,
bool scale = true) { bool scale = true) {
int64_t width = output->ne[0]; GGML_ASSERT(image.width == tensor->ne[0]);
int64_t height = output->ne[1]; GGML_ASSERT(image.height == tensor->ne[1]);
int64_t channels = output->ne[2]; GGML_ASSERT(image.channel == tensor->ne[2]);
GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32); GGML_ASSERT(1 == tensor->ne[3]);
for (int iy = 0; iy < height; iy++) { GGML_ASSERT(tensor->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) { ggml_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
for (int k = 0; k < channels; k++) { float value = sd_image_get_f32(image, i0, i1, i2, scale);
int value = *(image_data + iy * width * channels + ix * channels + k); ggml_tensor_set_f32(tensor, value, i0, i1, i2, i3);
if (scale) { });
value /= 255.f;
}
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
}
} }
__STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input, __STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,

View File

@ -156,7 +156,7 @@ namespace Qwen {
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head] auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 256.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
auto txt_attn_out = ggml_view_3d(ctx, auto txt_attn_out = ggml_view_3d(ctx,
attn, attn,

View File

@ -730,11 +730,10 @@ namespace Qwen {
input_embed = ggml_concat(ctx, input_embed, image_embed, 1); input_embed = ggml_concat(ctx, input_embed, image_embed, 1);
} }
auto final_txt_embed = ggml_slice(ctx, txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1];
raw_x, txt_token_end = raw_x->ne[1];
1,
image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1], auto final_txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
raw_x->ne[1]);
input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1); input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1);
GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]); GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]);

View File

@ -222,8 +222,8 @@ namespace Rope {
int context_len, int context_len,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) { bool increase_ref_index) {
int h_len = (h + (patch_size / 2)) / patch_size; int h_len = (h + (patch_size / 2)) / patch_size / 2;
int w_len = (w + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size / 2;
int txt_id_start = std::max(h_len, w_len); int txt_id_start = std::max(h_len, w_len);
auto txt_ids = linspace<float>(txt_id_start, context_len + txt_id_start, context_len); auto txt_ids = linspace<float>(txt_id_start, context_len + txt_id_start, context_len);
std::vector<std::vector<float>> txt_ids_repeated(bs * context_len, std::vector<float>(3)); std::vector<std::vector<float>> txt_ids_repeated(bs * context_len, std::vector<float>(3));

View File

@ -973,7 +973,7 @@ public:
image.data = NULL; image.data = NULL;
ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
sd_image_f32_to_tensor(resized_image.data, pixel_values, false); sd_image_f32_to_tensor(resized_image, pixel_values, false);
free(resized_image.data); free(resized_image.data);
resized_image.data = NULL; resized_image.data = NULL;
@ -1010,7 +1010,7 @@ public:
sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height); sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height);
free(image.data); free(image.data);
image.data = NULL; image.data = NULL;
sd_image_f32_to_tensor(resized_image.data, init_img, false); sd_image_f32_to_tensor(resized_image, init_img, false);
free(resized_image.data); free(resized_image.data);
resized_image.data = NULL; resized_image.data = NULL;
} else { } else {
@ -2006,8 +2006,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
seed = rand(); seed = rand();
} }
print_ggml_tensor(init_latent, true, "init");
// for (auto v : sigmas) { // for (auto v : sigmas) {
// std::cout << v << " "; // std::cout << v << " ";
// } // }
@ -2485,13 +2483,42 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
std::vector<ggml_tensor*> ref_latents; std::vector<ggml_tensor*> ref_latents;
for (int i = 0; i < ref_images.size(); i++) { for (int i = 0; i < ref_images.size(); i++) {
ggml_tensor* img = ggml_new_tensor_4d(work_ctx, ggml_tensor* img;
if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
sd_image_f32_t ref_image = sd_image_t_to_sd_image_f32_t(*ref_images[i]);
int VAE_IMAGE_SIZE = std::min(1024 * 1024, width * height);
double vae_width = sqrt(VAE_IMAGE_SIZE * ref_image.width / ref_image.height);
double vae_height = vae_width * ref_image.height / ref_image.width;
vae_height = round(vae_height / 32) * 32;
vae_width = round(vae_width / 32) * 32;
sd_image_f32_t resized_image = resize_sd_image_f32_t(ref_image, static_cast<int>(vae_width), static_cast<int>(vae_height));
free(ref_image.data);
ref_image.data = nullptr;
LOG_DEBUG("resize vae ref image %d from %dx%d to %dx%d", i, ref_image.height, ref_image.width, resized_image.height, resized_image.width);
img = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
resized_image.width,
resized_image.height,
3,
1);
sd_image_f32_to_tensor(resized_image, img);
free(resized_image.data);
resized_image.data = nullptr;
} else {
img = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32, GGML_TYPE_F32,
ref_images[i]->width, ref_images[i]->width,
ref_images[i]->height, ref_images[i]->height,
3, 3,
1); 1);
sd_image_to_tensor(*ref_images[i], img); sd_image_to_tensor(*ref_images[i], img);
}
// print_ggml_tensor(img, false, "img");
ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img); ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img);
ref_latents.push_back(latent); ref_latents.push_back(latent);

View File

@ -299,7 +299,7 @@ std::string trim(const std::string& s) {
static sd_log_cb_t sd_log_cb = NULL; static sd_log_cb_t sd_log_cb = NULL;
void* sd_log_cb_data = NULL; void* sd_log_cb_data = NULL;
#define LOG_BUFFER_SIZE 1024 #define LOG_BUFFER_SIZE 4096
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) { void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) {
va_list args; va_list args;
@ -388,10 +388,10 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
float original_x = (float)x * image.width / target_width; float original_x = (float)x * image.width / target_width;
float original_y = (float)y * image.height / target_height; float original_y = (float)y * image.height / target_height;
int x1 = (int)original_x; uint32_t x1 = (uint32_t)original_x;
int y1 = (int)original_y; uint32_t y1 = (uint32_t)original_y;
int x2 = x1 + 1; uint32_t x2 = std::min(x1 + 1, image.width - 1);
int y2 = y1 + 1; uint32_t y2 = std::min(y1 + 1, image.height - 1);
for (int k = 0; k < image.channel; k++) { for (int k = 0; k < image.channel; k++) {
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k); float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
@ -444,10 +444,10 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe
float original_x = (float)x * image.width / resized_width; float original_x = (float)x * image.width / resized_width;
float original_y = (float)y * image.height / resized_height; float original_y = (float)y * image.height / resized_height;
int x1 = (int)original_x; uint32_t x1 = (uint32_t)original_x;
int y1 = (int)original_y; uint32_t y1 = (uint32_t)original_y;
int x2 = x1 + 1; uint32_t x2 = std::min(x1 + 1, image.width - 1);
int y2 = y1 + 1; uint32_t y2 = std::min(y1 + 1, image.height - 1);
for (int k = 0; k < image.channel; k++) { for (int k = 0; k < image.channel; k++) {
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k); float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
@ -478,8 +478,10 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe
for (int k = 0; k < image.channel; k++) { for (int k = 0; k < image.channel; k++) {
for (int i = 0; i < result.height; i++) { for (int i = 0; i < result.height; i++) {
for (int j = 0; j < result.width; j++) { for (int j = 0; j < result.width; j++) {
int src_y = std::min(i + h_offset, resized_height - 1);
int src_x = std::min(j + w_offset, resized_width - 1);
*(result.data + i * result.width * image.channel + j * image.channel + k) = *(result.data + i * result.width * image.channel + j * image.channel + k) =
fmin(fmax(*(resized_data + (i + h_offset) * resized_width * image.channel + (j + w_offset) * image.channel + k), 0.0f), 255.0f) / 255.0f; fmin(fmax(*(resized_data + src_y * resized_width * image.channel + src_x * image.channel + k), 0.0f), 255.0f) / 255.0f;
} }
} }
} }