Compare commits

...

5 Commits

8 changed files with 900 additions and 792 deletions

View File

@ -750,7 +750,7 @@ int main(int argc, const char* argv[]) {
gen_params.pm_id_embed_path.c_str(), gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength, gen_params.pm_style_strength,
}, // pm_params }, // pm_params
ctx_params.vae_tiling_params, gen_params.vae_tiling_params,
gen_params.cache_params, gen_params.cache_params,
}; };
@ -776,7 +776,7 @@ int main(int argc, const char* argv[]) {
gen_params.seed, gen_params.seed,
gen_params.video_frames, gen_params.video_frames,
gen_params.vace_strength, gen_params.vace_strength,
ctx_params.vae_tiling_params, gen_params.vae_tiling_params,
gen_params.cache_params, gen_params.cache_params,
}; };

View File

@ -475,7 +475,6 @@ struct SDContextParams {
prediction_t prediction = PREDICTION_COUNT; prediction_t prediction = PREDICTION_COUNT;
lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO; lora_apply_mode_t lora_apply_mode = LORA_APPLY_AUTO;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
bool force_sdxl_vae_conv_scale = false; bool force_sdxl_vae_conv_scale = false;
float flow_shift = INFINITY; float flow_shift = INFINITY;
@ -576,18 +575,9 @@ struct SDContextParams {
&chroma_t5_mask_pad}, &chroma_t5_mask_pad},
}; };
options.float_options = { options.float_options = {};
{"",
"--vae-tile-overlap",
"tile overlap for vae tiling, in fraction of tile size (default: 0.5)",
&vae_tiling_params.target_overlap},
};
options.bool_options = { options.bool_options = {
{"",
"--vae-tiling",
"process vae in tiles to reduce memory usage",
true, &vae_tiling_params.enabled},
{"", {"",
"--force-sdxl-vae-conv-scale", "--force-sdxl-vae-conv-scale",
"force use of conv scale on sdxl vae", "force use of conv scale on sdxl vae",
@ -724,52 +714,6 @@ struct SDContextParams {
return 1; return 1;
}; };
auto on_tile_size_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
std::string tile_size_str = argv[index];
size_t x_pos = tile_size_str.find('x');
try {
if (x_pos != std::string::npos) {
std::string tile_x_str = tile_size_str.substr(0, x_pos);
std::string tile_y_str = tile_size_str.substr(x_pos + 1);
vae_tiling_params.tile_size_x = std::stoi(tile_x_str);
vae_tiling_params.tile_size_y = std::stoi(tile_y_str);
} else {
vae_tiling_params.tile_size_x = vae_tiling_params.tile_size_y = std::stoi(tile_size_str);
}
} catch (const std::invalid_argument&) {
return -1;
} catch (const std::out_of_range&) {
return -1;
}
return 1;
};
auto on_relative_tile_size_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
std::string rel_size_str = argv[index];
size_t x_pos = rel_size_str.find('x');
try {
if (x_pos != std::string::npos) {
std::string rel_x_str = rel_size_str.substr(0, x_pos);
std::string rel_y_str = rel_size_str.substr(x_pos + 1);
vae_tiling_params.rel_size_x = std::stof(rel_x_str);
vae_tiling_params.rel_size_y = std::stof(rel_y_str);
} else {
vae_tiling_params.rel_size_x = vae_tiling_params.rel_size_y = std::stof(rel_size_str);
}
} catch (const std::invalid_argument&) {
return -1;
} catch (const std::out_of_range&) {
return -1;
}
return 1;
};
options.manual_options = { options.manual_options = {
{"", {"",
"--type", "--type",
@ -796,14 +740,6 @@ struct SDContextParams {
"but it usually offers faster inference speed and, in some cases, lower memory usage. " "but it usually offers faster inference speed and, in some cases, lower memory usage. "
"The at_runtime mode, on the other hand, is exactly the opposite.", "The at_runtime mode, on the other hand, is exactly the opposite.",
on_lora_apply_mode_arg}, on_lora_apply_mode_arg},
{"",
"--vae-tile-size",
"tile size for vae tiling, format [X]x[Y] (default: 32x32)",
on_tile_size_arg},
{"",
"--vae-relative-tile-size",
"relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)",
on_relative_tile_size_arg},
}; };
return options; return options;
@ -917,13 +853,6 @@ struct SDContextParams {
<< " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n" << " chroma_t5_mask_pad: " << chroma_t5_mask_pad << ",\n"
<< " prediction: " << sd_prediction_name(prediction) << ",\n" << " prediction: " << sd_prediction_name(prediction) << ",\n"
<< " lora_apply_mode: " << sd_lora_apply_mode_name(lora_apply_mode) << ",\n" << " lora_apply_mode: " << sd_lora_apply_mode_name(lora_apply_mode) << ",\n"
<< " vae_tiling_params: { "
<< vae_tiling_params.enabled << ", "
<< vae_tiling_params.tile_size_x << ", "
<< vae_tiling_params.tile_size_y << ", "
<< vae_tiling_params.target_overlap << ", "
<< vae_tiling_params.rel_size_x << ", "
<< vae_tiling_params.rel_size_y << " },\n"
<< " force_sdxl_vae_conv_scale: " << (force_sdxl_vae_conv_scale ? "true" : "false") << "\n" << " force_sdxl_vae_conv_scale: " << (force_sdxl_vae_conv_scale ? "true" : "false") << "\n"
<< "}"; << "}";
return oss.str(); return oss.str();
@ -1061,6 +990,8 @@ struct SDGenerationParams {
int64_t seed = 42; int64_t seed = 42;
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
// Photo Maker // Photo Maker
std::string pm_id_images_dir; std::string pm_id_images_dir;
std::string pm_id_embed_path; std::string pm_id_embed_path;
@ -1251,6 +1182,10 @@ struct SDGenerationParams {
"--vace-strength", "--vace-strength",
"wan vace strength", "wan vace strength",
&vace_strength}, &vace_strength},
{"",
"--vae-tile-overlap",
"tile overlap for vae tiling, in fraction of tile size (default: 0.5)",
&vae_tiling_params.target_overlap},
}; };
options.bool_options = { options.bool_options = {
@ -1264,6 +1199,10 @@ struct SDGenerationParams {
"disable auto resize of ref images", "disable auto resize of ref images",
false, false,
&auto_resize_ref_image}, &auto_resize_ref_image},
{"",
"--vae-tiling",
"process vae in tiles to reduce memory usage",
true, &vae_tiling_params.enabled},
}; };
auto on_seed_arg = [&](int argc, const char** argv, int index) { auto on_seed_arg = [&](int argc, const char** argv, int index) {
@ -1460,6 +1399,52 @@ struct SDGenerationParams {
return 1; return 1;
}; };
auto on_tile_size_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
std::string tile_size_str = argv[index];
size_t x_pos = tile_size_str.find('x');
try {
if (x_pos != std::string::npos) {
std::string tile_x_str = tile_size_str.substr(0, x_pos);
std::string tile_y_str = tile_size_str.substr(x_pos + 1);
vae_tiling_params.tile_size_x = std::stoi(tile_x_str);
vae_tiling_params.tile_size_y = std::stoi(tile_y_str);
} else {
vae_tiling_params.tile_size_x = vae_tiling_params.tile_size_y = std::stoi(tile_size_str);
}
} catch (const std::invalid_argument&) {
return -1;
} catch (const std::out_of_range&) {
return -1;
}
return 1;
};
auto on_relative_tile_size_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
std::string rel_size_str = argv[index];
size_t x_pos = rel_size_str.find('x');
try {
if (x_pos != std::string::npos) {
std::string rel_x_str = rel_size_str.substr(0, x_pos);
std::string rel_y_str = rel_size_str.substr(x_pos + 1);
vae_tiling_params.rel_size_x = std::stof(rel_x_str);
vae_tiling_params.rel_size_y = std::stof(rel_y_str);
} else {
vae_tiling_params.rel_size_x = vae_tiling_params.rel_size_y = std::stof(rel_size_str);
}
} catch (const std::invalid_argument&) {
return -1;
} catch (const std::out_of_range&) {
return -1;
}
return 1;
};
options.manual_options = { options.manual_options = {
{"-s", {"-s",
"--seed", "--seed",
@ -1511,6 +1496,14 @@ struct SDGenerationParams {
"--scm-policy", "--scm-policy",
"SCM policy: 'dynamic' (default) or 'static'", "SCM policy: 'dynamic' (default) or 'static'",
on_scm_policy_arg}, on_scm_policy_arg},
{"",
"--vae-tile-size",
"tile size for vae tiling, format [X]x[Y] (default: 32x32)",
on_tile_size_arg},
{"",
"--vae-relative-tile-size",
"relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)",
on_relative_tile_size_arg},
}; };
@ -1946,6 +1939,13 @@ struct SDGenerationParams {
<< " seed: " << seed << ",\n" << " seed: " << seed << ",\n"
<< " upscale_repeats: " << upscale_repeats << ",\n" << " upscale_repeats: " << upscale_repeats << ",\n"
<< " upscale_tile_size: " << upscale_tile_size << ",\n" << " upscale_tile_size: " << upscale_tile_size << ",\n"
<< " vae_tiling_params: { "
<< vae_tiling_params.enabled << ", "
<< vae_tiling_params.tile_size_x << ", "
<< vae_tiling_params.tile_size_y << ", "
<< vae_tiling_params.target_overlap << ", "
<< vae_tiling_params.rel_size_x << ", "
<< vae_tiling_params.rel_size_y << " },\n"
<< "}"; << "}";
free(sample_params_str); free(sample_params_str);
free(high_noise_sample_params_str); free(high_noise_sample_params_str);

View File

@ -273,55 +273,20 @@ struct LoraEntry {
std::string fullpath; std::string fullpath;
}; };
void free_results(sd_image_t* result_images, int num_results) { struct ServerRuntime {
if (result_images) { sd_ctx_t* sd_ctx;
for (int i = 0; i < num_results; ++i) { std::mutex* sd_ctx_mutex;
if (result_images[i].data) { const SDSvrParams* svr_params;
stbi_image_free(result_images[i].data); const SDContextParams* ctx_params;
result_images[i].data = nullptr; const SDGenerationParams* default_gen_params;
} std::vector<LoraEntry>* lora_cache;
} std::mutex* lora_mutex;
} };
free(result_images);
}
int main(int argc, const char** argv) { void refresh_lora_cache(ServerRuntime& rt) {
if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n";
return EXIT_SUCCESS;
}
SDSvrParams svr_params;
SDContextParams ctx_params;
SDGenerationParams default_gen_params;
parse_args(argc, argv, svr_params, ctx_params, default_gen_params);
sd_set_log_callback(sd_log_cb, (void*)&svr_params);
log_verbose = svr_params.verbose;
log_color = svr_params.color;
LOG_DEBUG("version: %s", version_string().c_str());
LOG_DEBUG("%s", sd_get_system_info());
LOG_DEBUG("%s", svr_params.to_string().c_str());
LOG_DEBUG("%s", ctx_params.to_string().c_str());
LOG_DEBUG("%s", default_gen_params.to_string().c_str());
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false);
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
if (sd_ctx == nullptr) {
LOG_ERROR("new_sd_ctx_t failed");
return 1;
}
std::mutex sd_ctx_mutex;
std::vector<LoraEntry> lora_cache;
std::mutex lora_mutex;
auto refresh_lora_cache = [&]() {
std::vector<LoraEntry> new_cache; std::vector<LoraEntry> new_cache;
fs::path lora_dir = ctx_params.lora_model_dir; fs::path lora_dir = rt.ctx_params->lora_model_dir;
if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) { if (fs::exists(lora_dir) && fs::is_directory(lora_dir)) {
auto is_lora_ext = [](const fs::path& p) { auto is_lora_ext = [](const fs::path& p) {
auto ext = p.extension().string(); auto ext = p.extension().string();
@ -353,47 +318,35 @@ int main(int argc, const char** argv) {
}); });
{ {
std::lock_guard<std::mutex> lock(lora_mutex); std::lock_guard<std::mutex> lock(*rt.lora_mutex);
lora_cache = std::move(new_cache); *rt.lora_cache = std::move(new_cache);
} }
}; }
auto get_lora_full_path = [&](const std::string& path) -> std::string { std::string get_lora_full_path(ServerRuntime& rt, const std::string& path) {
std::lock_guard<std::mutex> lock(lora_mutex); std::lock_guard<std::mutex> lock(*rt.lora_mutex);
auto it = std::find_if(lora_cache.begin(), lora_cache.end(), auto it = std::find_if(rt.lora_cache->begin(), rt.lora_cache->end(),
[&](const LoraEntry& e) { return e.path == path; }); [&](const LoraEntry& e) { return e.path == path; });
return (it != lora_cache.end()) ? it->fullpath : ""; return (it != rt.lora_cache->end()) ? it->fullpath : "";
}; }
httplib::Server svr; void free_results(sd_image_t* result_images, int num_results) {
if (result_images) {
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) { for (int i = 0; i < num_results; ++i) {
std::string origin = req.get_header_value("Origin"); if (result_images[i].data) {
if (origin.empty()) { stbi_image_free(result_images[i].data);
origin = "*"; result_images[i].data = nullptr;
} }
res.set_header("Access-Control-Allow-Origin", origin);
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "*");
res.set_header("Access-Control-Allow-Headers", "*");
if (req.method == "OPTIONS") {
res.status = 204;
return httplib::Server::HandlerResponse::Handled;
} }
return httplib::Server::HandlerResponse::Unhandled; }
}); free(result_images);
}
// index html void register_index_endpoints(httplib::Server& svr, const SDSvrParams& svr_params, const std::string& index_html) {
std::string index_html; const std::string serve_html_path = svr_params.serve_html_path;
#ifdef HAVE_INDEX_HTML svr.Get("/", [serve_html_path, index_html](const httplib::Request&, httplib::Response& res) {
index_html.assign(reinterpret_cast<const char*>(index_html_bytes), index_html_size); if (!serve_html_path.empty()) {
#else std::ifstream file(serve_html_path);
index_html = "Stable Diffusion Server is running";
#endif
svr.Get("/", [&](const httplib::Request&, httplib::Response& res) {
if (!svr_params.serve_html_path.empty()) {
std::ifstream file(svr_params.serve_html_path);
if (file) { if (file) {
std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>()); std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
res.set_content(content, "text/html"); res.set_content(content, "text/html");
@ -405,17 +358,19 @@ int main(int argc, const char** argv) {
res.set_content(index_html, "text/html"); res.set_content(index_html, "text/html");
} }
}); });
}
// models endpoint (minimal) void register_openai_api_endpoints(httplib::Server& svr, ServerRuntime& rt) {
svr.Get("/v1/models", [&](const httplib::Request&, httplib::Response& res) { ServerRuntime* runtime = &rt;
svr.Get("/v1/models", [runtime](const httplib::Request&, httplib::Response& res) {
json r; json r;
r["data"] = json::array(); r["data"] = json::array();
r["data"].push_back({{"id", "sd-cpp-local"}, {"object", "model"}, {"owned_by", "local"}}); r["data"].push_back({{"id", "sd-cpp-local"}, {"object", "model"}, {"owned_by", "local"}});
res.set_content(r.dump(), "application/json"); res.set_content(r.dump(), "application/json");
}); });
// core endpoint: /v1/images/generations svr.Post("/v1/images/generations", [runtime](const httplib::Request& req, httplib::Response& res) {
svr.Post("/v1/images/generations", [&](const httplib::Request& req, httplib::Response& res) {
try { try {
if (req.body.empty()) { if (req.body.empty()) {
res.status = 400; res.status = 400;
@ -429,8 +384,8 @@ int main(int argc, const char** argv) {
std::string size = j.value("size", ""); std::string size = j.value("size", "");
std::string output_format = j.value("output_format", "png"); std::string output_format = j.value("output_format", "png");
int output_compression = j.value("output_compression", 100); int output_compression = j.value("output_compression", 100);
int width = default_gen_params.width > 0 ? default_gen_params.width : 512; int width = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->width : 512;
int height = default_gen_params.width > 0 ? default_gen_params.height : 512; int height = runtime->default_gen_params->width > 0 ? runtime->default_gen_params->height : 512;
if (!size.empty()) { if (!size.empty()) {
auto pos = size.find('x'); auto pos = size.find('x');
if (pos != std::string::npos) { if (pos != std::string::npos) {
@ -458,7 +413,7 @@ int main(int argc, const char** argv) {
if (n <= 0) if (n <= 0)
n = 1; n = 1;
if (n > 8) if (n > 8)
n = 8; // safety n = 8;
if (output_compression > 100) { if (output_compression > 100) {
output_compression = 100; output_compression = 100;
} }
@ -471,7 +426,7 @@ int main(int argc, const char** argv) {
out["data"] = json::array(); out["data"] = json::array();
out["output_format"] = output_format; out["output_format"] = output_format;
SDGenerationParams gen_params = default_gen_params; SDGenerationParams gen_params = *runtime->default_gen_params;
gen_params.prompt = prompt; gen_params.prompt = prompt;
gen_params.width = width; gen_params.width = width;
gen_params.height = height; gen_params.height = height;
@ -524,8 +479,8 @@ int main(int argc, const char** argv) {
(int)pmid_images.size(), (int)pmid_images.size(),
gen_params.pm_id_embed_path.c_str(), gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength, gen_params.pm_style_strength,
}, // pm_params },
ctx_params.vae_tiling_params, gen_params.vae_tiling_params,
gen_params.cache_params, gen_params.cache_params,
}; };
@ -533,8 +488,8 @@ int main(int argc, const char** argv) {
int num_results = 0; int num_results = 0;
{ {
std::lock_guard<std::mutex> lock(sd_ctx_mutex); std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params); results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
} }
@ -553,7 +508,6 @@ int main(int argc, const char** argv) {
continue; continue;
} }
// base64 encode
std::string b64 = base64_encode(image_bytes); std::string b64 = base64_encode(image_bytes);
json item; json item;
item["b64_json"] = b64; item["b64_json"] = b64;
@ -573,7 +527,7 @@ int main(int argc, const char** argv) {
} }
}); });
svr.Post("/v1/images/edits", [&](const httplib::Request& req, httplib::Response& res) { svr.Post("/v1/images/edits", [runtime](const httplib::Request& req, httplib::Response& res) {
try { try {
if (!req.is_multipart_form_data()) { if (!req.is_multipart_form_data()) {
res.status = 400; res.status = 400;
@ -658,7 +612,7 @@ int main(int argc, const char** argv) {
output_compression = 0; output_compression = 0;
} }
SDGenerationParams gen_params = default_gen_params; SDGenerationParams gen_params = *runtime->default_gen_params;
gen_params.prompt = prompt; gen_params.prompt = prompt;
gen_params.width = width; gen_params.width = width;
gen_params.height = height; gen_params.height = height;
@ -685,18 +639,18 @@ int main(int argc, const char** argv) {
sd_image_t control_image = {0, 0, 3, nullptr}; sd_image_t control_image = {0, 0, 3, nullptr};
std::vector<sd_image_t> pmid_images; std::vector<sd_image_t> pmid_images;
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int { auto get_resolved_width = [&gen_params, runtime]() -> int {
if (gen_params.width > 0) if (gen_params.width > 0)
return gen_params.width; return gen_params.width;
if (default_gen_params.width > 0) if (runtime->default_gen_params->width > 0)
return default_gen_params.width; return runtime->default_gen_params->width;
return 512; return 512;
}; };
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int { auto get_resolved_height = [&gen_params, runtime]() -> int {
if (gen_params.height > 0) if (gen_params.height > 0)
return gen_params.height; return gen_params.height;
if (default_gen_params.height > 0) if (runtime->default_gen_params->height > 0)
return default_gen_params.height; return runtime->default_gen_params->height;
return 512; return 512;
}; };
@ -771,8 +725,8 @@ int main(int argc, const char** argv) {
(int)pmid_images.size(), (int)pmid_images.size(),
gen_params.pm_id_embed_path.c_str(), gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength, gen_params.pm_style_strength,
}, // pm_params },
ctx_params.vae_tiling_params, gen_params.vae_tiling_params,
gen_params.cache_params, gen_params.cache_params,
}; };
@ -780,8 +734,8 @@ int main(int argc, const char** argv) {
int num_results = 0; int num_results = 0;
{ {
std::lock_guard<std::mutex> lock(sd_ctx_mutex); std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params); results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
} }
@ -826,10 +780,12 @@ int main(int argc, const char** argv) {
res.set_content(err.dump(), "application/json"); res.set_content(err.dump(), "application/json");
} }
}); });
}
// sdapi endpoints (AUTOMATIC1111 / Forge) void register_sdapi_endpoints(httplib::Server& svr, ServerRuntime& rt) {
ServerRuntime* runtime = &rt;
auto sdapi_any2img = [&](const httplib::Request& req, httplib::Response& res, bool img2img) { auto sdapi_any2img = [runtime](const httplib::Request& req, httplib::Response& res, bool img2img) {
try { try {
if (req.body.empty()) { if (req.body.empty()) {
res.status = 400; res.status = 400;
@ -843,8 +799,8 @@ int main(int argc, const char** argv) {
std::string negative_prompt = j.value("negative_prompt", ""); std::string negative_prompt = j.value("negative_prompt", "");
int width = j.value("width", 512); int width = j.value("width", 512);
int height = j.value("height", 512); int height = j.value("height", 512);
int steps = j.value("steps", default_gen_params.sample_params.sample_steps); int steps = j.value("steps", runtime->default_gen_params->sample_params.sample_steps);
float cfg_scale = j.value("cfg_scale", default_gen_params.sample_params.guidance.txt_cfg); float cfg_scale = j.value("cfg_scale", runtime->default_gen_params->sample_params.guidance.txt_cfg);
int64_t seed = j.value("seed", -1); int64_t seed = j.value("seed", -1);
int batch_size = j.value("batch_size", 1); int batch_size = j.value("batch_size", 1);
int clip_skip = j.value("clip_skip", -1); int clip_skip = j.value("clip_skip", -1);
@ -894,7 +850,7 @@ int main(int argc, const char** argv) {
return bad("lora.path required"); return bad("lora.path required");
} }
std::string fullpath = get_lora_full_path(path); std::string fullpath = get_lora_full_path(*runtime, path);
if (fullpath.empty()) { if (fullpath.empty()) {
return bad("invalid lora path: " + path); return bad("invalid lora path: " + path);
} }
@ -912,7 +868,6 @@ int main(int argc, const char** argv) {
auto get_sample_method = [](std::string name) -> enum sample_method_t { auto get_sample_method = [](std::string name) -> enum sample_method_t {
enum sample_method_t result = str_to_sample_method(name.c_str()); enum sample_method_t result = str_to_sample_method(name.c_str());
if (result != SAMPLE_METHOD_COUNT) return result; if (result != SAMPLE_METHOD_COUNT) return result;
// some applications use a hardcoded sampler list
std::transform(name.begin(), name.end(), name.begin(), std::transform(name.begin(), name.end(), name.begin(),
[](unsigned char c) { return std::tolower(c); }); [](unsigned char c) { return std::tolower(c); });
static const std::unordered_map<std::string_view, sample_method_t> hardcoded{ static const std::unordered_map<std::string_view, sample_method_t> hardcoded{
@ -938,10 +893,9 @@ int main(int argc, const char** argv) {
}; };
enum sample_method_t sample_method = get_sample_method(sampler_name); enum sample_method_t sample_method = get_sample_method(sampler_name);
enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str()); enum scheduler_t scheduler = str_to_scheduler(scheduler_name.c_str());
SDGenerationParams gen_params = default_gen_params; SDGenerationParams gen_params = *runtime->default_gen_params;
gen_params.prompt = prompt; gen_params.prompt = prompt;
gen_params.negative_prompt = negative_prompt; gen_params.negative_prompt = negative_prompt;
gen_params.seed = seed; gen_params.seed = seed;
@ -961,8 +915,6 @@ int main(int argc, const char** argv) {
gen_params.sample_params.scheduler = scheduler; gen_params.sample_params.scheduler = scheduler;
} }
// re-read to avoid applying 512 as default before the provided
// images and/or server command-line
gen_params.width = j.value("width", -1); gen_params.width = j.value("width", -1);
gen_params.height = j.value("height", -1); gen_params.height = j.value("height", -1);
@ -975,23 +927,22 @@ int main(int argc, const char** argv) {
std::vector<sd_image_t> pmid_images; std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> ref_images; std::vector<sd_image_t> ref_images;
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int { auto get_resolved_width = [&gen_params, runtime]() -> int {
if (gen_params.width > 0) if (gen_params.width > 0)
return gen_params.width; return gen_params.width;
if (default_gen_params.width > 0) if (runtime->default_gen_params->width > 0)
return default_gen_params.width; return runtime->default_gen_params->width;
return 512; return 512;
}; };
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int { auto get_resolved_height = [&gen_params, runtime]() -> int {
if (gen_params.height > 0) if (gen_params.height > 0)
return gen_params.height; return gen_params.height;
if (default_gen_params.height > 0) if (runtime->default_gen_params->height > 0)
return default_gen_params.height; return runtime->default_gen_params->height;
return 512; return 512;
}; };
auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool { auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
// remove data URI prefix if present ("data:image/png;base64,")
auto comma_pos = encoded.find(','); auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) { if (comma_pos != std::string::npos) {
encoded = encoded.substr(comma_pos + 1); encoded = encoded.substr(comma_pos + 1);
@ -1087,8 +1038,8 @@ int main(int argc, const char** argv) {
(int)pmid_images.size(), (int)pmid_images.size(),
gen_params.pm_id_embed_path.c_str(), gen_params.pm_id_embed_path.c_str(),
gen_params.pm_style_strength, gen_params.pm_style_strength,
}, // pm_params },
ctx_params.vae_tiling_params, gen_params.vae_tiling_params,
gen_params.cache_params, gen_params.cache_params,
}; };
@ -1096,14 +1047,14 @@ int main(int argc, const char** argv) {
int num_results = 0; int num_results = 0;
{ {
std::lock_guard<std::mutex> lock(sd_ctx_mutex); std::lock_guard<std::mutex> lock(*runtime->sd_ctx_mutex);
results = generate_image(sd_ctx, &img_gen_params); results = generate_image(runtime->sd_ctx, &img_gen_params);
num_results = gen_params.batch_count; num_results = gen_params.batch_count;
} }
json out; json out;
out["images"] = json::array(); out["images"] = json::array();
out["parameters"] = j; // TODO should return changed defaults out["parameters"] = j;
out["info"] = ""; out["info"] = "";
for (int i = 0; i < num_results; i++) { for (int i = 0; i < num_results; i++) {
@ -1149,21 +1100,21 @@ int main(int argc, const char** argv) {
} }
}; };
svr.Post("/sdapi/v1/txt2img", [&](const httplib::Request& req, httplib::Response& res) { svr.Post("/sdapi/v1/txt2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, false); sdapi_any2img(req, res, false);
}); });
svr.Post("/sdapi/v1/img2img", [&](const httplib::Request& req, httplib::Response& res) { svr.Post("/sdapi/v1/img2img", [sdapi_any2img](const httplib::Request& req, httplib::Response& res) {
sdapi_any2img(req, res, true); sdapi_any2img(req, res, true);
}); });
svr.Get("/sdapi/v1/loras", [&](const httplib::Request&, httplib::Response& res) { svr.Get("/sdapi/v1/loras", [runtime](const httplib::Request&, httplib::Response& res) {
refresh_lora_cache(); refresh_lora_cache(*runtime);
json result = json::array(); json result = json::array();
{ {
std::lock_guard<std::mutex> lock(lora_mutex); std::lock_guard<std::mutex> lock(*runtime->lora_mutex);
for (const auto& e : lora_cache) { for (const auto& e : *runtime->lora_cache) {
json item; json item;
item["name"] = e.name; item["name"] = e.name;
item["path"] = e.path; item["path"] = e.path;
@ -1174,7 +1125,7 @@ int main(int argc, const char** argv) {
res.set_content(result.dump(), "application/json"); res.set_content(result.dump(), "application/json");
}); });
svr.Get("/sdapi/v1/samplers", [&](const httplib::Request&, httplib::Response& res) { svr.Get("/sdapi/v1/samplers", [runtime](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> sampler_names; std::vector<std::string> sampler_names;
sampler_names.push_back("default"); sampler_names.push_back("default");
for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) { for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) {
@ -1191,7 +1142,7 @@ int main(int argc, const char** argv) {
res.set_content(r.dump(), "application/json"); res.set_content(r.dump(), "application/json");
}); });
svr.Get("/sdapi/v1/schedulers", [&](const httplib::Request&, httplib::Response& res) { svr.Get("/sdapi/v1/schedulers", [runtime](const httplib::Request&, httplib::Response& res) {
std::vector<std::string> scheduler_names; std::vector<std::string> scheduler_names;
scheduler_names.push_back("default"); scheduler_names.push_back("default");
for (int i = 0; i < SCHEDULER_COUNT; i++) { for (int i = 0; i < SCHEDULER_COUNT; i++) {
@ -1207,8 +1158,8 @@ int main(int argc, const char** argv) {
res.set_content(r.dump(), "application/json"); res.set_content(r.dump(), "application/json");
}); });
svr.Get("/sdapi/v1/sd-models", [&](const httplib::Request&, httplib::Response& res) { svr.Get("/sdapi/v1/sd-models", [runtime](const httplib::Request&, httplib::Response& res) {
fs::path model_path = ctx_params.model_path; fs::path model_path = runtime->ctx_params->model_path;
json entry; json entry;
entry["title"] = model_path.stem(); entry["title"] = model_path.stem();
entry["model_name"] = model_path.stem(); entry["model_name"] = model_path.stem();
@ -1221,18 +1172,89 @@ int main(int argc, const char** argv) {
res.set_content(r.dump(), "application/json"); res.set_content(r.dump(), "application/json");
}); });
svr.Get("/sdapi/v1/options", [&](const httplib::Request&, httplib::Response& res) { svr.Get("/sdapi/v1/options", [runtime](const httplib::Request&, httplib::Response& res) {
fs::path model_path = ctx_params.model_path; fs::path model_path = runtime->ctx_params->model_path;
json r; json r;
r["samples_format"] = "png"; r["samples_format"] = "png";
r["sd_model_checkpoint"] = model_path.stem(); r["sd_model_checkpoint"] = model_path.stem();
res.set_content(r.dump(), "application/json"); res.set_content(r.dump(), "application/json");
}); });
}
int main(int argc, const char** argv) {
if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n";
return EXIT_SUCCESS;
}
SDSvrParams svr_params;
SDContextParams ctx_params;
SDGenerationParams default_gen_params;
parse_args(argc, argv, svr_params, ctx_params, default_gen_params);
sd_set_log_callback(sd_log_cb, (void*)&svr_params);
log_verbose = svr_params.verbose;
log_color = svr_params.color;
LOG_DEBUG("version: %s", version_string().c_str());
LOG_DEBUG("%s", sd_get_system_info());
LOG_DEBUG("%s", svr_params.to_string().c_str());
LOG_DEBUG("%s", ctx_params.to_string().c_str());
LOG_DEBUG("%s", default_gen_params.to_string().c_str());
sd_ctx_params_t sd_ctx_params = ctx_params.to_sd_ctx_params_t(false, false, false);
sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params);
if (sd_ctx == nullptr) {
LOG_ERROR("new_sd_ctx_t failed");
return 1;
}
std::mutex sd_ctx_mutex;
std::vector<LoraEntry> lora_cache;
std::mutex lora_mutex;
ServerRuntime runtime = {
sd_ctx,
&sd_ctx_mutex,
&svr_params,
&ctx_params,
&default_gen_params,
&lora_cache,
&lora_mutex,
};
httplib::Server svr;
svr.set_pre_routing_handler([](const httplib::Request& req, httplib::Response& res) {
std::string origin = req.get_header_value("Origin");
if (origin.empty()) {
origin = "*";
}
res.set_header("Access-Control-Allow-Origin", origin);
res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "*");
res.set_header("Access-Control-Allow-Headers", "*");
if (req.method == "OPTIONS") {
res.status = 204;
return httplib::Server::HandlerResponse::Handled;
}
return httplib::Server::HandlerResponse::Unhandled;
});
std::string index_html;
#ifdef HAVE_INDEX_HTML
index_html.assign(reinterpret_cast<const char*>(index_html_bytes), index_html_size);
#else
index_html = "Stable Diffusion Server is running";
#endif
register_index_endpoints(svr, svr_params, index_html);
register_openai_api_endpoints(svr, runtime);
register_sdapi_endpoints(svr, runtime);
LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port); LOG_INFO("listening on: %s:%d\n", svr_params.listen_ip.c_str(), svr_params.listen_port);
svr.listen(svr_params.listen_ip, svr_params.listen_port); svr.listen(svr_params.listen_ip, svr_params.listen_port);
// cleanup
free_sd_ctx(sd_ctx); free_sd_ctx(sd_ctx);
return 0; return 0;
} }

2
ggml

@ -1 +1 @@
Subproject commit a8db410a252c8c8f2d120c6f2e7133ebe032f35d Subproject commit 404fcb9d7c96989569e68c9e7881ee3465a05c50

View File

@ -120,7 +120,8 @@ enum sd_type_t {
// SD_TYPE_IQ4_NL_4_8 = 37, // SD_TYPE_IQ4_NL_4_8 = 37,
// SD_TYPE_IQ4_NL_8_8 = 38, // SD_TYPE_IQ4_NL_8_8 = 38,
SD_TYPE_MXFP4 = 39, // MXFP4 (1 block) SD_TYPE_MXFP4 = 39, // MXFP4 (1 block)
SD_TYPE_COUNT = 40, SD_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
SD_TYPE_COUNT = 41,
}; };
enum sd_log_level_t { enum sd_log_level_t {

View File

@ -1313,14 +1313,14 @@ struct T5CLIPEmbedder : public Conditioner {
std::shared_ptr<T5Runner> t5; std::shared_ptr<T5Runner> t5;
size_t chunk_len = 512; size_t chunk_len = 512;
bool use_mask = false; bool use_mask = false;
int mask_pad = 1; int mask_pad = 0;
bool is_umt5 = false; bool is_umt5 = false;
T5CLIPEmbedder(ggml_backend_t backend, T5CLIPEmbedder(ggml_backend_t backend,
bool offload_params_to_cpu, bool offload_params_to_cpu,
const String2TensorStorage& tensor_storage_map = {}, const String2TensorStorage& tensor_storage_map = {},
bool use_mask = false, bool use_mask = false,
int mask_pad = 1, int mask_pad = 0,
bool is_umt5 = false) bool is_umt5 = false)
: use_mask(use_mask), mask_pad(mask_pad), t5_tokenizer(is_umt5) { : use_mask(use_mask), mask_pad(mask_pad), t5_tokenizer(is_umt5) {
bool use_t5 = false; bool use_t5 = false;

View File

@ -2,6 +2,7 @@
#define __DENOISER_HPP__ #define __DENOISER_HPP__
#include <cmath> #include <cmath>
#include <utility>
#include "ggml_extend.hpp" #include "ggml_extend.hpp"
#include "gits_noise.inl" #include "gits_noise.inl"
@ -763,16 +764,33 @@ struct Flux2FlowDenoiser : public FluxFlowDenoiser {
typedef std::function<sd::Tensor<float>(const sd::Tensor<float>&, float, int)> denoise_cb_t; typedef std::function<sd::Tensor<float>(const sd::Tensor<float>&, float, int)> denoise_cb_t;
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t static std::pair<float, float> get_ancestral_step(float sigma_from,
static sd::Tensor<float> sample_k_diffusion(sample_method_t method, float sigma_to,
denoise_cb_t model, float eta = 1.0f) {
float sigma_up = 0.0f;
float sigma_down = sigma_to;
if (eta <= 0.0f) {
return {sigma_down, sigma_up};
}
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
sigma_up = std::min(sigma_to, eta * std::sqrt(std::max(term, 0.0f)));
}
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
return {sigma_down, sigma_up};
}
static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
sd::Tensor<float> x, sd::Tensor<float> x,
std::vector<float> sigmas, const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng, std::shared_ptr<RNG> rng) {
float eta) { int steps = static_cast<int>(sigmas.size()) - 1;
size_t steps = sigmas.size() - 1;
switch (method) {
case EULER_A_SAMPLE_METHOD: {
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
float sigma = sigmas[i]; float sigma = sigmas[i];
auto denoised_opt = model(x, sigma, i + 1); auto denoised_opt = model(x, sigma, i + 1);
@ -781,18 +799,19 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
} }
sd::Tensor<float> denoised = std::move(denoised_opt); sd::Tensor<float> denoised = std::move(denoised_opt);
sd::Tensor<float> d = (x - denoised) / sigma; sd::Tensor<float> d = (x - denoised) / sigma;
float sigma_up = std::min(sigmas[i + 1], auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1]);
std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i]))); x += d * (sigma_down - sigmas[i]);
float sigma_down = std::sqrt(sigmas[i + 1] * sigmas[i + 1] - sigma_up * sigma_up);
float dt = sigma_down - sigmas[i];
x += d * dt;
if (sigmas[i + 1] > 0) { if (sigmas[i + 1] > 0) {
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up; x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
} }
} }
return x; return x;
} }
case EULER_SAMPLE_METHOD: {
static sd::Tensor<float> sample_euler(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
float sigma = sigmas[i]; float sigma = sigmas[i];
auto denoised_opt = model(x, sigma, i + 1); auto denoised_opt = model(x, sigma, i + 1);
@ -801,12 +820,15 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
} }
sd::Tensor<float> denoised = std::move(denoised_opt); sd::Tensor<float> denoised = std::move(denoised_opt);
sd::Tensor<float> d = (x - denoised) / sigma; sd::Tensor<float> d = (x - denoised) / sigma;
float dt = sigmas[i + 1] - sigma; x += d * (sigmas[i + 1] - sigma);
x += d * dt;
} }
return x; return x;
} }
case HEUN_SAMPLE_METHOD: {
static sd::Tensor<float> sample_heun(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], -(i + 1)); auto denoised_opt = model(x, sigmas[i], -(i + 1));
if (denoised_opt.empty()) { if (denoised_opt.empty()) {
@ -829,8 +851,12 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
} }
} }
return x; return x;
} }
case DPM2_SAMPLE_METHOD: {
static sd::Tensor<float> sample_dpm2(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], -(i + 1)); auto denoised_opt = model(x, sigmas[i], -(i + 1));
if (denoised_opt.empty()) { if (denoised_opt.empty()) {
@ -839,8 +865,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
sd::Tensor<float> denoised = std::move(denoised_opt); sd::Tensor<float> denoised = std::move(denoised_opt);
sd::Tensor<float> d = (x - denoised) / sigmas[i]; sd::Tensor<float> d = (x - denoised) / sigmas[i];
if (sigmas[i + 1] == 0) { if (sigmas[i + 1] == 0) {
float dt = sigmas[i + 1] - sigmas[i]; x += d * (sigmas[i + 1] - sigmas[i]);
x += d * dt;
} else { } else {
float sigma_mid = exp(0.5f * (log(sigmas[i]) + log(sigmas[i + 1]))); float sigma_mid = exp(0.5f * (log(sigmas[i]) + log(sigmas[i + 1])));
float dt_1 = sigma_mid - sigmas[i]; float dt_1 = sigma_mid - sigmas[i];
@ -855,19 +880,23 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
} }
} }
return x; return x;
} }
case DPMPP2S_A_SAMPLE_METHOD: {
static sd::Tensor<float> sample_dpmpp_2s_ancestral(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng) {
auto t_fn = [](float sigma) -> float { return -log(sigma); };
auto sigma_fn = [](float t) -> float { return exp(-t); };
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], -(i + 1)); auto denoised_opt = model(x, sigmas[i], -(i + 1));
if (denoised_opt.empty()) { if (denoised_opt.empty()) {
return {}; return {};
} }
sd::Tensor<float> denoised = std::move(denoised_opt); sd::Tensor<float> denoised = std::move(denoised_opt);
float sigma_up = std::min(sigmas[i + 1], auto [sigma_down, sigma_up] = get_ancestral_step(sigmas[i], sigmas[i + 1]);
std::sqrt(sigmas[i + 1] * sigmas[i + 1] * (sigmas[i] * sigmas[i] - sigmas[i + 1] * sigmas[i + 1]) / (sigmas[i] * sigmas[i])));
float sigma_down = std::sqrt(sigmas[i + 1] * sigmas[i + 1] - sigma_up * sigma_up);
auto t_fn = [](float sigma) -> float { return -log(sigma); };
auto sigma_fn = [](float t) -> float { return exp(-t); };
if (sigma_down == 0) { if (sigma_down == 0) {
x = denoised; x = denoised;
@ -882,7 +911,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
return {}; return {};
} }
sd::Tensor<float> denoised2 = std::move(denoised2_opt); sd::Tensor<float> denoised2 = std::move(denoised2_opt);
x = (sigma_fn(t_next) / sigma_fn(t)) * (x) - (exp(-h) - 1) * denoised2; x = (sigma_fn(t_next) / sigma_fn(t)) * x - (exp(-h) - 1) * denoised2;
} }
if (sigmas[i + 1] > 0) { if (sigmas[i + 1] > 0) {
@ -890,10 +919,15 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
} }
} }
return x; return x;
} }
case DPMPP2M_SAMPLE_METHOD: {
static sd::Tensor<float> sample_dpmpp_2m(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
sd::Tensor<float> old_denoised = x; sd::Tensor<float> old_denoised = x;
auto t_fn = [](float sigma) -> float { return -log(sigma); }; auto t_fn = [](float sigma) -> float { return -log(sigma); };
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], i + 1); auto denoised_opt = model(x, sigmas[i], i + 1);
if (denoised_opt.empty()) { if (denoised_opt.empty()) {
@ -907,20 +941,25 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float b = exp(-h) - 1.f; float b = exp(-h) - 1.f;
if (i == 0 || sigmas[i + 1] == 0) { if (i == 0 || sigmas[i + 1] == 0) {
x = a * (x)-b * denoised; x = a * x - b * denoised;
} else { } else {
float h_last = t - t_fn(sigmas[i - 1]); float h_last = t - t_fn(sigmas[i - 1]);
float r = h_last / h; float r = h_last / h;
sd::Tensor<float> denoised_d = (1.f + 1.f / (2.f * r)) * denoised - (1.f / (2.f * r)) * old_denoised; sd::Tensor<float> denoised_d = (1.f + 1.f / (2.f * r)) * denoised - (1.f / (2.f * r)) * old_denoised;
x = a * (x)-b * denoised_d; x = a * x - b * denoised_d;
} }
old_denoised = denoised; old_denoised = denoised;
} }
return x; return x;
} }
case DPMPP2Mv2_SAMPLE_METHOD: {
static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
sd::Tensor<float> old_denoised = x; sd::Tensor<float> old_denoised = x;
auto t_fn = [](float sigma) -> float { return -log(sigma); }; auto t_fn = [](float sigma) -> float { return -log(sigma); };
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], i + 1); auto denoised_opt = model(x, sigmas[i], i + 1);
if (denoised_opt.empty()) { if (denoised_opt.empty()) {
@ -931,9 +970,10 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float t_next = t_fn(sigmas[i + 1]); float t_next = t_fn(sigmas[i + 1]);
float h = t_next - t; float h = t_next - t;
float a = sigmas[i + 1] / sigmas[i]; float a = sigmas[i + 1] / sigmas[i];
if (i == 0 || sigmas[i + 1] == 0) { if (i == 0 || sigmas[i + 1] == 0) {
float b = exp(-h) - 1.f; float b = exp(-h) - 1.f;
x = a * (x)-b * denoised; x = a * x - b * denoised;
} else { } else {
float h_last = t - t_fn(sigmas[i - 1]); float h_last = t - t_fn(sigmas[i - 1]);
float h_min = std::min(h_last, h); float h_min = std::min(h_last, h);
@ -942,30 +982,38 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float h_d = (h_max + h_min) / 2.f; float h_d = (h_max + h_min) / 2.f;
float b = exp(-h_d) - 1.f; float b = exp(-h_d) - 1.f;
sd::Tensor<float> denoised_d = (1.f + 1.f / (2.f * r)) * denoised - (1.f / (2.f * r)) * old_denoised; sd::Tensor<float> denoised_d = (1.f + 1.f / (2.f * r)) * denoised - (1.f / (2.f * r)) * old_denoised;
x = a * (x)-b * denoised_d; x = a * x - b * denoised_d;
} }
old_denoised = denoised; old_denoised = denoised;
} }
return x; return x;
} }
case LCM_SAMPLE_METHOD: {
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng) {
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], i + 1); auto denoised_opt = model(x, sigmas[i], i + 1);
if (denoised_opt.empty()) { if (denoised_opt.empty()) {
return {}; return {};
} }
sd::Tensor<float> denoised = std::move(denoised_opt); x = std::move(denoised_opt);
x = denoised;
if (sigmas[i + 1] > 0) { if (sigmas[i + 1] > 0) {
x += sd::Tensor<float>::randn_like(x, rng) * sigmas[i + 1]; x += sd::Tensor<float>::randn_like(x, rng) * sigmas[i + 1];
} }
} }
return x; return x;
} }
case IPNDM_SAMPLE_METHOD: {
int max_order = 4; static sd::Tensor<float> sample_ipndm(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
const int max_order = 4;
std::vector<sd::Tensor<float>> hist = {}; std::vector<sd::Tensor<float>> hist = {};
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
float sigma = sigmas[i]; float sigma = sigmas[i];
float sigma_next = sigmas[i + 1]; float sigma_next = sigmas[i + 1];
@ -1001,10 +1049,15 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
hist.push_back(std::move(d_cur)); hist.push_back(std::move(d_cur));
} }
return x; return x;
} }
case IPNDM_V_SAMPLE_METHOD: {
int max_order = 4; static sd::Tensor<float> sample_ipndm_v(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas) {
const int max_order = 4;
std::vector<sd::Tensor<float>> hist = {}; std::vector<sd::Tensor<float>> hist = {};
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
float sigma = sigmas[i]; float sigma = sigmas[i];
float t_next = sigmas[i + 1]; float t_next = sigmas[i + 1];
@ -1041,10 +1094,14 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
hist.push_back(std::move(d_cur)); hist.push_back(std::move(d_cur));
} }
return x; return x;
} }
case RES_MULTISTEP_SAMPLE_METHOD: {
sd::Tensor<float> old_denoised = x;
static sd::Tensor<float> sample_res_multistep(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
sd::Tensor<float> old_denoised = x;
bool have_old_sigma = false; bool have_old_sigma = false;
float old_sigma_down = 0.0f; float old_sigma_down = 0.0f;
@ -1064,6 +1121,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
return (phi1_val - 1.0f) / t; return (phi1_val - 1.0f) / t;
}; };
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
auto denoised_opt = model(x, sigmas[i], i + 1); auto denoised_opt = model(x, sigmas[i], i + 1);
if (denoised_opt.empty()) { if (denoised_opt.empty()) {
@ -1073,22 +1131,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float sigma_from = sigmas[i]; float sigma_from = sigmas[i];
float sigma_to = sigmas[i + 1]; float sigma_to = sigmas[i + 1];
float sigma_up = 0.0f; auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta);
float sigma_down = sigma_to;
if (eta > 0.0f) {
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
if (term > 0.0f) {
sigma_up = eta * std::sqrt(term);
}
}
sigma_up = std::min(sigma_up, sigma_to);
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
}
if (sigma_down == 0.0f || !have_old_sigma) { if (sigma_down == 0.0f || !have_old_sigma) {
x += ((x - denoised) / sigma_from) * (sigma_down - sigma_from); x += ((x - denoised) / sigma_from) * (sigma_down - sigma_from);
@ -1112,7 +1155,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
b2 = 0.0f; b2 = 0.0f;
} }
x = sigma_fn(h) * (x) + h * (b1 * denoised + b2 * old_denoised); x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised);
} }
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) { if (sigmas[i + 1] > 0 && sigma_up > 0.0f) {
@ -1124,8 +1167,13 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
have_old_sigma = true; have_old_sigma = true;
} }
return x; return x;
} }
case RES_2S_SAMPLE_METHOD: {
static sd::Tensor<float> sample_res_2s(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
const float c2 = 0.5f; const float c2 = 0.5f;
auto t_fn = [](float sigma) -> float { return -logf(sigma); }; auto t_fn = [](float sigma) -> float { return -logf(sigma); };
auto phi1_fn = [](float t) -> float { auto phi1_fn = [](float t) -> float {
@ -1142,6 +1190,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
return (phi1_val - 1.0f) / t; return (phi1_val - 1.0f) / t;
}; };
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
float sigma_from = sigmas[i]; float sigma_from = sigmas[i];
float sigma_to = sigmas[i + 1]; float sigma_to = sigmas[i + 1];
@ -1152,21 +1201,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
} }
sd::Tensor<float> denoised = std::move(denoised_opt); sd::Tensor<float> denoised = std::move(denoised_opt);
float sigma_up = 0.0f; auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta);
float sigma_down = sigma_to;
if (eta > 0.0f) {
float sigma_from_sq = sigma_from * sigma_from;
float sigma_to_sq = sigma_to * sigma_to;
if (sigma_from_sq > 0.0f) {
float term = sigma_to_sq * (sigma_from_sq - sigma_to_sq) / sigma_from_sq;
if (term > 0.0f) {
sigma_up = eta * std::sqrt(term);
}
}
sigma_up = std::min(sigma_up, sigma_to);
float sigma_down_sq = sigma_to_sq - sigma_up * sigma_up;
sigma_down = sigma_down_sq > 0.0f ? std::sqrt(sigma_down_sq) : 0.0f;
}
sd::Tensor<float> x0 = x; sd::Tensor<float> x0 = x;
if (sigma_down == 0.0f || sigma_from == 0.0f) { if (sigma_down == 0.0f || sigma_from == 0.0f) {
@ -1200,8 +1235,13 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
} }
} }
return x; return x;
} }
case DDIM_TRAILING_SAMPLE_METHOD: {
static sd::Tensor<float> sample_ddim_trailing(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
float beta_start = 0.00085f; float beta_start = 0.00085f;
float beta_end = 0.0120f; float beta_end = 0.0120f;
std::vector<double> alphas_cumprod(TIMESTEPS); std::vector<double> alphas_cumprod(TIMESTEPS);
@ -1218,9 +1258,10 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]);
} }
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
int timestep = static_cast<int>(roundf(TIMESTEPS - i * ((float)TIMESTEPS / steps))) - 1; int timestep = static_cast<int>(roundf(TIMESTEPS - i * ((float)TIMESTEPS / steps))) - 1;
int prev_timestep = timestep - TIMESTEPS / static_cast<int>(steps); int prev_timestep = timestep - TIMESTEPS / steps;
float sigma = static_cast<float>(compvis_sigmas[timestep]); float sigma = static_cast<float>(compvis_sigmas[timestep]);
if (i == 0) { if (i == 0) {
x *= std::sqrt(sigma * sigma + 1) / sigma; x *= std::sqrt(sigma * sigma + 1) / sigma;
@ -1256,8 +1297,13 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
} }
} }
return x; return x;
} }
case TCD_SAMPLE_METHOD: {
static sd::Tensor<float> sample_tcd(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
float eta) {
float beta_start = 0.00085f; float beta_start = 0.00085f;
float beta_end = 0.0120f; float beta_end = 0.0120f;
std::vector<double> alphas_cumprod(TIMESTEPS); std::vector<double> alphas_cumprod(TIMESTEPS);
@ -1273,7 +1319,9 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
compvis_sigmas[i] = compvis_sigmas[i] =
std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]); std::sqrt((1 - alphas_cumprod[i]) / alphas_cumprod[i]);
} }
int original_steps = 50; int original_steps = 50;
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) { for (int i = 0; i < steps; i++) {
int timestep = TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor(i * ((float)original_steps / steps)); int timestep = TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor(i * ((float)original_steps / steps));
int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps)); int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps));
@ -1307,12 +1355,49 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
std::sqrt(beta_prod_s) * model_output; std::sqrt(beta_prod_s) * model_output;
if (eta > 0 && i != steps - 1) { if (eta > 0 && i != steps - 1) {
x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * (x) + x = std::sqrt(alpha_prod_t_prev / alpha_prod_s) * x +
std::sqrt(1.0f - alpha_prod_t_prev / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng); std::sqrt(1.0f - alpha_prod_t_prev / alpha_prod_s) * sd::Tensor<float>::randn_like(x, rng);
} }
} }
return x; return x;
} }
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
denoise_cb_t model,
sd::Tensor<float> x,
std::vector<float> sigmas,
std::shared_ptr<RNG> rng,
float eta) {
switch (method) {
case EULER_A_SAMPLE_METHOD:
return sample_euler_ancestral(model, std::move(x), sigmas, rng);
case EULER_SAMPLE_METHOD:
return sample_euler(model, std::move(x), sigmas);
case HEUN_SAMPLE_METHOD:
return sample_heun(model, std::move(x), sigmas);
case DPM2_SAMPLE_METHOD:
return sample_dpm2(model, std::move(x), sigmas);
case DPMPP2S_A_SAMPLE_METHOD:
return sample_dpmpp_2s_ancestral(model, std::move(x), sigmas, rng);
case DPMPP2M_SAMPLE_METHOD:
return sample_dpmpp_2m(model, std::move(x), sigmas);
case DPMPP2Mv2_SAMPLE_METHOD:
return sample_dpmpp_2m_v2(model, std::move(x), sigmas);
case LCM_SAMPLE_METHOD:
return sample_lcm(model, std::move(x), sigmas, rng);
case IPNDM_SAMPLE_METHOD:
return sample_ipndm(model, std::move(x), sigmas);
case IPNDM_V_SAMPLE_METHOD:
return sample_ipndm_v(model, std::move(x), sigmas);
case RES_MULTISTEP_SAMPLE_METHOD:
return sample_res_multistep(model, std::move(x), sigmas, rng, eta);
case RES_2S_SAMPLE_METHOD:
return sample_res_2s(model, std::move(x), sigmas, rng, eta);
case DDIM_TRAILING_SAMPLE_METHOD:
return sample_ddim_trailing(model, std::move(x), sigmas, rng, eta);
case TCD_SAMPLE_METHOD:
return sample_tcd(model, std::move(x), sigmas, rng, eta);
default: default:
return {}; return {};
} }

View File

@ -493,7 +493,7 @@ public:
offload_params_to_cpu, offload_params_to_cpu,
tensor_storage_map, tensor_storage_map,
true, true,
1, 0,
true); true);
diffusion_model = std::make_shared<WanModel>(backend, diffusion_model = std::make_shared<WanModel>(backend,
offload_params_to_cpu, offload_params_to_cpu,