z-image-omni-base rope

This commit is contained in:
leejet 2025-12-25 01:35:11 +08:00
parent b0e6680add
commit 8004d32de2
3 changed files with 114 additions and 54 deletions

View File

@ -439,9 +439,9 @@ struct ZImageModel : public DiffusionModel {
return z_image.compute(n_threads, return z_image.compute(n_threads,
diffusion_params.x, diffusion_params.x,
diffusion_params.timesteps, diffusion_params.timesteps,
diffusion_params.context, {diffusion_params.context},
diffusion_params.ref_latents, diffusion_params.ref_latents,
true, // increase_ref_index {},
output, output,
output_ctx); output_ctx);
} }

115
rope.hpp
View File

@ -518,60 +518,117 @@ namespace Rope {
return (m - (a % m)) % m; return (m - (a % m)) % m;
} }
__STATIC_INLINE__ std::vector<std::vector<float>> gen_z_image_ids(int h, __STATIC_INLINE__ std::vector<std::vector<float>> gen_z_image_ids(ggml_tensor* x,
int w, const std::vector<ggml_tensor*>& contexts,
int patch_size,
int bs,
int context_len,
int seq_multi_of,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index) { const std::vector<ggml_tensor*>& siglip_feats,
int padded_context_len = context_len + bound_mod(context_len, seq_multi_of); int patch_size,
auto txt_ids = std::vector<std::vector<float>>(bs * padded_context_len, std::vector<float>(3, 0.0f)); int seq_multi_of,
for (int i = 0; i < bs * padded_context_len; i++) { int bs) {
txt_ids[i][0] = (i % padded_context_len) + 1.f; GGML_ASSERT(contexts.size() > ref_latents.size());
GGML_ASSERT(contexts.size() >= siglip_feats.size());
int context_cu_len = 1;
std::vector<int> context_end_pos;
std::vector<std::vector<float>> txt_ids;
for (auto context : contexts) {
int padded_context_len = context->ne[1] + bound_mod(context->ne[1], seq_multi_of);
auto curr_txt_ids = std::vector<std::vector<float>>(bs * padded_context_len, std::vector<float>(3, 0.0f));
for (int i = 0; i < bs * padded_context_len; i++) {
curr_txt_ids[i][0] = static_cast<float>((i % padded_context_len) + context_cu_len);
}
context_cu_len += padded_context_len;
context_end_pos.push_back(context_cu_len);
context_cu_len += 2; // for image and siglip tokens
txt_ids = concat_ids(txt_ids, curr_txt_ids, bs);
} }
int axes_dim_num = 3; std::vector<std::vector<float>> img_ids;
int index = padded_context_len + 1; std::vector<ggml_tensor*> all_img = ref_latents;
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, index); all_img.push_back(x);
for (int i = 0; i < all_img.size(); i++) {
int axes_dim_num = 3;
int index = context_end_pos[i];
auto curr_img_ids = gen_flux_img_ids(all_img[i]->ne[1], all_img[i]->ne[0], patch_size, bs, axes_dim_num, index);
int img_pad_len = bound_mod(static_cast<int>(img_ids.size() / bs), seq_multi_of); int img_pad_len = bound_mod(static_cast<int>(curr_img_ids.size() / bs), seq_multi_of);
if (img_pad_len > 0) { if (img_pad_len > 0) {
std::vector<std::vector<float>> img_pad_ids(bs * img_pad_len, std::vector<float>(3, 0.f)); std::vector<std::vector<float>> img_pad_ids(bs * img_pad_len, std::vector<float>(3, 0.f));
img_ids = concat_ids(img_ids, img_pad_ids, bs); curr_img_ids = concat_ids(curr_img_ids, img_pad_ids, bs);
}
img_ids = concat_ids(img_ids, curr_img_ids, bs);
}
std::vector<std::vector<float>> sig_ids;
for (int i = 0; i < siglip_feats.size(); i++) {
int axes_dim_num = 3;
int index = context_end_pos[i] + 1;
int h_len = siglip_feats[i]->ne[1];
int w_len = siglip_feats[i]->ne[0];
std::vector<std::vector<float>> curr_sig_ids(bs * h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
// scale position IDs to match img resolution
std::vector<float> row_ids = linspace<float>(0, all_img[i]->ne[1] - 1, h_len);
std::vector<float> col_ids = linspace<float>(0, all_img[i]->ne[0] - 1, w_len);
for (int ib = 0; ib < bs; ++ib) {
for (int ih = 0; ih < h_len; ++ih) {
for (int iw = 0; iw < w_len; ++iw) {
curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][0] = index;
curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][1] = row_ids[ih];
curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][2] = col_ids[iw];
}
}
}
int sig_pad_len = bound_mod(static_cast<int>(curr_sig_ids.size() / bs), seq_multi_of);
if (sig_pad_len > 0) {
std::vector<std::vector<float>> sig_pad_ids(bs * sig_pad_len, std::vector<float>(3, 0.f));
curr_sig_ids = concat_ids(curr_sig_ids, sig_pad_ids, bs);
}
sig_ids = concat_ids(sig_ids, curr_sig_ids, bs);
} }
auto ids = concat_ids(txt_ids, img_ids, bs); auto ids = concat_ids(txt_ids, img_ids, bs);
// ignore ref_latents for now if (!sig_ids.empty()) {
ids = concat_ids(ids, sig_ids, bs);
}
return ids; return ids;
} }
// Generate z_image positional embeddings // Generate z_image positional embeddings
__STATIC_INLINE__ std::vector<float> gen_z_image_pe(int h, __STATIC_INLINE__ std::vector<float> gen_z_image_pe(ggml_tensor* x,
int w, const std::vector<ggml_tensor*>& contexts,
int patch_size,
int bs,
int context_len,
int seq_multi_of,
const std::vector<ggml_tensor*>& ref_latents, const std::vector<ggml_tensor*>& ref_latents,
bool increase_ref_index, const std::vector<ggml_tensor*>& siglip_feats,
int patch_size,
int seq_multi_of,
int theta, int theta,
const std::vector<int>& axes_dim,
bool circular_h, bool circular_h,
bool circular_w, bool circular_w,
const std::vector<int>& axes_dim) { int bs) {
std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index); std::vector<std::vector<float>> ids = gen_z_image_ids(x, contexts, ref_latents, siglip_feats, patch_size, seq_multi_of, bs);
std::vector<std::vector<int>> wrap_dims; std::vector<std::vector<int>> wrap_dims;
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
int context_len = 0;
for (auto context : contexts) {
int padded_context_len = context->ne[1] + bound_mod(context->ne[1], seq_multi_of);
context_len += padded_context_len;
}
int h = x->ne[1];
int w = x->ne[0];
int pad_h = (patch_size - (h % patch_size)) % patch_size; int pad_h = (patch_size - (h % patch_size)) % patch_size;
int pad_w = (patch_size - (w % patch_size)) % patch_size; int pad_w = (patch_size - (w % patch_size)) % patch_size;
int h_len = (h + pad_h) / patch_size; int h_len = (h + pad_h) / patch_size;
int w_len = (w + pad_w) / patch_size; int w_len = (w + pad_w) / patch_size;
if (h_len > 0 && w_len > 0) { if (h_len > 0 && w_len > 0) {
size_t pos_len = ids.size() / bs; size_t pos_len = ids.size() / bs;
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0)); wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
size_t cursor = context_len + bound_mod(context_len, seq_multi_of); // skip text (and its padding) size_t cursor = context_len; // skip text (and its padding)
size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len); size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
for (size_t token_i = 0; token_i < img_tokens; ++token_i) { for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
if (circular_h) { if (circular_h) {

View File

@ -774,34 +774,37 @@ namespace ZImage {
z_image.get_param_tensors(tensors, prefix); z_image.get_param_tensors(tensors, prefix);
} }
struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_cgraph* build_graph(ggml_tensor* x,
struct ggml_tensor* timesteps, ggml_tensor* timesteps,
struct ggml_tensor* context, std::vector<ggml_tensor*> contexts,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false) { std::vector<ggml_tensor*> siglip_feats = {}) {
GGML_ASSERT(x->ne[3] == 1); GGML_ASSERT(x->ne[3] == 1);
struct ggml_cgraph* gf = new_graph_custom(Z_IMAGE_GRAPH_SIZE); struct ggml_cgraph* gf = new_graph_custom(Z_IMAGE_GRAPH_SIZE);
x = to_backend(x); x = to_backend(x);
context = to_backend(context);
for (int i = 0; i < contexts.size(); i++) {
contexts[i] = to_backend(contexts[i]);
}
timesteps = to_backend(timesteps); timesteps = to_backend(timesteps);
for (int i = 0; i < ref_latents.size(); i++) { for (int i = 0; i < ref_latents.size(); i++) {
ref_latents[i] = to_backend(ref_latents[i]); ref_latents[i] = to_backend(ref_latents[i]);
} }
pe_vec = Rope::gen_z_image_pe(x->ne[1], pe_vec = Rope::gen_z_image_pe(x,
x->ne[0], contexts,
z_image_params.patch_size,
x->ne[3],
context->ne[1],
SEQ_MULTI_OF,
ref_latents, ref_latents,
increase_ref_index, siglip_feats,
z_image_params.patch_size,
SEQ_MULTI_OF,
z_image_params.theta, z_image_params.theta,
z_image_params.axes_dim,
circular_y_enabled, circular_y_enabled,
circular_x_enabled, circular_x_enabled,
z_image_params.axes_dim); x->ne[3]);
int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2; int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2;
// LOG_DEBUG("pos_len %d", pos_len); // LOG_DEBUG("pos_len %d", pos_len);
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len);
@ -814,7 +817,7 @@ namespace ZImage {
struct ggml_tensor* out = z_image.forward(&runner_ctx, struct ggml_tensor* out = z_image.forward(&runner_ctx,
x, x,
timesteps, timesteps,
{context}, contexts,
pe, pe,
ref_latents); ref_latents);
@ -826,16 +829,16 @@ namespace ZImage {
bool compute(int n_threads, bool compute(int n_threads,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* timesteps, struct ggml_tensor* timesteps,
struct ggml_tensor* context, std::vector<ggml_tensor*> contexts,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false, std::vector<ggml_tensor*> siglip_feats = {},
struct ggml_tensor** output = nullptr, struct ggml_tensor** output = nullptr,
struct ggml_context* output_ctx = nullptr) { struct ggml_context* output_ctx = nullptr) {
// x: [N, in_channels, h, w] // x: [N, in_channels, h, w]
// timesteps: [N, ] // timesteps: [N, ]
// context: [N, max_position, hidden_size] // context: [N, max_position, hidden_size]
auto get_graph = [&]() -> struct ggml_cgraph* { auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(x, timesteps, context, ref_latents, increase_ref_index); return build_graph(x, timesteps, contexts, ref_latents, siglip_feats);
}; };
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
@ -867,7 +870,7 @@ namespace ZImage {
struct ggml_tensor* out = nullptr; struct ggml_tensor* out = nullptr;
int t0 = ggml_time_ms(); int t0 = ggml_time_ms();
compute(8, x, timesteps, context, {}, false, &out, work_ctx); compute(8, x, timesteps, {context}, {}, {}, &out, work_ctx);
int t1 = ggml_time_ms(); int t1 = ggml_time_ms();
print_ggml_tensor(out); print_ggml_tensor(out);