Compare commits

..

No commits in common. "8afbeb6ba9702c15d41a38296f2ab1fe5c829fa0" and "7397ddaa86f4e8837d5261724678cde0f36d4d89" have entirely different histories.

13 changed files with 2265 additions and 2468 deletions

View File

@ -1,7 +1,6 @@
set(TARGET sd-cli) set(TARGET sd-cli)
add_executable(${TARGET} add_executable(${TARGET}
../common/common.cpp
../common/log.cpp ../common/log.cpp
../common/media_io.cpp ../common/media_io.cpp
image_metadata.cpp image_metadata.cpp

View File

@ -15,13 +15,10 @@
// #include "preprocessing.hpp" // #include "preprocessing.hpp"
#include "stable-diffusion.h" #include "stable-diffusion.h"
#include "common/common.h" #include "common/common.hpp"
#include "common/media_io.h" #include "common/media_io.h"
#include "common/resource_owners.hpp"
#include "image_metadata.h" #include "image_metadata.h"
namespace fs = std::filesystem;
const char* previews_str[] = { const char* previews_str[] = {
"none", "none",
"proj", "proj",
@ -278,7 +275,7 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
} }
bool load_images_from_dir(const std::string dir, bool load_images_from_dir(const std::string dir,
SDImageVec& images, std::vector<sd_image_t>& images,
int expected_width = 0, int expected_width = 0,
int expected_height = 0, int expected_height = 0,
int max_image_num = 0, int max_image_num = 0,
@ -320,7 +317,7 @@ bool load_images_from_dir(const std::string dir,
3, 3,
image_buffer}); image_buffer});
if (max_image_num > 0 && static_cast<int>(images.size()) >= max_image_num) { if (max_image_num > 0 && images.size() >= max_image_num) {
break; break;
} }
} }
@ -558,16 +555,38 @@ int main(int argc, const char* argv[]) {
} }
bool vae_decode_only = true; bool vae_decode_only = true;
SDImageOwner init_image({0, 0, 3, nullptr}); sd_image_t init_image = {0, 0, 3, nullptr};
SDImageOwner end_image({0, 0, 3, nullptr}); sd_image_t end_image = {0, 0, 3, nullptr};
SDImageOwner control_image({0, 0, 3, nullptr}); sd_image_t control_image = {0, 0, 3, nullptr};
SDImageOwner mask_image({0, 0, 1, nullptr}); sd_image_t mask_image = {0, 0, 1, nullptr};
SDImageVec ref_images; std::vector<sd_image_t> ref_images;
SDImageVec pmid_images; std::vector<sd_image_t> pmid_images;
SDImageVec control_frames; std::vector<sd_image_t> control_frames;
auto release_all_resources = [&]() {
free(init_image.data);
free(end_image.data);
free(control_image.data);
free(mask_image.data);
for (auto image : ref_images) {
free(image.data);
image.data = nullptr;
}
ref_images.clear();
for (auto image : pmid_images) {
free(image.data);
image.data = nullptr;
}
pmid_images.clear();
for (auto image : control_frames) {
free(image.data);
image.data = nullptr;
}
control_frames.clear();
};
auto load_image_and_update_size = [&](const std::string& path, auto load_image_and_update_size = [&](const std::string& path,
SDImageOwner& image, sd_image_t& image,
bool resize_image = true, bool resize_image = true,
int expected_channel = 3) -> bool { int expected_channel = 3) -> bool {
int expected_width = 0; int expected_width = 0;
@ -577,12 +596,13 @@ int main(int argc, const char* argv[]) {
expected_height = gen_params.height; expected_height = gen_params.height;
} }
if (!load_sd_image_from_file(image.put(), path.c_str(), expected_width, expected_height, expected_channel)) { if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
LOG_ERROR("load image from '%s' failed", path.c_str()); LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return false; return false;
} }
gen_params.set_width_and_height_if_unset(image.get().width, image.get().height); gen_params.set_width_and_height_if_unset(image.width, image.height);
return true; return true;
}; };
@ -603,46 +623,47 @@ int main(int argc, const char* argv[]) {
if (gen_params.ref_image_paths.size() > 0) { if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false; vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) { for (auto& path : gen_params.ref_image_paths) {
SDImageOwner ref_image({0, 0, 3, nullptr}); sd_image_t ref_image = {0, 0, 3, nullptr};
if (!load_image_and_update_size(path, ref_image, false)) { if (!load_image_and_update_size(path, ref_image, false)) {
return 1; return 1;
} }
ref_images.push_back(std::move(ref_image)); ref_images.push_back(ref_image);
} }
} }
if (gen_params.mask_image_path.size() > 0) { if (gen_params.mask_image_path.size() > 0) {
if (!load_sd_image_from_file(mask_image.put(), if (!load_sd_image_from_file(&mask_image,
gen_params.mask_image_path.c_str(), gen_params.mask_image_path.c_str(),
gen_params.get_resolved_width(), gen_params.get_resolved_width(),
gen_params.get_resolved_height(), gen_params.get_resolved_height(),
1)) { 1)) {
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str()); LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
release_all_resources();
return 1; return 1;
} }
} else { } else {
sd_image_t generated_mask = {0, 0, 1, nullptr}; mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
generated_mask.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height()); if (mask_image.data == nullptr) {
if (generated_mask.data == nullptr) {
LOG_ERROR("malloc mask image failed"); LOG_ERROR("malloc mask image failed");
release_all_resources();
return 1; return 1;
} }
generated_mask.width = gen_params.get_resolved_width(); mask_image.width = gen_params.get_resolved_width();
generated_mask.height = gen_params.get_resolved_height(); mask_image.height = gen_params.get_resolved_height();
memset(generated_mask.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height()); memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
mask_image.reset(generated_mask);
} }
if (gen_params.control_image_path.size() > 0) { if (gen_params.control_image_path.size() > 0) {
if (!load_sd_image_from_file(control_image.put(), if (!load_sd_image_from_file(&control_image,
gen_params.control_image_path.c_str(), gen_params.control_image_path.c_str(),
gen_params.get_resolved_width(), gen_params.get_resolved_width(),
gen_params.get_resolved_height())) { gen_params.get_resolved_height())) {
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str()); LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
release_all_resources();
return 1; return 1;
} }
if (cli_params.canny_preprocess) { // apply preprocessor if (cli_params.canny_preprocess) { // apply preprocessor
preprocess_canny(control_image.get(), preprocess_canny(control_image,
0.08f, 0.08f,
0.08f, 0.08f,
0.8f, 0.8f,
@ -658,6 +679,7 @@ int main(int argc, const char* argv[]) {
gen_params.get_resolved_height(), gen_params.get_resolved_height(),
gen_params.video_frames, gen_params.video_frames,
cli_params.verbose)) { cli_params.verbose)) {
release_all_resources();
return 1; return 1;
} }
} }
@ -669,6 +691,7 @@ int main(int argc, const char* argv[]) {
0, 0,
0, 0,
cli_params.verbose)) { cli_params.verbose)) {
release_all_resources();
return 1; return 1;
} }
} }
@ -679,30 +702,39 @@ int main(int argc, const char* argv[]) {
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview); sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(vae_decode_only, true, cli_params.taesd_preview);
SDImageVec results; sd_image_t* results = nullptr;
int num_results = 0; int num_results = 0;
if (cli_params.mode == UPSCALE) { if (cli_params.mode == UPSCALE) {
num_results = 1; num_results = 1;
results.push_back(init_image.release()); results = (sd_image_t*)calloc(num_results, sizeof(sd_image_t));
if (results == nullptr) {
LOG_INFO("failed to allocate results array");
release_all_resources();
return 1;
}
results[0] = init_image;
init_image.data = nullptr;
} else { } else {
SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params)); sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
if (sd_ctx == nullptr) { if (sd_ctx == nullptr) {
LOG_INFO("new_sd_ctx_t failed"); LOG_INFO("new_sd_ctx_t failed");
release_all_resources();
return 1; return 1;
} }
if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) { if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get()); gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
} }
if (gen_params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) { if (gen_params.high_noise_sample_params.sample_method == SAMPLE_METHOD_COUNT) {
gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx.get()); gen_params.high_noise_sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
} }
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) { if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx.get(), gen_params.sample_params.sample_method); gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method);
} }
if (cli_params.mode == IMG_GEN) { if (cli_params.mode == IMG_GEN) {
@ -712,19 +744,19 @@ int main(int argc, const char* argv[]) {
gen_params.prompt.c_str(), gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(), gen_params.negative_prompt.c_str(),
gen_params.clip_skip, gen_params.clip_skip,
init_image.get(), init_image,
ref_images.data(), ref_images.data(),
(int)ref_images.size(), (int)ref_images.size(),
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image.get(), mask_image,
gen_params.get_resolved_width(), gen_params.get_resolved_width(),
gen_params.get_resolved_height(), gen_params.get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
gen_params.batch_count, gen_params.batch_count,
control_image.get(), control_image,
gen_params.control_strength, gen_params.control_strength,
{ {
pmid_images.data(), pmid_images.data(),
@ -736,8 +768,8 @@ int main(int argc, const char* argv[]) {
gen_params.cache_params, gen_params.cache_params,
}; };
results = generate_image(sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
results.adopt(generate_image(sd_ctx.get(), &img_gen_params), num_results);
} else if (cli_params.mode == VID_GEN) { } else if (cli_params.mode == VID_GEN) {
sd_vid_gen_params_t vid_gen_params = { sd_vid_gen_params_t vid_gen_params = {
gen_params.lora_vec.data(), gen_params.lora_vec.data(),
@ -745,8 +777,8 @@ int main(int argc, const char* argv[]) {
gen_params.prompt.c_str(), gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(), gen_params.negative_prompt.c_str(),
gen_params.clip_skip, gen_params.clip_skip,
init_image.get(), init_image,
end_image.get(), end_image,
control_frames.data(), control_frames.data(),
(int)control_frames.size(), (int)control_frames.size(),
gen_params.get_resolved_width(), gen_params.get_resolved_width(),
@ -762,23 +794,25 @@ int main(int argc, const char* argv[]) {
gen_params.cache_params, gen_params.cache_params,
}; };
sd_image_t* generated_video = generate_video(sd_ctx.get(), &vid_gen_params, &num_results); results = generate_video(sd_ctx, &vid_gen_params, &num_results);
results.adopt(generated_video, num_results);
} }
if (!results) { if (results == nullptr) {
LOG_ERROR("generate failed"); LOG_ERROR("generate failed");
free_sd_ctx(sd_ctx);
return 1; return 1;
} }
free_sd_ctx(sd_ctx);
} }
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
if (ctx_params.esrgan_path.size() > 0 && gen_params.upscale_repeats > 0) { if (ctx_params.esrgan_path.size() > 0 && gen_params.upscale_repeats > 0) {
UpscalerCtxPtr upscaler_ctx(new_upscaler_ctx(ctx_params.esrgan_path.c_str(), upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(ctx_params.esrgan_path.c_str(),
ctx_params.offload_params_to_cpu, ctx_params.offload_params_to_cpu,
ctx_params.diffusion_conv_direct, ctx_params.diffusion_conv_direct,
ctx_params.n_threads, ctx_params.n_threads,
gen_params.upscale_tile_size)); gen_params.upscale_tile_size);
if (upscaler_ctx == nullptr) { if (upscaler_ctx == nullptr) {
LOG_ERROR("new_upscaler_ctx failed"); LOG_ERROR("new_upscaler_ctx failed");
@ -787,24 +821,32 @@ int main(int argc, const char* argv[]) {
if (results[i].data == nullptr) { if (results[i].data == nullptr) {
continue; continue;
} }
SDImageOwner current_image(results[i]); sd_image_t current_image = results[i];
results[i] = {0, 0, 0, nullptr};
for (int u = 0; u < gen_params.upscale_repeats; ++u) { for (int u = 0; u < gen_params.upscale_repeats; ++u) {
SDImageOwner upscaled_image(upscale(upscaler_ctx.get(), current_image.get(), upscale_factor)); sd_image_t upscaled_image = upscale(upscaler_ctx, current_image, upscale_factor);
if (upscaled_image.get().data == nullptr) { if (upscaled_image.data == nullptr) {
LOG_ERROR("upscale failed"); LOG_ERROR("upscale failed");
break; break;
} }
current_image = std::move(upscaled_image); free(current_image.data);
current_image = upscaled_image;
} }
results[i] = current_image.release(); // Set the final upscaled image as the result results[i] = current_image; // Set the final upscaled image as the result
} }
} }
} }
if (!save_results(cli_params, ctx_params, gen_params, results.data(), num_results)) { if (!save_results(cli_params, ctx_params, gen_params, results, num_results)) {
return 1; return 1;
} }
for (int i = 0; i < num_results; i++) {
free(results[i].data);
results[i].data = nullptr;
}
free(results);
release_all_resources();
return 0; return 0;
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,207 +0,0 @@
#ifndef __EXAMPLES_COMMON_COMMON_H__
#define __EXAMPLES_COMMON_COMMON_H__
#include <cmath>
#include <cstdint>
#include <functional>
#include <map>
#include <string>
#include <vector>
#include "log.h"
#include "stable-diffusion.h"
#define SAFE_STR(s) ((s) ? (s) : "")
#define BOOL_STR(b) ((b) ? "true" : "false")
extern const char* const modes_str[];
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale, metadata"
enum SDMode {
IMG_GEN,
VID_GEN,
CONVERT,
UPSCALE,
METADATA,
MODE_COUNT
};
struct StringOption {
std::string short_name;
std::string long_name;
std::string desc;
std::string* target;
};
struct IntOption {
std::string short_name;
std::string long_name;
std::string desc;
int* target;
};
struct FloatOption {
std::string short_name;
std::string long_name;
std::string desc;
float* target;
};
struct BoolOption {
std::string short_name;
std::string long_name;
std::string desc;
bool keep_true;
bool* target;
};
struct ManualOption {
std::string short_name;
std::string long_name;
std::string desc;
std::function<int(int argc, const char** argv, int index)> cb;
};
struct ArgOptions {
std::vector<StringOption> string_options;
std::vector<IntOption> int_options;
std::vector<FloatOption> float_options;
std::vector<BoolOption> bool_options;
std::vector<ManualOption> manual_options;
static std::string wrap_text(const std::string& text, size_t width, size_t indent);
void print() const;
};
bool parse_options(int argc, const char** argv, const std::vector<ArgOptions>& options_list);
struct SDContextParams {
int n_threads = -1;
std::string model_path;
std::string clip_l_path;
std::string clip_g_path;
std::string clip_vision_path;
std::string t5xxl_path;
std::string llm_path;
std::string llm_vision_path;
std::string diffusion_model_path;
std::string high_noise_diffusion_model_path;
std::string vae_path;
std::string taesd_path;
std::string esrgan_path;
std::string control_net_path;
std::string embedding_dir;
std::string photo_maker_path;
sd_type_t wtype = SD_TYPE_COUNT;
std::string tensor_type_rules;
std::string lora_model_dir = ".";
std::map<std::string, std::string> embedding_map;
std::vector<sd_embedding_t> embedding_vec;
rng_type_t rng_type = CUDA_RNG;
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
bool offload_params_to_cpu = false;
bool enable_mmap = false;
bool control_net_cpu = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool flash_attn = false;
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
bool circular = false;
bool circular_x = false;
bool circular_y = false;
bool chroma_use_dit_mask = true;
bool chroma_use_t5_mask = false;
int chroma_t5_mask_pad = 1;
bool qwen_image_zero_cond_t = false;
prediction_t prediction = PREDICTION_COUNT;
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
bool force_sdxl_vae_conv_scale = false;
float flow_shift = INFINITY;
ArgOptions get_options();
void build_embedding_map();
bool process_and_check(SDMode mode);
std::string to_string() const;
sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview);
};
struct SDGenerationParams {
std::string prompt;
std::string prompt_with_lora; // for metadata record only
std::string negative_prompt;
int clip_skip = -1; // <= 0 represents unspecified
int width = -1;
int height = -1;
int batch_count = 1;
std::string init_image_path;
std::string end_image_path;
std::string mask_image_path;
std::string control_image_path;
std::vector<std::string> ref_image_paths;
std::string control_video_path;
bool auto_resize_ref_image = true;
bool increase_ref_index = false;
bool embed_image_metadata = true;
std::vector<int> skip_layers = {7, 8, 9};
sd_sample_params_t sample_params;
std::vector<int> high_noise_skip_layers = {7, 8, 9};
sd_sample_params_t high_noise_sample_params;
std::vector<float> custom_sigmas;
std::string cache_mode;
std::string cache_option;
std::string scm_mask;
bool scm_policy_dynamic = true;
sd_cache_params_t cache_params{};
float moe_boundary = 0.875f;
int video_frames = 1;
int fps = 16;
float vace_strength = 1.f;
float strength = 0.75f;
float control_strength = 0.9f;
int64_t seed = 42;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
// Photo Maker
std::string pm_id_images_dir;
std::string pm_id_embed_path;
float pm_style_strength = 20.f;
int upscale_repeats = 1;
int upscale_tile_size = 128;
std::map<std::string, float> lora_map;
std::map<std::string, float> high_noise_lora_map;
std::vector<sd_lora_t> lora_vec;
SDGenerationParams();
ArgOptions get_options();
bool from_json_str(const std::string& json_str);
void extract_and_remove_lora(const std::string& lora_model_dir);
bool width_and_height_are_set() const;
void set_width_and_height_if_unset(int w, int h);
int get_resolved_width() const;
int get_resolved_height() const;
bool process_and_check(SDMode mode, const std::string& lora_model_dir);
std::string to_string() const;
};
std::string version_string();
std::string get_image_params(const SDContextParams& ctx_params, const SDGenerationParams& gen_params, int64_t seed);
#endif // __EXAMPLES_COMMON_COMMON_H__

1902
examples/common/common.hpp Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,7 +1,5 @@
#include "log.h" #include "log.h"
#include <vector>
bool log_verbose = false; bool log_verbose = false;
bool log_color = false; bool log_color = false;
@ -36,12 +34,17 @@ void print_utf8(FILE* stream, const char* utf8) {
return; return;
} }
std::vector<wchar_t> wbuf(static_cast<size_t>(wlen)); wchar_t* wbuf = (wchar_t*)malloc(wlen * sizeof(wchar_t));
if (!wbuf) {
return;
}
MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf.data(), wlen); MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wbuf, wlen);
DWORD written; DWORD written;
WriteConsoleW(h, wbuf.data(), wlen - 1, &written, NULL); WriteConsoleW(h, wbuf, wlen - 1, &written, NULL);
free(wbuf);
} else { } else {
DWORD written; DWORD written;
WriteFile(h, utf8, (DWORD)strlen(utf8), &written, NULL); WriteFile(h, utf8, (DWORD)strlen(utf8), &written, NULL);

View File

@ -1,6 +1,5 @@
#include "media_io.h"
#include "log.h" #include "log.h"
#include "resource_owners.hpp" #include "media_io.h"
#include <algorithm> #include <algorithm>
#include <cctype> #include <cctype>
@ -38,63 +37,7 @@
namespace fs = std::filesystem; namespace fs = std::filesystem;
#ifdef SD_USE_WEBP namespace {
struct WebPFreeDeleter {
void operator()(void* ptr) const {
if (ptr != nullptr) {
WebPFree(ptr);
}
}
};
struct WebPMuxDeleter {
void operator()(WebPMux* mux) const {
if (mux != nullptr) {
WebPMuxDelete(mux);
}
}
};
struct WebPAnimEncoderDeleter {
void operator()(WebPAnimEncoder* enc) const {
if (enc != nullptr) {
WebPAnimEncoderDelete(enc);
}
}
};
struct WebPDataGuard {
WebPDataGuard() {
WebPDataInit(&data);
}
~WebPDataGuard() {
WebPDataClear(&data);
}
WebPData data;
};
struct WebPPictureGuard {
WebPPictureGuard()
: initialized(WebPPictureInit(&picture) != 0) {
}
~WebPPictureGuard() {
if (initialized) {
WebPPictureFree(&picture);
}
}
WebPPicture picture;
bool initialized;
};
using WebPBufferPtr = std::unique_ptr<uint8_t, WebPFreeDeleter>;
using WebPMuxPtr = std::unique_ptr<WebPMux, WebPMuxDeleter>;
using WebPAnimEncoderPtr = std::unique_ptr<WebPAnimEncoder, WebPAnimEncoderDeleter>;
#endif
bool read_binary_file_bytes(const char* path, std::vector<uint8_t>& data) { bool read_binary_file_bytes(const char* path, std::vector<uint8_t>& data) {
std::ifstream fin(fs::path(path), std::ios::binary); std::ifstream fin(fs::path(path), std::ios::binary);
if (!fin) { if (!fin) {
@ -215,25 +158,27 @@ uint8_t* decode_webp_image_to_buffer(const uint8_t* data,
if (expected_channel == 1) { if (expected_channel == 1) {
int decoded_width = width; int decoded_width = width;
int decoded_height = height; int decoded_height = height;
WebPBufferPtr decoded(features.has_alpha uint8_t* decoded = features.has_alpha
? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height) ? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height)
: WebPDecodeRGB(data, size, &decoded_width, &decoded_height)); : WebPDecodeRGB(data, size, &decoded_width, &decoded_height);
if (decoded == nullptr) { if (decoded == nullptr) {
return nullptr; return nullptr;
} }
FreeUniquePtr<uint8_t> grayscale((uint8_t*)malloc(pixel_count)); uint8_t* grayscale = (uint8_t*)malloc(pixel_count);
if (grayscale == nullptr) { if (grayscale == nullptr) {
WebPFree(decoded);
return nullptr; return nullptr;
} }
const int decoded_channels = features.has_alpha ? 4 : 3; const int decoded_channels = features.has_alpha ? 4 : 3;
for (size_t i = 0; i < pixel_count; ++i) { for (size_t i = 0; i < pixel_count; ++i) {
const uint8_t* src = decoded.get() + i * decoded_channels; const uint8_t* src = decoded + i * decoded_channels;
grayscale.get()[i] = static_cast<uint8_t>((77 * src[0] + 150 * src[1] + 29 * src[2] + 128) >> 8); grayscale[i] = static_cast<uint8_t>((77 * src[0] + 150 * src[1] + 29 * src[2] + 128) >> 8);
} }
return grayscale.release(); WebPFree(decoded);
return grayscale;
} }
if (expected_channel != 3 && expected_channel != 4) { if (expected_channel != 3 && expected_channel != 4) {
@ -242,21 +187,23 @@ uint8_t* decode_webp_image_to_buffer(const uint8_t* data,
int decoded_width = width; int decoded_width = width;
int decoded_height = height; int decoded_height = height;
WebPBufferPtr decoded((expected_channel == 4) uint8_t* decoded = (expected_channel == 4)
? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height) ? WebPDecodeRGBA(data, size, &decoded_width, &decoded_height)
: WebPDecodeRGB(data, size, &decoded_width, &decoded_height)); : WebPDecodeRGB(data, size, &decoded_width, &decoded_height);
if (decoded == nullptr) { if (decoded == nullptr) {
return nullptr; return nullptr;
} }
const size_t out_size = pixel_count * static_cast<size_t>(expected_channel); const size_t out_size = pixel_count * static_cast<size_t>(expected_channel);
FreeUniquePtr<uint8_t> output((uint8_t*)malloc(out_size)); uint8_t* output = (uint8_t*)malloc(out_size);
if (output == nullptr) { if (output == nullptr) {
WebPFree(decoded);
return nullptr; return nullptr;
} }
memcpy(output.get(), decoded.get(), out_size); memcpy(output, decoded, out_size);
return output.release(); WebPFree(decoded);
return output;
} }
std::string build_webp_xmp_packet(const std::string& parameters) { std::string build_webp_xmp_packet(const std::string& parameters) {
@ -308,29 +255,30 @@ bool encode_webp_image_to_vector(const uint8_t* image,
return false; return false;
} }
uint8_t* encoded_raw = nullptr; uint8_t* encoded = nullptr;
size_t encoded_size = (input_channels == 4) size_t encoded_size = (input_channels == 4)
? WebPEncodeRGBA(input_image, width, height, width * input_channels, static_cast<float>(quality), &encoded_raw) ? WebPEncodeRGBA(input_image, width, height, width * input_channels, static_cast<float>(quality), &encoded)
: WebPEncodeRGB(input_image, width, height, width * input_channels, static_cast<float>(quality), &encoded_raw); : WebPEncodeRGB(input_image, width, height, width * input_channels, static_cast<float>(quality), &encoded);
WebPBufferPtr encoded(encoded_raw);
if (encoded == nullptr || encoded_size == 0) { if (encoded == nullptr || encoded_size == 0) {
return false; return false;
} }
out.assign(encoded.get(), encoded.get() + encoded_size); out.assign(encoded, encoded + encoded_size);
WebPFree(encoded);
if (parameters.empty()) { if (parameters.empty()) {
return true; return true;
} }
WebPData image_data; WebPData image_data;
WebPData assembled_data;
WebPDataInit(&image_data); WebPDataInit(&image_data);
WebPDataGuard assembled_data; WebPDataInit(&assembled_data);
image_data.bytes = out.data(); image_data.bytes = out.data();
image_data.size = out.size(); image_data.size = out.size();
WebPMuxPtr mux(WebPMuxNew()); WebPMux* mux = WebPMuxNew();
if (mux == nullptr) { if (mux == nullptr) {
return false; return false;
} }
@ -341,14 +289,16 @@ bool encode_webp_image_to_vector(const uint8_t* image,
xmp_data.bytes = reinterpret_cast<const uint8_t*>(xmp_packet.data()); xmp_data.bytes = reinterpret_cast<const uint8_t*>(xmp_packet.data());
xmp_data.size = xmp_packet.size(); xmp_data.size = xmp_packet.size();
const bool ok = WebPMuxSetImage(mux.get(), &image_data, 1) == WEBP_MUX_OK && const bool ok = WebPMuxSetImage(mux, &image_data, 1) == WEBP_MUX_OK &&
WebPMuxSetChunk(mux.get(), "XMP ", &xmp_data, 1) == WEBP_MUX_OK && WebPMuxSetChunk(mux, "XMP ", &xmp_data, 1) == WEBP_MUX_OK &&
WebPMuxAssemble(mux.get(), &assembled_data.data) == WEBP_MUX_OK; WebPMuxAssemble(mux, &assembled_data) == WEBP_MUX_OK;
if (ok) { if (ok) {
out.assign(assembled_data.data.bytes, assembled_data.data.bytes + assembled_data.data.size); out.assign(assembled_data.bytes, assembled_data.bytes + assembled_data.size);
} }
WebPDataClear(&assembled_data);
WebPMuxDelete(mux);
return ok; return ok;
} }
@ -432,19 +382,19 @@ uint8_t* load_image_common(bool from_memory,
int expected_height, int expected_height,
int expected_channel) { int expected_channel) {
const char* image_path; const char* image_path;
FreeUniquePtr<uint8_t> image_buffer; uint8_t* image_buffer = nullptr;
int source_channel_count = 0; int source_channel_count = 0;
#ifdef SD_USE_WEBP #ifdef SD_USE_WEBP
if (from_memory) { if (from_memory) {
image_path = "memory"; image_path = "memory";
if (len > 0 && is_webp_signature(reinterpret_cast<const uint8_t*>(image_path_or_bytes), static_cast<size_t>(len))) { if (len > 0 && is_webp_signature(reinterpret_cast<const uint8_t*>(image_path_or_bytes), static_cast<size_t>(len))) {
image_buffer.reset(decode_webp_image_to_buffer(reinterpret_cast<const uint8_t*>(image_path_or_bytes), image_buffer = decode_webp_image_to_buffer(reinterpret_cast<const uint8_t*>(image_path_or_bytes),
static_cast<size_t>(len), static_cast<size_t>(len),
width, width,
height, height,
expected_channel, expected_channel,
source_channel_count)); source_channel_count);
} }
} else { } else {
image_path = image_path_or_bytes; image_path = image_path_or_bytes;
@ -458,12 +408,12 @@ uint8_t* load_image_common(bool from_memory,
LOG_ERROR("load image from '%s' failed", image_path_or_bytes); LOG_ERROR("load image from '%s' failed", image_path_or_bytes);
return nullptr; return nullptr;
} }
image_buffer.reset(decode_webp_image_to_buffer(file_bytes.data(), image_buffer = decode_webp_image_to_buffer(file_bytes.data(),
file_bytes.size(), file_bytes.size(),
width, width,
height, height,
expected_channel, expected_channel,
source_channel_count)); source_channel_count);
} }
} }
#endif #endif
@ -472,14 +422,14 @@ uint8_t* load_image_common(bool from_memory,
image_path = "memory"; image_path = "memory";
if (image_buffer == nullptr) { if (image_buffer == nullptr) {
int c = 0; int c = 0;
image_buffer.reset((uint8_t*)stbi_load_from_memory((const stbi_uc*)image_path_or_bytes, len, &width, &height, &c, expected_channel)); image_buffer = (uint8_t*)stbi_load_from_memory((const stbi_uc*)image_path_or_bytes, len, &width, &height, &c, expected_channel);
source_channel_count = c; source_channel_count = c;
} }
} else { } else {
image_path = image_path_or_bytes; image_path = image_path_or_bytes;
if (image_buffer == nullptr) { if (image_buffer == nullptr) {
int c = 0; int c = 0;
image_buffer.reset((uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel)); image_buffer = (uint8_t*)stbi_load(image_path_or_bytes, &width, &height, &c, expected_channel);
source_channel_count = c; source_channel_count = c;
} }
} }
@ -494,14 +444,17 @@ uint8_t* load_image_common(bool from_memory,
expected_channel, expected_channel,
source_channel_count, source_channel_count,
image_path); image_path);
free(image_buffer);
return nullptr; return nullptr;
} }
if (width <= 0) { if (width <= 0) {
LOG_ERROR("error: the width of image must be greater than 0, image_path = %s", image_path); LOG_ERROR("error: the width of image must be greater than 0, image_path = %s", image_path);
free(image_buffer);
return nullptr; return nullptr;
} }
if (height <= 0) { if (height <= 0) {
LOG_ERROR("error: the height of image must be greater than 0, image_path = %s", image_path); LOG_ERROR("error: the height of image must be greater than 0, image_path = %s", image_path);
free(image_buffer);
return nullptr; return nullptr;
} }
@ -522,39 +475,43 @@ uint8_t* load_image_common(bool from_memory,
if (crop_x != 0 || crop_y != 0) { if (crop_x != 0 || crop_y != 0) {
LOG_INFO("crop input image from %dx%d to %dx%d, image_path = %s", width, height, crop_w, crop_h, image_path); LOG_INFO("crop input image from %dx%d to %dx%d, image_path = %s", width, height, crop_w, crop_h, image_path);
FreeUniquePtr<uint8_t> cropped_image_buffer((uint8_t*)malloc(crop_w * crop_h * expected_channel)); uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel);
if (cropped_image_buffer == nullptr) { if (cropped_image_buffer == nullptr) {
LOG_ERROR("error: allocate memory for crop\n"); LOG_ERROR("error: allocate memory for crop\n");
free(image_buffer);
return nullptr; return nullptr;
} }
for (int row = 0; row < crop_h; row++) { for (int row = 0; row < crop_h; row++) {
uint8_t* src = image_buffer.get() + ((crop_y + row) * width + crop_x) * expected_channel; uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel;
uint8_t* dst = cropped_image_buffer.get() + (row * crop_w) * expected_channel; uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel;
memcpy(dst, src, crop_w * expected_channel); memcpy(dst, src, crop_w * expected_channel);
} }
width = crop_w; width = crop_w;
height = crop_h; height = crop_h;
image_buffer = std::move(cropped_image_buffer); free(image_buffer);
image_buffer = cropped_image_buffer;
} }
LOG_INFO("resize input image from %dx%d to %dx%d", width, height, expected_width, expected_height); LOG_INFO("resize input image from %dx%d to %dx%d", width, height, expected_width, expected_height);
FreeUniquePtr<uint8_t> resized_image_buffer((uint8_t*)malloc(expected_height * expected_width * expected_channel)); uint8_t* resized_image_buffer = (uint8_t*)malloc(expected_height * expected_width * expected_channel);
if (resized_image_buffer == nullptr) { if (resized_image_buffer == nullptr) {
LOG_ERROR("error: allocate memory for resize input image\n"); LOG_ERROR("error: allocate memory for resize input image\n");
free(image_buffer);
return nullptr; return nullptr;
} }
stbir_resize(image_buffer.get(), width, height, 0, stbir_resize(image_buffer, width, height, 0,
resized_image_buffer.get(), expected_width, expected_height, 0, STBIR_TYPE_UINT8, resized_image_buffer, expected_width, expected_height, 0, STBIR_TYPE_UINT8,
expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0, expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
STBIR_FILTER_BOX, STBIR_FILTER_BOX, STBIR_FILTER_BOX, STBIR_FILTER_BOX,
STBIR_COLORSPACE_SRGB, nullptr); STBIR_COLORSPACE_SRGB, nullptr);
width = expected_width; width = expected_width;
height = expected_height; height = expected_height;
image_buffer = std::move(resized_image_buffer); free(image_buffer);
image_buffer = resized_image_buffer;
} }
return image_buffer.release(); return image_buffer;
} }
typedef struct { typedef struct {
@ -569,6 +526,8 @@ void write_u32_le(FILE* f, uint32_t val) {
void write_u16_le(FILE* f, uint16_t val) { void write_u16_le(FILE* f, uint16_t val) {
fwrite(&val, 2, 1, f); fwrite(&val, 2, 1, f);
} }
} // namespace
EncodedImageFormat encoded_image_format_from_path(const std::string& path) { EncodedImageFormat encoded_image_format_from_path(const std::string& path) {
std::string ext = fs::path(path).extension().string(); std::string ext = fs::path(path).extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
@ -703,18 +662,18 @@ int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int
return -1; return -1;
} }
FilePtr file(fopen(filename, "wb")); FILE* f = fopen(filename, "wb");
if (!file) { if (!f) {
perror("Error opening file for writing"); perror("Error opening file for writing");
return -1; return -1;
} }
FILE* f = file.get();
uint32_t width = images[0].width; uint32_t width = images[0].width;
uint32_t height = images[0].height; uint32_t height = images[0].height;
uint32_t channels = images[0].channel; uint32_t channels = images[0].channel;
if (channels != 3 && channels != 4) { if (channels != 3 && channels != 4) {
fprintf(stderr, "Error: Unsupported channel count: %u\n", channels); fprintf(stderr, "Error: Unsupported channel count: %u\n", channels);
fclose(f);
return -1; return -1;
} }
@ -787,32 +746,41 @@ int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int
write_u32_le(f, 0); write_u32_le(f, 0);
fwrite("movi", 4, 1, f); fwrite("movi", 4, 1, f);
std::vector<avi_index_entry> index(static_cast<size_t>(num_images)); avi_index_entry* index = (avi_index_entry*)malloc(sizeof(avi_index_entry) * num_images);
std::vector<uint8_t> jpeg_data; if (!index) {
fclose(f);
for (int i = 0; i < num_images; i++) {
jpeg_data.clear();
auto write_to_buf = [](void* context, void* data, int size) {
auto* buffer = reinterpret_cast<std::vector<uint8_t>*>(context);
const uint8_t* src = reinterpret_cast<const uint8_t*>(data);
buffer->insert(buffer->end(), src, src + size);
};
if (!stbi_write_jpg_to_func(write_to_buf, &jpeg_data, images[i].width, images[i].height, channels, images[i].data, quality)) {
fprintf(stderr, "Error: Failed to encode JPEG frame.\n");
return -1; return -1;
} }
fwrite("00dc", 4, 1, f); struct {
write_u32_le(f, (uint32_t)jpeg_data.size()); uint8_t* buf;
index[i].offset = ftell(f) - 8; size_t size;
index[i].size = (uint32_t)jpeg_data.size(); } jpeg_data;
fwrite(jpeg_data.data(), 1, jpeg_data.size(), f);
if (jpeg_data.size() % 2) { for (int i = 0; i < num_images; i++) {
jpeg_data.buf = nullptr;
jpeg_data.size = 0;
auto write_to_buf = [](void* context, void* data, int size) {
auto jd = (decltype(jpeg_data)*)context;
jd->buf = (uint8_t*)realloc(jd->buf, jd->size + size);
memcpy(jd->buf + jd->size, data, size);
jd->size += size;
};
stbi_write_jpg_to_func(write_to_buf, &jpeg_data, images[i].width, images[i].height, channels, images[i].data, quality);
fwrite("00dc", 4, 1, f);
write_u32_le(f, (uint32_t)jpeg_data.size);
index[i].offset = ftell(f) - 8;
index[i].size = (uint32_t)jpeg_data.size;
fwrite(jpeg_data.buf, 1, jpeg_data.size, f);
if (jpeg_data.size % 2) {
fputc(0, f); fputc(0, f);
} }
free(jpeg_data.buf);
} }
long cur_pos = ftell(f); long cur_pos = ftell(f);
@ -836,6 +804,9 @@ int create_mjpg_avi_from_sd_images(const char* filename, sd_image_t* images, int
write_u32_le(f, file_size); write_u32_le(f, file_size);
fseek(f, cur_pos, SEEK_SET); fseek(f, cur_pos, SEEK_SET);
fclose(f);
free(index);
return 0; return 0;
} }
@ -876,7 +847,7 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images
return -1; return -1;
} }
WebPAnimEncoderPtr enc(WebPAnimEncoderNew(width, height, &anim_options)); WebPAnimEncoder* enc = WebPAnimEncoderNew(width, height, &anim_options);
if (enc == nullptr) { if (enc == nullptr) {
fprintf(stderr, "Error: Could not create WebPAnimEncoder object.\n"); fprintf(stderr, "Error: Could not create WebPAnimEncoder object.\n");
return -1; return -1;
@ -884,22 +855,23 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images
const int frame_duration_ms = std::max(1, static_cast<int>(std::lround(1000.0 / static_cast<double>(fps)))); const int frame_duration_ms = std::max(1, static_cast<int>(std::lround(1000.0 / static_cast<double>(fps))));
int timestamp_ms = 0; int timestamp_ms = 0;
int ret = -1;
for (int i = 0; i < num_images; ++i) { for (int i = 0; i < num_images; ++i) {
const sd_image_t& image = images[i]; const sd_image_t& image = images[i];
if (static_cast<int>(image.width) != width || static_cast<int>(image.height) != height) { if (static_cast<int>(image.width) != width || static_cast<int>(image.height) != height) {
fprintf(stderr, "Error: Frame dimensions do not match.\n"); fprintf(stderr, "Error: Frame dimensions do not match.\n");
return -1; goto cleanup;
} }
WebPPictureGuard picture; WebPPicture picture;
if (!picture.initialized) { if (!WebPPictureInit(&picture)) {
fprintf(stderr, "Error: Failed to initialize WebPPicture.\n"); fprintf(stderr, "Error: Failed to initialize WebPPicture.\n");
return -1; goto cleanup;
} }
picture.picture.use_argb = 1; picture.use_argb = 1;
picture.picture.width = width; picture.width = width;
picture.picture.height = height; picture.height = height;
bool picture_ok = false; bool picture_ok = false;
std::vector<uint8_t> rgb_buffer; std::vector<uint8_t> rgb_buffer;
@ -910,48 +882,64 @@ int create_animated_webp_from_sd_images(const char* filename, sd_image_t* images
rgb_buffer[p * 3 + 1] = image.data[p]; rgb_buffer[p * 3 + 1] = image.data[p];
rgb_buffer[p * 3 + 2] = image.data[p]; rgb_buffer[p * 3 + 2] = image.data[p];
} }
picture_ok = WebPPictureImportRGB(&picture.picture, rgb_buffer.data(), width * 3) != 0; picture_ok = WebPPictureImportRGB(&picture, rgb_buffer.data(), width * 3) != 0;
} else if (image.channel == 4) { } else if (image.channel == 4) {
picture_ok = WebPPictureImportRGBA(&picture.picture, image.data, width * 4) != 0; picture_ok = WebPPictureImportRGBA(&picture, image.data, width * 4) != 0;
} else { } else {
picture_ok = WebPPictureImportRGB(&picture.picture, image.data, width * 3) != 0; picture_ok = WebPPictureImportRGB(&picture, image.data, width * 3) != 0;
} }
if (!picture_ok) { if (!picture_ok) {
fprintf(stderr, "Error: Failed to import frame into WebPPicture.\n"); fprintf(stderr, "Error: Failed to import frame into WebPPicture.\n");
return -1; WebPPictureFree(&picture);
goto cleanup;
} }
if (!WebPAnimEncoderAdd(enc.get(), &picture.picture, timestamp_ms, &config)) { if (!WebPAnimEncoderAdd(enc, &picture, timestamp_ms, &config)) {
fprintf(stderr, "Error: Failed to add frame to animated WebP: %s\n", WebPAnimEncoderGetError(enc.get())); fprintf(stderr, "Error: Failed to add frame to animated WebP: %s\n", WebPAnimEncoderGetError(enc));
return -1; WebPPictureFree(&picture);
goto cleanup;
} }
WebPPictureFree(&picture);
timestamp_ms += frame_duration_ms; timestamp_ms += frame_duration_ms;
} }
if (!WebPAnimEncoderAdd(enc.get(), nullptr, timestamp_ms, nullptr)) { if (!WebPAnimEncoderAdd(enc, nullptr, timestamp_ms, nullptr)) {
fprintf(stderr, "Error: Failed to finalize animated WebP frames: %s\n", WebPAnimEncoderGetError(enc.get())); fprintf(stderr, "Error: Failed to finalize animated WebP frames: %s\n", WebPAnimEncoderGetError(enc));
return -1; goto cleanup;
} }
WebPDataGuard webp_data; {
if (!WebPAnimEncoderAssemble(enc.get(), &webp_data.data)) { WebPData webp_data;
fprintf(stderr, "Error: Failed to assemble animated WebP: %s\n", WebPAnimEncoderGetError(enc.get())); WebPDataInit(&webp_data);
return -1; if (!WebPAnimEncoderAssemble(enc, &webp_data)) {
fprintf(stderr, "Error: Failed to assemble animated WebP: %s\n", WebPAnimEncoderGetError(enc));
WebPDataClear(&webp_data);
goto cleanup;
} }
FilePtr f(fopen(filename, "wb")); FILE* f = fopen(filename, "wb");
if (!f) { if (!f) {
perror("Error opening file for writing"); perror("Error opening file for writing");
return -1; WebPDataClear(&webp_data);
goto cleanup;
} }
if (webp_data.data.size > 0 && fwrite(webp_data.data.bytes, 1, webp_data.data.size, f.get()) != webp_data.data.size) { if (webp_data.size > 0 && fwrite(webp_data.bytes, 1, webp_data.size, f) != webp_data.size) {
fprintf(stderr, "Error: Failed to write animated WebP file.\n"); fprintf(stderr, "Error: Failed to write animated WebP file.\n");
return -1; fclose(f);
WebPDataClear(&webp_data);
goto cleanup;
}
fclose(f);
WebPDataClear(&webp_data);
} }
return 0; ret = 0;
cleanup:
WebPAnimEncoderDelete(enc);
return ret;
} }
#endif #endif

View File

@ -1,207 +0,0 @@
#ifndef __EXAMPLE_RESOURCE_OWNERS_H__
#define __EXAMPLE_RESOURCE_OWNERS_H__
#include <cstdio>
#include <cstdlib>
#include <memory>
#include <utility>
#include <vector>
#include "stable-diffusion.h"
struct FreeDeleter {
void operator()(void* ptr) const {
free(ptr);
}
};
struct FileCloser {
void operator()(FILE* file) const {
if (file != nullptr) {
fclose(file);
}
}
};
struct SDCtxDeleter {
void operator()(sd_ctx_t* ctx) const {
if (ctx != nullptr) {
free_sd_ctx(ctx);
}
}
};
struct UpscalerCtxDeleter {
void operator()(upscaler_ctx_t* ctx) const {
if (ctx != nullptr) {
free_upscaler_ctx(ctx);
}
}
};
template <typename T>
using FreeUniquePtr = std::unique_ptr<T, FreeDeleter>;
using FilePtr = std::unique_ptr<FILE, FileCloser>;
using SDCtxPtr = std::unique_ptr<sd_ctx_t, SDCtxDeleter>;
using UpscalerCtxPtr = std::unique_ptr<upscaler_ctx_t, UpscalerCtxDeleter>;
class SDImageOwner {
public:
SDImageOwner() = default;
explicit SDImageOwner(sd_image_t image)
: image_(image) {
}
SDImageOwner(const SDImageOwner&) = delete;
SDImageOwner& operator=(const SDImageOwner&) = delete;
SDImageOwner(SDImageOwner&& other) noexcept
: image_(other.release()) {
}
SDImageOwner& operator=(SDImageOwner&& other) noexcept {
if (this != &other) {
reset();
image_ = other.release();
}
return *this;
}
~SDImageOwner() {
reset();
}
sd_image_t* put() {
if (image_.data != nullptr) {
free(image_.data);
image_.data = nullptr;
}
image_.width = 0;
image_.height = 0;
return &image_;
}
sd_image_t& get() {
return image_;
}
const sd_image_t& get() const {
return image_;
}
sd_image_t release() {
sd_image_t image = image_;
image_ = {0, 0, 0, nullptr};
return image;
}
void reset(sd_image_t image = {0, 0, 0, nullptr}) {
if (image_.data != nullptr) {
free(image_.data);
}
image_ = image;
}
private:
sd_image_t image_ = {0, 0, 0, nullptr};
};
class SDImageVec {
public:
SDImageVec() = default;
SDImageVec(const SDImageVec&) = delete;
SDImageVec& operator=(const SDImageVec&) = delete;
SDImageVec(SDImageVec&& other) noexcept
: images_(std::move(other.images_)) {
}
SDImageVec& operator=(SDImageVec&& other) noexcept {
if (this != &other) {
clear();
images_ = std::move(other.images_);
}
return *this;
}
~SDImageVec() {
clear();
}
void push_back(sd_image_t image) {
images_.push_back(image);
}
void push_back(SDImageOwner&& image) {
images_.push_back(image.release());
}
void reserve(size_t count) {
images_.reserve(count);
}
void adopt(sd_image_t* images, int count) {
clear();
if (images == nullptr || count <= 0) {
free(images);
return;
}
images_.reserve(static_cast<size_t>(count));
for (int i = 0; i < count; ++i) {
images_.push_back(images[i]);
}
free(images);
}
size_t size() const {
return images_.size();
}
bool empty() const {
return images_.empty();
}
explicit operator bool() const {
return !images_.empty();
}
sd_image_t* data() {
return images_.data();
}
const sd_image_t* data() const {
return images_.data();
}
sd_image_t& operator[](size_t index) {
return images_[index];
}
const sd_image_t& operator[](size_t index) const {
return images_[index];
}
std::vector<sd_image_t>& raw() {
return images_;
}
const std::vector<sd_image_t>& raw() const {
return images_;
}
void clear() {
for (sd_image_t& image : images_) {
free(image.data);
image.data = nullptr;
}
images_.clear();
}
private:
std::vector<sd_image_t> images_;
};
#endif // __EXAMPLE_RESOURCE_OWNERS_H__

View File

@ -57,7 +57,6 @@ else()
endif() endif()
add_executable(${TARGET} add_executable(${TARGET}
../common/common.cpp
../common/log.cpp ../common/log.cpp
../common/media_io.cpp ../common/media_io.cpp
main.cpp main.cpp

View File

@ -8,19 +8,16 @@
#include <sstream> #include <sstream>
#include <vector> #include <vector>
#include <json.hpp>
#include "httplib.h" #include "httplib.h"
#include "stable-diffusion.h" #include "stable-diffusion.h"
#include "common/common.h" #include "common/common.hpp"
#include "common/media_io.h" #include "common/media_io.h"
#include "common/resource_owners.hpp"
#ifdef HAVE_INDEX_HTML #ifdef HAVE_INDEX_HTML
#include "frontend/dist/gen_index_html.h" #include "frontend/dist/gen_index_html.h"
#endif #endif
using json = nlohmann::json;
namespace fs = std::filesystem; namespace fs = std::filesystem;
// ----------------------- helpers ----------------------- // ----------------------- helpers -----------------------
@ -289,6 +286,18 @@ std::string get_lora_full_path(ServerRuntime& rt, const std::string& path) {
return (it != rt.lora_cache->end()) ? it->fullpath : ""; return (it != rt.lora_cache->end()) ? it->fullpath : "";
} }
void free_results(sd_image_t* result_images, int num_results) {
if (result_images) {
for (int i = 0; i < num_results; ++i) {
if (result_images[i].data) {
free(result_images[i].data);
result_images[i].data = nullptr;
}
}
}
free(result_images);
}
void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) { void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) {
const std::string serve_html_path = svr_params.serve_html_path; const std::string serve_html_path = svr_params.serve_html_path;
svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) { svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) {
@ -396,10 +405,10 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
SDImageOwner init_image({(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}); sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
SDImageOwner control_image({(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}); sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
SDImageOwner mask_image({(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr}); sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
SDImageVec pmid_images; std::vector<sd_image_t> pmid_images;
sd_img_gen_params_t img_gen_params = { sd_img_gen_params_t img_gen_params = {
gen_params.lora_vec.data(), gen_params.lora_vec.data(),
@ -407,19 +416,19 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.prompt.c_str(), gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(), gen_params.negative_prompt.c_str(),
gen_params.clip_skip, gen_params.clip_skip,
init_image.get(), init_image,
nullptr, nullptr,
0, 0,
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image.get(), mask_image,
gen_params.width, gen_params.width,
gen_params.height, gen_params.height,
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
gen_params.batch_count, gen_params.batch_count,
control_image.get(), control_image,
gen_params.control_strength, gen_params.control_strength,
{ {
pmid_images.data(), pmid_images.data(),
@ -431,19 +440,13 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.cache_params, gen_params.cache_params,
}; };
SDImageVec results; sd_image_t* results = nullptr;
int num_results = 0; int num_results = 0;
{ {
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex); std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
results.adopt(generate_image(runtime->sd_ctx, &img_gen_params), num_results);
}
if (!results) {
res.status = 500;
res.set_content(R"({"error":"generate failed"})", "application/json");
return;
} }
for (int i = 0; i < num_results; i++) { for (int i = 0; i < num_results; i++) {
@ -474,6 +477,8 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
item["b64_json"] = b64; item["b64_json"] = b64;
out["data"].push_back(item); out["data"].push_back(item);
} }
free_results(results, num_results);
res.set_content(out.dump(), "application/json"); res.set_content(out.dump(), "application/json");
res.status = 200; res.status = 200;
@ -594,9 +599,9 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
SDImageOwner init_image({0, 0, 3, nullptr}); sd_image_t init_image = {0, 0, 3, nullptr};
SDImageOwner control_image({0, 0, 3, nullptr}); sd_image_t control_image = {0, 0, 3, nullptr};
SDImageVec pmid_images; std::vector<sd_image_t> pmid_images;
auto get_resolved_width = [&gen_params, runtime]() -> int { auto get_resolved_width = [&gen_params, runtime]() -> int {
if (gen_params.width > 0) if (gen_params.width > 0)
@ -613,7 +618,7 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
return 512; return 512;
}; };
SDImageVec ref_images; std::vector<sd_image_t> ref_images;
ref_images.reserve(images_bytes.size()); ref_images.reserve(images_bytes.size());
for (auto& bytes : images_bytes) { for (auto& bytes : images_bytes) {
int img_w; int img_w;
@ -629,12 +634,12 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
continue; continue;
} }
SDImageOwner img({(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels}); sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels};
gen_params.set_width_and_height_if_unset(img.get().width, img.get().height); gen_params.set_width_and_height_if_unset(img.width, img.height);
ref_images.push_back(std::move(img)); ref_images.push_back(img);
} }
SDImageOwner mask_image({0, 0, 1, nullptr}); sd_image_t mask_image = {0};
if (!mask_bytes.empty()) { if (!mask_bytes.empty()) {
int expected_width = 0; int expected_width = 0;
int expected_height = 0; int expected_height = 0;
@ -650,10 +655,13 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
static_cast<int>(mask_bytes.size()), static_cast<int>(mask_bytes.size()),
mask_w, mask_h, mask_w, mask_h,
expected_width, expected_height, 1); expected_width, expected_height, 1);
mask_image.reset({(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw}); mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw};
gen_params.set_width_and_height_if_unset(mask_image.get().width, mask_image.get().height); gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height);
} else { } else {
mask_image.reset({(uint32_t)get_resolved_width(), (uint32_t)get_resolved_height(), 1, nullptr}); mask_image.width = get_resolved_width();
mask_image.height = get_resolved_height();
mask_image.channel = 1;
mask_image.data = nullptr;
} }
sd_img_gen_params_t img_gen_params = { sd_img_gen_params_t img_gen_params = {
@ -662,19 +670,19 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.prompt.c_str(), gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(), gen_params.negative_prompt.c_str(),
gen_params.clip_skip, gen_params.clip_skip,
init_image.get(), init_image,
ref_images.data(), ref_images.data(),
(int)ref_images.size(), (int)ref_images.size(),
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image.get(), mask_image,
get_resolved_width(), get_resolved_width(),
get_resolved_height(), get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
gen_params.batch_count, gen_params.batch_count,
control_image.get(), control_image,
gen_params.control_strength, gen_params.control_strength,
{ {
pmid_images.data(), pmid_images.data(),
@ -686,19 +694,13 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.cache_params, gen_params.cache_params,
}; };
SDImageVec results; sd_image_t* results = nullptr;
int num_results = 0; int num_results = 0;
{ {
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex); std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
results.adopt(generate_image(runtime->sd_ctx, &img_gen_params), num_results);
}
if (!results) {
res.status = 500;
res.set_content(R"({"error":"generate failed"})", "application/json");
return;
} }
json out; json out;
@ -728,8 +730,20 @@ void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
item["b64_json"] = b64; item["b64_json"] = b64;
out["data"].push_back(item); out["data"].push_back(item);
} }
free_results(results, num_results);
res.set_content(out.dump(), "application/json"); res.set_content(out.dump(), "application/json");
res.status = 200; res.status = 200;
if (init_image.data) {
free(init_image.data);
}
if (mask_image.data) {
free(mask_image.data);
}
for (auto ref_image : ref_images) {
free(ref_image.data);
}
} catch (const std::exception& e) { } catch (const std::exception& e) {
res.status = 500; res.status = 500;
json err; json err;
@ -878,11 +892,12 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
LOG_DEBUG("%s\n", gen_params.to_string().c_str()); LOG_DEBUG("%s\n", gen_params.to_string().c_str());
SDImageOwner init_image({0, 0, 3, nullptr}); sd_image_t init_image = {0, 0, 3, nullptr};
SDImageOwner control_image({0, 0, 3, nullptr}); sd_image_t control_image = {0, 0, 3, nullptr};
SDImageOwner mask_image({0, 0, 1, nullptr}); sd_image_t mask_image = {0, 0, 1, nullptr};
SDImageVec pmid_images; std::vector<uint8_t> mask_data;
SDImageVec ref_images; std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> ref_images;
auto get_resolved_width = [&gen_params, runtime]() -> int { auto get_resolved_width = [&gen_params, runtime]() -> int {
if (gen_params.width > 0) if (gen_params.width > 0)
@ -899,7 +914,7 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
return 512; return 512;
}; };
auto decode_image = [&gen_params](SDImageOwner& image, std::string encoded) -> bool { auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
auto comma_pos = encoded.find(','); auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) { if (comma_pos != std::string::npos) {
encoded = encoded.substr(comma_pos + 1); encoded = encoded.substr(comma_pos + 1);
@ -918,10 +933,10 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
uint8_t* raw_data = load_image_from_memory( uint8_t* raw_data = load_image_from_memory(
(const char*)img_data.data(), (int)img_data.size(), (const char*)img_data.data(), (int)img_data.size(),
img_w, img_h, img_w, img_h,
expected_width, expected_height, image.get().channel); expected_width, expected_height, image.channel);
if (raw_data) { if (raw_data) {
image.reset({(uint32_t)img_w, (uint32_t)img_h, image.get().channel, raw_data}); image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
gen_params.set_width_and_height_if_unset(image.get().width, image.get().height); gen_params.set_width_and_height_if_unset(image.width, image.height);
return true; return true;
} }
} }
@ -938,21 +953,19 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
std::string encoded = j["mask"].get<std::string>(); std::string encoded = j["mask"].get<std::string>();
decode_image(mask_image, encoded); decode_image(mask_image, encoded);
bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0; bool inpainting_mask_invert = j.value("inpainting_mask_invert", 0) != 0;
if (inpainting_mask_invert && mask_image.get().data != nullptr) { if (inpainting_mask_invert && mask_image.data != nullptr) {
for (uint32_t i = 0; i < mask_image.get().width * mask_image.get().height; i++) { for (uint32_t i = 0; i < mask_image.width * mask_image.height; i++) {
mask_image.get().data[i] = 255 - mask_image.get().data[i]; mask_image.data[i] = 255 - mask_image.data[i];
} }
} }
} else { } else {
int m_width = get_resolved_width(); int m_width = get_resolved_width();
int m_height = get_resolved_height(); int m_height = get_resolved_height();
sd_image_t generated_mask = {(uint32_t)m_width, (uint32_t)m_height, 1, nullptr}; mask_data = std::vector<uint8_t>(m_width * m_height, 255);
generated_mask.data = (uint8_t*)malloc(static_cast<size_t>(m_width) * static_cast<size_t>(m_height)); mask_image.width = m_width;
if (generated_mask.data == nullptr) { mask_image.height = m_height;
return bad("failed to allocate default mask"); mask_image.channel = 1;
} mask_image.data = mask_data.data();
memset(generated_mask.data, 255, static_cast<size_t>(m_width) * static_cast<size_t>(m_height));
mask_image.reset(generated_mask);
} }
float denoising_strength = j.value("denoising_strength", -1.f); float denoising_strength = j.value("denoising_strength", -1.f);
@ -965,9 +978,9 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
if (j.contains("extra_images") && j["extra_images"].is_array()) { if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) { for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>(); std::string encoded = extra_image.get<std::string>();
SDImageOwner tmp_image({(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}); sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
if (decode_image(tmp_image, encoded)) { if (decode_image(tmp_image, encoded)) {
ref_images.push_back(std::move(tmp_image)); ref_images.push_back(tmp_image);
} }
} }
} }
@ -978,19 +991,19 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.prompt.c_str(), gen_params.prompt.c_str(),
gen_params.negative_prompt.c_str(), gen_params.negative_prompt.c_str(),
gen_params.clip_skip, gen_params.clip_skip,
init_image.get(), init_image,
ref_images.data(), ref_images.data(),
(int)ref_images.size(), (int)ref_images.size(),
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image.get(), mask_image,
get_resolved_width(), get_resolved_width(),
get_resolved_height(), get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
gen_params.batch_count, gen_params.batch_count,
control_image.get(), control_image,
gen_params.control_strength, gen_params.control_strength,
{ {
pmid_images.data(), pmid_images.data(),
@ -1002,19 +1015,13 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
gen_params.cache_params, gen_params.cache_params,
}; };
SDImageVec results; sd_image_t* results = nullptr;
int num_results = 0; int num_results = 0;
{ {
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex); std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
results.adopt(generate_image(runtime->sd_ctx, &img_gen_params), num_results);
}
if (!results) {
res.status = 500;
res.set_content(R"({"error":"generate failed"})", "application/json");
return;
} }
json out; json out;
@ -1045,9 +1052,21 @@ void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
std::string b64 = base64_encode(image_bytes); std::string b64 = base64_encode(image_bytes);
out["images"].push_back(b64); out["images"].push_back(b64);
} }
free_results(results, num_results);
res.set_content(out.dump(), "application/json"); res.set_content(out.dump(), "application/json");
res.status = 200; res.status = 200;
if (init_image.data) {
free(init_image.data);
}
if (mask_image.data && mask_data.empty()) {
free(mask_image.data);
}
for (auto ref_image : ref_images) {
free(ref_image.data);
}
} catch (const std::exception& e) { } catch (const std::exception& e) {
res.status = 500; res.status = 500;
json err; json err;
@ -1159,7 +1178,7 @@ int main(int argc, const char** argv) {
LOG_DEBUG("%s", default_gen_params.to_string().c_str()); LOG_DEBUG("%s", default_gen_params.to_string().c_str());
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false); sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false);
SDCtxPtr sd_ctx(new_sd_ctx(&sd_ctx_params)); sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
if (sd_ctx == nullptr) { if (sd_ctx == nullptr) {
LOG_ERROR("new_sd_ctx_t failed"); LOG_ERROR("new_sd_ctx_t failed");
@ -1171,7 +1190,7 @@ int main(int argc, const char** argv) {
std::vector<LoraEntry> lora_cache; std::vector<LoraEntry> lora_cache;
std::mutex lora_mutex; std::mutex lora_mutex;
ServerRuntime runtime = { ServerRuntime runtime = {
sd_ctx.get(), sd_ctx,
&sd_ctx_mutex, &sd_ctx_mutex,
&svr_params, &svr_params,
&ctx_params, &ctx_params,
@ -1212,5 +1231,6 @@ int main(int argc, const char** argv) {
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port); LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
svr.listen(svr_params.listen_ip, svr_params.listen_port); svr.listen(svr_params.listen_ip, svr_params.listen_port);
free_sd_ctx(sd_ctx);
return 0; return 0;
} }

View File

@ -1,6 +1,4 @@
for f in src/*.cpp src/*.h src/*.hpp src/vocab/*.h src/vocab/*.cpp \ for f in src/*.cpp src/*.h src/*.hpp src/vocab/*.h src/vocab/*.cpp examples/cli/*.cpp examples/common/*.hpp examples/cli/*.h examples/server/*.cpp; do
examples/cli/*.cpp examples/cli/*.h examples/server/*.cpp \
examples/common/*.hpp examples/common/*.h examples/common/*.cpp; do
[[ "$f" == vocab* ]] && continue [[ "$f" == vocab* ]] && continue
echo "formatting '$f'" echo "formatting '$f'"
# if [ "$f" != "stable-diffusion.h" ]; then # if [ "$f" != "stable-diffusion.h" ]; then

View File

@ -1,4 +1,4 @@
#ifndef __T5_HPP__ #ifndef __T5_HPP__
#define __T5_HPP__ #define __T5_HPP__
#include <cfloat> #include <cfloat>

View File

@ -1,4 +1,4 @@
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>