add qwen image t2i pipeline

This commit is contained in:
leejet 2025-09-22 21:18:20 +08:00
parent d232509b6e
commit cf19c6e759
10 changed files with 268 additions and 84 deletions

View File

@ -398,7 +398,6 @@ public:
}
for (auto& token : matches) {
std::string token_str = token.str();
LOG_DEBUG("%s", token_str.c_str());
std::u32string utf32_token;
for (int i = 0; i < token_str.length(); i++) {
unsigned char b = token_str[i];

View File

@ -2,6 +2,7 @@
#define __CONDITIONER_HPP__
#include "clip.hpp"
#include "qwenvl.hpp"
#include "t5.hpp"
struct SDCondition {
@ -22,11 +23,11 @@ struct Conditioner {
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) = 0;
virtual void alloc_params_buffer() = 0;
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
bool zero_out_masked = false) = 0;
virtual void alloc_params_buffer() = 0;
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
@ -35,9 +36,13 @@ struct Conditioner {
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) = 0;
bool zero_out_masked = false) {
GGML_ABORT("Not implemented yet!");
}
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) = 0;
const std::string& prompt) {
GGML_ABORT("Not implemented yet!");
}
};
// ldm.modules.encoders.modules.FrozenCLIPEmbedder
@ -978,23 +983,6 @@ struct SD3CLIPEmbedder : public Conditioner {
auto tokens_and_weights = tokenize(text, 77, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
}
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) {
GGML_ASSERT(0 && "Not implemented yet!");
}
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) {
GGML_ASSERT(0 && "Not implemented yet!");
}
};
struct FluxCLIPEmbedder : public Conditioner {
@ -1195,23 +1183,6 @@ struct FluxCLIPEmbedder : public Conditioner {
auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
}
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) {
GGML_ASSERT(0 && "Not implemented yet!");
}
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) {
GGML_ASSERT(0 && "Not implemented yet!");
}
};
struct T5CLIPEmbedder : public Conditioner {
@ -1398,22 +1369,135 @@ struct T5CLIPEmbedder : public Conditioner {
auto tokens_and_weights = tokenize(text, chunk_len, true);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
}
};
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int num_input_imgs,
int adm_in_channels = -1,
bool zero_out_masked = false) {
GGML_ASSERT(0 && "Not implemented yet!");
struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
Qwen::Qwen2Tokenizer tokenizer;
std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl;
int prompt_template_encode_start_idx = 34;
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "") {
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.qwen2vl");
}
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
const std::string& prompt) {
GGML_ASSERT(0 && "Not implemented yet!");
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
qwenvl->get_param_tensors(tensors, "text_encoders.qwen2vl");
}
void alloc_params_buffer() {
qwenvl->alloc_params_buffer();
}
void free_params_buffer() {
qwenvl->free_params_buffer();
}
size_t get_params_buffer_size() {
size_t buffer_size = 0;
buffer_size += qwenvl->get_params_buffer_size();
return buffer_size;
}
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
auto parsed_attention = parse_prompt_attention(text);
{
std::stringstream ss;
ss << "[";
for (const auto& item : parsed_attention) {
ss << "['" << item.first << "', " << item.second << "], ";
}
ss << "]";
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
}
std::vector<int> tokens;
std::vector<float> weights;
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = tokenizer.tokenize(curr_text, nullptr);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}
tokenizer.pad_tokens(tokens, weights, max_length, padding);
// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", ";
// }
// std::cout << std::endl;
return {tokens, weights};
}
SDCondition get_learned_condition_common(ggml_context* work_ctx,
int n_threads,
std::tuple<std::vector<int>, std::vector<float>> token_and_weights,
int clip_skip,
bool zero_out_masked = false) {
auto& tokens = std::get<0>(token_and_weights);
auto& weights = std::get<1>(token_and_weights);
int64_t t0 = ggml_time_ms();
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 3584]
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
qwenvl->compute(n_threads,
input_ids,
&hidden_states,
work_ctx);
{
auto tensor = hidden_states;
float original_mean = ggml_tensor_mean(tensor);
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
value *= weights[i1];
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
}
}
}
float new_mean = ggml_tensor_mean(tensor);
ggml_tensor_scale(tensor, (original_mean / new_mean));
}
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
GGML_TYPE_F32,
hidden_states->ne[0],
hidden_states->ne[1] - prompt_template_encode_start_idx,
hidden_states->ne[2]);
ggml_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
float value = ggml_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
ggml_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
});
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
return SDCondition(new_hidden_states, nullptr, nullptr);
}
SDCondition get_learned_condition(ggml_context* work_ctx,
int n_threads,
const std::string& text,
int clip_skip,
int width,
int height,
int adm_in_channels = -1,
bool zero_out_masked = false) {
std::string prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + text + "<|im_end|>\n<|im_start|>assistant\n";
auto tokens_and_weights = tokenize(prompt, 0, false);
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
}
};

View File

@ -3,6 +3,7 @@
#include "flux.hpp"
#include "mmdit.hpp"
#include "qwen_image.hpp"
#include "unet.hpp"
#include "wan.hpp"
@ -263,4 +264,58 @@ struct WanModel : public DiffusionModel {
}
};
struct QwenImageModel : public DiffusionModel {
std::string prefix;
Qwen::QwenImageRunner qwen_image;
QwenImageModel(ggml_backend_t backend,
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "model.diffusion_model",
SDVersion version = VERSION_QWEN_IMAGE,
bool flash_attn = false)
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
}
std::string get_desc() {
return qwen_image.get_desc();
}
void alloc_params_buffer() {
qwen_image.alloc_params_buffer();
}
void free_params_buffer() {
qwen_image.free_params_buffer();
}
void free_compute_buffer() {
qwen_image.free_compute_buffer();
}
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
qwen_image.get_param_tensors(tensors, prefix);
}
size_t get_params_buffer_size() {
return qwen_image.get_params_buffer_size();
}
int64_t get_adm_in_channels() {
return 768;
}
void compute(int n_threads,
DiffusionParams diffusion_params,
struct ggml_tensor** output = NULL,
struct ggml_context* output_ctx = NULL) {
return qwen_image.compute(n_threads,
diffusion_params.x,
diffusion_params.timesteps,
diffusion_params.context,
output,
output_ctx);
}
};
#endif

View File

@ -27,8 +27,6 @@
#include "avi_writer.h"
#include "qwen_image.hpp"
#if defined(_WIN32)
#define NOMINMAX
#include <windows.h>
@ -61,6 +59,7 @@ struct SDParams {
std::string clip_g_path;
std::string clip_vision_path;
std::string t5xxl_path;
std::string qwen2vl_path;
std::string diffusion_model_path;
std::string high_noise_diffusion_model_path;
std::string vae_path;
@ -146,6 +145,7 @@ void print_params(SDParams params) {
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str());
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
printf(" vae_path: %s\n", params.vae_path.c_str());
@ -217,6 +217,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --clip_g path to the clip-g text encoder\n");
printf(" --clip_vision path to the clip-vision encoder\n");
printf(" --t5xxl path to the t5xxl text encoder\n");
printf(" --qwen2vl path to the qwen2vl text encoder\n");
printf(" --vae [VAE] path to vae\n");
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --control-net [CONTROL_PATH] path to control net model\n");
@ -486,6 +487,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--clip_g", "", &params.clip_g_path},
{"", "--clip_vision", "", &params.clip_vision_path},
{"", "--t5xxl", "", &params.t5xxl_path},
{"", "--qwen2vl", "", &params.qwen2vl_path},
{"", "--diffusion-model", "", &params.diffusion_model_path},
{"", "--high-noise-diffusion-model", "", &params.high_noise_diffusion_model_path},
{"", "--vae", "", &params.vae_path},
@ -945,7 +947,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
}
parameter_string += ", ";
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path}) {
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path}) {
if (!te.empty()) {
parameter_string += "TE: " + sd_basename(te) + ", ";
}
@ -1140,10 +1142,6 @@ bool load_images_from_dir(const std::string dir,
int main(int argc, const char* argv[]) {
SDParams params;
params.verbose = true;
sd_set_log_callback(sd_log_cb, (void*)&params);
Qwen::QwenImageRunner::load_from_file_and_test(argv[1]);
exit(1);
parse_args(argc, argv, params);
params.sample_params.guidance.slg.layers = params.skip_layers.data();
params.sample_params.guidance.slg.layer_count = params.skip_layers.size();
@ -1323,6 +1321,7 @@ int main(int argc, const char* argv[]) {
params.clip_g_path.c_str(),
params.clip_vision_path.c_str(),
params.t5xxl_path.c_str(),
params.qwen2vl_path.c_str(),
params.diffusion_model_path.c_str(),
params.high_noise_diffusion_model_path.c_str(),
params.vae_path.c_str(),

View File

@ -110,9 +110,9 @@ const char* unused_tensors[] = {
"embedding_manager",
"denoiser.sigmas",
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
"qwen2vl.output.weight",
"qwen2vl.lm_head.",
"qwen2vl.visual.",
"text_encoders.qwen2vl.output.weight",
"text_encoders.qwen2vl.lm_head.",
"text_encoders.qwen2vl.visual.",
};
bool is_unused_tensor(std::string name) {
@ -1762,6 +1762,9 @@ SDVersion ModelLoader::get_sd_version() {
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
return VERSION_SD3;
}
if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) {
return VERSION_QWEN_IMAGE;
}
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
is_wan = true;
}

13
model.h
View File

@ -34,6 +34,7 @@ enum SDVersion {
VERSION_WAN2,
VERSION_WAN2_2_I2V,
VERSION_WAN2_2_TI2V,
VERSION_QWEN_IMAGE,
VERSION_COUNT,
};
@ -79,6 +80,13 @@ static inline bool sd_version_is_wan(SDVersion version) {
return false;
}
static inline bool sd_version_is_qwen_image(SDVersion version) {
if (version == VERSION_QWEN_IMAGE) {
return true;
}
return false;
}
static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
return true;
@ -87,7 +95,10 @@ static inline bool sd_version_is_inpaint(SDVersion version) {
}
static inline bool sd_version_is_dit(SDVersion version) {
if (sd_version_is_flux(version) || sd_version_is_sd3(version) || sd_version_is_wan(version)) {
if (sd_version_is_flux(version) ||
sd_version_is_sd3(version) ||
sd_version_is_wan(version) ||
sd_version_is_qwen_image(version)) {
return true;
}
return false;

View File

@ -54,7 +54,7 @@ namespace Qwen {
// return: [N, embedding_dim]
auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]);
auto timesteps_proj = ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f);
auto timesteps_proj = ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1.f);
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
return timesteps_emb;
}
@ -423,13 +423,9 @@ namespace Qwen {
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
auto t_emb = time_text_embed->forward(ctx, timestep);
LOG_DEBUG("xxx");
auto img = img_in->forward(ctx, x);
LOG_DEBUG("xxx");
auto txt = txt_norm->forward(ctx, context);
LOG_DEBUG("xxx");
txt = txt_in->forward(ctx, txt);
LOG_DEBUG("xxx");
auto img = img_in->forward(ctx, x);
auto txt = txt_norm->forward(ctx, context);
txt = txt_in->forward(ctx, txt);
for (int i = 0; i < params.num_layers; i++) {
auto block = std::dynamic_pointer_cast<QwenImageTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
@ -492,7 +488,7 @@ namespace Qwen {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {},
const std::string prefix = "",
SDVersion version = VERSION_FLUX,
SDVersion version = VERSION_QWEN_IMAGE,
bool flash_attn = false)
: GGMLRunner(backend, offload_params_to_cpu) {
qwen_image_params.flash_attn = flash_attn;
@ -571,13 +567,12 @@ namespace Qwen {
GGML_ASSERT(work_ctx != NULL);
{
// cpu f16:
// auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1);
// ggml_set_f32(x, 0.01f);
auto x = load_tensor_from_file(work_ctx, "./qwen_image_x.bin");
print_ggml_tensor(x);
std::vector<float> timesteps_vec(1, 1.f);
std::vector<float> timesteps_vec(1, 1000.f);
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 3584, 256, 1);
@ -598,6 +593,7 @@ namespace Qwen {
static void load_from_file_and_test(const std::string& file_path) {
// cuda q8: pass
// cuda q8 fa: nan
// ggml_backend_t backend = ggml_backend_cuda_init(0);
ggml_backend_t backend = ggml_backend_cpu_init();
ggml_type model_data_type = GGML_TYPE_Q8_0;
@ -619,7 +615,9 @@ namespace Qwen {
std::shared_ptr<QwenImageRunner> qwen_image = std::shared_ptr<QwenImageRunner>(new QwenImageRunner(backend,
false,
tensor_types,
"model.diffusion_model"));
"model.diffusion_model",
VERSION_QWEN_IMAGE,
true));
qwen_image->alloc_params_buffer();
std::map<std::string, ggml_tensor*> tensors;

View File

@ -331,7 +331,7 @@ namespace Qwen {
ss << "\"" << token << "\", ";
}
ss << "]";
LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
return bpe_tokens;
}

View File

@ -40,6 +40,7 @@ const char* model_version_to_str[] = {
"Wan 2.x",
"Wan 2.2 I2V",
"Wan 2.2 TI2V",
"Qwen Image",
};
const char* sampling_methods_str[] = {
@ -251,6 +252,13 @@ public:
}
}
if (strlen(SAFE_STR(sd_ctx_params->qwen2vl_path)) > 0) {
LOG_INFO("loading qwen2vl from '%s'", sd_ctx_params->qwen2vl_path);
if (!model_loader.init_from_file(sd_ctx_params->qwen2vl_path, "text_encoders.qwen2vl.")) {
LOG_WARN("loading qwen2vl from '%s' failed", sd_ctx_params->qwen2vl_path);
}
}
if (strlen(SAFE_STR(sd_ctx_params->vae_path)) > 0) {
LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path);
if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) {
@ -316,7 +324,7 @@ public:
} else if (sd_version_is_flux(version)) {
scale_factor = 0.3611f;
// TODO: shift_factor
} else if (sd_version_is_wan(version)) {
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
scale_factor = 1.0f;
}
@ -325,7 +333,7 @@ public:
{
clip_backend = backend;
bool use_t5xxl = false;
if (sd_version_is_dit(version)) {
if (sd_version_is_dit(version) && !sd_version_is_qwen_image(version)) {
use_t5xxl = true;
}
if (!clip_on_cpu && !ggml_backend_is_cpu(backend) && use_t5xxl) {
@ -411,6 +419,16 @@ public:
clip_vision->alloc_params_buffer();
clip_vision->get_param_tensors(tensors);
}
} else if (sd_version_is_qwen_image(version)) {
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types);
diffusion_model = std::make_shared<QwenImageModel>(backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
"model.diffusion_model",
version,
sd_ctx_params->diffusion_flash_attn);
} else { // SD1.x SD2.x SDXL
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
@ -459,7 +477,7 @@ public:
vae_backend = backend;
}
if (sd_version_is_wan(version)) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
offload_params_to_cpu,
model_loader.tensor_storages_types,
@ -704,6 +722,13 @@ public:
shift = 5.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (sd_version_is_qwen_image(version)) {
LOG_INFO("running in FLOW mode");
float shift = sd_ctx_params->flow_shift;
if (shift == INFINITY) {
shift = 3.0;
}
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
} else if (is_using_v_parameterization) {
LOG_INFO("running in v-prediction mode");
denoiser = std::make_shared<CompVisVDenoiser>();
@ -1402,7 +1427,7 @@ public:
}
void process_latent_in(ggml_tensor* latent) {
if (sd_version_is_wan(version)) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
@ -1442,7 +1467,7 @@ public:
}
void process_latent_out(ggml_tensor* latent) {
if (sd_version_is_wan(version)) {
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
@ -1511,6 +1536,9 @@ public:
}
int64_t t0 = ggml_time_ms();
if (!use_tiny_autoencoder) {
if (sd_version_is_qwen_image(version)) {
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
}
process_latent_out(x);
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
if (vae_tiling_params.enabled && !decode_video) {
@ -1682,6 +1710,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"clip_g_path: %s\n"
"clip_vision_path: %s\n"
"t5xxl_path: %s\n"
"qwen2vl_path: %s\n"
"diffusion_model_path: %s\n"
"high_noise_diffusion_model_path: %s\n"
"vae_path: %s\n"
@ -1709,6 +1738,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->clip_g_path),
SAFE_STR(sd_ctx_params->clip_vision_path),
SAFE_STR(sd_ctx_params->t5xxl_path),
SAFE_STR(sd_ctx_params->qwen2vl_path),
SAFE_STR(sd_ctx_params->diffusion_model_path),
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
SAFE_STR(sd_ctx_params->vae_path),
@ -2066,6 +2096,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
C = 16;
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
C = 16;
} else if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
C = 16;
}
int W = width / 8;
int H = height / 8;
@ -2215,6 +2247,8 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
C = 16;
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
C = 16;
} else if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
C = 16;
} else if (sd_version_is_wan(sd_ctx->sd->version)) {
C = 16;
T = ((T - 1) / 4) + 1;

View File

@ -131,6 +131,7 @@ typedef struct {
const char* clip_g_path;
const char* clip_vision_path;
const char* t5xxl_path;
const char* qwen2vl_path;
const char* diffusion_model_path;
const char* high_noise_diffusion_model_path;
const char* vae_path;