make flash attn work with wan

This commit is contained in:
leejet 2025-08-10 17:52:59 +08:00
parent 1d9ccea41a
commit 00f790d0e9
3 changed files with 36 additions and 40 deletions

View File

@ -1011,12 +1011,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
float scale = (1.0f / sqrt((float)d_head)); float scale = (1.0f / sqrt((float)d_head));
int kv_pad = 0; int kv_pad = 0;
// if (flash_attn) { if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
// }
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
// GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
bool can_use_flash_attn = true; bool can_use_flash_attn = true;
can_use_flash_attn = can_use_flash_attn && (d_head == 64 || can_use_flash_attn = can_use_flash_attn && (d_head == 64 ||
d_head == 80 || d_head == 80 ||
@ -1024,18 +1020,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
d_head == 112 || d_head == 112 ||
d_head == 128 || d_head == 128 ||
d_head == 256); d_head == 256);
#if 0
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
#else
if (can_use_flash_attn && L_k % 256 != 0) { if (can_use_flash_attn && L_k % 256 != 0) {
// TODO(Green-Sky): might be worth just padding by default // TODO(Green-Sky): might be worth just padding by default
if (L_k == 77 || L_k == 4208 || L_k == 3952) { if (L_k == 77 || L_k == 1560 || L_k == 4208 || L_k == 3952) {
kv_pad = GGML_PAD(L_k, 256) - L_k; kv_pad = GGML_PAD(L_k, 256) - L_k;
} else { } else {
can_use_flash_attn = false; can_use_flash_attn = false;
} }
} }
#endif
if (mask != nullptr) { if (mask != nullptr) {
// TODO(Green-Sky): figure out if we can bend t5 to work too // TODO(Green-Sky): figure out if we can bend t5 to work too
@ -1043,11 +1035,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
} }
// TODO(Green-Sky): more pad or disable for funny tensor shapes if (!can_use_flash_attn) {
flash_attn = false;
}
}
ggml_tensor* kqv = nullptr; ggml_tensor* kqv = nullptr;
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn); if (flash_attn) {
if (can_use_flash_attn && flash_attn) {
// LOG_DEBUG(" uses flash attention"); // LOG_DEBUG(" uses flash attention");
if (kv_pad != 0) { if (kv_pad != 0) {
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad); // LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);

View File

@ -1092,7 +1092,6 @@ public:
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta); sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
LOG_DEBUG("sigmas[sigmas.size() - 1] %f", sigmas[sigmas.size() - 1]);
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
if (control_net) { if (control_net) {

17
wan.hpp
View File

@ -939,7 +939,7 @@ namespace WAN {
k = norm_k->forward(ctx, k); k = norm_k->forward(ctx, k);
auto v = v_proj->forward(ctx, context); // [N, n_context, dim] auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim] x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = o_proj->forward(ctx, x); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim]
return x; return x;
@ -1003,8 +1003,8 @@ namespace WAN {
k_img = norm_k_img->forward(ctx, k_img); k_img = norm_k_img->forward(ctx, k_img);
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads); // [N, n_token, dim] auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim] x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = ggml_add(ctx, x, img_x); x = ggml_add(ctx, x, img_x);
@ -1033,16 +1033,16 @@ namespace WAN {
bool flash_attn = false) bool flash_attn = false)
: dim(dim) { : dim(dim) {
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false)); blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps)); blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn));
if (cross_attn_norm) { if (cross_attn_norm) {
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, true)); blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, true));
} else { } else {
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new Identity()); blocks["norm3"] = std::shared_ptr<GGMLBlock>(new Identity());
} }
if (t2v_cross_attn) { if (t2v_cross_attn) {
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps)); blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn));
} else { } else {
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps)); blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn));
} }
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false)); blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
@ -1215,6 +1215,7 @@ namespace WAN {
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24 // wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
std::vector<int> axes_dim = {44, 42, 42}; std::vector<int> axes_dim = {44, 42, 42};
int64_t axes_dim_sum = 128; int64_t axes_dim_sum = 128;
bool flash_attn = false;
}; };
class Wan : public GGMLBlock { class Wan : public GGMLBlock {
@ -1249,7 +1250,8 @@ namespace WAN {
params.num_heads, params.num_heads,
params.qk_norm, params.qk_norm,
params.cross_attn_norm, params.cross_attn_norm,
params.eps)); params.eps,
params.flash_attn));
blocks["blocks." + std::to_string(i)] = block; blocks["blocks." + std::to_string(i)] = block;
} }
@ -1428,6 +1430,7 @@ namespace WAN {
SDVersion version = VERSION_WAN2, SDVersion version = VERSION_WAN2,
bool flash_attn = false) bool flash_attn = false)
: GGMLRunner(backend) { : GGMLRunner(backend) {
wan_params.flash_attn = flash_attn;
wan_params.num_layers = 0; wan_params.num_layers = 0;
for (auto pair : tensor_types) { for (auto pair : tensor_types) {
std::string tensor_name = pair.first; std::string tensor_name = pair.first;