diff --git a/ggml_extend.hpp b/ggml_extend.hpp index b5f4274..796ae33 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1011,43 +1011,37 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* float scale = (1.0f / sqrt((float)d_head)); int kv_pad = 0; - // if (flash_attn) { - // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); - // } - // is there anything oddly shaped?? ping Green-Sky if you can trip this assert - // GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); + if (flash_attn) { + // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); + bool can_use_flash_attn = true; + can_use_flash_attn = can_use_flash_attn && (d_head == 64 || + d_head == 80 || + d_head == 96 || + d_head == 112 || + d_head == 128 || + d_head == 256); + if (can_use_flash_attn && L_k % 256 != 0) { + // TODO(Green-Sky): might be worth just padding by default + if (L_k == 77 || L_k == 1560 || L_k == 4208 || L_k == 3952) { + kv_pad = GGML_PAD(L_k, 256) - L_k; + } else { + can_use_flash_attn = false; + } + } - bool can_use_flash_attn = true; - can_use_flash_attn = can_use_flash_attn && (d_head == 64 || - d_head == 80 || - d_head == 96 || - d_head == 112 || - d_head == 128 || - d_head == 256); -#if 0 - can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0; -#else - if (can_use_flash_attn && L_k % 256 != 0) { - // TODO(Green-Sky): might be worth just padding by default - if (L_k == 77 || L_k == 4208 || L_k == 3952) { - kv_pad = GGML_PAD(L_k, 256) - L_k; - } else { - can_use_flash_attn = false; + if (mask != nullptr) { + // TODO(Green-Sky): figure out if we can bend t5 to work too + can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1; + can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; + } + + if (!can_use_flash_attn) { + flash_attn = false; } } -#endif - - if (mask != nullptr) { - // TODO(Green-Sky): figure out if we can bend t5 to work too - can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1; - can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; - } - - // TODO(Green-Sky): more pad or disable for funny tensor shapes ggml_tensor* kqv = nullptr; - // GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn); - if (can_use_flash_attn && flash_attn) { + if (flash_attn) { // LOG_DEBUG(" uses flash attention"); if (kv_pad != 0) { // LOG_DEBUG(" padding k and v dim1 by %d", kv_pad); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index c9d5c28..3ef767b 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1092,7 +1092,6 @@ public: sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta); - LOG_DEBUG("sigmas[sigmas.size() - 1] %f", sigmas[sigmas.size() - 1]); x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); if (control_net) { diff --git a/wan.hpp b/wan.hpp index 18dc07a..f031a23 100644 --- a/wan.hpp +++ b/wan.hpp @@ -939,7 +939,7 @@ namespace WAN { k = norm_k->forward(ctx, k); auto v = v_proj->forward(ctx, context); // [N, n_context, dim] - x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim] + x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] x = o_proj->forward(ctx, x); // [N, n_token, dim] return x; @@ -1003,8 +1003,8 @@ namespace WAN { k_img = norm_k_img->forward(ctx, k_img); auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim] - auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads); // [N, n_token, dim] - x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim] + auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] + x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim] x = ggml_add(ctx, x, img_x); @@ -1033,16 +1033,16 @@ namespace WAN { bool flash_attn = false) : dim(dim) { blocks["norm1"] = std::shared_ptr(new LayerNorm(dim, eps, false)); - blocks["self_attn"] = std::shared_ptr(new WanSelfAttention(dim, num_heads, qk_norm, eps)); + blocks["self_attn"] = std::shared_ptr(new WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn)); if (cross_attn_norm) { blocks["norm3"] = std::shared_ptr(new LayerNorm(dim, eps, true)); } else { blocks["norm3"] = std::shared_ptr(new Identity()); } if (t2v_cross_attn) { - blocks["cross_attn"] = std::shared_ptr(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps)); + blocks["cross_attn"] = std::shared_ptr(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn)); } else { - blocks["cross_attn"] = std::shared_ptr(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps)); + blocks["cross_attn"] = std::shared_ptr(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn)); } blocks["norm2"] = std::shared_ptr(new LayerNorm(dim, eps, false)); @@ -1215,6 +1215,7 @@ namespace WAN { // wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24 std::vector axes_dim = {44, 42, 42}; int64_t axes_dim_sum = 128; + bool flash_attn = false; }; class Wan : public GGMLBlock { @@ -1249,7 +1250,8 @@ namespace WAN { params.num_heads, params.qk_norm, params.cross_attn_norm, - params.eps)); + params.eps, + params.flash_attn)); blocks["blocks." + std::to_string(i)] = block; } @@ -1428,6 +1430,7 @@ namespace WAN { SDVersion version = VERSION_WAN2, bool flash_attn = false) : GGMLRunner(backend) { + wan_params.flash_attn = flash_attn; wan_params.num_layers = 0; for (auto pair : tensor_types) { std::string tensor_name = pair.first;