add ref latent support for qwen image

This commit is contained in:
leejet 2025-09-23 23:34:51 +08:00
parent 94f4f295c1
commit 4e48e6b82b
3 changed files with 78 additions and 24 deletions

View File

@ -313,6 +313,8 @@ struct QwenImageModel : public DiffusionModel {
diffusion_params.x,
diffusion_params.timesteps,
diffusion_params.context,
diffusion_params.ref_latents,
diffusion_params.increase_ref_index,
output,
output_ctx);
}

View File

@ -256,7 +256,7 @@ namespace Qwen {
auto txt_gate1 = txt_mod_param_vec[2];
auto [img_attn_output, txt_attn_output] = attn->forward(ctx, backend, img_modulated, txt_modulated, pe);
img = ggml_add(ctx, img, ggml_mul(ctx, img_attn_output, img_gate1));
txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn_output, txt_gate1));
@ -386,6 +386,13 @@ namespace Qwen {
return x;
}
struct ggml_tensor* process_img(struct ggml_context* ctx,
struct ggml_tensor* x) {
x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x);
return x;
}
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
struct ggml_tensor* x,
int64_t h,
@ -446,7 +453,8 @@ namespace Qwen {
struct ggml_tensor* x,
struct ggml_tensor* timestep,
struct ggml_tensor* context,
struct ggml_tensor* pe) {
struct ggml_tensor* pe,
std::vector<ggml_tensor*> ref_latents = {}) {
// Forward pass of DiT.
// x: [N, C, H, W]
// timestep: [N,]
@ -459,13 +467,26 @@ namespace Qwen {
int64_t C = x->ne[2];
int64_t N = x->ne[3];
x = pad_to_patch_size(ctx, x);
x = patchify(ctx, x);
auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1];
if (ref_latents.size() > 0) {
for (ggml_tensor* ref : ref_latents) {
ref = process_img(ctx, ref);
img = ggml_concat(ctx, img, ref, 1);
}
}
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
auto out = forward_orig(ctx, backend, x, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
auto out = forward_orig(ctx, backend, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size]
}
out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w]
@ -506,7 +527,9 @@ namespace Qwen {
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context) {
struct ggml_tensor* context,
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false) {
GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false);
@ -514,18 +537,24 @@ namespace Qwen {
context = to_backend(context);
timesteps = to_backend(timesteps);
for (int i = 0; i < ref_latents.size(); i++) {
ref_latents[i] = to_backend(ref_latents[i]);
}
pe_vec = Rope::gen_qwen_image_pe(x->ne[1],
x->ne[0],
qwen_image_params.patch_size,
x->ne[3],
context->ne[1],
ref_latents,
increase_ref_index,
qwen_image_params.theta,
qwen_image_params.axes_dim);
int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len);
// pe->data = pe_vec.data();
// print_ggml_tensor(pe);
// print_ggml_tensor(pe, true, "pe");
// pe->data = NULL;
set_backend_tensor_data(pe, pe_vec.data());
@ -534,7 +563,8 @@ namespace Qwen {
x,
timesteps,
context,
pe);
pe,
ref_latents);
ggml_build_forward_expand(gf, out);
@ -545,13 +575,15 @@ namespace Qwen {
struct ggml_tensor* x,
struct ggml_tensor* timesteps,
struct ggml_tensor* context,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
// x: [N, in_channels, h, w]
// timesteps: [N, ]
// context: [N, max_position, hidden_size]
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context);
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
};
GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -583,7 +615,7 @@ namespace Qwen {
struct ggml_tensor* out = NULL;
int t0 = ggml_time_ms();
compute(8, x, timesteps, context, &out, work_ctx);
compute(8, x, timesteps, context, {}, false, &out, work_ctx);
int t1 = ggml_time_ms();
print_ggml_tensor(out);

View File

@ -151,17 +151,11 @@ struct Rope {
return flatten(emb);
}
static std::vector<std::vector<float>> gen_flux_ids(int h,
int w,
int patch_size,
static std::vector<std::vector<float>> gen_refs_ids(int patch_size,
int bs,
int context_len,
std::vector<ggml_tensor*> ref_latents,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
auto txt_ids = gen_txt_ids(bs, context_len);
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids, img_ids, bs);
std::vector<std::vector<float>> ids;
uint64_t curr_h_offset = 0;
uint64_t curr_w_offset = 0;
int index = 1;
@ -189,13 +183,31 @@ struct Rope {
return ids;
}
static std::vector<std::vector<float>> gen_flux_ids(int h,
int w,
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
auto txt_ids = gen_txt_ids(bs, context_len);
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
}
// Generate flux positional embeddings
static std::vector<float> gen_flux_pe(int h,
int w,
int patch_size,
int bs,
int context_len,
std::vector<ggml_tensor*> ref_latents,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
int theta,
const std::vector<int>& axes_dim) {
@ -207,7 +219,9 @@ struct Rope {
int w,
int patch_size,
int bs,
int context_len) {
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) {
int h_len = (h + (patch_size / 2)) / patch_size;
int w_len = (w + (patch_size / 2)) / patch_size;
int txt_id_start = std::max(h_len, w_len);
@ -220,6 +234,10 @@ struct Rope {
}
auto img_ids = gen_img_ids(h, w, patch_size, bs);
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
if (ref_latents.size() > 0) {
auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index);
ids = concat_ids(ids, refs_ids, bs);
}
return ids;
}
@ -229,9 +247,11 @@ struct Rope {
int patch_size,
int bs,
int context_len,
const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index,
int theta,
const std::vector<int>& axes_dim) {
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len);
std::vector<std::vector<float>> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index);
return embed_nd(ids, bs, theta, axes_dim);
}