mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-31 05:39:42 +00:00
refactor(server): split server endpoint registration (#1376)
This commit is contained in:
parent
8d878872d9
commit
83e8f6f0af
@ -273,55 +273,20 @@ struct LoraEntry {
|
||||
std::string fullpath;
|
||||
};
|
||||
|
||||
void free_results(sd_image_t* result_images, int num_results) {
|
||||
if (result_images) {
|
||||
for (int i = 0; i < num_results; ++i) {
|
||||
if (result_images[i].data) {
|
||||
stbi_image_free(result_images[i].data);
|
||||
result_images[i].data = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
free(result_images);
|
||||
}
|
||||
struct ServerRuntime {
|
||||
sd_ctx_t* sd_ctx;
|
||||
std::mutex* sd_ctx_mutex;
|
||||
const SDSvrParams* svr_params;
|
||||
const SDContextParams* ctx_params;
|
||||
const SDGenerationParams* default_gen_params;
|
||||
std::vector<LoraEntry>* lora_cache;
|
||||
std::mutex* lora_mutex;
|
||||
};
|
||||
|
||||
int main(int argc, const char** argv) {
|
||||
if (argc > 1 && std::string(argv[1]) == "--version") {
|
||||
std::cout << version_string() << "\n";
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
SDSvrParams svr_params;
|
||||
SDContextParams ctx_params;
|
||||
SDGenerationParams default_gen_params;
|
||||
parse_args(argc, argv, svr_params, ctx_params, default_gen_params);
|
||||
|
||||
sd_set_log_callback(sd_log_cb, (void*)&svr_params);
|
||||
log_verbose = svr_params.verbose;
|
||||
log_color = svr_params.color;
|
||||
|
||||
LOG_DEBUG("version: %s", version_string().c_str());
|
||||
LOG_DEBUG("%s", sd_get_system_info());
|
||||
LOG_DEBUG("%s", svr_params.to_string().c_str());
|
||||
LOG_DEBUG("%s", ctx_params.to_string().c_str());
|
||||
LOG_DEBUG("%s", default_gen_params.to_string().c_str());
|
||||
|
||||
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false);
|
||||
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
|
||||
|
||||
if (sd_ctx == nullptr) {
|
||||
LOG_ERROR("new_sd_ctx_t failed");
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::mutex sd_ctx_mutex;
|
||||
|
||||
std::vector<LoraEntry> lora_cache;
|
||||
std::mutex lora_mutex;
|
||||
|
||||
auto refresh_lora_cache = [&]() {
|
||||
void refresh_lora_cache(ServerRuntime& rt) {
|
||||
std::vector<LoraEntry> new_cache;
|
||||
|
||||
fs::path lora_dir = ctx_params.lora_model_dir;
|
||||
fs::path lora_dir = rt.ctx_params->lora_model_dir;
|
||||
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
|
||||
auto is_lora_ext = [](const fs::path& p) {
|
||||
auto ext = p.extension().string();
|
||||
@ -353,47 +318,35 @@ int main(int argc, const char** argv) {
|
||||
});
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(lora_mutex);
|
||||
lora_cache = std::move(new_cache);
|
||||
std::lock_guard<std::mutex> lock(*rt.lora_mutex);
|
||||
*rt.lora_cache = std::move(new_cache);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
auto get_lora_full_path = [&](const std::string& path) -> std::string {
|
||||
std::lock_guard<std::mutex> lock(lora_mutex);
|
||||
auto it = std::find_if(lora_cache.begin(), lora_cache.end(),
|
||||
std::string get_lora_full_path(ServerRuntime& rt, const std::string& path) {
|
||||
std::lock_guard<std::mutex> lock(*rt.lora_mutex);
|
||||
auto it = std::find_if(rt.lora_cache->begin(), rt.lora_cache->end(),
|
||||
[&](const LoraEntry& e) { return e.path == path; });
|
||||
return (it != lora_cache.end()) ? it->fullpath : "";
|
||||
};
|
||||
return (it != rt.lora_cache->end()) ? it->fullpath : "";
|
||||
}
|
||||
|
||||
httplib::Server svr;
|
||||
|
||||
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
|
||||
std::string origin = req.get_header_value("Origin");
|
||||
if (origin.empty()) {
|
||||
origin = "*";
|
||||
void free_results(sd_image_t* result_images, int num_results) {
|
||||
if (result_images) {
|
||||
for (int i = 0; i < num_results; ++i) {
|
||||
if (result_images[i].data) {
|
||||
stbi_image_free(result_images[i].data);
|
||||
result_images[i].data = nullptr;
|
||||
}
|
||||
res.set_header("Access-Control-Allow-Origin", origin);
|
||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||
res.set_header("Access-Control-Allow-Methods", "*");
|
||||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
|
||||
if (req.method == "OPTIONS") {
|
||||
res.status = 204;
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
return httplib::Server::HandlerResponse::Unhandled;
|
||||
});
|
||||
}
|
||||
free(result_images);
|
||||
}
|
||||
|
||||
// index html
|
||||
std::string index_html;
|
||||
#ifdef HAVE_INDEX_HTML
|
||||
index_html.assign(reinterpret_cast<const char*>(index_html_bytes), index_html_size);
|
||||
#else
|
||||
index_html = "Stable Diffusion Server is running";
|
||||
#endif
|
||||
svr.Get("/", [&](const httplib::Request&, httplib::Response& res) {
|
||||
if (!svr_params.serve_html_path.empty()) {
|
||||
std::ifstream file(svr_params.serve_html_path);
|
||||
void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) {
|
||||
const std::string serve_html_path = svr_params.serve_html_path;
|
||||
svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) {
|
||||
if (!serve_html_path.empty()) {
|
||||
std::ifstream file(serve_html_path);
|
||||
if (file) {
|
||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||
res.set_content(content, "text/html");
|
||||
@ -405,17 +358,19 @@ int main(int argc, const char** argv) {
|
||||
res.set_content(index_html, "text/html");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// models endpoint (minimal)
|
||||
svr.Get("/v1/models", [&](const httplib::Request&, httplib::Response& res) {
|
||||
void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
|
||||
ServerRuntime* runtime = &rt;
|
||||
|
||||
svr.Get("/v1/models", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
json r;
|
||||
r["data"] = json::array();
|
||||
r["data"].push_back({{"id", "sd-cpp-local"}, {"object", "model"}, {"owned_by", "local"}});
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
// core endpoint: /v1/images/generations
|
||||
svr.Post("/v1/images/generations", [&](const httplib::Request& req, httplib::Response& res) {
|
||||
svr.Post("/v1/images/generations", [runtime](const httplib::Request& req, httplib::Response& res) {
|
||||
try {
|
||||
if (req.body.empty()) {
|
||||
res.status = 400;
|
||||
@ -429,8 +384,8 @@ int main(int argc, const char** argv) {
|
||||
std::string size = j.value("size", "");
|
||||
std::string output_format = j.value("output_format", "png");
|
||||
int output_compression = j.value("output_compression", 100);
|
||||
int width = default_gen_params.width > 0 ? default_gen_params.width : 512;
|
||||
int height = default_gen_params.width > 0 ? default_gen_params.height : 512;
|
||||
int width = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->width : 512;
|
||||
int height = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->height : 512;
|
||||
if (!size.empty()) {
|
||||
auto pos = size.find('x');
|
||||
if (pos != std::string::npos) {
|
||||
@ -458,7 +413,7 @@ int main(int argc, const char** argv) {
|
||||
if (n <= 0)
|
||||
n = 1;
|
||||
if (n > 8)
|
||||
n = 8; // safety
|
||||
n = 8;
|
||||
if (output_compression > 100) {
|
||||
output_compression = 100;
|
||||
}
|
||||
@ -471,7 +426,7 @@ int main(int argc, const char** argv) {
|
||||
out["data"] = json::array();
|
||||
out["output_format"] = output_format;
|
||||
|
||||
SDGenerationParams gen_params = default_gen_params;
|
||||
SDGenerationParams gen_params = *runtime->default_gen_params;
|
||||
gen_params.prompt = prompt;
|
||||
gen_params.width = width;
|
||||
gen_params.height = height;
|
||||
@ -524,7 +479,7 @@ int main(int argc, const char** argv) {
|
||||
(int)pmid_images.size(),
|
||||
gen_params.pm_id_embed_path.c_str(),
|
||||
gen_params.pm_style_strength,
|
||||
}, // pm_params
|
||||
},
|
||||
gen_params.vae_tiling_params,
|
||||
gen_params.cache_params,
|
||||
};
|
||||
@ -533,8 +488,8 @@ int main(int argc, const char** argv) {
|
||||
int num_results = 0;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
|
||||
results = generate_image(sd_ctx, &img_gen_params);
|
||||
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
|
||||
results = generate_image(runtime->sd_ctx, &img_gen_params);
|
||||
num_results = gen_params.batch_count;
|
||||
}
|
||||
|
||||
@ -553,7 +508,6 @@ int main(int argc, const char** argv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// base64 encode
|
||||
std::string b64 = base64_encode(image_bytes);
|
||||
json item;
|
||||
item["b64_json"] = b64;
|
||||
@ -573,7 +527,7 @@ int main(int argc, const char** argv) {
|
||||
}
|
||||
});
|
||||
|
||||
svr.Post("/v1/images/edits", [&](const httplib::Request& req, httplib::Response& res) {
|
||||
svr.Post("/v1/images/edits", [runtime](const httplib::Request& req, httplib::Response& res) {
|
||||
try {
|
||||
if (!req.is_multipart_form_data()) {
|
||||
res.status = 400;
|
||||
@ -658,7 +612,7 @@ int main(int argc, const char** argv) {
|
||||
output_compression = 0;
|
||||
}
|
||||
|
||||
SDGenerationParams gen_params = default_gen_params;
|
||||
SDGenerationParams gen_params = *runtime->default_gen_params;
|
||||
gen_params.prompt = prompt;
|
||||
gen_params.width = width;
|
||||
gen_params.height = height;
|
||||
@ -685,18 +639,18 @@ int main(int argc, const char** argv) {
|
||||
sd_image_t control_image = {0, 0, 3, nullptr};
|
||||
std::vector<sd_image_t> pmid_images;
|
||||
|
||||
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
|
||||
auto get_resolved_width = [&gen_params, runtime]() -> int {
|
||||
if (gen_params.width > 0)
|
||||
return gen_params.width;
|
||||
if (default_gen_params.width > 0)
|
||||
return default_gen_params.width;
|
||||
if (runtime->default_gen_params->width > 0)
|
||||
return runtime->default_gen_params->width;
|
||||
return 512;
|
||||
};
|
||||
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
|
||||
auto get_resolved_height = [&gen_params, runtime]() -> int {
|
||||
if (gen_params.height > 0)
|
||||
return gen_params.height;
|
||||
if (default_gen_params.height > 0)
|
||||
return default_gen_params.height;
|
||||
if (runtime->default_gen_params->height > 0)
|
||||
return runtime->default_gen_params->height;
|
||||
return 512;
|
||||
};
|
||||
|
||||
@ -771,7 +725,7 @@ int main(int argc, const char** argv) {
|
||||
(int)pmid_images.size(),
|
||||
gen_params.pm_id_embed_path.c_str(),
|
||||
gen_params.pm_style_strength,
|
||||
}, // pm_params
|
||||
},
|
||||
gen_params.vae_tiling_params,
|
||||
gen_params.cache_params,
|
||||
};
|
||||
@ -780,8 +734,8 @@ int main(int argc, const char** argv) {
|
||||
int num_results = 0;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
|
||||
results = generate_image(sd_ctx, &img_gen_params);
|
||||
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
|
||||
results = generate_image(runtime->sd_ctx, &img_gen_params);
|
||||
num_results = gen_params.batch_count;
|
||||
}
|
||||
|
||||
@ -826,10 +780,12 @@ int main(int argc, const char** argv) {
|
||||
res.set_content(err.dump(), "application/json");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// sdapi endpoints (AUTOMATIC1111 / Forge)
|
||||
void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
|
||||
ServerRuntime* runtime = &rt;
|
||||
|
||||
auto sdapi_any2img = [&](const httplib::Request& req, httplib::Response& res, bool img2img) {
|
||||
auto sdapi_any2img = [runtime](const httplib::Request& req, httplib::Response& res, bool img2img) {
|
||||
try {
|
||||
if (req.body.empty()) {
|
||||
res.status = 400;
|
||||
@ -843,8 +799,8 @@ int main(int argc, const char** argv) {
|
||||
std::string negative_prompt = j.value("negative_prompt", "");
|
||||
int width = j.value("width", 512);
|
||||
int height = j.value("height", 512);
|
||||
int steps = j.value("steps", default_gen_params.sample_params.sample_steps);
|
||||
float cfg_scale = j.value("cfg_scale", default_gen_params.sample_params.guidance.txt_cfg);
|
||||
int steps = j.value("steps", runtime->default_gen_params->sample_params.sample_steps);
|
||||
float cfg_scale = j.value("cfg_scale", runtime->default_gen_params->sample_params.guidance.txt_cfg);
|
||||
int64_t seed = j.value("seed", -1);
|
||||
int batch_size = j.value("batch_size", 1);
|
||||
int clip_skip = j.value("clip_skip", -1);
|
||||
@ -894,7 +850,7 @@ int main(int argc, const char** argv) {
|
||||
return bad("lora.path required");
|
||||
}
|
||||
|
||||
std::string fullpath = get_lora_full_path(path);
|
||||
std::string fullpath = get_lora_full_path(*runtime, path);
|
||||
if (fullpath.empty()) {
|
||||
return bad("invalid lora path: " + path);
|
||||
}
|
||||
@ -912,7 +868,6 @@ int main(int argc, const char** argv) {
|
||||
auto get_sample_method = [](std::string name) -> enum sample_method_t {
|
||||
enum sample_method_t result = str_to_sample_method(name.c_str());
|
||||
if (result != SAMPLE_METHOD_COUNT) return result;
|
||||
// some applications use a hardcoded sampler list
|
||||
std::transform(name.begin(), name.end(), name.begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
|
||||
@ -938,10 +893,9 @@ int main(int argc, const char** argv) {
|
||||
};
|
||||
|
||||
enum sample_method_t sample_method = get_sample_method(sampler_name);
|
||||
|
||||
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
|
||||
|
||||
SDGenerationParams gen_params = default_gen_params;
|
||||
SDGenerationParams gen_params = *runtime->default_gen_params;
|
||||
gen_params.prompt = prompt;
|
||||
gen_params.negative_prompt = negative_prompt;
|
||||
gen_params.seed = seed;
|
||||
@ -961,8 +915,6 @@ int main(int argc, const char** argv) {
|
||||
gen_params.sample_params.scheduler = scheduler;
|
||||
}
|
||||
|
||||
// re-read to avoid applying 512 as default before the provided
|
||||
// images and/or server command-line
|
||||
gen_params.width = j.value("width", -1);
|
||||
gen_params.height = j.value("height", -1);
|
||||
|
||||
@ -975,23 +927,22 @@ int main(int argc, const char** argv) {
|
||||
std::vector<sd_image_t> pmid_images;
|
||||
std::vector<sd_image_t> ref_images;
|
||||
|
||||
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
|
||||
auto get_resolved_width = [&gen_params, runtime]() -> int {
|
||||
if (gen_params.width > 0)
|
||||
return gen_params.width;
|
||||
if (default_gen_params.width > 0)
|
||||
return default_gen_params.width;
|
||||
if (runtime->default_gen_params->width > 0)
|
||||
return runtime->default_gen_params->width;
|
||||
return 512;
|
||||
};
|
||||
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
|
||||
auto get_resolved_height = [&gen_params, runtime]() -> int {
|
||||
if (gen_params.height > 0)
|
||||
return gen_params.height;
|
||||
if (default_gen_params.height > 0)
|
||||
return default_gen_params.height;
|
||||
if (runtime->default_gen_params->height > 0)
|
||||
return runtime->default_gen_params->height;
|
||||
return 512;
|
||||
};
|
||||
|
||||
auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
|
||||
// remove data URI prefix if present ("data:image/png;base64,")
|
||||
auto comma_pos = encoded.find(',');
|
||||
if (comma_pos != std::string::npos) {
|
||||
encoded = encoded.substr(comma_pos + 1);
|
||||
@ -1087,7 +1038,7 @@ int main(int argc, const char** argv) {
|
||||
(int)pmid_images.size(),
|
||||
gen_params.pm_id_embed_path.c_str(),
|
||||
gen_params.pm_style_strength,
|
||||
}, // pm_params
|
||||
},
|
||||
gen_params.vae_tiling_params,
|
||||
gen_params.cache_params,
|
||||
};
|
||||
@ -1096,14 +1047,14 @@ int main(int argc, const char** argv) {
|
||||
int num_results = 0;
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
|
||||
results = generate_image(sd_ctx, &img_gen_params);
|
||||
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
|
||||
results = generate_image(runtime->sd_ctx, &img_gen_params);
|
||||
num_results = gen_params.batch_count;
|
||||
}
|
||||
|
||||
json out;
|
||||
out["images"] = json::array();
|
||||
out["parameters"] = j; // TODO should return changed defaults
|
||||
out["parameters"] = j;
|
||||
out["info"] = "";
|
||||
|
||||
for (int i = 0; i < num_results; i++) {
|
||||
@ -1149,21 +1100,21 @@ int main(int argc, const char** argv) {
|
||||
}
|
||||
};
|
||||
|
||||
svr.Post("/sdapi/v1/txt2img", [&](const httplib::Request& req, httplib::Response& res) {
|
||||
svr.Post("/sdapi/v1/txt2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
|
||||
sdapi_any2img(req, res, false);
|
||||
});
|
||||
|
||||
svr.Post("/sdapi/v1/img2img", [&](const httplib::Request& req, httplib::Response& res) {
|
||||
svr.Post("/sdapi/v1/img2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
|
||||
sdapi_any2img(req, res, true);
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/loras", [&](const httplib::Request&, httplib::Response& res) {
|
||||
refresh_lora_cache();
|
||||
svr.Get("/sdapi/v1/loras", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
refresh_lora_cache(*runtime);
|
||||
|
||||
json result = json::array();
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(lora_mutex);
|
||||
for (const auto& e : lora_cache) {
|
||||
std::lock_guard<std::mutex> lock(*runtime->lora_mutex);
|
||||
for (const auto& e : *runtime->lora_cache) {
|
||||
json item;
|
||||
item["name"] = e.name;
|
||||
item["path"] = e.path;
|
||||
@ -1174,7 +1125,7 @@ int main(int argc, const char** argv) {
|
||||
res.set_content(result.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
|
||||
svr.Get("/sdapi/v1/samplers", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
std::vector<std::string> sampler_names;
|
||||
sampler_names.push_back("default");
|
||||
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
|
||||
@ -1191,7 +1142,7 @@ int main(int argc, const char** argv) {
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/schedulers", [&](const httplib::Request&, httplib::Response& res) {
|
||||
svr.Get("/sdapi/v1/schedulers", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
std::vector<std::string> scheduler_names;
|
||||
scheduler_names.push_back("default");
|
||||
for (int i = 0; i < SCHEDULER_COUNT; i++) {
|
||||
@ -1207,8 +1158,8 @@ int main(int argc, const char** argv) {
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/sd-models", [&](const httplib::Request&, httplib::Response& res) {
|
||||
fs::path model_path = ctx_params.model_path;
|
||||
svr.Get("/sdapi/v1/sd-models", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
fs::path model_path = runtime->ctx_params->model_path;
|
||||
json entry;
|
||||
entry["title"] = model_path.stem();
|
||||
entry["model_name"] = model_path.stem();
|
||||
@ -1221,18 +1172,89 @@ int main(int argc, const char** argv) {
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
|
||||
svr.Get("/sdapi/v1/options", [&](const httplib::Request&, httplib::Response& res) {
|
||||
fs::path model_path = ctx_params.model_path;
|
||||
svr.Get("/sdapi/v1/options", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||
fs::path model_path = runtime->ctx_params->model_path;
|
||||
json r;
|
||||
r["samples_format"] = "png";
|
||||
r["sd_model_checkpoint"] = model_path.stem();
|
||||
res.set_content(r.dump(), "application/json");
|
||||
});
|
||||
}
|
||||
|
||||
int main(int argc, const char** argv) {
|
||||
if (argc > 1 && std::string(argv[1]) == "--version") {
|
||||
std::cout << version_string() << "\n";
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
SDSvrParams svr_params;
|
||||
SDContextParams ctx_params;
|
||||
SDGenerationParams default_gen_params;
|
||||
parse_args(argc, argv, svr_params, ctx_params, default_gen_params);
|
||||
|
||||
sd_set_log_callback(sd_log_cb, (void*)&svr_params);
|
||||
log_verbose = svr_params.verbose;
|
||||
log_color = svr_params.color;
|
||||
|
||||
LOG_DEBUG("version: %s", version_string().c_str());
|
||||
LOG_DEBUG("%s", sd_get_system_info());
|
||||
LOG_DEBUG("%s", svr_params.to_string().c_str());
|
||||
LOG_DEBUG("%s", ctx_params.to_string().c_str());
|
||||
LOG_DEBUG("%s", default_gen_params.to_string().c_str());
|
||||
|
||||
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false);
|
||||
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
|
||||
|
||||
if (sd_ctx == nullptr) {
|
||||
LOG_ERROR("new_sd_ctx_t failed");
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::mutex sd_ctx_mutex;
|
||||
|
||||
std::vector<LoraEntry> lora_cache;
|
||||
std::mutex lora_mutex;
|
||||
ServerRuntime runtime = {
|
||||
sd_ctx,
|
||||
&sd_ctx_mutex,
|
||||
&svr_params,
|
||||
&ctx_params,
|
||||
&default_gen_params,
|
||||
&lora_cache,
|
||||
&lora_mutex,
|
||||
};
|
||||
|
||||
httplib::Server svr;
|
||||
|
||||
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
|
||||
std::string origin = req.get_header_value("Origin");
|
||||
if (origin.empty()) {
|
||||
origin = "*";
|
||||
}
|
||||
res.set_header("Access-Control-Allow-Origin", origin);
|
||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||
res.set_header("Access-Control-Allow-Methods", "*");
|
||||
res.set_header("Access-Control-Allow-Headers", "*");
|
||||
|
||||
if (req.method == "OPTIONS") {
|
||||
res.status = 204;
|
||||
return httplib::Server::HandlerResponse::Handled;
|
||||
}
|
||||
return httplib::Server::HandlerResponse::Unhandled;
|
||||
});
|
||||
|
||||
std::string index_html;
|
||||
#ifdef HAVE_INDEX_HTML
|
||||
index_html.assign(reinterpret_cast<const char*>(index_html_bytes), index_html_size);
|
||||
#else
|
||||
index_html = "Stable Diffusion Server is running";
|
||||
#endif
|
||||
register_index_endpoints(svr, svr_params, index_html);
|
||||
register_openai_api_endpoints(svr, runtime);
|
||||
register_sdapi_endpoints(svr, runtime);
|
||||
|
||||
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
|
||||
svr.listen(svr_params.listen_ip, svr_params.listen_port);
|
||||
|
||||
// cleanup
|
||||
free_sd_ctx(sd_ctx);
|
||||
return 0;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user