diff --git a/common.hpp b/common.hpp index 9c8aba1..a197e8f 100644 --- a/common.hpp +++ b/common.hpp @@ -242,7 +242,8 @@ public: FeedForward(int64_t dim, int64_t dim_out, int64_t mult = 4, - Activation activation = Activation::GEGLU) { + Activation activation = Activation::GEGLU, + bool force_prec_f32 = false) { int64_t inner_dim = dim * mult; if (activation == Activation::GELU) { @@ -252,7 +253,7 @@ public: } // net_1 is nn.Dropout(), skip for inference - blocks["net.2"] = std::shared_ptr(new Linear(inner_dim, dim_out)); + blocks["net.2"] = std::shared_ptr(new Linear(inner_dim, dim_out, true, false, force_prec_f32)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 15d80f9..b64dc85 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -939,8 +939,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ct __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* w, - struct ggml_tensor* b) { + struct ggml_tensor* b, + bool force_prec_f32 = false) { x = ggml_mul_mat(ctx, w, x); + if (force_prec_f32) { + ggml_mul_mat_set_prec(x, GGML_PREC_F32); + } if (b != NULL) { x = ggml_add_inplace(ctx, x, b); } @@ -1953,6 +1957,7 @@ protected: int64_t out_features; bool bias; bool force_f32; + bool force_prec_f32; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); @@ -1969,12 +1974,14 @@ protected: public: Linear(int64_t in_features, int64_t out_features, - bool bias = true, - bool force_f32 = false) + bool bias = true, + bool force_f32 = false, + bool force_prec_f32 = false) : in_features(in_features), out_features(out_features), bias(bias), - force_f32(force_f32) {} + force_f32(force_f32), + force_prec_f32(force_prec_f32) {} struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; @@ -1982,7 +1989,7 @@ public: if (bias) { b = params["bias"]; } - return ggml_nn_linear(ctx, x, w, b); + return ggml_nn_linear(ctx, x, w, b, force_prec_f32); } }; diff --git a/qwen_image.hpp b/qwen_image.hpp index 2f5dad8..ab16b82 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -196,7 +196,7 @@ namespace Qwen { blocks["img_norm1"] = std::shared_ptr(new LayerNorm(dim, eps, false)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(dim, eps, false)); - blocks["img_mlp"] = std::shared_ptr(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU)); + blocks["img_mlp"] = std::shared_ptr(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU, true)); // txt_mod.0 is nn.SiLU() blocks["txt_mod.1"] = std::shared_ptr(new Linear(dim, 6 * dim, true));