add ref latent support for qwen image

This commit is contained in:
leejet 2025-09-23 23:34:51 +08:00
parent 94f4f295c1
commit 4e48e6b82b
3 changed files with 78 additions and 24 deletions

View File

@ -313,6 +313,8 @@ struct QwenImageModel : public DiffusionModel {
diffusion_params.x, diffusion_params.x,
diffusion_params.timesteps, diffusion_params.timesteps,
diffusion_params.context, diffusion_params.context,
diffusion_params.ref_latents,
diffusion_params.increase_ref_index,
output, output,
output_ctx); output_ctx);
} }

View File

@ -386,6 +386,13 @@ namespace Qwen {
return x; return x;
} }
struct ggml_tensor* process_img(struct ggml_context* ctx,
struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x);
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx, struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int64_t h, int64_t h,
@ -446,7 +453,8 @@ namespace Qwen {
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timestep, struct ggml_tensor* timestep,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* pe) { struct ggml_tensor* pe,
std::vector<ggml_tensor*> ref_latents = {}) {
// Forward pass of DiT. // Forward pass of DiT.
// x: [N, C, H, W] // x: [N, C, H, W]
// timestep: [N,] // timestep: [N,]
@ -459,13 +467,26 @@ namespace Qwen {
int64_t C = x->ne[2]; int64_t C = x->ne[2];
int64_t N = x->ne[3]; int64_t N = x->ne[3];
x = pad_to_patch_size(ctx, x); auto img = process_img(ctx, x);
x = patchify(ctx, x); uint64_t img_tokens = img->ne[1];
if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx, ref);
img = ggml_concat(ctx, img, ref, 1);
}
}
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size); int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size); int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
auto out = forward_orig(ctx, backend, x, timestep, context, pe); // [N, h_len*w_len, ph*pw*C] auto out = forward_orig(ctx, backend, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
}
out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
@ -506,7 +527,9 @@ namespace Qwen {
struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context) { struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false) {
GGML_ASSERT(x->ne[3] == 1); GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false);
@ -514,18 +537,24 @@ namespace Qwen {
context = to_backend(context); context = to_backend(context);
timesteps = to_backend(timesteps); timesteps = to_backend(timesteps);
for (int i = 0; i < ref_latents.size(); i++) {
ref_latents[i] = to_backend(ref_latents[i]);
}
pe_vec = Rope::gen_qwen_image_pe(x->ne[1], pe_vec = Rope::gen_qwen_image_pe(x->ne[1],
x->ne[0], x->ne[0],
qwen_image_params.patch_size, qwen_image_params.patch_size,
x->ne[3], x->ne[3],
context->ne[1], context->ne[1],
ref_latents,
increase_ref_index,
qwen_image_params.theta, qwen_image_params.theta,
qwen_image_params.axes_dim); qwen_image_params.axes_dim);
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data(); // pe->data = pe_vec.data();
// print_ggml_tensor(pe); // print_ggml_tensor(pe, true, "pe");
// pe->data = NULL; // pe->data = NULL;
set_backend_tensor_data(pe, pe_vec.data()); set_backend_tensor_data(pe, pe_vec.data());
@ -534,7 +563,8 @@ namespace Qwen {
x, x,
timesteps, timesteps,
context, context,
pe); pe,
ref_latents);
ggml_build_forward_expand(gf, out); ggml_build_forward_expand(gf, out);
@ -545,13 +575,15 @@ namespace Qwen {
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor** output = NULL, std::vector<ggml_tensor*> ref_latents = {},
struct ggml_context* output_ctx = NULL) { bool increase_ref_index = false,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
// x: [N, in_channels, h, w] // x: [N, in_channels, h, w]
// timesteps: [N, ] // timesteps: [N, ]
// context: [N, max_position, hidden_size] // context: [N, max_position, hidden_size]
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context); return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
}; };
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -583,7 +615,7 @@ namespace Qwen {
struct ggml_tensor* out = NULL; struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms(); int t0 = ggml_time_ms();
compute(8, x, timesteps, context, &out, work_ctx); compute(8, x, timesteps, context, {}, false, &out, work_ctx);
int t1 = ggml_time_ms(); int t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);

View File

@ -151,17 +151,11 @@ struct Rope {
return flatten(emb); return flatten(emb);
} }
static std::vector<std::vector<float>> gen_flux_ids(int h, static std::vector<std::vector<float>> gen_refs_ids(int patch_size,
int w,
int patch_size,
int bs, int bs,
int context_len, const std::vector<ggml_tensor*>& ref_latents,
std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index) { bool increase_ref_index) {
auto txt_ids = gen_txt_ids(bs, context_len); std::vector<std::vector<float>> ids;
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids, img_ids, bs);
uint64_t curr_h_offset = 0; uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0; uint64_t curr_w_offset = 0;
int index = 1; int index = 1;
@ -189,13 +183,31 @@ struct Rope {
return ids; return ids;
} }
static std::vector<std::vector<float>> gen_flux_ids(int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
auto txt_ids = gen_txt_ids(bs, context_len);
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
}
// Generate flux positional embeddings // Generate flux positional embeddings
static std::vector<float> gen_flux_pe(int h, static std::vector<float> gen_flux_pe(int h,
int w, int w,
int patch_size, int patch_size,
int bs, int bs,
int context_len, int context_len,
std::vector<ggml_tensor*> ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, bool increase_ref_index,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
@ -207,7 +219,9 @@ struct Rope {
int w, int w,
int patch_size, int patch_size,
int bs, int bs,
int context_len) { int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
int h_len = (h + (patch_size / 2)) / patch_size; int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size;
int txt_id_start = std::max(h_len, w_len); int txt_id_start = std::max(h_len, w_len);
@ -220,6 +234,10 @@ struct Rope {
} }
auto img_ids = gen_img_ids(h, w, patch_size, bs); auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs); auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
ids = concat_ids(ids, refs_ids, bs);
}
return ids; return ids;
} }
@ -229,9 +247,11 @@ struct Rope {
int patch_size, int patch_size,
int bs, int bs,
int context_len, int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
int theta, int theta,
const std::vector<int>& axes_dim) { const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len); std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim); return embed_nd(ids, bs, theta, axes_dim);
} }