diff --git a/common.hpp b/common.hpp index 33d499f..74b218a 100644 --- a/common.hpp +++ b/common.hpp @@ -194,10 +194,12 @@ public: auto proj = std::dynamic_pointer_cast(blocks["proj"]); x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2] - auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0); + auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0, false); x = x_vec[0]; // [ne3, ne2, ne1, dim_out] auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out] + gate = ggml_cont(ctx->ggml_ctx, gate); + gate = ggml_gelu_inplace(ctx->ggml_ctx, gate); x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out] diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 07b9bfb..01dc4c4 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -732,7 +732,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx, __STATIC_INLINE__ std::vector ggml_ext_chunk(struct ggml_context* ctx, struct ggml_tensor* x, int num, - int64_t dim) { + int64_t dim, + bool cont = true) { GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(x->ne[dim] % num == 0); @@ -747,7 +748,9 @@ __STATIC_INLINE__ std::vector ggml_ext_chunk(struct ggml_co if (dim != 3) { x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]); - x = ggml_cont(ctx, x); + if (cont) { + x = ggml_cont(ctx, x); + } } std::vector chunks; @@ -760,7 +763,9 @@ __STATIC_INLINE__ std::vector ggml_ext_chunk(struct ggml_co if (dim != 3) { chunk = ggml_ext_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]); - chunk = ggml_cont(ctx, chunk); + if (cont) { + chunk = ggml_cont(ctx, chunk); + } } chunks.push_back(chunk); }