leejet 2025-10-11 01:04:14 +08:00
parent d19d4a5903
commit 6ea2a75929
3 changed files with 16 additions and 8 deletions

View File

@ -242,7 +242,8 @@ public:
FeedForward(int64_t dim,
int64_t dim_out,
int64_t mult = 4,
Activation activation = Activation::GEGLU) {
Activation activation = Activation::GEGLU,
bool force_prec_f32 = false) {
int64_t inner_dim = dim * mult;
if (activation == Activation::GELU) {
@ -252,7 +253,7 @@ public:
}
// net_1 is nn.Dropout(), skip for inference
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out));
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, force_prec_f32));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {

View File

@ -939,8 +939,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ct
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b) {
struct ggml_tensor* b,
bool force_prec_f32 = false) {
x = ggml_mul_mat(ctx, w, x);
if (force_prec_f32) {
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
}
if (b != NULL) {
x = ggml_add_inplace(ctx, x, b);
}
@ -1953,6 +1957,7 @@ protected:
int64_t out_features;
bool bias;
bool force_f32;
bool force_prec_f32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
@ -1969,12 +1974,14 @@ protected:
public:
Linear(int64_t in_features,
int64_t out_features,
bool bias = true,
bool force_f32 = false)
bool bias = true,
bool force_f32 = false,
bool force_prec_f32 = false)
: in_features(in_features),
out_features(out_features),
bias(bias),
force_f32(force_f32) {}
force_f32(force_f32),
force_prec_f32(force_prec_f32) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
@ -1982,7 +1989,7 @@ public:
if (bias) {
b = params["bias"];
}
return ggml_nn_linear(ctx, x, w, b);
return ggml_nn_linear(ctx, x, w, b, force_prec_f32);
}
};

View File

@ -196,7 +196,7 @@ namespace Qwen {
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["img_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU));
blocks["img_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU, true));
// txt_mod.0 is nn.SiLU()
blocks["txt_mod.1"] = std::shared_ptr<GGMLBlock>(new Linear(dim, 6 * dim, true));