feat: add sd3 flash attn support (#815)

This commit is contained in:
leejet 2025-09-11 23:24:29 +08:00 committed by GitHub
parent 49d6570c43
commit fce6afcc6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 21 deletions

View File

@ -95,8 +95,9 @@ struct MMDiTModel : public DiffusionModel {
MMDiTModel(ggml_backend_t backend, MMDiTModel(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
bool flash_attn = false,
const String2GGMLType& tensor_types = {}) const String2GGMLType& tensor_types = {})
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") { : mmdit(backend, offload_params_to_cpu, flash_attn, tensor_types, "model.diffusion_model") {
} }
std::string get_desc() { std::string get_desc() {

View File

@ -147,14 +147,16 @@ public:
int64_t num_heads; int64_t num_heads;
bool pre_only; bool pre_only;
std::string qk_norm; std::string qk_norm;
bool flash_attn;
public: public:
SelfAttention(int64_t dim, SelfAttention(int64_t dim,
int64_t num_heads = 8, int64_t num_heads = 8,
std::string qk_norm = "", std::string qk_norm = "",
bool qkv_bias = false, bool qkv_bias = false,
bool pre_only = false) bool pre_only = false,
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) { bool flash_attn = false)
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm), flash_attn(flash_attn) {
int64_t d_head = dim / num_heads; int64_t d_head = dim / num_heads;
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias)); blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
if (!pre_only) { if (!pre_only) {
@ -206,7 +208,7 @@ public:
ggml_backend_t backend, ggml_backend_t backend,
struct ggml_tensor* x) { struct ggml_tensor* x) {
auto qkv = pre_attention(ctx, x); auto qkv = pre_attention(ctx, x);
x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, true); // [N, n_token, dim]
x = post_attention(ctx, x); // [N, n_token, dim] x = post_attention(ctx, x); // [N, n_token, dim]
return x; return x;
} }
@ -232,6 +234,7 @@ public:
int64_t num_heads; int64_t num_heads;
bool pre_only; bool pre_only;
bool self_attn; bool self_attn;
bool flash_attn;
public: public:
DismantledBlock(int64_t hidden_size, DismantledBlock(int64_t hidden_size,
@ -240,16 +243,17 @@ public:
std::string qk_norm = "", std::string qk_norm = "",
bool qkv_bias = false, bool qkv_bias = false,
bool pre_only = false, bool pre_only = false,
bool self_attn = false) bool self_attn = false,
bool flash_attn = false)
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) { : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
// rmsnorm is always Flase // rmsnorm is always Flase
// scale_mod_only is always Flase // scale_mod_only is always Flase
// swiglu is always Flase // swiglu is always Flase
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false)); blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only, flash_attn));
if (self_attn) { if (self_attn) {
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false)); blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false, flash_attn));
} }
if (!pre_only) { if (!pre_only) {
@ -435,8 +439,8 @@ public:
auto qkv2 = std::get<1>(qkv_intermediates); auto qkv2 = std::get<1>(qkv_intermediates);
auto intermediates = std::get<2>(qkv_intermediates); auto intermediates = std::get<2>(qkv_intermediates);
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim] auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = post_attention_x(ctx, x = post_attention_x(ctx,
attn_out, attn_out,
attn2_out, attn2_out,
@ -452,7 +456,7 @@ public:
auto qkv = qkv_intermediates.first; auto qkv = qkv_intermediates.first;
auto intermediates = qkv_intermediates.second; auto intermediates = qkv_intermediates.second;
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
x = post_attention(ctx, x = post_attention(ctx,
attn_out, attn_out,
intermediates[0], intermediates[0],
@ -468,6 +472,7 @@ public:
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*> __STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
block_mixing(struct ggml_context* ctx, block_mixing(struct ggml_context* ctx,
ggml_backend_t backend, ggml_backend_t backend,
bool flash_attn,
struct ggml_tensor* context, struct ggml_tensor* context,
struct ggml_tensor* x, struct ggml_tensor* x,
struct ggml_tensor* c, struct ggml_tensor* c,
@ -497,7 +502,7 @@ block_mixing(struct ggml_context* ctx,
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
} }
auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size] auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, NULL, false, false, flash_attn); // [N, n_context + n_token, hidden_size]
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
auto context_attn = ggml_view_3d(ctx, auto context_attn = ggml_view_3d(ctx,
attn, attn,
@ -556,6 +561,8 @@ block_mixing(struct ggml_context* ctx,
} }
struct JointBlock : public GGMLBlock { struct JointBlock : public GGMLBlock {
bool flash_attn;
public: public:
JointBlock(int64_t hidden_size, JointBlock(int64_t hidden_size,
int64_t num_heads, int64_t num_heads,
@ -563,9 +570,11 @@ public:
std::string qk_norm = "", std::string qk_norm = "",
bool qkv_bias = false, bool qkv_bias = false,
bool pre_only = false, bool pre_only = false,
bool self_attn_x = false) { bool self_attn_x = false,
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only)); bool flash_attn = false)
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x)); : flash_attn(flash_attn) {
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false, flash_attn));
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x, flash_attn));
} }
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx, std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
@ -576,7 +585,7 @@ public:
auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]); auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]);
auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]); auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]);
return block_mixing(ctx, backend, context, x, c, context_block, x_block); return block_mixing(ctx, backend, flash_attn, context, x, c, context_block, x_block);
} }
}; };
@ -634,6 +643,7 @@ protected:
int64_t context_embedder_out_dim = 1536; int64_t context_embedder_out_dim = 1536;
int64_t hidden_size; int64_t hidden_size;
std::string qk_norm; std::string qk_norm;
bool flash_attn = false;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") { void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F32; enum ggml_type wtype = GGML_TYPE_F32;
@ -641,7 +651,8 @@ protected:
} }
public: public:
MMDiT(const String2GGMLType& tensor_types = {}) { MMDiT(bool flash_attn = false, const String2GGMLType& tensor_types = {})
: flash_attn(flash_attn) {
// input_size is always None // input_size is always None
// learn_sigma is always False // learn_sigma is always False
// register_length is alwalys 0 // register_length is alwalys 0
@ -709,7 +720,8 @@ public:
qk_norm, qk_norm,
true, true,
i == depth - 1, i == depth - 1,
i <= d_self)); i <= d_self,
flash_attn));
} }
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels)); blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
@ -856,9 +868,10 @@ struct MMDiTRunner : public GGMLRunner {
MMDiTRunner(ggml_backend_t backend, MMDiTRunner(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
bool flash_attn,
const String2GGMLType& tensor_types = {}, const String2GGMLType& tensor_types = {},
const std::string prefix = "") const std::string prefix = "")
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) { : GGMLRunner(backend, offload_params_to_cpu), mmdit(flash_attn, tensor_types) {
mmdit.init(params_ctx, tensor_types, prefix); mmdit.init(params_ctx, tensor_types, prefix);
} }
@ -957,7 +970,7 @@ struct MMDiTRunner : public GGMLRunner {
// ggml_backend_t backend = ggml_backend_cuda_init(0); // ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init(); ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_F16; ggml_type model_data_type = GGML_TYPE_F16;
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false)); std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false, false));
{ {
LOG_INFO("loading from '%s'", file_path.c_str()); LOG_INFO("loading from '%s'", file_path.c_str());

View File

@ -350,6 +350,7 @@ public:
model_loader.tensor_storages_types); model_loader.tensor_storages_types);
diffusion_model = std::make_shared<MMDiTModel>(backend, diffusion_model = std::make_shared<MMDiTModel>(backend,
offload_params_to_cpu, offload_params_to_cpu,
sd_ctx_params->diffusion_flash_attn,
model_loader.tensor_storages_types); model_loader.tensor_storages_types);
} else if (sd_version_is_flux(version)) { } else if (sd_version_is_flux(version)) {
bool is_chroma = false; bool is_chroma = false;