fix nan issue that occurs when using CUDA with k-quants weights

This commit is contained in:
leejet 2025-11-30 22:54:13 +08:00
parent 2fec01d2b3
commit 1798ec02ba

View File

@ -85,7 +85,15 @@ namespace ZImage {
}
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) / multiple_of);
blocks["w1"] = std::make_shared<Linear>(dim, hidden_dim, false);
blocks["w2"] = std::make_shared<Linear>(hidden_dim, dim, false);
bool force_prec_f32 = false;
float scale = 1.f / 128.f;
#ifdef SD_USE_VULKAN
force_prec_f32 = true;
#endif
// The purpose of the scale here is to prevent NaN issues in certain situations.
// For example, when using CUDA but the weights are k-quants.
blocks["w2"] = std::make_shared<Linear>(hidden_dim, dim, false, false, force_prec_f32, 1.f / 128.f);
blocks["w3"] = std::make_shared<Linear>(dim, hidden_dim, false);
}