mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-17 03:37:20 +00:00
fix: correct mask shape for masked flash attention (#1625)
This commit is contained in:
parent
19bdfe22d2
commit
1b702a51e7
@ -1346,10 +1346,18 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_attention_ext(ggml_context* ctx,
|
||||
v_in = ggml_cast(ctx, v_in, GGML_TYPE_F16);
|
||||
|
||||
if (mask_in != nullptr) {
|
||||
mask_in = ggml_transpose(ctx, mask_in);
|
||||
// ggml_flash_attn_ext expects the mask as a contiguous F16 tensor shaped
|
||||
// [n_kv, n_q, (heads), (batch)] (ne0 = key length, ne1 = query length) and,
|
||||
// unlike the manual-attention path, does not broadcast the query dimension.
|
||||
// Some callers (e.g. Chroma/T5) pass a per-key padding mask broadcast over
|
||||
// queries ([n_kv, 1, ...]); materialize the query dimension to L_q so the
|
||||
// kernel indexes it correctly. (A bare ggml_transpose here produced a
|
||||
// [1, n_kv, ...] mask that the kernel silently misreads, yielding NaN/blank
|
||||
// output for masked flash attention.)
|
||||
if (mask_in->ne[1] != L_q) {
|
||||
mask_in = ggml_repeat(ctx, mask_in,
|
||||
ggml_new_tensor_4d(ctx, mask_in->type, mask_in->ne[0], L_q, mask_in->ne[2], mask_in->ne[3]));
|
||||
}
|
||||
|
||||
if (mask_in != nullptr) {
|
||||
mask_in = ggml_cast(ctx, mask_in, GGML_TYPE_F16);
|
||||
}
|
||||
|
||||
|
||||
@ -575,15 +575,6 @@ public:
|
||||
}
|
||||
}
|
||||
if (is_chroma) {
|
||||
if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
|
||||
LOG_WARN(
|
||||
"!!!It looks like you are using Chroma with flash attention. "
|
||||
"This is currently unsupported. "
|
||||
"If you find that the generated images are broken, "
|
||||
"try either disabling flash attention or specifying "
|
||||
"--chroma-disable-dit-mask as a workaround.");
|
||||
}
|
||||
|
||||
cond_stage_model = std::make_shared<T5CLIPEmbedder>(backend_for(SDBackendModule::TE),
|
||||
params_backend_for(SDBackendModule::TE),
|
||||
tensor_storage_map,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user