feat: add support for APG (adaptive projected guidance) + unconditionnal SLG (#593)

This commit is contained in:
stduhpf 2026-05-31 18:55:49 +02:00 committed by GitHub
parent 20901f6d8e
commit be65ac7511
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 267 additions and 30 deletions

View File

@ -104,9 +104,10 @@ Generation Options:
--hires-upscaler <string> highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent
(nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
stretch, terminal; euler_ge supports gamma
--extra-sample-args <string> extra sampler/scheduler/guidance args, key=value list. APG supports apg_eta,
apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports
slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end;
ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma
--extra-tiling-args <string> extra VAE tiling args, key=value list. LTX video VAE supports
temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)
-H, --height <int> image height, in pixel space (default: 512)

View File

@ -860,7 +860,7 @@ ArgOptions SDGenerationParams::get_options() {
&hires_upscaler},
{"",
"--extra-sample-args",
"extra sampler/scheduler args, key=value list. lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
"extra sampler/scheduler/guidance args, key=value list. APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
&extra_sample_args},
{"",
"--extra-tiling-args",

View File

@ -206,9 +206,10 @@ Default Generation Options:
--hires-upscaler <string> highres fix upscaler, Lanczos, Nearest, Latent, Latent (nearest), Latent
(nearest-exact), Latent (antialiased), Latent (bicubic), Latent (bicubic
antialiased), or a model name under --hires-upscalers-dir (default: Latent)
--extra-sample-args <string> extra sampler/scheduler args, key=value list. lcm supports noise_clip_std,
noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift,
stretch, terminal; euler_ge supports gamma
--extra-sample-args <string> extra sampler/scheduler/guidance args, key=value list. APG supports apg_eta,
apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports
slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end;
ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma
--extra-tiling-args <string> extra VAE tiling args, key=value list. LTX video VAE supports
temporal_tile_frames (default: 4), temporal_tile_overlap (default: 1)
-H, --height <int> image height, in pixel space (default: 512)

View File

@ -514,8 +514,6 @@ struct LTX2Scheduler : SigmaScheduler {
if (!parse_strict_bool(value, stretch)) {
LOG_WARN("ignoring invalid ltx2 scheduler arg '%s=%s'", key.c_str(), value.c_str());
}
} else {
LOG_WARN("ignoring unknown ltx2 scheduler arg '%s'", key.c_str());
}
}
}
@ -1238,20 +1236,26 @@ static sd::Tensor<float> sample_lcm(denoise_cb_t model,
for (const auto& [key, value] : extra_sample_args) {
float parsed = 0.0f;
if (key == "noise_clip_std") {
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
if (key == "noise_clip_std") {
args.noise_clip_std = parsed;
} else if (key == "noise_scale_start") {
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
args.noise_scale_start = parsed;
noise_scale_start_was_set = true;
} else if (key == "noise_scale_end") {
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid lcm extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
args.noise_scale_end = parsed;
noise_scale_end_was_set = true;
} else {
LOG_WARN("ignoring unknown lcm extra sample arg '%s'", key.c_str());
}
}
@ -1795,16 +1799,14 @@ static sd::Tensor<float> sample_gradient_estimation(denoise_cb_t model,
float ge_gamma = 2.0f;
for (const auto& [key, value] : extra_sample_args) {
if (key == "gamma") {
float parsed = 0.0f;
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid euler_ge extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
if (key == "gamma") {
LOG_DEBUG("setting euler_ge gamma to %.2f", parsed);
ge_gamma = parsed;
} else {
LOG_WARN("ignoring unknown euler_ge extra sample arg '%s'", key.c_str());
}
}

View File

@ -1,13 +1,68 @@
#include "guidance.h"
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <string>
#include <utility>
#include "util.h"
namespace sd::guidance {
static bool has_tensor(const sd::Tensor<float>* tensor) {
return tensor != nullptr && !tensor->empty();
}
bool is_adaptive_projected_guidance_enabled(const AdaptiveProjectedGuidanceParams& params) {
return params.eta != 1.0f || params.momentum != 0.0f || params.norm_threshold > 0.0f;
}
AdaptiveProjectedGuidanceParams parse_adaptive_projected_guidance_args(const char* extra_sample_args) {
AdaptiveProjectedGuidanceParams params;
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "extra sample arg")) {
float parsed = 0.0f;
if (key == "apg_eta") {
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid APG extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
params.eta = parsed;
} else if (key == "apg_momentum") {
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid APG extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
params.momentum = parsed;
} else if (key == "apg_norm_threshold") {
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid APG extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
params.norm_threshold = parsed;
} else if (key == "apg_norm_threshold_smoothing") {
if (!parse_strict_float(value, parsed)) {
LOG_WARN("ignoring invalid APG extra sample arg '%s=%s'", key.c_str(), value.c_str());
continue;
}
params.norm_threshold_smoothing = parsed;
}
}
return params;
}
bool parse_skip_layer_guidance_uncond_arg(const char* extra_sample_args) {
bool uncond = false;
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "extra sample arg")) {
if (key == "slg_uncond") {
if (!parse_strict_bool(value, uncond)) {
LOG_WARN("ignoring invalid SLG extra sample arg '%s=%s'", key.c_str(), value.c_str());
}
}
}
return uncond;
}
ClassifierFreeGuidance::ClassifierFreeGuidance(float guidance_scale,
float image_guidance_scale)
: guidance_scale_(guidance_scale),
@ -43,6 +98,120 @@ namespace sd::guidance {
return output;
}
AdaptiveProjectedGuidance::AdaptiveProjectedGuidance(float guidance_scale,
float image_guidance_scale,
AdaptiveProjectedGuidanceParams params)
: guidance_scale_(guidance_scale),
image_guidance_scale_(image_guidance_scale),
params_(params) {
}
static sd::Tensor<float> calculate_guidance_delta(const sd::Tensor<float>& pred_cond,
const sd::Tensor<float>* pred_uncond,
const sd::Tensor<float>* pred_img_cond,
float guidance_scale,
float image_guidance_scale) {
if (pred_img_cond != nullptr) {
if (pred_uncond != nullptr && guidance_scale == 1.0f) {
return *pred_img_cond - *pred_uncond;
}
if (pred_uncond != nullptr) {
return pred_cond +
(*pred_uncond * (1.0f - image_guidance_scale) +
*pred_img_cond * (image_guidance_scale - guidance_scale)) /
(guidance_scale - 1.0f);
}
return pred_cond - *pred_img_cond;
}
return pred_cond - *pred_uncond;
}
GuiderOutput AdaptiveProjectedGuidance::forward(const GuidanceInput& input,
GuiderOutput previous) const {
(void)previous;
GuiderOutput output;
if (!has_tensor(input.pred_cond)) {
return output;
}
const sd::Tensor<float>& pred_cond = *input.pred_cond;
output.pred = pred_cond;
if (has_tensor(input.pred_uncond)) {
const sd::Tensor<float>& pred_uncond = *input.pred_uncond;
if (has_tensor(input.pred_img_cond)) {
const sd::Tensor<float>& pred_img_cond = *input.pred_img_cond;
output.pred = pred_uncond +
image_guidance_scale_ * (pred_img_cond - pred_uncond) +
guidance_scale_ * (pred_cond - pred_img_cond);
} else {
output.pred = pred_uncond + guidance_scale_ * (pred_cond - pred_uncond);
}
} else if (has_tensor(input.pred_img_cond)) {
const sd::Tensor<float>& pred_img_cond = *input.pred_img_cond;
output.pred = pred_img_cond + guidance_scale_ * (pred_cond - pred_img_cond);
}
if (!has_tensor(input.pred_uncond) && !has_tensor(input.pred_img_cond)) {
return output;
}
const sd::Tensor<float>* pred_uncond = input.pred_uncond;
const sd::Tensor<float>* pred_img_cond = input.pred_img_cond;
sd::Tensor<float> deltas = calculate_guidance_delta(pred_cond,
pred_uncond,
pred_img_cond,
guidance_scale_,
image_guidance_scale_);
if (params_.momentum != 0.0f) {
if (momentum_buffer_.shape() != deltas.shape()) {
momentum_buffer_ = sd::Tensor<float>::zeros_like(deltas);
}
deltas += params_.momentum * momentum_buffer_;
momentum_buffer_ = deltas;
}
float diff_norm = 0.0f;
if (params_.norm_threshold > 0.0f) {
diff_norm = std::sqrt((deltas * deltas).sum());
}
float apg_scale_factor = 1.0f;
if (params_.norm_threshold > 0.0f) {
if (diff_norm > 0.0f) {
if (params_.norm_threshold_smoothing <= 0.0f) {
apg_scale_factor = std::min(1.0f, params_.norm_threshold / diff_norm);
} else {
float x = params_.norm_threshold / diff_norm;
apg_scale_factor = x / std::pow(1.0f + std::pow(x, 1.0f / params_.norm_threshold_smoothing),
params_.norm_threshold_smoothing);
}
}
}
deltas *= apg_scale_factor;
if (params_.eta != 1.0f) {
float cond_norm_sq = (pred_cond * pred_cond).sum();
if (cond_norm_sq != 0.0f) {
float projection_scale = (pred_cond * deltas).sum() / cond_norm_sq;
deltas += (params_.eta - 1.0f) * (projection_scale * pred_cond);
}
}
output.pred = pred_cond;
if (pred_uncond != nullptr) {
if (guidance_scale_ != 1.0f) {
output.pred = pred_cond + (guidance_scale_ - 1.0f) * deltas;
} else if (pred_img_cond != nullptr) {
output.pred = pred_cond + (image_guidance_scale_ - 1.0f) * deltas;
}
} else if (pred_img_cond != nullptr) {
output.pred = *pred_img_cond + guidance_scale_ * deltas;
}
return output;
}
SkipLayerGuidance::SkipLayerGuidance(std::vector<int> layers,
float scale,
float start,
@ -54,7 +223,7 @@ namespace sd::guidance {
}
bool SkipLayerGuidance::is_enabled_for_step(const GuidanceInput& input) const {
if (scale_ == 0.0f || layers_.empty() || input.schedule_size == 0) {
if (layers_.empty() || input.schedule_size == 0) {
return false;
}
@ -69,7 +238,7 @@ namespace sd::guidance {
GuiderOutput SkipLayerGuidance::forward(const GuidanceInput& input,
GuiderOutput output) const {
if (!is_enabled_for_step(input) || !input.predict_skip_layer) {
if (scale_ == 0.0f || !is_enabled_for_step(input) || !input.predict_skip_layer) {
return output;
}

View File

@ -17,6 +17,17 @@ namespace sd::guidance {
sd::Tensor<float> pred_skip_layer;
};
struct AdaptiveProjectedGuidanceParams {
float eta = 1.0f;
float momentum = 0.0f;
float norm_threshold = 0.0f;
float norm_threshold_smoothing = 0.0f;
};
AdaptiveProjectedGuidanceParams parse_adaptive_projected_guidance_args(const char* extra_sample_args);
bool is_adaptive_projected_guidance_enabled(const AdaptiveProjectedGuidanceParams& params);
bool parse_skip_layer_guidance_uncond_arg(const char* extra_sample_args);
struct GuidanceInput {
int step = 0;
size_t schedule_size = 0;
@ -46,6 +57,21 @@ namespace sd::guidance {
GuiderOutput previous) const override;
};
class AdaptiveProjectedGuidance : public BaseGuidance {
float guidance_scale_ = 1.0f;
float image_guidance_scale_ = 1.0f;
AdaptiveProjectedGuidanceParams params_;
mutable sd::Tensor<float> momentum_buffer_;
public:
AdaptiveProjectedGuidance(float guidance_scale,
float image_guidance_scale,
AdaptiveProjectedGuidanceParams params);
GuiderOutput forward(const GuidanceInput& input,
GuiderOutput previous) const override;
};
class SkipLayerGuidance : public BaseGuidance {
std::vector<int> layers_;
float scale_ = 0.0f;

View File

@ -1941,6 +1941,7 @@ public:
float cfg_scale = guidance.txt_cfg;
float img_cfg_scale = guidance.img_cfg;
float slg_scale = guidance.slg.scale;
bool slg_uncond = sd::guidance::parse_skip_layer_guidance_uncond_arg(extra_sample_args);
sd_sample::SampleCacheRuntime cache_runtime = sd_sample::init_sample_cache_runtime(version,
cache_params,
@ -1957,12 +1958,21 @@ public:
}
size_t steps = sigmas.size() - 1;
bool has_skiplayer = slg_scale != 0.0f && !skip_layers.empty();
bool has_skiplayer = (slg_scale != 0.0f || slg_uncond) && !skip_layers.empty();
if (has_skiplayer && !sd_version_is_dit(version)) {
has_skiplayer = false;
LOG_WARN("SLG is incompatible with this model type");
}
sd::guidance::AdaptiveProjectedGuidanceParams apg_params = sd::guidance::parse_adaptive_projected_guidance_args(extra_sample_args);
bool use_apg_guidance = sd::guidance::is_adaptive_projected_guidance_enabled(apg_params);
if (use_apg_guidance) {
LOG_INFO("using Adaptive Projected Guidance (APG)");
}
sd::guidance::ClassifierFreeGuidance classifier_free_guidance(cfg_scale, img_cfg_scale);
sd::guidance::AdaptiveProjectedGuidance adaptive_projected_guidance(cfg_scale, img_cfg_scale, apg_params);
const sd::guidance::BaseGuidance& primary_guidance = use_apg_guidance
? static_cast<const sd::guidance::BaseGuidance&>(adaptive_projected_guidance)
: static_cast<const sd::guidance::BaseGuidance&>(classifier_free_guidance);
sd::guidance::SkipLayerGuidance skip_layer_guidance(has_skiplayer ? skip_layers : std::vector<int>(),
has_skiplayer ? slg_scale : 0.0f,
guidance.slg.layer_start,
@ -2038,6 +2048,10 @@ public:
diffusion_params.x = &noised_input;
diffusion_params.timesteps = &timesteps_tensor;
diffusion_params.increase_ref_index = increase_ref_index;
sd::guidance::GuidanceInput step_guidance_input;
step_guidance_input.step = step;
step_guidance_input.schedule_size = sigmas.size();
bool is_skiplayer_step = skip_layer_guidance.is_enabled_for_step(step_guidance_input);
compute_sample_controls(control_image,
noised_input,
@ -2121,7 +2135,12 @@ public:
uncond,
&controls);
}
uncond_out = run_condition(uncond);
const std::vector<int>* uncond_skip_layers = nullptr;
if (is_skiplayer_step && slg_uncond) {
LOG_DEBUG("Skipping layers at uncond step %d\n", step);
uncond_skip_layers = &skip_layer_guidance.layers();
}
uncond_out = run_condition(uncond, nullptr, uncond_skip_layers);
if (uncond_out.empty()) {
return {};
}
@ -2140,12 +2159,12 @@ public:
guidance_input.pred_uncond = uncond_out.empty() ? nullptr : &uncond_out;
guidance_input.pred_img_cond = img_cond_out.empty() ? nullptr : &img_cond_out;
sd::guidance::GuiderOutput guided = classifier_free_guidance.forward(guidance_input, {});
sd::guidance::GuiderOutput guided = primary_guidance.forward(guidance_input, {});
if (guided.pred.empty()) {
return {};
}
if (skip_layer_guidance.is_enabled_for_step(guidance_input)) {
if (is_skiplayer_step && slg_scale != 0.0f) {
LOG_DEBUG("Skipping layers at step %d\n", step);
if (!step_cache.is_step_skipped()) {
guidance_input.predict_skip_layer = [&]() -> sd::Tensor<float> {

View File

@ -235,6 +235,7 @@ namespace sd {
Tensor& masked_fill_(const Tensor<uint8_t>& mask, const T& value);
T sum() const;
T mean() const;
static Tensor zeros(std::vector<int64_t> shape) {
@ -327,6 +328,24 @@ namespace sd {
std::vector<int64_t> shape_;
};
template <typename T>
inline T Tensor<T>::sum() const {
T total = T{};
for (const T& value : data_) {
total += value;
}
return total;
}
template <>
inline float Tensor<float>::sum() const {
double total = 0.0;
for (float value : data_) {
total += static_cast<double>(value);
}
return static_cast<float>(total);
}
template <typename T>
inline T Tensor<T>::mean() const {
if (empty()) {