mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 10:18:51 +00:00
feat(server): add lora support to sdapi (#1256)
This commit is contained in:
parent
9f56833e14
commit
f0f641a142
@ -263,6 +263,11 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
|||||||
log_print(level, log, svr_params->verbose, svr_params->color);
|
log_print(level, log, svr_params->verbose, svr_params->color);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct LoraEntry {
|
||||||
|
std::string name;
|
||||||
|
std::string path;
|
||||||
|
};
|
||||||
|
|
||||||
int main(int argc, const char** argv) {
|
int main(int argc, const char** argv) {
|
||||||
if (argc > 1 && std::string(argv[1]) == "--version") {
|
if (argc > 1 && std::string(argv[1]) == "--version") {
|
||||||
std::cout << version_string() << "\n";
|
std::cout << version_string() << "\n";
|
||||||
@ -293,6 +298,54 @@ int main(int argc, const char** argv) {
|
|||||||
|
|
||||||
std::mutex sd_ctx_mutex;
|
std::mutex sd_ctx_mutex;
|
||||||
|
|
||||||
|
std::vector<LoraEntry> lora_cache;
|
||||||
|
std::mutex lora_mutex;
|
||||||
|
|
||||||
|
auto refresh_lora_cache = [&]() {
|
||||||
|
std::vector<LoraEntry> new_cache;
|
||||||
|
|
||||||
|
fs::path lora_dir = ctx_params.lora_model_dir;
|
||||||
|
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
|
||||||
|
auto is_lora_ext = [](const fs::path& p) {
|
||||||
|
auto ext = p.extension().string();
|
||||||
|
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
||||||
|
return ext == ".gguf" || ext == ".pt" || ext == ".pth" || ext == ".safetensors";
|
||||||
|
};
|
||||||
|
|
||||||
|
for (auto& entry : fs::recursive_directory_iterator(lora_dir)) {
|
||||||
|
if (!entry.is_regular_file())
|
||||||
|
continue;
|
||||||
|
const fs::path& p = entry.path();
|
||||||
|
if (!is_lora_ext(p))
|
||||||
|
continue;
|
||||||
|
|
||||||
|
LoraEntry e;
|
||||||
|
e.name = p.stem().u8string();
|
||||||
|
std::string rel = fs::relative(p, lora_dir).u8string();
|
||||||
|
std::replace(rel.begin(), rel.end(), '\\', '/');
|
||||||
|
e.path = rel;
|
||||||
|
|
||||||
|
new_cache.push_back(std::move(e));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(new_cache.begin(), new_cache.end(),
|
||||||
|
[](const LoraEntry& a, const LoraEntry& b) {
|
||||||
|
return a.path < b.path;
|
||||||
|
});
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(lora_mutex);
|
||||||
|
lora_cache = std::move(new_cache);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto is_valid_lora_path = [&](const std::string& path) -> bool {
|
||||||
|
std::lock_guard<std::mutex> lock(lora_mutex);
|
||||||
|
return std::any_of(lora_cache.begin(), lora_cache.end(),
|
||||||
|
[&](const LoraEntry& e) { return e.path == path; });
|
||||||
|
};
|
||||||
|
|
||||||
httplib::Server svr;
|
httplib::Server svr;
|
||||||
|
|
||||||
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
|
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
|
||||||
@ -312,7 +365,7 @@ int main(int argc, const char** argv) {
|
|||||||
return httplib::Server::HandlerResponse::Unhandled;
|
return httplib::Server::HandlerResponse::Unhandled;
|
||||||
});
|
});
|
||||||
|
|
||||||
// health
|
// root
|
||||||
svr.Get("/", [&](const httplib::Request&, httplib::Response& res) {
|
svr.Get("/", [&](const httplib::Request&, httplib::Response& res) {
|
||||||
if (!svr_params.serve_html_path.empty()) {
|
if (!svr_params.serve_html_path.empty()) {
|
||||||
std::ifstream file(svr_params.serve_html_path);
|
std::ifstream file(svr_params.serve_html_path);
|
||||||
@ -767,6 +820,37 @@ int main(int argc, const char** argv) {
|
|||||||
return bad("prompt required");
|
return bad("prompt required");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<sd_lora_t> sd_loras;
|
||||||
|
std::vector<std::string> lora_path_storage;
|
||||||
|
|
||||||
|
if (j.contains("lora") && j["lora"].is_array()) {
|
||||||
|
for (const auto& item : j["lora"]) {
|
||||||
|
if (!item.is_object()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string path = item.value("path", "");
|
||||||
|
float multiplier = item.value("multiplier", 1.0f);
|
||||||
|
bool is_high_noise = item.value("is_high_noise", false);
|
||||||
|
|
||||||
|
if (path.empty()) {
|
||||||
|
return bad("lora.path required");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_valid_lora_path(path)) {
|
||||||
|
return bad("invalid lora path: " + path);
|
||||||
|
}
|
||||||
|
|
||||||
|
lora_path_storage.push_back(path);
|
||||||
|
sd_lora_t l;
|
||||||
|
l.is_high_noise = is_high_noise;
|
||||||
|
l.multiplier = multiplier;
|
||||||
|
l.path = lora_path_storage.back().c_str();
|
||||||
|
|
||||||
|
sd_loras.push_back(l);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
auto get_sample_method = [](std::string name) -> enum sample_method_t {
|
auto get_sample_method = [](std::string name) -> enum sample_method_t {
|
||||||
enum sample_method_t result = str_to_sample_method(name.c_str());
|
enum sample_method_t result = str_to_sample_method(name.c_str());
|
||||||
if (result != SAMPLE_METHOD_COUNT) return result;
|
if (result != SAMPLE_METHOD_COUNT) return result;
|
||||||
@ -894,8 +978,8 @@ int main(int argc, const char** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sd_img_gen_params_t img_gen_params = {
|
sd_img_gen_params_t img_gen_params = {
|
||||||
gen_params.lora_vec.data(),
|
sd_loras.data(),
|
||||||
static_cast<uint32_t>(gen_params.lora_vec.size()),
|
static_cast<uint32_t>(sd_loras.size()),
|
||||||
gen_params.prompt.c_str(),
|
gen_params.prompt.c_str(),
|
||||||
gen_params.negative_prompt.c_str(),
|
gen_params.negative_prompt.c_str(),
|
||||||
gen_params.clip_skip,
|
gen_params.clip_skip,
|
||||||
@ -987,6 +1071,23 @@ int main(int argc, const char** argv) {
|
|||||||
sdapi_any2img(req, res, true);
|
sdapi_any2img(req, res, true);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
svr.Get("/sdapi/v1/loras", [&](const httplib::Request&, httplib::Response& res) {
|
||||||
|
refresh_lora_cache();
|
||||||
|
|
||||||
|
json result = json::array();
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(lora_mutex);
|
||||||
|
for (const auto& e : lora_cache) {
|
||||||
|
json item;
|
||||||
|
item["name"] = e.name;
|
||||||
|
item["path"] = e.path;
|
||||||
|
result.push_back(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res.set_content(result.dump(), "application/json");
|
||||||
|
});
|
||||||
|
|
||||||
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
|
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) {
|
||||||
std::vector<std::string> sampler_names;
|
std::vector<std::string> sampler_names;
|
||||||
sampler_names.push_back("default");
|
sampler_names.push_back("default");
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user