wip image edit api

This commit is contained in:
leejet 2025-12-11 00:17:29 +08:00
parent 96aea6340c
commit 90215344ed
3 changed files with 369 additions and 117 deletions

View File

@ -15,22 +15,10 @@
// #include "preprocessing.hpp"
#include "stable-diffusion.h"
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#define STB_IMAGE_WRITE_STATIC
#include "stb_image_write.h"
#define STB_IMAGE_RESIZE_IMPLEMENTATION
#define STB_IMAGE_RESIZE_STATIC
#include "stb_image_resize.h"
#include "common/common.hpp"
#include "avi_writer.h"
#include "common/common.hpp"
const char* previews_str[] = {
"none",
"proj",
@ -335,94 +323,6 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
fflush(out_stream);
}
uint8_t* load_image(const char* image_path, int& width, int& height, int expected_width = 0, int expected_height = 0, int expected_channel = 3) {
int c = 0;
uint8_t* image_buffer = (uint8_t*)stbi_load(image_path, &width, &height, &c, expected_channel);
if (image_buffer == nullptr) {
fprintf(stderr, "load image from '%s' failed\n", image_path);
return nullptr;
}
if (c < expected_channel) {
fprintf(stderr,
"the number of channels for the input image must be >= %d,"
"but got %d channels, image_path = %s\n",
expected_channel,
c,
image_path);
free(image_buffer);
return nullptr;
}
if (width <= 0) {
fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path);
free(image_buffer);
return nullptr;
}
if (height <= 0) {
fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path);
free(image_buffer);
return nullptr;
}
// Resize input image ...
if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) {
float dst_aspect = (float)expected_width / (float)expected_height;
float src_aspect = (float)width / (float)height;
int crop_x = 0, crop_y = 0;
int crop_w = width, crop_h = height;
if (src_aspect > dst_aspect) {
crop_w = (int)(height * dst_aspect);
crop_x = (width - crop_w) / 2;
} else if (src_aspect < dst_aspect) {
crop_h = (int)(width / dst_aspect);
crop_y = (height - crop_h) / 2;
}
if (crop_x != 0 || crop_y != 0) {
printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path);
uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel);
if (cropped_image_buffer == nullptr) {
fprintf(stderr, "error: allocate memory for crop\n");
free(image_buffer);
return nullptr;
}
for (int row = 0; row < crop_h; row++) {
uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel;
uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel;
memcpy(dst, src, crop_w * expected_channel);
}
width = crop_w;
height = crop_h;
free(image_buffer);
image_buffer = cropped_image_buffer;
}
printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height);
int resized_height = expected_height;
int resized_width = expected_width;
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel);
if (resized_image_buffer == nullptr) {
fprintf(stderr, "error: allocate memory for resize input image\n");
free(image_buffer);
return nullptr;
}
stbir_resize(image_buffer, width, height, 0,
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
STBIR_COLORSPACE_SRGB, nullptr);
width = resized_width;
height = resized_height;
free(image_buffer);
image_buffer = resized_image_buffer;
}
return image_buffer;
}
bool load_images_from_dir(const std::string dir,
std::vector<sd_image_t>& images,
int expected_width = 0,
@ -457,7 +357,7 @@ bool load_images_from_dir(const std::string dir,
}
int width = 0;
int height = 0;
uint8_t* image_buffer = load_image(path.c_str(), width, height, expected_width, expected_height);
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height, expected_width, expected_height);
if (image_buffer == nullptr) {
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
return false;
@ -593,7 +493,7 @@ int main(int argc, const char* argv[]) {
int width = 0;
int height = 0;
init_image.data = load_image(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (init_image.data == nullptr) {
fprintf(stderr, "load image from '%s' failed\n", gen_params.init_image_path.c_str());
release_all_resources();
@ -606,7 +506,7 @@ int main(int argc, const char* argv[]) {
int width = 0;
int height = 0;
end_image.data = load_image(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (end_image.data == nullptr) {
fprintf(stderr, "load image from '%s' failed\n", gen_params.end_image_path.c_str());
release_all_resources();
@ -618,7 +518,7 @@ int main(int argc, const char* argv[]) {
int c = 0;
int width = 0;
int height = 0;
mask_image.data = load_image(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1);
mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1);
if (mask_image.data == nullptr) {
fprintf(stderr, "load image from '%s' failed\n", gen_params.mask_image_path.c_str());
release_all_resources();
@ -637,7 +537,7 @@ int main(int argc, const char* argv[]) {
if (gen_params.control_image_path.size() > 0) {
int width = 0;
int height = 0;
control_image.data = load_image(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height);
control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (control_image.data == nullptr) {
fprintf(stderr, "load image from '%s' failed\n", gen_params.control_image_path.c_str());
release_all_resources();
@ -658,7 +558,7 @@ int main(int argc, const char* argv[]) {
for (auto& path : gen_params.ref_image_paths) {
int width = 0;
int height = 0;
uint8_t* image_buffer = load_image(path.c_str(), width, height);
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height);
if (image_buffer == nullptr) {
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
release_all_resources();

View File

@ -1,4 +1,5 @@
#include <filesystem>
#include <iostream>
#include <map>
#include <random>
@ -6,7 +7,6 @@
#include <sstream>
#include <string>
#include <vector>
#include <filesystem>
#include <json.hpp>
using json = nlohmann::json;
@ -19,6 +19,18 @@ namespace fs = std::filesystem;
#include "stable-diffusion.h"
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#define STB_IMAGE_WRITE_STATIC
#include "stb_image_write.h"
#define STB_IMAGE_RESIZE_IMPLEMENTATION
#define STB_IMAGE_RESIZE_STATIC
#include "stb_image_resize.h"
#define SAFE_STR(s) ((s) ? (s) : "")
#define BOOL_STR(b) ((b) ? "true" : "false")
@ -1612,3 +1624,123 @@ struct SDGenerationParams {
static std::string version_string() {
return std::string("stable-diffusion.cpp version ") + sd_version() + ", commit " + sd_commit();
}
uint8_t* load_image_common(bool from_memory,
const char* image_path_or_bytes,
int& width,
int& height,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3) {
int c = 0;
const char* image_path;
uint8_t* image_buffer = nullptr;
if (from_memory) {
image_path = "memory";
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
} else {
image_path = image_path_or_bytes;
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
}
if (image_buffer == nullptr) {
fprintf(stderr, "load image from '%s' failed\n", image_path);
return nullptr;
}
if (c < expected_channel) {
fprintf(stderr,
"the number of channels for the input image must be >= %d,"
"but got %d channels, image_path = %s\n",
expected_channel,
c,
image_path);
free(image_buffer);
return nullptr;
}
if (width <= 0) {
fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path);
free(image_buffer);
return nullptr;
}
if (height <= 0) {
fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path);
free(image_buffer);
return nullptr;
}
// Resize input image ...
if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) {
float dst_aspect = (float)expected_width / (float)expected_height;
float src_aspect = (float)width / (float)height;
int crop_x = 0, crop_y = 0;
int crop_w = width, crop_h = height;
if (src_aspect > dst_aspect) {
crop_w = (int)(height * dst_aspect);
crop_x = (width - crop_w) / 2;
} else if (src_aspect < dst_aspect) {
crop_h = (int)(width / dst_aspect);
crop_y = (height - crop_h) / 2;
}
if (crop_x != 0 || crop_y != 0) {
printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path);
uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel);
if (cropped_image_buffer == nullptr) {
fprintf(stderr, "error: allocate memory for crop\n");
free(image_buffer);
return nullptr;
}
for (int row = 0; row < crop_h; row++) {
uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel;
uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel;
memcpy(dst, src, crop_w * expected_channel);
}
width = crop_w;
height = crop_h;
free(image_buffer);
image_buffer = cropped_image_buffer;
}
printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height);
int resized_height = expected_height;
int resized_width = expected_width;
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel);
if (resized_image_buffer == nullptr) {
fprintf(stderr, "error: allocate memory for resize input image\n");
free(image_buffer);
return nullptr;
}
stbir_resize(image_buffer, width, height, 0,
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
STBIR_COLORSPACE_SRGB, nullptr);
width = resized_width;
height = resized_height;
free(image_buffer);
image_buffer = resized_image_buffer;
}
return image_buffer;
}
uint8_t* load_image_from_file(const char* image_path,
int& width,
int& height,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3) {
return load_image_common(false, image_path, width, height, expected_width, expected_height, expected_channel);
}
uint8_t* load_image_from_memory(const char* image_bytes,
int& width,
int& height,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3) {
return load_image_common(true, image_bytes, width, height, expected_width, expected_height, expected_channel);
}

View File

@ -8,14 +8,6 @@
#include <sstream>
#include <vector>
#define STB_IMAGE_IMPLEMENTATION
#define STB_IMAGE_STATIC
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#define STB_IMAGE_WRITE_STATIC
#include "stb_image_write.h"
#include "httplib.h"
#include "stable-diffusion.h"
@ -47,6 +39,53 @@ std::string base64_encode(const std::vector<uint8_t>& bytes) {
return ret;
}
inline bool is_base64(unsigned char c) {
return (isalnum(c) || (c == '+') || (c == '/'));
}
std::vector<uint8_t> base64_decode(const std::string& encoded_string) {
int in_len = encoded_string.size();
int i = 0;
int j = 0;
int in_ = 0;
uint8_t char_array_4[4], char_array_3[3];
std::vector<uint8_t> ret;
while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
char_array_4[i++] = encoded_string[in_];
in_++;
if (i == 4) {
for (i = 0; i < 4; i++)
char_array_4[i] = static_cast<uint8_t>(base64_chars.find(char_array_4[i]));
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (i = 0; i < 3; i++)
ret.push_back(char_array_3[i]);
i = 0;
}
}
if (i) {
for (j = i; j < 4; j++)
char_array_4[j] = 0;
for (j = 0; j < 4; j++)
char_array_4[j] = static_cast<uint8_t>(base64_chars.find(char_array_4[j]));
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (j = 0; j < i - 1; j++)
ret.push_back(char_array_3[j]);
}
return ret;
}
std::string iso_timestamp_now() {
using namespace std::chrono;
auto now = system_clock::now();
@ -451,6 +490,187 @@ int main(int argc, const char** argv) {
}
});
svr.Post("/v1/images/edits", [&](const httplib::Request& req, httplib::Response& res) {
try {
if (req.body.empty()) {
res.status = 400;
res.set_content(R"({"error":"empty body"})", "application/json");
return;
}
json j = json::parse(req.body);
std::string prompt = j.value("prompt", "");
int n = std::max(1, j.value("n", 1));
std::string size = j.value("size", "");
std::string output_format = j.value("output_format", "png");
int output_compression = j.value("output_compression", 100);
std::string ref_image_b64 = j.value("image", "");
std::string mask_image_b64 = j.value("mask", "");
if (prompt.empty()) {
res.status = 400;
res.set_content(R"({"error":"prompt required"})", "application/json");
return;
}
if (ref_image_b64.empty()) {
res.status = 400;
res.set_content(R"({"error":"image required"})", "application/json");
return;
}
int width = 512;
int height = 512;
if (!size.empty()) {
auto pos = size.find('x');
if (pos != std::string::npos) {
try {
width = std::stoi(size.substr(0, pos));
height = std::stoi(size.substr(pos + 1));
} catch (...) {
}
}
}
if (output_format != "png" && output_format != "jpeg") {
res.status = 400;
res.set_content(R"({"error":"invalid output_format, must be one of [png, jpeg]"})", "application/json");
return;
}
if (n <= 0)
n = 1;
if (n > 8)
n = 8;
if (output_compression > 100)
output_compression = 100;
if (output_compression < 0)
output_compression = 0;
// base64 -> raw image
std::vector<uint8_t> ref_image_bytes = base64_decode(ref_image_b64);
int img_w = width;
int img_h = height;
uint8_t* raw_pixels = load_image_from_memory(
reinterpret_cast<const char*>(ref_image_bytes.data()),
img_w, img_h,
width, height, 3);
sd_image_t ref_image;
ref_image.width = img_w;
ref_image.height = img_h;
ref_image.channel = 3;
ref_image.data = raw_pixels;
sd_image_t mask_image = {0};
if (!mask_image_b64.empty()) {
std::vector<uint8_t> mask_bytes = base64_decode(mask_image_b64);
int mask_w = width, mask_h = height;
uint8_t* mask_raw = load_image_from_memory(
reinterpret_cast<const char*>(mask_bytes.data()),
mask_w, mask_h,
width, height, 1);
mask_image.width = mask_w;
mask_image.height = mask_h;
mask_image.channel = 1;
mask_image.data = mask_raw;
} else {
mask_image.width = width;
mask_image.height = height;
mask_image.channel = 1;
mask_image.data = nullptr;
}
SDGenerationParams gen_params;
gen_params.prompt = prompt;
gen_params.width = width;
gen_params.height = height;
gen_params.batch_count = n;
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
std::vector<sd_image_t> ref_images = {ref_image};
std::vector<sd_image_t> pmid_images;
sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(),
static_cast<uint32_t>(gen_params.lora_vec.size()),
gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(),
gen_params.clip_skip,
init_image,
ref_images.data(),
(int)ref_images.size(),
gen_params.auto_resize_ref_image,
gen_params.increase_ref_index,
mask_image,
gen_params.width,
gen_params.height,
gen_params.sample_params,
gen_params.strength,
gen_params.seed,
gen_params.batch_count,
control_image,
gen_params.control_strength,
{
pmid_images.data(),
(int)pmid_images.size(),
gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength,
}, // pm_params
ctx_params.vae_tiling_params,
gen_params.easycache_params,
};
sd_image_t* results = nullptr;
int num_results = 0;
{
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
json out;
out["created"] = iso_timestamp_now();
out["data"] = json::array();
out["output_format"] = output_format;
for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr)
continue;
auto image_bytes = write_image_to_vector(output_format == "jpeg" ? ImageFormat::JPEG : ImageFormat::PNG,
results[i].data,
results[i].width,
results[i].height,
results[i].channel,
output_compression);
std::string b64 = base64_encode(image_bytes);
json item;
item["b64_json"] = b64;
out["data"].push_back(item);
}
res.set_content(out.dump(), "application/json");
res.status = 200;
if (init_image.data) {
stbi_image_free(init_image.data);
}
if (mask_image.data) {
stbi_image_free(mask_image.data);
}
} catch (const std::exception& e) {
res.status = 500;
json err;
err["error"] = "server_error";
err["message"] = e.what();
res.set_content(err.dump(), "application/json");
}
});
printf("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
svr.listen(svr_params.listen_ip, svr_params.listen_port);