diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 98d3528..e508aba 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -1341,6 +1341,10 @@ struct SDGenerationParams { load_if_exists("skip_layers", skip_layers); load_if_exists("high_noise_skip_layers", high_noise_skip_layers); + load_if_exists("cfg_scale", sample_params.guidance.txt_cfg); + load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg); + load_if_exists("guidance", sample_params.guidance.distilled_guidance); + return true; } @@ -1627,6 +1631,7 @@ static std::string version_string() { uint8_t* load_image_common(bool from_memory, const char* image_path_or_bytes, + int len, int& width, int& height, int expected_width = 0, @@ -1637,7 +1642,7 @@ uint8_t* load_image_common(bool from_memory, uint8_t* image_buffer = nullptr; if (from_memory) { image_path = "memory"; - image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel); + image_buffer = (uint8_t*)stbi_load_from_memory((const stbi_uc*)image_path_or_bytes, len, &width, &height, &c, expected_channel); } else { image_path = image_path_or_bytes; image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel); @@ -1733,14 +1738,15 @@ uint8_t* load_image_from_file(const char* image_path, int expected_width = 0, int expected_height = 0, int expected_channel = 3) { - return load_image_common(false, image_path, width, height, expected_width, expected_height, expected_channel); + return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel); } uint8_t* load_image_from_memory(const char* image_bytes, + int len, int& width, int& height, int expected_width = 0, int expected_height = 0, int expected_channel = 3) { - return load_image_common(true, image_bytes, width, height, expected_width, expected_height, expected_channel); + return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel); } diff --git a/examples/server/main.cpp b/examples/server/main.cpp index d63a631..bcd7d17 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -492,37 +492,51 @@ int main(int argc, const char** argv) { svr.Post("/v1/images/edits", [&](const httplib::Request& req, httplib::Response& res) { try { - if (req.body.empty()) { + if (!req.is_multipart_form_data()) { res.status = 400; - res.set_content(R"({"error":"empty body"})", "application/json"); + res.set_content(R"({"error":"Content-Type must be multipart/form-data"})", "application/json"); return; } - json j = json::parse(req.body); - - std::string prompt = j.value("prompt", ""); - int n = std::max(1, j.value("n", 1)); - std::string size = j.value("size", ""); - std::string output_format = j.value("output_format", "png"); - int output_compression = j.value("output_compression", 100); - - std::string ref_image_b64 = j.value("image", ""); - std::string mask_image_b64 = j.value("mask", ""); - + std::string prompt = req.form.get_field("prompt"); if (prompt.empty()) { res.status = 400; res.set_content(R"({"error":"prompt required"})", "application/json"); return; } - if (ref_image_b64.empty()) { + std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(prompt); + + size_t image_count = req.form.get_file_count("image[]"); + if (image_count == 0) { res.status = 400; - res.set_content(R"({"error":"image required"})", "application/json"); + res.set_content(R"({"error":"at least one image[] required"})", "application/json"); return; } - int width = 512; - int height = 512; + std::vector> images_bytes; + for (size_t i = 0; i < image_count; i++) { + auto file = req.form.get_file("image[]", i); + images_bytes.emplace_back(file.content.begin(), file.content.end()); + } + + std::vector mask_bytes; + if (req.form.has_field("mask")) { + auto file = req.form.get_file("mask"); + mask_bytes.assign(file.content.begin(), file.content.end()); + } + + int n = 1; + if (req.form.has_field("n")) { + try { + n = std::stoi(req.form.get_field("n")); + } catch (...) { + } + } + n = std::clamp(n, 1, 8); + + std::string size = req.form.get_field("size"); + int width = 512, height = 512; if (!size.empty()) { auto pos = size.find('x'); if (pos != std::string::npos) { @@ -534,53 +548,26 @@ int main(int argc, const char** argv) { } } + std::string output_format = "png"; + if (req.form.has_field("output_format")) + output_format = req.form.get_field("output_format"); if (output_format != "png" && output_format != "jpeg") { res.status = 400; res.set_content(R"({"error":"invalid output_format, must be one of [png, jpeg]"})", "application/json"); return; } - if (n <= 0) - n = 1; - if (n > 8) - n = 8; - if (output_compression > 100) + std::string output_compression_str = req.form.get_field("output_compression"); + int output_compression = 100; + try { + output_compression = std::stoi(output_compression_str); + } catch (...) { + } + if (output_compression > 100) { output_compression = 100; - if (output_compression < 0) + } + if (output_compression < 0) { output_compression = 0; - - // base64 -> raw image - std::vector ref_image_bytes = base64_decode(ref_image_b64); - int img_w = width; - int img_h = height; - uint8_t* raw_pixels = load_image_from_memory( - reinterpret_cast(ref_image_bytes.data()), - img_w, img_h, - width, height, 3); - - sd_image_t ref_image; - ref_image.width = img_w; - ref_image.height = img_h; - ref_image.channel = 3; - ref_image.data = raw_pixels; - - sd_image_t mask_image = {0}; - if (!mask_image_b64.empty()) { - std::vector mask_bytes = base64_decode(mask_image_b64); - int mask_w = width, mask_h = height; - uint8_t* mask_raw = load_image_from_memory( - reinterpret_cast(mask_bytes.data()), - mask_w, mask_h, - width, height, 1); - mask_image.width = mask_w; - mask_image.height = mask_h; - mask_image.channel = 1; - mask_image.data = mask_raw; - } else { - mask_image.width = width; - mask_image.height = height; - mask_image.channel = 1; - mask_image.data = nullptr; } SDGenerationParams gen_params; @@ -589,11 +576,56 @@ int main(int argc, const char** argv) { gen_params.height = height; gen_params.batch_count = n; - sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; - sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; - std::vector ref_images = {ref_image}; + if (!sd_cpp_extra_args_str.empty() && !gen_params.from_json_str(sd_cpp_extra_args_str)) { + res.status = 400; + res.set_content(R"({"error":"invalid sd_cpp_extra_args"})", "application/json"); + return; + } + + if (svr_params.verbose) { + printf("%s\n", gen_params.to_string().c_str()); + } + + sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; + sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; std::vector pmid_images; + std::vector ref_images; + ref_images.reserve(images_bytes.size()); + for (auto& bytes : images_bytes) { + int img_w = width; + int img_h = height; + uint8_t* raw_pixels = load_image_from_memory( + reinterpret_cast(bytes.data()), + bytes.size(), + img_w, img_h, + width, height, 3); + + if (!raw_pixels) { + continue; + } + + sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels}; + ref_images.push_back(img); + } + + sd_image_t mask_image = {0}; + if (!mask_bytes.empty()) { + int mask_w = width; + int mask_h = height; + uint8_t* mask_raw = load_image_from_memory( + reinterpret_cast(mask_bytes.data()), + mask_bytes.size(), + mask_w, mask_h, + width, height, 1); + mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw}; + } else { + mask_image.width = width; + mask_image.height = height; + mask_image.channel = 1; + mask_image.data = nullptr; + } + sd_img_gen_params_t img_gen_params = { gen_params.lora_vec.data(), static_cast(gen_params.lora_vec.size()), @@ -662,6 +694,9 @@ int main(int argc, const char** argv) { if (mask_image.data) { stbi_image_free(mask_image.data); } + for (auto ref_image : ref_images) { + stbi_image_free(ref_image.data); + } } catch (const std::exception& e) { res.status = 500; json err;