diff --git a/README.md b/README.md index fe64cea..0d2da62 100644 --- a/README.md +++ b/README.md @@ -286,7 +286,7 @@ usage: ./bin/sd [arguments] arguments: -h, --help show this help message and exit - -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, convert], default: img_gen + -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, upscale, convert], default: img_gen -t, --threads N number of threads to use during computation (default: -1) If threads <= 0, then threads will be set to the number of CPU physical cores --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed @@ -302,7 +302,7 @@ arguments: --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) --control-net [CONTROL_PATH] path to control net model --embd-dir [EMBEDDING_PATH] path to embeddings - --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now + --upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now --upscale-repeats Run the ESRGAN upscaler this many times (default 1) --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K) If not specified, the default is the type of the weight file diff --git a/esrgan.hpp b/esrgan.hpp index 7ede2e4..fe5f16d 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -83,39 +83,44 @@ public: class RRDBNet : public GGMLBlock { protected: - int scale = 4; // default RealESRGAN_x4plus_anime_6B - int num_block = 6; // default RealESRGAN_x4plus_anime_6B + int scale = 4; + int num_block = 23; int num_in_ch = 3; int num_out_ch = 3; - int num_feat = 64; // default RealESRGAN_x4plus_anime_6B - int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B + int num_feat = 64; + int num_grow_ch = 32; public: - RRDBNet() { + RRDBNet(int scale, int num_block, int num_in_ch, int num_out_ch, int num_feat, int num_grow_ch) + : scale(scale), num_block(num_block), num_in_ch(num_in_ch), num_out_ch(num_out_ch), num_feat(num_feat), num_grow_ch(num_grow_ch) { blocks["conv_first"] = std::shared_ptr(new Conv2d(num_in_ch, num_feat, {3, 3}, {1, 1}, {1, 1})); for (int i = 0; i < num_block; i++) { std::string name = "body." + std::to_string(i); blocks[name] = std::shared_ptr(new RRDB(num_feat, num_grow_ch)); } blocks["conv_body"] = std::shared_ptr(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1})); - // upsample - blocks["conv_up1"] = std::shared_ptr(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1})); - blocks["conv_up2"] = std::shared_ptr(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1})); + if (scale >= 2) { + blocks["conv_up1"] = std::shared_ptr(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1})); + } + if (scale == 4) { + blocks["conv_up2"] = std::shared_ptr(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1})); + } blocks["conv_hr"] = std::shared_ptr(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1})); blocks["conv_last"] = std::shared_ptr(new Conv2d(num_feat, num_out_ch, {3, 3}, {1, 1}, {1, 1})); } + int get_scale() { return scale; } + int get_num_block() { return num_block; } + struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) { return ggml_leaky_relu(ctx, x, 0.2f, true); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [n, num_in_ch, h, w] - // return: [n, num_out_ch, h*4, w*4] + // return: [n, num_out_ch, h*scale, w*scale] auto conv_first = std::dynamic_pointer_cast(blocks["conv_first"]); auto conv_body = std::dynamic_pointer_cast(blocks["conv_body"]); - auto conv_up1 = std::dynamic_pointer_cast(blocks["conv_up1"]); - auto conv_up2 = std::dynamic_pointer_cast(blocks["conv_up2"]); auto conv_hr = std::dynamic_pointer_cast(blocks["conv_hr"]); auto conv_last = std::dynamic_pointer_cast(blocks["conv_last"]); @@ -130,15 +135,22 @@ public: body_feat = conv_body->forward(ctx, body_feat); feat = ggml_add(ctx, feat, body_feat); // upsample - feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); - feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); + if (scale >= 2) { + auto conv_up1 = std::dynamic_pointer_cast(blocks["conv_up1"]); + feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); + if (scale == 4) { + auto conv_up2 = std::dynamic_pointer_cast(blocks["conv_up2"]); + feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); + } + } + // for all scales auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat))); return out; } }; struct ESRGAN : public GGMLRunner { - RRDBNet rrdb_net; + std::unique_ptr rrdb_net; int scale = 4; int tile_size = 128; // avoid cuda OOM for 4gb VRAM @@ -146,12 +158,14 @@ struct ESRGAN : public GGMLRunner { bool offload_params_to_cpu, const String2GGMLType& tensor_types = {}) : GGMLRunner(backend, offload_params_to_cpu) { - rrdb_net.init(params_ctx, tensor_types, ""); + // rrdb_net will be created in load_from_file } void enable_conv2d_direct() { + if (!rrdb_net) + return; std::vector blocks; - rrdb_net.get_all_blocks(blocks); + rrdb_net->get_all_blocks(blocks); for (auto block : blocks) { if (block->get_desc() == "Conv2d") { auto conv_block = (Conv2d*)block; @@ -167,31 +181,185 @@ struct ESRGAN : public GGMLRunner { bool load_from_file(const std::string& file_path, int n_threads) { LOG_INFO("loading esrgan from '%s'", file_path.c_str()); - alloc_params_buffer(); - std::map esrgan_tensors; - rrdb_net.get_param_tensors(esrgan_tensors); - ModelLoader model_loader; if (!model_loader.init_from_file(file_path)) { LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str()); return false; } - bool success = model_loader.load_tensors(esrgan_tensors, {}, n_threads); + // Get tensor names + auto tensor_names = model_loader.get_tensor_names(); + + // Detect if it's ESRGAN format + bool is_ESRGAN = std::find(tensor_names.begin(), tensor_names.end(), "model.0.weight") != tensor_names.end(); + + // Detect parameters from tensor names + int detected_num_block = 0; + if (is_ESRGAN) { + for (const auto& name : tensor_names) { + if (name.find("model.1.sub.") == 0) { + size_t first_dot = name.find('.', 12); + if (first_dot != std::string::npos) { + size_t second_dot = name.find('.', first_dot + 1); + if (second_dot != std::string::npos && name.substr(first_dot + 1, 3) == "RDB") { + try { + int idx = std::stoi(name.substr(12, first_dot - 12)); + detected_num_block = std::max(detected_num_block, idx + 1); + } catch (...) { + } + } + } + } + } + } else { + // Original format + for (const auto& name : tensor_names) { + if (name.find("body.") == 0) { + size_t pos = name.find('.', 5); + if (pos != std::string::npos) { + try { + int idx = std::stoi(name.substr(5, pos - 5)); + detected_num_block = std::max(detected_num_block, idx + 1); + } catch (...) { + } + } + } + } + } + + int detected_scale = 4; // default + if (is_ESRGAN) { + // For ESRGAN format, detect scale by highest model number + int max_model_num = 0; + for (const auto& name : tensor_names) { + if (name.find("model.") == 0) { + size_t dot_pos = name.find('.', 6); + if (dot_pos != std::string::npos) { + try { + int num = std::stoi(name.substr(6, dot_pos - 6)); + max_model_num = std::max(max_model_num, num); + } catch (...) { + } + } + } + } + if (max_model_num <= 4) { + detected_scale = 1; + } else if (max_model_num <= 7) { + detected_scale = 2; + } else { + detected_scale = 4; + } + } else { + // Original format + bool has_conv_up2 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) { + return name == "conv_up2.weight"; + }); + bool has_conv_up1 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) { + return name == "conv_up1.weight"; + }); + if (has_conv_up2) { + detected_scale = 4; + } else if (has_conv_up1) { + detected_scale = 2; + } else { + detected_scale = 1; + } + } + + int detected_num_in_ch = 3; + int detected_num_out_ch = 3; + int detected_num_feat = 64; + int detected_num_grow_ch = 32; + + // Create RRDBNet with detected parameters + rrdb_net = std::make_unique(detected_scale, detected_num_block, detected_num_in_ch, detected_num_out_ch, detected_num_feat, detected_num_grow_ch); + rrdb_net->init(params_ctx, {}, ""); + + alloc_params_buffer(); + std::map esrgan_tensors; + rrdb_net->get_param_tensors(esrgan_tensors); + + bool success; + if (is_ESRGAN) { + // Build name mapping for ESRGAN format + std::map expected_to_model; + expected_to_model["conv_first.weight"] = "model.0.weight"; + expected_to_model["conv_first.bias"] = "model.0.bias"; + + for (int i = 0; i < detected_num_block; i++) { + for (int j = 1; j <= 3; j++) { + for (int k = 1; k <= 5; k++) { + std::string expected_weight = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".weight"; + std::string model_weight = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.weight"; + expected_to_model[expected_weight] = model_weight; + + std::string expected_bias = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".bias"; + std::string model_bias = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.bias"; + expected_to_model[expected_bias] = model_bias; + } + } + } + + if (detected_scale == 1) { + expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight"; + expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias"; + expected_to_model["conv_hr.weight"] = "model.2.weight"; + expected_to_model["conv_hr.bias"] = "model.2.bias"; + expected_to_model["conv_last.weight"] = "model.4.weight"; + expected_to_model["conv_last.bias"] = "model.4.bias"; + } else { + expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight"; + expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias"; + if (detected_scale >= 2) { + expected_to_model["conv_up1.weight"] = "model.3.weight"; + expected_to_model["conv_up1.bias"] = "model.3.bias"; + } + if (detected_scale == 4) { + expected_to_model["conv_up2.weight"] = "model.6.weight"; + expected_to_model["conv_up2.bias"] = "model.6.bias"; + expected_to_model["conv_hr.weight"] = "model.8.weight"; + expected_to_model["conv_hr.bias"] = "model.8.bias"; + expected_to_model["conv_last.weight"] = "model.10.weight"; + expected_to_model["conv_last.bias"] = "model.10.bias"; + } else if (detected_scale == 2) { + expected_to_model["conv_hr.weight"] = "model.5.weight"; + expected_to_model["conv_hr.bias"] = "model.5.bias"; + expected_to_model["conv_last.weight"] = "model.7.weight"; + expected_to_model["conv_last.bias"] = "model.7.bias"; + } + } + + std::map model_tensors; + for (auto& p : esrgan_tensors) { + auto it = expected_to_model.find(p.first); + if (it != expected_to_model.end()) { + model_tensors[it->second] = p.second; + } + } + + success = model_loader.load_tensors(model_tensors, {}, n_threads); + } else { + success = model_loader.load_tensors(esrgan_tensors, {}, n_threads); + } if (!success) { LOG_ERROR("load esrgan tensors from model loader failed"); return false; } - LOG_INFO("esrgan model loaded"); + scale = rrdb_net->get_scale(); + LOG_INFO("esrgan model loaded with scale=%d, num_block=%d", scale, detected_num_block); return success; } struct ggml_cgraph* build_graph(struct ggml_tensor* x) { - struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); - x = to_backend(x); - struct ggml_tensor* out = rrdb_net.forward(compute_ctx, x); + if (!rrdb_net) + return nullptr; + constexpr int kGraphNodes = 1 << 16; // 65k + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, kGraphNodes, /*grads*/ false); + x = to_backend(x); + struct ggml_tensor* out = rrdb_net->forward(compute_ctx, x); ggml_build_forward_expand(gf, out); return gf; } diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index ce741af..5229876 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -41,13 +41,15 @@ const char* modes_str[] = { "img_gen", "vid_gen", "convert", + "upscale", }; -#define SD_ALL_MODES_STR "img_gen, vid_gen, convert" +#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale" enum SDMode { IMG_GEN, VID_GEN, CONVERT, + UPSCALE, MODE_COUNT }; @@ -206,7 +208,7 @@ void print_usage(int argc, const char* argv[]) { printf("\n"); printf("arguments:\n"); printf(" -h, --help show this help message and exit\n"); - printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, convert], default: img_gen\n"); + printf(" -M, --mode [MODE] run mode, one of: [img_gen, vid_gen, upscale, convert], default: img_gen\n"); printf(" -t, --threads N number of threads to use during computation (default: -1)\n"); printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); printf(" --offload-to-cpu place the weights in RAM to save VRAM, and automatically load them into VRAM when needed\n"); @@ -222,7 +224,7 @@ void print_usage(int argc, const char* argv[]) { printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n"); - printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n"); + printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. For img_gen mode, upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n"); printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n"); printf(" If not specified, the default is the type of the weight file\n"); @@ -821,13 +823,13 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.n_threads = get_num_physical_cores(); } - if (params.mode != CONVERT && params.mode != VID_GEN && params.prompt.length() == 0) { + if ((params.mode == IMG_GEN || params.mode == VID_GEN) && params.prompt.length() == 0) { fprintf(stderr, "error: the following arguments are required: prompt\n"); print_usage(argc, argv); exit(1); } - if (params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) { + if (params.mode != UPSCALE && params.model_path.length() == 0 && params.diffusion_model_path.length() == 0) { fprintf(stderr, "error: the following arguments are required: model_path/diffusion_model\n"); print_usage(argc, argv); exit(1); @@ -887,6 +889,17 @@ void parse_args(int argc, const char** argv, SDParams& params) { exit(1); } + if (params.mode == UPSCALE) { + if (params.esrgan_path.length() == 0) { + fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n"); + exit(1); + } + if (params.init_image_path.length() == 0) { + fprintf(stderr, "error: upscale mode needs an init image (--init-img)\n"); + exit(1); + } + } + if (params.seed < 0) { srand((int)time(NULL)); params.seed = rand(); @@ -897,14 +910,6 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.output_path = "output.gguf"; } } - - if (!isfinite(params.sample_params.guidance.img_cfg)) { - params.sample_params.guidance.img_cfg = params.sample_params.guidance.txt_cfg; - } - - if (!isfinite(params.high_noise_sample_params.guidance.img_cfg)) { - params.high_noise_sample_params.guidance.img_cfg = params.high_noise_sample_params.guidance.txt_cfg; - } } static std::string sd_basename(const std::string& path) { @@ -1357,76 +1362,92 @@ int main(int argc, const char* argv[]) { params.flow_shift, }; - sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); + sd_image_t* results = nullptr; + int num_results = 0; - if (sd_ctx == NULL) { - printf("new_sd_ctx_t failed\n"); - release_all_resources(); - return 1; - } + if (params.mode == UPSCALE) { + num_results = 1; + results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t)); + if (results == NULL) { + printf("failed to allocate results array\n"); + release_all_resources(); + return 1; + } - if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) { - params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx); - } + results[0] = init_image; + init_image.data = NULL; + } else { + sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); - sd_image_t* results; - int num_results = 1; - if (params.mode == IMG_GEN) { - sd_img_gen_params_t img_gen_params = { - params.prompt.c_str(), - params.negative_prompt.c_str(), - params.clip_skip, - init_image, - ref_images.data(), - (int)ref_images.size(), - params.increase_ref_index, - mask_image, - params.width, - params.height, - params.sample_params, - params.strength, - params.seed, - params.batch_count, - control_image, - params.control_strength, - { - pmid_images.data(), - (int)pmid_images.size(), - params.pm_id_embed_path.c_str(), - params.pm_style_strength, - }, // pm_params - params.vae_tiling_params, - }; + if (sd_ctx == NULL) { + printf("new_sd_ctx_t failed\n"); + release_all_resources(); + return 1; + } - results = generate_image(sd_ctx, &img_gen_params); - num_results = params.batch_count; - } else if (params.mode == VID_GEN) { - sd_vid_gen_params_t vid_gen_params = { - params.prompt.c_str(), - params.negative_prompt.c_str(), - params.clip_skip, - init_image, - end_image, - control_frames.data(), - (int)control_frames.size(), - params.width, - params.height, - params.sample_params, - params.high_noise_sample_params, - params.moe_boundary, - params.strength, - params.seed, - params.video_frames, - params.vace_strength, - }; + if (params.sample_params.sample_method == SAMPLE_METHOD_DEFAULT) { + params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx); + } - results = generate_video(sd_ctx, &vid_gen_params, &num_results); - } + if (params.mode == IMG_GEN) { + sd_img_gen_params_t img_gen_params = { + params.prompt.c_str(), + params.negative_prompt.c_str(), + params.clip_skip, + init_image, + ref_images.data(), + (int)ref_images.size(), + params.increase_ref_index, + mask_image, + params.width, + params.height, + params.sample_params, + params.strength, + params.seed, + params.batch_count, + control_image, + params.control_strength, + { + pmid_images.data(), + (int)pmid_images.size(), + params.pm_id_embed_path.c_str(), + params.pm_style_strength, + }, // pm_params + params.vae_tiling_params, + }; + + results = generate_image(sd_ctx, &img_gen_params); + num_results = params.batch_count; + } else if (params.mode == VID_GEN) { + sd_vid_gen_params_t vid_gen_params = { + params.prompt.c_str(), + params.negative_prompt.c_str(), + params.clip_skip, + init_image, + end_image, + control_frames.data(), + (int)control_frames.size(), + params.width, + params.height, + params.sample_params, + params.high_noise_sample_params, + params.moe_boundary, + params.strength, + params.seed, + params.video_frames, + params.vace_strength, + }; + + results = generate_video(sd_ctx, &vid_gen_params, &num_results); + } + + if (results == NULL) { + printf("generate failed\n"); + free_sd_ctx(sd_ctx); + return 1; + } - if (results == NULL) { - printf("generate failed\n"); free_sd_ctx(sd_ctx); - return 1; } int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth @@ -1439,7 +1460,7 @@ int main(int argc, const char* argv[]) { if (upscaler_ctx == NULL) { printf("new_upscaler_ctx failed\n"); } else { - for (int i = 0; i < params.batch_count; i++) { + for (int i = 0; i < num_results; i++) { if (results[i].data == NULL) { continue; } @@ -1525,7 +1546,6 @@ int main(int argc, const char* argv[]) { results[i].data = NULL; } free(results); - free_sd_ctx(sd_ctx); release_all_resources(); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 9f7d0b3..a125357 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -483,12 +483,15 @@ __STATIC_INLINE__ void ggml_split_tensor_2d(struct ggml_tensor* input, int64_t width = output->ne[0]; int64_t height = output->ne[1]; int64_t channels = output->ne[2]; + int64_t ne3 = output->ne[3]; GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); for (int iy = 0; iy < height; iy++) { for (int ix = 0; ix < width; ix++) { for (int k = 0; k < channels; k++) { - float value = ggml_tensor_get_f32(input, ix + x, iy + y, k); - ggml_tensor_set_f32(output, value, ix, iy, k); + for (int l = 0; l < ne3; l++) { + float value = ggml_tensor_get_f32(input, ix + x, iy + y, k, l); + ggml_tensor_set_f32(output, value, ix, iy, k, l); + } } } } @@ -511,6 +514,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input, int64_t width = input->ne[0]; int64_t height = input->ne[1]; int64_t channels = input->ne[2]; + int64_t ne3 = input->ne[3]; int64_t img_width = output->ne[0]; int64_t img_height = output->ne[1]; @@ -519,24 +523,26 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input, for (int iy = y_skip; iy < height; iy++) { for (int ix = x_skip; ix < width; ix++) { for (int k = 0; k < channels; k++) { - float new_value = ggml_tensor_get_f32(input, ix, iy, k); - if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area - float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k); + for (int l = 0; l < ne3; l++) { + float new_value = ggml_tensor_get_f32(input, ix, iy, k, l); + if (overlap_x > 0 || overlap_y > 0) { // blend colors in overlapped area + float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k, l); - const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1; - const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1; - const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1; - const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1; + const float x_f_0 = (overlap_x > 0 && x > 0) ? (ix - x_skip) / float(overlap_x) : 1; + const float x_f_1 = (overlap_x > 0 && x < (img_width - width)) ? (width - ix) / float(overlap_x) : 1; + const float y_f_0 = (overlap_y > 0 && y > 0) ? (iy - y_skip) / float(overlap_y) : 1; + const float y_f_1 = (overlap_y > 0 && y < (img_height - height)) ? (height - iy) / float(overlap_y) : 1; - const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f); - const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f); + const float x_f = std::min(std::min(x_f_0, x_f_1), 1.f); + const float y_f = std::min(std::min(y_f_0, y_f_1), 1.f); - ggml_tensor_set_f32( - output, - old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f), - x + ix, y + iy, k); - } else { - ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k); + ggml_tensor_set_f32( + output, + old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f), + x + ix, y + iy, k, l); + } else { + ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k, l); + } } } } @@ -852,8 +858,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, } struct ggml_init_params params = {}; - params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * sizeof(float); // input chunk - params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * sizeof(float); // output chunk + params.mem_size += input_tile_size_x * input_tile_size_y * input->ne[2] * input->ne[3] * sizeof(float); // input chunk + params.mem_size += output_tile_size_x * output_tile_size_y * output->ne[2] * output->ne[3] * sizeof(float); // output chunk params.mem_size += 3 * ggml_tensor_overhead(); params.mem_buffer = NULL; params.no_alloc = false; @@ -868,8 +874,8 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input, } // tiling - ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], 1); - ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], 1); + ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]); + ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]); int num_tiles = num_tiles_x * num_tiles_y; LOG_INFO("processing %i tiles", num_tiles); pretty_progress(0, num_tiles, 0.0f); diff --git a/model.h b/model.h index 628639c..069bb0c 100644 --- a/model.h +++ b/model.h @@ -269,6 +269,14 @@ public: std::set ignore_tensors = {}, int n_threads = 0); + std::vector get_tensor_names() const { + std::vector names; + for (const auto& ts : tensor_storages) { + names.push_back(ts.name); + } + return names; + } + bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 6c64720..3a44d32 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1086,7 +1086,7 @@ public: std::vector skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count); float cfg_scale = guidance.txt_cfg; - float img_cfg_scale = guidance.img_cfg; + float img_cfg_scale = isfinite(guidance.img_cfg) ? guidance.img_cfg : guidance.txt_cfg; float slg_scale = guidance.slg.scale; if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) { @@ -1430,10 +1430,23 @@ public: if (vae_tiling_params.enabled && !encode_video) { // TODO wan2.2 vae support? int C = sd_version_is_dit(version) ? 16 : 4; - if (!use_tiny_autoencoder) { - C *= 2; + int ne2; + int ne3; + if (sd_version_is_qwen_image(version)) { + ne2 = 1; + ne3 = C*x->ne[3]; + } else { + if (!use_tiny_autoencoder) { + C *= 2; + } + ne2 = C; + ne3 = x->ne[3]; } - result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, x->ne[3]); + result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3); + } + + if (sd_version_is_qwen_image(version)) { + x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); } if (sd_version_is_qwen_image(version)) { @@ -1825,7 +1838,9 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) { "eta: %.2f, " "shifted_timestep: %d)", sample_params->guidance.txt_cfg, - sample_params->guidance.img_cfg, + isfinite(sample_params->guidance.img_cfg) + ? sample_params->guidance.img_cfg + : sample_params->guidance.txt_cfg, sample_params->guidance.distilled_guidance, sample_params->guidance.slg.layer_count, sample_params->guidance.slg.layer_start, @@ -1986,7 +2001,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, seed = rand(); } - print_ggml_tensor(init_latent, true, "init"); + if (!isfinite(guidance.img_cfg)) { + guidance.img_cfg = guidance.txt_cfg; + } // for (auto v : sigmas) { // std::cout << v << " "; diff --git a/stable-diffusion.h b/stable-diffusion.h index 90b4e8c..4711b45 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -284,6 +284,8 @@ SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor); +SD_API int get_upscale_factor(upscaler_ctx_t* upscaler_ctx); + SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, diff --git a/upscaler.cpp b/upscaler.cpp index 4c138ea..d304237 100644 --- a/upscaler.cpp +++ b/upscaler.cpp @@ -138,6 +138,13 @@ sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_ return upscaler_ctx->upscaler->upscale(input_image, upscale_factor); } +int get_upscale_factor(upscaler_ctx_t* upscaler_ctx) { + if (upscaler_ctx == NULL || upscaler_ctx->upscaler == NULL || upscaler_ctx->upscaler->esrgan_upscaler == NULL) { + return 1; + } + return upscaler_ctx->upscaler->esrgan_upscaler->scale; +} + void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx) { if (upscaler_ctx->upscaler != NULL) { delete upscaler_ctx->upscaler;