diff --git a/conditioner.hpp b/conditioner.hpp index a4e84aa..b187695 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -34,6 +34,7 @@ struct Conditioner { virtual void free_params_buffer() = 0; virtual void get_param_tensors(std::map& tensors) = 0; virtual size_t get_params_buffer_size() = 0; + virtual void set_flash_attention_enabled(bool enabled) = 0; virtual void set_weight_adapter(const std::shared_ptr& adapter) {} virtual std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, int n_threads, @@ -115,6 +116,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { return buffer_size; } + void set_flash_attention_enabled(bool enabled) override { + text_model->set_flash_attention_enabled(enabled); + if (sd_version_is_sdxl(version)) { + text_model2->set_flash_attention_enabled(enabled); + } + } + void set_weight_adapter(const std::shared_ptr& adapter) override { text_model->set_weight_adapter(adapter); if (sd_version_is_sdxl(version)) { @@ -783,6 +791,18 @@ struct SD3CLIPEmbedder : public Conditioner { return buffer_size; } + void set_flash_attention_enabled(bool enabled) override { + if (clip_l) { + clip_l->set_flash_attention_enabled(enabled); + } + if (clip_g) { + clip_g->set_flash_attention_enabled(enabled); + } + if (t5) { + t5->set_flash_attention_enabled(enabled); + } + } + void set_weight_adapter(const std::shared_ptr& adapter) override { if (clip_l) { clip_l->set_weight_adapter(adapter); @@ -1191,6 +1211,15 @@ struct FluxCLIPEmbedder : public Conditioner { return buffer_size; } + void set_flash_attention_enabled(bool enabled) override { + if (clip_l) { + clip_l->set_flash_attention_enabled(enabled); + } + if (t5) { + t5->set_flash_attention_enabled(enabled); + } + } + void set_weight_adapter(const std::shared_ptr& adapter) { if (clip_l) { clip_l->set_weight_adapter(adapter); @@ -1440,6 +1469,12 @@ struct T5CLIPEmbedder : public Conditioner { return buffer_size; } + void set_flash_attention_enabled(bool enabled) override { + if (t5) { + t5->set_flash_attention_enabled(enabled); + } + } + void set_weight_adapter(const std::shared_ptr& adapter) override { if (t5) { t5->set_weight_adapter(adapter); @@ -1650,6 +1685,10 @@ struct LLMEmbedder : public Conditioner { return buffer_size; } + void set_flash_attention_enabled(bool enabled) override { + llm->set_flash_attention_enabled(enabled); + } + void set_weight_adapter(const std::shared_ptr& adapter) override { if (llm) { llm->set_weight_adapter(adapter); diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 06cbecc..3293ba9 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -38,7 +38,7 @@ struct DiffusionModel { virtual size_t get_params_buffer_size() = 0; virtual void set_weight_adapter(const std::shared_ptr& adapter){}; virtual int64_t get_adm_in_channels() = 0; - virtual void set_flash_attn_enabled(bool enabled) = 0; + virtual void set_flash_attention_enabled(bool enabled) = 0; virtual void set_circular_axes(bool circular_x, bool circular_y) = 0; }; @@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel { return unet.unet.adm_in_channels; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { unet.set_flash_attention_enabled(enabled); } @@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel { return 768 + 1280; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { mmdit.set_flash_attention_enabled(enabled); } @@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel { return 768; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { flux.set_flash_attention_enabled(enabled); } @@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel { return 768; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { wan.set_flash_attention_enabled(enabled); } @@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel { return 768; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { qwen_image.set_flash_attention_enabled(enabled); } @@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel { return 768; } - void set_flash_attn_enabled(bool enabled) { + void set_flash_attention_enabled(bool enabled) { z_image.set_flash_attention_enabled(enabled); } diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 11d30a3..50f35ae 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -457,6 +457,7 @@ struct SDContextParams { bool control_net_cpu = false; bool clip_on_cpu = false; bool vae_on_cpu = false; + bool flash_attn = false; bool diffusion_flash_attn = false; bool diffusion_conv_direct = false; bool vae_conv_direct = false; @@ -615,9 +616,13 @@ struct SDContextParams { "--vae-on-cpu", "keep vae in cpu (for low vram)", true, &vae_on_cpu}, + {"", + "--fa", + "use flash attention", + true, &flash_attn}, {"", "--diffusion-fa", - "use flash attention in the diffusion model", + "use flash attention in the diffusion model only", true, &diffusion_flash_attn}, {"", "--diffusion-conv-direct", @@ -904,6 +909,7 @@ struct SDContextParams { << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n" << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n" << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n" + << " flash_attn: " << (flash_attn ? "true" : "false") << ",\n" << " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n" << " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n" << " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n" @@ -968,6 +974,7 @@ struct SDContextParams { clip_on_cpu, control_net_cpu, vae_on_cpu, + flash_attn, diffusion_flash_attn, taesd_preview, diffusion_conv_direct, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 7dac037..193a2c3 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -2623,7 +2623,7 @@ public: v = v_proj->forward(ctx, x); } - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask); // [N, n_token, embed_dim] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask, false); // [N, n_token, embed_dim] x = out_proj->forward(ctx, x); // [N, n_token, embed_dim] return x; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 44412ed..f5c82b2 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -445,7 +445,7 @@ public: } } if (is_chroma) { - if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) { + if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) { LOG_WARN( "!!!It looks like you are using Chroma with flash attention. " "This is currently unsupported. " @@ -571,14 +571,6 @@ public: } } - if (sd_ctx_params->diffusion_flash_attn) { - LOG_INFO("Using flash attention in the diffusion model"); - diffusion_model->set_flash_attn_enabled(true); - if (high_noise_diffusion_model) { - high_noise_diffusion_model->set_flash_attn_enabled(true); - } - } - cond_stage_model->alloc_params_buffer(); cond_stage_model->get_param_tensors(tensors); @@ -725,6 +717,28 @@ public: pmid_model->get_param_tensors(tensors, "pmid"); } + if (sd_ctx_params->flash_attn) { + LOG_INFO("Using flash attention"); + cond_stage_model->set_flash_attention_enabled(true); + if (clip_vision) { + clip_vision->set_flash_attention_enabled(true); + } + if (first_stage_model) { + first_stage_model->set_flash_attention_enabled(true); + } + if (tae_first_stage) { + tae_first_stage->set_flash_attention_enabled(true); + } + } + + if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) { + LOG_INFO("Using flash attention in the diffusion model"); + diffusion_model->set_flash_attention_enabled(true); + if (high_noise_diffusion_model) { + high_noise_diffusion_model->set_flash_attention_enabled(true); + } + } + diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); if (high_noise_diffusion_model) { high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y); @@ -2942,6 +2956,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { "keep_clip_on_cpu: %s\n" "keep_control_net_on_cpu: %s\n" "keep_vae_on_cpu: %s\n" + "flash_attn: %s\n" "diffusion_flash_attn: %s\n" "circular_x: %s\n" "circular_y: %s\n" @@ -2973,6 +2988,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { BOOL_STR(sd_ctx_params->keep_clip_on_cpu), BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), BOOL_STR(sd_ctx_params->keep_vae_on_cpu), + BOOL_STR(sd_ctx_params->flash_attn), BOOL_STR(sd_ctx_params->diffusion_flash_attn), BOOL_STR(sd_ctx_params->circular_x), BOOL_STR(sd_ctx_params->circular_y), diff --git a/stable-diffusion.h b/stable-diffusion.h index 85768f4..cb966d7 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -189,6 +189,7 @@ typedef struct { bool keep_clip_on_cpu; bool keep_control_net_on_cpu; bool keep_vae_on_cpu; + bool flash_attn; bool diffusion_flash_attn; bool tae_preview_only; bool diffusion_conv_direct; diff --git a/vae.hpp b/vae.hpp index 01b99e8..0108134 100644 --- a/vae.hpp +++ b/vae.hpp @@ -141,7 +141,7 @@ public: v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels] } - h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); + h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); if (use_linear) { h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels] diff --git a/wan.hpp b/wan.hpp index 81959ef..7b10597 100644 --- a/wan.hpp +++ b/wan.hpp @@ -572,8 +572,8 @@ namespace WAN { auto v = qkv_vec[2]; v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w] - v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c] - x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); // [t, h * w, c] + v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c] + x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); // [t, h * w, c] x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w] x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]