mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
make flash attn work with wan
This commit is contained in:
parent
1d9ccea41a
commit
00f790d0e9
@ -1011,12 +1011,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
float scale = (1.0f / sqrt((float)d_head));
|
float scale = (1.0f / sqrt((float)d_head));
|
||||||
|
|
||||||
int kv_pad = 0;
|
int kv_pad = 0;
|
||||||
// if (flash_attn) {
|
if (flash_attn) {
|
||||||
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||||
// }
|
|
||||||
// is there anything oddly shaped?? ping Green-Sky if you can trip this assert
|
|
||||||
// GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
|
|
||||||
|
|
||||||
bool can_use_flash_attn = true;
|
bool can_use_flash_attn = true;
|
||||||
can_use_flash_attn = can_use_flash_attn && (d_head == 64 ||
|
can_use_flash_attn = can_use_flash_attn && (d_head == 64 ||
|
||||||
d_head == 80 ||
|
d_head == 80 ||
|
||||||
@ -1024,18 +1020,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
d_head == 112 ||
|
d_head == 112 ||
|
||||||
d_head == 128 ||
|
d_head == 128 ||
|
||||||
d_head == 256);
|
d_head == 256);
|
||||||
#if 0
|
|
||||||
can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
|
|
||||||
#else
|
|
||||||
if (can_use_flash_attn && L_k % 256 != 0) {
|
if (can_use_flash_attn && L_k % 256 != 0) {
|
||||||
// TODO(Green-Sky): might be worth just padding by default
|
// TODO(Green-Sky): might be worth just padding by default
|
||||||
if (L_k == 77 || L_k == 4208 || L_k == 3952) {
|
if (L_k == 77 || L_k == 1560 || L_k == 4208 || L_k == 3952) {
|
||||||
kv_pad = GGML_PAD(L_k, 256) - L_k;
|
kv_pad = GGML_PAD(L_k, 256) - L_k;
|
||||||
} else {
|
} else {
|
||||||
can_use_flash_attn = false;
|
can_use_flash_attn = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
|
||||||
|
|
||||||
if (mask != nullptr) {
|
if (mask != nullptr) {
|
||||||
// TODO(Green-Sky): figure out if we can bend t5 to work too
|
// TODO(Green-Sky): figure out if we can bend t5 to work too
|
||||||
@ -1043,11 +1035,13 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
|
can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(Green-Sky): more pad or disable for funny tensor shapes
|
if (!can_use_flash_attn) {
|
||||||
|
flash_attn = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor* kqv = nullptr;
|
ggml_tensor* kqv = nullptr;
|
||||||
// GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
|
if (flash_attn) {
|
||||||
if (can_use_flash_attn && flash_attn) {
|
|
||||||
// LOG_DEBUG(" uses flash attention");
|
// LOG_DEBUG(" uses flash attention");
|
||||||
if (kv_pad != 0) {
|
if (kv_pad != 0) {
|
||||||
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
|
// LOG_DEBUG(" padding k and v dim1 by %d", kv_pad);
|
||||||
|
|||||||
@ -1092,7 +1092,6 @@ public:
|
|||||||
|
|
||||||
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
|
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
|
||||||
|
|
||||||
LOG_DEBUG("sigmas[sigmas.size() - 1] %f", sigmas[sigmas.size() - 1]);
|
|
||||||
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
||||||
|
|
||||||
if (control_net) {
|
if (control_net) {
|
||||||
|
|||||||
17
wan.hpp
17
wan.hpp
@ -939,7 +939,7 @@ namespace WAN {
|
|||||||
k = norm_k->forward(ctx, k);
|
k = norm_k->forward(ctx, k);
|
||||||
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
|
auto v = v_proj->forward(ctx, context); // [N, n_context, dim]
|
||||||
|
|
||||||
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim]
|
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
|
||||||
|
|
||||||
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
x = o_proj->forward(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
@ -1003,8 +1003,8 @@ namespace WAN {
|
|||||||
k_img = norm_k_img->forward(ctx, k_img);
|
k_img = norm_k_img->forward(ctx, k_img);
|
||||||
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
auto v_img = v_img_proj->forward(ctx, context_img); // [N, context_img_len, dim]
|
||||||
|
|
||||||
auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads); // [N, n_token, dim]
|
auto img_x = ggml_nn_attention_ext(ctx, q, k_img, v_img, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
|
||||||
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads); // [N, n_token, dim]
|
x = ggml_nn_attention_ext(ctx, q, k, v, num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
|
||||||
|
|
||||||
x = ggml_add(ctx, x, img_x);
|
x = ggml_add(ctx, x, img_x);
|
||||||
|
|
||||||
@ -1033,16 +1033,16 @@ namespace WAN {
|
|||||||
bool flash_attn = false)
|
bool flash_attn = false)
|
||||||
: dim(dim) {
|
: dim(dim) {
|
||||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||||
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps));
|
blocks["self_attn"] = std::shared_ptr<GGMLBlock>(new WanSelfAttention(dim, num_heads, qk_norm, eps, flash_attn));
|
||||||
if (cross_attn_norm) {
|
if (cross_attn_norm) {
|
||||||
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, true));
|
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, true));
|
||||||
} else {
|
} else {
|
||||||
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new Identity());
|
blocks["norm3"] = std::shared_ptr<GGMLBlock>(new Identity());
|
||||||
}
|
}
|
||||||
if (t2v_cross_attn) {
|
if (t2v_cross_attn) {
|
||||||
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps));
|
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanT2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn));
|
||||||
} else {
|
} else {
|
||||||
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps));
|
blocks["cross_attn"] = std::shared_ptr<GGMLBlock>(new WanI2VCrossAttention(dim, num_heads, qk_norm, eps, flash_attn));
|
||||||
}
|
}
|
||||||
|
|
||||||
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
blocks["norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
|
||||||
@ -1215,6 +1215,7 @@ namespace WAN {
|
|||||||
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
|
// wan2.1 1.3B: 1536/12, wan2.1/2.2 14B: 5120/40, wan2.2 5B: 3074/24
|
||||||
std::vector<int> axes_dim = {44, 42, 42};
|
std::vector<int> axes_dim = {44, 42, 42};
|
||||||
int64_t axes_dim_sum = 128;
|
int64_t axes_dim_sum = 128;
|
||||||
|
bool flash_attn = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Wan : public GGMLBlock {
|
class Wan : public GGMLBlock {
|
||||||
@ -1249,7 +1250,8 @@ namespace WAN {
|
|||||||
params.num_heads,
|
params.num_heads,
|
||||||
params.qk_norm,
|
params.qk_norm,
|
||||||
params.cross_attn_norm,
|
params.cross_attn_norm,
|
||||||
params.eps));
|
params.eps,
|
||||||
|
params.flash_attn));
|
||||||
blocks["blocks." + std::to_string(i)] = block;
|
blocks["blocks." + std::to_string(i)] = block;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1428,6 +1430,7 @@ namespace WAN {
|
|||||||
SDVersion version = VERSION_WAN2,
|
SDVersion version = VERSION_WAN2,
|
||||||
bool flash_attn = false)
|
bool flash_attn = false)
|
||||||
: GGMLRunner(backend) {
|
: GGMLRunner(backend) {
|
||||||
|
wan_params.flash_attn = flash_attn;
|
||||||
wan_params.num_layers = 0;
|
wan_params.num_layers = 0;
|
||||||
for (auto pair : tensor_types) {
|
for (auto pair : tensor_types) {
|
||||||
std::string tensor_name = pair.first;
|
std::string tensor_name = pair.first;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user