mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add qwen image t2i pipeline
This commit is contained in:
parent
d232509b6e
commit
cf19c6e759
1
clip.hpp
1
clip.hpp
@ -398,7 +398,6 @@ public:
|
||||
}
|
||||
for (auto& token : matches) {
|
||||
std::string token_str = token.str();
|
||||
LOG_DEBUG("%s", token_str.c_str());
|
||||
std::u32string utf32_token;
|
||||
for (int i = 0; i < token_str.length(); i++) {
|
||||
unsigned char b = token_str[i];
|
||||
|
||||
192
conditioner.hpp
192
conditioner.hpp
@ -2,6 +2,7 @@
|
||||
#define __CONDITIONER_HPP__
|
||||
|
||||
#include "clip.hpp"
|
||||
#include "qwenvl.hpp"
|
||||
#include "t5.hpp"
|
||||
|
||||
struct SDCondition {
|
||||
@ -22,11 +23,11 @@ struct Conditioner {
|
||||
int width,
|
||||
int height,
|
||||
int adm_in_channels = -1,
|
||||
bool zero_out_masked = false) = 0;
|
||||
virtual void alloc_params_buffer() = 0;
|
||||
virtual void free_params_buffer() = 0;
|
||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
||||
virtual size_t get_params_buffer_size() = 0;
|
||||
bool zero_out_masked = false) = 0;
|
||||
virtual void alloc_params_buffer() = 0;
|
||||
virtual void free_params_buffer() = 0;
|
||||
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
|
||||
virtual size_t get_params_buffer_size() = 0;
|
||||
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||
int n_threads,
|
||||
const std::string& text,
|
||||
@ -35,9 +36,13 @@ struct Conditioner {
|
||||
int height,
|
||||
int num_input_imgs,
|
||||
int adm_in_channels = -1,
|
||||
bool zero_out_masked = false) = 0;
|
||||
bool zero_out_masked = false) {
|
||||
GGML_ABORT("Not implemented yet!");
|
||||
}
|
||||
virtual std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
||||
const std::string& prompt) = 0;
|
||||
const std::string& prompt) {
|
||||
GGML_ABORT("Not implemented yet!");
|
||||
}
|
||||
};
|
||||
|
||||
// ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
@ -978,23 +983,6 @@ struct SD3CLIPEmbedder : public Conditioner {
|
||||
auto tokens_and_weights = tokenize(text, 77, true);
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||
}
|
||||
|
||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||
int n_threads,
|
||||
const std::string& text,
|
||||
int clip_skip,
|
||||
int width,
|
||||
int height,
|
||||
int num_input_imgs,
|
||||
int adm_in_channels = -1,
|
||||
bool zero_out_masked = false) {
|
||||
GGML_ASSERT(0 && "Not implemented yet!");
|
||||
}
|
||||
|
||||
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
||||
const std::string& prompt) {
|
||||
GGML_ASSERT(0 && "Not implemented yet!");
|
||||
}
|
||||
};
|
||||
|
||||
struct FluxCLIPEmbedder : public Conditioner {
|
||||
@ -1195,23 +1183,6 @@ struct FluxCLIPEmbedder : public Conditioner {
|
||||
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||
}
|
||||
|
||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||
int n_threads,
|
||||
const std::string& text,
|
||||
int clip_skip,
|
||||
int width,
|
||||
int height,
|
||||
int num_input_imgs,
|
||||
int adm_in_channels = -1,
|
||||
bool zero_out_masked = false) {
|
||||
GGML_ASSERT(0 && "Not implemented yet!");
|
||||
}
|
||||
|
||||
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
||||
const std::string& prompt) {
|
||||
GGML_ASSERT(0 && "Not implemented yet!");
|
||||
}
|
||||
};
|
||||
|
||||
struct T5CLIPEmbedder : public Conditioner {
|
||||
@ -1398,22 +1369,135 @@ struct T5CLIPEmbedder : public Conditioner {
|
||||
auto tokens_and_weights = tokenize(text, chunk_len, true);
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||
}
|
||||
};
|
||||
|
||||
std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
|
||||
int n_threads,
|
||||
const std::string& text,
|
||||
int clip_skip,
|
||||
int width,
|
||||
int height,
|
||||
int num_input_imgs,
|
||||
int adm_in_channels = -1,
|
||||
bool zero_out_masked = false) {
|
||||
GGML_ASSERT(0 && "Not implemented yet!");
|
||||
struct Qwen2_5_VLCLIPEmbedder : public Conditioner {
|
||||
Qwen::Qwen2Tokenizer tokenizer;
|
||||
std::shared_ptr<Qwen::Qwen2_5_VLRunner> qwenvl;
|
||||
int prompt_template_encode_start_idx = 34;
|
||||
|
||||
Qwen2_5_VLCLIPEmbedder(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "") {
|
||||
qwenvl = std::make_shared<Qwen::Qwen2_5_VLRunner>(backend, offload_params_to_cpu, tensor_types, "text_encoders.qwen2vl");
|
||||
}
|
||||
|
||||
std::string remove_trigger_from_prompt(ggml_context* work_ctx,
|
||||
const std::string& prompt) {
|
||||
GGML_ASSERT(0 && "Not implemented yet!");
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
|
||||
qwenvl->get_param_tensors(tensors, "text_encoders.qwen2vl");
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
qwenvl->alloc_params_buffer();
|
||||
}
|
||||
|
||||
void free_params_buffer() {
|
||||
qwenvl->free_params_buffer();
|
||||
}
|
||||
|
||||
size_t get_params_buffer_size() {
|
||||
size_t buffer_size = 0;
|
||||
buffer_size += qwenvl->get_params_buffer_size();
|
||||
return buffer_size;
|
||||
}
|
||||
|
||||
std::tuple<std::vector<int>, std::vector<float>> tokenize(std::string text,
|
||||
size_t max_length = 0,
|
||||
bool padding = false) {
|
||||
auto parsed_attention = parse_prompt_attention(text);
|
||||
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "[";
|
||||
for (const auto& item : parsed_attention) {
|
||||
ss << "['" << item.first << "', " << item.second << "], ";
|
||||
}
|
||||
ss << "]";
|
||||
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
|
||||
}
|
||||
|
||||
std::vector<int> tokens;
|
||||
std::vector<float> weights;
|
||||
for (const auto& item : parsed_attention) {
|
||||
const std::string& curr_text = item.first;
|
||||
float curr_weight = item.second;
|
||||
std::vector<int> curr_tokens = tokenizer.tokenize(curr_text, nullptr);
|
||||
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
|
||||
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
|
||||
}
|
||||
|
||||
tokenizer.pad_tokens(tokens, weights, max_length, padding);
|
||||
|
||||
// for (int i = 0; i < tokens.size(); i++) {
|
||||
// std::cout << tokens[i] << ":" << weights[i] << ", ";
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
|
||||
return {tokens, weights};
|
||||
}
|
||||
|
||||
SDCondition get_learned_condition_common(ggml_context* work_ctx,
|
||||
int n_threads,
|
||||
std::tuple<std::vector<int>, std::vector<float>> token_and_weights,
|
||||
int clip_skip,
|
||||
bool zero_out_masked = false) {
|
||||
auto& tokens = std::get<0>(token_and_weights);
|
||||
auto& weights = std::get<1>(token_and_weights);
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
struct ggml_tensor* hidden_states = NULL; // [N, n_token, 3584]
|
||||
|
||||
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
|
||||
|
||||
qwenvl->compute(n_threads,
|
||||
input_ids,
|
||||
&hidden_states,
|
||||
work_ctx);
|
||||
{
|
||||
auto tensor = hidden_states;
|
||||
float original_mean = ggml_tensor_mean(tensor);
|
||||
for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
|
||||
for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
|
||||
for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
|
||||
float value = ggml_tensor_get_f32(tensor, i0, i1, i2);
|
||||
value *= weights[i1];
|
||||
ggml_tensor_set_f32(tensor, value, i0, i1, i2);
|
||||
}
|
||||
}
|
||||
}
|
||||
float new_mean = ggml_tensor_mean(tensor);
|
||||
ggml_tensor_scale(tensor, (original_mean / new_mean));
|
||||
}
|
||||
|
||||
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
|
||||
|
||||
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
|
||||
GGML_TYPE_F32,
|
||||
hidden_states->ne[0],
|
||||
hidden_states->ne[1] - prompt_template_encode_start_idx,
|
||||
hidden_states->ne[2]);
|
||||
|
||||
ggml_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = ggml_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
|
||||
ggml_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
|
||||
});
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
|
||||
return SDCondition(new_hidden_states, nullptr, nullptr);
|
||||
}
|
||||
|
||||
SDCondition get_learned_condition(ggml_context* work_ctx,
|
||||
int n_threads,
|
||||
const std::string& text,
|
||||
int clip_skip,
|
||||
int width,
|
||||
int height,
|
||||
int adm_in_channels = -1,
|
||||
bool zero_out_masked = false) {
|
||||
std::string prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n" + text + "<|im_end|>\n<|im_start|>assistant\n";
|
||||
auto tokens_and_weights = tokenize(prompt, 0, false);
|
||||
return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, zero_out_masked);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
#include "flux.hpp"
|
||||
#include "mmdit.hpp"
|
||||
#include "qwen_image.hpp"
|
||||
#include "unet.hpp"
|
||||
#include "wan.hpp"
|
||||
|
||||
@ -263,4 +264,58 @@ struct WanModel : public DiffusionModel {
|
||||
}
|
||||
};
|
||||
|
||||
struct QwenImageModel : public DiffusionModel {
|
||||
std::string prefix;
|
||||
Qwen::QwenImageRunner qwen_image;
|
||||
|
||||
QwenImageModel(ggml_backend_t backend,
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "model.diffusion_model",
|
||||
SDVersion version = VERSION_QWEN_IMAGE,
|
||||
bool flash_attn = false)
|
||||
: prefix(prefix), qwen_image(backend, offload_params_to_cpu, tensor_types, prefix, version, flash_attn) {
|
||||
}
|
||||
|
||||
std::string get_desc() {
|
||||
return qwen_image.get_desc();
|
||||
}
|
||||
|
||||
void alloc_params_buffer() {
|
||||
qwen_image.alloc_params_buffer();
|
||||
}
|
||||
|
||||
void free_params_buffer() {
|
||||
qwen_image.free_params_buffer();
|
||||
}
|
||||
|
||||
void free_compute_buffer() {
|
||||
qwen_image.free_compute_buffer();
|
||||
}
|
||||
|
||||
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) {
|
||||
qwen_image.get_param_tensors(tensors, prefix);
|
||||
}
|
||||
|
||||
size_t get_params_buffer_size() {
|
||||
return qwen_image.get_params_buffer_size();
|
||||
}
|
||||
|
||||
int64_t get_adm_in_channels() {
|
||||
return 768;
|
||||
}
|
||||
|
||||
void compute(int n_threads,
|
||||
DiffusionParams diffusion_params,
|
||||
struct ggml_tensor** output = NULL,
|
||||
struct ggml_context* output_ctx = NULL) {
|
||||
return qwen_image.compute(n_threads,
|
||||
diffusion_params.x,
|
||||
diffusion_params.timesteps,
|
||||
diffusion_params.context,
|
||||
output,
|
||||
output_ctx);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@ -27,8 +27,6 @@
|
||||
|
||||
#include "avi_writer.h"
|
||||
|
||||
#include "qwen_image.hpp"
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define NOMINMAX
|
||||
#include <windows.h>
|
||||
@ -61,6 +59,7 @@ struct SDParams {
|
||||
std::string clip_g_path;
|
||||
std::string clip_vision_path;
|
||||
std::string t5xxl_path;
|
||||
std::string qwen2vl_path;
|
||||
std::string diffusion_model_path;
|
||||
std::string high_noise_diffusion_model_path;
|
||||
std::string vae_path;
|
||||
@ -146,6 +145,7 @@ void print_params(SDParams params) {
|
||||
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
|
||||
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
|
||||
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
|
||||
printf(" qwen2vl_path: %s\n", params.qwen2vl_path.c_str());
|
||||
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
|
||||
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
|
||||
printf(" vae_path: %s\n", params.vae_path.c_str());
|
||||
@ -217,6 +217,7 @@ void print_usage(int argc, const char* argv[]) {
|
||||
printf(" --clip_g path to the clip-g text encoder\n");
|
||||
printf(" --clip_vision path to the clip-vision encoder\n");
|
||||
printf(" --t5xxl path to the t5xxl text encoder\n");
|
||||
printf(" --qwen2vl path to the qwen2vl text encoder\n");
|
||||
printf(" --vae [VAE] path to vae\n");
|
||||
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
|
||||
printf(" --control-net [CONTROL_PATH] path to control net model\n");
|
||||
@ -486,6 +487,7 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
||||
{"", "--clip_g", "", ¶ms.clip_g_path},
|
||||
{"", "--clip_vision", "", ¶ms.clip_vision_path},
|
||||
{"", "--t5xxl", "", ¶ms.t5xxl_path},
|
||||
{"", "--qwen2vl", "", ¶ms.qwen2vl_path},
|
||||
{"", "--diffusion-model", "", ¶ms.diffusion_model_path},
|
||||
{"", "--high-noise-diffusion-model", "", ¶ms.high_noise_diffusion_model_path},
|
||||
{"", "--vae", "", ¶ms.vae_path},
|
||||
@ -945,7 +947,7 @@ std::string get_image_params(SDParams params, int64_t seed) {
|
||||
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
|
||||
}
|
||||
parameter_string += ", ";
|
||||
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path}) {
|
||||
for (const auto& te : {params.clip_l_path, params.clip_g_path, params.t5xxl_path, params.qwen2vl_path}) {
|
||||
if (!te.empty()) {
|
||||
parameter_string += "TE: " + sd_basename(te) + ", ";
|
||||
}
|
||||
@ -1140,10 +1142,6 @@ bool load_images_from_dir(const std::string dir,
|
||||
|
||||
int main(int argc, const char* argv[]) {
|
||||
SDParams params;
|
||||
params.verbose = true;
|
||||
sd_set_log_callback(sd_log_cb, (void*)¶ms);
|
||||
Qwen::QwenImageRunner::load_from_file_and_test(argv[1]);
|
||||
exit(1);
|
||||
parse_args(argc, argv, params);
|
||||
params.sample_params.guidance.slg.layers = params.skip_layers.data();
|
||||
params.sample_params.guidance.slg.layer_count = params.skip_layers.size();
|
||||
@ -1323,6 +1321,7 @@ int main(int argc, const char* argv[]) {
|
||||
params.clip_g_path.c_str(),
|
||||
params.clip_vision_path.c_str(),
|
||||
params.t5xxl_path.c_str(),
|
||||
params.qwen2vl_path.c_str(),
|
||||
params.diffusion_model_path.c_str(),
|
||||
params.high_noise_diffusion_model_path.c_str(),
|
||||
params.vae_path.c_str(),
|
||||
|
||||
@ -110,9 +110,9 @@ const char* unused_tensors[] = {
|
||||
"embedding_manager",
|
||||
"denoiser.sigmas",
|
||||
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
|
||||
"qwen2vl.output.weight",
|
||||
"qwen2vl.lm_head.",
|
||||
"qwen2vl.visual.",
|
||||
"text_encoders.qwen2vl.output.weight",
|
||||
"text_encoders.qwen2vl.lm_head.",
|
||||
"text_encoders.qwen2vl.visual.",
|
||||
};
|
||||
|
||||
bool is_unused_tensor(std::string name) {
|
||||
@ -1762,6 +1762,9 @@ SDVersion ModelLoader::get_sd_version() {
|
||||
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
||||
return VERSION_SD3;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.transformer_blocks.0.img_mod.1.weight") != std::string::npos) {
|
||||
return VERSION_QWEN_IMAGE;
|
||||
}
|
||||
if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) {
|
||||
is_wan = true;
|
||||
}
|
||||
|
||||
13
model.h
13
model.h
@ -34,6 +34,7 @@ enum SDVersion {
|
||||
VERSION_WAN2,
|
||||
VERSION_WAN2_2_I2V,
|
||||
VERSION_WAN2_2_TI2V,
|
||||
VERSION_QWEN_IMAGE,
|
||||
VERSION_COUNT,
|
||||
};
|
||||
|
||||
@ -79,6 +80,13 @@ static inline bool sd_version_is_wan(SDVersion version) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_qwen_image(SDVersion version) {
|
||||
if (version == VERSION_QWEN_IMAGE) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_inpaint(SDVersion version) {
|
||||
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
|
||||
return true;
|
||||
@ -87,7 +95,10 @@ static inline bool sd_version_is_inpaint(SDVersion version) {
|
||||
}
|
||||
|
||||
static inline bool sd_version_is_dit(SDVersion version) {
|
||||
if (sd_version_is_flux(version) || sd_version_is_sd3(version) || sd_version_is_wan(version)) {
|
||||
if (sd_version_is_flux(version) ||
|
||||
sd_version_is_sd3(version) ||
|
||||
sd_version_is_wan(version) ||
|
||||
sd_version_is_qwen_image(version)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@ -54,7 +54,7 @@ namespace Qwen {
|
||||
// return: [N, embedding_dim]
|
||||
auto timestep_embedder = std::dynamic_pointer_cast<TimestepEmbedding>(blocks["timestep_embedder"]);
|
||||
|
||||
auto timesteps_proj = ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f);
|
||||
auto timesteps_proj = ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1.f);
|
||||
auto timesteps_emb = timestep_embedder->forward(ctx, timesteps_proj);
|
||||
return timesteps_emb;
|
||||
}
|
||||
@ -423,13 +423,9 @@ namespace Qwen {
|
||||
auto proj_out = std::dynamic_pointer_cast<Linear>(blocks["proj_out"]);
|
||||
|
||||
auto t_emb = time_text_embed->forward(ctx, timestep);
|
||||
LOG_DEBUG("xxx");
|
||||
auto img = img_in->forward(ctx, x);
|
||||
LOG_DEBUG("xxx");
|
||||
auto txt = txt_norm->forward(ctx, context);
|
||||
LOG_DEBUG("xxx");
|
||||
txt = txt_in->forward(ctx, txt);
|
||||
LOG_DEBUG("xxx");
|
||||
auto img = img_in->forward(ctx, x);
|
||||
auto txt = txt_norm->forward(ctx, context);
|
||||
txt = txt_in->forward(ctx, txt);
|
||||
|
||||
for (int i = 0; i < params.num_layers; i++) {
|
||||
auto block = std::dynamic_pointer_cast<QwenImageTransformerBlock>(blocks["transformer_blocks." + std::to_string(i)]);
|
||||
@ -492,7 +488,7 @@ namespace Qwen {
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {},
|
||||
const std::string prefix = "",
|
||||
SDVersion version = VERSION_FLUX,
|
||||
SDVersion version = VERSION_QWEN_IMAGE,
|
||||
bool flash_attn = false)
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
qwen_image_params.flash_attn = flash_attn;
|
||||
@ -571,13 +567,12 @@ namespace Qwen {
|
||||
GGML_ASSERT(work_ctx != NULL);
|
||||
|
||||
{
|
||||
// cpu f16:
|
||||
// auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1);
|
||||
// ggml_set_f32(x, 0.01f);
|
||||
auto x = load_tensor_from_file(work_ctx, "./qwen_image_x.bin");
|
||||
print_ggml_tensor(x);
|
||||
|
||||
std::vector<float> timesteps_vec(1, 1.f);
|
||||
std::vector<float> timesteps_vec(1, 1000.f);
|
||||
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||
|
||||
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 3584, 256, 1);
|
||||
@ -598,6 +593,7 @@ namespace Qwen {
|
||||
|
||||
static void load_from_file_and_test(const std::string& file_path) {
|
||||
// cuda q8: pass
|
||||
// cuda q8 fa: nan
|
||||
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||
ggml_type model_data_type = GGML_TYPE_Q8_0;
|
||||
@ -619,7 +615,9 @@ namespace Qwen {
|
||||
std::shared_ptr<QwenImageRunner> qwen_image = std::shared_ptr<QwenImageRunner>(new QwenImageRunner(backend,
|
||||
false,
|
||||
tensor_types,
|
||||
"model.diffusion_model"));
|
||||
"model.diffusion_model",
|
||||
VERSION_QWEN_IMAGE,
|
||||
true));
|
||||
|
||||
qwen_image->alloc_params_buffer();
|
||||
std::map<std::string, ggml_tensor*> tensors;
|
||||
|
||||
@ -331,7 +331,7 @@ namespace Qwen {
|
||||
ss << "\"" << token << "\", ";
|
||||
}
|
||||
ss << "]";
|
||||
LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
||||
// LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str());
|
||||
// printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str());
|
||||
return bpe_tokens;
|
||||
}
|
||||
|
||||
@ -40,6 +40,7 @@ const char* model_version_to_str[] = {
|
||||
"Wan 2.x",
|
||||
"Wan 2.2 I2V",
|
||||
"Wan 2.2 TI2V",
|
||||
"Qwen Image",
|
||||
};
|
||||
|
||||
const char* sampling_methods_str[] = {
|
||||
@ -251,6 +252,13 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->qwen2vl_path)) > 0) {
|
||||
LOG_INFO("loading qwen2vl from '%s'", sd_ctx_params->qwen2vl_path);
|
||||
if (!model_loader.init_from_file(sd_ctx_params->qwen2vl_path, "text_encoders.qwen2vl.")) {
|
||||
LOG_WARN("loading qwen2vl from '%s' failed", sd_ctx_params->qwen2vl_path);
|
||||
}
|
||||
}
|
||||
|
||||
if (strlen(SAFE_STR(sd_ctx_params->vae_path)) > 0) {
|
||||
LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path);
|
||||
if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) {
|
||||
@ -316,7 +324,7 @@ public:
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
scale_factor = 0.3611f;
|
||||
// TODO: shift_factor
|
||||
} else if (sd_version_is_wan(version)) {
|
||||
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||
scale_factor = 1.0f;
|
||||
}
|
||||
|
||||
@ -325,7 +333,7 @@ public:
|
||||
{
|
||||
clip_backend = backend;
|
||||
bool use_t5xxl = false;
|
||||
if (sd_version_is_dit(version)) {
|
||||
if (sd_version_is_dit(version) && !sd_version_is_qwen_image(version)) {
|
||||
use_t5xxl = true;
|
||||
}
|
||||
if (!clip_on_cpu && !ggml_backend_is_cpu(backend) && use_t5xxl) {
|
||||
@ -411,6 +419,16 @@ public:
|
||||
clip_vision->alloc_params_buffer();
|
||||
clip_vision->get_param_tensors(tensors);
|
||||
}
|
||||
} else if (sd_version_is_qwen_image(version)) {
|
||||
cond_stage_model = std::make_shared<Qwen2_5_VLCLIPEmbedder>(clip_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types);
|
||||
diffusion_model = std::make_shared<QwenImageModel>(backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
"model.diffusion_model",
|
||||
version,
|
||||
sd_ctx_params->diffusion_flash_attn);
|
||||
} else { // SD1.x SD2.x SDXL
|
||||
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
|
||||
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
|
||||
@ -459,7 +477,7 @@ public:
|
||||
vae_backend = backend;
|
||||
}
|
||||
|
||||
if (sd_version_is_wan(version)) {
|
||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
|
||||
offload_params_to_cpu,
|
||||
model_loader.tensor_storages_types,
|
||||
@ -704,6 +722,13 @@ public:
|
||||
shift = 5.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (sd_version_is_qwen_image(version)) {
|
||||
LOG_INFO("running in FLOW mode");
|
||||
float shift = sd_ctx_params->flow_shift;
|
||||
if (shift == INFINITY) {
|
||||
shift = 3.0;
|
||||
}
|
||||
denoiser = std::make_shared<DiscreteFlowDenoiser>(shift);
|
||||
} else if (is_using_v_parameterization) {
|
||||
LOG_INFO("running in v-prediction mode");
|
||||
denoiser = std::make_shared<CompVisVDenoiser>();
|
||||
@ -1402,7 +1427,7 @@ public:
|
||||
}
|
||||
|
||||
void process_latent_in(ggml_tensor* latent) {
|
||||
if (sd_version_is_wan(version)) {
|
||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
|
||||
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
||||
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
||||
@ -1442,7 +1467,7 @@ public:
|
||||
}
|
||||
|
||||
void process_latent_out(ggml_tensor* latent) {
|
||||
if (sd_version_is_wan(version)) {
|
||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
|
||||
std::vector<float> latents_mean_vec = {-0.7571f, -0.7089f, -0.9113f, 0.1075f, -0.1745f, 0.9653f, -0.1517f, 1.5508f,
|
||||
0.4134f, -0.0715f, 0.5517f, -0.3632f, -0.1922f, -0.9497f, 0.2503f, -0.2921f};
|
||||
@ -1511,6 +1536,9 @@ public:
|
||||
}
|
||||
int64_t t0 = ggml_time_ms();
|
||||
if (!use_tiny_autoencoder) {
|
||||
if (sd_version_is_qwen_image(version)) {
|
||||
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
||||
}
|
||||
process_latent_out(x);
|
||||
// x = load_tensor_from_file(work_ctx, "wan_vae_z.bin");
|
||||
if (vae_tiling_params.enabled && !decode_video) {
|
||||
@ -1682,6 +1710,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
"clip_g_path: %s\n"
|
||||
"clip_vision_path: %s\n"
|
||||
"t5xxl_path: %s\n"
|
||||
"qwen2vl_path: %s\n"
|
||||
"diffusion_model_path: %s\n"
|
||||
"high_noise_diffusion_model_path: %s\n"
|
||||
"vae_path: %s\n"
|
||||
@ -1709,6 +1738,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
|
||||
SAFE_STR(sd_ctx_params->clip_g_path),
|
||||
SAFE_STR(sd_ctx_params->clip_vision_path),
|
||||
SAFE_STR(sd_ctx_params->t5xxl_path),
|
||||
SAFE_STR(sd_ctx_params->qwen2vl_path),
|
||||
SAFE_STR(sd_ctx_params->diffusion_model_path),
|
||||
SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path),
|
||||
SAFE_STR(sd_ctx_params->vae_path),
|
||||
@ -2066,6 +2096,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
C = 16;
|
||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||
C = 16;
|
||||
} else if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
|
||||
C = 16;
|
||||
}
|
||||
int W = width / 8;
|
||||
int H = height / 8;
|
||||
@ -2215,6 +2247,8 @@ ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
|
||||
C = 16;
|
||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
||||
C = 16;
|
||||
} else if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
|
||||
C = 16;
|
||||
} else if (sd_version_is_wan(sd_ctx->sd->version)) {
|
||||
C = 16;
|
||||
T = ((T - 1) / 4) + 1;
|
||||
|
||||
@ -131,6 +131,7 @@ typedef struct {
|
||||
const char* clip_g_path;
|
||||
const char* clip_vision_path;
|
||||
const char* t5xxl_path;
|
||||
const char* qwen2vl_path;
|
||||
const char* diffusion_model_path;
|
||||
const char* high_noise_diffusion_model_path;
|
||||
const char* vae_path;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user