mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
feat: add sd3 flash attn support (#815)
This commit is contained in:
parent
49d6570c43
commit
fce6afcc6a
@ -95,8 +95,9 @@ struct MMDiTModel : public DiffusionModel {
|
|||||||
|
|
||||||
MMDiTModel(ggml_backend_t backend,
|
MMDiTModel(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
|
bool flash_attn = false,
|
||||||
const String2GGMLType& tensor_types = {})
|
const String2GGMLType& tensor_types = {})
|
||||||
: mmdit(backend, offload_params_to_cpu, tensor_types, "model.diffusion_model") {
|
: mmdit(backend, offload_params_to_cpu, flash_attn, tensor_types, "model.diffusion_model") {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_desc() {
|
std::string get_desc() {
|
||||||
|
|||||||
49
mmdit.hpp
49
mmdit.hpp
@ -147,14 +147,16 @@ public:
|
|||||||
int64_t num_heads;
|
int64_t num_heads;
|
||||||
bool pre_only;
|
bool pre_only;
|
||||||
std::string qk_norm;
|
std::string qk_norm;
|
||||||
|
bool flash_attn;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
SelfAttention(int64_t dim,
|
SelfAttention(int64_t dim,
|
||||||
int64_t num_heads = 8,
|
int64_t num_heads = 8,
|
||||||
std::string qk_norm = "",
|
std::string qk_norm = "",
|
||||||
bool qkv_bias = false,
|
bool qkv_bias = false,
|
||||||
bool pre_only = false)
|
bool pre_only = false,
|
||||||
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) {
|
bool flash_attn = false)
|
||||||
|
: num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm), flash_attn(flash_attn) {
|
||||||
int64_t d_head = dim / num_heads;
|
int64_t d_head = dim / num_heads;
|
||||||
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
|
blocks["qkv"] = std::shared_ptr<GGMLBlock>(new Linear(dim, dim * 3, qkv_bias));
|
||||||
if (!pre_only) {
|
if (!pre_only) {
|
||||||
@ -206,7 +208,7 @@ public:
|
|||||||
ggml_backend_t backend,
|
ggml_backend_t backend,
|
||||||
struct ggml_tensor* x) {
|
struct ggml_tensor* x) {
|
||||||
auto qkv = pre_attention(ctx, x);
|
auto qkv = pre_attention(ctx, x);
|
||||||
x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
|
x = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, true); // [N, n_token, dim]
|
||||||
x = post_attention(ctx, x); // [N, n_token, dim]
|
x = post_attention(ctx, x); // [N, n_token, dim]
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
@ -232,6 +234,7 @@ public:
|
|||||||
int64_t num_heads;
|
int64_t num_heads;
|
||||||
bool pre_only;
|
bool pre_only;
|
||||||
bool self_attn;
|
bool self_attn;
|
||||||
|
bool flash_attn;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
DismantledBlock(int64_t hidden_size,
|
DismantledBlock(int64_t hidden_size,
|
||||||
@ -240,16 +243,17 @@ public:
|
|||||||
std::string qk_norm = "",
|
std::string qk_norm = "",
|
||||||
bool qkv_bias = false,
|
bool qkv_bias = false,
|
||||||
bool pre_only = false,
|
bool pre_only = false,
|
||||||
bool self_attn = false)
|
bool self_attn = false,
|
||||||
|
bool flash_attn = false)
|
||||||
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
|
: num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) {
|
||||||
// rmsnorm is always Flase
|
// rmsnorm is always Flase
|
||||||
// scale_mod_only is always Flase
|
// scale_mod_only is always Flase
|
||||||
// swiglu is always Flase
|
// swiglu is always Flase
|
||||||
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
|
blocks["norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(hidden_size, 1e-06f, false));
|
||||||
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only));
|
blocks["attn"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only, flash_attn));
|
||||||
|
|
||||||
if (self_attn) {
|
if (self_attn) {
|
||||||
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false));
|
blocks["attn2"] = std::shared_ptr<GGMLBlock>(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false, flash_attn));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!pre_only) {
|
if (!pre_only) {
|
||||||
@ -435,8 +439,8 @@ public:
|
|||||||
auto qkv2 = std::get<1>(qkv_intermediates);
|
auto qkv2 = std::get<1>(qkv_intermediates);
|
||||||
auto intermediates = std::get<2>(qkv_intermediates);
|
auto intermediates = std::get<2>(qkv_intermediates);
|
||||||
|
|
||||||
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
|
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
|
||||||
auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim]
|
auto attn2_out = ggml_nn_attention_ext(ctx, backend, qkv2[0], qkv2[1], qkv2[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
|
||||||
x = post_attention_x(ctx,
|
x = post_attention_x(ctx,
|
||||||
attn_out,
|
attn_out,
|
||||||
attn2_out,
|
attn2_out,
|
||||||
@ -452,7 +456,7 @@ public:
|
|||||||
auto qkv = qkv_intermediates.first;
|
auto qkv = qkv_intermediates.first;
|
||||||
auto intermediates = qkv_intermediates.second;
|
auto intermediates = qkv_intermediates.second;
|
||||||
|
|
||||||
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim]
|
auto attn_out = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], num_heads, NULL, false, false, flash_attn); // [N, n_token, dim]
|
||||||
x = post_attention(ctx,
|
x = post_attention(ctx,
|
||||||
attn_out,
|
attn_out,
|
||||||
intermediates[0],
|
intermediates[0],
|
||||||
@ -468,6 +472,7 @@ public:
|
|||||||
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
|
__STATIC_INLINE__ std::pair<struct ggml_tensor*, struct ggml_tensor*>
|
||||||
block_mixing(struct ggml_context* ctx,
|
block_mixing(struct ggml_context* ctx,
|
||||||
ggml_backend_t backend,
|
ggml_backend_t backend,
|
||||||
|
bool flash_attn,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
struct ggml_tensor* c,
|
struct ggml_tensor* c,
|
||||||
@ -497,7 +502,7 @@ block_mixing(struct ggml_context* ctx,
|
|||||||
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
|
qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads); // [N, n_context + n_token, hidden_size]
|
auto attn = ggml_nn_attention_ext(ctx, backend, qkv[0], qkv[1], qkv[2], x_block->num_heads, NULL, false, false, flash_attn); // [N, n_context + n_token, hidden_size]
|
||||||
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
|
attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_context + n_token, N, hidden_size]
|
||||||
auto context_attn = ggml_view_3d(ctx,
|
auto context_attn = ggml_view_3d(ctx,
|
||||||
attn,
|
attn,
|
||||||
@ -556,6 +561,8 @@ block_mixing(struct ggml_context* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct JointBlock : public GGMLBlock {
|
struct JointBlock : public GGMLBlock {
|
||||||
|
bool flash_attn;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
JointBlock(int64_t hidden_size,
|
JointBlock(int64_t hidden_size,
|
||||||
int64_t num_heads,
|
int64_t num_heads,
|
||||||
@ -563,9 +570,11 @@ public:
|
|||||||
std::string qk_norm = "",
|
std::string qk_norm = "",
|
||||||
bool qkv_bias = false,
|
bool qkv_bias = false,
|
||||||
bool pre_only = false,
|
bool pre_only = false,
|
||||||
bool self_attn_x = false) {
|
bool self_attn_x = false,
|
||||||
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only));
|
bool flash_attn = false)
|
||||||
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x));
|
: flash_attn(flash_attn) {
|
||||||
|
blocks["context_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only, false, flash_attn));
|
||||||
|
blocks["x_block"] = std::shared_ptr<GGMLBlock>(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x, flash_attn));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
|
std::pair<struct ggml_tensor*, struct ggml_tensor*> forward(struct ggml_context* ctx,
|
||||||
@ -576,7 +585,7 @@ public:
|
|||||||
auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]);
|
auto context_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["context_block"]);
|
||||||
auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]);
|
auto x_block = std::dynamic_pointer_cast<DismantledBlock>(blocks["x_block"]);
|
||||||
|
|
||||||
return block_mixing(ctx, backend, context, x, c, context_block, x_block);
|
return block_mixing(ctx, backend, flash_attn, context, x, c, context_block, x_block);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -634,6 +643,7 @@ protected:
|
|||||||
int64_t context_embedder_out_dim = 1536;
|
int64_t context_embedder_out_dim = 1536;
|
||||||
int64_t hidden_size;
|
int64_t hidden_size;
|
||||||
std::string qk_norm;
|
std::string qk_norm;
|
||||||
|
bool flash_attn = false;
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, std::string prefix = "") {
|
||||||
enum ggml_type wtype = GGML_TYPE_F32;
|
enum ggml_type wtype = GGML_TYPE_F32;
|
||||||
@ -641,7 +651,8 @@ protected:
|
|||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
MMDiT(const String2GGMLType& tensor_types = {}) {
|
MMDiT(bool flash_attn = false, const String2GGMLType& tensor_types = {})
|
||||||
|
: flash_attn(flash_attn) {
|
||||||
// input_size is always None
|
// input_size is always None
|
||||||
// learn_sigma is always False
|
// learn_sigma is always False
|
||||||
// register_length is alwalys 0
|
// register_length is alwalys 0
|
||||||
@ -709,7 +720,8 @@ public:
|
|||||||
qk_norm,
|
qk_norm,
|
||||||
true,
|
true,
|
||||||
i == depth - 1,
|
i == depth - 1,
|
||||||
i <= d_self));
|
i <= d_self,
|
||||||
|
flash_attn));
|
||||||
}
|
}
|
||||||
|
|
||||||
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
|
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new FinalLayer(hidden_size, patch_size, out_channels));
|
||||||
@ -856,9 +868,10 @@ struct MMDiTRunner : public GGMLRunner {
|
|||||||
|
|
||||||
MMDiTRunner(ggml_backend_t backend,
|
MMDiTRunner(ggml_backend_t backend,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
|
bool flash_attn,
|
||||||
const String2GGMLType& tensor_types = {},
|
const String2GGMLType& tensor_types = {},
|
||||||
const std::string prefix = "")
|
const std::string prefix = "")
|
||||||
: GGMLRunner(backend, offload_params_to_cpu), mmdit(tensor_types) {
|
: GGMLRunner(backend, offload_params_to_cpu), mmdit(flash_attn, tensor_types) {
|
||||||
mmdit.init(params_ctx, tensor_types, prefix);
|
mmdit.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -957,7 +970,7 @@ struct MMDiTRunner : public GGMLRunner {
|
|||||||
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||||
ggml_type model_data_type = GGML_TYPE_F16;
|
ggml_type model_data_type = GGML_TYPE_F16;
|
||||||
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false));
|
std::shared_ptr<MMDiTRunner> mmdit = std::shared_ptr<MMDiTRunner>(new MMDiTRunner(backend, false, false));
|
||||||
{
|
{
|
||||||
LOG_INFO("loading from '%s'", file_path.c_str());
|
LOG_INFO("loading from '%s'", file_path.c_str());
|
||||||
|
|
||||||
|
|||||||
@ -350,6 +350,7 @@ public:
|
|||||||
model_loader.tensor_storages_types);
|
model_loader.tensor_storages_types);
|
||||||
diffusion_model = std::make_shared<MMDiTModel>(backend,
|
diffusion_model = std::make_shared<MMDiTModel>(backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
|
sd_ctx_params->diffusion_flash_attn,
|
||||||
model_loader.tensor_storages_types);
|
model_loader.tensor_storages_types);
|
||||||
} else if (sd_version_is_flux(version)) {
|
} else if (sd_version_is_flux(version)) {
|
||||||
bool is_chroma = false;
|
bool is_chroma = false;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user