mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
save more memory
This commit is contained in:
parent
4008102a93
commit
ceb0fcfae6
33
common.hpp
33
common.hpp
@ -181,39 +181,22 @@ class GEGLU : public UnaryBlock {
|
||||
protected:
|
||||
int64_t dim_in;
|
||||
int64_t dim_out;
|
||||
std::string prefix;
|
||||
|
||||
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, std::string prefix = "") override {
|
||||
this->prefix = prefix;
|
||||
enum ggml_type wtype = get_type(prefix + "proj.weight", tensor_storage_map, GGML_TYPE_F32);
|
||||
enum ggml_type bias_wtype = GGML_TYPE_F32;
|
||||
params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2);
|
||||
params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2);
|
||||
}
|
||||
|
||||
public:
|
||||
GEGLU(int64_t dim_in, int64_t dim_out)
|
||||
: dim_in(dim_in), dim_out(dim_out) {}
|
||||
: dim_in(dim_in), dim_out(dim_out) {
|
||||
blocks["proj"] = std::shared_ptr<GGMLBlock>(new Linear(dim_in, dim_out * 2));
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) override {
|
||||
// x: [ne3, ne2, ne1, dim_in]
|
||||
// return: [ne3, ne2, ne1, dim_out]
|
||||
struct ggml_tensor* w = params["proj.weight"];
|
||||
struct ggml_tensor* b = params["proj.bias"];
|
||||
auto proj = std::dynamic_pointer_cast<Linear>(blocks["proj"]);
|
||||
|
||||
if (ctx->weight_adapter) {
|
||||
w = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, w, prefix + "proj.weight");
|
||||
b = ctx->weight_adapter->patch_weight(ctx->ggml_ctx, b, prefix + "proj.bias");
|
||||
}
|
||||
|
||||
auto x_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], 0); // [dim_out, dim_in]
|
||||
auto x_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, 0); // [dim_out, dim_in]
|
||||
auto gate_w = ggml_view_2d(ctx->ggml_ctx, w, w->ne[0], w->ne[1] / 2, w->nb[1], w->nb[1] * w->ne[1] / 2); // [dim_out, ]
|
||||
auto gate_b = ggml_view_1d(ctx->ggml_ctx, b, b->ne[0] / 2, b->nb[0] * b->ne[0] / 2); // [dim_out, ]
|
||||
|
||||
auto x_in = x;
|
||||
x = ggml_ext_linear(ctx->ggml_ctx, x_in, x_w, x_b); // [ne3, ne2, ne1, dim_out]
|
||||
auto gate = ggml_ext_linear(ctx->ggml_ctx, x_in, gate_w, gate_b); // [ne3, ne2, ne1, dim_out]
|
||||
x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2]
|
||||
auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0);
|
||||
x = x_vec[0]; // [ne3, ne2, ne1, dim_out]
|
||||
auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out]
|
||||
|
||||
gate = ggml_gelu_inplace(ctx->ggml_ctx, gate);
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user