mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
fix image edit api
This commit is contained in:
parent
90215344ed
commit
e280695453
@ -1341,6 +1341,10 @@ struct SDGenerationParams {
|
||||
load_if_exists("skip_layers", skip_layers);
|
||||
load_if_exists("high_noise_skip_layers", high_noise_skip_layers);
|
||||
|
||||
load_if_exists("cfg_scale", sample_params.guidance.txt_cfg);
|
||||
load_if_exists("img_cfg_scale", sample_params.guidance.img_cfg);
|
||||
load_if_exists("guidance", sample_params.guidance.distilled_guidance);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -1627,6 +1631,7 @@ static std::string version_string() {
|
||||
|
||||
uint8_t* load_image_common(bool from_memory,
|
||||
const char* image_path_or_bytes,
|
||||
int len,
|
||||
int& width,
|
||||
int& height,
|
||||
int expected_width = 0,
|
||||
@ -1637,7 +1642,7 @@ uint8_t* load_image_common(bool from_memory,
|
||||
uint8_t* image_buffer = nullptr;
|
||||
if (from_memory) {
|
||||
image_path = "memory";
|
||||
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
|
||||
image_buffer = (uint8_t*)stbi_load_from_memory((const stbi_uc*)image_path_or_bytes, len, &width, &height, &c, expected_channel);
|
||||
} else {
|
||||
image_path = image_path_or_bytes;
|
||||
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
|
||||
@ -1733,14 +1738,15 @@ uint8_t* load_image_from_file(const char* image_path,
|
||||
int expected_width = 0,
|
||||
int expected_height = 0,
|
||||
int expected_channel = 3) {
|
||||
return load_image_common(false, image_path, width, height, expected_width, expected_height, expected_channel);
|
||||
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
|
||||
}
|
||||
|
||||
uint8_t* load_image_from_memory(const char* image_bytes,
|
||||
int len,
|
||||
int& width,
|
||||
int& height,
|
||||
int expected_width = 0,
|
||||
int expected_height = 0,
|
||||
int expected_channel = 3) {
|
||||
return load_image_common(true, image_bytes, width, height, expected_width, expected_height, expected_channel);
|
||||
return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel);
|
||||
}
|
||||
|
||||
@ -492,37 +492,51 @@ int main(int argc, const char** argv) {
|
||||
|
||||
svr.Post("/v1/images/edits", [&](const httplib::Request& req, httplib::Response& res) {
|
||||
try {
|
||||
if (req.body.empty()) {
|
||||
if (!req.is_multipart_form_data()) {
|
||||
res.status = 400;
|
||||
res.set_content(R"({"error":"empty body"})", "application/json");
|
||||
res.set_content(R"({"error":"Content-Type must be multipart/form-data"})", "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
json j = json::parse(req.body);
|
||||
|
||||
std::string prompt = j.value("prompt", "");
|
||||
int n = std::max(1, j.value("n", 1));
|
||||
std::string size = j.value("size", "");
|
||||
std::string output_format = j.value("output_format", "png");
|
||||
int output_compression = j.value("output_compression", 100);
|
||||
|
||||
std::string ref_image_b64 = j.value("image", "");
|
||||
std::string mask_image_b64 = j.value("mask", "");
|
||||
|
||||
std::string prompt = req.form.get_field("prompt");
|
||||
if (prompt.empty()) {
|
||||
res.status = 400;
|
||||
res.set_content(R"({"error":"prompt required"})", "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
if (ref_image_b64.empty()) {
|
||||
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(prompt);
|
||||
|
||||
size_t image_count = req.form.get_file_count("image[]");
|
||||
if (image_count == 0) {
|
||||
res.status = 400;
|
||||
res.set_content(R"({"error":"image required"})", "application/json");
|
||||
res.set_content(R"({"error":"at least one image[] required"})", "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
int width = 512;
|
||||
int height = 512;
|
||||
std::vector<std::vector<uint8_t>> images_bytes;
|
||||
for (size_t i = 0; i < image_count; i++) {
|
||||
auto file = req.form.get_file("image[]", i);
|
||||
images_bytes.emplace_back(file.content.begin(), file.content.end());
|
||||
}
|
||||
|
||||
std::vector<uint8_t> mask_bytes;
|
||||
if (req.form.has_field("mask")) {
|
||||
auto file = req.form.get_file("mask");
|
||||
mask_bytes.assign(file.content.begin(), file.content.end());
|
||||
}
|
||||
|
||||
int n = 1;
|
||||
if (req.form.has_field("n")) {
|
||||
try {
|
||||
n = std::stoi(req.form.get_field("n"));
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
n = std::clamp(n, 1, 8);
|
||||
|
||||
std::string size = req.form.get_field("size");
|
||||
int width = 512, height = 512;
|
||||
if (!size.empty()) {
|
||||
auto pos = size.find('x');
|
||||
if (pos != std::string::npos) {
|
||||
@ -534,53 +548,26 @@ int main(int argc, const char** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string output_format = "png";
|
||||
if (req.form.has_field("output_format"))
|
||||
output_format = req.form.get_field("output_format");
|
||||
if (output_format != "png" && output_format != "jpeg") {
|
||||
res.status = 400;
|
||||
res.set_content(R"({"error":"invalid output_format, must be one of [png, jpeg]"})", "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
if (n <= 0)
|
||||
n = 1;
|
||||
if (n > 8)
|
||||
n = 8;
|
||||
if (output_compression > 100)
|
||||
std::string output_compression_str = req.form.get_field("output_compression");
|
||||
int output_compression = 100;
|
||||
try {
|
||||
output_compression = std::stoi(output_compression_str);
|
||||
} catch (...) {
|
||||
}
|
||||
if (output_compression > 100) {
|
||||
output_compression = 100;
|
||||
if (output_compression < 0)
|
||||
}
|
||||
if (output_compression < 0) {
|
||||
output_compression = 0;
|
||||
|
||||
// base64 -> raw image
|
||||
std::vector<uint8_t> ref_image_bytes = base64_decode(ref_image_b64);
|
||||
int img_w = width;
|
||||
int img_h = height;
|
||||
uint8_t* raw_pixels = load_image_from_memory(
|
||||
reinterpret_cast<const char*>(ref_image_bytes.data()),
|
||||
img_w, img_h,
|
||||
width, height, 3);
|
||||
|
||||
sd_image_t ref_image;
|
||||
ref_image.width = img_w;
|
||||
ref_image.height = img_h;
|
||||
ref_image.channel = 3;
|
||||
ref_image.data = raw_pixels;
|
||||
|
||||
sd_image_t mask_image = {0};
|
||||
if (!mask_image_b64.empty()) {
|
||||
std::vector<uint8_t> mask_bytes = base64_decode(mask_image_b64);
|
||||
int mask_w = width, mask_h = height;
|
||||
uint8_t* mask_raw = load_image_from_memory(
|
||||
reinterpret_cast<const char*>(mask_bytes.data()),
|
||||
mask_w, mask_h,
|
||||
width, height, 1);
|
||||
mask_image.width = mask_w;
|
||||
mask_image.height = mask_h;
|
||||
mask_image.channel = 1;
|
||||
mask_image.data = mask_raw;
|
||||
} else {
|
||||
mask_image.width = width;
|
||||
mask_image.height = height;
|
||||
mask_image.channel = 1;
|
||||
mask_image.data = nullptr;
|
||||
}
|
||||
|
||||
SDGenerationParams gen_params;
|
||||
@ -589,11 +576,56 @@ int main(int argc, const char** argv) {
|
||||
gen_params.height = height;
|
||||
gen_params.batch_count = n;
|
||||
|
||||
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||
std::vector<sd_image_t> ref_images = {ref_image};
|
||||
if (!sd_cpp_extra_args_str.empty() && !gen_params.from_json_str(sd_cpp_extra_args_str)) {
|
||||
res.status = 400;
|
||||
res.set_content(R"({"error":"invalid sd_cpp_extra_args"})", "application/json");
|
||||
return;
|
||||
}
|
||||
|
||||
if (svr_params.verbose) {
|
||||
printf("%s\n", gen_params.to_string().c_str());
|
||||
}
|
||||
|
||||
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||
std::vector<sd_image_t> pmid_images;
|
||||
|
||||
std::vector<sd_image_t> ref_images;
|
||||
ref_images.reserve(images_bytes.size());
|
||||
for (auto& bytes : images_bytes) {
|
||||
int img_w = width;
|
||||
int img_h = height;
|
||||
uint8_t* raw_pixels = load_image_from_memory(
|
||||
reinterpret_cast<const char*>(bytes.data()),
|
||||
bytes.size(),
|
||||
img_w, img_h,
|
||||
width, height, 3);
|
||||
|
||||
if (!raw_pixels) {
|
||||
continue;
|
||||
}
|
||||
|
||||
sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels};
|
||||
ref_images.push_back(img);
|
||||
}
|
||||
|
||||
sd_image_t mask_image = {0};
|
||||
if (!mask_bytes.empty()) {
|
||||
int mask_w = width;
|
||||
int mask_h = height;
|
||||
uint8_t* mask_raw = load_image_from_memory(
|
||||
reinterpret_cast<const char*>(mask_bytes.data()),
|
||||
mask_bytes.size(),
|
||||
mask_w, mask_h,
|
||||
width, height, 1);
|
||||
mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw};
|
||||
} else {
|
||||
mask_image.width = width;
|
||||
mask_image.height = height;
|
||||
mask_image.channel = 1;
|
||||
mask_image.data = nullptr;
|
||||
}
|
||||
|
||||
sd_img_gen_params_t img_gen_params = {
|
||||
gen_params.lora_vec.data(),
|
||||
static_cast<uint32_t>(gen_params.lora_vec.size()),
|
||||
@ -662,6 +694,9 @@ int main(int argc, const char** argv) {
|
||||
if (mask_image.data) {
|
||||
stbi_image_free(mask_image.data);
|
||||
}
|
||||
for (auto ref_image : ref_images) {
|
||||
stbi_image_free(ref_image.data);
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
res.status = 500;
|
||||
json err;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user