diff --git a/qwen_image.hpp b/qwen_image.hpp index 726d24d..3ac32de 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -97,7 +97,10 @@ namespace Qwen { blocks["to_out.0"] = std::shared_ptr(new Linear(inner_dim, out_dim, out_bias)); // to_out.1 is nn.Dropout - blocks["to_add_out"] = std::shared_ptr(new Linear(inner_dim, out_context_dim, out_bias)); + float scale = 1.f / 32.f; + // The purpose of the scale here is to prevent NaN issues in certain situations. + // For example when using CUDA but the weights are k-quants (not all prompts). + blocks["to_add_out"] = std::shared_ptr(new Linear(inner_dim, out_context_dim, out_bias, false, false, scale)); } std::pair forward(struct ggml_context* ctx,