diff --git a/z_image.hpp b/z_image.hpp index 55c6125..b692a14 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -85,7 +85,15 @@ namespace ZImage { } hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) / multiple_of); blocks["w1"] = std::make_shared(dim, hidden_dim, false); - blocks["w2"] = std::make_shared(hidden_dim, dim, false); + + bool force_prec_f32 = false; + float scale = 1.f / 128.f; +#ifdef SD_USE_VULKAN + force_prec_f32 = true; +#endif + // The purpose of the scale here is to prevent NaN issues in certain situations. + // For example, when using CUDA but the weights are k-quants. + blocks["w2"] = std::make_shared(hidden_dim, dim, false, false, force_prec_f32, 1.f / 128.f); blocks["w3"] = std::make_shared(dim, hidden_dim, false); }