feat: add sd3 flash attn support (#815)

This commit is contained in:
leejet 2025-09-11 23:24:29 +08:00 committed by GitHub
parent 49d6570c43
commit fce6afcc6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 21 deletions

View File

@ -95,8 +95,9 @@ struct MMDiTModel : public DiffusionModel {
MMDiTModel(ggml_backend_t backend,
bool offload_params_to_cpu,
bool flash_attn = false,
const String2GGMLType& tensor_types = {})
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
: mmdit(backend, offload_params_to_cpu, flash_attn, tensor_types, "model.diffusion_model") {
}
std::string get_desc() {

View File

@ -147,14 +147,16 @@ public:
int64_t num_heads;
bool pre_only;
std::string qk_norm;
bool flash_attn;
public:
SelfAttention(int64_t dim,
int64_t num_heads = 8,
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false)
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) {
bool pre_only = false,
bool flash_attn = false)
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm), flash_attn(flash_attn) {
int64_t d_head = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
if (!pre_only) {
@ -206,8 +208,8 @@ public:
ggml_backend_t backend,
struct ggml_tensor* x) {
auto qkv = pre_attention(ctx, x);
x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, true); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim]
return x;
}
};
@ -232,6 +234,7 @@ public:
int64_t num_heads;
bool pre_only;
bool self_attn;
bool flash_attn;
public:
DismantledBlock(int64_t hidden_size,
@ -240,16 +243,17 @@ public:
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false,
bool self_attn = false)
bool self_attn = false,
bool flash_attn = false)
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
// rmsnorm is always Flase
// scale_mod_only is always Flase
// swiglu is always Flase
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only, flash_attn));
if (self_attn) {
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false));
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false, flash_attn));
}
if (!pre_only) {
@ -435,8 +439,8 @@ public:
auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates);
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim]
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = post_attention_x(ctx,
attn_out,
attn2_out,
@ -452,7 +456,7 @@ public:
auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = post_attention(ctx,
attn_out,
intermediates[0],
@ -468,6 +472,7 @@ public:
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
block_mixing(struct ggml_context* ctx,
ggml_backend_t backend,
bool flash_attn,
struct ggml_tensor* context,
struct ggml_tensor* x,
struct ggml_tensor* c,
@ -497,8 +502,8 @@ block_mixing(struct ggml_context* ctx,
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
}
auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, NULL, false, false, flash_attn); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto context_attn = ggml_view_3d(ctx,
attn,
attn->ne[0],
@ -556,6 +561,8 @@ block_mixing(struct ggml_context* ctx,
}
struct JointBlock : public GGMLBlock {
bool flash_attn;
public:
JointBlock(int64_t hidden_size,
int64_t num_heads,
@ -563,9 +570,11 @@ public:
std::string qk_norm = "",
bool qkv_bias = false,
bool pre_only = false,
bool self_attn_x = false) {
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x));
bool self_attn_x = false,
bool flash_attn = false)
: flash_attn(flash_attn) {
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false, flash_attn));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x, flash_attn));
}
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
@ -576,7 +585,7 @@ public:
auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]);
auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]);
return block_mixing(ctx, backend, context, x, c, context_block, x_block);
return block_mixing(ctx, backend, flash_attn, context, x, c, context_block, x_block);
}
};
@ -634,6 +643,7 @@ protected:
int64_t context_embedder_out_dim = 1536;
int64_t hidden_size;
std::string qk_norm;
bool flash_attn = false;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32;
@ -641,7 +651,8 @@ protected:
}
public:
MMDiT(const String2GGMLType& tensor_types = {}) {
MMDiT(bool flash_attn = false, const String2GGMLType& tensor_types = {})
: flash_attn(flash_attn) {
// input_size is always None
// learn_sigma is always False
// register_length is alwalys 0
@ -709,7 +720,8 @@ public:
qk_norm,
true,
i == depth - 1,
i <= d_self));
i <= d_self,
flash_attn));
}
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
@ -856,9 +868,10 @@ struct MMDiTRunner : public GGMLRunner {
MMDiTRunner(ggml_backend_t backend,
bool offload_params_to_cpu,
bool flash_attn,
const String2GGMLType& tensor_types = {},
const std::string prefix = "")
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) {
: GGMLRunner(backend, offload_params_to_cpu), mmdit(flash_attn, tensor_types) {
mmdit.init(params_ctx, tensor_types, prefix);
}
@ -957,7 +970,7 @@ struct MMDiTRunner : public GGMLRunner {
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_F16;
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false));
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false, false));
{
LOG_INFO("loading from '%s'", file_path.c_str());

View File

@ -350,6 +350,7 @@ public:
model_loader.tensor_storages_types);
diffusion_model = std::make_shared<MMDiTModel>(backend,
offload_params_to_cpu,
sd_ctx_params->diffusion_flash_attn,
model_loader.tensor_storages_types);
} else if (sd_version_is_flux(version)) {
bool is_chroma = false;