fix image edit api

This commit is contained in:
leejet 2025-12-12 01:08:25 +08:00
parent 90215344ed
commit e280695453
2 changed files with 103 additions and 62 deletions

View File

@ -1341,6 +1341,10 @@ struct SDGenerationParams {
load_if_exists("skip_layers", skip_layers);
load_if_exists("high_noise_skip_layers", high_noise_skip_layers);
load_if_exists("cfg_scale", sample_params.guidance.txt_cfg);
load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg);
load_if_exists("guidance", sample_params.guidance.distilled_guidance);
return true;
}
@ -1627,6 +1631,7 @@ static std::string version_string() {
uint8_t* load_image_common(bool from_memory,
const char* image_path_or_bytes,
int len,
int& width,
int& height,
int expected_width = 0,
@ -1637,7 +1642,7 @@ uint8_t* load_image_common(bool from_memory,
uint8_t* image_buffer = nullptr;
if (from_memory) {
image_path = "memory";
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
image_buffer = (uint8_t*)stbi_load_from_memory((const stbi_uc*)image_path_or_bytes, len, &width, &height, &c, expected_channel);
} else {
image_path = image_path_or_bytes;
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
@ -1733,14 +1738,15 @@ uint8_t* load_image_from_file(const char* image_path,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3) {
return load_image_common(false, image_path, width, height, expected_width, expected_height, expected_channel);
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
}
uint8_t* load_image_from_memory(const char* image_bytes,
int len,
int& width,
int& height,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3) {
return load_image_common(true, image_bytes, width, height, expected_width, expected_height, expected_channel);
return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel);
}

View File

@ -492,37 +492,51 @@ int main(int argc, const char** argv) {
svr.Post("/v1/images/edits", [&](const httplib::Request& req, httplib::Response& res) {
try {
if (req.body.empty()) {
if (!req.is_multipart_form_data()) {
res.status = 400;
res.set_content(R"({"error":"empty body"})", "application/json");
res.set_content(R"({"error":"Content-Type must be multipart/form-data"})", "application/json");
return;
}
json j = json::parse(req.body);
std::string prompt = j.value("prompt", "");
int n = std::max(1, j.value("n", 1));
std::string size = j.value("size", "");
std::string output_format = j.value("output_format", "png");
int output_compression = j.value("output_compression", 100);
std::string ref_image_b64 = j.value("image", "");
std::string mask_image_b64 = j.value("mask", "");
std::string prompt = req.form.get_field("prompt");
if (prompt.empty()) {
res.status = 400;
res.set_content(R"({"error":"prompt required"})", "application/json");
return;
}
if (ref_image_b64.empty()) {
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(prompt);
size_t image_count = req.form.get_file_count("image[]");
if (image_count == 0) {
res.status = 400;
res.set_content(R"({"error":"image required"})", "application/json");
res.set_content(R"({"error":"at least one image[] required"})", "application/json");
return;
}
int width = 512;
int height = 512;
std::vector<std::vector<uint8_t>> images_bytes;
for (size_t i = 0; i < image_count; i++) {
auto file = req.form.get_file("image[]", i);
images_bytes.emplace_back(file.content.begin(), file.content.end());
}
std::vector<uint8_t> mask_bytes;
if (req.form.has_field("mask")) {
auto file = req.form.get_file("mask");
mask_bytes.assign(file.content.begin(), file.content.end());
}
int n = 1;
if (req.form.has_field("n")) {
try {
n = std::stoi(req.form.get_field("n"));
} catch (...) {
}
}
n = std::clamp(n, 1, 8);
std::string size = req.form.get_field("size");
int width = 512, height = 512;
if (!size.empty()) {
auto pos = size.find('x');
if (pos != std::string::npos) {
@ -534,53 +548,26 @@ int main(int argc, const char** argv) {
}
}
std::string output_format = "png";
if (req.form.has_field("output_format"))
output_format = req.form.get_field("output_format");
if (output_format != "png" && output_format != "jpeg") {
res.status = 400;
res.set_content(R"({"error":"invalid output_format, must be one of [png, jpeg]"})", "application/json");
return;
}
if (n <= 0)
n = 1;
if (n > 8)
n = 8;
if (output_compression > 100)
std::string output_compression_str = req.form.get_field("output_compression");
int output_compression = 100;
try {
output_compression = std::stoi(output_compression_str);
} catch (...) {
}
if (output_compression > 100) {
output_compression = 100;
if (output_compression < 0)
}
if (output_compression < 0) {
output_compression = 0;
// base64 -> raw image
std::vector<uint8_t> ref_image_bytes = base64_decode(ref_image_b64);
int img_w = width;
int img_h = height;
uint8_t* raw_pixels = load_image_from_memory(
reinterpret_cast<const char*>(ref_image_bytes.data()),
img_w, img_h,
width, height, 3);
sd_image_t ref_image;
ref_image.width = img_w;
ref_image.height = img_h;
ref_image.channel = 3;
ref_image.data = raw_pixels;
sd_image_t mask_image = {0};
if (!mask_image_b64.empty()) {
std::vector<uint8_t> mask_bytes = base64_decode(mask_image_b64);
int mask_w = width, mask_h = height;
uint8_t* mask_raw = load_image_from_memory(
reinterpret_cast<const char*>(mask_bytes.data()),
mask_w, mask_h,
width, height, 1);
mask_image.width = mask_w;
mask_image.height = mask_h;
mask_image.channel = 1;
mask_image.data = mask_raw;
} else {
mask_image.width = width;
mask_image.height = height;
mask_image.channel = 1;
mask_image.data = nullptr;
}
SDGenerationParams gen_params;
@ -589,11 +576,56 @@ int main(int argc, const char** argv) {
gen_params.height = height;
gen_params.batch_count = n;
if (!sd_cpp_extra_args_str.empty() && !gen_params.from_json_str(sd_cpp_extra_args_str)) {
res.status = 400;
res.set_content(R"({"error":"invalid sd_cpp_extra_args"})", "application/json");
return;
}
if (svr_params.verbose) {
printf("%s\n", gen_params.to_string().c_str());
}
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
std::vector<sd_image_t> ref_images = {ref_image};
std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> ref_images;
ref_images.reserve(images_bytes.size());
for (auto& bytes : images_bytes) {
int img_w = width;
int img_h = height;
uint8_t* raw_pixels = load_image_from_memory(
reinterpret_cast<const char*>(bytes.data()),
bytes.size(),
img_w, img_h,
width, height, 3);
if (!raw_pixels) {
continue;
}
sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels};
ref_images.push_back(img);
}
sd_image_t mask_image = {0};
if (!mask_bytes.empty()) {
int mask_w = width;
int mask_h = height;
uint8_t* mask_raw = load_image_from_memory(
reinterpret_cast<const char*>(mask_bytes.data()),
mask_bytes.size(),
mask_w, mask_h,
width, height, 1);
mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw};
} else {
mask_image.width = width;
mask_image.height = height;
mask_image.channel = 1;
mask_image.data = nullptr;
}
sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
@ -662,6 +694,9 @@ int main(int argc, const char** argv) {
if (mask_image.data) {
stbi_image_free(mask_image.data);
}
for (auto ref_image : ref_images) {
stbi_image_free(ref_image.data);
}
} catch (const std::exception& e) {
res.status = 500;
json err;