mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 18:28:57 +00:00
make z-image a litter faster
This commit is contained in:
parent
c7d4a6035d
commit
6f4b49239c
16
flux.hpp
16
flux.hpp
@ -1014,14 +1014,14 @@ namespace Flux {
|
|||||||
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
|
txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask, ss_mods);
|
||||||
}
|
}
|
||||||
|
|
||||||
img = ggml_view_3d(ctx->ggml_ctx,
|
img = ggml_view_3d(ctx->ggml_ctx,
|
||||||
txt_img,
|
txt_img,
|
||||||
txt_img->ne[0],
|
txt_img->ne[0],
|
||||||
img->ne[1],
|
img->ne[1],
|
||||||
txt_img->ne[2],
|
txt_img->ne[2],
|
||||||
txt_img->nb[1],
|
txt_img->nb[1],
|
||||||
txt_img->nb[2],
|
txt_img->nb[2],
|
||||||
txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
|
txt->ne[1] * txt_img->nb[1]); // [N, n_img_token, hidden_size]
|
||||||
|
|
||||||
if (final_layer) {
|
if (final_layer) {
|
||||||
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
||||||
|
|||||||
@ -687,42 +687,19 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
|
|||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int dim,
|
int dim,
|
||||||
int64_t start,
|
int64_t start,
|
||||||
int64_t end) {
|
int64_t end,
|
||||||
|
bool cont = true) {
|
||||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||||
if (x->ne[dim] == 1) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
while (start < 0) {
|
|
||||||
start = x->ne[dim] + start;
|
|
||||||
}
|
|
||||||
while (end < 0) {
|
|
||||||
end = x->ne[dim] + end;
|
|
||||||
}
|
|
||||||
GGML_ASSERT(end > start);
|
|
||||||
GGML_ASSERT(start >= 0 && start < x->ne[dim]);
|
|
||||||
GGML_ASSERT(end > start && end <= x->ne[dim]);
|
|
||||||
|
|
||||||
int perm[4] = {0, 1, 2, 3};
|
int64_t slice_size = end - start;
|
||||||
for (int i = dim; i < 3; ++i)
|
int64_t slice_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]};
|
||||||
perm[i] = perm[i + 1];
|
slice_ne[dim] = slice_size;
|
||||||
perm[3] = dim;
|
|
||||||
|
|
||||||
int inv_perm[4];
|
x = ggml_view_4d(ctx, x,
|
||||||
for (int i = 0; i < 4; ++i)
|
slice_ne[0], slice_ne[1], slice_ne[2], slice_ne[3],
|
||||||
inv_perm[perm[i]] = i;
|
x->nb[1], x->nb[2], x->nb[3], start * x->nb[dim]);
|
||||||
|
|
||||||
if (dim != 3) {
|
if (cont) {
|
||||||
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
|
|
||||||
x = ggml_cont(ctx, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
x = ggml_view_4d(
|
|
||||||
ctx, x,
|
|
||||||
x->ne[0], x->ne[1], x->ne[2], end - start,
|
|
||||||
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * start);
|
|
||||||
|
|
||||||
if (dim != 3) {
|
|
||||||
x = ggml_ext_torch_permute(ctx, x, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
|
||||||
x = ggml_cont(ctx, x);
|
x = ggml_cont(ctx, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
38
z_image.hpp
38
z_image.hpp
@ -54,15 +54,37 @@ namespace ZImage {
|
|||||||
|
|
||||||
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
|
auto qkv = qkv_proj->forward(ctx, x); // [N, n_token, (num_heads + num_kv_heads*2)*head_dim]
|
||||||
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
|
qkv = ggml_reshape_4d(ctx->ggml_ctx, qkv, head_dim, num_heads + num_kv_heads * 2, qkv->ne[1], qkv->ne[2]); // [N, n_token, num_heads + num_kv_heads*2, head_dim]
|
||||||
qkv = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, qkv, 0, 2, 3, 1)); // [num_heads + num_kv_heads*2, N, n_token, head_dim]
|
|
||||||
|
|
||||||
auto q = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], 0); // [num_heads, N, n_token, head_dim]
|
auto q = ggml_view_4d(ctx->ggml_ctx,
|
||||||
auto k = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * num_heads); // [num_kv_heads, N, n_token, head_dim]
|
qkv,
|
||||||
auto v = ggml_view_4d(ctx->ggml_ctx, qkv, qkv->ne[0], qkv->ne[1], qkv->ne[2], num_kv_heads, qkv->nb[1], qkv->nb[2], qkv->nb[3], qkv->nb[3] * (num_heads + num_kv_heads)); // [num_kv_heads, N, n_token, head_dim]
|
qkv->ne[0],
|
||||||
|
num_heads,
|
||||||
q = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, q, 0, 3, 1, 2)); // [N, n_token, num_heads, head_dim]
|
qkv->ne[2],
|
||||||
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim]
|
qkv->ne[3],
|
||||||
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 0, 3, 1, 2)); // [N, n_token, num_kv_heads, head_dim]
|
qkv->nb[1],
|
||||||
|
qkv->nb[2],
|
||||||
|
qkv->nb[3],
|
||||||
|
0); // [N, n_token, num_heads, head_dim]
|
||||||
|
auto k = ggml_view_4d(ctx->ggml_ctx,
|
||||||
|
qkv,
|
||||||
|
qkv->ne[0],
|
||||||
|
num_kv_heads,
|
||||||
|
qkv->ne[2],
|
||||||
|
qkv->ne[3],
|
||||||
|
qkv->nb[1],
|
||||||
|
qkv->nb[2],
|
||||||
|
qkv->nb[3],
|
||||||
|
num_heads * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
|
||||||
|
auto v = ggml_view_4d(ctx->ggml_ctx,
|
||||||
|
qkv,
|
||||||
|
qkv->ne[0],
|
||||||
|
num_kv_heads,
|
||||||
|
qkv->ne[2],
|
||||||
|
qkv->ne[3],
|
||||||
|
qkv->nb[1],
|
||||||
|
qkv->nb[2],
|
||||||
|
qkv->nb[3],
|
||||||
|
(num_heads + num_kv_heads) * qkv->nb[1]); // [N, n_token, num_kv_heads, head_dim]
|
||||||
|
|
||||||
if (qk_norm) {
|
if (qk_norm) {
|
||||||
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
|
auto q_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["q_norm"]);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user