mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
refactor: unify extra argument parsing (#1540)
This commit is contained in:
parent
449165caf5
commit
3a8788cb7d
158
src/denoiser.hpp
158
src/denoiser.hpp
@ -496,84 +496,26 @@ struct LTX2Scheduler : SigmaScheduler {
|
|||||||
parse_extra_sample_args(extra_sample_args);
|
parse_extra_sample_args(extra_sample_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string trim(std::string value) {
|
|
||||||
const char* whitespace = " \t\r\n";
|
|
||||||
size_t begin = value.find_first_not_of(whitespace);
|
|
||||||
if (begin == std::string::npos) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
size_t end = value.find_last_not_of(whitespace);
|
|
||||||
return value.substr(begin, end - begin + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void parse_extra_sample_args(const char* extra_sample_args) {
|
void parse_extra_sample_args(const char* extra_sample_args) {
|
||||||
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
|
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "ltx2 scheduler arg")) {
|
||||||
return;
|
if (key == "max_shift") {
|
||||||
}
|
if (!parse_strict_float(value, max_shift)) {
|
||||||
|
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
std::string raw(extra_sample_args);
|
|
||||||
size_t start = 0;
|
|
||||||
auto parse_arg = [&](const std::string& item) {
|
|
||||||
std::string token = trim(item);
|
|
||||||
if (token.empty()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
size_t eq = token.find('=');
|
|
||||||
if (eq == std::string::npos) {
|
|
||||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string key = trim(token.substr(0, eq));
|
|
||||||
std::string value = trim(token.substr(eq + 1));
|
|
||||||
auto parse_float = [&](float* out) -> bool {
|
|
||||||
try {
|
|
||||||
size_t consumed = 0;
|
|
||||||
float parsed = std::stof(value, &consumed);
|
|
||||||
if (!trim(value.substr(consumed)).empty()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
*out = parsed;
|
|
||||||
return true;
|
|
||||||
} catch (const std::exception&) {
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
};
|
} else if (key == "base_shift") {
|
||||||
try {
|
if (!parse_strict_float(value, base_shift)) {
|
||||||
if (key == "max_shift") {
|
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
if (!parse_float(&max_shift)) {
|
|
||||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
|
||||||
}
|
|
||||||
} else if (key == "base_shift") {
|
|
||||||
if (!parse_float(&base_shift)) {
|
|
||||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
|
||||||
}
|
|
||||||
} else if (key == "terminal") {
|
|
||||||
if (!parse_float(&terminal)) {
|
|
||||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
|
||||||
}
|
|
||||||
} else if (key == "stretch") {
|
|
||||||
std::string v = value;
|
|
||||||
std::transform(v.begin(), v.end(), v.begin(), [](unsigned char c) { return static_cast<char>(std::tolower(c)); });
|
|
||||||
if (v == "1" || v == "true" || v == "yes" || v == "on") {
|
|
||||||
stretch = true;
|
|
||||||
} else if (v == "0" || v == "false" || v == "no" || v == "off") {
|
|
||||||
stretch = false;
|
|
||||||
} else {
|
|
||||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
|
|
||||||
}
|
}
|
||||||
} catch (const std::exception&) {
|
} else if (key == "terminal") {
|
||||||
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s'", token.c_str());
|
if (!parse_strict_float(value, terminal)) {
|
||||||
}
|
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
};
|
}
|
||||||
|
} else if (key == "stretch") {
|
||||||
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
if (!parse_strict_bool(value, stretch)) {
|
||||||
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
|
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
parse_arg(raw.substr(start, pos - start));
|
}
|
||||||
start = pos + 1;
|
} else {
|
||||||
|
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1276,7 +1218,7 @@ static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
using SamplerExtraArgs = std::vector<std::pair<std::string, std::string>>;
|
using SamplerExtraArgs = KeyValueArgs;
|
||||||
|
|
||||||
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
||||||
sd::Tensor<float> x,
|
sd::Tensor<float> x,
|
||||||
@ -1296,15 +1238,8 @@ static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
|||||||
|
|
||||||
for (const auto& [key, value] : extra_sample_args) {
|
for (const auto& [key, value] : extra_sample_args) {
|
||||||
float parsed = 0.0f;
|
float parsed = 0.0f;
|
||||||
try {
|
if (!parse_strict_float(value, parsed)) {
|
||||||
size_t consumed = 0;
|
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
parsed = std::stof(value, &consumed);
|
|
||||||
if (trim(value.substr(consumed)).size() != 0) {
|
|
||||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", key.c_str());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
} catch (const std::exception&) {
|
|
||||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str());
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (key == "noise_clip_std") {
|
if (key == "noise_clip_std") {
|
||||||
@ -1861,15 +1796,8 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
|
|||||||
|
|
||||||
for (const auto& [key, value] : extra_sample_args) {
|
for (const auto& [key, value] : extra_sample_args) {
|
||||||
float parsed = 0.0f;
|
float parsed = 0.0f;
|
||||||
try {
|
if (!parse_strict_float(value, parsed)) {
|
||||||
size_t consumed = 0;
|
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
parsed = std::stof(value, &consumed);
|
|
||||||
if (trim(value.substr(consumed)).size() != 0) {
|
|
||||||
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
} catch (const std::exception&) {
|
|
||||||
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (key == "gamma") {
|
if (key == "gamma") {
|
||||||
@ -1916,46 +1844,6 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
static SamplerExtraArgs parse_sampler_args(const char* extra_sample_args) {
|
|
||||||
SamplerExtraArgs pairs;
|
|
||||||
|
|
||||||
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
|
|
||||||
return pairs;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto trim = [](std::string value) -> std::string {
|
|
||||||
const char* whitespace = " \t\r\n";
|
|
||||||
size_t begin = value.find_first_not_of(whitespace);
|
|
||||||
if (begin == std::string::npos) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
size_t end = value.find_last_not_of(whitespace);
|
|
||||||
return value.substr(begin, end - begin + 1);
|
|
||||||
};
|
|
||||||
|
|
||||||
std::string raw(extra_sample_args);
|
|
||||||
size_t start = 0;
|
|
||||||
|
|
||||||
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
|
||||||
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
|
|
||||||
std::string item = raw.substr(start, pos - start);
|
|
||||||
std::string token = trim(item);
|
|
||||||
|
|
||||||
if (!token.empty()) {
|
|
||||||
size_t eq = token.find('=');
|
|
||||||
if (eq != std::string::npos) {
|
|
||||||
std::string key = trim(token.substr(0, eq));
|
|
||||||
std::string value = trim(token.substr(eq + 1));
|
|
||||||
pairs.emplace_back(std::move(key), std::move(value));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
start = pos + 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return pairs;
|
|
||||||
}
|
|
||||||
|
|
||||||
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
|
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
|
||||||
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
||||||
denoise_cb_t model,
|
denoise_cb_t model,
|
||||||
@ -1965,7 +1853,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
|||||||
float eta,
|
float eta,
|
||||||
bool is_flow_denoiser,
|
bool is_flow_denoiser,
|
||||||
const char* extra_sample_args) {
|
const char* extra_sample_args) {
|
||||||
SamplerExtraArgs extra_args = parse_sampler_args(extra_sample_args);
|
SamplerExtraArgs extra_args = parse_key_value_args(extra_sample_args, "extra sample arg");
|
||||||
switch (method) {
|
switch (method) {
|
||||||
case EULER_A_SAMPLE_METHOD:
|
case EULER_A_SAMPLE_METHOD:
|
||||||
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
|
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
|
||||||
|
|||||||
@ -1251,65 +1251,22 @@ struct LTXVideoVAE : public VAE {
|
|||||||
temporal_tiling_enabled = enabled;
|
temporal_tiling_enabled = enabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string trim_tiling_arg(std::string value) {
|
|
||||||
const char* whitespace = " \t\r\n";
|
|
||||||
size_t begin = value.find_first_not_of(whitespace);
|
|
||||||
if (begin == std::string::npos) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
size_t end = value.find_last_not_of(whitespace);
|
|
||||||
return value.substr(begin, end - begin + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool parse_tiling_int(const std::string& value, int& parsed) {
|
|
||||||
try {
|
|
||||||
size_t consumed = 0;
|
|
||||||
parsed = std::stoi(value, &consumed);
|
|
||||||
return trim_tiling_arg(value.substr(consumed)).empty();
|
|
||||||
} catch (...) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void set_tiling_params(const sd_tiling_params_t& params) override {
|
void set_tiling_params(const sd_tiling_params_t& params) override {
|
||||||
temporal_tiling_enabled = params.temporal_tiling;
|
temporal_tiling_enabled = params.temporal_tiling;
|
||||||
temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
|
temporal_tile_frames = DEFAULT_TEMPORAL_TILE_FRAMES;
|
||||||
temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
|
temporal_tile_overlap = DEFAULT_TEMPORAL_TILE_OVERLAP;
|
||||||
|
|
||||||
const char* extra_tiling_args = params.extra_tiling_args;
|
for (const auto& [key, value] : parse_key_value_args(params.extra_tiling_args, "LTX VAE extra tiling arg")) {
|
||||||
if (extra_tiling_args == nullptr || extra_tiling_args[0] == '\0') {
|
int parsed = 0;
|
||||||
return;
|
if (!parse_strict_int(value, parsed)) {
|
||||||
}
|
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
|
||||||
|
} else if (key == "temporal_tile_frames") {
|
||||||
std::string raw(extra_tiling_args);
|
temporal_tile_frames = std::max(1, parsed);
|
||||||
size_t start = 0;
|
} else if (key == "temporal_tile_overlap") {
|
||||||
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
temporal_tile_overlap = std::max(0, parsed);
|
||||||
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
|
} else {
|
||||||
continue;
|
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string token = trim_tiling_arg(raw.substr(start, pos - start));
|
|
||||||
if (!token.empty()) {
|
|
||||||
size_t eq = token.find('=');
|
|
||||||
if (eq == std::string::npos) {
|
|
||||||
LOG_WARN("ignoring malformed LTX VAE extra tiling arg '%s'", token.c_str());
|
|
||||||
} else {
|
|
||||||
std::string key = trim_tiling_arg(token.substr(0, eq));
|
|
||||||
std::string value = trim_tiling_arg(token.substr(eq + 1));
|
|
||||||
int parsed = 0;
|
|
||||||
if (!parse_tiling_int(value, parsed)) {
|
|
||||||
LOG_WARN("ignoring invalid LTX VAE extra tiling arg '%s=%s'", key.c_str(), value.c_str());
|
|
||||||
} else if (key == "temporal_tile_frames") {
|
|
||||||
temporal_tile_frames = std::max(1, parsed);
|
|
||||||
} else if (key == "temporal_tile_overlap") {
|
|
||||||
temporal_tile_overlap = std::max(0, parsed);
|
|
||||||
} else {
|
|
||||||
LOG_WARN("ignoring unknown LTX VAE extra tiling arg '%s'", key.c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
start = pos + 1;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
84
src/util.cpp
84
src/util.cpp
@ -1,8 +1,10 @@
|
|||||||
#include "util.h"
|
#include "util.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cctype>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <codecvt>
|
#include <codecvt>
|
||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
|
#include <exception>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <locale>
|
#include <locale>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
@ -406,6 +408,88 @@ std::vector<std::string> split_string(const std::string& str, char delimiter) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KeyValueArgs parse_key_value_args(const char* args, const char* context) {
|
||||||
|
KeyValueArgs pairs;
|
||||||
|
|
||||||
|
if (args == nullptr || args[0] == '\0') {
|
||||||
|
return pairs;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string raw(args);
|
||||||
|
size_t start = 0;
|
||||||
|
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
||||||
|
if (pos != raw.size() && raw[pos] != ',' && raw[pos] != ';') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string token = trim(raw.substr(start, pos - start));
|
||||||
|
if (!token.empty()) {
|
||||||
|
size_t eq = token.find('=');
|
||||||
|
if (eq == std::string::npos) {
|
||||||
|
const char* log_context = context ? context : "key=value arg";
|
||||||
|
LOG_WARN("ignoring malformed %s '%s'", log_context, token.c_str());
|
||||||
|
} else {
|
||||||
|
std::string key = trim(token.substr(0, eq));
|
||||||
|
std::string value = trim(token.substr(eq + 1));
|
||||||
|
pairs.emplace_back(std::move(key), std::move(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
start = pos + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return pairs;
|
||||||
|
}
|
||||||
|
|
||||||
|
KeyValueArgs parse_key_value_args(const std::string& args, const char* context) {
|
||||||
|
return parse_key_value_args(args.c_str(), context);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool parse_strict_float(const std::string& text, float& value) {
|
||||||
|
try {
|
||||||
|
size_t consumed = 0;
|
||||||
|
float parsed = std::stof(text, &consumed);
|
||||||
|
if (!trim(text.substr(consumed)).empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
value = parsed;
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool parse_strict_int(const std::string& text, int& value) {
|
||||||
|
try {
|
||||||
|
size_t consumed = 0;
|
||||||
|
int parsed = std::stoi(text, &consumed);
|
||||||
|
if (!trim(text.substr(consumed)).empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
value = parsed;
|
||||||
|
return true;
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool parse_strict_bool(const std::string& text, bool& value) {
|
||||||
|
std::string lowered = trim(text);
|
||||||
|
std::transform(lowered.begin(), lowered.end(), lowered.begin(), [](unsigned char c) {
|
||||||
|
return static_cast<char>(std::tolower(c));
|
||||||
|
});
|
||||||
|
|
||||||
|
if (lowered == "1" || lowered == "true" || lowered == "yes" || lowered == "on") {
|
||||||
|
value = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (lowered == "0" || lowered == "false" || lowered == "no" || lowered == "off") {
|
||||||
|
value = false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static std::string build_progress_bar(int step, int steps) {
|
static std::string build_progress_bar(int step, int steps) {
|
||||||
std::string progress = " |";
|
std::string progress = " |";
|
||||||
int max_progress = 50;
|
int max_progress = 50;
|
||||||
|
|||||||
10
src/util.h
10
src/util.h
@ -4,6 +4,7 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "ggml-backend.h"
|
#include "ggml-backend.h"
|
||||||
@ -65,6 +66,15 @@ protected:
|
|||||||
|
|
||||||
std::string path_join(const std::string& p1, const std::string& p2);
|
std::string path_join(const std::string& p1, const std::string& p2);
|
||||||
std::vector<std::string> split_string(const std::string& str, char delimiter);
|
std::vector<std::string> split_string(const std::string& str, char delimiter);
|
||||||
|
|
||||||
|
using KeyValueArgs = std::vector<std::pair<std::string, std::string>>;
|
||||||
|
|
||||||
|
KeyValueArgs parse_key_value_args(const char* args, const char* context = "key=value arg");
|
||||||
|
KeyValueArgs parse_key_value_args(const std::string& args, const char* context = "key=value arg");
|
||||||
|
bool parse_strict_float(const std::string& text, float& value);
|
||||||
|
bool parse_strict_int(const std::string& text, int& value);
|
||||||
|
bool parse_strict_bool(const std::string& text, bool& value);
|
||||||
|
|
||||||
void pretty_progress(int step, int steps, float time);
|
void pretty_progress(int step, int steps, float time);
|
||||||
void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float elapsed_seconds);
|
void pretty_bytes_progress(int step, int steps, uint64_t bytes_processed, float elapsed_seconds);
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user