feat: adapt res samplers for flow models for eta > 0 (#1436)

This commit is contained in:
Wagner Bruna 2026-05-06 10:49:06 -03:00 committed by GitHub
parent 9097ce5211
commit 586b6f1481
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -808,6 +808,18 @@ static std::tuple<float, float, float> get_ancestral_step_flow(float sigma_from,
return {sigma_down, sigma_up, alpha_scale}; return {sigma_down, sigma_up, alpha_scale};
} }
static std::tuple<float, float, float> get_ancestral_step(float sigma_from,
float sigma_to,
float eta,
bool is_flow_denoiser) {
if (is_flow_denoiser) {
return get_ancestral_step_flow(sigma_from, sigma_to, eta);
} else {
auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta);
return {sigma_down, sigma_up, 1.0f};
}
}
static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model, static sd::Tensor<float> sample_euler_ancestral(denoise_cb_t model,
sd::Tensor<float> x, sd::Tensor<float> x,
const std::vector<float>& sigmas, const std::vector<float>& sigmas,
@ -1247,6 +1259,7 @@ static sd::Tensor<float> sample_res_multistep(denoise_cb_t model,
sd::Tensor<float> x, sd::Tensor<float> x,
const std::vector<float>& sigmas, const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng, std::shared_ptr<RNG> rng,
bool is_flow_denoiser,
float eta) { float eta) {
sd::Tensor<float> old_denoised = x; sd::Tensor<float> old_denoised = x;
bool have_old_sigma = false; bool have_old_sigma = false;
@ -1278,7 +1291,8 @@ static sd::Tensor<float> sample_res_multistep(denoise_cb_t model,
float sigma_from = sigmas[i]; float sigma_from = sigmas[i];
float sigma_to = sigmas[i + 1]; float sigma_to = sigmas[i + 1];
auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta);
auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma_from, sigma_to, eta, is_flow_denoiser);
if (sigma_down == 0.0f || !have_old_sigma) { if (sigma_down == 0.0f || !have_old_sigma) {
x += ((x - denoised) / sigma_from) * (sigma_down - sigma_from); x += ((x - denoised) / sigma_from) * (sigma_down - sigma_from);
@ -1305,7 +1319,10 @@ static sd::Tensor<float> sample_res_multistep(denoise_cb_t model,
x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised); x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised);
} }
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) { if (sigma_to > 0.0f && sigma_up > 0.0f) {
if (is_flow_denoiser) {
x *= alpha_scale;
}
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up; x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
} }
@ -1320,6 +1337,7 @@ static sd::Tensor<float> sample_res_2s(denoise_cb_t model,
sd::Tensor<float> x, sd::Tensor<float> x,
const std::vector<float>& sigmas, const std::vector<float>& sigmas,
std::shared_ptr<RNG> rng, std::shared_ptr<RNG> rng,
bool is_flow_denoiser,
float eta) { float eta) {
const float c2 = 0.5f; const float c2 = 0.5f;
auto t_fn = [](float sigma) -> float { return -logf(sigma); }; auto t_fn = [](float sigma) -> float { return -logf(sigma); };
@ -1348,7 +1366,7 @@ static sd::Tensor<float> sample_res_2s(denoise_cb_t model,
} }
sd::Tensor<float> denoised = std::move(denoised_opt); sd::Tensor<float> denoised = std::move(denoised_opt);
auto [sigma_down, sigma_up] = get_ancestral_step(sigma_from, sigma_to, eta); auto [sigma_down, sigma_up, alpha_scale] = get_ancestral_step(sigma_from, sigma_to, eta, is_flow_denoiser);
sd::Tensor<float> x0 = x; sd::Tensor<float> x0 = x;
if (sigma_down == 0.0f || sigma_from == 0.0f) { if (sigma_down == 0.0f || sigma_from == 0.0f) {
@ -1377,7 +1395,10 @@ static sd::Tensor<float> sample_res_2s(denoise_cb_t model,
x = x0 + h * (b1 * eps1 + b2 * eps2); x = x0 + h * (b1 * eps1 + b2 * eps2);
} }
if (sigmas[i + 1] > 0 && sigma_up > 0.0f) { if (sigma_to > 0.0f && sigma_up > 0.0f) {
if (is_flow_denoiser) {
x *= alpha_scale;
}
x += sd::Tensor<float>::randn_like(x, rng) * sigma_up; x += sd::Tensor<float>::randn_like(x, rng) * sigma_up;
} }
} }
@ -1664,9 +1685,9 @@ static sd::Tensor<float> sample_k_diffusion(sample_method_t method,
case IPNDM_V_SAMPLE_METHOD: case IPNDM_V_SAMPLE_METHOD:
return sample_ipndm_v(model, std::move(x), sigmas); return sample_ipndm_v(model, std::move(x), sigmas);
case RES_MULTISTEP_SAMPLE_METHOD: case RES_MULTISTEP_SAMPLE_METHOD:
return sample_res_multistep(model, std::move(x), sigmas, rng, eta); return sample_res_multistep(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
case RES_2S_SAMPLE_METHOD: case RES_2S_SAMPLE_METHOD:
return sample_res_2s(model, std::move(x), sigmas, rng, eta); return sample_res_2s(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
case ER_SDE_SAMPLE_METHOD: case ER_SDE_SAMPLE_METHOD:
return sample_er_sde(model, std::move(x), sigmas, rng, is_flow_denoiser, eta); return sample_er_sde(model, std::move(x), sigmas, rng, is_flow_denoiser, eta);
case DDIM_TRAILING_SAMPLE_METHOD: case DDIM_TRAILING_SAMPLE_METHOD: