mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 02:43:36 +00:00
feat: add --fa option (#1242)
This commit is contained in:
parent
c252e03c6b
commit
f957fa3d2a
@ -34,6 +34,7 @@ struct Conditioner {
|
||||
virtual void free_params_buffer() = 0;
|
||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
||||
virtual size_t get_params_buffer_size() = 0;
|
||||
virtual void set_flash_attention_enabled(bool enabled) = 0;
|
||||
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
|
||||
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||
int n_threads,
|
||||
@ -115,6 +116,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
|
||||
return buffer_size;
|
||||
}
|
||||
|
||||
void set_flash_attention_enabled(bool enabled) override {
|
||||
text_model->set_flash_attention_enabled(enabled);
|
||||
if (sd_version_is_sdxl(version)) {
|
||||
text_model2->set_flash_attention_enabled(enabled);
|
||||
}
|
||||
}
|
||||
|
||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
||||
text_model->set_weight_adapter(adapter);
|
||||
if (sd_version_is_sdxl(version)) {
|
||||
@ -783,6 +791,18 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
return buffer_size;
|
||||
}
|
||||
|
||||
void set_flash_attention_enabled(bool enabled) override {
|
||||
if (clip_l) {
|
||||
clip_l->set_flash_attention_enabled(enabled);
|
||||
}
|
||||
if (clip_g) {
|
||||
clip_g->set_flash_attention_enabled(enabled);
|
||||
}
|
||||
if (t5) {
|
||||
t5->set_flash_attention_enabled(enabled);
|
||||
}
|
||||
}
|
||||
|
||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
||||
if (clip_l) {
|
||||
clip_l->set_weight_adapter(adapter);
|
||||
@ -1191,6 +1211,15 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
return buffer_size;
|
||||
}
|
||||
|
||||
void set_flash_attention_enabled(bool enabled) override {
|
||||
if (clip_l) {
|
||||
clip_l->set_flash_attention_enabled(enabled);
|
||||
}
|
||||
if (t5) {
|
||||
t5->set_flash_attention_enabled(enabled);
|
||||
}
|
||||
}
|
||||
|
||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
|
||||
if (clip_l) {
|
||||
clip_l->set_weight_adapter(adapter);
|
||||
@ -1440,6 +1469,12 @@ struct T5CLIPEmbedder : public Conditioner {
|
||||
return buffer_size;
|
||||
}
|
||||
|
||||
void set_flash_attention_enabled(bool enabled) override {
|
||||
if (t5) {
|
||||
t5->set_flash_attention_enabled(enabled);
|
||||
}
|
||||
}
|
||||
|
||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
||||
if (t5) {
|
||||
t5->set_weight_adapter(adapter);
|
||||
@ -1650,6 +1685,10 @@ struct LLMEmbedder : public Conditioner {
|
||||
return buffer_size;
|
||||
}
|
||||
|
||||
void set_flash_attention_enabled(bool enabled) override {
|
||||
llm->set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
|
||||
if (llm) {
|
||||
llm->set_weight_adapter(adapter);
|
||||
|
||||
@ -38,7 +38,7 @@ struct DiffusionModel {
|
||||
virtual size_t get_params_buffer_size() = 0;
|
||||
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
|
||||
virtual int64_t get_adm_in_channels() = 0;
|
||||
virtual void set_flash_attn_enabled(bool enabled) = 0;
|
||||
virtual void set_flash_attention_enabled(bool enabled) = 0;
|
||||
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
|
||||
};
|
||||
|
||||
@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel {
|
||||
return unet.unet.adm_in_channels;
|
||||
}
|
||||
|
||||
void set_flash_attn_enabled(bool enabled) {
|
||||
void set_flash_attention_enabled(bool enabled) {
|
||||
unet.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel {
|
||||
return 768 + 1280;
|
||||
}
|
||||
|
||||
void set_flash_attn_enabled(bool enabled) {
|
||||
void set_flash_attention_enabled(bool enabled) {
|
||||
mmdit.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel {
|
||||
return 768;
|
||||
}
|
||||
|
||||
void set_flash_attn_enabled(bool enabled) {
|
||||
void set_flash_attention_enabled(bool enabled) {
|
||||
flux.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel {
|
||||
return 768;
|
||||
}
|
||||
|
||||
void set_flash_attn_enabled(bool enabled) {
|
||||
void set_flash_attention_enabled(bool enabled) {
|
||||
wan.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel {
|
||||
return 768;
|
||||
}
|
||||
|
||||
void set_flash_attn_enabled(bool enabled) {
|
||||
void set_flash_attention_enabled(bool enabled) {
|
||||
qwen_image.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel {
|
||||
return 768;
|
||||
}
|
||||
|
||||
void set_flash_attn_enabled(bool enabled) {
|
||||
void set_flash_attention_enabled(bool enabled) {
|
||||
z_image.set_flash_attention_enabled(enabled);
|
||||
}
|
||||
|
||||
|
||||
@ -52,7 +52,8 @@ Context Options:
|
||||
--control-net-cpu keep controlnet in cpu (for low vram)
|
||||
--clip-on-cpu keep clip in cpu (for low vram)
|
||||
--vae-on-cpu keep vae in cpu (for low vram)
|
||||
--diffusion-fa use flash attention in the diffusion model
|
||||
--fa use flash attention
|
||||
--diffusion-fa use flash attention in the diffusion model only
|
||||
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
|
||||
--vae-conv-direct use ggml_conv2d_direct in the vae model
|
||||
--circular enable circular padding for convolutions
|
||||
|
||||
@ -457,6 +457,7 @@ struct SDContextParams {
|
||||
bool control_net_cpu = false;
|
||||
bool clip_on_cpu = false;
|
||||
bool vae_on_cpu = false;
|
||||
bool flash_attn = false;
|
||||
bool diffusion_flash_attn = false;
|
||||
bool diffusion_conv_direct = false;
|
||||
bool vae_conv_direct = false;
|
||||
@ -615,9 +616,13 @@ struct SDContextParams {
|
||||
"--vae-on-cpu",
|
||||
"keep vae in cpu (for low vram)",
|
||||
true, &vae_on_cpu},
|
||||
{"",
|
||||
"--fa",
|
||||
"use flash attention",
|
||||
true, &flash_attn},
|
||||
{"",
|
||||
"--diffusion-fa",
|
||||
"use flash attention in the diffusion model",
|
||||
"use flash attention in the diffusion model only",
|
||||
true, &diffusion_flash_attn},
|
||||
{"",
|
||||
"--diffusion-conv-direct",
|
||||
@ -904,6 +909,7 @@ struct SDContextParams {
|
||||
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
|
||||
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
|
||||
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
|
||||
<< " flash_attn: " << (flash_attn ? "true" : "false") << ",\n"
|
||||
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
|
||||
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
|
||||
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
|
||||
@ -968,6 +974,7 @@ struct SDContextParams {
|
||||
clip_on_cpu,
|
||||
control_net_cpu,
|
||||
vae_on_cpu,
|
||||
flash_attn,
|
||||
diffusion_flash_attn,
|
||||
taesd_preview,
|
||||
diffusion_conv_direct,
|
||||
|
||||
@ -44,7 +44,8 @@ Context Options:
|
||||
--clip-on-cpu keep clip in cpu (for low vram)
|
||||
--vae-on-cpu keep vae in cpu (for low vram)
|
||||
--mmap whether to memory-map model
|
||||
--diffusion-fa use flash attention in the diffusion model
|
||||
--fa use flash attention
|
||||
--diffusion-fa use flash attention in the diffusion model only
|
||||
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
|
||||
--vae-conv-direct use ggml_conv2d_direct in the vae model
|
||||
--circular enable circular padding for convolutions
|
||||
|
||||
@ -2623,7 +2623,7 @@ public:
|
||||
v = v_proj->forward(ctx, x);
|
||||
}
|
||||
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask); // [N, n_token, embed_dim]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask, false); // [N, n_token, embed_dim]
|
||||
|
||||
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
|
||||
return x;
|
||||
|
||||
@ -445,7 +445,7 @@ public:
|
||||
}
|
||||
}
|
||||
if (is_chroma) {
|
||||
if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) {
|
||||
if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
|
||||
LOG_WARN(
|
||||
"!!!It looks like you are using Chroma with flash attention. "
|
||||
"This is currently unsupported. "
|
||||
@ -571,14 +571,6 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
if (sd_ctx_params->diffusion_flash_attn) {
|
||||
LOG_INFO("Using flash attention in the diffusion model");
|
||||
diffusion_model->set_flash_attn_enabled(true);
|
||||
if (high_noise_diffusion_model) {
|
||||
high_noise_diffusion_model->set_flash_attn_enabled(true);
|
||||
}
|
||||
}
|
||||
|
||||
cond_stage_model->alloc_params_buffer();
|
||||
cond_stage_model->get_param_tensors(tensors);
|
||||
|
||||
@ -725,6 +717,28 @@ public:
|
||||
pmid_model->get_param_tensors(tensors, "pmid");
|
||||
}
|
||||
|
||||
if (sd_ctx_params->flash_attn) {
|
||||
LOG_INFO("Using flash attention");
|
||||
cond_stage_model->set_flash_attention_enabled(true);
|
||||
if (clip_vision) {
|
||||
clip_vision->set_flash_attention_enabled(true);
|
||||
}
|
||||
if (first_stage_model) {
|
||||
first_stage_model->set_flash_attention_enabled(true);
|
||||
}
|
||||
if (tae_first_stage) {
|
||||
tae_first_stage->set_flash_attention_enabled(true);
|
||||
}
|
||||
}
|
||||
|
||||
if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) {
|
||||
LOG_INFO("Using flash attention in the diffusion model");
|
||||
diffusion_model->set_flash_attention_enabled(true);
|
||||
if (high_noise_diffusion_model) {
|
||||
high_noise_diffusion_model->set_flash_attention_enabled(true);
|
||||
}
|
||||
}
|
||||
|
||||
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
|
||||
if (high_noise_diffusion_model) {
|
||||
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
|
||||
@ -2942,6 +2956,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
"keep_clip_on_cpu: %s\n"
|
||||
"keep_control_net_on_cpu: %s\n"
|
||||
"keep_vae_on_cpu: %s\n"
|
||||
"flash_attn: %s\n"
|
||||
"diffusion_flash_attn: %s\n"
|
||||
"circular_x: %s\n"
|
||||
"circular_y: %s\n"
|
||||
@ -2973,6 +2988,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
|
||||
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
|
||||
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
|
||||
BOOL_STR(sd_ctx_params->flash_attn),
|
||||
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
|
||||
BOOL_STR(sd_ctx_params->circular_x),
|
||||
BOOL_STR(sd_ctx_params->circular_y),
|
||||
|
||||
@ -189,6 +189,7 @@ typedef struct {
|
||||
bool keep_clip_on_cpu;
|
||||
bool keep_control_net_on_cpu;
|
||||
bool keep_vae_on_cpu;
|
||||
bool flash_attn;
|
||||
bool diffusion_flash_attn;
|
||||
bool tae_preview_only;
|
||||
bool diffusion_conv_direct;
|
||||
|
||||
2
vae.hpp
2
vae.hpp
@ -141,7 +141,7 @@ public:
|
||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
|
||||
}
|
||||
|
||||
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false);
|
||||
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled);
|
||||
|
||||
if (use_linear) {
|
||||
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
|
||||
|
||||
4
wan.hpp
4
wan.hpp
@ -572,8 +572,8 @@ namespace WAN {
|
||||
auto v = qkv_vec[2];
|
||||
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
|
||||
|
||||
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); // [t, h * w, c]
|
||||
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
|
||||
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); // [t, h * w, c]
|
||||
|
||||
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
|
||||
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user