enable flash attn by default

This commit is contained in:
leejet 2026-01-11 17:35:01 +08:00
parent 885e62ea82
commit 8283e1bade
18 changed files with 92 additions and 58 deletions

View File

@ -144,6 +144,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
- [Docker](./docs/docker.md)
- [Quantization and GGUF](./docs/quantization_and_gguf.md)
- [Inference acceleration via caching](./docs/caching.md)
- [Troubleshooting](./docs/troubleshooting.md)
## Bindings

View File

@ -34,6 +34,7 @@ struct Conditioner {
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
@ -115,6 +116,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
text_model->set_flash_attention_enabled(enabled);
if (sd_version_is_sdxl(version)) {
text_model2->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
text_model->set_weight_adapter(adapter);
if (sd_version_is_sdxl(version)) {
@ -783,6 +791,18 @@ struct SD3CLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (clip_g) {
clip_g->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (clip_l) {
clip_l->set_weight_adapter(adapter);
@ -1191,6 +1211,15 @@ struct FluxCLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
if (clip_l) {
clip_l->set_weight_adapter(adapter);
@ -1440,6 +1469,12 @@ struct T5CLIPEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (t5) {
t5->set_weight_adapter(adapter);
@ -1650,6 +1685,10 @@ struct LLMEmbedder : public Conditioner {
return buffer_size;
}
void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}
void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (llm) {
llm->set_weight_adapter(adapter);

View File

@ -38,7 +38,7 @@ struct DiffusionModel {
virtual size_t get_params_buffer_size() = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
};
@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel {
return unet.unet.adm_in_channels;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
unet.set_flash_attention_enabled(enabled);
}
@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel {
return 768 + 1280;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
mmdit.set_flash_attention_enabled(enabled);
}
@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
flux.set_flash_attention_enabled(enabled);
}
@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
wan.set_flash_attention_enabled(enabled);
}
@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
qwen_image.set_flash_attention_enabled(enabled);
}
@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel {
return 768;
}
void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
z_image.set_flash_attention_enabled(enabled);
}

View File

@ -12,7 +12,7 @@
## Examples
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --offload-to-cpu
```
<img alt="flux2 example" src="../assets/flux2/example.png" />

View File

@ -13,7 +13,7 @@
## Examples
```
.\bin\Release\sd-cli.exe --diffusion-model ovis_image-Q4_0.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\ovis_2.5.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu --diffusion-fa
.\bin\Release\sd-cli.exe --diffusion-model ovis_image-Q4_0.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\ovis_2.5.safetensors -p "a lovely cat" --cfg-scale 5.0 -v --offload-to-cpu
```
<img alt="ovis image example" src="../assets/ovis_image/example.png" />

View File

@ -1,22 +1,3 @@
## Use Flash Attention to save memory and improve speed.
Enabling flash attention for the diffusion model reduces memory usage by varying amounts of MB.
eg.:
- flux 768x768 ~600mb
- SD2 768x768 ~1400mb
For most backends, it slows things down, but for cuda it generally speeds it up too.
At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal).
Run by adding `--diffusion-fa` to the arguments and watch for:
```
[INFO ] stable-diffusion.cpp:312 - Using flash attention in the diffusion model
```
and the compute buffer shrink in the debug log:
```
[DEBUG] ggml_extend.hpp:1004 - flux compute buffer size: 650.00 MB(VRAM)
```
## Offload weights to the CPU to save VRAM without reducing generation speed.
Using `--offload-to-cpu` allows you to offload weights to the CPU, saving VRAM without reducing generation speed.

View File

@ -14,7 +14,7 @@
## Examples
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线 探索视觉生成基础模型的极限开创理解与生成一体化的未来。二、Qwen-Image的模型特色1、复杂文字渲染。支持中英渲染、自动布局 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --diffusion-fa --flow-shift 3
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线 探索视觉生成基础模型的极限开创理解与生成一体化的未来。二、Qwen-Image的模型特色1、复杂文字渲染。支持中英渲染、自动布局 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --flow-shift 3
```
<img alt="qwen example" src="../assets/qwen/example.png" />

View File

@ -23,7 +23,7 @@
### Qwen Image Edit
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen_Image_Edit-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --seed 1118877715456453
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen_Image_Edit-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --seed 1118877715456453
```
<img alt="qwen_image_edit" src="../assets/qwen/qwen_image_edit.png" />
@ -32,7 +32,7 @@
### Qwen Image Edit 2509
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen-Image-Edit-2509-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --llm_vision ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct.mmproj-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'Qwen Image Edit 2509'"
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Qwen-Image-Edit-2509-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf --llm_vision ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct.mmproj-Q8_0.gguf --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'Qwen Image Edit 2509'"
```
<img alt="qwen_image_edit_2509" src="../assets/qwen/qwen_image_edit_2509.png" />
@ -42,7 +42,7 @@
To use the new Qwen Image Edit 2511 mode, the `--qwen-image-zero-cond-t` flag must be enabled; otherwise, image editing quality will degrade significantly.
```
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-edit-2511-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --qwen-image-zero-cond-t
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-edit-2511-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --qwen-image-zero-cond-t
```
<img alt="qwen_image_edit_2509" src="../assets/qwen/qwen_image_edit_2511.png" />

3
docs/troubleshooting.md Normal file
View File

@ -0,0 +1,3 @@
## Try `--disable-fa`
By default, **stable-diffusion.cpp** uses Flash Attention to improve generation speed and optimize GPU memory usage. However, on some backends, Flash Attention may cause unexpected issues, such as generating completely black images. In such cases, you can try disabling Flash Attention by using `--disable-fa`.

View File

@ -16,7 +16,7 @@ You can run Z-Image with stable-diffusion.cpp on GPUs with 4GB of VRAM — or ev
## Examples
```
.\bin\Release\sd-cli.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu --diffusion-fa -H 1024 -W 512
.\bin\Release\sd-cli.exe --diffusion-model z_image_turbo-Q3_K.gguf --vae ..\..\ComfyUI\models\vae\ae.sft --llm ..\..\ComfyUI\models\text_encoders\Qwen3-4B-Instruct-2507-Q4_K_M.gguf -p "A cinematic, melancholic photograph of a solitary hooded figure walking through a sprawling, rain-slicked metropolis at night. The city lights are a chaotic blur of neon orange and cool blue, reflecting on the wet asphalt. The scene evokes a sense of being a single component in a vast machine. Superimposed over the image in a sleek, modern, slightly glitched font is the philosophical quote: 'THE CITY IS A CIRCUIT BOARD, AND I AM A BROKEN TRANSISTOR.' -- moody, atmospheric, profound, dark academic" --cfg-scale 1.0 -v --offload-to-cpu -H 1024 -W 512
```
<img width="256" alt="z-image example" src="../assets/z_image/q3_K.png" />

View File

@ -52,7 +52,7 @@ Context Options:
--control-net-cpu keep controlnet in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram)
--diffusion-fa use flash attention in the diffusion model
--disable-fa disable flash attention
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
--vae-conv-direct use ggml_conv2d_direct in the vae model
--circular enable circular padding for convolutions

View File

@ -457,7 +457,7 @@ struct SDContextParams {
bool control_net_cpu = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool diffusion_flash_attn = false;
bool flash_attn = true;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
@ -616,9 +616,9 @@ struct SDContextParams {
"keep vae in cpu (for low vram)",
true, &vae_on_cpu},
{"",
"--diffusion-fa",
"use flash attention in the diffusion model",
true, &diffusion_flash_attn},
"--disable-fa",
"disable flash attention",
false, &flash_attn},
{"",
"--diffusion-conv-direct",
"use ggml_conv2d_direct in the diffusion model",
@ -904,7 +904,7 @@ struct SDContextParams {
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
<< " flash_attn: " << (flash_attn ? "true" : "false") << ",\n"
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
<< " circular: " << (circular ? "true" : "false") << ",\n"
@ -968,7 +968,7 @@ struct SDContextParams {
clip_on_cpu,
control_net_cpu,
vae_on_cpu,
diffusion_flash_attn,
flash_attn,
taesd_preview,
diffusion_conv_direct,
vae_conv_direct,

View File

@ -44,7 +44,7 @@ Context Options:
--clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram)
--mmap whether to memory-map model
--diffusion-fa use flash attention in the diffusion model
--disable-fa disable flash attention
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
--vae-conv-direct use ggml_conv2d_direct in the vae model
--circular enable circular padding for convolutions

View File

@ -2594,7 +2594,7 @@ public:
v = v_proj->forward(ctx, x);
}
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask); // [N, n_token, embed_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, nullptr, mask, false, false); // [N, n_token, embed_dim]
x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x;

View File

@ -435,7 +435,7 @@ public:
}
}
if (is_chroma) {
if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) {
if (sd_ctx_params->flash_attn && sd_ctx_params->chroma_use_dit_mask) {
LOG_WARN(
"!!!It looks like you are using Chroma with flash attention. "
"This is currently unsupported. "
@ -561,14 +561,6 @@ public:
}
}
if (sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attn_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attn_enabled(true);
}
}
cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors);
@ -712,6 +704,24 @@ public:
pmid_model->get_param_tensors(tensors, "pmid");
}
if (sd_ctx_params->flash_attn) {
LOG_INFO("Using flash attention");
diffusion_model->set_flash_attention_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attention_enabled(true);
}
cond_stage_model->set_flash_attention_enabled(true);
if (clip_vision) {
clip_vision->set_flash_attention_enabled(true);
}
if (first_stage_model) {
first_stage_model->set_flash_attention_enabled(true);
}
if (tae_first_stage) {
tae_first_stage->set_flash_attention_enabled(true);
}
}
diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
@ -2884,7 +2894,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
sd_ctx_params->keep_clip_on_cpu = false;
sd_ctx_params->keep_control_net_on_cpu = false;
sd_ctx_params->keep_vae_on_cpu = false;
sd_ctx_params->diffusion_flash_attn = false;
sd_ctx_params->flash_attn = false;
sd_ctx_params->circular_x = false;
sd_ctx_params->circular_y = false;
sd_ctx_params->chroma_use_dit_mask = true;
@ -2925,7 +2935,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"keep_clip_on_cpu: %s\n"
"keep_control_net_on_cpu: %s\n"
"keep_vae_on_cpu: %s\n"
"diffusion_flash_attn: %s\n"
"flash_attn: %s\n"
"circular_x: %s\n"
"circular_y: %s\n"
"chroma_use_dit_mask: %s\n"
@ -2956,7 +2966,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
BOOL_STR(sd_ctx_params->flash_attn),
BOOL_STR(sd_ctx_params->circular_x),
BOOL_STR(sd_ctx_params->circular_y),
BOOL_STR(sd_ctx_params->chroma_use_dit_mask),

View File

@ -186,7 +186,7 @@ typedef struct {
bool keep_clip_on_cpu;
bool keep_control_net_on_cpu;
bool keep_vae_on_cpu;
bool diffusion_flash_attn;
bool flash_attn;
bool tae_preview_only;
bool diffusion_conv_direct;
bool vae_conv_direct;

View File

@ -141,7 +141,7 @@ public:
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
}
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false);
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, ctx->flash_attn_enabled);
if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]

View File

@ -572,8 +572,8 @@ namespace WAN {
auto v = qkv_vec[2];
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c]
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, ctx->flash_attn_enabled); // [t, h * w, c]
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]