refactor(server): split server endpoint registration (#1376)

This commit is contained in:
leejet 2026-03-31 00:02:03 +08:00 committed by GitHub
parent 8d878872d9
commit 83e8f6f0af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -273,6 +273,63 @@ struct LoraEntry {
std::string fullpath;
};
struct ServerRuntime {
sd_ctx_t* sd_ctx;
std::mutex* sd_ctx_mutex;
const SDSvrParams* svr_params;
const SDContextParams* ctx_params;
const SDGenerationParams* default_gen_params;
std::vector<LoraEntry>* lora_cache;
std::mutex* lora_mutex;
};
void refresh_lora_cache(ServerRuntime& rt) {
std::vector<LoraEntry> new_cache;
fs::path lora_dir = rt.ctx_params->lora_model_dir;
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
auto is_lora_ext = [](const fs::path& p) {
auto ext = p.extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors";
};
for (auto& entry : fs::recursive_directory_iterator(lora_dir)) {
if (!entry.is_regular_file())
continue;
const fs::path& p = entry.path();
if (!is_lora_ext(p))
continue;
LoraEntry e;
e.name = p.stem().u8string();
e.fullpath = p.u8string();
std::string rel = p.lexically_relative(lora_dir).u8string();
std::replace(rel.begin(), rel.end(), '\\', '/');
e.path = rel;
new_cache.push_back(std::move(e));
}
}
std::sort(new_cache.begin(), new_cache.end(),
[](const LoraEntry& a, const LoraEntry& b) {
return a.path < b.path;
});
{
std::lock_guard<std::mutex> lock(*rt.lora_mutex);
*rt.lora_cache = std::move(new_cache);
}
}
std::string get_lora_full_path(ServerRuntime& rt, const std::string& path) {
std::lock_guard<std::mutex> lock(*rt.lora_mutex);
auto it = std::find_if(rt.lora_cache->begin(), rt.lora_cache->end(),
[&](const LoraEntry& e) { return e.path == path; });
return (it != rt.lora_cache->end()) ? it->fullpath : "";
}
void free_results(sd_image_t* result_images, int num_results) {
if (result_images) {
for (int i = 0; i < num_results; ++i) {
@ -285,115 +342,11 @@ void free_results(sd_image_t* result_images, int num_results) {
free(result_images);
}
int main(int argc, const char** argv) {
if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n";
return EXIT_SUCCESS;
}
SDSvrParams svr_params;
SDContextParams ctx_params;
SDGenerationParams default_gen_params;
parse_args(argc, argv, svr_params, ctx_params, default_gen_params);
sd_set_log_callback(sd_log_cb, (void*)&svr_params);
log_verbose = svr_params.verbose;
log_color = svr_params.color;
LOG_DEBUG("version: %s", version_string().c_str());
LOG_DEBUG("%s", sd_get_system_info());
LOG_DEBUG("%s", svr_params.to_string().c_str());
LOG_DEBUG("%s", ctx_params.to_string().c_str());
LOG_DEBUG("%s", default_gen_params.to_string().c_str());
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false);
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
if (sd_ctx == nullptr) {
LOG_ERROR("new_sd_ctx_t failed");
return 1;
}
std::mutex sd_ctx_mutex;
std::vector<LoraEntry> lora_cache;
std::mutex lora_mutex;
auto refresh_lora_cache = [&]() {
std::vector<LoraEntry> new_cache;
fs::path lora_dir = ctx_params.lora_model_dir;
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
auto is_lora_ext = [](const fs::path& p) {
auto ext = p.extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors";
};
for (auto& entry : fs::recursive_directory_iterator(lora_dir)) {
if (!entry.is_regular_file())
continue;
const fs::path& p = entry.path();
if (!is_lora_ext(p))
continue;
LoraEntry e;
e.name = p.stem().u8string();
e.fullpath = p.u8string();
std::string rel = p.lexically_relative(lora_dir).u8string();
std::replace(rel.begin(), rel.end(), '\\', '/');
e.path = rel;
new_cache.push_back(std::move(e));
}
}
std::sort(new_cache.begin(), new_cache.end(),
[](const LoraEntry& a, const LoraEntry& b) {
return a.path < b.path;
});
{
std::lock_guard<std::mutex> lock(lora_mutex);
lora_cache = std::move(new_cache);
}
};
auto get_lora_full_path = [&](const std::string& path) -> std::string {
std::lock_guard<std::mutex> lock(lora_mutex);
auto it = std::find_if(lora_cache.begin(), lora_cache.end(),
[&](const LoraEntry& e) { return e.path == path; });
return (it != lora_cache.end()) ? it->fullpath : "";
};
httplib::Server svr;
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
std::string origin = req.get_header_value("Origin");
if (origin.empty()) {
origin = "*";
}
res.set_header("Access-Control-Allow-Origin", origin);
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "*");
res.set_header("Access-Control-Allow-Headers", "*");
if (req.method == "OPTIONS") {
res.status = 204;
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});
// index html
std::string index_html;
#ifdef HAVE_INDEX_HTML
index_html.assign(reinterpret_cast<const char*>(index_html_bytes), index_html_size);
#else
index_html = "Stable Diffusion Server is running";
#endif
svr.Get("/", [&](const httplib::Request&, httplib::Response& res) {
if (!svr_params.serve_html_path.empty()) {
std::ifstream file(svr_params.serve_html_path);
void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) {
const std::string serve_html_path = svr_params.serve_html_path;
svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) {
if (!serve_html_path.empty()) {
std::ifstream file(serve_html_path);
if (file) {
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
res.set_content(content, "text/html");
@ -405,17 +358,19 @@ int main(int argc, const char** argv) {
res.set_content(index_html, "text/html");
}
});
}
// models endpoint (minimal)
svr.Get("/v1/models", [&](const httplib::Request&, httplib::Response& res) {
void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
ServerRuntime* runtime = &rt;
svr.Get("/v1/models", [runtime](const httplib::Request&, httplib::Response& res) {
json r;
r["data"] = json::array();
r["data"].push_back({{"id", "sd-cpp-local"}, {"object", "model"}, {"owned_by", "local"}});
res.set_content(r.dump(), "application/json");
});
// core endpoint: /v1/images/generations
svr.Post("/v1/images/generations", [&](const httplib::Request& req, httplib::Response& res) {
svr.Post("/v1/images/generations", [runtime](const httplib::Request& req, httplib::Response& res) {
try {
if (req.body.empty()) {
res.status = 400;
@ -429,8 +384,8 @@ int main(int argc, const char** argv) {
std::string size = j.value("size", "");
std::string output_format = j.value("output_format", "png");
int output_compression = j.value("output_compression", 100);
int width = default_gen_params.width > 0 ? default_gen_params.width : 512;
int height = default_gen_params.width > 0 ? default_gen_params.height : 512;
int width = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->width : 512;
int height = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->height : 512;
if (!size.empty()) {
auto pos = size.find('x');
if (pos != std::string::npos) {
@ -458,7 +413,7 @@ int main(int argc, const char** argv) {
if (n <= 0)
n = 1;
if (n > 8)
n = 8; // safety
n = 8;
if (output_compression > 100) {
output_compression = 100;
}
@ -471,7 +426,7 @@ int main(int argc, const char** argv) {
out["data"] = json::array();
out["output_format"] = output_format;
SDGenerationParams gen_params = default_gen_params;
SDGenerationParams gen_params = *runtime->default_gen_params;
gen_params.prompt = prompt;
gen_params.width = width;
gen_params.height = height;
@ -524,7 +479,7 @@ int main(int argc, const char** argv) {
(int)pmid_images.size(),
gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength,
}, // pm_params
},
gen_params.vae_tiling_params,
gen_params.cache_params,
};
@ -533,8 +488,8 @@ int main(int argc, const char** argv) {
int num_results = 0;
{
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params);
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
@ -553,7 +508,6 @@ int main(int argc, const char** argv) {
continue;
}
// base64 encode
std::string b64 = base64_encode(image_bytes);
json item;
item["b64_json"] = b64;
@ -573,7 +527,7 @@ int main(int argc, const char** argv) {
}
});
svr.Post("/v1/images/edits", [&](const httplib::Request& req, httplib::Response& res) {
svr.Post("/v1/images/edits", [runtime](const httplib::Request& req, httplib::Response& res) {
try {
if (!req.is_multipart_form_data()) {
res.status = 400;
@ -658,7 +612,7 @@ int main(int argc, const char** argv) {
output_compression = 0;
}
SDGenerationParams gen_params = default_gen_params;
SDGenerationParams gen_params = *runtime->default_gen_params;
gen_params.prompt = prompt;
gen_params.width = width;
gen_params.height = height;
@ -685,18 +639,18 @@ int main(int argc, const char** argv) {
sd_image_t control_image = {0, 0, 3, nullptr};
std::vector<sd_image_t> pmid_images;
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
auto get_resolved_width = [&gen_params, runtime]() -> int {
if (gen_params.width > 0)
return gen_params.width;
if (default_gen_params.width > 0)
return default_gen_params.width;
if (runtime->default_gen_params->width > 0)
return runtime->default_gen_params->width;
return 512;
};
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
auto get_resolved_height = [&gen_params, runtime]() -> int {
if (gen_params.height > 0)
return gen_params.height;
if (default_gen_params.height > 0)
return default_gen_params.height;
if (runtime->default_gen_params->height > 0)
return runtime->default_gen_params->height;
return 512;
};
@ -771,7 +725,7 @@ int main(int argc, const char** argv) {
(int)pmid_images.size(),
gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength,
}, // pm_params
},
gen_params.vae_tiling_params,
gen_params.cache_params,
};
@ -780,8 +734,8 @@ int main(int argc, const char** argv) {
int num_results = 0;
{
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params);
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
@ -826,10 +780,12 @@ int main(int argc, const char** argv) {
res.set_content(err.dump(), "application/json");
}
});
}
// sdapi endpoints (AUTOMATIC1111 / Forge)
void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
ServerRuntime* runtime = &rt;
auto sdapi_any2img = [&](const httplib::Request& req, httplib::Response& res, bool img2img) {
auto sdapi_any2img = [runtime](const httplib::Request& req, httplib::Response& res, bool img2img) {
try {
if (req.body.empty()) {
res.status = 400;
@ -843,8 +799,8 @@ int main(int argc, const char** argv) {
std::string negative_prompt = j.value("negative_prompt", "");
int width = j.value("width", 512);
int height = j.value("height", 512);
int steps = j.value("steps", default_gen_params.sample_params.sample_steps);
float cfg_scale = j.value("cfg_scale", default_gen_params.sample_params.guidance.txt_cfg);
int steps = j.value("steps", runtime->default_gen_params->sample_params.sample_steps);
float cfg_scale = j.value("cfg_scale", runtime->default_gen_params->sample_params.guidance.txt_cfg);
int64_t seed = j.value("seed", -1);
int batch_size = j.value("batch_size", 1);
int clip_skip = j.value("clip_skip", -1);
@ -894,7 +850,7 @@ int main(int argc, const char** argv) {
return bad("lora.path required");
}
std::string fullpath = get_lora_full_path(path);
std::string fullpath = get_lora_full_path(*runtime, path);
if (fullpath.empty()) {
return bad("invalid lora path: " + path);
}
@ -912,7 +868,6 @@ int main(int argc, const char** argv) {
auto get_sample_method = [](std::string name) -> enum sample_method_t {
enum sample_method_t result = str_to_sample_method(name.c_str());
if (result != SAMPLE_METHOD_COUNT) return result;
// some applications use a hardcoded sampler list
std::transform(name.begin(), name.end(), name.begin(),
[](unsigned char c) { return std::tolower(c); });
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
@ -938,10 +893,9 @@ int main(int argc, const char** argv) {
};
enum sample_method_t sample_method = get_sample_method(sampler_name);
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
SDGenerationParams gen_params = default_gen_params;
SDGenerationParams gen_params = *runtime->default_gen_params;
gen_params.prompt = prompt;
gen_params.negative_prompt = negative_prompt;
gen_params.seed = seed;
@ -961,8 +915,6 @@ int main(int argc, const char** argv) {
gen_params.sample_params.scheduler = scheduler;
}
// re-read to avoid applying 512 as default before the provided
// images and/or server command-line
gen_params.width = j.value("width", -1);
gen_params.height = j.value("height", -1);
@ -975,23 +927,22 @@ int main(int argc, const char** argv) {
std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> ref_images;
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
auto get_resolved_width = [&gen_params, runtime]() -> int {
if (gen_params.width > 0)
return gen_params.width;
if (default_gen_params.width > 0)
return default_gen_params.width;
if (runtime->default_gen_params->width > 0)
return runtime->default_gen_params->width;
return 512;
};
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
auto get_resolved_height = [&gen_params, runtime]() -> int {
if (gen_params.height > 0)
return gen_params.height;
if (default_gen_params.height > 0)
return default_gen_params.height;
if (runtime->default_gen_params->height > 0)
return runtime->default_gen_params->height;
return 512;
};
auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
// remove data URI prefix if present ("data:image/png;base64,")
auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) {
encoded = encoded.substr(comma_pos + 1);
@ -1087,7 +1038,7 @@ int main(int argc, const char** argv) {
(int)pmid_images.size(),
gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength,
}, // pm_params
},
gen_params.vae_tiling_params,
gen_params.cache_params,
};
@ -1096,14 +1047,14 @@ int main(int argc, const char** argv) {
int num_results = 0;
{
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params);
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count;
}
json out;
out["images"] = json::array();
out["parameters"] = j; // TODO should return changed defaults
out["parameters"] = j;
out["info"] = "";
for (int i = 0; i < num_results; i++) {
@ -1149,21 +1100,21 @@ int main(int argc, const char** argv) {
}
};
svr.Post("/sdapi/v1/txt2img", [&](const httplib::Request& req, httplib::Response& res) {
svr.Post("/sdapi/v1/txt2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, false);
});
svr.Post("/sdapi/v1/img2img", [&](const httplib::Request& req, httplib::Response& res) {
svr.Post("/sdapi/v1/img2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, true);
});
svr.Get("/sdapi/v1/loras", [&](const httplib::Request&, httplib::Response& res) {
refresh_lora_cache();
svr.Get("/sdapi/v1/loras", [runtime](const httplib::Request&, httplib::Response& res) {
refresh_lora_cache(*runtime);
json result = json::array();
{
std::lock_guard<std::mutex> lock(lora_mutex);
for (const auto& e : lora_cache) {
std::lock_guard<std::mutex> lock(*runtime->lora_mutex);
for (const auto& e : *runtime->lora_cache) {
json item;
item["name"] = e.name;
item["path"] = e.path;
@ -1174,7 +1125,7 @@ int main(int argc, const char** argv) {
res.set_content(result.dump(), "application/json");
});
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
svr.Get("/sdapi/v1/samplers", [runtime](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> sampler_names;
sampler_names.push_back("default");
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
@ -1191,7 +1142,7 @@ int main(int argc, const char** argv) {
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/schedulers", [&](const httplib::Request&, httplib::Response& res) {
svr.Get("/sdapi/v1/schedulers", [runtime](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> scheduler_names;
scheduler_names.push_back("default");
for (int i = 0; i < SCHEDULER_COUNT; i++) {
@ -1207,8 +1158,8 @@ int main(int argc, const char** argv) {
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/sd-models", [&](const httplib::Request&, httplib::Response& res) {
fs::path model_path = ctx_params.model_path;
svr.Get("/sdapi/v1/sd-models", [runtime](const httplib::Request&, httplib::Response& res) {
fs::path model_path = runtime->ctx_params->model_path;
json entry;
entry["title"] = model_path.stem();
entry["model_name"] = model_path.stem();
@ -1221,18 +1172,89 @@ int main(int argc, const char** argv) {
res.set_content(r.dump(), "application/json");
});
svr.Get("/sdapi/v1/options", [&](const httplib::Request&, httplib::Response& res) {
fs::path model_path = ctx_params.model_path;
svr.Get("/sdapi/v1/options", [runtime](const httplib::Request&, httplib::Response& res) {
fs::path model_path = runtime->ctx_params->model_path;
json r;
r["samples_format"] = "png";
r["sd_model_checkpoint"] = model_path.stem();
res.set_content(r.dump(), "application/json");
});
}
int main(int argc, const char** argv) {
if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n";
return EXIT_SUCCESS;
}
SDSvrParams svr_params;
SDContextParams ctx_params;
SDGenerationParams default_gen_params;
parse_args(argc, argv, svr_params, ctx_params, default_gen_params);
sd_set_log_callback(sd_log_cb, (void*)&svr_params);
log_verbose = svr_params.verbose;
log_color = svr_params.color;
LOG_DEBUG("version: %s", version_string().c_str());
LOG_DEBUG("%s", sd_get_system_info());
LOG_DEBUG("%s", svr_params.to_string().c_str());
LOG_DEBUG("%s", ctx_params.to_string().c_str());
LOG_DEBUG("%s", default_gen_params.to_string().c_str());
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false);
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
if (sd_ctx == nullptr) {
LOG_ERROR("new_sd_ctx_t failed");
return 1;
}
std::mutex sd_ctx_mutex;
std::vector<LoraEntry> lora_cache;
std::mutex lora_mutex;
ServerRuntime runtime = {
sd_ctx,
&sd_ctx_mutex,
&svr_params,
&ctx_params,
&default_gen_params,
&lora_cache,
&lora_mutex,
};
httplib::Server svr;
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
std::string origin = req.get_header_value("Origin");
if (origin.empty()) {
origin = "*";
}
res.set_header("Access-Control-Allow-Origin", origin);
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "*");
res.set_header("Access-Control-Allow-Headers", "*");
if (req.method == "OPTIONS") {
res.status = 204;
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});
std::string index_html;
#ifdef HAVE_INDEX_HTML
index_html.assign(reinterpret_cast<const char*>(index_html_bytes), index_html_size);
#else
index_html = "Stable Diffusion Server is running";
#endif
register_index_endpoints(svr, svr_params, index_html);
register_openai_api_endpoints(svr, runtime);
register_sdapi_endpoints(svr, runtime);
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
svr.listen(svr_params.listen_ip, svr_params.listen_port);
// cleanup
free_sd_ctx(sd_ctx);
return 0;
}