feat: add --fa option

This commit is contained in:
leejet 2026-02-01 21:33:09 +08:00
parent c252e03c6b
commit 580b9d1e61
8 changed files with 84 additions and 21 deletions

View File

@ -34,6 +34,7 @@ struct Conditioner {
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
@ -115,6 +116,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
text_model->set_flash_attention_enabled(enabled);
if (sd_version_is_sdxl(version)) {
text_model2->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
text_model->set_weight_adapter(adapter);
if (sd_version_is_sdxl(version)) {
@ -783,6 +791,18 @@ struct SD3CLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (clip_g) {
clip_g->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (clip_l) {
clip_l->set_weight_adapter(adapter);
@ -1191,6 +1211,15 @@ struct FluxCLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
if (clip_l) {
clip_l->set_weight_adapter(adapter);
@ -1440,6 +1469,12 @@ struct T5CLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (t5) {
t5->set_weight_adapter(adapter);
@ -1650,6 +1685,10 @@ struct LLMEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (llm) {
llm->set_weight_adapter(adapter);

View File

@ -38,7 +38,7 @@ struct DiffusionModel {
virtual size_t get_params_buffer_size() = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
};
@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel {
return unet.unet.adm_in_channels;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
unet.set_flash_attention_enabled(enabled);
}
@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel {
return 768 + 1280;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
mmdit.set_flash_attention_enabled(enabled);
}
@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
flux.set_flash_attention_enabled(enabled);
}
@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
wan.set_flash_attention_enabled(enabled);
}
@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
qwen_image.set_flash_attention_enabled(enabled);
}
@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
z_image.set_flash_attention_enabled(enabled);
}

View File

@ -457,6 +457,7 @@ struct SDContextParams {
bool control_net_cpu = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool flash_attn = false;
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
@ -615,9 +616,13 @@ struct SDContextParams {
"--vae-on-cpu",
"keep vae in cpu (for low vram)",
true, &vae_on_cpu},
{"",
"--fa",
"use flash attention",
true, &flash_attn},
{"",
"--diffusion-fa",
"use flash attention in the diffusion model",
"use flash attention in the diffusion model only",
true, &diffusion_flash_attn},
{"",
"--diffusion-conv-direct",
@ -904,6 +909,7 @@ struct SDContextParams {
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
<< " flash_attn: " << (flash_attn ? "true" : "false") << ",\n"
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
@ -968,6 +974,7 @@ struct SDContextParams {
clip_on_cpu,
control_net_cpu,
vae_on_cpu,
flash_attn,
diffusion_flash_attn,
taesd_preview,
diffusion_conv_direct,

View File

@ -2623,7 +2623,7 @@ public:
v = v_proj->forward(ctx, x);
}
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask); // [N, n_token, embed_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask, false); // [N, n_token, embed_dim]
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x;

View File

@ -445,7 +445,7 @@ public:
}
}
if (is_chroma) {
if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) {
if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
LOG_WARN(
"!!!It looks like you are using Chroma with flash attention. "
"This is currently unsupported. "
@ -571,14 +571,6 @@ public:
}
}
if (sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attn_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attn_enabled(true);
}
}
cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors);
@ -725,6 +717,28 @@ public:
pmid_model->get_param_tensors(tensors, "pmid");
}
if (sd_ctx_params->flash_attn) {
LOG_INFO("Using flash attention");
cond_stage_model->set_flash_attention_enabled(true);
if (clip_vision) {
clip_vision->set_flash_attention_enabled(true);
}
if (first_stage_model) {
first_stage_model->set_flash_attention_enabled(true);
}
if (tae_first_stage) {
tae_first_stage->set_flash_attention_enabled(true);
}
}
if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attention_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attention_enabled(true);
}
}
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
@ -2942,6 +2956,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"keep_clip_on_cpu: %s\n"
"keep_control_net_on_cpu: %s\n"
"keep_vae_on_cpu: %s\n"
"flash_attn: %s\n"
"diffusion_flash_attn: %s\n"
"circular_x: %s\n"
"circular_y: %s\n"
@ -2973,6 +2988,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
BOOL_STR(sd_ctx_params->flash_attn),
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
BOOL_STR(sd_ctx_params->circular_x),
BOOL_STR(sd_ctx_params->circular_y),

View File

@ -189,6 +189,7 @@ typedef struct {
bool keep_clip_on_cpu;
bool keep_control_net_on_cpu;
bool keep_vae_on_cpu;
bool flash_attn;
bool diffusion_flash_attn;
bool tae_preview_only;
bool diffusion_conv_direct;

View File

@ -141,7 +141,7 @@ public:
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
}
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false);
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled);
if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]

View File

@ -573,7 +573,7 @@ namespace WAN {
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); // [t, h * w, c]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]