diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 69cd574..6411857 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -313,6 +313,8 @@ struct QwenImageModel : public DiffusionModel { diffusion_params.x, diffusion_params.timesteps, diffusion_params.context, + diffusion_params.ref_latents, + diffusion_params.increase_ref_index, output, output_ctx); } diff --git a/qwen_image.hpp b/qwen_image.hpp index 2f5dad8..4fb0e47 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -256,7 +256,7 @@ namespace Qwen { auto txt_gate1 = txt_mod_param_vec[2]; auto [img_attn_output, txt_attn_output] = attn->forward(ctx, backend, img_modulated, txt_modulated, pe); - + img = ggml_add(ctx, img, ggml_mul(ctx, img_attn_output, img_gate1)); txt = ggml_add(ctx, txt, ggml_mul(ctx, txt_attn_output, txt_gate1)); @@ -386,6 +386,13 @@ namespace Qwen { return x; } + struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* x) { + x = pad_to_patch_size(ctx, x); + x = patchify(ctx, x); + return x; + } + struct ggml_tensor* unpatchify(struct ggml_context* ctx, struct ggml_tensor* x, int64_t h, @@ -446,7 +453,8 @@ namespace Qwen { struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + std::vector ref_latents = {}) { // Forward pass of DiT. // x: [N, C, H, W] // timestep: [N,] @@ -459,13 +467,26 @@ namespace Qwen { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - x = pad_to_patch_size(ctx, x); - x = patchify(ctx, x); + auto img = process_img(ctx, x); + uint64_t img_tokens = img->ne[1]; + + if (ref_latents.size() > 0) { + for (ggml_tensor* ref : ref_latents) { + ref = process_img(ctx, ref); + img = ggml_concat(ctx, img, ref, 1); + } + } int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size); int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size); - auto out = forward_orig(ctx, backend, x, timestep, context, pe); // [N, h_len*w_len, ph*pw*C] + auto out = forward_orig(ctx, backend, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C] + + if (out->ne[1] > img_tokens) { + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] + out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] + } out = unpatchify(ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] @@ -506,7 +527,9 @@ namespace Qwen { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, - struct ggml_tensor* context) { + struct ggml_tensor* context, + std::vector ref_latents = {}, + bool increase_ref_index = false) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, QWEN_IMAGE_GRAPH_SIZE, false); @@ -514,18 +537,24 @@ namespace Qwen { context = to_backend(context); timesteps = to_backend(timesteps); + for (int i = 0; i < ref_latents.size(); i++) { + ref_latents[i] = to_backend(ref_latents[i]); + } + pe_vec = Rope::gen_qwen_image_pe(x->ne[1], x->ne[0], qwen_image_params.patch_size, x->ne[3], context->ne[1], + ref_latents, + increase_ref_index, qwen_image_params.theta, qwen_image_params.axes_dim); int pos_len = pe_vec.size() / qwen_image_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, qwen_image_params.axes_dim_sum / 2, pos_len); // pe->data = pe_vec.data(); - // print_ggml_tensor(pe); + // print_ggml_tensor(pe, true, "pe"); // pe->data = NULL; set_backend_tensor_data(pe, pe_vec.data()); @@ -534,7 +563,8 @@ namespace Qwen { x, timesteps, context, - pe); + pe, + ref_latents); ggml_build_forward_expand(gf, out); @@ -545,13 +575,15 @@ namespace Qwen { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { + std::vector ref_latents = {}, + bool increase_ref_index = false, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context); + return build_graph(x, timesteps, context, ref_latents, increase_ref_index); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -583,7 +615,7 @@ namespace Qwen { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, &out, work_ctx); + compute(8, x, timesteps, context, {}, false, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); diff --git a/rope.hpp b/rope.hpp index 5e3aaf9..8ecd818 100644 --- a/rope.hpp +++ b/rope.hpp @@ -151,17 +151,11 @@ struct Rope { return flatten(emb); } - static std::vector> gen_flux_ids(int h, - int w, - int patch_size, + static std::vector> gen_refs_ids(int patch_size, int bs, - int context_len, - std::vector ref_latents, + const std::vector& ref_latents, bool increase_ref_index) { - auto txt_ids = gen_txt_ids(bs, context_len); - auto img_ids = gen_img_ids(h, w, patch_size, bs); - - auto ids = concat_ids(txt_ids, img_ids, bs); + std::vector> ids; uint64_t curr_h_offset = 0; uint64_t curr_w_offset = 0; int index = 1; @@ -189,13 +183,31 @@ struct Rope { return ids; } + static std::vector> gen_flux_ids(int h, + int w, + int patch_size, + int bs, + int context_len, + const std::vector& ref_latents, + bool increase_ref_index) { + auto txt_ids = gen_txt_ids(bs, context_len); + auto img_ids = gen_img_ids(h, w, patch_size, bs); + + auto ids = concat_ids(txt_ids, img_ids, bs); + if (ref_latents.size() > 0) { + auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index); + ids = concat_ids(ids, refs_ids, bs); + } + return ids; + } + // Generate flux positional embeddings static std::vector gen_flux_pe(int h, int w, int patch_size, int bs, int context_len, - std::vector ref_latents, + const std::vector& ref_latents, bool increase_ref_index, int theta, const std::vector& axes_dim) { @@ -207,7 +219,9 @@ struct Rope { int w, int patch_size, int bs, - int context_len) { + int context_len, + const std::vector& ref_latents, + bool increase_ref_index) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; int txt_id_start = std::max(h_len, w_len); @@ -220,6 +234,10 @@ struct Rope { } auto img_ids = gen_img_ids(h, w, patch_size, bs); auto ids = concat_ids(txt_ids_repeated, img_ids, bs); + if (ref_latents.size() > 0) { + auto refs_ids = gen_refs_ids(patch_size, bs, ref_latents, increase_ref_index); + ids = concat_ids(ids, refs_ids, bs); + } return ids; } @@ -229,9 +247,11 @@ struct Rope { int patch_size, int bs, int context_len, + const std::vector& ref_latents, + bool increase_ref_index, int theta, const std::vector& axes_dim) { - std::vector> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len); + std::vector> ids = gen_qwen_image_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); return embed_nd(ids, bs, theta, axes_dim); }