fix: avoid precision issues on vulkan backend (#980)

This commit is contained in:
leejet 2025-11-16 20:57:08 +08:00 committed by GitHub
parent d5b05f70c6
commit f532972d60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View File

@ -242,14 +242,18 @@ public:
} }
// net_1 is nn.Dropout(), skip for inference // net_1 is nn.Dropout(), skip for inference
bool force_prec_f32 = false;
float scale = 1.f; float scale = 1.f;
if (precision_fix) { if (precision_fix) {
scale = 1.f / 128.f; scale = 1.f / 128.f;
#ifdef SD_USE_VULKAN
force_prec_f32 = true;
#endif
} }
// The purpose of the scale here is to prevent NaN issues in certain situations. // The purpose of the scale here is to prevent NaN issues in certain situations.
// For example, when using Vulkan without enabling force_prec_f32, // For example, when using Vulkan without enabling force_prec_f32,
// or when using CUDA but the weights are k-quants. // or when using CUDA but the weights are k-quants.
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, false, scale)); blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, force_prec_f32, scale));
} }
struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x) {

View File

@ -95,9 +95,13 @@ namespace Qwen {
blocks["norm_added_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps)); blocks["norm_added_k"] = std::shared_ptr<GGMLBlock>(new RMSNorm(dim_head, eps));
float scale = 1.f / 32.f; float scale = 1.f / 32.f;
bool force_prec_f32 = false;
#ifdef SD_USE_VULKAN
force_prec_f32 = true;
#endif
// The purpose of the scale here is to prevent NaN issues in certain situations. // The purpose of the scale here is to prevent NaN issues in certain situations.
// For example when using CUDA but the weights are k-quants (not all prompts). // For example when using CUDA but the weights are k-quants (not all prompts).
blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_dim, out_bias, false, false, scale)); blocks["to_out.0"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_dim, out_bias, false, force_prec_f32, scale));
// to_out.1 is nn.Dropout // to_out.1 is nn.Dropout
blocks["to_add_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale)); blocks["to_add_out"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale));