mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-09 15:56:39 +00:00
feat: add Gradient Estimation sampler (#1484)
This commit is contained in:
parent
50134e51dd
commit
e7eb92fd84
@ -105,7 +105,7 @@ Generation Options:
|
|||||||
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
||||||
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||||
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||||
stretch, terminal
|
stretch, terminal; euler_ge supports gamma
|
||||||
-H, --height <int> image height, in pixel space (default: 512)
|
-H, --height <int> image height, in pixel space (default: 512)
|
||||||
-W, --width <int> image width, in pixel space (default: 512)
|
-W, --width <int> image width, in pixel space (default: 512)
|
||||||
--steps <int> number of sample steps (default: 20)
|
--steps <int> number of sample steps (default: 20)
|
||||||
|
|||||||
@ -833,7 +833,7 @@ ArgOptions SDGenerationParams::get_options() {
|
|||||||
&hires_upscaler},
|
&hires_upscaler},
|
||||||
{"",
|
{"",
|
||||||
"--extra-sample-args",
|
"--extra-sample-args",
|
||||||
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal",
|
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
|
||||||
&extra_sample_args},
|
&extra_sample_args},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -207,7 +207,7 @@ Default Generation Options:
|
|||||||
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
|
||||||
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
|
||||||
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
|
||||||
stretch, terminal
|
stretch, terminal; euler_ge supports gamma
|
||||||
-H, --height <int> image height, in pixel space (default: 512)
|
-H, --height <int> image height, in pixel space (default: 512)
|
||||||
-W, --width <int> image width, in pixel space (default: 512)
|
-W, --width <int> image width, in pixel space (default: 512)
|
||||||
--steps <int> number of sample steps (default: 20)
|
--steps <int> number of sample steps (default: 20)
|
||||||
|
|||||||
@ -53,6 +53,7 @@ enum sample_method_t {
|
|||||||
ER_SDE_SAMPLE_METHOD,
|
ER_SDE_SAMPLE_METHOD,
|
||||||
EULER_CFG_PP_SAMPLE_METHOD,
|
EULER_CFG_PP_SAMPLE_METHOD,
|
||||||
EULER_A_CFG_PP_SAMPLE_METHOD,
|
EULER_A_CFG_PP_SAMPLE_METHOD,
|
||||||
|
EULER_GE_SAMPLE_METHOD,
|
||||||
SAMPLE_METHOD_COUNT
|
SAMPLE_METHOD_COUNT
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
160
src/denoiser.hpp
160
src/denoiser.hpp
@ -1276,60 +1276,37 @@ static sd::Tensor<float> sample_dpmpp_2m_v2(denoise_cb_t model,
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
using SamplerExtraArgs = std::vector<std::pair<std::string, std::string>>;
|
||||||
|
|
||||||
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
||||||
sd::Tensor<float> x,
|
sd::Tensor<float> x,
|
||||||
const std::vector<float>& sigmas,
|
const std::vector<float>& sigmas,
|
||||||
std::shared_ptr<RNG> rng,
|
std::shared_ptr<RNG> rng,
|
||||||
bool is_flow_denoiser,
|
bool is_flow_denoiser,
|
||||||
const char* extra_sample_args = nullptr) {
|
const SamplerExtraArgs& extra_sample_args) {
|
||||||
struct LCMSampleArgs {
|
struct LCMSampleArgs {
|
||||||
float noise_clip_std = 0.0f;
|
float noise_clip_std = 0.0f;
|
||||||
float noise_scale_start = 1.0f;
|
float noise_scale_start = 1.0f;
|
||||||
float noise_scale_end = 1.0f;
|
float noise_scale_end = 1.0f;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto trim = [](std::string value) -> std::string {
|
|
||||||
const char* whitespace = " \t\r\n";
|
|
||||||
size_t begin = value.find_first_not_of(whitespace);
|
|
||||||
if (begin == std::string::npos) {
|
|
||||||
return "";
|
|
||||||
}
|
|
||||||
size_t end = value.find_last_not_of(whitespace);
|
|
||||||
return value.substr(begin, end - begin + 1);
|
|
||||||
};
|
|
||||||
|
|
||||||
LCMSampleArgs args;
|
LCMSampleArgs args;
|
||||||
if (extra_sample_args != nullptr && extra_sample_args[0] != '\0') {
|
|
||||||
std::string raw(extra_sample_args);
|
|
||||||
size_t start = 0;
|
|
||||||
bool noise_scale_end_was_set = false;
|
bool noise_scale_end_was_set = false;
|
||||||
bool noise_scale_start_was_set = false;
|
bool noise_scale_start_was_set = false;
|
||||||
auto parse_arg = [&](const std::string& item) {
|
|
||||||
std::string token = trim(item);
|
|
||||||
if (token.empty()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
size_t eq = token.find('=');
|
|
||||||
if (eq == std::string::npos) {
|
|
||||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", token.c_str());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string key = trim(token.substr(0, eq));
|
for (const auto& [key, value] : extra_sample_args) {
|
||||||
std::string value = trim(token.substr(eq + 1));
|
|
||||||
float parsed = 0.0f;
|
float parsed = 0.0f;
|
||||||
try {
|
try {
|
||||||
size_t consumed = 0;
|
size_t consumed = 0;
|
||||||
parsed = std::stof(value, &consumed);
|
parsed = std::stof(value, &consumed);
|
||||||
if (trim(value.substr(consumed)).size() != 0) {
|
if (trim(value.substr(consumed)).size() != 0) {
|
||||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", token.c_str());
|
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", key.c_str());
|
||||||
return;
|
continue;
|
||||||
}
|
}
|
||||||
} catch (const std::exception&) {
|
} catch (const std::exception&) {
|
||||||
LOG_WARN("ignoring invalid lcm extra sample arg '%s'", token.c_str());
|
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str());
|
||||||
return;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (key == "noise_clip_std") {
|
if (key == "noise_clip_std") {
|
||||||
args.noise_clip_std = parsed;
|
args.noise_clip_std = parsed;
|
||||||
} else if (key == "noise_scale_start") {
|
} else if (key == "noise_scale_start") {
|
||||||
@ -1341,18 +1318,11 @@ static sd::Tensor<float> sample_lcm(denoise_cb_t model,
|
|||||||
} else {
|
} else {
|
||||||
LOG_WARN("ignoring unknown lcm extra sample arg '%s'", key.c_str());
|
LOG_WARN("ignoring unknown lcm extra sample arg '%s'", key.c_str());
|
||||||
}
|
}
|
||||||
};
|
}
|
||||||
|
|
||||||
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
|
||||||
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
|
|
||||||
parse_arg(raw.substr(start, pos - start));
|
|
||||||
start = pos + 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (noise_scale_start_was_set && !noise_scale_end_was_set) {
|
if (noise_scale_start_was_set && !noise_scale_end_was_set) {
|
||||||
args.noise_scale_end = args.noise_scale_start;
|
args.noise_scale_end = args.noise_scale_start;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
int steps = static_cast<int>(sigmas.size()) - 1;
|
int steps = static_cast<int>(sigmas.size()) - 1;
|
||||||
for (int i = 0; i < steps; i++) {
|
for (int i = 0; i < steps; i++) {
|
||||||
@ -1879,6 +1849,113 @@ static sd::Tensor<float> sample_euler_ancestral_cfg_pp(denoise_cb_t model,
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/ToyotaResearchInstitute/gradient-estimation-sampler
|
||||||
|
static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
|
||||||
|
sd::Tensor<float> x,
|
||||||
|
const std::vector<float>& sigmas,
|
||||||
|
std::shared_ptr<RNG> rng,
|
||||||
|
bool is_flow_denoiser,
|
||||||
|
float eta,
|
||||||
|
const SamplerExtraArgs& extra_sample_args) {
|
||||||
|
float ge_gamma = 2.0f;
|
||||||
|
|
||||||
|
for (const auto& [key, value] : extra_sample_args) {
|
||||||
|
float parsed = 0.0f;
|
||||||
|
try {
|
||||||
|
size_t consumed = 0;
|
||||||
|
parsed = std::stof(value, &consumed);
|
||||||
|
if (trim(value.substr(consumed)).size() != 0) {
|
||||||
|
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s'", key.c_str());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (key == "gamma") {
|
||||||
|
LOG_DEBUG("setting euler_ge gamma to %.2f", parsed);
|
||||||
|
ge_gamma = parsed;
|
||||||
|
} else {
|
||||||
|
LOG_WARN("ignoring unknown euler_ge extra sample arg '%s'", key.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int steps = static_cast<int>(sigmas.size()) - 1;
|
||||||
|
sd::Tensor<float> old_d;
|
||||||
|
bool has_old_d = false;
|
||||||
|
for (int i = 0; i < steps; i++) {
|
||||||
|
float sigma = sigmas[i];
|
||||||
|
float sigma_to = sigmas[i + 1];
|
||||||
|
auto denoised_opt = model(x, sigma, i + 1);
|
||||||
|
if (denoised_opt.pred.empty()) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
sd::Tensor<float> denoised = std::move(denoised_opt.pred);
|
||||||
|
if (sigma_to == 0.f) {
|
||||||
|
x = denoised;
|
||||||
|
} else {
|
||||||
|
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma, sigma_to, eta, is_flow_denoiser);
|
||||||
|
sd::Tensor<float> d = (x - denoised) / sigma;
|
||||||
|
float dt = sigma_down - sigma;
|
||||||
|
if (has_old_d) {
|
||||||
|
sd::Tensor<float> d_bar = d * ge_gamma + old_d * (1.0f - ge_gamma);
|
||||||
|
x += d_bar * dt;
|
||||||
|
} else {
|
||||||
|
x += d * dt;
|
||||||
|
}
|
||||||
|
old_d = std::move(d);
|
||||||
|
has_old_d = true;
|
||||||
|
if (sigma_up > 0.f) {
|
||||||
|
if (is_flow_denoiser) {
|
||||||
|
x *= alpha_scale;
|
||||||
|
}
|
||||||
|
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SamplerExtraArgs parse_sampler_args(const char* extra_sample_args) {
|
||||||
|
SamplerExtraArgs pairs;
|
||||||
|
|
||||||
|
if (extra_sample_args == nullptr || extra_sample_args[0] == '\0') {
|
||||||
|
return pairs;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto trim = [](std::string value) -> std::string {
|
||||||
|
const char* whitespace = " \t\r\n";
|
||||||
|
size_t begin = value.find_first_not_of(whitespace);
|
||||||
|
if (begin == std::string::npos) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
size_t end = value.find_last_not_of(whitespace);
|
||||||
|
return value.substr(begin, end - begin + 1);
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string raw(extra_sample_args);
|
||||||
|
size_t start = 0;
|
||||||
|
|
||||||
|
for (size_t pos = 0; pos <= raw.size(); ++pos) {
|
||||||
|
if (pos == raw.size() || raw[pos] == ',' || raw[pos] == ';') {
|
||||||
|
std::string item = raw.substr(start, pos - start);
|
||||||
|
std::string token = trim(item);
|
||||||
|
|
||||||
|
if (!token.empty()) {
|
||||||
|
size_t eq = token.find('=');
|
||||||
|
if (eq != std::string::npos) {
|
||||||
|
std::string key = trim(token.substr(0, eq));
|
||||||
|
std::string value = trim(token.substr(eq + 1));
|
||||||
|
pairs.emplace_back(std::move(key), std::move(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
start = pos + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pairs;
|
||||||
|
}
|
||||||
|
|
||||||
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
|
// k diffusion reverse ODE: dx = (x - D(x;\sigma)) / \sigma dt; \sigma(t) = t
|
||||||
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
||||||
denoise_cb_t model,
|
denoise_cb_t model,
|
||||||
@ -1888,6 +1965,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
|||||||
float eta,
|
float eta,
|
||||||
bool is_flow_denoiser,
|
bool is_flow_denoiser,
|
||||||
const char* extra_sample_args) {
|
const char* extra_sample_args) {
|
||||||
|
SamplerExtraArgs extra_args = parse_sampler_args(extra_sample_args);
|
||||||
switch (method) {
|
switch (method) {
|
||||||
case EULER_A_SAMPLE_METHOD:
|
case EULER_A_SAMPLE_METHOD:
|
||||||
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
|
return sample_euler_ancestral(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
|
||||||
@ -1907,7 +1985,7 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
|||||||
case DPMPP2Mv2_SAMPLE_METHOD:
|
case DPMPP2Mv2_SAMPLE_METHOD:
|
||||||
return sample_dpmpp_2m_v2(model, std::move(x), sigmas);
|
return sample_dpmpp_2m_v2(model, std::move(x), sigmas);
|
||||||
case LCM_SAMPLE_METHOD:
|
case LCM_SAMPLE_METHOD:
|
||||||
return sample_lcm(model, std::move(x), sigmas, rng, is_flow_denoiser, extra_sample_args);
|
return sample_lcm(model, std::move(x), sigmas, rng, is_flow_denoiser, extra_args);
|
||||||
case IPNDM_SAMPLE_METHOD:
|
case IPNDM_SAMPLE_METHOD:
|
||||||
return sample_ipndm(model, std::move(x), sigmas);
|
return sample_ipndm(model, std::move(x), sigmas);
|
||||||
case IPNDM_V_SAMPLE_METHOD:
|
case IPNDM_V_SAMPLE_METHOD:
|
||||||
@ -1927,6 +2005,8 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
|
|||||||
return sample_euler_cfg_pp(model, std::move(x), sigmas);
|
return sample_euler_cfg_pp(model, std::move(x), sigmas);
|
||||||
case EULER_A_CFG_PP_SAMPLE_METHOD:
|
case EULER_A_CFG_PP_SAMPLE_METHOD:
|
||||||
return sample_euler_ancestral_cfg_pp(model, std::move(x), sigmas, rng, eta);
|
return sample_euler_ancestral_cfg_pp(model, std::move(x), sigmas, rng, eta);
|
||||||
|
case EULER_GE_SAMPLE_METHOD:
|
||||||
|
return sample_gradient_estimation(model, std::move(x), sigmas, rng, is_flow_denoiser, eta, extra_args);
|
||||||
default:
|
default:
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -81,6 +81,7 @@ const char* sampling_methods_str[] = {
|
|||||||
"ER-SDE",
|
"ER-SDE",
|
||||||
"Euler CFG++",
|
"Euler CFG++",
|
||||||
"Euler A CFG++",
|
"Euler A CFG++",
|
||||||
|
"Euler GE",
|
||||||
};
|
};
|
||||||
|
|
||||||
/*================================================== Helper Functions ================================================*/
|
/*================================================== Helper Functions ================================================*/
|
||||||
@ -2282,6 +2283,7 @@ const char* sample_method_to_str[] = {
|
|||||||
"er_sde",
|
"er_sde",
|
||||||
"euler_cfg_pp",
|
"euler_cfg_pp",
|
||||||
"euler_a_cfg_pp",
|
"euler_a_cfg_pp",
|
||||||
|
"euler_ge",
|
||||||
};
|
};
|
||||||
|
|
||||||
const char* sd_sample_method_name(enum sample_method_t sample_method) {
|
const char* sd_sample_method_name(enum sample_method_t sample_method) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user