fix: correct mask shape for masked flash attention (#1625)

This commit is contained in:
RapidMark 2026-06-12 22:01:20 -07:00 committed by GitHub
parent 19bdfe22d2
commit 1b702a51e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 13 deletions

View File

@ -1346,10 +1346,18 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16); v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
if (mask_in != nullptr) { if (mask_in != nullptr) {
mask_in = ggml_transpose(ctx, mask_in); // ggml_flash_attn_ext expects the mask as a contiguous F16 tensor shaped
} // [n_kv, n_q, (heads), (batch)] (ne0 = key length, ne1 = query length) and,
// unlike the manual-attention path, does not broadcast the query dimension.
if (mask_in != nullptr) { // Some callers (e.g. Chroma/T5) pass a per-key padding mask broadcast over
// queries ([n_kv, 1, ...]); materialize the query dimension to L_q so the
// kernel indexes it correctly. (A bare ggml_transpose here produced a
// [1, n_kv, ...] mask that the kernel silently misreads, yielding NaN/blank
// output for masked flash attention.)
if (mask_in->ne[1] != L_q) {
mask_in = ggml_repeat(ctx, mask_in,
ggml_new_tensor_4d(ctx, mask_in->type, mask_in->ne[0], L_q, mask_in->ne[2], mask_in->ne[3]));
}
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16); mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
} }

View File

@ -575,15 +575,6 @@ public:
} }
} }
if (is_chroma) { if (is_chroma) {
if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
LOG_WARN(
"!!!It looks like you are using Chroma with flash attention. "
"This is currently unsupported. "
"If you find that the generated images are broken, "
"try either disabling flash attention or specifying "
"--chroma-disable-dit-mask as a workaround.");
}
cond_stage_model = std::make_shared<T5CLIPEmbedder>(backend_for(SDBackendModule::TE), cond_stage_model = std::make_shared<T5CLIPEmbedder>(backend_for(SDBackendModule::TE),
params_backend_for(SDBackendModule::TE), params_backend_for(SDBackendModule::TE),
tensor_storage_map, tensor_storage_map,