mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
fix qwen image edit pipeline
This commit is contained in:
parent
40752b629f
commit
887055edce
@ -1409,9 +1409,19 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
||||||
size_t max_length = 0,
|
size_t max_length = 0,
|
||||||
bool padding = false) {
|
size_t system_prompt_length = 0,
|
||||||
auto parsed_attention = parse_prompt_attention(text);
|
bool padding = false) {
|
||||||
|
std::vector<std::pair<std::string, float>> parsed_attention;
|
||||||
|
if (system_prompt_length > 0) {
|
||||||
|
parsed_attention.emplace_back(text.substr(0, system_prompt_length), 1.f);
|
||||||
|
auto new_parsed_attention = parse_prompt_attention(text.substr(system_prompt_length, text.size() - system_prompt_length));
|
||||||
|
parsed_attention.insert(parsed_attention.end(),
|
||||||
|
new_parsed_attention.begin(),
|
||||||
|
new_parsed_attention.end());
|
||||||
|
} else {
|
||||||
|
parsed_attention = parse_prompt_attention(text);
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
@ -1436,7 +1446,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
|||||||
tokenizer.pad_tokens(tokens, weights, max_length, padding);
|
tokenizer.pad_tokens(tokens, weights, max_length, padding);
|
||||||
|
|
||||||
// for (int i = 0; i < tokens.size(); i++) {
|
// for (int i = 0; i < tokens.size(); i++) {
|
||||||
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
// std::cout << tokens[i] << ":" << weights[i] << ", " << i << std::endl;
|
||||||
// }
|
// }
|
||||||
// std::cout << std::endl;
|
// std::cout << std::endl;
|
||||||
|
|
||||||
@ -1448,12 +1458,13 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
|||||||
const ConditionerParams& conditioner_params) {
|
const ConditionerParams& conditioner_params) {
|
||||||
std::string prompt;
|
std::string prompt;
|
||||||
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
|
||||||
|
size_t system_prompt_length = 0;
|
||||||
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) {
|
if (qwenvl->enable_vision && conditioner_params.ref_images.size() > 0) {
|
||||||
LOG_INFO("QwenImageEditPlusPipeline");
|
LOG_INFO("QwenImageEditPlusPipeline");
|
||||||
prompt_template_encode_start_idx = 64;
|
prompt_template_encode_start_idx = 64;
|
||||||
int image_embed_idx = 64 + 6;
|
int image_embed_idx = 64 + 6;
|
||||||
|
|
||||||
int min_pixels = 56 * 56;
|
int min_pixels = 384 * 384;
|
||||||
int max_pixels = 560 * 560;
|
int max_pixels = 560 * 560;
|
||||||
std::string placeholder = "<|image_pad|>";
|
std::string placeholder = "<|image_pad|>";
|
||||||
std::string img_prompt;
|
std::string img_prompt;
|
||||||
@ -1485,7 +1496,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
|||||||
image.data = nullptr;
|
image.data = nullptr;
|
||||||
|
|
||||||
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
|
ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
|
||||||
sd_image_f32_to_tensor(resized_image.data, image_tensor, false);
|
sd_image_f32_to_tensor(resized_image, image_tensor, false);
|
||||||
free(resized_image.data);
|
free(resized_image.data);
|
||||||
resized_image.data = nullptr;
|
resized_image.data = nullptr;
|
||||||
|
|
||||||
@ -1505,6 +1516,8 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
|||||||
|
|
||||||
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
|
prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n";
|
||||||
|
|
||||||
|
system_prompt_length = prompt.size();
|
||||||
|
|
||||||
prompt += img_prompt;
|
prompt += img_prompt;
|
||||||
prompt += conditioner_params.text;
|
prompt += conditioner_params.text;
|
||||||
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
prompt += "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
@ -1512,7 +1525,7 @@ struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
|||||||
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n";
|
prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tokens_and_weights = tokenize(prompt, 0, false);
|
auto tokens_and_weights = tokenize(prompt, 0, system_prompt_length, false);
|
||||||
auto& tokens = std::get<0>(tokens_and_weights);
|
auto& tokens = std::get<0>(tokens_and_weights);
|
||||||
auto& weights = std::get<1>(tokens_and_weights);
|
auto& weights = std::get<1>(tokens_and_weights);
|
||||||
|
|
||||||
|
|||||||
@ -314,7 +314,7 @@ struct QwenImageModel : public DiffusionModel {
|
|||||||
diffusion_params.timesteps,
|
diffusion_params.timesteps,
|
||||||
diffusion_params.context,
|
diffusion_params.context,
|
||||||
diffusion_params.ref_latents,
|
diffusion_params.ref_latents,
|
||||||
diffusion_params.increase_ref_index,
|
true, // increase_ref_index
|
||||||
output,
|
output,
|
||||||
output_ctx);
|
output_ctx);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -27,8 +27,6 @@
|
|||||||
|
|
||||||
#include "avi_writer.h"
|
#include "avi_writer.h"
|
||||||
|
|
||||||
#include "qwenvl.hpp"
|
|
||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
#define NOMINMAX
|
#define NOMINMAX
|
||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
|
|||||||
@ -193,8 +193,11 @@ __STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int i
|
|||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic) {
|
__STATIC_INLINE__ float sd_image_get_f32(sd_image_f32_t image, int iw, int ih, int ic, bool scale = true) {
|
||||||
float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic);
|
float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic);
|
||||||
|
if (scale) {
|
||||||
|
value /= 255.f;
|
||||||
|
}
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -446,24 +449,18 @@ __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ void sd_image_f32_to_tensor(const float* image_data,
|
__STATIC_INLINE__ void sd_image_f32_to_tensor(sd_image_f32_t image,
|
||||||
struct ggml_tensor* output,
|
ggml_tensor* tensor,
|
||||||
bool scale = true) {
|
bool scale = true) {
|
||||||
int64_t width = output->ne[0];
|
GGML_ASSERT(image.width == tensor->ne[0]);
|
||||||
int64_t height = output->ne[1];
|
GGML_ASSERT(image.height == tensor->ne[1]);
|
||||||
int64_t channels = output->ne[2];
|
GGML_ASSERT(image.channel == tensor->ne[2]);
|
||||||
GGML_ASSERT(channels == 3 && output->type == GGML_TYPE_F32);
|
GGML_ASSERT(1 == tensor->ne[3]);
|
||||||
for (int iy = 0; iy < height; iy++) {
|
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
|
||||||
for (int ix = 0; ix < width; ix++) {
|
ggml_tensor_iter(tensor, [&](ggml_tensor* tensor, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
for (int k = 0; k < channels; k++) {
|
float value = sd_image_get_f32(image, i0, i1, i2, scale);
|
||||||
int value = *(image_data + iy * width * channels + ix * channels + k);
|
ggml_tensor_set_f32(tensor, value, i0, i1, i2, i3);
|
||||||
if (scale) {
|
});
|
||||||
value /= 255.f;
|
|
||||||
}
|
|
||||||
ggml_tensor_set_f32(output, value, ix, iy, k);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,
|
__STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,
|
||||||
|
|||||||
@ -156,7 +156,7 @@ namespace Qwen {
|
|||||||
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head]
|
||||||
|
|
||||||
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 128.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
|
auto attn = Rope::attention(ctx, backend, q, k, v, pe, mask, flash_attn, (1.0f / 256.f)); // [N, n_txt_token + n_img_token, n_head*d_head]
|
||||||
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size]
|
||||||
auto txt_attn_out = ggml_view_3d(ctx,
|
auto txt_attn_out = ggml_view_3d(ctx,
|
||||||
attn,
|
attn,
|
||||||
|
|||||||
@ -730,11 +730,10 @@ namespace Qwen {
|
|||||||
input_embed = ggml_concat(ctx, input_embed, image_embed, 1);
|
input_embed = ggml_concat(ctx, input_embed, image_embed, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto final_txt_embed = ggml_slice(ctx,
|
txt_token_start = image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1];
|
||||||
raw_x,
|
txt_token_end = raw_x->ne[1];
|
||||||
1,
|
|
||||||
image_embeds[image_embeds.size() - 1].first + image_embeds[image_embeds.size() - 1].second->ne[1],
|
auto final_txt_embed = ggml_slice(ctx, raw_x, 1, txt_token_start, txt_token_end);
|
||||||
raw_x->ne[1]);
|
|
||||||
|
|
||||||
input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1);
|
input_embed = ggml_concat(ctx, input_embed, final_txt_embed, 1);
|
||||||
GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]);
|
GGML_ASSERT(raw_x->ne[1] == input_embed->ne[1]);
|
||||||
|
|||||||
4
rope.hpp
4
rope.hpp
@ -222,8 +222,8 @@ namespace Rope {
|
|||||||
int context_len,
|
int context_len,
|
||||||
const std::vector<ggml_tensor*>& ref_latents,
|
const std::vector<ggml_tensor*>& ref_latents,
|
||||||
bool increase_ref_index) {
|
bool increase_ref_index) {
|
||||||
int h_len = (h + (patch_size / 2)) / patch_size;
|
int h_len = (h + (patch_size / 2)) / patch_size / 2;
|
||||||
int w_len = (w + (patch_size / 2)) / patch_size;
|
int w_len = (w + (patch_size / 2)) / patch_size / 2;
|
||||||
int txt_id_start = std::max(h_len, w_len);
|
int txt_id_start = std::max(h_len, w_len);
|
||||||
auto txt_ids = linspace<float>(txt_id_start, context_len + txt_id_start, context_len);
|
auto txt_ids = linspace<float>(txt_id_start, context_len + txt_id_start, context_len);
|
||||||
std::vector<std::vector<float>> txt_ids_repeated(bs * context_len, std::vector<float>(3));
|
std::vector<std::vector<float>> txt_ids_repeated(bs * context_len, std::vector<float>(3));
|
||||||
|
|||||||
@ -973,7 +973,7 @@ public:
|
|||||||
image.data = NULL;
|
image.data = NULL;
|
||||||
|
|
||||||
ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
|
ggml_tensor* pixel_values = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1);
|
||||||
sd_image_f32_to_tensor(resized_image.data, pixel_values, false);
|
sd_image_f32_to_tensor(resized_image, pixel_values, false);
|
||||||
free(resized_image.data);
|
free(resized_image.data);
|
||||||
resized_image.data = NULL;
|
resized_image.data = NULL;
|
||||||
|
|
||||||
@ -1010,7 +1010,7 @@ public:
|
|||||||
sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height);
|
sd_image_f32_t resized_image = resize_sd_image_f32_t(image, width, height);
|
||||||
free(image.data);
|
free(image.data);
|
||||||
image.data = NULL;
|
image.data = NULL;
|
||||||
sd_image_f32_to_tensor(resized_image.data, init_img, false);
|
sd_image_f32_to_tensor(resized_image, init_img, false);
|
||||||
free(resized_image.data);
|
free(resized_image.data);
|
||||||
resized_image.data = NULL;
|
resized_image.data = NULL;
|
||||||
} else {
|
} else {
|
||||||
@ -2006,8 +2006,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
seed = rand();
|
seed = rand();
|
||||||
}
|
}
|
||||||
|
|
||||||
print_ggml_tensor(init_latent, true, "init");
|
|
||||||
|
|
||||||
// for (auto v : sigmas) {
|
// for (auto v : sigmas) {
|
||||||
// std::cout << v << " ";
|
// std::cout << v << " ";
|
||||||
// }
|
// }
|
||||||
@ -2485,13 +2483,42 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
|
|
||||||
std::vector<ggml_tensor*> ref_latents;
|
std::vector<ggml_tensor*> ref_latents;
|
||||||
for (int i = 0; i < ref_images.size(); i++) {
|
for (int i = 0; i < ref_images.size(); i++) {
|
||||||
ggml_tensor* img = ggml_new_tensor_4d(work_ctx,
|
ggml_tensor* img;
|
||||||
GGML_TYPE_F32,
|
if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
|
||||||
ref_images[i]->width,
|
sd_image_f32_t ref_image = sd_image_t_to_sd_image_f32_t(*ref_images[i]);
|
||||||
ref_images[i]->height,
|
int VAE_IMAGE_SIZE = std::min(1024 * 1024, width * height);
|
||||||
3,
|
double vae_width = sqrt(VAE_IMAGE_SIZE * ref_image.width / ref_image.height);
|
||||||
1);
|
double vae_height = vae_width * ref_image.height / ref_image.width;
|
||||||
sd_image_to_tensor(*ref_images[i], img);
|
|
||||||
|
vae_height = round(vae_height / 32) * 32;
|
||||||
|
vae_width = round(vae_width / 32) * 32;
|
||||||
|
|
||||||
|
sd_image_f32_t resized_image = resize_sd_image_f32_t(ref_image, static_cast<int>(vae_width), static_cast<int>(vae_height));
|
||||||
|
free(ref_image.data);
|
||||||
|
ref_image.data = nullptr;
|
||||||
|
|
||||||
|
LOG_DEBUG("resize vae ref image %d from %dx%d to %dx%d", i, ref_image.height, ref_image.width, resized_image.height, resized_image.width);
|
||||||
|
|
||||||
|
img = ggml_new_tensor_4d(work_ctx,
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
resized_image.width,
|
||||||
|
resized_image.height,
|
||||||
|
3,
|
||||||
|
1);
|
||||||
|
sd_image_f32_to_tensor(resized_image, img);
|
||||||
|
free(resized_image.data);
|
||||||
|
resized_image.data = nullptr;
|
||||||
|
} else {
|
||||||
|
img = ggml_new_tensor_4d(work_ctx,
|
||||||
|
GGML_TYPE_F32,
|
||||||
|
ref_images[i]->width,
|
||||||
|
ref_images[i]->height,
|
||||||
|
3,
|
||||||
|
1);
|
||||||
|
sd_image_to_tensor(*ref_images[i], img);
|
||||||
|
}
|
||||||
|
|
||||||
|
// print_ggml_tensor(img, false, "img");
|
||||||
|
|
||||||
ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img);
|
ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img);
|
||||||
ref_latents.push_back(latent);
|
ref_latents.push_back(latent);
|
||||||
|
|||||||
22
util.cpp
22
util.cpp
@ -299,7 +299,7 @@ std::string trim(const std::string& s) {
|
|||||||
static sd_log_cb_t sd_log_cb = NULL;
|
static sd_log_cb_t sd_log_cb = NULL;
|
||||||
void* sd_log_cb_data = NULL;
|
void* sd_log_cb_data = NULL;
|
||||||
|
|
||||||
#define LOG_BUFFER_SIZE 1024
|
#define LOG_BUFFER_SIZE 4096
|
||||||
|
|
||||||
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) {
|
void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...) {
|
||||||
va_list args;
|
va_list args;
|
||||||
@ -388,10 +388,10 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int
|
|||||||
float original_x = (float)x * image.width / target_width;
|
float original_x = (float)x * image.width / target_width;
|
||||||
float original_y = (float)y * image.height / target_height;
|
float original_y = (float)y * image.height / target_height;
|
||||||
|
|
||||||
int x1 = (int)original_x;
|
uint32_t x1 = (uint32_t)original_x;
|
||||||
int y1 = (int)original_y;
|
uint32_t y1 = (uint32_t)original_y;
|
||||||
int x2 = x1 + 1;
|
uint32_t x2 = std::min(x1 + 1, image.width - 1);
|
||||||
int y2 = y1 + 1;
|
uint32_t y2 = std::min(y1 + 1, image.height - 1);
|
||||||
|
|
||||||
for (int k = 0; k < image.channel; k++) {
|
for (int k = 0; k < image.channel; k++) {
|
||||||
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
|
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
|
||||||
@ -444,10 +444,10 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe
|
|||||||
float original_x = (float)x * image.width / resized_width;
|
float original_x = (float)x * image.width / resized_width;
|
||||||
float original_y = (float)y * image.height / resized_height;
|
float original_y = (float)y * image.height / resized_height;
|
||||||
|
|
||||||
int x1 = (int)original_x;
|
uint32_t x1 = (uint32_t)original_x;
|
||||||
int y1 = (int)original_y;
|
uint32_t y1 = (uint32_t)original_y;
|
||||||
int x2 = x1 + 1;
|
uint32_t x2 = std::min(x1 + 1, image.width - 1);
|
||||||
int y2 = y1 + 1;
|
uint32_t y2 = std::min(y1 + 1, image.height - 1);
|
||||||
|
|
||||||
for (int k = 0; k < image.channel; k++) {
|
for (int k = 0; k < image.channel; k++) {
|
||||||
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
|
float v1 = *(image.data + y1 * image.width * image.channel + x1 * image.channel + k);
|
||||||
@ -478,8 +478,10 @@ sd_image_f32_t clip_preprocess(sd_image_f32_t image, int target_width, int targe
|
|||||||
for (int k = 0; k < image.channel; k++) {
|
for (int k = 0; k < image.channel; k++) {
|
||||||
for (int i = 0; i < result.height; i++) {
|
for (int i = 0; i < result.height; i++) {
|
||||||
for (int j = 0; j < result.width; j++) {
|
for (int j = 0; j < result.width; j++) {
|
||||||
|
int src_y = std::min(i + h_offset, resized_height - 1);
|
||||||
|
int src_x = std::min(j + w_offset, resized_width - 1);
|
||||||
*(result.data + i * result.width * image.channel + j * image.channel + k) =
|
*(result.data + i * result.width * image.channel + j * image.channel + k) =
|
||||||
fmin(fmax(*(resized_data + (i + h_offset) * resized_width * image.channel + (j + w_offset) * image.channel + k), 0.0f), 255.0f) / 255.0f;
|
fmin(fmax(*(resized_data + src_y * resized_width * image.channel + src_x * image.channel + k), 0.0f), 255.0f) / 255.0f;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user