diff --git a/examples/server/main.cpp b/examples/server/main.cpp index 8681f2e..25ae4a8 100644 --- a/examples/server/main.cpp +++ b/examples/server/main.cpp @@ -273,6 +273,63 @@ struct LoraEntry { std::string fullpath; }; +struct ServerRuntime { + sd_ctx_t* sd_ctx; + std::mutex* sd_ctx_mutex; + const SDSvrParams* svr_params; + const SDContextParams* ctx_params; + const SDGenerationParams* default_gen_params; + std::vector* lora_cache; + std::mutex* lora_mutex; +}; + +void refresh_lora_cache(ServerRuntime& rt) { + std::vector new_cache; + + fs::path lora_dir = rt.ctx_params->lora_model_dir; + if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) { + auto is_lora_ext = [](const fs::path& p) { + auto ext = p.extension().string(); + std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); + return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors"; + }; + + for (auto& entry : fs::recursive_directory_iterator(lora_dir)) { + if (!entry.is_regular_file()) + continue; + const fs::path& p = entry.path(); + if (!is_lora_ext(p)) + continue; + + LoraEntry e; + e.name = p.stem().u8string(); + e.fullpath = p.u8string(); + std::string rel = p.lexically_relative(lora_dir).u8string(); + std::replace(rel.begin(), rel.end(), '\\', '/'); + e.path = rel; + + new_cache.push_back(std::move(e)); + } + } + + std::sort(new_cache.begin(), new_cache.end(), + [](const LoraEntry& a, const LoraEntry& b) { + return a.path < b.path; + }); + + { + std::lock_guard lock(*rt.lora_mutex); + *rt.lora_cache = std::move(new_cache); + } +} + +std::string get_lora_full_path(ServerRuntime& rt, const std::string& path) { + std::lock_guard lock(*rt.lora_mutex); + auto it = std::find_if(rt.lora_cache->begin(), rt.lora_cache->end(), + [&](const LoraEntry& e) { return e.path == path; }); + return (it != rt.lora_cache->end()) ? it->fullpath : ""; +} + void free_results(sd_image_t* result_images, int num_results) { if (result_images) { for (int i = 0; i < num_results; ++i) { @@ -285,115 +342,11 @@ void free_results(sd_image_t* result_images, int num_results) { free(result_images); } -int main(int argc, const char** argv) { - if (argc > 1 && std::string(argv[1]) == "--version") { - std::cout << version_string() << "\n"; - return EXIT_SUCCESS; - } - SDSvrParams svr_params; - SDContextParams ctx_params; - SDGenerationParams default_gen_params; - parse_args(argc, argv, svr_params, ctx_params, default_gen_params); - - sd_set_log_callback(sd_log_cb, (void*)&svr_params); - log_verbose = svr_params.verbose; - log_color = svr_params.color; - - LOG_DEBUG("version: %s", version_string().c_str()); - LOG_DEBUG("%s", sd_get_system_info()); - LOG_DEBUG("%s", svr_params.to_string().c_str()); - LOG_DEBUG("%s", ctx_params.to_string().c_str()); - LOG_DEBUG("%s", default_gen_params.to_string().c_str()); - - sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false); - sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); - - if (sd_ctx == nullptr) { - LOG_ERROR("new_sd_ctx_t failed"); - return 1; - } - - std::mutex sd_ctx_mutex; - - std::vector lora_cache; - std::mutex lora_mutex; - - auto refresh_lora_cache = [&]() { - std::vector new_cache; - - fs::path lora_dir = ctx_params.lora_model_dir; - if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) { - auto is_lora_ext = [](const fs::path& p) { - auto ext = p.extension().string(); - std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); - return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors"; - }; - - for (auto& entry : fs::recursive_directory_iterator(lora_dir)) { - if (!entry.is_regular_file()) - continue; - const fs::path& p = entry.path(); - if (!is_lora_ext(p)) - continue; - - LoraEntry e; - e.name = p.stem().u8string(); - e.fullpath = p.u8string(); - std::string rel = p.lexically_relative(lora_dir).u8string(); - std::replace(rel.begin(), rel.end(), '\\', '/'); - e.path = rel; - - new_cache.push_back(std::move(e)); - } - } - - std::sort(new_cache.begin(), new_cache.end(), - [](const LoraEntry& a, const LoraEntry& b) { - return a.path < b.path; - }); - - { - std::lock_guard lock(lora_mutex); - lora_cache = std::move(new_cache); - } - }; - - auto get_lora_full_path = [&](const std::string& path) -> std::string { - std::lock_guard lock(lora_mutex); - auto it = std::find_if(lora_cache.begin(), lora_cache.end(), - [&](const LoraEntry& e) { return e.path == path; }); - return (it != lora_cache.end()) ? it->fullpath : ""; - }; - - httplib::Server svr; - - svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) { - std::string origin = req.get_header_value("Origin"); - if (origin.empty()) { - origin = "*"; - } - res.set_header("Access-Control-Allow-Origin", origin); - res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "*"); - res.set_header("Access-Control-Allow-Headers", "*"); - - if (req.method == "OPTIONS") { - res.status = 204; - return httplib::Server::HandlerResponse::Handled; - } - return httplib::Server::HandlerResponse::Unhandled; - }); - - // index html - std::string index_html; -#ifdef HAVE_INDEX_HTML - index_html.assign(reinterpret_cast(index_html_bytes), index_html_size); -#else - index_html = "Stable Diffusion Server is running"; -#endif - svr.Get("/", [&](const httplib::Request&, httplib::Response& res) { - if (!svr_params.serve_html_path.empty()) { - std::ifstream file(svr_params.serve_html_path); +void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) { + const std::string serve_html_path = svr_params.serve_html_path; + svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) { + if (!serve_html_path.empty()) { + std::ifstream file(serve_html_path); if (file) { std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator()); res.set_content(content, "text/html"); @@ -405,17 +358,19 @@ int main(int argc, const char** argv) { res.set_content(index_html, "text/html"); } }); +} - // models endpoint (minimal) - svr.Get("/v1/models", [&](const httplib::Request&, httplib::Response& res) { +void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) { + ServerRuntime* runtime = &rt; + + svr.Get("/v1/models", [runtime](const httplib::Request&, httplib::Response& res) { json r; r["data"] = json::array(); r["data"].push_back({{"id", "sd-cpp-local"}, {"object", "model"}, {"owned_by", "local"}}); res.set_content(r.dump(), "application/json"); }); - // core endpoint: /v1/images/generations - svr.Post("/v1/images/generations", [&](const httplib::Request& req, httplib::Response& res) { + svr.Post("/v1/images/generations", [runtime](const httplib::Request& req, httplib::Response& res) { try { if (req.body.empty()) { res.status = 400; @@ -429,8 +384,8 @@ int main(int argc, const char** argv) { std::string size = j.value("size", ""); std::string output_format = j.value("output_format", "png"); int output_compression = j.value("output_compression", 100); - int width = default_gen_params.width > 0 ? default_gen_params.width : 512; - int height = default_gen_params.width > 0 ? default_gen_params.height : 512; + int width = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->width : 512; + int height = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->height : 512; if (!size.empty()) { auto pos = size.find('x'); if (pos != std::string::npos) { @@ -458,7 +413,7 @@ int main(int argc, const char** argv) { if (n <= 0) n = 1; if (n > 8) - n = 8; // safety + n = 8; if (output_compression > 100) { output_compression = 100; } @@ -471,7 +426,7 @@ int main(int argc, const char** argv) { out["data"] = json::array(); out["output_format"] = output_format; - SDGenerationParams gen_params = default_gen_params; + SDGenerationParams gen_params = *runtime->default_gen_params; gen_params.prompt = prompt; gen_params.width = width; gen_params.height = height; @@ -524,7 +479,7 @@ int main(int argc, const char** argv) { (int)pmid_images.size(), gen_params.pm_id_embed_path.c_str(), gen_params.pm_style_strength, - }, // pm_params + }, gen_params.vae_tiling_params, gen_params.cache_params, }; @@ -533,8 +488,8 @@ int main(int argc, const char** argv) { int num_results = 0; { - std::lock_guard lock(sd_ctx_mutex); - results = generate_image(sd_ctx, &img_gen_params); + std::lock_guard lock(*runtime->sd_ctx_mutex); + results = generate_image(runtime->sd_ctx, &img_gen_params); num_results = gen_params.batch_count; } @@ -553,7 +508,6 @@ int main(int argc, const char** argv) { continue; } - // base64 encode std::string b64 = base64_encode(image_bytes); json item; item["b64_json"] = b64; @@ -573,7 +527,7 @@ int main(int argc, const char** argv) { } }); - svr.Post("/v1/images/edits", [&](const httplib::Request& req, httplib::Response& res) { + svr.Post("/v1/images/edits", [runtime](const httplib::Request& req, httplib::Response& res) { try { if (!req.is_multipart_form_data()) { res.status = 400; @@ -658,7 +612,7 @@ int main(int argc, const char** argv) { output_compression = 0; } - SDGenerationParams gen_params = default_gen_params; + SDGenerationParams gen_params = *runtime->default_gen_params; gen_params.prompt = prompt; gen_params.width = width; gen_params.height = height; @@ -685,18 +639,18 @@ int main(int argc, const char** argv) { sd_image_t control_image = {0, 0, 3, nullptr}; std::vector pmid_images; - auto get_resolved_width = [&gen_params, &default_gen_params]() -> int { + auto get_resolved_width = [&gen_params, runtime]() -> int { if (gen_params.width > 0) return gen_params.width; - if (default_gen_params.width > 0) - return default_gen_params.width; + if (runtime->default_gen_params->width > 0) + return runtime->default_gen_params->width; return 512; }; - auto get_resolved_height = [&gen_params, &default_gen_params]() -> int { + auto get_resolved_height = [&gen_params, runtime]() -> int { if (gen_params.height > 0) return gen_params.height; - if (default_gen_params.height > 0) - return default_gen_params.height; + if (runtime->default_gen_params->height > 0) + return runtime->default_gen_params->height; return 512; }; @@ -771,7 +725,7 @@ int main(int argc, const char** argv) { (int)pmid_images.size(), gen_params.pm_id_embed_path.c_str(), gen_params.pm_style_strength, - }, // pm_params + }, gen_params.vae_tiling_params, gen_params.cache_params, }; @@ -780,8 +734,8 @@ int main(int argc, const char** argv) { int num_results = 0; { - std::lock_guard lock(sd_ctx_mutex); - results = generate_image(sd_ctx, &img_gen_params); + std::lock_guard lock(*runtime->sd_ctx_mutex); + results = generate_image(runtime->sd_ctx, &img_gen_params); num_results = gen_params.batch_count; } @@ -826,10 +780,12 @@ int main(int argc, const char** argv) { res.set_content(err.dump(), "application/json"); } }); +} - // sdapi endpoints (AUTOMATIC1111 / Forge) +void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) { + ServerRuntime* runtime = &rt; - auto sdapi_any2img = [&](const httplib::Request& req, httplib::Response& res, bool img2img) { + auto sdapi_any2img = [runtime](const httplib::Request& req, httplib::Response& res, bool img2img) { try { if (req.body.empty()) { res.status = 400; @@ -843,8 +799,8 @@ int main(int argc, const char** argv) { std::string negative_prompt = j.value("negative_prompt", ""); int width = j.value("width", 512); int height = j.value("height", 512); - int steps = j.value("steps", default_gen_params.sample_params.sample_steps); - float cfg_scale = j.value("cfg_scale", default_gen_params.sample_params.guidance.txt_cfg); + int steps = j.value("steps", runtime->default_gen_params->sample_params.sample_steps); + float cfg_scale = j.value("cfg_scale", runtime->default_gen_params->sample_params.guidance.txt_cfg); int64_t seed = j.value("seed", -1); int batch_size = j.value("batch_size", 1); int clip_skip = j.value("clip_skip", -1); @@ -894,7 +850,7 @@ int main(int argc, const char** argv) { return bad("lora.path required"); } - std::string fullpath = get_lora_full_path(path); + std::string fullpath = get_lora_full_path(*runtime, path); if (fullpath.empty()) { return bad("invalid lora path: " + path); } @@ -912,7 +868,6 @@ int main(int argc, const char** argv) { auto get_sample_method = [](std::string name) -> enum sample_method_t { enum sample_method_t result = str_to_sample_method(name.c_str()); if (result != SAMPLE_METHOD_COUNT) return result; - // some applications use a hardcoded sampler list std::transform(name.begin(), name.end(), name.begin(), [](unsigned char c) { return std::tolower(c); }); static const std::unordered_map hardcoded{ @@ -938,10 +893,9 @@ int main(int argc, const char** argv) { }; enum sample_method_t sample_method = get_sample_method(sampler_name); + enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str()); - enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str()); - - SDGenerationParams gen_params = default_gen_params; + SDGenerationParams gen_params = *runtime->default_gen_params; gen_params.prompt = prompt; gen_params.negative_prompt = negative_prompt; gen_params.seed = seed; @@ -961,8 +915,6 @@ int main(int argc, const char** argv) { gen_params.sample_params.scheduler = scheduler; } - // re-read to avoid applying 512 as default before the provided - // images and/or server command-line gen_params.width = j.value("width", -1); gen_params.height = j.value("height", -1); @@ -975,23 +927,22 @@ int main(int argc, const char** argv) { std::vector pmid_images; std::vector ref_images; - auto get_resolved_width = [&gen_params, &default_gen_params]() -> int { + auto get_resolved_width = [&gen_params, runtime]() -> int { if (gen_params.width > 0) return gen_params.width; - if (default_gen_params.width > 0) - return default_gen_params.width; + if (runtime->default_gen_params->width > 0) + return runtime->default_gen_params->width; return 512; }; - auto get_resolved_height = [&gen_params, &default_gen_params]() -> int { + auto get_resolved_height = [&gen_params, runtime]() -> int { if (gen_params.height > 0) return gen_params.height; - if (default_gen_params.height > 0) - return default_gen_params.height; + if (runtime->default_gen_params->height > 0) + return runtime->default_gen_params->height; return 512; }; auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool { - // remove data URI prefix if present ("data:image/png;base64,") auto comma_pos = encoded.find(','); if (comma_pos != std::string::npos) { encoded = encoded.substr(comma_pos + 1); @@ -1087,7 +1038,7 @@ int main(int argc, const char** argv) { (int)pmid_images.size(), gen_params.pm_id_embed_path.c_str(), gen_params.pm_style_strength, - }, // pm_params + }, gen_params.vae_tiling_params, gen_params.cache_params, }; @@ -1096,14 +1047,14 @@ int main(int argc, const char** argv) { int num_results = 0; { - std::lock_guard lock(sd_ctx_mutex); - results = generate_image(sd_ctx, &img_gen_params); + std::lock_guard lock(*runtime->sd_ctx_mutex); + results = generate_image(runtime->sd_ctx, &img_gen_params); num_results = gen_params.batch_count; } json out; out["images"] = json::array(); - out["parameters"] = j; // TODO should return changed defaults + out["parameters"] = j; out["info"] = ""; for (int i = 0; i < num_results; i++) { @@ -1149,21 +1100,21 @@ int main(int argc, const char** argv) { } }; - svr.Post("/sdapi/v1/txt2img", [&](const httplib::Request& req, httplib::Response& res) { + svr.Post("/sdapi/v1/txt2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) { sdapi_any2img(req, res, false); }); - svr.Post("/sdapi/v1/img2img", [&](const httplib::Request& req, httplib::Response& res) { + svr.Post("/sdapi/v1/img2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) { sdapi_any2img(req, res, true); }); - svr.Get("/sdapi/v1/loras", [&](const httplib::Request&, httplib::Response& res) { - refresh_lora_cache(); + svr.Get("/sdapi/v1/loras", [runtime](const httplib::Request&, httplib::Response& res) { + refresh_lora_cache(*runtime); json result = json::array(); { - std::lock_guard lock(lora_mutex); - for (const auto& e : lora_cache) { + std::lock_guard lock(*runtime->lora_mutex); + for (const auto& e : *runtime->lora_cache) { json item; item["name"] = e.name; item["path"] = e.path; @@ -1174,7 +1125,7 @@ int main(int argc, const char** argv) { res.set_content(result.dump(), "application/json"); }); - svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) { + svr.Get("/sdapi/v1/samplers", [runtime](const httplib::Request&, httplib::Response& res) { std::vector sampler_names; sampler_names.push_back("default"); for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) { @@ -1191,7 +1142,7 @@ int main(int argc, const char** argv) { res.set_content(r.dump(), "application/json"); }); - svr.Get("/sdapi/v1/schedulers", [&](const httplib::Request&, httplib::Response& res) { + svr.Get("/sdapi/v1/schedulers", [runtime](const httplib::Request&, httplib::Response& res) { std::vector scheduler_names; scheduler_names.push_back("default"); for (int i = 0; i < SCHEDULER_COUNT; i++) { @@ -1207,8 +1158,8 @@ int main(int argc, const char** argv) { res.set_content(r.dump(), "application/json"); }); - svr.Get("/sdapi/v1/sd-models", [&](const httplib::Request&, httplib::Response& res) { - fs::path model_path = ctx_params.model_path; + svr.Get("/sdapi/v1/sd-models", [runtime](const httplib::Request&, httplib::Response& res) { + fs::path model_path = runtime->ctx_params->model_path; json entry; entry["title"] = model_path.stem(); entry["model_name"] = model_path.stem(); @@ -1221,18 +1172,89 @@ int main(int argc, const char** argv) { res.set_content(r.dump(), "application/json"); }); - svr.Get("/sdapi/v1/options", [&](const httplib::Request&, httplib::Response& res) { - fs::path model_path = ctx_params.model_path; + svr.Get("/sdapi/v1/options", [runtime](const httplib::Request&, httplib::Response& res) { + fs::path model_path = runtime->ctx_params->model_path; json r; r["samples_format"] = "png"; r["sd_model_checkpoint"] = model_path.stem(); res.set_content(r.dump(), "application/json"); }); +} + +int main(int argc, const char** argv) { + if (argc > 1 && std::string(argv[1]) == "--version") { + std::cout << version_string() << "\n"; + return EXIT_SUCCESS; + } + SDSvrParams svr_params; + SDContextParams ctx_params; + SDGenerationParams default_gen_params; + parse_args(argc, argv, svr_params, ctx_params, default_gen_params); + + sd_set_log_callback(sd_log_cb, (void*)&svr_params); + log_verbose = svr_params.verbose; + log_color = svr_params.color; + + LOG_DEBUG("version: %s", version_string().c_str()); + LOG_DEBUG("%s", sd_get_system_info()); + LOG_DEBUG("%s", svr_params.to_string().c_str()); + LOG_DEBUG("%s", ctx_params.to_string().c_str()); + LOG_DEBUG("%s", default_gen_params.to_string().c_str()); + + sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false); + sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); + + if (sd_ctx == nullptr) { + LOG_ERROR("new_sd_ctx_t failed"); + return 1; + } + + std::mutex sd_ctx_mutex; + + std::vector lora_cache; + std::mutex lora_mutex; + ServerRuntime runtime = { + sd_ctx, + &sd_ctx_mutex, + &svr_params, + &ctx_params, + &default_gen_params, + &lora_cache, + &lora_mutex, + }; + + httplib::Server svr; + + svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) { + std::string origin = req.get_header_value("Origin"); + if (origin.empty()) { + origin = "*"; + } + res.set_header("Access-Control-Allow-Origin", origin); + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "*"); + res.set_header("Access-Control-Allow-Headers", "*"); + + if (req.method == "OPTIONS") { + res.status = 204; + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; + }); + + std::string index_html; +#ifdef HAVE_INDEX_HTML + index_html.assign(reinterpret_cast(index_html_bytes), index_html_size); +#else + index_html = "Stable Diffusion Server is running"; +#endif + register_index_endpoints(svr, svr_params, index_html); + register_openai_api_endpoints(svr, runtime); + register_sdapi_endpoints(svr, runtime); LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port); svr.listen(svr_params.listen_ip, svr_params.listen_port); - // cleanup free_sd_ctx(sd_ctx); return 0; }