fix: prevent NaN issues with Z-Image on certain ROCm setups (#1034)

This commit is contained in:
Wagner Bruna 2025-12-03 11:19:34 -03:00 committed by GitHub
parent 710169df5c
commit 99e17232a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -30,7 +30,12 @@ namespace ZImage {
JointAttention(int64_t hidden_size, int64_t head_dim, int64_t num_heads, int64_t num_kv_heads, bool qk_norm)
: head_dim(head_dim), num_heads(num_heads), num_kv_heads(num_kv_heads), qk_norm(qk_norm) {
blocks["qkv"] = std::make_shared<Linear>(hidden_size, (num_heads + num_kv_heads * 2) * head_dim, false);
blocks["out"] = std::make_shared<Linear>(num_heads * head_dim, hidden_size, false);
float scale = 1.f;
#if GGML_USE_HIP
// Prevent NaN issues with certain ROCm setups
scale = 1.f / 16.f;
#endif
blocks["out"] = std::make_shared<Linear>(num_heads * head_dim, hidden_size, false, false, false, scale);
if (qk_norm) {
blocks["q_norm"] = std::make_shared<RMSNorm>(head_dim);
blocks["k_norm"] = std::make_shared<RMSNorm>(head_dim);
@ -93,7 +98,7 @@ namespace ZImage {
#endif
// The purpose of the scale here is to prevent NaN issues in certain situations.
// For example, when using CUDA but the weights are k-quants.
blocks["w2"] = std::make_shared<Linear>(hidden_dim, dim, false, false, force_prec_f32, 1.f / 128.f);
blocks["w2"] = std::make_shared<Linear>(hidden_dim, dim, false, false, force_prec_f32, scale);
blocks["w3"] = std::make_shared<Linear>(dim, hidden_dim, false);
}