optimize ggml_ext_chunk

This commit is contained in:
leejet 2025-12-13 01:16:58 +08:00
parent 11ab095230
commit 0835e5c227
2 changed files with 11 additions and 4 deletions

View File

@ -194,10 +194,12 @@ public:
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]); auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2] x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0); auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0, false);
x = x_vec[0]; // [ne3, ne2, ne1, dim_out] x = x_vec[0]; // [ne3, ne2, ne1, dim_out]
auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out] auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out]
gate = ggml_cont(ctx->ggml_ctx, gate);
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate); gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out] x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out]

View File

@ -732,7 +732,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx,
__STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_context* ctx, __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_context* ctx,
struct ggml_tensor* x, struct ggml_tensor* x,
int num, int num,
int64_t dim) { int64_t dim,
bool cont = true) {
GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(dim >= 0 && dim < 4);
GGML_ASSERT(x->ne[dim] % num == 0); GGML_ASSERT(x->ne[dim] % num == 0);
@ -747,8 +748,10 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_co
if (dim != 3) { if (dim != 3) {
x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]); x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]);
if (cont) {
x = ggml_cont(ctx, x); x = ggml_cont(ctx, x);
} }
}
std::vector<struct ggml_tensor*> chunks; std::vector<struct ggml_tensor*> chunks;
int64_t chunk_size = x->ne[3] / num; int64_t chunk_size = x->ne[3] / num;
@ -760,8 +763,10 @@ __STATIC_INLINE__ std::vector<struct ggml_tensor*> ggml_ext_chunk(struct ggml_co
if (dim != 3) { if (dim != 3) {
chunk = ggml_ext_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]); chunk = ggml_ext_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]);
if (cont) {
chunk = ggml_cont(ctx, chunk); chunk = ggml_cont(ctx, chunk);
} }
}
chunks.push_back(chunk); chunks.push_back(chunk);
} }