mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-31 13:48:54 +00:00
refactor(server): split server endpoint registration (#1376)
This commit is contained in:
parent
8d878872d9
commit
83e8f6f0af
@ -273,55 +273,20 @@ struct LoraEntry {
|
|||||||
std::string fullpath;
|
std::string fullpath;
|
||||||
};
|
};
|
||||||
|
|
||||||
void free_results(sd_image_t* result_images, int num_results) {
|
struct ServerRuntime {
|
||||||
if (result_images) {
|
sd_ctx_t* sd_ctx;
|
||||||
for (int i = 0; i < num_results; ++i) {
|
std::mutex* sd_ctx_mutex;
|
||||||
if (result_images[i].data) {
|
const SDSvrParams* svr_params;
|
||||||
stbi_image_free(result_images[i].data);
|
const SDContextParams* ctx_params;
|
||||||
result_images[i].data = nullptr;
|
const SDGenerationParams* default_gen_params;
|
||||||
}
|
std::vector<LoraEntry>* lora_cache;
|
||||||
}
|
std::mutex* lora_mutex;
|
||||||
}
|
};
|
||||||
free(result_images);
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(int argc, const char** argv) {
|
void refresh_lora_cache(ServerRuntime& rt) {
|
||||||
if (argc > 1 && std::string(argv[1]) == "--version") {
|
|
||||||
std::cout << version_string() << "\n";
|
|
||||||
return EXIT_SUCCESS;
|
|
||||||
}
|
|
||||||
SDSvrParams svr_params;
|
|
||||||
SDContextParams ctx_params;
|
|
||||||
SDGenerationParams default_gen_params;
|
|
||||||
parse_args(argc, argv, svr_params, ctx_params, default_gen_params);
|
|
||||||
|
|
||||||
sd_set_log_callback(sd_log_cb, (void*)&svr_params);
|
|
||||||
log_verbose = svr_params.verbose;
|
|
||||||
log_color = svr_params.color;
|
|
||||||
|
|
||||||
LOG_DEBUG("version: %s", version_string().c_str());
|
|
||||||
LOG_DEBUG("%s", sd_get_system_info());
|
|
||||||
LOG_DEBUG("%s", svr_params.to_string().c_str());
|
|
||||||
LOG_DEBUG("%s", ctx_params.to_string().c_str());
|
|
||||||
LOG_DEBUG("%s", default_gen_params.to_string().c_str());
|
|
||||||
|
|
||||||
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false);
|
|
||||||
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
|
|
||||||
|
|
||||||
if (sd_ctx == nullptr) {
|
|
||||||
LOG_ERROR("new_sd_ctx_t failed");
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::mutex sd_ctx_mutex;
|
|
||||||
|
|
||||||
std::vector<LoraEntry> lora_cache;
|
|
||||||
std::mutex lora_mutex;
|
|
||||||
|
|
||||||
auto refresh_lora_cache = [&]() {
|
|
||||||
std::vector<LoraEntry> new_cache;
|
std::vector<LoraEntry> new_cache;
|
||||||
|
|
||||||
fs::path lora_dir = ctx_params.lora_model_dir;
|
fs::path lora_dir = rt.ctx_params->lora_model_dir;
|
||||||
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
|
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
|
||||||
auto is_lora_ext = [](const fs::path& p) {
|
auto is_lora_ext = [](const fs::path& p) {
|
||||||
auto ext = p.extension().string();
|
auto ext = p.extension().string();
|
||||||
@ -353,47 +318,35 @@ int main(int argc, const char** argv) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(lora_mutex);
|
std::lock_guard<std::mutex> lock(*rt.lora_mutex);
|
||||||
lora_cache = std::move(new_cache);
|
*rt.lora_cache = std::move(new_cache);
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
auto get_lora_full_path = [&](const std::string& path) -> std::string {
|
std::string get_lora_full_path(ServerRuntime& rt, const std::string& path) {
|
||||||
std::lock_guard<std::mutex> lock(lora_mutex);
|
std::lock_guard<std::mutex> lock(*rt.lora_mutex);
|
||||||
auto it = std::find_if(lora_cache.begin(), lora_cache.end(),
|
auto it = std::find_if(rt.lora_cache->begin(), rt.lora_cache->end(),
|
||||||
[&](const LoraEntry& e) { return e.path == path; });
|
[&](const LoraEntry& e) { return e.path == path; });
|
||||||
return (it != lora_cache.end()) ? it->fullpath : "";
|
return (it != rt.lora_cache->end()) ? it->fullpath : "";
|
||||||
};
|
}
|
||||||
|
|
||||||
httplib::Server svr;
|
void free_results(sd_image_t* result_images, int num_results) {
|
||||||
|
if (result_images) {
|
||||||
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
|
for (int i = 0; i < num_results; ++i) {
|
||||||
std::string origin = req.get_header_value("Origin");
|
if (result_images[i].data) {
|
||||||
if (origin.empty()) {
|
stbi_image_free(result_images[i].data);
|
||||||
origin = "*";
|
result_images[i].data = nullptr;
|
||||||
}
|
}
|
||||||
res.set_header("Access-Control-Allow-Origin", origin);
|
|
||||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
|
||||||
res.set_header("Access-Control-Allow-Methods", "*");
|
|
||||||
res.set_header("Access-Control-Allow-Headers", "*");
|
|
||||||
|
|
||||||
if (req.method == "OPTIONS") {
|
|
||||||
res.status = 204;
|
|
||||||
return httplib::Server::HandlerResponse::Handled;
|
|
||||||
}
|
}
|
||||||
return httplib::Server::HandlerResponse::Unhandled;
|
}
|
||||||
});
|
free(result_images);
|
||||||
|
}
|
||||||
|
|
||||||
// index html
|
void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) {
|
||||||
std::string index_html;
|
const std::string serve_html_path = svr_params.serve_html_path;
|
||||||
#ifdef HAVE_INDEX_HTML
|
svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) {
|
||||||
index_html.assign(reinterpret_cast<const char*>(index_html_bytes), index_html_size);
|
if (!serve_html_path.empty()) {
|
||||||
#else
|
std::ifstream file(serve_html_path);
|
||||||
index_html = "Stable Diffusion Server is running";
|
|
||||||
#endif
|
|
||||||
svr.Get("/", [&](const httplib::Request&, httplib::Response& res) {
|
|
||||||
if (!svr_params.serve_html_path.empty()) {
|
|
||||||
std::ifstream file(svr_params.serve_html_path);
|
|
||||||
if (file) {
|
if (file) {
|
||||||
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
||||||
res.set_content(content, "text/html");
|
res.set_content(content, "text/html");
|
||||||
@ -405,17 +358,19 @@ int main(int argc, const char** argv) {
|
|||||||
res.set_content(index_html, "text/html");
|
res.set_content(index_html, "text/html");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// models endpoint (minimal)
|
void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
|
||||||
svr.Get("/v1/models", [&](const httplib::Request&, httplib::Response& res) {
|
ServerRuntime* runtime = &rt;
|
||||||
|
|
||||||
|
svr.Get("/v1/models", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||||
json r;
|
json r;
|
||||||
r["data"] = json::array();
|
r["data"] = json::array();
|
||||||
r["data"].push_back({{"id", "sd-cpp-local"}, {"object", "model"}, {"owned_by", "local"}});
|
r["data"].push_back({{"id", "sd-cpp-local"}, {"object", "model"}, {"owned_by", "local"}});
|
||||||
res.set_content(r.dump(), "application/json");
|
res.set_content(r.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
|
||||||
// core endpoint: /v1/images/generations
|
svr.Post("/v1/images/generations", [runtime](const httplib::Request& req, httplib::Response& res) {
|
||||||
svr.Post("/v1/images/generations", [&](const httplib::Request& req, httplib::Response& res) {
|
|
||||||
try {
|
try {
|
||||||
if (req.body.empty()) {
|
if (req.body.empty()) {
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
@ -429,8 +384,8 @@ int main(int argc, const char** argv) {
|
|||||||
std::string size = j.value("size", "");
|
std::string size = j.value("size", "");
|
||||||
std::string output_format = j.value("output_format", "png");
|
std::string output_format = j.value("output_format", "png");
|
||||||
int output_compression = j.value("output_compression", 100);
|
int output_compression = j.value("output_compression", 100);
|
||||||
int width = default_gen_params.width > 0 ? default_gen_params.width : 512;
|
int width = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->width : 512;
|
||||||
int height = default_gen_params.width > 0 ? default_gen_params.height : 512;
|
int height = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->height : 512;
|
||||||
if (!size.empty()) {
|
if (!size.empty()) {
|
||||||
auto pos = size.find('x');
|
auto pos = size.find('x');
|
||||||
if (pos != std::string::npos) {
|
if (pos != std::string::npos) {
|
||||||
@ -458,7 +413,7 @@ int main(int argc, const char** argv) {
|
|||||||
if (n <= 0)
|
if (n <= 0)
|
||||||
n = 1;
|
n = 1;
|
||||||
if (n > 8)
|
if (n > 8)
|
||||||
n = 8; // safety
|
n = 8;
|
||||||
if (output_compression > 100) {
|
if (output_compression > 100) {
|
||||||
output_compression = 100;
|
output_compression = 100;
|
||||||
}
|
}
|
||||||
@ -471,7 +426,7 @@ int main(int argc, const char** argv) {
|
|||||||
out["data"] = json::array();
|
out["data"] = json::array();
|
||||||
out["output_format"] = output_format;
|
out["output_format"] = output_format;
|
||||||
|
|
||||||
SDGenerationParams gen_params = default_gen_params;
|
SDGenerationParams gen_params = *runtime->default_gen_params;
|
||||||
gen_params.prompt = prompt;
|
gen_params.prompt = prompt;
|
||||||
gen_params.width = width;
|
gen_params.width = width;
|
||||||
gen_params.height = height;
|
gen_params.height = height;
|
||||||
@ -524,7 +479,7 @@ int main(int argc, const char** argv) {
|
|||||||
(int)pmid_images.size(),
|
(int)pmid_images.size(),
|
||||||
gen_params.pm_id_embed_path.c_str(),
|
gen_params.pm_id_embed_path.c_str(),
|
||||||
gen_params.pm_style_strength,
|
gen_params.pm_style_strength,
|
||||||
}, // pm_params
|
},
|
||||||
gen_params.vae_tiling_params,
|
gen_params.vae_tiling_params,
|
||||||
gen_params.cache_params,
|
gen_params.cache_params,
|
||||||
};
|
};
|
||||||
@ -533,8 +488,8 @@ int main(int argc, const char** argv) {
|
|||||||
int num_results = 0;
|
int num_results = 0;
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
|
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
|
||||||
results = generate_image(sd_ctx, &img_gen_params);
|
results = generate_image(runtime->sd_ctx, &img_gen_params);
|
||||||
num_results = gen_params.batch_count;
|
num_results = gen_params.batch_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -553,7 +508,6 @@ int main(int argc, const char** argv) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// base64 encode
|
|
||||||
std::string b64 = base64_encode(image_bytes);
|
std::string b64 = base64_encode(image_bytes);
|
||||||
json item;
|
json item;
|
||||||
item["b64_json"] = b64;
|
item["b64_json"] = b64;
|
||||||
@ -573,7 +527,7 @@ int main(int argc, const char** argv) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Post("/v1/images/edits", [&](const httplib::Request& req, httplib::Response& res) {
|
svr.Post("/v1/images/edits", [runtime](const httplib::Request& req, httplib::Response& res) {
|
||||||
try {
|
try {
|
||||||
if (!req.is_multipart_form_data()) {
|
if (!req.is_multipart_form_data()) {
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
@ -658,7 +612,7 @@ int main(int argc, const char** argv) {
|
|||||||
output_compression = 0;
|
output_compression = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
SDGenerationParams gen_params = default_gen_params;
|
SDGenerationParams gen_params = *runtime->default_gen_params;
|
||||||
gen_params.prompt = prompt;
|
gen_params.prompt = prompt;
|
||||||
gen_params.width = width;
|
gen_params.width = width;
|
||||||
gen_params.height = height;
|
gen_params.height = height;
|
||||||
@ -685,18 +639,18 @@ int main(int argc, const char** argv) {
|
|||||||
sd_image_t control_image = {0, 0, 3, nullptr};
|
sd_image_t control_image = {0, 0, 3, nullptr};
|
||||||
std::vector<sd_image_t> pmid_images;
|
std::vector<sd_image_t> pmid_images;
|
||||||
|
|
||||||
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
|
auto get_resolved_width = [&gen_params, runtime]() -> int {
|
||||||
if (gen_params.width > 0)
|
if (gen_params.width > 0)
|
||||||
return gen_params.width;
|
return gen_params.width;
|
||||||
if (default_gen_params.width > 0)
|
if (runtime->default_gen_params->width > 0)
|
||||||
return default_gen_params.width;
|
return runtime->default_gen_params->width;
|
||||||
return 512;
|
return 512;
|
||||||
};
|
};
|
||||||
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
|
auto get_resolved_height = [&gen_params, runtime]() -> int {
|
||||||
if (gen_params.height > 0)
|
if (gen_params.height > 0)
|
||||||
return gen_params.height;
|
return gen_params.height;
|
||||||
if (default_gen_params.height > 0)
|
if (runtime->default_gen_params->height > 0)
|
||||||
return default_gen_params.height;
|
return runtime->default_gen_params->height;
|
||||||
return 512;
|
return 512;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -771,7 +725,7 @@ int main(int argc, const char** argv) {
|
|||||||
(int)pmid_images.size(),
|
(int)pmid_images.size(),
|
||||||
gen_params.pm_id_embed_path.c_str(),
|
gen_params.pm_id_embed_path.c_str(),
|
||||||
gen_params.pm_style_strength,
|
gen_params.pm_style_strength,
|
||||||
}, // pm_params
|
},
|
||||||
gen_params.vae_tiling_params,
|
gen_params.vae_tiling_params,
|
||||||
gen_params.cache_params,
|
gen_params.cache_params,
|
||||||
};
|
};
|
||||||
@ -780,8 +734,8 @@ int main(int argc, const char** argv) {
|
|||||||
int num_results = 0;
|
int num_results = 0;
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
|
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
|
||||||
results = generate_image(sd_ctx, &img_gen_params);
|
results = generate_image(runtime->sd_ctx, &img_gen_params);
|
||||||
num_results = gen_params.batch_count;
|
num_results = gen_params.batch_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -826,10 +780,12 @@ int main(int argc, const char** argv) {
|
|||||||
res.set_content(err.dump(), "application/json");
|
res.set_content(err.dump(), "application/json");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
// sdapi endpoints (AUTOMATIC1111 / Forge)
|
void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
|
||||||
|
ServerRuntime* runtime = &rt;
|
||||||
|
|
||||||
auto sdapi_any2img = [&](const httplib::Request& req, httplib::Response& res, bool img2img) {
|
auto sdapi_any2img = [runtime](const httplib::Request& req, httplib::Response& res, bool img2img) {
|
||||||
try {
|
try {
|
||||||
if (req.body.empty()) {
|
if (req.body.empty()) {
|
||||||
res.status = 400;
|
res.status = 400;
|
||||||
@ -843,8 +799,8 @@ int main(int argc, const char** argv) {
|
|||||||
std::string negative_prompt = j.value("negative_prompt", "");
|
std::string negative_prompt = j.value("negative_prompt", "");
|
||||||
int width = j.value("width", 512);
|
int width = j.value("width", 512);
|
||||||
int height = j.value("height", 512);
|
int height = j.value("height", 512);
|
||||||
int steps = j.value("steps", default_gen_params.sample_params.sample_steps);
|
int steps = j.value("steps", runtime->default_gen_params->sample_params.sample_steps);
|
||||||
float cfg_scale = j.value("cfg_scale", default_gen_params.sample_params.guidance.txt_cfg);
|
float cfg_scale = j.value("cfg_scale", runtime->default_gen_params->sample_params.guidance.txt_cfg);
|
||||||
int64_t seed = j.value("seed", -1);
|
int64_t seed = j.value("seed", -1);
|
||||||
int batch_size = j.value("batch_size", 1);
|
int batch_size = j.value("batch_size", 1);
|
||||||
int clip_skip = j.value("clip_skip", -1);
|
int clip_skip = j.value("clip_skip", -1);
|
||||||
@ -894,7 +850,7 @@ int main(int argc, const char** argv) {
|
|||||||
return bad("lora.path required");
|
return bad("lora.path required");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string fullpath = get_lora_full_path(path);
|
std::string fullpath = get_lora_full_path(*runtime, path);
|
||||||
if (fullpath.empty()) {
|
if (fullpath.empty()) {
|
||||||
return bad("invalid lora path: " + path);
|
return bad("invalid lora path: " + path);
|
||||||
}
|
}
|
||||||
@ -912,7 +868,6 @@ int main(int argc, const char** argv) {
|
|||||||
auto get_sample_method = [](std::string name) -> enum sample_method_t {
|
auto get_sample_method = [](std::string name) -> enum sample_method_t {
|
||||||
enum sample_method_t result = str_to_sample_method(name.c_str());
|
enum sample_method_t result = str_to_sample_method(name.c_str());
|
||||||
if (result != SAMPLE_METHOD_COUNT) return result;
|
if (result != SAMPLE_METHOD_COUNT) return result;
|
||||||
// some applications use a hardcoded sampler list
|
|
||||||
std::transform(name.begin(), name.end(), name.begin(),
|
std::transform(name.begin(), name.end(), name.begin(),
|
||||||
[](unsigned char c) { return std::tolower(c); });
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
|
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
|
||||||
@ -938,10 +893,9 @@ int main(int argc, const char** argv) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
enum sample_method_t sample_method = get_sample_method(sampler_name);
|
enum sample_method_t sample_method = get_sample_method(sampler_name);
|
||||||
|
|
||||||
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
|
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
|
||||||
|
|
||||||
SDGenerationParams gen_params = default_gen_params;
|
SDGenerationParams gen_params = *runtime->default_gen_params;
|
||||||
gen_params.prompt = prompt;
|
gen_params.prompt = prompt;
|
||||||
gen_params.negative_prompt = negative_prompt;
|
gen_params.negative_prompt = negative_prompt;
|
||||||
gen_params.seed = seed;
|
gen_params.seed = seed;
|
||||||
@ -961,8 +915,6 @@ int main(int argc, const char** argv) {
|
|||||||
gen_params.sample_params.scheduler = scheduler;
|
gen_params.sample_params.scheduler = scheduler;
|
||||||
}
|
}
|
||||||
|
|
||||||
// re-read to avoid applying 512 as default before the provided
|
|
||||||
// images and/or server command-line
|
|
||||||
gen_params.width = j.value("width", -1);
|
gen_params.width = j.value("width", -1);
|
||||||
gen_params.height = j.value("height", -1);
|
gen_params.height = j.value("height", -1);
|
||||||
|
|
||||||
@ -975,23 +927,22 @@ int main(int argc, const char** argv) {
|
|||||||
std::vector<sd_image_t> pmid_images;
|
std::vector<sd_image_t> pmid_images;
|
||||||
std::vector<sd_image_t> ref_images;
|
std::vector<sd_image_t> ref_images;
|
||||||
|
|
||||||
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
|
auto get_resolved_width = [&gen_params, runtime]() -> int {
|
||||||
if (gen_params.width > 0)
|
if (gen_params.width > 0)
|
||||||
return gen_params.width;
|
return gen_params.width;
|
||||||
if (default_gen_params.width > 0)
|
if (runtime->default_gen_params->width > 0)
|
||||||
return default_gen_params.width;
|
return runtime->default_gen_params->width;
|
||||||
return 512;
|
return 512;
|
||||||
};
|
};
|
||||||
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
|
auto get_resolved_height = [&gen_params, runtime]() -> int {
|
||||||
if (gen_params.height > 0)
|
if (gen_params.height > 0)
|
||||||
return gen_params.height;
|
return gen_params.height;
|
||||||
if (default_gen_params.height > 0)
|
if (runtime->default_gen_params->height > 0)
|
||||||
return default_gen_params.height;
|
return runtime->default_gen_params->height;
|
||||||
return 512;
|
return 512;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
|
auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
|
||||||
// remove data URI prefix if present ("data:image/png;base64,")
|
|
||||||
auto comma_pos = encoded.find(',');
|
auto comma_pos = encoded.find(',');
|
||||||
if (comma_pos != std::string::npos) {
|
if (comma_pos != std::string::npos) {
|
||||||
encoded = encoded.substr(comma_pos + 1);
|
encoded = encoded.substr(comma_pos + 1);
|
||||||
@ -1087,7 +1038,7 @@ int main(int argc, const char** argv) {
|
|||||||
(int)pmid_images.size(),
|
(int)pmid_images.size(),
|
||||||
gen_params.pm_id_embed_path.c_str(),
|
gen_params.pm_id_embed_path.c_str(),
|
||||||
gen_params.pm_style_strength,
|
gen_params.pm_style_strength,
|
||||||
}, // pm_params
|
},
|
||||||
gen_params.vae_tiling_params,
|
gen_params.vae_tiling_params,
|
||||||
gen_params.cache_params,
|
gen_params.cache_params,
|
||||||
};
|
};
|
||||||
@ -1096,14 +1047,14 @@ int main(int argc, const char** argv) {
|
|||||||
int num_results = 0;
|
int num_results = 0;
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
|
std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
|
||||||
results = generate_image(sd_ctx, &img_gen_params);
|
results = generate_image(runtime->sd_ctx, &img_gen_params);
|
||||||
num_results = gen_params.batch_count;
|
num_results = gen_params.batch_count;
|
||||||
}
|
}
|
||||||
|
|
||||||
json out;
|
json out;
|
||||||
out["images"] = json::array();
|
out["images"] = json::array();
|
||||||
out["parameters"] = j; // TODO should return changed defaults
|
out["parameters"] = j;
|
||||||
out["info"] = "";
|
out["info"] = "";
|
||||||
|
|
||||||
for (int i = 0; i < num_results; i++) {
|
for (int i = 0; i < num_results; i++) {
|
||||||
@ -1149,21 +1100,21 @@ int main(int argc, const char** argv) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
svr.Post("/sdapi/v1/txt2img", [&](const httplib::Request& req, httplib::Response& res) {
|
svr.Post("/sdapi/v1/txt2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
|
||||||
sdapi_any2img(req, res, false);
|
sdapi_any2img(req, res, false);
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Post("/sdapi/v1/img2img", [&](const httplib::Request& req, httplib::Response& res) {
|
svr.Post("/sdapi/v1/img2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
|
||||||
sdapi_any2img(req, res, true);
|
sdapi_any2img(req, res, true);
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Get("/sdapi/v1/loras", [&](const httplib::Request&, httplib::Response& res) {
|
svr.Get("/sdapi/v1/loras", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||||
refresh_lora_cache();
|
refresh_lora_cache(*runtime);
|
||||||
|
|
||||||
json result = json::array();
|
json result = json::array();
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lock(lora_mutex);
|
std::lock_guard<std::mutex> lock(*runtime->lora_mutex);
|
||||||
for (const auto& e : lora_cache) {
|
for (const auto& e : *runtime->lora_cache) {
|
||||||
json item;
|
json item;
|
||||||
item["name"] = e.name;
|
item["name"] = e.name;
|
||||||
item["path"] = e.path;
|
item["path"] = e.path;
|
||||||
@ -1174,7 +1125,7 @@ int main(int argc, const char** argv) {
|
|||||||
res.set_content(result.dump(), "application/json");
|
res.set_content(result.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
|
svr.Get("/sdapi/v1/samplers", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||||
std::vector<std::string> sampler_names;
|
std::vector<std::string> sampler_names;
|
||||||
sampler_names.push_back("default");
|
sampler_names.push_back("default");
|
||||||
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
|
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
|
||||||
@ -1191,7 +1142,7 @@ int main(int argc, const char** argv) {
|
|||||||
res.set_content(r.dump(), "application/json");
|
res.set_content(r.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Get("/sdapi/v1/schedulers", [&](const httplib::Request&, httplib::Response& res) {
|
svr.Get("/sdapi/v1/schedulers", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||||
std::vector<std::string> scheduler_names;
|
std::vector<std::string> scheduler_names;
|
||||||
scheduler_names.push_back("default");
|
scheduler_names.push_back("default");
|
||||||
for (int i = 0; i < SCHEDULER_COUNT; i++) {
|
for (int i = 0; i < SCHEDULER_COUNT; i++) {
|
||||||
@ -1207,8 +1158,8 @@ int main(int argc, const char** argv) {
|
|||||||
res.set_content(r.dump(), "application/json");
|
res.set_content(r.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Get("/sdapi/v1/sd-models", [&](const httplib::Request&, httplib::Response& res) {
|
svr.Get("/sdapi/v1/sd-models", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||||
fs::path model_path = ctx_params.model_path;
|
fs::path model_path = runtime->ctx_params->model_path;
|
||||||
json entry;
|
json entry;
|
||||||
entry["title"] = model_path.stem();
|
entry["title"] = model_path.stem();
|
||||||
entry["model_name"] = model_path.stem();
|
entry["model_name"] = model_path.stem();
|
||||||
@ -1221,18 +1172,89 @@ int main(int argc, const char** argv) {
|
|||||||
res.set_content(r.dump(), "application/json");
|
res.set_content(r.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Get("/sdapi/v1/options", [&](const httplib::Request&, httplib::Response& res) {
|
svr.Get("/sdapi/v1/options", [runtime](const httplib::Request&, httplib::Response& res) {
|
||||||
fs::path model_path = ctx_params.model_path;
|
fs::path model_path = runtime->ctx_params->model_path;
|
||||||
json r;
|
json r;
|
||||||
r["samples_format"] = "png";
|
r["samples_format"] = "png";
|
||||||
r["sd_model_checkpoint"] = model_path.stem();
|
r["sd_model_checkpoint"] = model_path.stem();
|
||||||
res.set_content(r.dump(), "application/json");
|
res.set_content(r.dump(), "application/json");
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, const char** argv) {
|
||||||
|
if (argc > 1 && std::string(argv[1]) == "--version") {
|
||||||
|
std::cout << version_string() << "\n";
|
||||||
|
return EXIT_SUCCESS;
|
||||||
|
}
|
||||||
|
SDSvrParams svr_params;
|
||||||
|
SDContextParams ctx_params;
|
||||||
|
SDGenerationParams default_gen_params;
|
||||||
|
parse_args(argc, argv, svr_params, ctx_params, default_gen_params);
|
||||||
|
|
||||||
|
sd_set_log_callback(sd_log_cb, (void*)&svr_params);
|
||||||
|
log_verbose = svr_params.verbose;
|
||||||
|
log_color = svr_params.color;
|
||||||
|
|
||||||
|
LOG_DEBUG("version: %s", version_string().c_str());
|
||||||
|
LOG_DEBUG("%s", sd_get_system_info());
|
||||||
|
LOG_DEBUG("%s", svr_params.to_string().c_str());
|
||||||
|
LOG_DEBUG("%s", ctx_params.to_string().c_str());
|
||||||
|
LOG_DEBUG("%s", default_gen_params.to_string().c_str());
|
||||||
|
|
||||||
|
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false);
|
||||||
|
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
|
||||||
|
|
||||||
|
if (sd_ctx == nullptr) {
|
||||||
|
LOG_ERROR("new_sd_ctx_t failed");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::mutex sd_ctx_mutex;
|
||||||
|
|
||||||
|
std::vector<LoraEntry> lora_cache;
|
||||||
|
std::mutex lora_mutex;
|
||||||
|
ServerRuntime runtime = {
|
||||||
|
sd_ctx,
|
||||||
|
&sd_ctx_mutex,
|
||||||
|
&svr_params,
|
||||||
|
&ctx_params,
|
||||||
|
&default_gen_params,
|
||||||
|
&lora_cache,
|
||||||
|
&lora_mutex,
|
||||||
|
};
|
||||||
|
|
||||||
|
httplib::Server svr;
|
||||||
|
|
||||||
|
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
|
||||||
|
std::string origin = req.get_header_value("Origin");
|
||||||
|
if (origin.empty()) {
|
||||||
|
origin = "*";
|
||||||
|
}
|
||||||
|
res.set_header("Access-Control-Allow-Origin", origin);
|
||||||
|
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||||
|
res.set_header("Access-Control-Allow-Methods", "*");
|
||||||
|
res.set_header("Access-Control-Allow-Headers", "*");
|
||||||
|
|
||||||
|
if (req.method == "OPTIONS") {
|
||||||
|
res.status = 204;
|
||||||
|
return httplib::Server::HandlerResponse::Handled;
|
||||||
|
}
|
||||||
|
return httplib::Server::HandlerResponse::Unhandled;
|
||||||
|
});
|
||||||
|
|
||||||
|
std::string index_html;
|
||||||
|
#ifdef HAVE_INDEX_HTML
|
||||||
|
index_html.assign(reinterpret_cast<const char*>(index_html_bytes), index_html_size);
|
||||||
|
#else
|
||||||
|
index_html = "Stable Diffusion Server is running";
|
||||||
|
#endif
|
||||||
|
register_index_endpoints(svr, svr_params, index_html);
|
||||||
|
register_openai_api_endpoints(svr, runtime);
|
||||||
|
register_sdapi_endpoints(svr, runtime);
|
||||||
|
|
||||||
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
|
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
|
||||||
svr.listen(svr_params.listen_ip, svr_params.listen_port);
|
svr.listen(svr_params.listen_ip, svr_params.listen_port);
|
||||||
|
|
||||||
// cleanup
|
|
||||||
free_sd_ctx(sd_ctx);
|
free_sd_ctx(sd_ctx);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user