mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-01-02 18:53:36 +00:00
perf: optimize ggml_ext_chunk (#1084)
This commit is contained in:
parent
8f05f5bc6e
commit
d96b4152d6
@ -194,10 +194,12 @@ public:
|
|||||||
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||||
|
|
||||||
x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
|
x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
|
||||||
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0);
|
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0, false);
|
||||||
x = x_vec[0]; // [ne3, ne2, ne1, dim_out]
|
x = x_vec[0]; // [ne3, ne2, ne1, dim_out]
|
||||||
auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out]
|
auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out]
|
||||||
|
|
||||||
|
gate = ggml_cont(ctx->ggml_ctx, gate);
|
||||||
|
|
||||||
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
|
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
|
||||||
|
|
||||||
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
|
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]
|
||||||
|
|||||||
@ -732,34 +732,22 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
|
|||||||
__STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_context* ctx,
|
__STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int num,
|
int num,
|
||||||
int64_t dim) {
|
int64_t dim,
|
||||||
|
bool cont = true) {
|
||||||
GGML_ASSERT(dim >= 0 && dim < 4);
|
GGML_ASSERT(dim >= 0 && dim < 4);
|
||||||
GGML_ASSERT(x->ne[dim] % num == 0);
|
GGML_ASSERT(x->ne[dim] % num == 0);
|
||||||
|
|
||||||
int perm[4] = {0, 1, 2, 3};
|
|
||||||
for (int i = dim; i < 3; ++i)
|
|
||||||
perm[i] = perm[i + 1];
|
|
||||||
perm[3] = dim;
|
|
||||||
|
|
||||||
int inv_perm[4];
|
|
||||||
for (int i = 0; i < 4; ++i)
|
|
||||||
inv_perm[perm[i]] = i;
|
|
||||||
|
|
||||||
if (dim != 3) {
|
|
||||||
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
|
|
||||||
x = ggml_cont(ctx, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<struct ggml_tensor*> chunks;
|
std::vector<struct ggml_tensor*> chunks;
|
||||||
int64_t chunk_size = x->ne[3] / num;
|
int64_t chunk_size = x->ne[dim] / num;
|
||||||
|
int64_t stride = chunk_size * x->nb[dim];
|
||||||
|
int64_t chunk_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]};
|
||||||
|
chunk_ne[dim] = chunk_size;
|
||||||
for (int i = 0; i < num; i++) {
|
for (int i = 0; i < num; i++) {
|
||||||
auto chunk = ggml_view_4d(
|
auto chunk = ggml_view_4d(
|
||||||
ctx, x,
|
ctx, x,
|
||||||
x->ne[0], x->ne[1], x->ne[2], chunk_size,
|
chunk_ne[0], chunk_ne[1], chunk_ne[2], chunk_ne[3],
|
||||||
x->nb[1], x->nb[2], x->nb[3], x->nb[3] * i * chunk_size);
|
x->nb[1], x->nb[2], x->nb[3], stride * i);
|
||||||
|
if (cont) {
|
||||||
if (dim != 3) {
|
|
||||||
chunk = ggml_ext_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
|
|
||||||
chunk = ggml_cont(ctx, chunk);
|
chunk = ggml_cont(ctx, chunk);
|
||||||
}
|
}
|
||||||
chunks.push_back(chunk);
|
chunks.push_back(chunk);
|
||||||
@ -772,7 +760,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor*
|
|||||||
// x: [ne3, ne2, ne1, ne0]
|
// x: [ne3, ne2, ne1, ne0]
|
||||||
// return: [ne3, ne2, ne1, ne0/2]
|
// return: [ne3, ne2, ne1, ne0/2]
|
||||||
|
|
||||||
auto x_vec = ggml_ext_chunk(ctx, x, 2, 0);
|
auto x_vec = ggml_ext_chunk(ctx, x, 2, 0, false);
|
||||||
ggml_tensor* gate;
|
ggml_tensor* gate;
|
||||||
if (gate_first) {
|
if (gate_first) {
|
||||||
gate = x_vec[0];
|
gate = x_vec[0];
|
||||||
@ -781,7 +769,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor*
|
|||||||
x = x_vec[0];
|
x = x_vec[0];
|
||||||
gate = x_vec[1];
|
gate = x_vec[1];
|
||||||
}
|
}
|
||||||
|
gate = ggml_cont(ctx, gate);
|
||||||
gate = ggml_silu_inplace(ctx, gate);
|
gate = ggml_silu_inplace(ctx, gate);
|
||||||
|
|
||||||
x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, ne0/2]
|
x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, ne0/2]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user