mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
make flash attn work with wan
This commit is contained in:
parent
1d9ccea41a
commit
00f790d0e9
@ -1011,43 +1011,37 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
||||
float scale = (1.0f / sqrt((float)d_head));
|
||||
|
||||
int kv_pad = 0;
|
||||
// if (flash_attn) {
|
||||
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||
// }
|
||||
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
|
||||
// GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
|
||||
if (flash_attn) {
|
||||
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||
bool can_use_flash_attn = true;
|
||||
can_use_flash_attn = can_use_flash_attn && (d_head == 64 ||
|
||||
d_head == 80 ||
|
||||
d_head == 96 ||
|
||||
d_head == 112 ||
|
||||
d_head == 128 ||
|
||||
d_head == 256);
|
||||
if (can_use_flash_attn && L_k % 256 != 0) {
|
||||
// TODO(Green-Sky): might be worth just padding by default
|
||||
if (L_k == 77 || L_k == 1560 || L_k == 4208 || L_k == 3952) {
|
||||
kv_pad = GGML_PAD(L_k, 256) - L_k;
|
||||
} else {
|
||||
can_use_flash_attn = false;
|
||||
}
|
||||
}
|
||||
|
||||
bool can_use_flash_attn = true;
|
||||
can_use_flash_attn = can_use_flash_attn && (d_head == 64 ||
|
||||
d_head == 80 ||
|
||||
d_head == 96 ||
|
||||
d_head == 112 ||
|
||||
d_head == 128 ||
|
||||
d_head == 256);
|
||||
#if 0
|
||||
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
|
||||
#else
|
||||
if (can_use_flash_attn && L_k % 256 != 0) {
|
||||
// TODO(Green-Sky): might be worth just padding by default
|
||||
if (L_k == 77 || L_k == 4208 || L_k == 3952) {
|
||||
kv_pad = GGML_PAD(L_k, 256) - L_k;
|
||||
} else {
|
||||
can_use_flash_attn = false;
|
||||
if (mask != nullptr) {
|
||||
// TODO(Green-Sky): figure out if we can bend t5 to work too
|
||||
can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1;
|
||||
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
|
||||
}
|
||||
|
||||
if (!can_use_flash_attn) {
|
||||
flash_attn = false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if (mask != nullptr) {
|
||||
// TODO(Green-Sky): figure out if we can bend t5 to work too
|
||||
can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1;
|
||||
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
|
||||
}
|
||||
|
||||
// TODO(Green-Sky): more pad or disable for funny tensor shapes
|
||||
|
||||
ggml_tensor* kqv = nullptr;
|
||||
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
|
||||
if (can_use_flash_attn && flash_attn) {
|
||||
if (flash_attn) {
|
||||
// LOG_DEBUG(" uses flash attention");
|
||||
if (kv_pad != 0) {
|
||||
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
|
||||
|
||||
@ -1092,7 +1092,6 @@ public:
|
||||
|
||||
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
|
||||
|
||||
LOG_DEBUG("sigmas[sigmas.size() - 1] %f", sigmas[sigmas.size() - 1]);
|
||||
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
||||
|
||||
if (control_net) {
|
||||
|
||||
17
wan.hpp
17
wan.hpp
@ -939,7 +939,7 @@ namespace WAN {
|
||||
k = norm_k->forward(ctx, k);
|
||||
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
|
||||
|
||||
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim]
|
||||
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
|
||||
|
||||
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
||||
return x;
|
||||
@ -1003,8 +1003,8 @@ namespace WAN {
|
||||
k_img = norm_k_img->forward(ctx, k_img);
|
||||
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
||||
|
||||
auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads); // [N, n_token, dim]
|
||||
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim]
|
||||
auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
|
||||
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
|
||||
|
||||
x = ggml_add(ctx, x, img_x);
|
||||
|
||||
@ -1033,16 +1033,16 @@ namespace WAN {
|
||||
bool flash_attn = false)
|
||||
: dim(dim) {
|
||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps));
|
||||
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn));
|
||||
if (cross_attn_norm) {
|
||||
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, true));
|
||||
} else {
|
||||
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new Identity());
|
||||
}
|
||||
if (t2v_cross_attn) {
|
||||
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps));
|
||||
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn));
|
||||
} else {
|
||||
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps));
|
||||
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn));
|
||||
}
|
||||
|
||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||
@ -1215,6 +1215,7 @@ namespace WAN {
|
||||
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
|
||||
std::vector<int> axes_dim = {44, 42, 42};
|
||||
int64_t axes_dim_sum = 128;
|
||||
bool flash_attn = false;
|
||||
};
|
||||
|
||||
class Wan : public GGMLBlock {
|
||||
@ -1249,7 +1250,8 @@ namespace WAN {
|
||||
params.num_heads,
|
||||
params.qk_norm,
|
||||
params.cross_attn_norm,
|
||||
params.eps));
|
||||
params.eps,
|
||||
params.flash_attn));
|
||||
blocks["blocks." + std::to_string(i)] = block;
|
||||
}
|
||||
|
||||
@ -1428,6 +1430,7 @@ namespace WAN {
|
||||
SDVersion version = VERSION_WAN2,
|
||||
bool flash_attn = false)
|
||||
: GGMLRunner(backend) {
|
||||
wan_params.flash_attn = flash_attn;
|
||||
wan_params.num_layers = 0;
|
||||
for (auto pair : tensor_types) {
|
||||
std::string tensor_name = pair.first;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user