refactor: apply RAII ownership to examples (#1392)

This commit is contained in:
leejet 2026-04-06 20:33:46 +08:00 committed by GitHub
parent 7397ddaa86
commit 359eb8b8de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 510 additions and 358 deletions

View File

@ -17,6 +17,7 @@
#include "common/common.hpp" #include "common/common.hpp"
#include "common/media_io.h" #include "common/media_io.h"
#include "common/resource_owners.hpp"
#include "image_metadata.h" #include "image_metadata.h"
const char* previews_str[] = { const char* previews_str[] = {
@ -275,7 +276,7 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
} }
bool load_images_from_dir(const std::string dir, bool load_images_from_dir(const std::string dir,
std::vector<sd_image_t>& images, SDImageVec& images,
int expected_width = 0, int expected_width = 0,
int expected_height = 0, int expected_height = 0,
int max_image_num = 0, int max_image_num = 0,
@ -317,7 +318,7 @@ bool load_images_from_dir(const std::string dir,
3, 3,
image_buffer}); image_buffer});
if (max_image_num > 0 && images.size() >= max_image_num) { if (max_image_num > 0 && static_cast<int>(images.size()) >= max_image_num) {
break; break;
} }
} }
@ -555,38 +556,16 @@ int main(int argc, const char* argv[]) {
} }
bool vae_decode_only = true; bool vae_decode_only = true;
sd_image_t init_image = {0, 0, 3, nullptr}; SDImageOwner init_image({0, 0, 3, nullptr});
sd_image_t end_image = {0, 0, 3, nullptr}; SDImageOwner end_image({0, 0, 3, nullptr});
sd_image_t control_image = {0, 0, 3, nullptr}; SDImageOwner control_image({0, 0, 3, nullptr});
sd_image_t mask_image = {0, 0, 1, nullptr}; SDImageOwner mask_image({0, 0, 1, nullptr});
std::vector<sd_image_t> ref_images; SDImageVec ref_images;
std::vector<sd_image_t> pmid_images; SDImageVec pmid_images;
std::vector<sd_image_t> control_frames; SDImageVec control_frames;
auto release_all_resources = [&]() {
free(init_image.data);
free(end_image.data);
free(control_image.data);
free(mask_image.data);
for (auto image : ref_images) {
free(image.data);
image.data = nullptr;
}
ref_images.clear();
for (auto image : pmid_images) {
free(image.data);
image.data = nullptr;
}
pmid_images.clear();
for (auto image : control_frames) {
free(image.data);
image.data = nullptr;
}
control_frames.clear();
};
auto load_image_and_update_size = [&](const std::string& path, auto load_image_and_update_size = [&](const std::string& path,
sd_image_t& image, SDImageOwner& image,
bool resize_image = true, bool resize_image = true,
int expected_channel = 3) -> bool { int expected_channel = 3) -> bool {
int expected_width = 0; int expected_width = 0;
@ -596,13 +575,12 @@ int main(int argc, const char* argv[]) {
expected_height = gen_params.height; expected_height = gen_params.height;
} }
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) { if (!load_sd_image_from_file(image.put(), path.c_str(), expected_width, expected_height, expected_channel)) {
LOG_ERROR("load image from '%s' failed", path.c_str()); LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return false; return false;
} }
gen_params.set_width_and_height_if_unset(image.width, image.height); gen_params.set_width_and_height_if_unset(image.get().width, image.get().height);
return true; return true;
}; };
@ -623,47 +601,46 @@ int main(int argc, const char* argv[]) {
if (gen_params.ref_image_paths.size() > 0) { if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false; vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) { for (auto& path : gen_params.ref_image_paths) {
sd_image_t ref_image = {0, 0, 3, nullptr}; SDImageOwner ref_image({0, 0, 3, nullptr});
if (!load_image_and_update_size(path, ref_image, false)) { if (!load_image_and_update_size(path, ref_image, false)) {
return 1; return 1;
} }
ref_images.push_back(ref_image); ref_images.push_back(std::move(ref_image));
} }
} }
if (gen_params.mask_image_path.size() > 0) { if (gen_params.mask_image_path.size() > 0) {
if (!load_sd_image_from_file(&mask_image, if (!load_sd_image_from_file(mask_image.put(),
gen_params.mask_image_path.c_str(), gen_params.mask_image_path.c_str(),
gen_params.get_resolved_width(), gen_params.get_resolved_width(),
gen_params.get_resolved_height(), gen_params.get_resolved_height(),
1)) { 1)) {
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str()); LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
release_all_resources();
return 1; return 1;
} }
} else { } else {
mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height()); sd_image_t generated_mask = {0, 0, 1, nullptr};
if (mask_image.data == nullptr) { generated_mask.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
if (generated_mask.data == nullptr) {
LOG_ERROR("malloc mask image failed"); LOG_ERROR("malloc mask image failed");
release_all_resources();
return 1; return 1;
} }
mask_image.width = gen_params.get_resolved_width(); generated_mask.width = gen_params.get_resolved_width();
mask_image.height = gen_params.get_resolved_height(); generated_mask.height = gen_params.get_resolved_height();
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height()); memset(generated_mask.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
mask_image.reset(generated_mask);
} }
if (gen_params.control_image_path.size() > 0) { if (gen_params.control_image_path.size() > 0) {
if (!load_sd_image_from_file(&control_image, if (!load_sd_image_from_file(control_image.put(),
gen_params.control_image_path.c_str(), gen_params.control_image_path.c_str(),
gen_params.get_resolved_width(), gen_params.get_resolved_width(),
gen_params.get_resolved_height())) { gen_params.get_resolved_height())) {
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str()); LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
release_all_resources();
return 1; return 1;
} }
if (cli_params.canny_preprocess) { // apply preprocessor if (cli_params.canny_preprocess) { // apply preprocessor
preprocess_canny(control_image, preprocess_canny(control_image.get(),
0.08f, 0.08f,
0.08f, 0.08f,
0.8f, 0.8f,
@ -679,7 +656,6 @@ int main(int argc, const char* argv[]) {
gen_params.get_resolved_height(), gen_params.get_resolved_height(),
gen_params.video_frames, gen_params.video_frames,
cli_params.verbose)) { cli_params.verbose)) {
release_all_resources();
return 1; return 1;
} }
} }
@ -691,7 +667,6 @@ int main(int argc, const char* argv[]) {
0, 0,
0, 0,
cli_params.verbose)) { cli_params.verbose)) {
release_all_resources();
return 1; return 1;
} }
} }
@ -702,39 +677,30 @@ int main(int argc, const char* argv[]) {
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview); sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview);
sd_image_t* results = nullptr; SDImageVec results;
int num_results = 0; int num_results = 0;
if (cli_params.mode == UPSCALE) { if (cli_params.mode == UPSCALE) {
num_results = 1; num_results = 1;
results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t)); results.push_back(init_image.release());
if (results == nullptr) {
LOG_INFO("failed to allocate results array");
release_all_resources();
return 1;
}
results[0] = init_image;
init_image.data = nullptr;
} else { } else {
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params));
if (sd_ctx == nullptr) { if (sd_ctx == nullptr) {
LOG_INFO("new_sd_ctx_t failed"); LOG_INFO("new_sd_ctx_t failed");
release_all_resources();
return 1; return 1;
} }
if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) { if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx); gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get());
} }
if (gen_params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) { if (gen_params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx); gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get());
} }
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) { if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method); gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx.get(), gen_params.sample_params.sample_method);
} }
if (cli_params.mode == IMG_GEN) { if (cli_params.mode == IMG_GEN) {
@ -744,19 +710,19 @@ int main(int argc, const char* argv[]) {
gen_params.prompt.c_str(), gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(), gen_params.negative_prompt.c_str(),
gen_params.clip_skip, gen_params.clip_skip,
init_image, init_image.get(),
ref_images.data(), ref_images.data(),
(int)ref_images.size(), (int)ref_images.size(),
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image.get(),
gen_params.get_resolved_width(), gen_params.get_resolved_width(),
gen_params.get_resolved_height(), gen_params.get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
gen_params.batch_count, gen_params.batch_count,
control_image, control_image.get(),
gen_params.control_strength, gen_params.control_strength,
{ {
pmid_images.data(), pmid_images.data(),
@ -768,8 +734,8 @@ int main(int argc, const char* argv[]) {
gen_params.cache_params, gen_params.cache_params,
}; };
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results);
} else if (cli_params.mode == VID_GEN) { } else if (cli_params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = { sd_vid_gen_params_t vid_gen_params = {
gen_params.lora_vec.data(), gen_params.lora_vec.data(),
@ -777,8 +743,8 @@ int main(int argc, const char* argv[]) {
gen_params.prompt.c_str(), gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(), gen_params.negative_prompt.c_str(),
gen_params.clip_skip, gen_params.clip_skip,
init_image, init_image.get(),
end_image, end_image.get(),
control_frames.data(), control_frames.data(),
(int)control_frames.size(), (int)control_frames.size(),
gen_params.get_resolved_width(), gen_params.get_resolved_width(),
@ -794,25 +760,23 @@ int main(int argc, const char* argv[]) {
gen_params.cache_params, gen_params.cache_params,
}; };
results = generate_video(sd_ctx, &vid_gen_params, &num_results); sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results);
results.adopt(generated_video, num_results);
} }
if (results == nullptr) { if (!results) {
LOG_ERROR("generate failed"); LOG_ERROR("generate failed");
free_sd_ctx(sd_ctx);
return 1; return 1;
} }
free_sd_ctx(sd_ctx);
} }
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
if (ctx_params.esrgan_path.size() > 0 && gen_params.upscale_repeats > 0) { if (ctx_params.esrgan_path.size() > 0 && gen_params.upscale_repeats > 0) {
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(ctx_params.esrgan_path.c_str(), UpscalerCtxPtr upscaler_ctx(new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
ctx_params.offload_params_to_cpu, ctx_params.offload_params_to_cpu,
ctx_params.diffusion_conv_direct, ctx_params.diffusion_conv_direct,
ctx_params.n_threads, ctx_params.n_threads,
gen_params.upscale_tile_size); gen_params.upscale_tile_size));
if (upscaler_ctx == nullptr) { if (upscaler_ctx == nullptr) {
LOG_ERROR("new_upscaler_ctx failed"); LOG_ERROR("new_upscaler_ctx failed");
@ -821,32 +785,24 @@ int main(int argc, const char* argv[]) {
if (results[i].data == nullptr) { if (results[i].data == nullptr) {
continue; continue;
} }
sd_image_t current_image = results[i]; SDImageOwner current_image(results[i]);
results[i] = {0, 0, 0, nullptr};
for (int u = 0; u < gen_params.upscale_repeats; ++u) { for (int u = 0; u < gen_params.upscale_repeats; ++u) {
sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor); SDImageOwner upscaled_image(upscale(upscaler_ctx.get(), current_image.get(), upscale_factor));
if (upscaled_image.data == nullptr) { if (upscaled_image.get().data == nullptr) {
LOG_ERROR("upscale failed"); LOG_ERROR("upscale failed");
break; break;
} }
free(current_image.data); current_image = std::move(upscaled_image);
current_image = upscaled_image;
} }
results[i] = current_image; // Set the final upscaled image as the result results[i] = current_image.release(); // Set the final upscaled image as the result
} }
} }
} }
if (!save_results(cli_params, ctx_params, gen_params, results, num_results)) { if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results)) {
return 1; return 1;
} }
for (int i = 0; i < num_results; i++) {
free(results[i].data);
results[i].data = nullptr;
}
free(results);
release_all_resources();
return 0; return 0;
} }

View File

@ -20,6 +20,7 @@ namespace fs = std::filesystem;
#endif // _WIN32 #endif // _WIN32
#include "log.h" #include "log.h"
#include "resource_owners.hpp"
#include "stable-diffusion.h" #include "stable-diffusion.h"
#define SAFE_STR(s) ((s) ? (s) : "") #define SAFE_STR(s) ((s) ? (s) : "")
@ -1751,8 +1752,8 @@ struct SDGenerationParams {
} }
std::string to_string() const { std::string to_string() const {
char* sample_params_str = sd_sample_params_to_str(&sample_params); FreeUniquePtr<char> sample_params_str(sd_sample_params_to_str(&sample_params));
char* high_noise_sample_params_str = sd_sample_params_to_str(&high_noise_sample_params); FreeUniquePtr<char> high_noise_sample_params_str(sd_sample_params_to_str(&high_noise_sample_params));
std::ostringstream lora_ss; std::ostringstream lora_ss;
lora_ss << "{\n"; lora_ss << "{\n";
@ -1801,9 +1802,9 @@ struct SDGenerationParams {
<< " pm_id_embed_path: \"" << pm_id_embed_path << "\",\n" << " pm_id_embed_path: \"" << pm_id_embed_path << "\",\n"
<< " pm_style_strength: " << pm_style_strength << ",\n" << " pm_style_strength: " << pm_style_strength << ",\n"
<< " skip_layers: " << vec_to_string(skip_layers) << ",\n" << " skip_layers: " << vec_to_string(skip_layers) << ",\n"
<< " sample_params: " << sample_params_str << ",\n" << " sample_params: " << SAFE_STR(sample_params_str.get()) << ",\n"
<< " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n" << " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n"
<< " high_noise_sample_params: " << high_noise_sample_params_str << ",\n" << " high_noise_sample_params: " << SAFE_STR(high_noise_sample_params_str.get()) << ",\n"
<< " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n" << " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n"
<< " cache_mode: \"" << cache_mode << "\",\n" << " cache_mode: \"" << cache_mode << "\",\n"
<< " cache_option: \"" << cache_option << "\",\n" << " cache_option: \"" << cache_option << "\",\n"
@ -1829,8 +1830,6 @@ struct SDGenerationParams {
<< vae_tiling_params.rel_size_x << ", " << vae_tiling_params.rel_size_x << ", "
<< vae_tiling_params.rel_size_y << " },\n" << vae_tiling_params.rel_size_y << " },\n"
<< "}"; << "}";
free(sample_params_str);
free(high_noise_sample_params_str);
return oss.str(); return oss.str();
} }
}; };

View File

@ -1,5 +1,7 @@
#include "log.h" #include "log.h"
#include <vector>
bool log_verbose = false; bool log_verbose = false;
bool log_color = false; bool log_color = false;
@ -34,17 +36,12 @@ void print_utf8(FILE* stream, const char* utf8) {
return; return;
} }
wchar_t* wbuf = (wchar_t*)malloc(wlen * sizeof(wchar_t)); std::vector<wchar_t> wbuf(static_cast<size_t>(wlen));
if (!wbuf) {
return;
}
MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf, wlen); MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf.data(), wlen);
DWORD written; DWORD written;
WriteConsoleW(h, wbuf, wlen - 1, &written, NULL); WriteConsoleW(h, wbuf.data(), wlen - 1, &written, NULL);
free(wbuf);
} else { } else {
DWORD written; DWORD written;
WriteFile(h, utf8, (DWORD)strlen(utf8), &written, NULL); WriteFile(h, utf8, (DWORD)strlen(utf8), &written, NULL);

View File

@ -1,5 +1,6 @@
#include "log.h" #include "log.h"
#include "media_io.h" #include "media_io.h"
#include "resource_owners.hpp"
#include <algorithm> #include <algorithm>
#include <cctype> #include <cctype>
@ -38,6 +39,63 @@
namespace fs = std::filesystem; namespace fs = std::filesystem;
namespace { namespace {
#ifdef SD_USE_WEBP
struct WebPFreeDeleter {
void operator()(void* ptr) const {
if (ptr != nullptr) {
WebPFree(ptr);
}
}
};
struct WebPMuxDeleter {
void operator()(WebPMux* mux) const {
if (mux != nullptr) {
WebPMuxDelete(mux);
}
}
};
struct WebPAnimEncoderDeleter {
void operator()(WebPAnimEncoder* enc) const {
if (enc != nullptr) {
WebPAnimEncoderDelete(enc);
}
}
};
struct WebPDataGuard {
WebPDataGuard() {
WebPDataInit(&data);
}
~WebPDataGuard() {
WebPDataClear(&data);
}
WebPData data;
};
struct WebPPictureGuard {
WebPPictureGuard()
: initialized(WebPPictureInit(&picture) != 0) {
}
~WebPPictureGuard() {
if (initialized) {
WebPPictureFree(&picture);
}
}
WebPPicture picture;
bool initialized;
};
using WebPBufferPtr = std::unique_ptr<uint8_t, WebPFreeDeleter>;
using WebPMuxPtr = std::unique_ptr<WebPMux, WebPMuxDeleter>;
using WebPAnimEncoderPtr = std::unique_ptr<WebPAnimEncoder, WebPAnimEncoderDeleter>;
#endif
bool read_binary_file_bytes(const char* path, std::vector<uint8_t>& data) { bool read_binary_file_bytes(const char* path, std::vector<uint8_t>& data) {
std::ifstream fin(fs::path(path), std::ios::binary); std::ifstream fin(fs::path(path), std::ios::binary);
if (!fin) { if (!fin) {
@ -158,27 +216,25 @@ uint8_t* decode_webp_image_to_buffer(const uint8_t* data,
if (expected_channel == 1) { if (expected_channel == 1) {
int decoded_width = width; int decoded_width = width;
int decoded_height = height; int decoded_height = height;
uint8_t* decoded = features.has_alpha WebPBufferPtr decoded(features.has_alpha
? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height) ? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height)
: WebPDecodeRGB(data, size, &decoded_width, &decoded_height); : WebPDecodeRGB(data, size, &decoded_width, &decoded_height));
if (decoded == nullptr) { if (decoded == nullptr) {
return nullptr; return nullptr;
} }
uint8_t* grayscale = (uint8_t*)malloc(pixel_count); FreeUniquePtr<uint8_t> grayscale((uint8_t*)malloc(pixel_count));
if (grayscale == nullptr) { if (grayscale == nullptr) {
WebPFree(decoded);
return nullptr; return nullptr;
} }
const int decoded_channels = features.has_alpha ? 4 : 3; const int decoded_channels = features.has_alpha ? 4 : 3;
for (size_t i = 0; i < pixel_count; ++i) { for (size_t i = 0; i < pixel_count; ++i) {
const uint8_t* src = decoded + i * decoded_channels; const uint8_t* src = decoded.get() + i * decoded_channels;
grayscale[i] = static_cast<uint8_t>((77 * src[0] + 150 * src[1] + 29 * src[2] + 128) >> 8); grayscale.get()[i] = static_cast<uint8_t>((77 * src[0] + 150 * src[1] + 29 * src[2] + 128) >> 8);
} }
WebPFree(decoded); return grayscale.release();
return grayscale;
} }
if (expected_channel != 3 && expected_channel != 4) { if (expected_channel != 3 && expected_channel != 4) {
@ -187,23 +243,21 @@ uint8_t* decode_webp_image_to_buffer(const uint8_t* data,
int decoded_width = width; int decoded_width = width;
int decoded_height = height; int decoded_height = height;
uint8_t* decoded = (expected_channel == 4) WebPBufferPtr decoded((expected_channel == 4)
? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height) ? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height)
: WebPDecodeRGB(data, size, &decoded_width, &decoded_height); : WebPDecodeRGB(data, size, &decoded_width, &decoded_height));
if (decoded == nullptr) { if (decoded == nullptr) {
return nullptr; return nullptr;
} }
const size_t out_size = pixel_count * static_cast<size_t>(expected_channel); const size_t out_size = pixel_count * static_cast<size_t>(expected_channel);
uint8_t* output = (uint8_t*)malloc(out_size); FreeUniquePtr<uint8_t> output((uint8_t*)malloc(out_size));
if (output == nullptr) { if (output == nullptr) {
WebPFree(decoded);
return nullptr; return nullptr;
} }
memcpy(output, decoded, out_size); memcpy(output.get(), decoded.get(), out_size);
WebPFree(decoded); return output.release();
return output;
} }
std::string build_webp_xmp_packet(const std::string& parameters) { std::string build_webp_xmp_packet(const std::string& parameters) {
@ -255,30 +309,29 @@ bool encode_webp_image_to_vector(const uint8_t* image,
return false; return false;
} }
uint8_t* encoded = nullptr; uint8_t* encoded_raw = nullptr;
size_t encoded_size = (input_channels == 4) size_t encoded_size = (input_channels == 4)
? WebPEncodeRGBA(input_image, width, height, width * input_channels, static_cast<float>(quality), &encoded) ? WebPEncodeRGBA(input_image, width, height, width * input_channels, static_cast<float>(quality), &encoded_raw)
: WebPEncodeRGB(input_image, width, height, width * input_channels, static_cast<float>(quality), &encoded); : WebPEncodeRGB(input_image, width, height, width * input_channels, static_cast<float>(quality), &encoded_raw);
WebPBufferPtr encoded(encoded_raw);
if (encoded == nullptr || encoded_size == 0) { if (encoded == nullptr || encoded_size == 0) {
return false; return false;
} }
out.assign(encoded, encoded + encoded_size); out.assign(encoded.get(), encoded.get() + encoded_size);
WebPFree(encoded);
if (parameters.empty()) { if (parameters.empty()) {
return true; return true;
} }
WebPData image_data; WebPData image_data;
WebPData assembled_data;
WebPDataInit(&image_data); WebPDataInit(&image_data);
WebPDataInit(&assembled_data); WebPDataGuard assembled_data;
image_data.bytes = out.data(); image_data.bytes = out.data();
image_data.size = out.size(); image_data.size = out.size();
WebPMux* mux = WebPMuxNew(); WebPMuxPtr mux(WebPMuxNew());
if (mux == nullptr) { if (mux == nullptr) {
return false; return false;
} }
@ -289,16 +342,14 @@ bool encode_webp_image_to_vector(const uint8_t* image,
xmp_data.bytes = reinterpret_cast<const uint8_t*>(xmp_packet.data()); xmp_data.bytes = reinterpret_cast<const uint8_t*>(xmp_packet.data());
xmp_data.size = xmp_packet.size(); xmp_data.size = xmp_packet.size();
const bool ok = WebPMuxSetImage(mux, &image_data, 1) == WEBP_MUX_OK && const bool ok = WebPMuxSetImage(mux.get(), &image_data, 1) == WEBP_MUX_OK &&
WebPMuxSetChunk(mux, "XMP ", &xmp_data, 1) == WEBP_MUX_OK && WebPMuxSetChunk(mux.get(), "XMP ", &xmp_data, 1) == WEBP_MUX_OK &&
WebPMuxAssemble(mux, &assembled_data) == WEBP_MUX_OK; WebPMuxAssemble(mux.get(), &assembled_data.data) == WEBP_MUX_OK;
if (ok) { if (ok) {
out.assign(assembled_data.bytes, assembled_data.bytes + assembled_data.size); out.assign(assembled_data.data.bytes, assembled_data.data.bytes + assembled_data.data.size);
} }
WebPDataClear(&assembled_data);
WebPMuxDelete(mux);
return ok; return ok;
} }
@ -382,19 +433,19 @@ uint8_t* load_image_common(bool from_memory,
int expected_height, int expected_height,
int expected_channel) { int expected_channel) {
const char* image_path; const char* image_path;
uint8_t* image_buffer = nullptr; FreeUniquePtr<uint8_t> image_buffer;
int source_channel_count = 0; int source_channel_count = 0;
#ifdef SD_USE_WEBP #ifdef SD_USE_WEBP
if (from_memory) { if (from_memory) {
image_path = "memory"; image_path = "memory";
if (len > 0 && is_webp_signature(reinterpret_cast<const uint8_t*>(image_path_or_bytes), static_cast<size_t>(len))) { if (len > 0 && is_webp_signature(reinterpret_cast<const uint8_t*>(image_path_or_bytes), static_cast<size_t>(len))) {
image_buffer = decode_webp_image_to_buffer(reinterpret_cast<const uint8_t*>(image_path_or_bytes), image_buffer.reset(decode_webp_image_to_buffer(reinterpret_cast<const uint8_t*>(image_path_or_bytes),
static_cast<size_t>(len), static_cast<size_t>(len),
width, width,
height, height,
expected_channel, expected_channel,
source_channel_count); source_channel_count));
} }
} else { } else {
image_path = image_path_or_bytes; image_path = image_path_or_bytes;
@ -408,12 +459,12 @@ uint8_t* load_image_common(bool from_memory,
LOG_ERROR("load image from '%s' failed", image_path_or_bytes); LOG_ERROR("load image from '%s' failed", image_path_or_bytes);
return nullptr; return nullptr;
} }
image_buffer = decode_webp_image_to_buffer(file_bytes.data(), image_buffer.reset(decode_webp_image_to_buffer(file_bytes.data(),
file_bytes.size(), file_bytes.size(),
width, width,
height, height,
expected_channel, expected_channel,
source_channel_count); source_channel_count));
} }
} }
#endif #endif
@ -422,14 +473,14 @@ uint8_t* load_image_common(bool from_memory,
image_path = "memory"; image_path = "memory";
if (image_buffer == nullptr) { if (image_buffer == nullptr) {
int c = 0; int c = 0;
image_buffer = (uint8_t*)stbi_load_from_memory((const stbi_uc*)image_path_or_bytes, len, &width, &height, &c, expected_channel); image_buffer.reset((uint8_t*)stbi_load_from_memory((const stbi_uc*)image_path_or_bytes, len, &width, &height, &c, expected_channel));
source_channel_count = c; source_channel_count = c;
} }
} else { } else {
image_path = image_path_or_bytes; image_path = image_path_or_bytes;
if (image_buffer == nullptr) { if (image_buffer == nullptr) {
int c = 0; int c = 0;
image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel); image_buffer.reset((uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel));
source_channel_count = c; source_channel_count = c;
} }
} }
@ -444,17 +495,14 @@ uint8_t* load_image_common(bool from_memory,
expected_channel, expected_channel,
source_channel_count, source_channel_count,
image_path); image_path);
free(image_buffer);
return nullptr; return nullptr;
} }
if (width <= 0) { if (width <= 0) {
LOG_ERROR("error: the width of image must be greater than 0, image_path = %s", image_path); LOG_ERROR("error: the width of image must be greater than 0, image_path = %s", image_path);
free(image_buffer);
return nullptr; return nullptr;
} }
if (height <= 0) { if (height <= 0) {
LOG_ERROR("error: the height of image must be greater than 0, image_path = %s", image_path); LOG_ERROR("error: the height of image must be greater than 0, image_path = %s", image_path);
free(image_buffer);
return nullptr; return nullptr;
} }
@ -475,43 +523,39 @@ uint8_t* load_image_common(bool from_memory,
if (crop_x != 0 || crop_y != 0) { if (crop_x != 0 || crop_y != 0) {
LOG_INFO("crop input image from %dx%d to %dx%d, image_path = %s", width, height, crop_w, crop_h, image_path); LOG_INFO("crop input image from %dx%d to %dx%d, image_path = %s", width, height, crop_w, crop_h, image_path);
uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel); FreeUniquePtr<uint8_t> cropped_image_buffer((uint8_t*)malloc(crop_w * crop_h * expected_channel));
if (cropped_image_buffer == nullptr) { if (cropped_image_buffer == nullptr) {
LOG_ERROR("error: allocate memory for crop\n"); LOG_ERROR("error: allocate memory for crop\n");
free(image_buffer);
return nullptr; return nullptr;
} }
for (int row = 0; row < crop_h; row++) { for (int row = 0; row < crop_h; row++) {
uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel; uint8_t* src = image_buffer.get() + ((crop_y + row) * width + crop_x) * expected_channel;
uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel; uint8_t* dst = cropped_image_buffer.get() + (row * crop_w) * expected_channel;
memcpy(dst, src, crop_w * expected_channel); memcpy(dst, src, crop_w * expected_channel);
} }
width = crop_w; width = crop_w;
height = crop_h; height = crop_h;
free(image_buffer); image_buffer = std::move(cropped_image_buffer);
image_buffer = cropped_image_buffer;
} }
LOG_INFO("resize input image from %dx%d to %dx%d", width, height, expected_width, expected_height); LOG_INFO("resize input image from %dx%d to %dx%d", width, height, expected_width, expected_height);
uint8_t* resized_image_buffer = (uint8_t*)malloc(expected_height * expected_width * expected_channel); FreeUniquePtr<uint8_t> resized_image_buffer((uint8_t*)malloc(expected_height * expected_width * expected_channel));
if (resized_image_buffer == nullptr) { if (resized_image_buffer == nullptr) {
LOG_ERROR("error: allocate memory for resize input image\n"); LOG_ERROR("error: allocate memory for resize input image\n");
free(image_buffer);
return nullptr; return nullptr;
} }
stbir_resize(image_buffer, width, height, 0, stbir_resize(image_buffer.get(), width, height, 0,
resized_image_buffer, expected_width, expected_height, 0, STBIR_TYPE_UINT8, resized_image_buffer.get(), expected_width, expected_height, 0, STBIR_TYPE_UINT8,
expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0, expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
STBIR_FILTER_BOX, STBIR_FILTER_BOX, STBIR_FILTER_BOX, STBIR_FILTER_BOX,
STBIR_COLORSPACE_SRGB, nullptr); STBIR_COLORSPACE_SRGB, nullptr);
width = expected_width; width = expected_width;
height = expected_height; height = expected_height;
free(image_buffer); image_buffer = std::move(resized_image_buffer);
image_buffer = resized_image_buffer;
} }
return image_buffer; return image_buffer.release();
} }
typedef struct { typedef struct {
@ -662,18 +706,18 @@ int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int
return -1; return -1;
} }
FILE* f = fopen(filename, "wb"); FilePtr file(fopen(filename, "wb"));
if (!f) { if (!file) {
perror("Error opening file for writing"); perror("Error opening file for writing");
return -1; return -1;
} }
FILE* f = file.get();
uint32_t width = images[0].width; uint32_t width = images[0].width;
uint32_t height = images[0].height; uint32_t height = images[0].height;
uint32_t channels = images[0].channel; uint32_t channels = images[0].channel;
if (channels != 3 && channels != 4) { if (channels != 3 && channels != 4) {
fprintf(stderr, "Error: Unsupported channel count: %u\n", channels); fprintf(stderr, "Error: Unsupported channel count: %u\n", channels);
fclose(f);
return -1; return -1;
} }
@ -746,41 +790,32 @@ int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int
write_u32_le(f, 0); write_u32_le(f, 0);
fwrite("movi", 4, 1, f); fwrite("movi", 4, 1, f);
avi_index_entry* index = (avi_index_entry*)malloc(sizeof(avi_index_entry) * num_images); std::vector<avi_index_entry> index(static_cast<size_t>(num_images));
if (!index) { std::vector<uint8_t> jpeg_data;
fclose(f);
for (int i = 0; i < num_images; i++) {
jpeg_data.clear();
auto write_to_buf = [](void* context, void* data, int size) {
auto* buffer = reinterpret_cast<std::vector<uint8_t>*>(context);
const uint8_t* src = reinterpret_cast<const uint8_t*>(data);
buffer->insert(buffer->end(), src, src + size);
};
if (!stbi_write_jpg_to_func(write_to_buf, &jpeg_data, images[i].width, images[i].height, channels, images[i].data, quality)) {
fprintf(stderr, "Error: Failed to encode JPEG frame.\n");
return -1; return -1;
} }
struct {
uint8_t* buf;
size_t size;
} jpeg_data;
for (int i = 0; i < num_images; i++) {
jpeg_data.buf = nullptr;
jpeg_data.size = 0;
auto write_to_buf = [](void* context, void* data, int size) {
auto jd = (decltype(jpeg_data)*)context;
jd->buf = (uint8_t*)realloc(jd->buf, jd->size + size);
memcpy(jd->buf + jd->size, data, size);
jd->size += size;
};
stbi_write_jpg_to_func(write_to_buf, &jpeg_data, images[i].width, images[i].height, channels, images[i].data, quality);
fwrite("00dc", 4, 1, f); fwrite("00dc", 4, 1, f);
write_u32_le(f, (uint32_t)jpeg_data.size); write_u32_le(f, (uint32_t)jpeg_data.size());
index[i].offset = ftell(f) - 8; index[i].offset = ftell(f) - 8;
index[i].size = (uint32_t)jpeg_data.size; index[i].size = (uint32_t)jpeg_data.size();
fwrite(jpeg_data.buf, 1, jpeg_data.size, f); fwrite(jpeg_data.data(), 1, jpeg_data.size(), f);
if (jpeg_data.size % 2) { if (jpeg_data.size() % 2) {
fputc(0, f); fputc(0, f);
} }
free(jpeg_data.buf);
} }
long cur_pos = ftell(f); long cur_pos = ftell(f);
@ -804,9 +839,6 @@ int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int
write_u32_le(f, file_size); write_u32_le(f, file_size);
fseek(f, cur_pos, SEEK_SET); fseek(f, cur_pos, SEEK_SET);
fclose(f);
free(index);
return 0; return 0;
} }
@ -847,7 +879,7 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images
return -1; return -1;
} }
WebPAnimEncoder* enc = WebPAnimEncoderNew(width, height, &anim_options); WebPAnimEncoderPtr enc(WebPAnimEncoderNew(width, height, &anim_options));
if (enc == nullptr) { if (enc == nullptr) {
fprintf(stderr, "Error: Could not create WebPAnimEncoder object.\n"); fprintf(stderr, "Error: Could not create WebPAnimEncoder object.\n");
return -1; return -1;
@ -855,23 +887,22 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images
const int frame_duration_ms = std::max(1, static_cast<int>(std::lround(1000.0 / static_cast<double>(fps)))); const int frame_duration_ms = std::max(1, static_cast<int>(std::lround(1000.0 / static_cast<double>(fps))));
int timestamp_ms = 0; int timestamp_ms = 0;
int ret = -1;
for (int i = 0; i < num_images; ++i) { for (int i = 0; i < num_images; ++i) {
const sd_image_t& image = images[i]; const sd_image_t& image = images[i];
if (static_cast<int>(image.width) != width || static_cast<int>(image.height) != height) { if (static_cast<int>(image.width) != width || static_cast<int>(image.height) != height) {
fprintf(stderr, "Error: Frame dimensions do not match.\n"); fprintf(stderr, "Error: Frame dimensions do not match.\n");
goto cleanup; return -1;
} }
WebPPicture picture; WebPPictureGuard picture;
if (!WebPPictureInit(&picture)) { if (!picture.initialized) {
fprintf(stderr, "Error: Failed to initialize WebPPicture.\n"); fprintf(stderr, "Error: Failed to initialize WebPPicture.\n");
goto cleanup; return -1;
} }
picture.use_argb = 1; picture.picture.use_argb = 1;
picture.width = width; picture.picture.width = width;
picture.height = height; picture.picture.height = height;
bool picture_ok = false; bool picture_ok = false;
std::vector<uint8_t> rgb_buffer; std::vector<uint8_t> rgb_buffer;
@ -882,64 +913,48 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images
rgb_buffer[p * 3 + 1] = image.data[p]; rgb_buffer[p * 3 + 1] = image.data[p];
rgb_buffer[p * 3 + 2] = image.data[p]; rgb_buffer[p * 3 + 2] = image.data[p];
} }
picture_ok = WebPPictureImportRGB(&picture, rgb_buffer.data(), width * 3) != 0; picture_ok = WebPPictureImportRGB(&picture.picture, rgb_buffer.data(), width * 3) != 0;
} else if (image.channel == 4) { } else if (image.channel == 4) {
picture_ok = WebPPictureImportRGBA(&picture, image.data, width * 4) != 0; picture_ok = WebPPictureImportRGBA(&picture.picture, image.data, width * 4) != 0;
} else { } else {
picture_ok = WebPPictureImportRGB(&picture, image.data, width * 3) != 0; picture_ok = WebPPictureImportRGB(&picture.picture, image.data, width * 3) != 0;
} }
if (!picture_ok) { if (!picture_ok) {
fprintf(stderr, "Error: Failed to import frame into WebPPicture.\n"); fprintf(stderr, "Error: Failed to import frame into WebPPicture.\n");
WebPPictureFree(&picture); return -1;
goto cleanup;
} }
if (!WebPAnimEncoderAdd(enc, &picture, timestamp_ms, &config)) { if (!WebPAnimEncoderAdd(enc.get(), &picture.picture, timestamp_ms, &config)) {
fprintf(stderr, "Error: Failed to add frame to animated WebP: %s\n", WebPAnimEncoderGetError(enc)); fprintf(stderr, "Error: Failed to add frame to animated WebP: %s\n", WebPAnimEncoderGetError(enc.get()));
WebPPictureFree(&picture); return -1;
goto cleanup;
} }
WebPPictureFree(&picture);
timestamp_ms += frame_duration_ms; timestamp_ms += frame_duration_ms;
} }
if (!WebPAnimEncoderAdd(enc, nullptr, timestamp_ms, nullptr)) { if (!WebPAnimEncoderAdd(enc.get(), nullptr, timestamp_ms, nullptr)) {
fprintf(stderr, "Error: Failed to finalize animated WebP frames: %s\n", WebPAnimEncoderGetError(enc)); fprintf(stderr, "Error: Failed to finalize animated WebP frames: %s\n", WebPAnimEncoderGetError(enc.get()));
goto cleanup; return -1;
} }
{ WebPDataGuard webp_data;
WebPData webp_data; if (!WebPAnimEncoderAssemble(enc.get(), &webp_data.data)) {
WebPDataInit(&webp_data); fprintf(stderr, "Error: Failed to assemble animated WebP: %s\n", WebPAnimEncoderGetError(enc.get()));
if (!WebPAnimEncoderAssemble(enc, &webp_data)) { return -1;
fprintf(stderr, "Error: Failed to assemble animated WebP: %s\n", WebPAnimEncoderGetError(enc));
WebPDataClear(&webp_data);
goto cleanup;
} }
FILE* f = fopen(filename, "wb"); FilePtr f(fopen(filename, "wb"));
if (!f) { if (!f) {
perror("Error opening file for writing"); perror("Error opening file for writing");
WebPDataClear(&webp_data); return -1;
goto cleanup;
} }
if (webp_data.size > 0 && fwrite(webp_data.bytes, 1, webp_data.size, f) != webp_data.size) { if (webp_data.data.size > 0 && fwrite(webp_data.data.bytes, 1, webp_data.data.size, f.get()) != webp_data.data.size) {
fprintf(stderr, "Error: Failed to write animated WebP file.\n"); fprintf(stderr, "Error: Failed to write animated WebP file.\n");
fclose(f); return -1;
WebPDataClear(&webp_data);
goto cleanup;
}
fclose(f);
WebPDataClear(&webp_data);
} }
ret = 0; return 0;
cleanup:
WebPAnimEncoderDelete(enc);
return ret;
} }
#endif #endif

View File

@ -0,0 +1,207 @@
#ifndef __EXAMPLE_RESOURCE_OWNERS_H__
#define __EXAMPLE_RESOURCE_OWNERS_H__
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <utility>
#include <vector>
#include "stable-diffusion.h"
struct FreeDeleter {
void operator()(void* ptr) const {
free(ptr);
}
};
struct FileCloser {
void operator()(FILE* file) const {
if (file != nullptr) {
fclose(file);
}
}
};
struct SDCtxDeleter {
void operator()(sd_ctx_t* ctx) const {
if (ctx != nullptr) {
free_sd_ctx(ctx);
}
}
};
struct UpscalerCtxDeleter {
void operator()(upscaler_ctx_t* ctx) const {
if (ctx != nullptr) {
free_upscaler_ctx(ctx);
}
}
};
template <typename T>
using FreeUniquePtr = std::unique_ptr<T, FreeDeleter>;
using FilePtr = std::unique_ptr<FILE, FileCloser>;
using SDCtxPtr = std::unique_ptr<sd_ctx_t, SDCtxDeleter>;
using UpscalerCtxPtr = std::unique_ptr<upscaler_ctx_t, UpscalerCtxDeleter>;
class SDImageOwner {
public:
SDImageOwner() = default;
explicit SDImageOwner(sd_image_t image)
: image_(image) {
}
SDImageOwner(const SDImageOwner&) = delete;
SDImageOwner& operator=(const SDImageOwner&) = delete;
SDImageOwner(SDImageOwner&& other) noexcept
: image_(other.release()) {
}
SDImageOwner& operator=(SDImageOwner&& other) noexcept {
if (this != &other) {
reset();
image_ = other.release();
}
return *this;
}
~SDImageOwner() {
reset();
}
sd_image_t* put() {
if (image_.data != nullptr) {
free(image_.data);
image_.data = nullptr;
}
image_.width = 0;
image_.height = 0;
return &image_;
}
sd_image_t& get() {
return image_;
}
const sd_image_t& get() const {
return image_;
}
sd_image_t release() {
sd_image_t image = image_;
image_ = {0, 0, 0, nullptr};
return image;
}
void reset(sd_image_t image = {0, 0, 0, nullptr}) {
if (image_.data != nullptr) {
free(image_.data);
}
image_ = image;
}
private:
sd_image_t image_ = {0, 0, 0, nullptr};
};
class SDImageVec {
public:
SDImageVec() = default;
SDImageVec(const SDImageVec&) = delete;
SDImageVec& operator=(const SDImageVec&) = delete;
SDImageVec(SDImageVec&& other) noexcept
: images_(std::move(other.images_)) {
}
SDImageVec& operator=(SDImageVec&& other) noexcept {
if (this != &other) {
clear();
images_ = std::move(other.images_);
}
return *this;
}
~SDImageVec() {
clear();
}
void push_back(sd_image_t image) {
images_.push_back(image);
}
void push_back(SDImageOwner&& image) {
images_.push_back(image.release());
}
void reserve(size_t count) {
images_.reserve(count);
}
void adopt(sd_image_t* images, int count) {
clear();
if (images == nullptr || count <= 0) {
free(images);
return;
}
images_.reserve(static_cast<size_t>(count));
for (int i = 0; i < count; ++i) {
images_.push_back(images[i]);
}
free(images);
}
size_t size() const {
return images_.size();
}
bool empty() const {
return images_.empty();
}
explicit operator bool() const {
return !images_.empty();
}
sd_image_t* data() {
return images_.data();
}
const sd_image_t* data() const {
return images_.data();
}
sd_image_t& operator[](size_t index) {
return images_[index];
}
const sd_image_t& operator[](size_t index) const {
return images_[index];
}
std::vector<sd_image_t>& raw() {
return images_;
}
const std::vector<sd_image_t>& raw() const {
return images_;
}
void clear() {
for (sd_image_t& image : images_) {
free(image.data);
image.data = nullptr;
}
images_.clear();
}
private:
std::vector<sd_image_t> images_;
};
#endif // __EXAMPLE_RESOURCE_OWNERS_H__

View File

@ -13,6 +13,7 @@
#include "common/common.hpp" #include "common/common.hpp"
#include "common/media_io.h" #include "common/media_io.h"
#include "common/resource_owners.hpp"
#ifdef HAVE_INDEX_HTML #ifdef HAVE_INDEX_HTML
#include "frontend/dist/gen_index_html.h" #include "frontend/dist/gen_index_html.h"
@ -286,18 +287,6 @@ std::string get_lora_full_path(ServerRuntime& rt, const std::string& path) {
return (it != rt.lora_cache->end()) ? it->fullpath : ""; return (it != rt.lora_cache->end()) ? it->fullpath : "";
} }
void free_results(sd_image_t* result_images, int num_results) {
if (result_images) {
for (int i = 0; i < num_results; ++i) {
if (result_images[i].data) {
free(result_images[i].data);
result_images[i].data = nullptr;
}
}
}
free(result_images);
}
void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) { void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) {
const std::string serve_html_path = svr_params.serve_html_path; const std::string serve_html_path = svr_params.serve_html_path;
svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) { svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) {
@ -405,10 +394,10 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; SDImageOwner init_image({(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr});
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; SDImageOwner control_image({(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr});
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr}; SDImageOwner mask_image({(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr});
std::vector<sd_image_t> pmid_images; SDImageVec pmid_images;
sd_img_gen_params_t img_gen_params = { sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(), gen_params.lora_vec.data(),
@ -416,19 +405,19 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.prompt.c_str(), gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(), gen_params.negative_prompt.c_str(),
gen_params.clip_skip, gen_params.clip_skip,
init_image, init_image.get(),
nullptr, nullptr,
0, 0,
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image.get(),
gen_params.width, gen_params.width,
gen_params.height, gen_params.height,
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
gen_params.batch_count, gen_params.batch_count,
control_image, control_image.get(),
gen_params.control_strength, gen_params.control_strength,
{ {
pmid_images.data(), pmid_images.data(),
@ -440,13 +429,19 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.cache_params, gen_params.cache_params,
}; };
sd_image_t* results = nullptr; SDImageVec results;
int num_results = 0; int num_results = 0;
{ {
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex); std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
results.adopt(generate_image(runtime->sd_ctx, &img_gen_params), num_results);
}
if (!results) {
res.status = 500;
res.set_content(R"({"error":"generate failed"})", "application/json");
return;
} }
for (int i = 0; i < num_results; i++) { for (int i = 0; i < num_results; i++) {
@ -477,8 +472,6 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
item["b64_json"] = b64; item["b64_json"] = b64;
out["data"].push_back(item); out["data"].push_back(item);
} }
free_results(results, num_results);
res.set_content(out.dump(), "application/json"); res.set_content(out.dump(), "application/json");
res.status = 200; res.status = 200;
@ -599,9 +592,9 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {0, 0, 3, nullptr}; SDImageOwner init_image({0, 0, 3, nullptr});
sd_image_t control_image = {0, 0, 3, nullptr}; SDImageOwner control_image({0, 0, 3, nullptr});
std::vector<sd_image_t> pmid_images; SDImageVec pmid_images;
auto get_resolved_width = [&gen_params, runtime]() -> int { auto get_resolved_width = [&gen_params, runtime]() -> int {
if (gen_params.width > 0) if (gen_params.width > 0)
@ -618,7 +611,7 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
return 512; return 512;
}; };
std::vector<sd_image_t> ref_images; SDImageVec ref_images;
ref_images.reserve(images_bytes.size()); ref_images.reserve(images_bytes.size());
for (auto& bytes : images_bytes) { for (auto& bytes : images_bytes) {
int img_w; int img_w;
@ -634,12 +627,12 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
continue; continue;
} }
sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels}; SDImageOwner img({(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels});
gen_params.set_width_and_height_if_unset(img.width, img.height); gen_params.set_width_and_height_if_unset(img.get().width, img.get().height);
ref_images.push_back(img); ref_images.push_back(std::move(img));
} }
sd_image_t mask_image = {0}; SDImageOwner mask_image({0, 0, 1, nullptr});
if (!mask_bytes.empty()) { if (!mask_bytes.empty()) {
int expected_width = 0; int expected_width = 0;
int expected_height = 0; int expected_height = 0;
@ -655,13 +648,10 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
static_cast<int>(mask_bytes.size()), static_cast<int>(mask_bytes.size()),
mask_w, mask_h, mask_w, mask_h,
expected_width, expected_height, 1); expected_width, expected_height, 1);
mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw}; mask_image.reset({(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw});
gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height); gen_params.set_width_and_height_if_unset(mask_image.get().width, mask_image.get().height);
} else { } else {
mask_image.width = get_resolved_width(); mask_image.reset({(uint32_t)get_resolved_width(), (uint32_t)get_resolved_height(), 1, nullptr});
mask_image.height = get_resolved_height();
mask_image.channel = 1;
mask_image.data = nullptr;
} }
sd_img_gen_params_t img_gen_params = { sd_img_gen_params_t img_gen_params = {
@ -670,19 +660,19 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.prompt.c_str(), gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(), gen_params.negative_prompt.c_str(),
gen_params.clip_skip, gen_params.clip_skip,
init_image, init_image.get(),
ref_images.data(), ref_images.data(),
(int)ref_images.size(), (int)ref_images.size(),
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image.get(),
get_resolved_width(), get_resolved_width(),
get_resolved_height(), get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
gen_params.batch_count, gen_params.batch_count,
control_image, control_image.get(),
gen_params.control_strength, gen_params.control_strength,
{ {
pmid_images.data(), pmid_images.data(),
@ -694,13 +684,19 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.cache_params, gen_params.cache_params,
}; };
sd_image_t* results = nullptr; SDImageVec results;
int num_results = 0; int num_results = 0;
{ {
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex); std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
results.adopt(generate_image(runtime->sd_ctx, &img_gen_params), num_results);
}
if (!results) {
res.status = 500;
res.set_content(R"({"error":"generate failed"})", "application/json");
return;
} }
json out; json out;
@ -730,20 +726,8 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
item["b64_json"] = b64; item["b64_json"] = b64;
out["data"].push_back(item); out["data"].push_back(item);
} }
free_results(results, num_results);
res.set_content(out.dump(), "application/json"); res.set_content(out.dump(), "application/json");
res.status = 200; res.status = 200;
if (init_image.data) {
free(init_image.data);
}
if (mask_image.data) {
free(mask_image.data);
}
for (auto ref_image : ref_images) {
free(ref_image.data);
}
} catch (const std::exception& e) { } catch (const std::exception& e) {
res.status = 500; res.status = 500;
json err; json err;
@ -892,12 +876,11 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {0, 0, 3, nullptr}; SDImageOwner init_image({0, 0, 3, nullptr});
sd_image_t control_image = {0, 0, 3, nullptr}; SDImageOwner control_image({0, 0, 3, nullptr});
sd_image_t mask_image = {0, 0, 1, nullptr}; SDImageOwner mask_image({0, 0, 1, nullptr});
std::vector<uint8_t> mask_data; SDImageVec pmid_images;
std::vector<sd_image_t> pmid_images; SDImageVec ref_images;
std::vector<sd_image_t> ref_images;
auto get_resolved_width = [&gen_params, runtime]() -> int { auto get_resolved_width = [&gen_params, runtime]() -> int {
if (gen_params.width > 0) if (gen_params.width > 0)
@ -914,7 +897,7 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
return 512; return 512;
}; };
auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool { auto decode_image = [&gen_params](SDImageOwner& image, std::string encoded) -> bool {
auto comma_pos = encoded.find(','); auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) { if (comma_pos != std::string::npos) {
encoded = encoded.substr(comma_pos + 1); encoded = encoded.substr(comma_pos + 1);
@ -933,10 +916,10 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
uint8_t* raw_data = load_image_from_memory( uint8_t* raw_data = load_image_from_memory(
(const char*)img_data.data(), (int)img_data.size(), (const char*)img_data.data(), (int)img_data.size(),
img_w, img_h, img_w, img_h,
expected_width, expected_height, image.channel); expected_width, expected_height, image.get().channel);
if (raw_data) { if (raw_data) {
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data}; image.reset({(uint32_t)img_w, (uint32_t)img_h, image.get().channel, raw_data});
gen_params.set_width_and_height_if_unset(image.width, image.height); gen_params.set_width_and_height_if_unset(image.get().width, image.get().height);
return true; return true;
} }
} }
@ -953,19 +936,21 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
std::string encoded = j["mask"].get<std::string>(); std::string encoded = j["mask"].get<std::string>();
decode_image(mask_image, encoded); decode_image(mask_image, encoded);
bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0; bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0;
if (inpainting_mask_invert && mask_image.data != nullptr) { if (inpainting_mask_invert && mask_image.get().data != nullptr) {
for (uint32_t i = 0; i < mask_image.width * mask_image.height; i++) { for (uint32_t i = 0; i < mask_image.get().width * mask_image.get().height; i++) {
mask_image.data[i] = 255 - mask_image.data[i]; mask_image.get().data[i] = 255 - mask_image.get().data[i];
} }
} }
} else { } else {
int m_width = get_resolved_width(); int m_width = get_resolved_width();
int m_height = get_resolved_height(); int m_height = get_resolved_height();
mask_data = std::vector<uint8_t>(m_width * m_height, 255); sd_image_t generated_mask = {(uint32_t)m_width, (uint32_t)m_height, 1, nullptr};
mask_image.width = m_width; generated_mask.data = (uint8_t*)malloc(static_cast<size_t>(m_width) * static_cast<size_t>(m_height));
mask_image.height = m_height; if (generated_mask.data == nullptr) {
mask_image.channel = 1; return bad("failed to allocate default mask");
mask_image.data = mask_data.data(); }
memset(generated_mask.data, 255, static_cast<size_t>(m_width) * static_cast<size_t>(m_height));
mask_image.reset(generated_mask);
} }
float denoising_strength = j.value("denoising_strength", -1.f); float denoising_strength = j.value("denoising_strength", -1.f);
@ -978,9 +963,9 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
if (j.contains("extra_images") && j["extra_images"].is_array()) { if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) { for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>(); std::string encoded = extra_image.get<std::string>();
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; SDImageOwner tmp_image({(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr});
if (decode_image(tmp_image, encoded)) { if (decode_image(tmp_image, encoded)) {
ref_images.push_back(tmp_image); ref_images.push_back(std::move(tmp_image));
} }
} }
} }
@ -991,19 +976,19 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.prompt.c_str(), gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(), gen_params.negative_prompt.c_str(),
gen_params.clip_skip, gen_params.clip_skip,
init_image, init_image.get(),
ref_images.data(), ref_images.data(),
(int)ref_images.size(), (int)ref_images.size(),
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image.get(),
get_resolved_width(), get_resolved_width(),
get_resolved_height(), get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
gen_params.batch_count, gen_params.batch_count,
control_image, control_image.get(),
gen_params.control_strength, gen_params.control_strength,
{ {
pmid_images.data(), pmid_images.data(),
@ -1015,13 +1000,19 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.cache_params, gen_params.cache_params,
}; };
sd_image_t* results = nullptr; SDImageVec results;
int num_results = 0; int num_results = 0;
{ {
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex); std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
results.adopt(generate_image(runtime->sd_ctx, &img_gen_params), num_results);
}
if (!results) {
res.status = 500;
res.set_content(R"({"error":"generate failed"})", "application/json");
return;
} }
json out; json out;
@ -1052,21 +1043,9 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
std::string b64 = base64_encode(image_bytes); std::string b64 = base64_encode(image_bytes);
out["images"].push_back(b64); out["images"].push_back(b64);
} }
free_results(results, num_results);
res.set_content(out.dump(), "application/json"); res.set_content(out.dump(), "application/json");
res.status = 200; res.status = 200;
if (init_image.data) {
free(init_image.data);
}
if (mask_image.data && mask_data.empty()) {
free(mask_image.data);
}
for (auto ref_image : ref_images) {
free(ref_image.data);
}
} catch (const std::exception& e) { } catch (const std::exception& e) {
res.status = 500; res.status = 500;
json err; json err;
@ -1178,7 +1157,7 @@ int main(int argc, const char** argv) {
LOG_DEBUG("%s", default_gen_params.to_string().c_str()); LOG_DEBUG("%s", default_gen_params.to_string().c_str());
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false); sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false);
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params));
if (sd_ctx == nullptr) { if (sd_ctx == nullptr) {
LOG_ERROR("new_sd_ctx_t failed"); LOG_ERROR("new_sd_ctx_t failed");
@ -1190,7 +1169,7 @@ int main(int argc, const char** argv) {
std::vector<LoraEntry> lora_cache; std::vector<LoraEntry> lora_cache;
std::mutex lora_mutex; std::mutex lora_mutex;
ServerRuntime runtime = { ServerRuntime runtime = {
sd_ctx, sd_ctx.get(),
&sd_ctx_mutex, &sd_ctx_mutex,
&svr_params, &svr_params,
&ctx_params, &ctx_params,
@ -1231,6 +1210,5 @@ int main(int argc, const char** argv) {
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port); LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
svr.listen(svr_params.listen_ip, svr_params.listen_port); svr.listen(svr_params.listen_ip, svr_params.listen_port);
free_sd_ctx(sd_ctx);
return 0; return 0;
} }