diff --git a/flux.hpp b/flux.hpp index b2e7422..83a4a22 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1014,14 +1014,14 @@ namespace Flux { txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods); } - img = ggml_view_3d(ctx->ggml_ctx, - txt_img, - txt_img->ne[0], - img->ne[1], - txt_img->ne[2], - txt_img->nb[1], - txt_img->nb[2], - txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size] + img = ggml_view_3d(ctx->ggml_ctx, + txt_img, + txt_img->ne[0], + img->ne[1], + txt_img->ne[2], + txt_img->nb[1], + txt_img->nb[2], + txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size] if (final_layer) { img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 9d5ea31..692ba85 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -687,42 +687,19 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx, struct ggml_tensor* x, int dim, int64_t start, - int64_t end) { + int64_t end, + bool cont = true) { GGML_ASSERT(dim >= 0 && dim < 4); - if (x->ne[dim] == 1) { - return x; - } - while (start < 0) { - start = x->ne[dim] + start; - } - while (end < 0) { - end = x->ne[dim] + end; - } - GGML_ASSERT(end > start); - GGML_ASSERT(start >= 0 && start < x->ne[dim]); - GGML_ASSERT(end > start && end <= x->ne[dim]); - int perm[4] = {0, 1, 2, 3}; - for (int i = dim; i < 3; ++i) - perm[i] = perm[i + 1]; - perm[3] = dim; + int64_t slice_size = end - start; + int64_t slice_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]}; + slice_ne[dim] = slice_size; - int inv_perm[4]; - for (int i = 0; i < 4; ++i) - inv_perm[perm[i]] = i; + x = ggml_view_4d(ctx, x, + slice_ne[0], slice_ne[1], slice_ne[2], slice_ne[3], + x->nb[1], x->nb[2], x->nb[3], start * x->nb[dim]); - if (dim != 3) { - x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]); - x = ggml_cont(ctx, x); - } - - x = ggml_view_4d( - ctx, x, - x->ne[0], x->ne[1], x->ne[2], end - start, - x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start); - - if (dim != 3) { - x = ggml_ext_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]); + if (cont) { x = ggml_cont(ctx, x); } diff --git a/z_image.hpp b/z_image.hpp index 0abc783..505fa7e 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -54,15 +54,37 @@ namespace ZImage { auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim] qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim] - qkv = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, qkv, 0, 2, 3, 1)); // [num_heads + num_kv_heads*2, N, n_token, head_dim] - auto q = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], 0); // [num_heads, N, n_token, head_dim] - auto k = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * num_heads); // [num_kv_heads, N, n_token, head_dim] - auto v = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * (num_heads + num_kv_heads)); // [num_kv_heads, N, n_token, head_dim] - - q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 3, 1, 2)); // [N, n_token, num_heads, head_dim] - k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] - v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim] + auto q = ggml_view_4d(ctx->ggml_ctx, + qkv, + qkv->ne[0], + num_heads, + qkv->ne[2], + qkv->ne[3], + qkv->nb[1], + qkv->nb[2], + qkv->nb[3], + 0); // [N, n_token, num_heads, head_dim] + auto k = ggml_view_4d(ctx->ggml_ctx, + qkv, + qkv->ne[0], + num_kv_heads, + qkv->ne[2], + qkv->ne[3], + qkv->nb[1], + qkv->nb[2], + qkv->nb[3], + num_heads * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim] + auto v = ggml_view_4d(ctx->ggml_ctx, + qkv, + qkv->ne[0], + num_kv_heads, + qkv->ne[2], + qkv->ne[3], + qkv->nb[1], + qkv->nb[2], + qkv->nb[3], + (num_heads + num_kv_heads) * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim] if (qk_norm) { auto q_norm = std::dynamic_pointer_cast(blocks["q_norm"]);