feat: add Gradient Estimation sampler (#1484)

This commit is contained in:
Wagner Bruna 2026-05-17 11:54:28 -03:00 committed by GitHub
parent 50134e51dd
commit e7eb92fd84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 147 additions and 64 deletions

View File

@ -105,7 +105,7 @@ Generation Options:
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
stretch, terminal
stretch, terminal; euler_ge supports gamma
-H, --height <int> image height, in pixel space (default: 512)
-W, --width <int> image width, in pixel space (default: 512)
--steps <int> number of sample steps (default: 20)

View File

@ -833,7 +833,7 @@ ArgOptions SDGenerationParams::get_options() {
&hires_upscaler},
{"",
"--extra-sample-args",
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal",
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
&extra_sample_args},
};

View File

@ -207,7 +207,7 @@ Default Generation Options:
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
stretch, terminal
stretch, terminal; euler_ge supports gamma
-H, --height <int> image height, in pixel space (default: 512)
-W, --width <int> image width, in pixel space (default: 512)
--steps <int> number of sample steps (default: 20)

View File

@ -53,6 +53,7 @@ enum sample_method_t {
ER_SDE_SAMPLE_METHOD,
EULER_CFG_PP_SAMPLE_METHOD,
EULER_A_CFG_PP_SAMPLE_METHOD,
EULER_GE_SAMPLE_METHOD,
SAMPLE_METHOD_COUNT
};

View File

@ -1276,60 +1276,37 @@ static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
return x;
}
using SamplerExtraArgs = std::vector<std::pair<std::string, std::string>>;
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
bool is_flow_denoiser,
const char* extra_sample_args = nullptr) {
const SamplerExtraArgs& extra_sample_args) {
struct LCMSampleArgs {
float noise_clip_std = 0.0f;
float noise_scale_start = 1.0f;
float noise_scale_end = 1.0f;
};
auto trim = [](std::string value) -> std::string {
const char* whitespace = " \t\r\n";
size_t begin = value.find_first_not_of(whitespace);
if (begin == std::string::npos) {
return "";
}
size_t end = value.find_last_not_of(whitespace);
return value.substr(begin, end - begin + 1);
};
LCMSampleArgs args;
if (extra_sample_args != nullptr && extra_sample_args[0] != '\0') {
std::string raw(extra_sample_args);
size_t start = 0;
bool noise_scale_end_was_set = false;
bool noise_scale_start_was_set = false;
auto parse_arg = [&](const std::string& item) {
std::string token = trim(item);
if (token.empty()) {
return;
}
size_t eq = token.find('=');
if (eq == std::string::npos) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", token.c_str());
return;
}
std::string key = trim(token.substr(0, eq));
std::string value = trim(token.substr(eq + 1));
for (const auto& [key, value] : extra_sample_args) {
float parsed = 0.0f;
try {
size_t consumed = 0;
parsed = std::stof(value, &consumed);
if (trim(value.substr(consumed)).size() != 0) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", token.c_str());
return;
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", key.c_str());
continue;
}
} catch (const std::exception&) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", token.c_str());
return;
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str());
continue;
}
if (key == "noise_clip_std") {
args.noise_clip_std = parsed;
} else if (key == "noise_scale_start") {
@ -1341,18 +1318,11 @@ static sd::Tensor<float> sample_lcm(denoise_cb_t model,
} else {
LOG_WARN("ignoring unknown lcm extra sample arg '%s'", key.c_str());
}
};
}
for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
parse_arg(raw.substr(start, pos - start));
start = pos + 1;
}
}
if (noise_scale_start_was_set && !noise_scale_end_was_set) {
args.noise_scale_end = args.noise_scale_start;
}
}
int steps = static_cast<int>(sigmas.size()) - 1;
for (int i = 0; i < steps; i++) {
@ -1879,6 +1849,113 @@ static sd::Tensor<float> sample_euler_ancestral_cfg_pp(denoise_cb_t model,
return x;
}
// https://github.com/ToyotaResearchInstitute/gradient-estimation-sampler
static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
sd::Tensor<float> x,
const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng,
bool is_flow_denoiser,
float eta,
const SamplerExtraArgs& extra_sample_args) {
float ge_gamma = 2.0f;
for (const auto& [key, value] : extra_sample_args) {
float parsed = 0.0f;
try {
size_t consumed = 0;
parsed = std::stof(value, &consumed);
if (trim(value.substr(consumed)).size() != 0) {
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
continue;
}
} catch (const std::exception&) {
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
continue;
}
if (key == "gamma") {
LOG_DEBUG("setting euler_ge gamma to %.2f", parsed);
ge_gamma = parsed;
} else {
LOG_WARN("ignoring unknown euler_ge extra sample arg '%s'", key.c_str());
}
}
int steps = static_cast<int>(sigmas.size()) - 1;
sd::Tensor<float> old_d;
bool has_old_d = false;
for (int i = 0; i < steps; i++) {
float sigma = sigmas[i];
float sigma_to = sigmas[i + 1];
auto denoised_opt = model(x, sigma, i + 1);
if (denoised_opt.pred.empty()) {
return {};
}
sd::Tensor<float> denoised = std::move(denoised_opt.pred);
if (sigma_to == 0.f) {
x = denoised;
} else {
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma, sigma_to, eta, is_flow_denoiser);
sd::Tensor<float> d = (x - denoised) / sigma;
float dt = sigma_down - sigma;
if (has_old_d) {
sd::Tensor<float> d_bar = d * ge_gamma + old_d * (1.0f - ge_gamma);
x += d_bar * dt;
} else {
x += d * dt;
}
old_d = std::move(d);
has_old_d = true;
if (sigma_up > 0.f) {
if (is_flow_denoiser) {
x *= alpha_scale;
}
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
}
}
}
return x;
}
static SamplerExtraArgs parse_sampler_args(const char* extra_sample_args) {
SamplerExtraArgs pairs;
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
return pairs;
}
auto trim = [](std::string value) -> std::string {
const char* whitespace = " \t\r\n";
size_t begin = value.find_first_not_of(whitespace);
if (begin == std::string::npos) {
return "";
}
size_t end = value.find_last_not_of(whitespace);
return value.substr(begin, end - begin + 1);
};
std::string raw(extra_sample_args);
size_t start = 0;
for (size_t pos = 0; pos <= raw.size(); ++pos) {
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
std::string item = raw.substr(start, pos - start);
std::string token = trim(item);
if (!token.empty()) {
size_t eq = token.find('=');
if (eq != std::string::npos) {
std::string key = trim(token.substr(0, eq));
std::string value = trim(token.substr(eq + 1));
pairs.emplace_back(std::move(key), std::move(value));
}
}
start = pos + 1;
}
}
return pairs;
}
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
denoise_cb_t model,
@ -1888,6 +1965,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
float eta,
bool is_flow_denoiser,
const char* extra_sample_args) {
SamplerExtraArgs extra_args = parse_sampler_args(extra_sample_args);
switch (method) {
case EULER_A_SAMPLE_METHOD:
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
@ -1907,7 +1985,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
case DPMPP2Mv2_SAMPLE_METHOD:
return sample_dpmpp_2m_v2(model, std::move(x), sigmas);
case LCM_SAMPLE_METHOD:
return sample_lcm(model, std::move(x), sigmas, rng, is_flow_denoiser, extra_sample_args);
return sample_lcm(model, std::move(x), sigmas, rng, is_flow_denoiser, extra_args);
case IPNDM_SAMPLE_METHOD:
return sample_ipndm(model, std::move(x), sigmas);
case IPNDM_V_SAMPLE_METHOD:
@ -1927,6 +2005,8 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
return sample_euler_cfg_pp(model, std::move(x), sigmas);
case EULER_A_CFG_PP_SAMPLE_METHOD:
return sample_euler_ancestral_cfg_pp(model, std::move(x), sigmas, rng, eta);
case EULER_GE_SAMPLE_METHOD:
return sample_gradient_estimation(model, std::move(x), sigmas, rng, is_flow_denoiser, eta, extra_args);
default:
return {};
}

View File

@ -81,6 +81,7 @@ const char* sampling_methods_str[] = {
"ER-SDE",
"Euler CFG++",
"Euler A CFG++",
"Euler GE",
};
/*================================================== Helper Functions ================================================*/
@ -2282,6 +2283,7 @@ const char* sample_method_to_str[] = {
"er_sde",
"euler_cfg_pp",
"euler_a_cfg_pp",
"euler_ge",
};
const char* sd_sample_method_name(enum sample_method_t sample_method) {