save more memory

This commit is contained in:
leejet 2025-11-12 23:45:06 +08:00
parent 4008102a93
commit ceb0fcfae6

View File

@ -181,39 +181,22 @@ class GEGLU : public UnaryBlock {
protected:
int64_t dim_in;
int64_t dim_out;
std::string prefix;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
this->prefix = prefix;
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_storage_map, GGML_TYPE_F32);
enum ggml_type bias_wtype = GGML_TYPE_F32;
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
}
public:
GEGLU(int64_t dim_in, int64_t dim_out)
: dim_in(dim_in), dim_out(dim_out) {}
: dim_in(dim_in), dim_out(dim_out) {
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim_in, dim_out * 2));
}
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
// x: [ne3, ne2, ne1, dim_in]
// return: [ne3, ne2, ne1, dim_out]
struct ggml_tensor* w = params["proj.weight"];
struct ggml_tensor* b = params["proj.bias"];
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
if (ctx->weight_adapter) {
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "proj.weight");
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "proj.bias");
}
auto x_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in]
auto x_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in]
auto gate_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ]
auto gate_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ]
auto x_in = x;
x = ggml_ext_linear(ctx->ggml_ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
auto gate = ggml_ext_linear(ctx->ggml_ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]
x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0);
x = x_vec[0]; // [ne3, ne2, ne1, dim_out]
auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out]
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);