From 8004d32de243dbe5044a85f6f766d5e240b2c262 Mon Sep 17 00:00:00 2001 From: leejet Date: Thu, 25 Dec 2025 01:35:11 +0800 Subject: [PATCH] z-image-omni-base rope --- diffusion_model.hpp | 4 +- rope.hpp | 115 +++++++++++++++++++++++++++++++++----------- z_image.hpp | 49 ++++++++++--------- 3 files changed, 114 insertions(+), 54 deletions(-) diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 06cbecc..3f735ac 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -439,9 +439,9 @@ struct ZImageModel : public DiffusionModel { return z_image.compute(n_threads, diffusion_params.x, diffusion_params.timesteps, - diffusion_params.context, + {diffusion_params.context}, diffusion_params.ref_latents, - true, // increase_ref_index + {}, output, output_ctx); } diff --git a/rope.hpp b/rope.hpp index 4e6136c..eb53260 100644 --- a/rope.hpp +++ b/rope.hpp @@ -518,60 +518,117 @@ namespace Rope { return (m - (a % m)) % m; } - __STATIC_INLINE__ std::vector> gen_z_image_ids(int h, - int w, - int patch_size, - int bs, - int context_len, - int seq_multi_of, + __STATIC_INLINE__ std::vector> gen_z_image_ids(ggml_tensor* x, + const std::vector& contexts, const std::vector& ref_latents, - bool increase_ref_index) { - int padded_context_len = context_len + bound_mod(context_len, seq_multi_of); - auto txt_ids = std::vector>(bs * padded_context_len, std::vector(3, 0.0f)); - for (int i = 0; i < bs * padded_context_len; i++) { - txt_ids[i][0] = (i % padded_context_len) + 1.f; + const std::vector& siglip_feats, + int patch_size, + int seq_multi_of, + int bs) { + GGML_ASSERT(contexts.size() > ref_latents.size()); + GGML_ASSERT(contexts.size() >= siglip_feats.size()); + int context_cu_len = 1; + std::vector context_end_pos; + std::vector> txt_ids; + for (auto context : contexts) { + int padded_context_len = context->ne[1] + bound_mod(context->ne[1], seq_multi_of); + auto curr_txt_ids = std::vector>(bs * padded_context_len, std::vector(3, 0.0f)); + for (int i = 0; i < bs * padded_context_len; i++) { + curr_txt_ids[i][0] = static_cast((i % padded_context_len) + context_cu_len); + } + context_cu_len += padded_context_len; + context_end_pos.push_back(context_cu_len); + context_cu_len += 2; // for image and siglip tokens + txt_ids = concat_ids(txt_ids, curr_txt_ids, bs); } - int axes_dim_num = 3; - int index = padded_context_len + 1; - auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, index); + std::vector> img_ids; + std::vector all_img = ref_latents; + all_img.push_back(x); + for (int i = 0; i < all_img.size(); i++) { + int axes_dim_num = 3; + int index = context_end_pos[i]; + auto curr_img_ids = gen_flux_img_ids(all_img[i]->ne[1], all_img[i]->ne[0], patch_size, bs, axes_dim_num, index); - int img_pad_len = bound_mod(static_cast(img_ids.size() / bs), seq_multi_of); - if (img_pad_len > 0) { - std::vector> img_pad_ids(bs * img_pad_len, std::vector(3, 0.f)); - img_ids = concat_ids(img_ids, img_pad_ids, bs); + int img_pad_len = bound_mod(static_cast(curr_img_ids.size() / bs), seq_multi_of); + if (img_pad_len > 0) { + std::vector> img_pad_ids(bs * img_pad_len, std::vector(3, 0.f)); + curr_img_ids = concat_ids(curr_img_ids, img_pad_ids, bs); + } + img_ids = concat_ids(img_ids, curr_img_ids, bs); + } + + std::vector> sig_ids; + for (int i = 0; i < siglip_feats.size(); i++) { + int axes_dim_num = 3; + int index = context_end_pos[i] + 1; + int h_len = siglip_feats[i]->ne[1]; + int w_len = siglip_feats[i]->ne[0]; + + std::vector> curr_sig_ids(bs * h_len * w_len, std::vector(axes_dim_num, 0.0)); + + // scale position IDs to match img resolution + std::vector row_ids = linspace(0, all_img[i]->ne[1] - 1, h_len); + std::vector col_ids = linspace(0, all_img[i]->ne[0] - 1, w_len); + + for (int ib = 0; ib < bs; ++ib) { + for (int ih = 0; ih < h_len; ++ih) { + for (int iw = 0; iw < w_len; ++iw) { + curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][0] = index; + curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][1] = row_ids[ih]; + curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][2] = col_ids[iw]; + } + } + } + + int sig_pad_len = bound_mod(static_cast(curr_sig_ids.size() / bs), seq_multi_of); + if (sig_pad_len > 0) { + std::vector> sig_pad_ids(bs * sig_pad_len, std::vector(3, 0.f)); + curr_sig_ids = concat_ids(curr_sig_ids, sig_pad_ids, bs); + } + sig_ids = concat_ids(sig_ids, curr_sig_ids, bs); } auto ids = concat_ids(txt_ids, img_ids, bs); - // ignore ref_latents for now + if (!sig_ids.empty()) { + ids = concat_ids(ids, sig_ids, bs); + } + return ids; } // Generate z_image positional embeddings - __STATIC_INLINE__ std::vector gen_z_image_pe(int h, - int w, - int patch_size, - int bs, - int context_len, - int seq_multi_of, + __STATIC_INLINE__ std::vector gen_z_image_pe(ggml_tensor* x, + const std::vector& contexts, const std::vector& ref_latents, - bool increase_ref_index, + const std::vector& siglip_feats, + int patch_size, + int seq_multi_of, int theta, + const std::vector& axes_dim, bool circular_h, bool circular_w, - const std::vector& axes_dim) { - std::vector> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index); + int bs) { + std::vector> ids = gen_z_image_ids(x, contexts, ref_latents, siglip_feats, patch_size, seq_multi_of, bs); std::vector> wrap_dims; if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { + int context_len = 0; + for (auto context : contexts) { + int padded_context_len = context->ne[1] + bound_mod(context->ne[1], seq_multi_of); + context_len += padded_context_len; + } + int h = x->ne[1]; + int w = x->ne[0]; int pad_h = (patch_size - (h % patch_size)) % patch_size; int pad_w = (patch_size - (w % patch_size)) % patch_size; int h_len = (h + pad_h) / patch_size; int w_len = (w + pad_w) / patch_size; + if (h_len > 0 && w_len > 0) { size_t pos_len = ids.size() / bs; wrap_dims.assign(axes_dim.size(), std::vector(pos_len, 0)); - size_t cursor = context_len + bound_mod(context_len, seq_multi_of); // skip text (and its padding) + size_t cursor = context_len; // skip text (and its padding) size_t img_tokens = static_cast(h_len) * static_cast(w_len); for (size_t token_i = 0; token_i < img_tokens; ++token_i) { if (circular_h) { diff --git a/z_image.hpp b/z_image.hpp index b3765dd..715b2d2 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -774,34 +774,37 @@ namespace ZImage { z_image.get_param_tensors(tensors, prefix); } - struct ggml_cgraph* build_graph(struct ggml_tensor* x, - struct ggml_tensor* timesteps, - struct ggml_tensor* context, - std::vector ref_latents = {}, - bool increase_ref_index = false) { + struct ggml_cgraph* build_graph(ggml_tensor* x, + ggml_tensor* timesteps, + std::vector contexts, + std::vector ref_latents = {}, + std::vector siglip_feats = {}) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = new_graph_custom(Z_IMAGE_GRAPH_SIZE); - x = to_backend(x); - context = to_backend(context); + x = to_backend(x); + + for (int i = 0; i < contexts.size(); i++) { + contexts[i] = to_backend(contexts[i]); + } + timesteps = to_backend(timesteps); for (int i = 0; i < ref_latents.size(); i++) { ref_latents[i] = to_backend(ref_latents[i]); } - pe_vec = Rope::gen_z_image_pe(x->ne[1], - x->ne[0], - z_image_params.patch_size, - x->ne[3], - context->ne[1], - SEQ_MULTI_OF, + pe_vec = Rope::gen_z_image_pe(x, + contexts, ref_latents, - increase_ref_index, + siglip_feats, + z_image_params.patch_size, + SEQ_MULTI_OF, z_image_params.theta, + z_image_params.axes_dim, circular_y_enabled, circular_x_enabled, - z_image_params.axes_dim); + x->ne[3]); int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len); @@ -814,7 +817,7 @@ namespace ZImage { struct ggml_tensor* out = z_image.forward(&runner_ctx, x, timesteps, - {context}, + contexts, pe, ref_latents); @@ -826,16 +829,16 @@ namespace ZImage { bool compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, - struct ggml_tensor* context, - std::vector ref_latents = {}, - bool increase_ref_index = false, - struct ggml_tensor** output = nullptr, - struct ggml_context* output_ctx = nullptr) { + std::vector contexts, + std::vector ref_latents = {}, + std::vector siglip_feats = {}, + struct ggml_tensor** output = nullptr, + struct ggml_context* output_ctx = nullptr) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, ref_latents, increase_ref_index); + return build_graph(x, timesteps, contexts, ref_latents, siglip_feats); }; return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -867,7 +870,7 @@ namespace ZImage { struct ggml_tensor* out = nullptr; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, {}, false, &out, work_ctx); + compute(8, x, timesteps, {context}, {}, {}, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out);