mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 10:53:34 +00:00
z-image-omni-base rope
This commit is contained in:
parent
b0e6680add
commit
8004d32de2
@ -439,9 +439,9 @@ struct ZImageModel : public DiffusionModel {
|
||||
return z_image.compute(n_threads,
|
||||
diffusion_params.x,
|
||||
diffusion_params.timesteps,
|
||||
diffusion_params.context,
|
||||
{diffusion_params.context},
|
||||
diffusion_params.ref_latents,
|
||||
true, // increase_ref_index
|
||||
{},
|
||||
output,
|
||||
output_ctx);
|
||||
}
|
||||
|
||||
115
rope.hpp
115
rope.hpp
@ -518,60 +518,117 @@ namespace Rope {
|
||||
return (m - (a % m)) % m;
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ std::vector<std::vector<float>> gen_z_image_ids(int h,
|
||||
int w,
|
||||
int patch_size,
|
||||
int bs,
|
||||
int context_len,
|
||||
int seq_multi_of,
|
||||
__STATIC_INLINE__ std::vector<std::vector<float>> gen_z_image_ids(ggml_tensor* x,
|
||||
const std::vector<ggml_tensor*>& contexts,
|
||||
const std::vector<ggml_tensor*>& ref_latents,
|
||||
bool increase_ref_index) {
|
||||
int padded_context_len = context_len + bound_mod(context_len, seq_multi_of);
|
||||
auto txt_ids = std::vector<std::vector<float>>(bs * padded_context_len, std::vector<float>(3, 0.0f));
|
||||
for (int i = 0; i < bs * padded_context_len; i++) {
|
||||
txt_ids[i][0] = (i % padded_context_len) + 1.f;
|
||||
const std::vector<ggml_tensor*>& siglip_feats,
|
||||
int patch_size,
|
||||
int seq_multi_of,
|
||||
int bs) {
|
||||
GGML_ASSERT(contexts.size() > ref_latents.size());
|
||||
GGML_ASSERT(contexts.size() >= siglip_feats.size());
|
||||
int context_cu_len = 1;
|
||||
std::vector<int> context_end_pos;
|
||||
std::vector<std::vector<float>> txt_ids;
|
||||
for (auto context : contexts) {
|
||||
int padded_context_len = context->ne[1] + bound_mod(context->ne[1], seq_multi_of);
|
||||
auto curr_txt_ids = std::vector<std::vector<float>>(bs * padded_context_len, std::vector<float>(3, 0.0f));
|
||||
for (int i = 0; i < bs * padded_context_len; i++) {
|
||||
curr_txt_ids[i][0] = static_cast<float>((i % padded_context_len) + context_cu_len);
|
||||
}
|
||||
context_cu_len += padded_context_len;
|
||||
context_end_pos.push_back(context_cu_len);
|
||||
context_cu_len += 2; // for image and siglip tokens
|
||||
txt_ids = concat_ids(txt_ids, curr_txt_ids, bs);
|
||||
}
|
||||
|
||||
int axes_dim_num = 3;
|
||||
int index = padded_context_len + 1;
|
||||
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, index);
|
||||
std::vector<std::vector<float>> img_ids;
|
||||
std::vector<ggml_tensor*> all_img = ref_latents;
|
||||
all_img.push_back(x);
|
||||
for (int i = 0; i < all_img.size(); i++) {
|
||||
int axes_dim_num = 3;
|
||||
int index = context_end_pos[i];
|
||||
auto curr_img_ids = gen_flux_img_ids(all_img[i]->ne[1], all_img[i]->ne[0], patch_size, bs, axes_dim_num, index);
|
||||
|
||||
int img_pad_len = bound_mod(static_cast<int>(img_ids.size() / bs), seq_multi_of);
|
||||
if (img_pad_len > 0) {
|
||||
std::vector<std::vector<float>> img_pad_ids(bs * img_pad_len, std::vector<float>(3, 0.f));
|
||||
img_ids = concat_ids(img_ids, img_pad_ids, bs);
|
||||
int img_pad_len = bound_mod(static_cast<int>(curr_img_ids.size() / bs), seq_multi_of);
|
||||
if (img_pad_len > 0) {
|
||||
std::vector<std::vector<float>> img_pad_ids(bs * img_pad_len, std::vector<float>(3, 0.f));
|
||||
curr_img_ids = concat_ids(curr_img_ids, img_pad_ids, bs);
|
||||
}
|
||||
img_ids = concat_ids(img_ids, curr_img_ids, bs);
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> sig_ids;
|
||||
for (int i = 0; i < siglip_feats.size(); i++) {
|
||||
int axes_dim_num = 3;
|
||||
int index = context_end_pos[i] + 1;
|
||||
int h_len = siglip_feats[i]->ne[1];
|
||||
int w_len = siglip_feats[i]->ne[0];
|
||||
|
||||
std::vector<std::vector<float>> curr_sig_ids(bs * h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
|
||||
|
||||
// scale position IDs to match img resolution
|
||||
std::vector<float> row_ids = linspace<float>(0, all_img[i]->ne[1] - 1, h_len);
|
||||
std::vector<float> col_ids = linspace<float>(0, all_img[i]->ne[0] - 1, w_len);
|
||||
|
||||
for (int ib = 0; ib < bs; ++ib) {
|
||||
for (int ih = 0; ih < h_len; ++ih) {
|
||||
for (int iw = 0; iw < w_len; ++iw) {
|
||||
curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][0] = index;
|
||||
curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][1] = row_ids[ih];
|
||||
curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][2] = col_ids[iw];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int sig_pad_len = bound_mod(static_cast<int>(curr_sig_ids.size() / bs), seq_multi_of);
|
||||
if (sig_pad_len > 0) {
|
||||
std::vector<std::vector<float>> sig_pad_ids(bs * sig_pad_len, std::vector<float>(3, 0.f));
|
||||
curr_sig_ids = concat_ids(curr_sig_ids, sig_pad_ids, bs);
|
||||
}
|
||||
sig_ids = concat_ids(sig_ids, curr_sig_ids, bs);
|
||||
}
|
||||
|
||||
auto ids = concat_ids(txt_ids, img_ids, bs);
|
||||
|
||||
// ignore ref_latents for now
|
||||
if (!sig_ids.empty()) {
|
||||
ids = concat_ids(ids, sig_ids, bs);
|
||||
}
|
||||
|
||||
return ids;
|
||||
}
|
||||
|
||||
// Generate z_image positional embeddings
|
||||
__STATIC_INLINE__ std::vector<float> gen_z_image_pe(int h,
|
||||
int w,
|
||||
int patch_size,
|
||||
int bs,
|
||||
int context_len,
|
||||
int seq_multi_of,
|
||||
__STATIC_INLINE__ std::vector<float> gen_z_image_pe(ggml_tensor* x,
|
||||
const std::vector<ggml_tensor*>& contexts,
|
||||
const std::vector<ggml_tensor*>& ref_latents,
|
||||
bool increase_ref_index,
|
||||
const std::vector<ggml_tensor*>& siglip_feats,
|
||||
int patch_size,
|
||||
int seq_multi_of,
|
||||
int theta,
|
||||
const std::vector<int>& axes_dim,
|
||||
bool circular_h,
|
||||
bool circular_w,
|
||||
const std::vector<int>& axes_dim) {
|
||||
std::vector<std::vector<float>> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index);
|
||||
int bs) {
|
||||
std::vector<std::vector<float>> ids = gen_z_image_ids(x, contexts, ref_latents, siglip_feats, patch_size, seq_multi_of, bs);
|
||||
std::vector<std::vector<int>> wrap_dims;
|
||||
if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) {
|
||||
int context_len = 0;
|
||||
for (auto context : contexts) {
|
||||
int padded_context_len = context->ne[1] + bound_mod(context->ne[1], seq_multi_of);
|
||||
context_len += padded_context_len;
|
||||
}
|
||||
int h = x->ne[1];
|
||||
int w = x->ne[0];
|
||||
int pad_h = (patch_size - (h % patch_size)) % patch_size;
|
||||
int pad_w = (patch_size - (w % patch_size)) % patch_size;
|
||||
int h_len = (h + pad_h) / patch_size;
|
||||
int w_len = (w + pad_w) / patch_size;
|
||||
|
||||
if (h_len > 0 && w_len > 0) {
|
||||
size_t pos_len = ids.size() / bs;
|
||||
wrap_dims.assign(axes_dim.size(), std::vector<int>(pos_len, 0));
|
||||
size_t cursor = context_len + bound_mod(context_len, seq_multi_of); // skip text (and its padding)
|
||||
size_t cursor = context_len; // skip text (and its padding)
|
||||
size_t img_tokens = static_cast<size_t>(h_len) * static_cast<size_t>(w_len);
|
||||
for (size_t token_i = 0; token_i < img_tokens; ++token_i) {
|
||||
if (circular_h) {
|
||||
|
||||
49
z_image.hpp
49
z_image.hpp
@ -774,34 +774,37 @@ namespace ZImage {
|
||||
z_image.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false) {
|
||||
struct ggml_cgraph* build_graph(ggml_tensor* x,
|
||||
ggml_tensor* timesteps,
|
||||
std::vector<ggml_tensor*> contexts,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
std::vector<ggml_tensor*> siglip_feats = {}) {
|
||||
GGML_ASSERT(x->ne[3] == 1);
|
||||
struct ggml_cgraph* gf = new_graph_custom(Z_IMAGE_GRAPH_SIZE);
|
||||
|
||||
x = to_backend(x);
|
||||
context = to_backend(context);
|
||||
x = to_backend(x);
|
||||
|
||||
for (int i = 0; i < contexts.size(); i++) {
|
||||
contexts[i] = to_backend(contexts[i]);
|
||||
}
|
||||
|
||||
timesteps = to_backend(timesteps);
|
||||
|
||||
for (int i = 0; i < ref_latents.size(); i++) {
|
||||
ref_latents[i] = to_backend(ref_latents[i]);
|
||||
}
|
||||
|
||||
pe_vec = Rope::gen_z_image_pe(x->ne[1],
|
||||
x->ne[0],
|
||||
z_image_params.patch_size,
|
||||
x->ne[3],
|
||||
context->ne[1],
|
||||
SEQ_MULTI_OF,
|
||||
pe_vec = Rope::gen_z_image_pe(x,
|
||||
contexts,
|
||||
ref_latents,
|
||||
increase_ref_index,
|
||||
siglip_feats,
|
||||
z_image_params.patch_size,
|
||||
SEQ_MULTI_OF,
|
||||
z_image_params.theta,
|
||||
z_image_params.axes_dim,
|
||||
circular_y_enabled,
|
||||
circular_x_enabled,
|
||||
z_image_params.axes_dim);
|
||||
x->ne[3]);
|
||||
int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2;
|
||||
// LOG_DEBUG("pos_len %d", pos_len);
|
||||
auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len);
|
||||
@ -814,7 +817,7 @@ namespace ZImage {
|
||||
struct ggml_tensor* out = z_image.forward(&runner_ctx,
|
||||
x,
|
||||
timesteps,
|
||||
{context},
|
||||
contexts,
|
||||
pe,
|
||||
ref_latents);
|
||||
|
||||
@ -826,16 +829,16 @@ namespace ZImage {
|
||||
bool compute(int n_threads,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timesteps,
|
||||
struct ggml_tensor* context,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false,
|
||||
struct ggml_tensor** output = nullptr,
|
||||
struct ggml_context* output_ctx = nullptr) {
|
||||
std::vector<ggml_tensor*> contexts,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
std::vector<ggml_tensor*> siglip_feats = {},
|
||||
struct ggml_tensor** output = nullptr,
|
||||
struct ggml_context* output_ctx = nullptr) {
|
||||
// x: [N, in_channels, h, w]
|
||||
// timesteps: [N, ]
|
||||
// context: [N, max_position, hidden_size]
|
||||
auto get_graph = [&]() -> struct ggml_cgraph* {
|
||||
return build_graph(x, timesteps, context, ref_latents, increase_ref_index);
|
||||
return build_graph(x, timesteps, contexts, ref_latents, siglip_feats);
|
||||
};
|
||||
|
||||
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
|
||||
@ -867,7 +870,7 @@ namespace ZImage {
|
||||
struct ggml_tensor* out = nullptr;
|
||||
|
||||
int t0 = ggml_time_ms();
|
||||
compute(8, x, timesteps, context, {}, false, &out, work_ctx);
|
||||
compute(8, x, timesteps, {context}, {}, {}, &out, work_ctx);
|
||||
int t1 = ggml_time_ms();
|
||||
|
||||
print_ggml_tensor(out);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user