From 90215344edfdfd239c34b29bd4312b39d37413bd Mon Sep 17 00:00:00 2001 From: leejet Date: Thu, 11 Dec 2025 00:17:29 +0800 Subject: [PATCH] wip image edit api --- examples/cli/main.cpp | 114 ++---------------- examples/common/common.hpp | 136 ++++++++++++++++++++- examples/server/main.cpp | 236 +++++++++++++++++++++++++++++++++++-- 3 files changed, 369 insertions(+), 117 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index a8d3be7..eaa2591 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -15,22 +15,10 @@ // #include "preprocessing.hpp" #include "stable-diffusion.h" -#define STB_IMAGE_IMPLEMENTATION -#define STB_IMAGE_STATIC -#include "stb_image.h" - -#define STB_IMAGE_WRITE_IMPLEMENTATION -#define STB_IMAGE_WRITE_STATIC -#include "stb_image_write.h" - -#define STB_IMAGE_RESIZE_IMPLEMENTATION -#define STB_IMAGE_RESIZE_STATIC -#include "stb_image_resize.h" +#include "common/common.hpp" #include "avi_writer.h" -#include "common/common.hpp" - const char* previews_str[] = { "none", "proj", @@ -335,94 +323,6 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { fflush(out_stream); } -uint8_t* load_image(const char* image_path, int& width, int& height, int expected_width = 0, int expected_height = 0, int expected_channel = 3) { - int c = 0; - uint8_t* image_buffer = (uint8_t*)stbi_load(image_path, &width, &height, &c, expected_channel); - if (image_buffer == nullptr) { - fprintf(stderr, "load image from '%s' failed\n", image_path); - return nullptr; - } - if (c < expected_channel) { - fprintf(stderr, - "the number of channels for the input image must be >= %d," - "but got %d channels, image_path = %s\n", - expected_channel, - c, - image_path); - free(image_buffer); - return nullptr; - } - if (width <= 0) { - fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path); - free(image_buffer); - return nullptr; - } - if (height <= 0) { - fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path); - free(image_buffer); - return nullptr; - } - - // Resize input image ... - if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) { - float dst_aspect = (float)expected_width / (float)expected_height; - float src_aspect = (float)width / (float)height; - - int crop_x = 0, crop_y = 0; - int crop_w = width, crop_h = height; - - if (src_aspect > dst_aspect) { - crop_w = (int)(height * dst_aspect); - crop_x = (width - crop_w) / 2; - } else if (src_aspect < dst_aspect) { - crop_h = (int)(width / dst_aspect); - crop_y = (height - crop_h) / 2; - } - - if (crop_x != 0 || crop_y != 0) { - printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path); - uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel); - if (cropped_image_buffer == nullptr) { - fprintf(stderr, "error: allocate memory for crop\n"); - free(image_buffer); - return nullptr; - } - for (int row = 0; row < crop_h; row++) { - uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel; - uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel; - memcpy(dst, src, crop_w * expected_channel); - } - - width = crop_w; - height = crop_h; - free(image_buffer); - image_buffer = cropped_image_buffer; - } - - printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height); - int resized_height = expected_height; - int resized_width = expected_width; - - uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel); - if (resized_image_buffer == nullptr) { - fprintf(stderr, "error: allocate memory for resize input image\n"); - free(image_buffer); - return nullptr; - } - stbir_resize(image_buffer, width, height, 0, - resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8, - expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0, - STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, - STBIR_FILTER_BOX, STBIR_FILTER_BOX, - STBIR_COLORSPACE_SRGB, nullptr); - width = resized_width; - height = resized_height; - free(image_buffer); - image_buffer = resized_image_buffer; - } - return image_buffer; -} - bool load_images_from_dir(const std::string dir, std::vector& images, int expected_width = 0, @@ -457,7 +357,7 @@ bool load_images_from_dir(const std::string dir, } int width = 0; int height = 0; - uint8_t* image_buffer = load_image(path.c_str(), width, height, expected_width, expected_height); + uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height, expected_width, expected_height); if (image_buffer == nullptr) { fprintf(stderr, "load image from '%s' failed\n", path.c_str()); return false; @@ -593,7 +493,7 @@ int main(int argc, const char* argv[]) { int width = 0; int height = 0; - init_image.data = load_image(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height); + init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (init_image.data == nullptr) { fprintf(stderr, "load image from '%s' failed\n", gen_params.init_image_path.c_str()); release_all_resources(); @@ -606,7 +506,7 @@ int main(int argc, const char* argv[]) { int width = 0; int height = 0; - end_image.data = load_image(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height); + end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (end_image.data == nullptr) { fprintf(stderr, "load image from '%s' failed\n", gen_params.end_image_path.c_str()); release_all_resources(); @@ -618,7 +518,7 @@ int main(int argc, const char* argv[]) { int c = 0; int width = 0; int height = 0; - mask_image.data = load_image(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1); + mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1); if (mask_image.data == nullptr) { fprintf(stderr, "load image from '%s' failed\n", gen_params.mask_image_path.c_str()); release_all_resources(); @@ -637,7 +537,7 @@ int main(int argc, const char* argv[]) { if (gen_params.control_image_path.size() > 0) { int width = 0; int height = 0; - control_image.data = load_image(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height); + control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height); if (control_image.data == nullptr) { fprintf(stderr, "load image from '%s' failed\n", gen_params.control_image_path.c_str()); release_all_resources(); @@ -658,7 +558,7 @@ int main(int argc, const char* argv[]) { for (auto& path : gen_params.ref_image_paths) { int width = 0; int height = 0; - uint8_t* image_buffer = load_image(path.c_str(), width, height); + uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height); if (image_buffer == nullptr) { fprintf(stderr, "load image from '%s' failed\n", path.c_str()); release_all_resources(); diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 85eb699..98d3528 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -1,4 +1,5 @@ +#include #include #include #include @@ -6,10 +7,9 @@ #include #include #include -#include #include -using json = nlohmann::json; +using json = nlohmann::json; namespace fs = std::filesystem; #if defined(_WIN32) @@ -19,6 +19,18 @@ namespace fs = std::filesystem; #include "stable-diffusion.h" +#define STB_IMAGE_IMPLEMENTATION +#define STB_IMAGE_STATIC +#include "stb_image.h" + +#define STB_IMAGE_WRITE_IMPLEMENTATION +#define STB_IMAGE_WRITE_STATIC +#include "stb_image_write.h" + +#define STB_IMAGE_RESIZE_IMPLEMENTATION +#define STB_IMAGE_RESIZE_STATIC +#include "stb_image_resize.h" + #define SAFE_STR(s) ((s) ? (s) : "") #define BOOL_STR(b) ((b) ? "true" : "false") @@ -1612,3 +1624,123 @@ struct SDGenerationParams { static std::string version_string() { return std::string("stable-diffusion.cpp version ") + sd_version() + ", commit " + sd_commit(); } + +uint8_t* load_image_common(bool from_memory, + const char* image_path_or_bytes, + int& width, + int& height, + int expected_width = 0, + int expected_height = 0, + int expected_channel = 3) { + int c = 0; + const char* image_path; + uint8_t* image_buffer = nullptr; + if (from_memory) { + image_path = "memory"; + image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel); + } else { + image_path = image_path_or_bytes; + image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel); + } + if (image_buffer == nullptr) { + fprintf(stderr, "load image from '%s' failed\n", image_path); + return nullptr; + } + if (c < expected_channel) { + fprintf(stderr, + "the number of channels for the input image must be >= %d," + "but got %d channels, image_path = %s\n", + expected_channel, + c, + image_path); + free(image_buffer); + return nullptr; + } + if (width <= 0) { + fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path); + free(image_buffer); + return nullptr; + } + if (height <= 0) { + fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path); + free(image_buffer); + return nullptr; + } + + // Resize input image ... + if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) { + float dst_aspect = (float)expected_width / (float)expected_height; + float src_aspect = (float)width / (float)height; + + int crop_x = 0, crop_y = 0; + int crop_w = width, crop_h = height; + + if (src_aspect > dst_aspect) { + crop_w = (int)(height * dst_aspect); + crop_x = (width - crop_w) / 2; + } else if (src_aspect < dst_aspect) { + crop_h = (int)(width / dst_aspect); + crop_y = (height - crop_h) / 2; + } + + if (crop_x != 0 || crop_y != 0) { + printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path); + uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel); + if (cropped_image_buffer == nullptr) { + fprintf(stderr, "error: allocate memory for crop\n"); + free(image_buffer); + return nullptr; + } + for (int row = 0; row < crop_h; row++) { + uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel; + uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel; + memcpy(dst, src, crop_w * expected_channel); + } + + width = crop_w; + height = crop_h; + free(image_buffer); + image_buffer = cropped_image_buffer; + } + + printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height); + int resized_height = expected_height; + int resized_width = expected_width; + + uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel); + if (resized_image_buffer == nullptr) { + fprintf(stderr, "error: allocate memory for resize input image\n"); + free(image_buffer); + return nullptr; + } + stbir_resize(image_buffer, width, height, 0, + resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8, + expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0, + STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, + STBIR_FILTER_BOX, STBIR_FILTER_BOX, + STBIR_COLORSPACE_SRGB, nullptr); + width = resized_width; + height = resized_height; + free(image_buffer); + image_buffer = resized_image_buffer; + } + return image_buffer; +} + +uint8_t* load_image_from_file(const char* image_path, + int& width, + int& height, + int expected_width = 0, + int expected_height = 0, + int expected_channel = 3) { + return load_image_common(false, image_path, width, height, expected_width, expected_height, expected_channel); +} + +uint8_t* load_image_from_memory(const char* image_bytes, + int& width, + int& height, + int expected_width = 0, + int expected_height = 0, + int expected_channel = 3) { + return load_image_common(true, image_bytes, width, height, expected_width, expected_height, expected_channel); +} diff --git a/examples/server/main.cpp b/examples/server/main.cpp index 2238581..d63a631 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -8,14 +8,6 @@ #include #include -#define STB_IMAGE_IMPLEMENTATION -#define STB_IMAGE_STATIC -#include "stb_image.h" - -#define STB_IMAGE_WRITE_IMPLEMENTATION -#define STB_IMAGE_WRITE_STATIC -#include "stb_image_write.h" - #include "httplib.h" #include "stable-diffusion.h" @@ -47,6 +39,53 @@ std::string base64_encode(const std::vector& bytes) { return ret; } +inline bool is_base64(unsigned char c) { + return (isalnum(c) || (c == '+') || (c == '/')); +} + +std::vector base64_decode(const std::string& encoded_string) { + int in_len = encoded_string.size(); + int i = 0; + int j = 0; + int in_ = 0; + uint8_t char_array_4[4], char_array_3[3]; + std::vector ret; + + while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { + char_array_4[i++] = encoded_string[in_]; + in_++; + if (i == 4) { + for (i = 0; i < 4; i++) + char_array_4[i] = static_cast(base64_chars.find(char_array_4[i])); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (i = 0; i < 3; i++) + ret.push_back(char_array_3[i]); + i = 0; + } + } + + if (i) { + for (j = i; j < 4; j++) + char_array_4[j] = 0; + + for (j = 0; j < 4; j++) + char_array_4[j] = static_cast(base64_chars.find(char_array_4[j])); + + char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4); + char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); + char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; + + for (j = 0; j < i - 1; j++) + ret.push_back(char_array_3[j]); + } + + return ret; +} + std::string iso_timestamp_now() { using namespace std::chrono; auto now = system_clock::now(); @@ -451,6 +490,187 @@ int main(int argc, const char** argv) { } }); + svr.Post("/v1/images/edits", [&](const httplib::Request& req, httplib::Response& res) { + try { + if (req.body.empty()) { + res.status = 400; + res.set_content(R"({"error":"empty body"})", "application/json"); + return; + } + + json j = json::parse(req.body); + + std::string prompt = j.value("prompt", ""); + int n = std::max(1, j.value("n", 1)); + std::string size = j.value("size", ""); + std::string output_format = j.value("output_format", "png"); + int output_compression = j.value("output_compression", 100); + + std::string ref_image_b64 = j.value("image", ""); + std::string mask_image_b64 = j.value("mask", ""); + + if (prompt.empty()) { + res.status = 400; + res.set_content(R"({"error":"prompt required"})", "application/json"); + return; + } + + if (ref_image_b64.empty()) { + res.status = 400; + res.set_content(R"({"error":"image required"})", "application/json"); + return; + } + + int width = 512; + int height = 512; + if (!size.empty()) { + auto pos = size.find('x'); + if (pos != std::string::npos) { + try { + width = std::stoi(size.substr(0, pos)); + height = std::stoi(size.substr(pos + 1)); + } catch (...) { + } + } + } + + if (output_format != "png" && output_format != "jpeg") { + res.status = 400; + res.set_content(R"({"error":"invalid output_format, must be one of [png, jpeg]"})", "application/json"); + return; + } + + if (n <= 0) + n = 1; + if (n > 8) + n = 8; + if (output_compression > 100) + output_compression = 100; + if (output_compression < 0) + output_compression = 0; + + // base64 -> raw image + std::vector ref_image_bytes = base64_decode(ref_image_b64); + int img_w = width; + int img_h = height; + uint8_t* raw_pixels = load_image_from_memory( + reinterpret_cast(ref_image_bytes.data()), + img_w, img_h, + width, height, 3); + + sd_image_t ref_image; + ref_image.width = img_w; + ref_image.height = img_h; + ref_image.channel = 3; + ref_image.data = raw_pixels; + + sd_image_t mask_image = {0}; + if (!mask_image_b64.empty()) { + std::vector mask_bytes = base64_decode(mask_image_b64); + int mask_w = width, mask_h = height; + uint8_t* mask_raw = load_image_from_memory( + reinterpret_cast(mask_bytes.data()), + mask_w, mask_h, + width, height, 1); + mask_image.width = mask_w; + mask_image.height = mask_h; + mask_image.channel = 1; + mask_image.data = mask_raw; + } else { + mask_image.width = width; + mask_image.height = height; + mask_image.channel = 1; + mask_image.data = nullptr; + } + + SDGenerationParams gen_params; + gen_params.prompt = prompt; + gen_params.width = width; + gen_params.height = height; + gen_params.batch_count = n; + + sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; + sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; + std::vector ref_images = {ref_image}; + std::vector pmid_images; + + sd_img_gen_params_t img_gen_params = { + gen_params.lora_vec.data(), + static_cast(gen_params.lora_vec.size()), + gen_params.prompt.c_str(), + gen_params.negative_prompt.c_str(), + gen_params.clip_skip, + init_image, + ref_images.data(), + (int)ref_images.size(), + gen_params.auto_resize_ref_image, + gen_params.increase_ref_index, + mask_image, + gen_params.width, + gen_params.height, + gen_params.sample_params, + gen_params.strength, + gen_params.seed, + gen_params.batch_count, + control_image, + gen_params.control_strength, + { + pmid_images.data(), + (int)pmid_images.size(), + gen_params.pm_id_embed_path.c_str(), + gen_params.pm_style_strength, + }, // pm_params + ctx_params.vae_tiling_params, + gen_params.easycache_params, + }; + + sd_image_t* results = nullptr; + int num_results = 0; + + { + std::lock_guard lock(sd_ctx_mutex); + results = generate_image(sd_ctx, &img_gen_params); + num_results = gen_params.batch_count; + } + + json out; + out["created"] = iso_timestamp_now(); + out["data"] = json::array(); + out["output_format"] = output_format; + + for (int i = 0; i < num_results; i++) { + if (results[i].data == nullptr) + continue; + auto image_bytes = write_image_to_vector(output_format == "jpeg" ? ImageFormat::JPEG : ImageFormat::PNG, + results[i].data, + results[i].width, + results[i].height, + results[i].channel, + output_compression); + std::string b64 = base64_encode(image_bytes); + json item; + item["b64_json"] = b64; + out["data"].push_back(item); + } + + res.set_content(out.dump(), "application/json"); + res.status = 200; + + if (init_image.data) { + stbi_image_free(init_image.data); + } + if (mask_image.data) { + stbi_image_free(mask_image.data); + } + } catch (const std::exception& e) { + res.status = 500; + json err; + err["error"] = "server_error"; + err["message"] = e.what(); + res.set_content(err.dump(), "application/json"); + } + }); + printf("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port); svr.listen(svr_params.listen_ip, svr_params.listen_port);