Merge branch 't5_fix' into qwen_image_edit

This commit is contained in:
leejet 2025-10-13 00:04:58 +08:00
commit 162d5cef64
11 changed files with 387 additions and 143 deletions

View File

@ -21,6 +21,7 @@ API and command-line option may change frequently.***
- [SD3/SD3.5](./docs/sd3.md)
- [Flux-dev/Flux-schnell](./docs/flux.md)
- [Chroma](./docs/chroma.md)
- [Qwen Image](./docs/qwen_image.md)
- Image Edit Models
- [FLUX.1-Kontext-dev](./docs/kontext.md)
- Video Models
@ -285,7 +286,7 @@ usage: ./bin/sd [arguments]
arguments:
-h, --help show this help message and exit
-M, --mode [MODE] run mode, one of: [img_gen, vid_gen, convert], default: img_gen
-M, --mode [MODE] run mode, one of: [img_gen, vid_gen, upscale, convert], default: img_gen
-t, --threads N number of threads to use during computation (default: -1)
If threads <= 0, then threads will be set to the number of CPU physical cores
--offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed
@ -296,11 +297,12 @@ arguments:
--clip_g path to the clip-g text encoder
--clip_vision path to the clip-vision encoder
--t5xxl path to the t5xxl text encoder
--qwen2vl path to the qwen2vl text encoder
--vae [VAE] path to vae
--taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
--control-net [CONTROL_PATH] path to control net model
--embd-dir [EMBEDDING_PATH] path to embeddings
--upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
--upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
--upscale-repeats Run the ESRGAN upscaler this many times (default 1)
--type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)
If not specified, the default is the type of the weight file
@ -464,6 +466,7 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp
## References
- [ggml](https://github.com/ggerganov/ggml)
- [diffusers](https://github.com/huggingface/diffusers)
- [stable-diffusion](https://github.com/CompVis/stable-diffusion)
- [sd3-ref](https://github.com/Stability-AI/sd3-ref)
- [stable-diffusion-stability-ai](https://github.com/Stability-AI/stablediffusion)

BIN
assets/qwen/example.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

23
docs/qwen_image.md Normal file
View File

@ -0,0 +1,23 @@
# How to Use
## Download weights
- Download Qwen Image
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/diffusion_models
- gguf: https://huggingface.co/QuantStack/Qwen-Image-GGUF/tree/main
- Download vae
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/vae
- Download qwen_2.5_vl 7b
- safetensors: https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/tree/main/split_files/text_encoders
- gguf: https://huggingface.co/mradermacher/Qwen2.5-VL-7B-Instruct-GGUF/tree/main
## Examples
```
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\qwen-image-Q8_0.gguf --vae ..\..\ComfyUI\models\vae\qwen_image_vae.safetensors --qwen2vl ..\..\ComfyUI\models\text_encoders\Qwen2.5-VL-7B-Instruct-Q8_0.gguf -p '一个穿着"QWEN"标志的T恤的中国美女正拿着黑色的马克笔面相镜头微笑。她身后的玻璃板上手写体写着 “一、Qwen-Image的技术路线 探索视觉生成基础模型的极限开创理解与生成一体化的未来。二、Qwen-Image的模型特色1、复杂文字渲染。支持中英渲染、自动布局 2、精准图像编辑。支持文字编辑、物体增减、风格变换。三、Qwen-Image的未来愿景赋能专业内容创作、助力生成式AI发展。”' --cfg-scale 2.5 --sampling-method euler -v --offload-to-cpu -H 1024 -W 1024 --diffusion-fa --flow-shift 3
```
<img alt="qwen example" src="../assets/qwen/example.png" />

View File

@ -83,39 +83,44 @@ public:
class RRDBNet : public GGMLBlock {
protected:
int scale = 4; // default RealESRGAN_x4plus_anime_6B
int num_block = 6; // default RealESRGAN_x4plus_anime_6B
int scale = 4;
int num_block = 23;
int num_in_ch = 3;
int num_out_ch = 3;
int num_feat = 64; // default RealESRGAN_x4plus_anime_6B
int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B
int num_feat = 64;
int num_grow_ch = 32;
public:
RRDBNet() {
RRDBNet(int scale, int num_block, int num_in_ch, int num_out_ch, int num_feat, int num_grow_ch)
: scale(scale), num_block(num_block), num_in_ch(num_in_ch), num_out_ch(num_out_ch), num_feat(num_feat), num_grow_ch(num_grow_ch) {
blocks["conv_first"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_in_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
for (int i = 0; i < num_block; i++) {
std::string name = "body." + std::to_string(i);
blocks[name] = std::shared_ptr<GGMLBlock>(new RRDB(num_feat, num_grow_ch));
}
blocks["conv_body"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
// upsample
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
if (scale >= 2) {
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
}
if (scale == 4) {
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
}
blocks["conv_hr"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
blocks["conv_last"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_out_ch, {3, 3}, {1, 1}, {1, 1}));
}
int get_scale() { return scale; }
int get_num_block() { return num_block; }
struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
return ggml_leaky_relu(ctx, x, 0.2f, true);
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
// x: [n, num_in_ch, h, w]
// return: [n, num_out_ch, h*4, w*4]
// return: [n, num_out_ch, h*scale, w*scale]
auto conv_first = std::dynamic_pointer_cast<Conv2d>(blocks["conv_first"]);
auto conv_body = std::dynamic_pointer_cast<Conv2d>(blocks["conv_body"]);
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
auto conv_hr = std::dynamic_pointer_cast<Conv2d>(blocks["conv_hr"]);
auto conv_last = std::dynamic_pointer_cast<Conv2d>(blocks["conv_last"]);
@ -130,15 +135,22 @@ public:
body_feat = conv_body->forward(ctx, body_feat);
feat = ggml_add(ctx, feat, body_feat);
// upsample
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
if (scale >= 2) {
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
if (scale == 4) {
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
}
}
// for all scales
auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat)));
return out;
}
};
struct ESRGAN : public GGMLRunner {
RRDBNet rrdb_net;
std::unique_ptr<RRDBNet> rrdb_net;
int scale = 4;
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
@ -146,12 +158,14 @@ struct ESRGAN : public GGMLRunner {
bool offload_params_to_cpu,
const String2GGMLType& tensor_types = {})
: GGMLRunner(backend, offload_params_to_cpu) {
rrdb_net.init(params_ctx, tensor_types, "");
// rrdb_net will be created in load_from_file
}
void enable_conv2d_direct() {
if (!rrdb_net)
return;
std::vector<GGMLBlock*> blocks;
rrdb_net.get_all_blocks(blocks);
rrdb_net->get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
@ -167,31 +181,185 @@ struct ESRGAN : public GGMLRunner {
bool load_from_file(const std::string& file_path, int n_threads) {
LOG_INFO("loading esrgan from '%s'", file_path.c_str());
alloc_params_buffer();
std::map<std::string, ggml_tensor*> esrgan_tensors;
rrdb_net.get_param_tensors(esrgan_tensors);
ModelLoader model_loader;
if (!model_loader.init_from_file(file_path)) {
LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str());
return false;
}
bool success = model_loader.load_tensors(esrgan_tensors, {}, n_threads);
// Get tensor names
auto tensor_names = model_loader.get_tensor_names();
// Detect if it's ESRGAN format
bool is_ESRGAN = std::find(tensor_names.begin(), tensor_names.end(), "model.0.weight") != tensor_names.end();
// Detect parameters from tensor names
int detected_num_block = 0;
if (is_ESRGAN) {
for (const auto& name : tensor_names) {
if (name.find("model.1.sub.") == 0) {
size_t first_dot = name.find('.', 12);
if (first_dot != std::string::npos) {
size_t second_dot = name.find('.', first_dot + 1);
if (second_dot != std::string::npos && name.substr(first_dot + 1, 3) == "RDB") {
try {
int idx = std::stoi(name.substr(12, first_dot - 12));
detected_num_block = std::max(detected_num_block, idx + 1);
} catch (...) {
}
}
}
}
}
} else {
// Original format
for (const auto& name : tensor_names) {
if (name.find("body.") == 0) {
size_t pos = name.find('.', 5);
if (pos != std::string::npos) {
try {
int idx = std::stoi(name.substr(5, pos - 5));
detected_num_block = std::max(detected_num_block, idx + 1);
} catch (...) {
}
}
}
}
}
int detected_scale = 4; // default
if (is_ESRGAN) {
// For ESRGAN format, detect scale by highest model number
int max_model_num = 0;
for (const auto& name : tensor_names) {
if (name.find("model.") == 0) {
size_t dot_pos = name.find('.', 6);
if (dot_pos != std::string::npos) {
try {
int num = std::stoi(name.substr(6, dot_pos - 6));
max_model_num = std::max(max_model_num, num);
} catch (...) {
}
}
}
}
if (max_model_num <= 4) {
detected_scale = 1;
} else if (max_model_num <= 7) {
detected_scale = 2;
} else {
detected_scale = 4;
}
} else {
// Original format
bool has_conv_up2 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) {
return name == "conv_up2.weight";
});
bool has_conv_up1 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) {
return name == "conv_up1.weight";
});
if (has_conv_up2) {
detected_scale = 4;
} else if (has_conv_up1) {
detected_scale = 2;
} else {
detected_scale = 1;
}
}
int detected_num_in_ch = 3;
int detected_num_out_ch = 3;
int detected_num_feat = 64;
int detected_num_grow_ch = 32;
// Create RRDBNet with detected parameters
rrdb_net = std::make_unique<RRDBNet>(detected_scale, detected_num_block, detected_num_in_ch, detected_num_out_ch, detected_num_feat, detected_num_grow_ch);
rrdb_net->init(params_ctx, {}, "");
alloc_params_buffer();
std::map<std::string, ggml_tensor*> esrgan_tensors;
rrdb_net->get_param_tensors(esrgan_tensors);
bool success;
if (is_ESRGAN) {
// Build name mapping for ESRGAN format
std::map<std::string, std::string> expected_to_model;
expected_to_model["conv_first.weight"] = "model.0.weight";
expected_to_model["conv_first.bias"] = "model.0.bias";
for (int i = 0; i < detected_num_block; i++) {
for (int j = 1; j <= 3; j++) {
for (int k = 1; k <= 5; k++) {
std::string expected_weight = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".weight";
std::string model_weight = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.weight";
expected_to_model[expected_weight] = model_weight;
std::string expected_bias = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".bias";
std::string model_bias = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.bias";
expected_to_model[expected_bias] = model_bias;
}
}
}
if (detected_scale == 1) {
expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight";
expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias";
expected_to_model["conv_hr.weight"] = "model.2.weight";
expected_to_model["conv_hr.bias"] = "model.2.bias";
expected_to_model["conv_last.weight"] = "model.4.weight";
expected_to_model["conv_last.bias"] = "model.4.bias";
} else {
expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight";
expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias";
if (detected_scale >= 2) {
expected_to_model["conv_up1.weight"] = "model.3.weight";
expected_to_model["conv_up1.bias"] = "model.3.bias";
}
if (detected_scale == 4) {
expected_to_model["conv_up2.weight"] = "model.6.weight";
expected_to_model["conv_up2.bias"] = "model.6.bias";
expected_to_model["conv_hr.weight"] = "model.8.weight";
expected_to_model["conv_hr.bias"] = "model.8.bias";
expected_to_model["conv_last.weight"] = "model.10.weight";
expected_to_model["conv_last.bias"] = "model.10.bias";
} else if (detected_scale == 2) {
expected_to_model["conv_hr.weight"] = "model.5.weight";
expected_to_model["conv_hr.bias"] = "model.5.bias";
expected_to_model["conv_last.weight"] = "model.7.weight";
expected_to_model["conv_last.bias"] = "model.7.bias";
}
}
std::map<std::string, ggml_tensor*> model_tensors;
for (auto& p : esrgan_tensors) {
auto it = expected_to_model.find(p.first);
if (it != expected_to_model.end()) {
model_tensors[it->second] = p.second;
}
}
success = model_loader.load_tensors(model_tensors, {}, n_threads);
} else {
success = model_loader.load_tensors(esrgan_tensors, {}, n_threads);
}
if (!success) {
LOG_ERROR("load esrgan tensors from model loader failed");
return false;
}
LOG_INFO("esrgan model loaded");
scale = rrdb_net->get_scale();
LOG_INFO("esrgan model loaded with scale=%d, num_block=%d", scale, detected_num_block);
return success;
}
struct ggml_cgraph* build_graph(struct ggml_tensor* x) {
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
x = to_backend(x);
struct ggml_tensor* out = rrdb_net.forward(compute_ctx, x);
if (!rrdb_net)
return nullptr;
constexpr int kGraphNodes = 1 << 16; // 65k
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, kGraphNodes, /*grads*/ false);
x = to_backend(x);
struct ggml_tensor* out = rrdb_net->forward(compute_ctx, x);
ggml_build_forward_expand(gf, out);
return gf;
}

View File

@ -41,13 +41,15 @@ const char* modes_str[] = {
"img_gen",
"vid_gen",
"convert",
"upscale",
};
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert"
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale"
enum SDMode {
IMG_GEN,
VID_GEN,
CONVERT,
UPSCALE,
MODE_COUNT
};
@ -208,7 +210,7 @@ void print_usage(int argc, const char* argv[]) {
printf("\n");
printf("arguments:\n");
printf(" -h, --help show this help message and exit\n");
printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, convert], default: img_gen\n");
printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, upscale, convert], default: img_gen\n");
printf(" -t, --threads N number of threads to use during computation (default: -1)\n");
printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n");
printf(" --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed\n");
@ -225,7 +227,7 @@ void print_usage(int argc, const char* argv[]) {
printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n");
printf(" --control-net [CONTROL_PATH] path to control net model\n");
printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n");
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n");
printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n");
printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n");
printf(" If not specified, the default is the type of the weight file\n");
@ -825,13 +827,13 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.n_threads = get_num_physical_cores();
}
if (params.mode != CONVERT && params.mode != VID_GEN && params.prompt.length() == 0) {
if ((params.mode == IMG_GEN || params.mode == VID_GEN) && params.prompt.length() == 0) {
fprintf(stderr, "error: the following arguments are required: prompt\n");
print_usage(argc, argv);
exit(1);
}
if (params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) {
if (params.mode != UPSCALE && params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) {
fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n");
print_usage(argc, argv);
exit(1);
@ -891,6 +893,17 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1);
}
if (params.mode == UPSCALE) {
if (params.esrgan_path.length() == 0) {
fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n");
exit(1);
}
if (params.init_image_path.length() == 0) {
fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n");
exit(1);
}
}
if (params.seed < 0) {
srand((int)time(NULL));
params.seed = rand();
@ -901,14 +914,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
params.output_path = "output.gguf";
}
}
if (!isfinite(params.sample_params.guidance.img_cfg)) {
params.sample_params.guidance.img_cfg = params.sample_params.guidance.txt_cfg;
}
if (!isfinite(params.high_noise_sample_params.guidance.img_cfg)) {
params.high_noise_sample_params.guidance.img_cfg = params.high_noise_sample_params.guidance.txt_cfg;
}
}
static std::string sd_basename(const std::string& path) {
@ -1362,76 +1367,92 @@ int main(int argc, const char* argv[]) {
params.flow_shift,
};
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
sd_image_t* results = nullptr;
int num_results = 0;
if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");
release_all_resources();
return 1;
}
if (params.mode == UPSCALE) {
num_results = 1;
results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t));
if (results == NULL) {
printf("failed to allocate results array\n");
release_all_resources();
return 1;
}
if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
}
results[0] = init_image;
init_image.data = NULL;
} else {
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
sd_image_t* results;
int num_results = 1;
if (params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
init_image,
ref_images.data(),
(int)ref_images.size(),
params.increase_ref_index,
mask_image,
params.width,
params.height,
params.sample_params,
params.strength,
params.seed,
params.batch_count,
control_image,
params.control_strength,
{
pmid_images.data(),
(int)pmid_images.size(),
params.pm_id_embed_path.c_str(),
params.pm_style_strength,
}, // pm_params
params.vae_tiling_params,
};
if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n");
release_all_resources();
return 1;
}
results = generate_image(sd_ctx, &img_gen_params);
num_results = params.batch_count;
} else if (params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
init_image,
end_image,
control_frames.data(),
(int)control_frames.size(),
params.width,
params.height,
params.sample_params,
params.high_noise_sample_params,
params.moe_boundary,
params.strength,
params.seed,
params.video_frames,
params.vace_strength,
};
if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) {
params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
}
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
}
if (params.mode == IMG_GEN) {
sd_img_gen_params_t img_gen_params = {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
init_image,
ref_images.data(),
(int)ref_images.size(),
params.increase_ref_index,
mask_image,
params.width,
params.height,
params.sample_params,
params.strength,
params.seed,
params.batch_count,
control_image,
params.control_strength,
{
pmid_images.data(),
(int)pmid_images.size(),
params.pm_id_embed_path.c_str(),
params.pm_style_strength,
}, // pm_params
params.vae_tiling_params,
};
results = generate_image(sd_ctx, &img_gen_params);
num_results = params.batch_count;
} else if (params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = {
params.prompt.c_str(),
params.negative_prompt.c_str(),
params.clip_skip,
init_image,
end_image,
control_frames.data(),
(int)control_frames.size(),
params.width,
params.height,
params.sample_params,
params.high_noise_sample_params,
params.moe_boundary,
params.strength,
params.seed,
params.video_frames,
params.vace_strength,
};
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
}
if (results == NULL) {
printf("generate failed\n");
free_sd_ctx(sd_ctx);
return 1;
}
if (results == NULL) {
printf("generate failed\n");
free_sd_ctx(sd_ctx);
return 1;
}
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
@ -1444,7 +1465,7 @@ int main(int argc, const char* argv[]) {
if (upscaler_ctx == NULL) {
printf("new_upscaler_ctx failed\n");
} else {
for (int i = 0; i < params.batch_count; i++) {
for (int i = 0; i < num_results; i++) {
if (results[i].data == NULL) {
continue;
}
@ -1530,7 +1551,6 @@ int main(int argc, const char* argv[]) {
results[i].data = NULL;
}
free(results);
free_sd_ctx(sd_ctx);
release_all_resources();

View File

@ -480,12 +480,15 @@ __STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input,
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
int64_t ne3 = output->ne[3];
GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32);
for (int iy = 0; iy < height; iy++) {
for (int ix = 0; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
float value = ggml_tensor_get_f32(input, ix + x, iy + y, k);
ggml_tensor_set_f32(output, value, ix, iy, k);
for (int l = 0; l < ne3; l++) {
float value = ggml_tensor_get_f32(input, ix + x, iy + y, k, l);
ggml_tensor_set_f32(output, value, ix, iy, k, l);
}
}
}
}
@ -508,6 +511,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
int64_t width = input->ne[0];
int64_t height = input->ne[1];
int64_t channels = input->ne[2];
int64_t ne3 = input->ne[3];
int64_t img_width = output->ne[0];
int64_t img_height = output->ne[1];
@ -516,24 +520,26 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input,
for (int iy = y_skip; iy < height; iy++) {
for (int ix = x_skip; ix < width; ix++) {
for (int k = 0; k < channels; k++) {
float new_value = ggml_tensor_get_f32(input, ix, iy, k);
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k);
for (int l = 0; l < ne3; l++) {
float new_value = ggml_tensor_get_f32(input, ix, iy, k, l);
if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area
float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k, l);
const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1;
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1;
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1;
const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1;
const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1;
const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1;
const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1;
const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f);
const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f);
ggml_tensor_set_f32(
output,
old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f),
x + ix, y + iy, k);
} else {
ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k);
ggml_tensor_set_f32(
output,
old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f),
x + ix, y + iy, k, l);
} else {
ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k, l);
}
}
}
}
@ -849,8 +855,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
}
struct ggml_init_params params = {};
params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * sizeof(float); // input chunk
params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * sizeof(float); // output chunk
params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * input->ne[3] * sizeof(float); // input chunk
params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * output->ne[3] * sizeof(float); // output chunk
params.mem_size += 3 * ggml_tensor_overhead();
params.mem_buffer = NULL;
params.no_alloc = false;
@ -865,8 +871,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
}
// tiling
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], 1);
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], 1);
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]);
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]);
int num_tiles = num_tiles_x * num_tiles_y;
LOG_INFO("processing %i tiles", num_tiles);
pretty_progress(0, num_tiles, 0.0f);

View File

@ -269,6 +269,14 @@ public:
std::set<std::string> ignore_tensors = {},
int n_threads = 0);
std::vector<std::string> get_tensor_names() const {
std::vector<std::string> names;
for (const auto& ts : tensor_storages) {
names.push_back(ts.name);
}
return names;
}
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);

View File

@ -354,17 +354,7 @@ public:
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
{
clip_backend = backend;
bool use_t5xxl = false;
if (sd_version_is_dit(version) && !sd_version_is_qwen_image(version)) {
use_t5xxl = true;
}
if (!clip_on_cpu && !ggml_backend_is_cpu(backend) && use_t5xxl) {
LOG_WARN(
"!!!It appears that you are using the T5 model. Some backends may encounter issues with it."
"If you notice that the generated images are completely black,"
"try running the T5 model on the CPU using the --clip-on-cpu parameter.");
}
clip_backend = backend;
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("CLIP: Using CPU backend");
clip_backend = ggml_backend_cpu_init();
@ -1120,7 +1110,7 @@ public:
std::vector<int> skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count);
float cfg_scale = guidance.txt_cfg;
float img_cfg_scale = guidance.img_cfg;
float img_cfg_scale = isfinite(guidance.img_cfg) ? guidance.img_cfg : guidance.txt_cfg;
float slg_scale = guidance.slg.scale;
if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) {
@ -1464,10 +1454,19 @@ public:
if (vae_tiling_params.enabled && !encode_video) {
// TODO wan2.2 vae support?
int C = sd_version_is_dit(version) ? 16 : 4;
if (!use_tiny_autoencoder) {
C *= 2;
int ne2;
int ne3;
if (sd_version_is_qwen_image(version)) {
ne2 = 1;
ne3 = C*x->ne[3];
} else {
if (!use_tiny_autoencoder) {
C *= 2;
}
ne2 = C;
ne3 = x->ne[3];
}
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, x->ne[3]);
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
}
if (sd_version_is_qwen_image(version)) {
@ -1861,7 +1860,9 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
"eta: %.2f, "
"shifted_timestep: %d)",
sample_params->guidance.txt_cfg,
sample_params->guidance.img_cfg,
isfinite(sample_params->guidance.img_cfg)
? sample_params->guidance.img_cfg
: sample_params->guidance.txt_cfg,
sample_params->guidance.distilled_guidance,
sample_params->guidance.slg.layer_count,
sample_params->guidance.slg.layer_start,
@ -2023,6 +2024,10 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
seed = rand();
}
if (!isfinite(guidance.img_cfg)) {
guidance.img_cfg = guidance.txt_cfg;
}
// for (auto v : sigmas) {
// std::cout << v << " ";
// }

View File

@ -285,6 +285,8 @@ SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx,
sd_image_t input_image,
uint32_t upscale_factor);
SD_API int get_upscale_factor(upscaler_ctx_t* upscaler_ctx);
SD_API bool convert(const char* input_path,
const char* vae_path,
const char* output_path,

4
t5.hpp
View File

@ -504,7 +504,9 @@ public:
T5DenseGatedActDense(int64_t model_dim, int64_t ff_dim) {
blocks["wi_0"] = std::shared_ptr<GGMLBlock>(new Linear(model_dim, ff_dim, false));
blocks["wi_1"] = std::shared_ptr<GGMLBlock>(new Linear(model_dim, ff_dim, false));
blocks["wo"] = std::shared_ptr<GGMLBlock>(new Linear(ff_dim, model_dim, false));
float scale = 1.f / 32.f;
// The purpose of the scale here is to prevent NaN issues on some backends(CUDA, ...).
blocks["wo"] = std::shared_ptr<GGMLBlock>(new Linear(ff_dim, model_dim, false, false, false, scale));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {

View File

@ -138,6 +138,13 @@ sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_
return upscaler_ctx->upscaler->upscale(input_image, upscale_factor);
}
int get_upscale_factor(upscaler_ctx_t* upscaler_ctx) {
if (upscaler_ctx == NULL || upscaler_ctx->upscaler == NULL || upscaler_ctx->upscaler->esrgan_upscaler == NULL) {
return 1;
}
return upscaler_ctx->upscaler->esrgan_upscaler->scale;
}
void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx) {
if (upscaler_ctx->upscaler != NULL) {
delete upscaler_ctx->upscaler;