mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
feat: add Gradient Estimation sampler (#1484)
This commit is contained in:
parent
50134e51dd
commit
e7eb92fd84
@ -105,7 +105,7 @@ Generation Options:
|
||||
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
||||
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||
stretch, terminal
|
||||
stretch, terminal; euler_ge supports gamma
|
||||
-H, --height <int> image height, in pixel space (default: 512)
|
||||
-W, --width <int> image width, in pixel space (default: 512)
|
||||
--steps <int> number of sample steps (default: 20)
|
||||
|
||||
@ -833,7 +833,7 @@ ArgOptions SDGenerationParams::get_options() {
|
||||
&hires_upscaler},
|
||||
{"",
|
||||
"--extra-sample-args",
|
||||
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal",
|
||||
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
|
||||
&extra_sample_args},
|
||||
};
|
||||
|
||||
|
||||
@ -207,7 +207,7 @@ Default Generation Options:
|
||||
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
||||
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||
stretch, terminal
|
||||
stretch, terminal; euler_ge supports gamma
|
||||
-H, --height <int> image height, in pixel space (default: 512)
|
||||
-W, --width <int> image width, in pixel space (default: 512)
|
||||
--steps <int> number of sample steps (default: 20)
|
||||
|
||||
@ -53,6 +53,7 @@ enum sample_method_t {
|
||||
ER_SDE_SAMPLE_METHOD,
|
||||
EULER_CFG_PP_SAMPLE_METHOD,
|
||||
EULER_A_CFG_PP_SAMPLE_METHOD,
|
||||
EULER_GE_SAMPLE_METHOD,
|
||||
SAMPLE_METHOD_COUNT
|
||||
};
|
||||
|
||||
|
||||
202
src/denoiser.hpp
202
src/denoiser.hpp
@ -1276,84 +1276,54 @@ static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
|
||||
return x;
|
||||
}
|
||||
|
||||
using SamplerExtraArgs = std::vector<std::pair<std::string, std::string>>;
|
||||
|
||||
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
||||
sd::Tensor<float> x,
|
||||
const std::vector<float>& sigmas,
|
||||
std::shared_ptr<RNG> rng,
|
||||
bool is_flow_denoiser,
|
||||
const char* extra_sample_args = nullptr) {
|
||||
const SamplerExtraArgs& extra_sample_args) {
|
||||
struct LCMSampleArgs {
|
||||
float noise_clip_std = 0.0f;
|
||||
float noise_scale_start = 1.0f;
|
||||
float noise_scale_end = 1.0f;
|
||||
};
|
||||
|
||||
auto trim = [](std::string value) -> std::string {
|
||||
const char* whitespace = " \t\r\n";
|
||||
size_t begin = value.find_first_not_of(whitespace);
|
||||
if (begin == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
size_t end = value.find_last_not_of(whitespace);
|
||||
return value.substr(begin, end - begin + 1);
|
||||
};
|
||||
|
||||
LCMSampleArgs args;
|
||||
if (extra_sample_args != nullptr && extra_sample_args[0] != '\0') {
|
||||
std::string raw(extra_sample_args);
|
||||
size_t start = 0;
|
||||
bool noise_scale_end_was_set = false;
|
||||
bool noise_scale_start_was_set = false;
|
||||
auto parse_arg = [&](const std::string& item) {
|
||||
std::string token = trim(item);
|
||||
if (token.empty()) {
|
||||
return;
|
||||
}
|
||||
size_t eq = token.find('=');
|
||||
if (eq == std::string::npos) {
|
||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", token.c_str());
|
||||
return;
|
||||
}
|
||||
bool noise_scale_end_was_set = false;
|
||||
bool noise_scale_start_was_set = false;
|
||||
|
||||
std::string key = trim(token.substr(0, eq));
|
||||
std::string value = trim(token.substr(eq + 1));
|
||||
float parsed = 0.0f;
|
||||
try {
|
||||
size_t consumed = 0;
|
||||
parsed = std::stof(value, &consumed);
|
||||
if (trim(value.substr(consumed)).size() != 0) {
|
||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", token.c_str());
|
||||
return;
|
||||
}
|
||||
} catch (const std::exception&) {
|
||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", token.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
if (key == "noise_clip_std") {
|
||||
args.noise_clip_std = parsed;
|
||||
} else if (key == "noise_scale_start") {
|
||||
args.noise_scale_start = parsed;
|
||||
noise_scale_start_was_set = true;
|
||||
} else if (key == "noise_scale_end") {
|
||||
args.noise_scale_end = parsed;
|
||||
noise_scale_end_was_set = true;
|
||||
} else {
|
||||
LOG_WARN("ignoring unknown lcm extra sample arg '%s'", key.c_str());
|
||||
}
|
||||
};
|
||||
|
||||
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
||||
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
|
||||
parse_arg(raw.substr(start, pos - start));
|
||||
start = pos + 1;
|
||||
for (const auto& [key, value] : extra_sample_args) {
|
||||
float parsed = 0.0f;
|
||||
try {
|
||||
size_t consumed = 0;
|
||||
parsed = std::stof(value, &consumed);
|
||||
if (trim(value.substr(consumed)).size() != 0) {
|
||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", key.c_str());
|
||||
continue;
|
||||
}
|
||||
} catch (const std::exception&) {
|
||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str());
|
||||
continue;
|
||||
}
|
||||
if (noise_scale_start_was_set && !noise_scale_end_was_set) {
|
||||
args.noise_scale_end = args.noise_scale_start;
|
||||
if (key == "noise_clip_std") {
|
||||
args.noise_clip_std = parsed;
|
||||
} else if (key == "noise_scale_start") {
|
||||
args.noise_scale_start = parsed;
|
||||
noise_scale_start_was_set = true;
|
||||
} else if (key == "noise_scale_end") {
|
||||
args.noise_scale_end = parsed;
|
||||
noise_scale_end_was_set = true;
|
||||
} else {
|
||||
LOG_WARN("ignoring unknown lcm extra sample arg '%s'", key.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (noise_scale_start_was_set && !noise_scale_end_was_set) {
|
||||
args.noise_scale_end = args.noise_scale_start;
|
||||
}
|
||||
|
||||
int steps = static_cast<int>(sigmas.size()) - 1;
|
||||
for (int i = 0; i < steps; i++) {
|
||||
auto denoised_opt = model(x, sigmas[i], i + 1);
|
||||
@ -1879,6 +1849,113 @@ static sd::Tensor<float> sample_euler_ancestral_cfg_pp(denoise_cb_t model,
|
||||
return x;
|
||||
}
|
||||
|
||||
// https://github.com/ToyotaResearchInstitute/gradient-estimation-sampler
|
||||
static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
|
||||
sd::Tensor<float> x,
|
||||
const std::vector<float>& sigmas,
|
||||
std::shared_ptr<RNG> rng,
|
||||
bool is_flow_denoiser,
|
||||
float eta,
|
||||
const SamplerExtraArgs& extra_sample_args) {
|
||||
float ge_gamma = 2.0f;
|
||||
|
||||
for (const auto& [key, value] : extra_sample_args) {
|
||||
float parsed = 0.0f;
|
||||
try {
|
||||
size_t consumed = 0;
|
||||
parsed = std::stof(value, &consumed);
|
||||
if (trim(value.substr(consumed)).size() != 0) {
|
||||
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
|
||||
continue;
|
||||
}
|
||||
} catch (const std::exception&) {
|
||||
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
|
||||
continue;
|
||||
}
|
||||
if (key == "gamma") {
|
||||
LOG_DEBUG("setting euler_ge gamma to %.2f", parsed);
|
||||
ge_gamma = parsed;
|
||||
} else {
|
||||
LOG_WARN("ignoring unknown euler_ge extra sample arg '%s'", key.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
int steps = static_cast<int>(sigmas.size()) - 1;
|
||||
sd::Tensor<float> old_d;
|
||||
bool has_old_d = false;
|
||||
for (int i = 0; i < steps; i++) {
|
||||
float sigma = sigmas[i];
|
||||
float sigma_to = sigmas[i + 1];
|
||||
auto denoised_opt = model(x, sigma, i + 1);
|
||||
if (denoised_opt.pred.empty()) {
|
||||
return {};
|
||||
}
|
||||
sd::Tensor<float> denoised = std::move(denoised_opt.pred);
|
||||
if (sigma_to == 0.f) {
|
||||
x = denoised;
|
||||
} else {
|
||||
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma, sigma_to, eta, is_flow_denoiser);
|
||||
sd::Tensor<float> d = (x - denoised) / sigma;
|
||||
float dt = sigma_down - sigma;
|
||||
if (has_old_d) {
|
||||
sd::Tensor<float> d_bar = d * ge_gamma + old_d * (1.0f - ge_gamma);
|
||||
x += d_bar * dt;
|
||||
} else {
|
||||
x += d * dt;
|
||||
}
|
||||
old_d = std::move(d);
|
||||
has_old_d = true;
|
||||
if (sigma_up > 0.f) {
|
||||
if (is_flow_denoiser) {
|
||||
x *= alpha_scale;
|
||||
}
|
||||
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
|
||||
}
|
||||
}
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
static SamplerExtraArgs parse_sampler_args(const char* extra_sample_args) {
|
||||
SamplerExtraArgs pairs;
|
||||
|
||||
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
|
||||
return pairs;
|
||||
}
|
||||
|
||||
auto trim = [](std::string value) -> std::string {
|
||||
const char* whitespace = " \t\r\n";
|
||||
size_t begin = value.find_first_not_of(whitespace);
|
||||
if (begin == std::string::npos) {
|
||||
return "";
|
||||
}
|
||||
size_t end = value.find_last_not_of(whitespace);
|
||||
return value.substr(begin, end - begin + 1);
|
||||
};
|
||||
|
||||
std::string raw(extra_sample_args);
|
||||
size_t start = 0;
|
||||
|
||||
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
||||
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
|
||||
std::string item = raw.substr(start, pos - start);
|
||||
std::string token = trim(item);
|
||||
|
||||
if (!token.empty()) {
|
||||
size_t eq = token.find('=');
|
||||
if (eq != std::string::npos) {
|
||||
std::string key = trim(token.substr(0, eq));
|
||||
std::string value = trim(token.substr(eq + 1));
|
||||
pairs.emplace_back(std::move(key), std::move(value));
|
||||
}
|
||||
}
|
||||
start = pos + 1;
|
||||
}
|
||||
}
|
||||
|
||||
return pairs;
|
||||
}
|
||||
|
||||
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
|
||||
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
||||
denoise_cb_t model,
|
||||
@ -1888,6 +1965,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
||||
float eta,
|
||||
bool is_flow_denoiser,
|
||||
const char* extra_sample_args) {
|
||||
SamplerExtraArgs extra_args = parse_sampler_args(extra_sample_args);
|
||||
switch (method) {
|
||||
case EULER_A_SAMPLE_METHOD:
|
||||
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
|
||||
@ -1907,7 +1985,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
||||
case DPMPP2Mv2_SAMPLE_METHOD:
|
||||
return sample_dpmpp_2m_v2(model, std::move(x), sigmas);
|
||||
case LCM_SAMPLE_METHOD:
|
||||
return sample_lcm(model, std::move(x), sigmas, rng, is_flow_denoiser, extra_sample_args);
|
||||
return sample_lcm(model, std::move(x), sigmas, rng, is_flow_denoiser, extra_args);
|
||||
case IPNDM_SAMPLE_METHOD:
|
||||
return sample_ipndm(model, std::move(x), sigmas);
|
||||
case IPNDM_V_SAMPLE_METHOD:
|
||||
@ -1927,6 +2005,8 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
||||
return sample_euler_cfg_pp(model, std::move(x), sigmas);
|
||||
case EULER_A_CFG_PP_SAMPLE_METHOD:
|
||||
return sample_euler_ancestral_cfg_pp(model, std::move(x), sigmas, rng, eta);
|
||||
case EULER_GE_SAMPLE_METHOD:
|
||||
return sample_gradient_estimation(model, std::move(x), sigmas, rng, is_flow_denoiser, eta, extra_args);
|
||||
default:
|
||||
return {};
|
||||
}
|
||||
|
||||
@ -81,6 +81,7 @@ const char* sampling_methods_str[] = {
|
||||
"ER-SDE",
|
||||
"Euler CFG++",
|
||||
"Euler A CFG++",
|
||||
"Euler GE",
|
||||
};
|
||||
|
||||
/*================================================== Helper Functions ================================================*/
|
||||
@ -2282,6 +2283,7 @@ const char* sample_method_to_str[] = {
|
||||
"er_sde",
|
||||
"euler_cfg_pp",
|
||||
"euler_a_cfg_pp",
|
||||
"euler_ge",
|
||||
};
|
||||
|
||||
const char* sd_sample_method_name(enum sample_method_t sample_method) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user