mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 10:18:51 +00:00
make z-image a litter faster
This commit is contained in:
parent
c7d4a6035d
commit
6f4b49239c
16
flux.hpp
16
flux.hpp
@ -1014,14 +1014,14 @@ namespace Flux {
|
||||
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
|
||||
}
|
||||
|
||||
img = ggml_view_3d(ctx->ggml_ctx,
|
||||
txt_img,
|
||||
txt_img->ne[0],
|
||||
img->ne[1],
|
||||
txt_img->ne[2],
|
||||
txt_img->nb[1],
|
||||
txt_img->nb[2],
|
||||
txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
|
||||
img = ggml_view_3d(ctx->ggml_ctx,
|
||||
txt_img,
|
||||
txt_img->ne[0],
|
||||
img->ne[1],
|
||||
txt_img->ne[2],
|
||||
txt_img->nb[1],
|
||||
txt_img->nb[2],
|
||||
txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
|
||||
|
||||
if (final_layer) {
|
||||
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
@ -687,42 +687,19 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
int dim,
|
||||
int64_t start,
|
||||
int64_t end) {
|
||||
int64_t end,
|
||||
bool cont = true) {
|
||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||
if (x->ne[dim] == 1) {
|
||||
return x;
|
||||
}
|
||||
while (start < 0) {
|
||||
start = x->ne[dim] + start;
|
||||
}
|
||||
while (end < 0) {
|
||||
end = x->ne[dim] + end;
|
||||
}
|
||||
GGML_ASSERT(end > start);
|
||||
GGML_ASSERT(start >= 0 && start < x->ne[dim]);
|
||||
GGML_ASSERT(end > start && end <= x->ne[dim]);
|
||||
|
||||
int perm[4] = {0, 1, 2, 3};
|
||||
for (int i = dim; i < 3; ++i)
|
||||
perm[i] = perm[i + 1];
|
||||
perm[3] = dim;
|
||||
int64_t slice_size = end - start;
|
||||
int64_t slice_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]};
|
||||
slice_ne[dim] = slice_size;
|
||||
|
||||
int inv_perm[4];
|
||||
for (int i = 0; i < 4; ++i)
|
||||
inv_perm[perm[i]] = i;
|
||||
x = ggml_view_4d(ctx, x,
|
||||
slice_ne[0], slice_ne[1], slice_ne[2], slice_ne[3],
|
||||
x->nb[1], x->nb[2], x->nb[3], start * x->nb[dim]);
|
||||
|
||||
if (dim != 3) {
|
||||
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
|
||||
x = ggml_cont(ctx, x);
|
||||
}
|
||||
|
||||
x = ggml_view_4d(
|
||||
ctx, x,
|
||||
x->ne[0], x->ne[1], x->ne[2], end - start,
|
||||
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start);
|
||||
|
||||
if (dim != 3) {
|
||||
x = ggml_ext_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
||||
if (cont) {
|
||||
x = ggml_cont(ctx, x);
|
||||
}
|
||||
|
||||
|
||||
38
z_image.hpp
38
z_image.hpp
@ -54,15 +54,37 @@ namespace ZImage {
|
||||
|
||||
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
|
||||
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
|
||||
qkv = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, qkv, 0, 2, 3, 1)); // [num_heads + num_kv_heads*2, N, n_token, head_dim]
|
||||
|
||||
auto q = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], 0); // [num_heads, N, n_token, head_dim]
|
||||
auto k = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * num_heads); // [num_kv_heads, N, n_token, head_dim]
|
||||
auto v = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * (num_heads + num_kv_heads)); // [num_kv_heads, N, n_token, head_dim]
|
||||
|
||||
q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 3, 1, 2)); // [N, n_token, num_heads, head_dim]
|
||||
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim]
|
||||
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim]
|
||||
auto q = ggml_view_4d(ctx->ggml_ctx,
|
||||
qkv,
|
||||
qkv->ne[0],
|
||||
num_heads,
|
||||
qkv->ne[2],
|
||||
qkv->ne[3],
|
||||
qkv->nb[1],
|
||||
qkv->nb[2],
|
||||
qkv->nb[3],
|
||||
0); // [N, n_token, num_heads, head_dim]
|
||||
auto k = ggml_view_4d(ctx->ggml_ctx,
|
||||
qkv,
|
||||
qkv->ne[0],
|
||||
num_kv_heads,
|
||||
qkv->ne[2],
|
||||
qkv->ne[3],
|
||||
qkv->nb[1],
|
||||
qkv->nb[2],
|
||||
qkv->nb[3],
|
||||
num_heads * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
|
||||
auto v = ggml_view_4d(ctx->ggml_ctx,
|
||||
qkv,
|
||||
qkv->ne[0],
|
||||
num_kv_heads,
|
||||
qkv->ne[2],
|
||||
qkv->ne[3],
|
||||
qkv->nb[1],
|
||||
qkv->nb[2],
|
||||
qkv->nb[3],
|
||||
(num_heads + num_kv_heads) * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
|
||||
|
||||
if (qk_norm) {
|
||||
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user