diff --git a/common.hpp b/common.hpp index 443f3dc..dd8281f 100644 --- a/common.hpp +++ b/common.hpp @@ -181,39 +181,22 @@ class GEGLU : public UnaryBlock { protected: int64_t dim_in; int64_t dim_out; - std::string prefix; - - void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override { - this->prefix = prefix; - enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_storage_map, GGML_TYPE_F32); - enum ggml_type bias_wtype = GGML_TYPE_F32; - params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2); - params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2); - } public: GEGLU(int64_t dim_in, int64_t dim_out) - : dim_in(dim_in), dim_out(dim_out) {} + : dim_in(dim_in), dim_out(dim_out) { + blocks["proj"] = std::shared_ptr(new Linear(dim_in, dim_out * 2)); + } struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override { // x: [ne3, ne2, ne1, dim_in] // return: [ne3, ne2, ne1, dim_out] - struct ggml_tensor* w = params["proj.weight"]; - struct ggml_tensor* b = params["proj.bias"]; + auto proj = std::dynamic_pointer_cast(blocks["proj"]); - if (ctx->weight_adapter) { - w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "proj.weight"); - b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "proj.bias"); - } - - auto x_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in] - auto x_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in] - auto gate_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ] - auto gate_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ] - - auto x_in = x; - x = ggml_ext_linear(ctx->ggml_ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out] - auto gate = ggml_ext_linear(ctx->ggml_ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out] + x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2] + auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0); + x = x_vec[0]; // [ne3, ne2, ne1, dim_out] + auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out] gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);