refactor: unify extra argument parsing (#1540)

This commit is contained in:
leejet 2026-05-22 01:00:03 +08:00 committed by GitHub
parent 449165caf5
commit 3a8788cb7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 127 additions and 188 deletions

View File

@ -496,84 +496,26 @@ struct LTX2Scheduler : SigmaScheduler {
parse_extra_sample_args(extra_sample_args); parse_extra_sample_args(extra_sample_args);
} }
static std::string trim(std::string value) {
const char* whitespace = " \t\r\n";
size_t begin = value.find_first_not_of(whitespace);
if (begin == std::string::npos) {
return "";
}
size_t end = value.find_last_not_of(whitespace);
return value.substr(begin, end - begin + 1);
}
void parse_extra_sample_args(const char* extra_sample_args) { void parse_extra_sample_args(const char* extra_sample_args) {
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') { for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "ltx2 scheduler arg")) {
return; if (key == "max_shift") {
} if (!parse_strict_float(value, max_shift)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
std::string raw(extra_sample_args);
size_t start = 0;
auto parse_arg = [&](const std::string& item) {
std::string token = trim(item);
if (token.empty()) {
return;
}
size_t eq = token.find('=');
if (eq == std::string::npos) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
return;
}
std::string key = trim(token.substr(0, eq));
std::string value = trim(token.substr(eq + 1));
auto parse_float = [&](float* out) -> bool {
try {
size_t consumed = 0;
float parsed = std::stof(value, &consumed);
if (!trim(value.substr(consumed)).empty()) {
return false;
}
*out = parsed;
return true;
} catch (const std::exception&) {
return false;
} }
}; } else if (key == "base_shift") {
try { if (!parse_strict_float(value, base_shift)) {
if (key == "max_shift") { LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
if (!parse_float(&max_shift)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else if (key == "base_shift") {
if (!parse_float(&base_shift)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else if (key == "terminal") {
if (!parse_float(&terminal)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else if (key == "stretch") {
std::string v = value;
std::transform(v.begin(), v.end(), v.begin(), [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
if (v == "1" || v == "true" || v == "yes" || v == "on") {
stretch = true;
} else if (v == "0" || v == "false" || v == "no" || v == "off") {
stretch = false;
} else {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
}
} else {
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
} }
} catch (const std::exception&) { } else if (key == "terminal") {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str()); if (!parse_strict_float(value, terminal)) {
} LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
}; }
} else if (key == "stretch") {
for (size_t pos = 0; pos <= raw.size(); ++pos) { if (!parse_strict_bool(value, stretch)) {
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') { LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
parse_arg(raw.substr(start, pos - start)); }
start = pos + 1; } else {
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
} }
} }
} }
@ -1276,7 +1218,7 @@ static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
return x; return x;
} }
using SamplerExtraArgs = std::vector<std::pair<std::string, std::string>>; using SamplerExtraArgs = KeyValueArgs;
static sd::Tensor<float> sample_lcm(denoise_cb_t model, static sd::Tensor<float> sample_lcm(denoise_cb_t model,
sd::Tensor<float> x, sd::Tensor<float> x,
@ -1296,15 +1238,8 @@ static sd::Tensor<float> sample_lcm(denoise_cb_t model,
for (const auto& [key, value] : extra_sample_args) { for (const auto& [key, value] : extra_sample_args) {
float parsed = 0.0f; float parsed = 0.0f;
try { if (!parse_strict_float(value, parsed)) {
size_t consumed = 0; LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str(), value.c_str());
parsed = std::stof(value, &consumed);
if (trim(value.substr(consumed)).size() != 0) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", key.c_str());
continue;
}
} catch (const std::exception&) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str());
continue; continue;
} }
if (key == "noise_clip_std") { if (key == "noise_clip_std") {
@ -1861,15 +1796,8 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
for (const auto& [key, value] : extra_sample_args) { for (const auto& [key, value] : extra_sample_args) {
float parsed = 0.0f; float parsed = 0.0f;
try { if (!parse_strict_float(value, parsed)) {
size_t consumed = 0; LOG_WARN("ignoring invalid euler_ge extra sample arg '%s=%s'", key.c_str(), value.c_str());
parsed = std::stof(value, &consumed);
if (trim(value.substr(consumed)).size() != 0) {
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
continue;
}
} catch (const std::exception&) {
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
continue; continue;
} }
if (key == "gamma") { if (key == "gamma") {
@ -1916,46 +1844,6 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
return x; return x;
} }
static SamplerExtraArgs parse_sampler_args(const char* extra_sample_args) {
SamplerExtraArgs pairs;
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
return pairs;
}
auto trim = [](std::string value) -> std::string {
const char* whitespace = " \t\r\n";
size_t begin = value.find_first_not_of(whitespace);
if (begin == std::string::npos) {
return "";
}
size_t end = value.find_last_not_of(whitespace);
return value.substr(begin, end - begin + 1);
};
std::string raw(extra_sample_args);
size_t start = 0;
for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
std::string item = raw.substr(start, pos - start);
std::string token = trim(item);
if (!token.empty()) {
size_t eq = token.find('=');
if (eq != std::string::npos) {
std::string key = trim(token.substr(0, eq));
std::string value = trim(token.substr(eq + 1));
pairs.emplace_back(std::move(key), std::move(value));
}
}
start = pos + 1;
}
}
return pairs;
}
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t // k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
static sd::Tensor<float> sample_k_diffusion(sample_method_t method, static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
denoise_cb_t model, denoise_cb_t model,
@ -1965,7 +1853,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float eta, float eta,
bool is_flow_denoiser, bool is_flow_denoiser,
const char* extra_sample_args) { const char* extra_sample_args) {
SamplerExtraArgs extra_args = parse_sampler_args(extra_sample_args); SamplerExtraArgs extra_args = parse_key_value_args(extra_sample_args, "extra sample arg");
switch (method) { switch (method) {
case EULER_A_SAMPLE_METHOD: case EULER_A_SAMPLE_METHOD:
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta); return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);

View File

@ -1251,65 +1251,22 @@ struct LTXVideoVAE : public VAE {
temporal_tiling_enabled = enabled; temporal_tiling_enabled = enabled;
} }
static std::string trim_tiling_arg(std::string value) {
const char* whitespace = " \t\r\n";
size_t begin = value.find_first_not_of(whitespace);
if (begin == std::string::npos) {
return "";
}
size_t end = value.find_last_not_of(whitespace);
return value.substr(begin, end - begin + 1);
}
static bool parse_tiling_int(const std::string& value, int& parsed) {
try {
size_t consumed = 0;
parsed = std::stoi(value, &consumed);
return trim_tiling_arg(value.substr(consumed)).empty();
} catch (...) {
return false;
}
}
void set_tiling_params(const sd_tiling_params_t& params) override { void set_tiling_params(const sd_tiling_params_t& params) override {
temporal_tiling_enabled = params.temporal_tiling; temporal_tiling_enabled = params.temporal_tiling;
temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES; temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP; temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
const char* extra_tiling_args = params.extra_tiling_args; for (const auto& [key, value] : parse_key_value_args(params.extra_tiling_args, "LTX VAE extra tiling arg")) {
if (extra_tiling_args == nullptr || extra_tiling_args[0] == '\0') { int parsed = 0;
return; if (!parse_strict_int(value, parsed)) {
} LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
} else if (key == "temporal_tile_frames") {
std::string raw(extra_tiling_args); temporal_tile_frames = std::max(1, parsed);
size_t start = 0; } else if (key == "temporal_tile_overlap") {
for (size_t pos = 0; pos <= raw.size(); ++pos) { temporal_tile_overlap = std::max(0, parsed);
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') { } else {
continue; LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
} }
std::string token = trim_tiling_arg(raw.substr(start, pos - start));
if (!token.empty()) {
size_t eq = token.find('=');
if (eq == std::string::npos) {
LOG_WARN("ignoring malformed LTX VAE extra tiling arg '%s'", token.c_str());
} else {
std::string key = trim_tiling_arg(token.substr(0, eq));
std::string value = trim_tiling_arg(token.substr(eq + 1));
int parsed = 0;
if (!parse_tiling_int(value, parsed)) {
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
} else if (key == "temporal_tile_frames") {
temporal_tile_frames = std::max(1, parsed);
} else if (key == "temporal_tile_overlap") {
temporal_tile_overlap = std::max(0, parsed);
} else {
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
}
}
}
start = pos + 1;
} }
} }

View File

@ -1,8 +1,10 @@
#include "util.h" #include "util.h"
#include <algorithm> #include <algorithm>
#include <cctype>
#include <cmath> #include <cmath>
#include <codecvt> #include <codecvt>
#include <cstdarg> #include <cstdarg>
#include <exception>
#include <fstream> #include <fstream>
#include <locale> #include <locale>
#include <regex> #include <regex>
@ -406,6 +408,88 @@ std::vector<std::string> split_string(const std::string& str, char delimiter) {
return result; return result;
} }
KeyValueArgs parse_key_value_args(const char* args, const char* context) {
KeyValueArgs pairs;
if (args == nullptr || args[0] == '\0') {
return pairs;
}
std::string raw(args);
size_t start = 0;
for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
continue;
}
std::string token = trim(raw.substr(start, pos - start));
if (!token.empty()) {
size_t eq = token.find('=');
if (eq == std::string::npos) {
const char* log_context = context ? context : "key=value arg";
LOG_WARN("ignoring malformed %s '%s'", log_context, token.c_str());
} else {
std::string key = trim(token.substr(0, eq));
std::string value = trim(token.substr(eq + 1));
pairs.emplace_back(std::move(key), std::move(value));
}
}
start = pos + 1;
}
return pairs;
}
KeyValueArgs parse_key_value_args(const std::string& args, const char* context) {
return parse_key_value_args(args.c_str(), context);
}
bool parse_strict_float(const std::string& text, float& value) {
try {
size_t consumed = 0;
float parsed = std::stof(text, &consumed);
if (!trim(text.substr(consumed)).empty()) {
return false;
}
value = parsed;
return true;
} catch (const std::exception&) {
return false;
}
}
bool parse_strict_int(const std::string& text, int& value) {
try {
size_t consumed = 0;
int parsed = std::stoi(text, &consumed);
if (!trim(text.substr(consumed)).empty()) {
return false;
}
value = parsed;
return true;
} catch (const std::exception&) {
return false;
}
}
bool parse_strict_bool(const std::string& text, bool& value) {
std::string lowered = trim(text);
std::transform(lowered.begin(), lowered.end(), lowered.begin(), [](unsigned char c) {
return static_cast<char>(std::tolower(c));
});
if (lowered == "1" || lowered == "true" || lowered == "yes" || lowered == "on") {
value = true;
return true;
}
if (lowered == "0" || lowered == "false" || lowered == "no" || lowered == "off") {
value = false;
return true;
}
return false;
}
static std::string build_progress_bar(int step, int steps) { static std::string build_progress_bar(int step, int steps) {
std::string progress = " |"; std::string progress = " |";
int max_progress = 50; int max_progress = 50;

View File

@ -4,6 +4,7 @@
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "ggml-backend.h" #include "ggml-backend.h"
@ -65,6 +66,15 @@ protected:
std::string path_join(const std::string& p1, const std::string& p2); std::string path_join(const std::string& p1, const std::string& p2);
std::vector<std::string> split_string(const std::string& str, char delimiter); std::vector<std::string> split_string(const std::string& str, char delimiter);
using KeyValueArgs = std::vector<std::pair<std::string, std::string>>;
KeyValueArgs parse_key_value_args(const char* args, const char* context = "key=value arg");
KeyValueArgs parse_key_value_args(const std::string& args, const char* context = "key=value arg");
bool parse_strict_float(const std::string& text, float& value);
bool parse_strict_int(const std::string& text, int& value);
bool parse_strict_bool(const std::string& text, bool& value);
void pretty_progress(int step, int steps, float time); void pretty_progress(int step, int steps, float time);
void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float elapsed_seconds); void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float elapsed_seconds);