mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-23 14:46:39 +00:00
feat: support guidance_schedule (#1684)
This commit is contained in:
parent
b395a6972d
commit
41f7acbfb0
@ -960,7 +960,7 @@ ArgOptions SDGenerationParams::get_options() {
|
|||||||
&hires_upscaler},
|
&hires_upscaler},
|
||||||
{"",
|
{"",
|
||||||
"--extra-sample-args",
|
"--extra-sample-args",
|
||||||
"extra sampler/scheduler/guidance args, key=value list. APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma",
|
"extra sampler/scheduler/guidance args, key=value list. CFG supports guidance_schedule; APG supports apg_eta, apg_momentum, apg_norm_threshold, apg_norm_threshold_smoothing; SLG supports slg_uncond; lcm supports noise_clip_std, noise_scale_start, noise_scale_end; ltx2 supports max_shift, base_shift, stretch, terminal; euler_ge supports gamma;",
|
||||||
(int)',',
|
(int)',',
|
||||||
&extra_sample_args},
|
&extra_sample_args},
|
||||||
{"",
|
{"",
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "core/util.h"
|
#include "core/util.h"
|
||||||
|
|
||||||
@ -63,6 +64,82 @@ namespace sd::guidance {
|
|||||||
return uncond;
|
return uncond;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<float> parse_guidance_schedule_from_spec(std::string spec) {
|
||||||
|
std::vector<float> schedule;
|
||||||
|
|
||||||
|
while (!spec.empty()) {
|
||||||
|
auto sep = spec.find('+');
|
||||||
|
auto segment = spec.substr(0, sep);
|
||||||
|
|
||||||
|
auto x = segment.find('x');
|
||||||
|
if (x == std::string::npos) {
|
||||||
|
LOG_ERROR("Invalid guidance schedule segment: '%s' (expected <guidance>x<count>)", segment.c_str());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
float guidance;
|
||||||
|
int count;
|
||||||
|
|
||||||
|
auto guidance_str = segment.substr(0, x);
|
||||||
|
auto count_str = segment.substr(x + 1);
|
||||||
|
|
||||||
|
try {
|
||||||
|
size_t idx = 0;
|
||||||
|
guidance = std::stof(guidance_str, &idx);
|
||||||
|
if (idx != guidance_str.size()) {
|
||||||
|
LOG_ERROR("Invalid guidance value in guidance schedule: '%s'", guidance_str.c_str());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
LOG_ERROR("Invalid guidance value in guidance schedule: '%s'", guidance_str.c_str());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
size_t idx = 0;
|
||||||
|
count = std::stoi(count_str, &idx);
|
||||||
|
if (idx != count_str.size()) {
|
||||||
|
LOG_ERROR("Invalid count in guidance schedule: '%s'", count_str.c_str());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
} catch (const std::exception&) {
|
||||||
|
LOG_ERROR("Invalid count in guidance schedule: '%s'", count_str.c_str());
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (count <= 0) {
|
||||||
|
LOG_ERROR("Guidance schedule count must be positive");
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
schedule.insert(schedule.end(), count, guidance);
|
||||||
|
|
||||||
|
if (sep == std::string::npos) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
spec = spec.substr(sep + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return schedule;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> parse_guidance_schedule(const char* extra_sample_args) {
|
||||||
|
std::vector<float> guidance_schedule;
|
||||||
|
std::string guidance_schedule_str = "";
|
||||||
|
for (const auto& [key, value] : parse_key_value_args(extra_sample_args, "extra sample arg")) {
|
||||||
|
float parsed = 0.0f;
|
||||||
|
if (key == "guidance_schedule") {
|
||||||
|
guidance_schedule_str = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!guidance_schedule_str.empty()) {
|
||||||
|
guidance_schedule = parse_guidance_schedule_from_spec(guidance_schedule_str);
|
||||||
|
}
|
||||||
|
return guidance_schedule;
|
||||||
|
}
|
||||||
|
|
||||||
ClassifierFreeGuidance::ClassifierFreeGuidance(float guidance_scale,
|
ClassifierFreeGuidance::ClassifierFreeGuidance(float guidance_scale,
|
||||||
float image_guidance_scale)
|
float image_guidance_scale)
|
||||||
: guidance_scale_(guidance_scale),
|
: guidance_scale_(guidance_scale),
|
||||||
@ -70,8 +147,10 @@ namespace sd::guidance {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GuiderOutput ClassifierFreeGuidance::forward(const GuidanceInput& input,
|
GuiderOutput ClassifierFreeGuidance::forward(const GuidanceInput& input,
|
||||||
GuiderOutput previous) const {
|
GuiderOutput previous,
|
||||||
|
std::optional<float> scale_override) const {
|
||||||
(void)previous;
|
(void)previous;
|
||||||
|
float guidance_scale = scale_override.value_or(guidance_scale_);
|
||||||
|
|
||||||
GuiderOutput output;
|
GuiderOutput output;
|
||||||
if (!has_tensor(input.pred_cond)) {
|
if (!has_tensor(input.pred_cond)) {
|
||||||
@ -86,14 +165,14 @@ namespace sd::guidance {
|
|||||||
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
|
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
|
||||||
output.pred = pred_img_uncond +
|
output.pred = pred_img_uncond +
|
||||||
image_guidance_scale_ * (pred_uncond - pred_img_uncond) +
|
image_guidance_scale_ * (pred_uncond - pred_img_uncond) +
|
||||||
guidance_scale_ * (pred_cond - pred_uncond);
|
guidance_scale * (pred_cond - pred_uncond);
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
output.pred = pred_uncond + guidance_scale_ * (pred_cond - pred_uncond);
|
output.pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond);
|
||||||
}
|
}
|
||||||
} else if (has_tensor(input.pred_img_uncond)) {
|
} else if (has_tensor(input.pred_img_uncond)) {
|
||||||
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
|
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
|
||||||
output.pred = pred_img_uncond + guidance_scale_ * (pred_cond - pred_img_uncond);
|
output.pred = pred_img_uncond + guidance_scale * (pred_cond - pred_img_uncond);
|
||||||
}
|
}
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
@ -128,8 +207,10 @@ namespace sd::guidance {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GuiderOutput AdaptiveProjectedGuidance::forward(const GuidanceInput& input,
|
GuiderOutput AdaptiveProjectedGuidance::forward(const GuidanceInput& input,
|
||||||
GuiderOutput previous) const {
|
GuiderOutput previous,
|
||||||
|
std::optional<float> scale_override) const {
|
||||||
(void)previous;
|
(void)previous;
|
||||||
|
float guidance_scale = scale_override.value_or(guidance_scale_);
|
||||||
|
|
||||||
GuiderOutput output;
|
GuiderOutput output;
|
||||||
if (!has_tensor(input.pred_cond)) {
|
if (!has_tensor(input.pred_cond)) {
|
||||||
@ -144,13 +225,13 @@ namespace sd::guidance {
|
|||||||
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
|
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
|
||||||
output.pred = pred_img_uncond +
|
output.pred = pred_img_uncond +
|
||||||
image_guidance_scale_ * (pred_uncond - pred_img_uncond) +
|
image_guidance_scale_ * (pred_uncond - pred_img_uncond) +
|
||||||
guidance_scale_ * (pred_cond - pred_uncond);
|
guidance_scale * (pred_cond - pred_uncond);
|
||||||
} else {
|
} else {
|
||||||
output.pred = pred_uncond + guidance_scale_ * (pred_cond - pred_uncond);
|
output.pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond);
|
||||||
}
|
}
|
||||||
} else if (has_tensor(input.pred_img_uncond)) {
|
} else if (has_tensor(input.pred_img_uncond)) {
|
||||||
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
|
const sd::Tensor<float>& pred_img_uncond = *input.pred_img_uncond;
|
||||||
output.pred = pred_img_uncond + guidance_scale_ * (pred_cond - pred_img_uncond);
|
output.pred = pred_img_uncond + guidance_scale * (pred_cond - pred_img_uncond);
|
||||||
}
|
}
|
||||||
if (!has_tensor(input.pred_uncond) && !has_tensor(input.pred_img_uncond)) {
|
if (!has_tensor(input.pred_uncond) && !has_tensor(input.pred_img_uncond)) {
|
||||||
return output;
|
return output;
|
||||||
@ -162,7 +243,7 @@ namespace sd::guidance {
|
|||||||
sd::Tensor<float> deltas = calculate_guidance_delta(pred_cond,
|
sd::Tensor<float> deltas = calculate_guidance_delta(pred_cond,
|
||||||
pred_uncond,
|
pred_uncond,
|
||||||
pred_img_uncond,
|
pred_img_uncond,
|
||||||
guidance_scale_,
|
guidance_scale,
|
||||||
image_guidance_scale_);
|
image_guidance_scale_);
|
||||||
if (params_.momentum != 0.0f) {
|
if (params_.momentum != 0.0f) {
|
||||||
if (momentum_buffer_.shape() != deltas.shape()) {
|
if (momentum_buffer_.shape() != deltas.shape()) {
|
||||||
@ -239,7 +320,8 @@ namespace sd::guidance {
|
|||||||
}
|
}
|
||||||
|
|
||||||
GuiderOutput SkipLayerGuidance::forward(const GuidanceInput& input,
|
GuiderOutput SkipLayerGuidance::forward(const GuidanceInput& input,
|
||||||
GuiderOutput output) const {
|
GuiderOutput output,
|
||||||
|
std::optional<float> /*scale_override*/) const {
|
||||||
if (scale_ == 0.0f || !is_enabled_for_step(input) || !input.predict_skip_layer) {
|
if (scale_ == 0.0f || !is_enabled_for_step(input) || !input.predict_skip_layer) {
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <optional>
|
||||||
|
|
||||||
#include "core/tensor.hpp"
|
#include "core/tensor.hpp"
|
||||||
|
|
||||||
@ -27,6 +28,7 @@ namespace sd::guidance {
|
|||||||
AdaptiveProjectedGuidanceParams parse_adaptive_projected_guidance_args(const char* extra_sample_args);
|
AdaptiveProjectedGuidanceParams parse_adaptive_projected_guidance_args(const char* extra_sample_args);
|
||||||
bool is_adaptive_projected_guidance_enabled(const AdaptiveProjectedGuidanceParams& params);
|
bool is_adaptive_projected_guidance_enabled(const AdaptiveProjectedGuidanceParams& params);
|
||||||
bool parse_skip_layer_guidance_uncond_arg(const char* extra_sample_args);
|
bool parse_skip_layer_guidance_uncond_arg(const char* extra_sample_args);
|
||||||
|
std::vector<float> parse_guidance_schedule(const char* extra_sample_args);
|
||||||
|
|
||||||
struct GuidanceInput {
|
struct GuidanceInput {
|
||||||
int step = 0;
|
int step = 0;
|
||||||
@ -42,7 +44,8 @@ namespace sd::guidance {
|
|||||||
public:
|
public:
|
||||||
virtual ~BaseGuidance() = default;
|
virtual ~BaseGuidance() = default;
|
||||||
virtual GuiderOutput forward(const GuidanceInput& input,
|
virtual GuiderOutput forward(const GuidanceInput& input,
|
||||||
GuiderOutput previous) const = 0;
|
GuiderOutput previous,
|
||||||
|
std::optional<float> scale_override = std::nullopt) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class ClassifierFreeGuidance : public BaseGuidance {
|
class ClassifierFreeGuidance : public BaseGuidance {
|
||||||
@ -54,7 +57,8 @@ namespace sd::guidance {
|
|||||||
float image_guidance_scale);
|
float image_guidance_scale);
|
||||||
|
|
||||||
GuiderOutput forward(const GuidanceInput& input,
|
GuiderOutput forward(const GuidanceInput& input,
|
||||||
GuiderOutput previous) const override;
|
GuiderOutput previous,
|
||||||
|
std::optional<float> scale_override = std::nullopt) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class AdaptiveProjectedGuidance : public BaseGuidance {
|
class AdaptiveProjectedGuidance : public BaseGuidance {
|
||||||
@ -69,7 +73,8 @@ namespace sd::guidance {
|
|||||||
AdaptiveProjectedGuidanceParams params);
|
AdaptiveProjectedGuidanceParams params);
|
||||||
|
|
||||||
GuiderOutput forward(const GuidanceInput& input,
|
GuiderOutput forward(const GuidanceInput& input,
|
||||||
GuiderOutput previous) const override;
|
GuiderOutput previous,
|
||||||
|
std::optional<float> scale_override = std::nullopt) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SkipLayerGuidance : public BaseGuidance {
|
class SkipLayerGuidance : public BaseGuidance {
|
||||||
@ -88,7 +93,8 @@ namespace sd::guidance {
|
|||||||
const std::vector<int>& layers() const;
|
const std::vector<int>& layers() const;
|
||||||
|
|
||||||
GuiderOutput forward(const GuidanceInput& input,
|
GuiderOutput forward(const GuidanceInput& input,
|
||||||
GuiderOutput previous) const override;
|
GuiderOutput previous,
|
||||||
|
std::optional<float> scale_override = std::nullopt) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sd::guidance
|
} // namespace sd::guidance
|
||||||
|
|||||||
@ -1942,6 +1942,32 @@ public:
|
|||||||
float slg_scale = guidance.slg.scale;
|
float slg_scale = guidance.slg.scale;
|
||||||
bool slg_uncond = sd::guidance::parse_skip_layer_guidance_uncond_arg(extra_sample_args);
|
bool slg_uncond = sd::guidance::parse_skip_layer_guidance_uncond_arg(extra_sample_args);
|
||||||
|
|
||||||
|
std::vector<float> guidance_schedule = sd::guidance::parse_guidance_schedule(extra_sample_args);
|
||||||
|
if(!guidance_schedule.empty() && guidance_schedule.size() != sigmas.size() - 1) {
|
||||||
|
if(guidance_schedule.size() > sigmas.size()) {
|
||||||
|
LOG_WARN("guidance_schedule length (%zu) is greater than number of steps (%zu)", guidance_schedule.size(), sigmas.size() - 1);
|
||||||
|
LOG_WARN("truncating guidance_schedule to match step count");
|
||||||
|
guidance_schedule.resize(sigmas.size() - 1);
|
||||||
|
} else {
|
||||||
|
LOG_INFO("padding guidance_schedule with cfg_scale");
|
||||||
|
while(guidance_schedule.size() < sigmas.size() - 1) {
|
||||||
|
guidance_schedule.push_back(cfg_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if(!guidance_schedule.empty()) {
|
||||||
|
std::string schedule_str = "[";
|
||||||
|
for(size_t i = 0; i < guidance_schedule.size(); ++i) {
|
||||||
|
schedule_str += std::to_string(guidance_schedule[i]);
|
||||||
|
if(i < guidance_schedule.size() - 1) {
|
||||||
|
schedule_str += ", ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
schedule_str += "]";
|
||||||
|
LOG_DEBUG("using guidance schedule: %s", schedule_str.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
sd_sample::SampleCacheRuntime cache_runtime = sd_sample::init_sample_cache_runtime(version,
|
sd_sample::SampleCacheRuntime cache_runtime = sd_sample::init_sample_cache_runtime(version,
|
||||||
cache_params,
|
cache_params,
|
||||||
denoiser.get(),
|
denoiser.get(),
|
||||||
@ -2182,7 +2208,9 @@ public:
|
|||||||
guidance_input.pred_uncond = uncond_out.empty() ? nullptr : &uncond_out;
|
guidance_input.pred_uncond = uncond_out.empty() ? nullptr : &uncond_out;
|
||||||
guidance_input.pred_img_uncond = img_uncond_out.empty() ? nullptr : &img_uncond_out;
|
guidance_input.pred_img_uncond = img_uncond_out.empty() ? nullptr : &img_uncond_out;
|
||||||
|
|
||||||
sd::guidance::GuiderOutput guided = primary_guidance.forward(guidance_input, {});
|
sd::guidance::GuiderOutput guided = guidance_schedule.empty()?
|
||||||
|
primary_guidance.forward(guidance_input, {}):
|
||||||
|
primary_guidance.forward(guidance_input, {}, guidance_schedule[guidance_schedule.size() - 1 - step]);
|
||||||
if (guided.pred.empty()) {
|
if (guided.pred.empty()) {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user