mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-01-02 10:43:35 +00:00
feat: add support for qwen image edit 2511 (#1096)
This commit is contained in:
parent
3d5fdd7b37
commit
a0adcfb148
@ -52,7 +52,7 @@ API and command-line option may change frequently.***
|
||||
- [Ovis-Image](./docs/ovis_image.md)
|
||||
- Image Edit Models
|
||||
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
||||
- [Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
|
||||
- [Qwen Image Edit series](./docs/qwen_image_edit.md)
|
||||
- Video Models
|
||||
- [Wan2.1/Wan2.2](./docs/wan.md)
|
||||
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support.
|
||||
@ -132,7 +132,7 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
|
||||
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
||||
- [Chroma](./docs/chroma.md)
|
||||
- [🔥Qwen Image](./docs/qwen_image.md)
|
||||
- [🔥Qwen Image Edit/Qwen Image Edit 2509](./docs/qwen_image_edit.md)
|
||||
- [🔥Qwen Image Edit series](./docs/qwen_image_edit.md)
|
||||
- [🔥Wan2.1/Wan2.2](./docs/wan.md)
|
||||
- [🔥Z-Image](./docs/z_image.md)
|
||||
- [Ovis-Image](./docs/ovis_image.md)
|
||||
|
||||
BIN
assets/qwen/qwen_image_edit_2511.png
Normal file
BIN
assets/qwen/qwen_image_edit_2511.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 450 KiB |
@ -320,8 +320,9 @@ struct QwenImageModel : public DiffusionModel {
|
||||
bool offload_params_to_cpu,
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "model.diffusion_model",
|
||||
SDVersion version = VERSION_QWEN_IMAGE)
|
||||
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_storage_map, prefix, version) {
|
||||
SDVersion version = VERSION_QWEN_IMAGE,
|
||||
bool zero_cond_t = false)
|
||||
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_storage_map, prefix, version, zero_cond_t) {
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
|
||||
@ -9,6 +9,9 @@
|
||||
- Qwen Image Edit 2509
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models
|
||||
- gguf: https://huggingface.co/QuantStack/Qwen-Image-Edit-2509-GGUF/tree/main
|
||||
- Qwen Image Edit 2511
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/tree/main/split_files/diffusion_models
|
||||
- gguf: https://huggingface.co/unsloth/Qwen-Image-Edit-2511-GGUF/tree/main
|
||||
- Download vae
|
||||
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/vae
|
||||
- Download qwen_2.5_vl 7b
|
||||
@ -33,3 +36,13 @@
|
||||
```
|
||||
|
||||
<img alt="qwen_image_edit_2509" src="../assets/qwen/qwen_image_edit_2509.png" />
|
||||
|
||||
### Qwen Image Edit 2511
|
||||
|
||||
To use the new Qwen Image Edit 2511 mode, the `--qwen-image-zero-cond-t` flag must be enabled; otherwise, image editing quality will degrade significantly.
|
||||
|
||||
```
|
||||
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-edit-2511-Q4_K_M.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_2.5_vl_7b.safetensors --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu --diffusion-fa --flow-shift 3 -r ..\assets\flux\flux1-dev-q8_0.png -p "change 'flux.cpp' to 'edit.cpp'" --qwen-image-zero-cond-t
|
||||
```
|
||||
|
||||
<img alt="qwen_image_edit_2509" src="../assets/qwen/qwen_image_edit_2511.png" />
|
||||
@ -457,6 +457,8 @@ struct SDContextParams {
|
||||
bool chroma_use_t5_mask = false;
|
||||
int chroma_t5_mask_pad = 1;
|
||||
|
||||
bool qwen_image_zero_cond_t = false;
|
||||
|
||||
prediction_t prediction = PREDICTION_COUNT;
|
||||
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
|
||||
|
||||
@ -625,6 +627,10 @@ struct SDContextParams {
|
||||
"--chroma-disable-dit-mask",
|
||||
"disable dit mask for chroma",
|
||||
false, &chroma_use_dit_mask},
|
||||
{"",
|
||||
"--qwen-image-zero-cond-t",
|
||||
"enable zero_cond_t for qwen image",
|
||||
true, &qwen_image_zero_cond_t},
|
||||
{"",
|
||||
"--chroma-enable-t5-mask",
|
||||
"enable t5 mask for chroma",
|
||||
@ -888,6 +894,7 @@ struct SDContextParams {
|
||||
<< " circular_x: " << (circular_x ? "true" : "false") << ",\n"
|
||||
<< " circular_y: " << (circular_y ? "true" : "false") << ",\n"
|
||||
<< " chroma_use_dit_mask: " << (chroma_use_dit_mask ? "true" : "false") << ",\n"
|
||||
<< " qwen_image_zero_cond_t: " << (qwen_image_zero_cond_t ? "true" : "false") << ",\n"
|
||||
<< " chroma_use_t5_mask: " << (chroma_use_t5_mask ? "true" : "false") << ",\n"
|
||||
<< " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n"
|
||||
<< " prediction: " << sd_prediction_name(prediction) << ",\n"
|
||||
@ -953,6 +960,7 @@ struct SDContextParams {
|
||||
chroma_use_dit_mask,
|
||||
chroma_use_t5_mask,
|
||||
chroma_t5_mask_pad,
|
||||
qwen_image_zero_cond_t,
|
||||
flow_shift,
|
||||
};
|
||||
return sd_ctx_params;
|
||||
|
||||
13
flux.hpp
13
flux.hpp
@ -233,14 +233,17 @@ namespace Flux {
|
||||
__STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* shift,
|
||||
struct ggml_tensor* scale) {
|
||||
struct ggml_tensor* scale,
|
||||
bool skip_reshape = false) {
|
||||
// x: [N, L, C]
|
||||
// scale: [N, C]
|
||||
// shift: [N, C]
|
||||
scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C]
|
||||
shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C]
|
||||
x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
|
||||
x = ggml_add(ctx, x, shift);
|
||||
if (!skip_reshape) {
|
||||
scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C]
|
||||
shift = ggml_reshape_3d(ctx, shift, shift->ne[0], 1, shift->ne[1]); // [N, 1, C]
|
||||
}
|
||||
x = ggml_add(ctx, x, ggml_mul(ctx, x, scale));
|
||||
x = ggml_add(ctx, x, shift);
|
||||
return x;
|
||||
}
|
||||
|
||||
|
||||
111
qwen_image.hpp
111
qwen_image.hpp
@ -191,11 +191,16 @@ namespace Qwen {
|
||||
};
|
||||
|
||||
class QwenImageTransformerBlock : public GGMLBlock {
|
||||
protected:
|
||||
bool zero_cond_t;
|
||||
|
||||
public:
|
||||
QwenImageTransformerBlock(int64_t dim,
|
||||
int64_t num_attention_heads,
|
||||
int64_t attention_head_dim,
|
||||
float eps = 1e-6) {
|
||||
float eps = 1e-6,
|
||||
bool zero_cond_t = false)
|
||||
: zero_cond_t(zero_cond_t) {
|
||||
// img_mod.0 is nn.SiLU()
|
||||
blocks["img_mod.1"] = std::shared_ptr<GGMLBlock>(new Linear(dim, 6 * dim, true));
|
||||
|
||||
@ -220,11 +225,37 @@ namespace Qwen {
|
||||
eps));
|
||||
}
|
||||
|
||||
std::vector<ggml_tensor*> get_mod_params_vec(ggml_context* ctx, ggml_tensor* mod_params, ggml_tensor* index = nullptr) {
|
||||
// index: [N, n_img_token]
|
||||
// mod_params: [N, hidden_size * 12]
|
||||
if (index == nullptr) {
|
||||
return ggml_ext_chunk(ctx, mod_params, 6, 0);
|
||||
}
|
||||
mod_params = ggml_reshape_1d(ctx, mod_params, ggml_nelements(mod_params));
|
||||
auto mod_params_vec = ggml_ext_chunk(ctx, mod_params, 12, 0);
|
||||
index = ggml_reshape_3d(ctx, index, 1, index->ne[0], index->ne[1]); // [N, n_img_token, 1]
|
||||
index = ggml_repeat_4d(ctx, index, mod_params_vec[0]->ne[0], index->ne[1], index->ne[2], index->ne[3]); // [N, n_img_token, hidden_size]
|
||||
std::vector<ggml_tensor*> mod_results;
|
||||
for (int i = 0; i < 6; i++) {
|
||||
auto mod_0 = mod_params_vec[i];
|
||||
auto mod_1 = mod_params_vec[i + 6];
|
||||
|
||||
// mod_result = torch.where(index == 0, mod_0, mod_1)
|
||||
// mod_result = (1 - index)*mod_0 + index*mod_1
|
||||
mod_0 = ggml_sub(ctx, ggml_repeat(ctx, mod_0, index), ggml_mul(ctx, index, mod_0)); // [N, n_img_token, hidden_size]
|
||||
mod_1 = ggml_mul(ctx, index, mod_1); // [N, n_img_token, hidden_size]
|
||||
auto mod_result = ggml_add(ctx, mod_0, mod_1);
|
||||
mod_results.push_back(mod_result);
|
||||
}
|
||||
return mod_results;
|
||||
}
|
||||
|
||||
virtual std::pair<ggml_tensor*, ggml_tensor*> forward(GGMLRunnerContext* ctx,
|
||||
struct ggml_tensor* img,
|
||||
struct ggml_tensor* txt,
|
||||
struct ggml_tensor* t_emb,
|
||||
struct ggml_tensor* pe) {
|
||||
struct ggml_tensor* pe,
|
||||
struct ggml_tensor* modulate_index = nullptr) {
|
||||
// img: [N, n_img_token, hidden_size]
|
||||
// txt: [N, n_txt_token, hidden_size]
|
||||
// pe: [n_img_token + n_txt_token, d_head/2, 2, 2]
|
||||
@ -244,14 +275,18 @@ namespace Qwen {
|
||||
|
||||
auto img_mod_params = ggml_silu(ctx->ggml_ctx, t_emb);
|
||||
img_mod_params = img_mod_1->forward(ctx, img_mod_params);
|
||||
auto img_mod_param_vec = ggml_ext_chunk(ctx->ggml_ctx, img_mod_params, 6, 0);
|
||||
auto img_mod_param_vec = get_mod_params_vec(ctx->ggml_ctx, img_mod_params, modulate_index);
|
||||
|
||||
if (zero_cond_t) {
|
||||
t_emb = ggml_ext_chunk(ctx->ggml_ctx, t_emb, 2, 1)[0];
|
||||
}
|
||||
|
||||
auto txt_mod_params = ggml_silu(ctx->ggml_ctx, t_emb);
|
||||
txt_mod_params = txt_mod_1->forward(ctx, txt_mod_params);
|
||||
auto txt_mod_param_vec = ggml_ext_chunk(ctx->ggml_ctx, txt_mod_params, 6, 0);
|
||||
auto txt_mod_param_vec = get_mod_params_vec(ctx->ggml_ctx, txt_mod_params);
|
||||
|
||||
auto img_normed = img_norm1->forward(ctx, img);
|
||||
auto img_modulated = Flux::modulate(ctx->ggml_ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1]);
|
||||
auto img_modulated = Flux::modulate(ctx->ggml_ctx, img_normed, img_mod_param_vec[0], img_mod_param_vec[1], modulate_index != nullptr);
|
||||
auto img_gate1 = img_mod_param_vec[2];
|
||||
|
||||
auto txt_normed = txt_norm1->forward(ctx, txt);
|
||||
@ -264,7 +299,7 @@ namespace Qwen {
|
||||
txt = ggml_add(ctx->ggml_ctx, txt, ggml_mul(ctx->ggml_ctx, txt_attn_output, txt_gate1));
|
||||
|
||||
auto img_normed2 = img_norm2->forward(ctx, img);
|
||||
auto img_modulated2 = Flux::modulate(ctx->ggml_ctx, img_normed2, img_mod_param_vec[3], img_mod_param_vec[4]);
|
||||
auto img_modulated2 = Flux::modulate(ctx->ggml_ctx, img_normed2, img_mod_param_vec[3], img_mod_param_vec[4], modulate_index != nullptr);
|
||||
auto img_gate2 = img_mod_param_vec[5];
|
||||
|
||||
auto txt_normed2 = txt_norm2->forward(ctx, txt);
|
||||
@ -325,6 +360,7 @@ namespace Qwen {
|
||||
float theta = 10000;
|
||||
std::vector<int> axes_dim = {16, 56, 56};
|
||||
int64_t axes_dim_sum = 128;
|
||||
bool zero_cond_t = false;
|
||||
};
|
||||
|
||||
class QwenImageModel : public GGMLBlock {
|
||||
@ -346,7 +382,8 @@ namespace Qwen {
|
||||
auto block = std::shared_ptr<GGMLBlock>(new QwenImageTransformerBlock(inner_dim,
|
||||
params.num_attention_heads,
|
||||
params.attention_head_dim,
|
||||
1e-6f));
|
||||
1e-6f,
|
||||
params.zero_cond_t));
|
||||
blocks["transformer_blocks." + std::to_string(i)] = block;
|
||||
}
|
||||
|
||||
@ -421,7 +458,8 @@ namespace Qwen {
|
||||
struct ggml_tensor* x,
|
||||
struct ggml_tensor* timestep,
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* pe) {
|
||||
struct ggml_tensor* pe,
|
||||
struct ggml_tensor* modulate_index = nullptr) {
|
||||
auto time_text_embed = std::dynamic_pointer_cast<QwenTimestepProjEmbeddings>(blocks["time_text_embed"]);
|
||||
auto txt_norm = std::dynamic_pointer_cast<RMSNorm>(blocks["txt_norm"]);
|
||||
auto img_in = std::dynamic_pointer_cast<Linear>(blocks["img_in"]);
|
||||
@ -430,18 +468,26 @@ namespace Qwen {
|
||||
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
|
||||
|
||||
auto t_emb = time_text_embed->forward(ctx, timestep);
|
||||
auto img = img_in->forward(ctx, x);
|
||||
auto txt = txt_norm->forward(ctx, context);
|
||||
txt = txt_in->forward(ctx, txt);
|
||||
if (params.zero_cond_t) {
|
||||
auto t_emb_0 = time_text_embed->forward(ctx, ggml_ext_zeros(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]));
|
||||
t_emb = ggml_concat(ctx->ggml_ctx, t_emb, t_emb_0, 1);
|
||||
}
|
||||
auto img = img_in->forward(ctx, x);
|
||||
auto txt = txt_norm->forward(ctx, context);
|
||||
txt = txt_in->forward(ctx, txt);
|
||||
|
||||
for (int i = 0; i < params.num_layers; i++) {
|
||||
auto block = std::dynamic_pointer_cast<QwenImageTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
|
||||
|
||||
auto result = block->forward(ctx, img, txt, t_emb, pe);
|
||||
auto result = block->forward(ctx, img, txt, t_emb, pe, modulate_index);
|
||||
img = result.first;
|
||||
txt = result.second;
|
||||
}
|
||||
|
||||
if (params.zero_cond_t) {
|
||||
t_emb = ggml_ext_chunk(ctx->ggml_ctx, t_emb, 2, 1)[0];
|
||||
}
|
||||
|
||||
img = norm_out->forward(ctx, img, t_emb);
|
||||
img = proj_out->forward(ctx, img);
|
||||
|
||||
@ -453,7 +499,8 @@ namespace Qwen {
|
||||
struct ggml_tensor* timestep,
|
||||
struct ggml_tensor* context,
|
||||
struct ggml_tensor* pe,
|
||||
std::vector<ggml_tensor*> ref_latents = {}) {
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
struct ggml_tensor* modulate_index = nullptr) {
|
||||
// Forward pass of DiT.
|
||||
// x: [N, C, H, W]
|
||||
// timestep: [N,]
|
||||
@ -479,7 +526,7 @@ namespace Qwen {
|
||||
int64_t h_len = ((H + (params.patch_size / 2)) / params.patch_size);
|
||||
int64_t w_len = ((W + (params.patch_size / 2)) / params.patch_size);
|
||||
|
||||
auto out = forward_orig(ctx, img, timestep, context, pe); // [N, h_len*w_len, ph*pw*C]
|
||||
auto out = forward_orig(ctx, img, timestep, context, pe, modulate_index); // [N, h_len*w_len, ph*pw*C]
|
||||
|
||||
if (out->ne[1] > img_tokens) {
|
||||
out = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
|
||||
@ -502,15 +549,19 @@ namespace Qwen {
|
||||
QwenImageParams qwen_image_params;
|
||||
QwenImageModel qwen_image;
|
||||
std::vector<float> pe_vec;
|
||||
std::vector<float> modulate_index_vec;
|
||||
SDVersion version;
|
||||
|
||||
QwenImageRunner(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
const std::string prefix = "",
|
||||
SDVersion version = VERSION_QWEN_IMAGE)
|
||||
SDVersion version = VERSION_QWEN_IMAGE,
|
||||
bool zero_cond_t = false)
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
qwen_image_params.num_layers = 0;
|
||||
qwen_image_params.num_layers = 0;
|
||||
qwen_image_params.zero_cond_t = zero_cond_t;
|
||||
LOG_DEBUG("zero_cond_t: %d", zero_cond_t);
|
||||
for (auto pair : tensor_storage_map) {
|
||||
std::string tensor_name = pair.first;
|
||||
if (tensor_name.find(prefix) == std::string::npos)
|
||||
@ -576,6 +627,31 @@ namespace Qwen {
|
||||
// pe->data = nullptr;
|
||||
set_backend_tensor_data(pe, pe_vec.data());
|
||||
|
||||
ggml_tensor* modulate_index = nullptr;
|
||||
if (qwen_image_params.zero_cond_t) {
|
||||
modulate_index_vec.clear();
|
||||
|
||||
int64_t h_len = ((x->ne[1] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size);
|
||||
int64_t w_len = ((x->ne[0] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size);
|
||||
int64_t num_img_tokens = h_len * w_len;
|
||||
|
||||
modulate_index_vec.insert(modulate_index_vec.end(), num_img_tokens, 0.f);
|
||||
int64_t num_ref_img_tokens = 0;
|
||||
for (ggml_tensor* ref : ref_latents) {
|
||||
int64_t h_len = ((ref->ne[1] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size);
|
||||
int64_t w_len = ((ref->ne[0] + (qwen_image_params.patch_size / 2)) / qwen_image_params.patch_size);
|
||||
|
||||
num_ref_img_tokens += h_len * w_len;
|
||||
}
|
||||
|
||||
if (num_ref_img_tokens > 0) {
|
||||
modulate_index_vec.insert(modulate_index_vec.end(), num_ref_img_tokens, 1.f);
|
||||
}
|
||||
|
||||
modulate_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, modulate_index_vec.size());
|
||||
set_backend_tensor_data(modulate_index, modulate_index_vec.data());
|
||||
}
|
||||
|
||||
auto runner_ctx = get_context();
|
||||
|
||||
struct ggml_tensor* out = qwen_image.forward(&runner_ctx,
|
||||
@ -583,7 +659,8 @@ namespace Qwen {
|
||||
timesteps,
|
||||
context,
|
||||
pe,
|
||||
ref_latents);
|
||||
ref_latents,
|
||||
modulate_index);
|
||||
|
||||
ggml_build_forward_expand(gf, out);
|
||||
|
||||
|
||||
@ -520,7 +520,8 @@ public:
|
||||
offload_params_to_cpu,
|
||||
tensor_storage_map,
|
||||
"model.diffusion_model",
|
||||
version);
|
||||
version,
|
||||
sd_ctx_params->qwen_image_zero_cond_t);
|
||||
} else if (sd_version_is_z_image(version)) {
|
||||
cond_stage_model = std::make_shared<LLMEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
|
||||
@ -195,6 +195,7 @@ typedef struct {
|
||||
bool chroma_use_dit_mask;
|
||||
bool chroma_use_t5_mask;
|
||||
int chroma_t5_mask_pad;
|
||||
bool qwen_image_zero_cond_t;
|
||||
float flow_shift;
|
||||
} sd_ctx_params_t;
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user