feat: add --fa option

This commit is contained in:
leejet 2026-02-01 21:33:09 +08:00
parent c252e03c6b
commit 580b9d1e61
8 changed files with 84 additions and 21 deletions

View File

@ -34,6 +34,7 @@ struct Conditioner {
virtual void free_params_buffer() = 0; virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0; virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0; virtual size_t get_params_buffer_size() = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {} virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx, virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads, int n_threads,
@ -115,6 +116,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
text_model->set_flash_attention_enabled(enabled);
if (sd_version_is_sdxl(version)) {
text_model2->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
text_model->set_weight_adapter(adapter); text_model->set_weight_adapter(adapter);
if (sd_version_is_sdxl(version)) { if (sd_version_is_sdxl(version)) {
@ -783,6 +791,18 @@ struct SD3CLIPEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (clip_g) {
clip_g->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (clip_l) { if (clip_l) {
clip_l->set_weight_adapter(adapter); clip_l->set_weight_adapter(adapter);
@ -1191,6 +1211,15 @@ struct FluxCLIPEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
if (clip_l) { if (clip_l) {
clip_l->set_weight_adapter(adapter); clip_l->set_weight_adapter(adapter);
@ -1440,6 +1469,12 @@ struct T5CLIPEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (t5) { if (t5) {
t5->set_weight_adapter(adapter); t5->set_weight_adapter(adapter);
@ -1650,6 +1685,10 @@ struct LLMEmbedder : public Conditioner {
return buffer_size; return buffer_size;
} }
void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override { void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (llm) { if (llm) {
llm->set_weight_adapter(adapter); llm->set_weight_adapter(adapter);

View File

@ -38,7 +38,7 @@ struct DiffusionModel {
virtual size_t get_params_buffer_size() = 0; virtual size_t get_params_buffer_size() = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){}; virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0; virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0; virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0; virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
}; };
@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel {
return unet.unet.adm_in_channels; return unet.unet.adm_in_channels;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
unet.set_flash_attention_enabled(enabled); unet.set_flash_attention_enabled(enabled);
} }
@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel {
return 768 + 1280; return 768 + 1280;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
mmdit.set_flash_attention_enabled(enabled); mmdit.set_flash_attention_enabled(enabled);
} }
@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
flux.set_flash_attention_enabled(enabled); flux.set_flash_attention_enabled(enabled);
} }
@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
wan.set_flash_attention_enabled(enabled); wan.set_flash_attention_enabled(enabled);
} }
@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
qwen_image.set_flash_attention_enabled(enabled); qwen_image.set_flash_attention_enabled(enabled);
} }
@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel {
return 768; return 768;
} }
void set_flash_attn_enabled(bool enabled) { void set_flash_attention_enabled(bool enabled) {
z_image.set_flash_attention_enabled(enabled); z_image.set_flash_attention_enabled(enabled);
} }

View File

@ -457,6 +457,7 @@ struct SDContextParams {
bool control_net_cpu = false; bool control_net_cpu = false;
bool clip_on_cpu = false; bool clip_on_cpu = false;
bool vae_on_cpu = false; bool vae_on_cpu = false;
bool flash_attn = false;
bool diffusion_flash_attn = false; bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false; bool diffusion_conv_direct = false;
bool vae_conv_direct = false; bool vae_conv_direct = false;
@ -615,9 +616,13 @@ struct SDContextParams {
"--vae-on-cpu", "--vae-on-cpu",
"keep vae in cpu (for low vram)", "keep vae in cpu (for low vram)",
true, &vae_on_cpu}, true, &vae_on_cpu},
{"",
"--fa",
"use flash attention",
true, &flash_attn},
{"", {"",
"--diffusion-fa", "--diffusion-fa",
"use flash attention in the diffusion model", "use flash attention in the diffusion model only",
true, &diffusion_flash_attn}, true, &diffusion_flash_attn},
{"", {"",
"--diffusion-conv-direct", "--diffusion-conv-direct",
@ -904,6 +909,7 @@ struct SDContextParams {
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n" << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n" << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n" << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
<< " flash_attn: " << (flash_attn ? "true" : "false") << ",\n"
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
@ -968,6 +974,7 @@ struct SDContextParams {
clip_on_cpu, clip_on_cpu,
control_net_cpu, control_net_cpu,
vae_on_cpu, vae_on_cpu,
flash_attn,
diffusion_flash_attn, diffusion_flash_attn,
taesd_preview, taesd_preview,
diffusion_conv_direct, diffusion_conv_direct,

View File

@ -2623,7 +2623,7 @@ public:
v = v_proj->forward(ctx, x); v = v_proj->forward(ctx, x);
} }
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask); // [N, n_token, embed_dim] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask, false); // [N, n_token, embed_dim]
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim] x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x; return x;

View File

@ -445,7 +445,7 @@ public:
} }
} }
if (is_chroma) { if (is_chroma) {
if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) { if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
LOG_WARN( LOG_WARN(
"!!!It looks like you are using Chroma with flash attention. " "!!!It looks like you are using Chroma with flash attention. "
"This is currently unsupported. " "This is currently unsupported. "
@ -571,14 +571,6 @@ public:
} }
} }
if (sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attn_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attn_enabled(true);
}
}
cond_stage_model->alloc_params_buffer(); cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors); cond_stage_model->get_param_tensors(tensors);
@ -725,6 +717,28 @@ public:
pmid_model->get_param_tensors(tensors, "pmid"); pmid_model->get_param_tensors(tensors, "pmid");
} }
if (sd_ctx_params->flash_attn) {
LOG_INFO("Using flash attention");
cond_stage_model->set_flash_attention_enabled(true);
if (clip_vision) {
clip_vision->set_flash_attention_enabled(true);
}
if (first_stage_model) {
first_stage_model->set_flash_attention_enabled(true);
}
if (tae_first_stage) {
tae_first_stage->set_flash_attention_enabled(true);
}
}
if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attention_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attention_enabled(true);
}
}
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
if (high_noise_diffusion_model) { if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
@ -2942,6 +2956,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"keep_clip_on_cpu: %s\n" "keep_clip_on_cpu: %s\n"
"keep_control_net_on_cpu: %s\n" "keep_control_net_on_cpu: %s\n"
"keep_vae_on_cpu: %s\n" "keep_vae_on_cpu: %s\n"
"flash_attn: %s\n"
"diffusion_flash_attn: %s\n" "diffusion_flash_attn: %s\n"
"circular_x: %s\n" "circular_x: %s\n"
"circular_y: %s\n" "circular_y: %s\n"
@ -2973,6 +2988,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
BOOL_STR(sd_ctx_params->keep_clip_on_cpu), BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
BOOL_STR(sd_ctx_params->keep_vae_on_cpu), BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
BOOL_STR(sd_ctx_params->flash_attn),
BOOL_STR(sd_ctx_params->diffusion_flash_attn), BOOL_STR(sd_ctx_params->diffusion_flash_attn),
BOOL_STR(sd_ctx_params->circular_x), BOOL_STR(sd_ctx_params->circular_x),
BOOL_STR(sd_ctx_params->circular_y), BOOL_STR(sd_ctx_params->circular_y),

View File

@ -189,6 +189,7 @@ typedef struct {
bool keep_clip_on_cpu; bool keep_clip_on_cpu;
bool keep_control_net_on_cpu; bool keep_control_net_on_cpu;
bool keep_vae_on_cpu; bool keep_vae_on_cpu;
bool flash_attn;
bool diffusion_flash_attn; bool diffusion_flash_attn;
bool tae_preview_only; bool tae_preview_only;
bool diffusion_conv_direct; bool diffusion_conv_direct;

View File

@ -141,7 +141,7 @@ public:
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels] v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
} }
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled);
if (use_linear) { if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]

View File

@ -572,8 +572,8 @@ namespace WAN {
auto v = qkv_vec[2]; auto v = qkv_vec[2];
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w] v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c] v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); // [t, h * w, c] x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); // [t, h * w, c]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w] x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]