mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-23 17:58:58 +00:00
refactor: simplify sample cache flow (#1350)
This commit is contained in:
parent
5265a5efa1
commit
545fac4f3f
@ -103,6 +103,379 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) {
|
||||
return std::max(0.0f, reuse_threshold);
|
||||
}
|
||||
|
||||
enum class SampleCacheMode {
|
||||
NONE,
|
||||
EASYCACHE,
|
||||
UCACHE,
|
||||
CACHEDIT,
|
||||
};
|
||||
|
||||
struct SampleCacheRuntime {
|
||||
SampleCacheMode mode = SampleCacheMode::NONE;
|
||||
|
||||
EasyCacheState easycache;
|
||||
UCacheState ucache;
|
||||
CacheDitConditionState cachedit;
|
||||
SpectrumState spectrum;
|
||||
|
||||
bool spectrum_enabled = false;
|
||||
|
||||
bool has_step_cache() const {
|
||||
return mode != SampleCacheMode::NONE;
|
||||
}
|
||||
|
||||
bool easycache_enabled() const {
|
||||
return mode == SampleCacheMode::EASYCACHE;
|
||||
}
|
||||
|
||||
bool ucache_enabled() const {
|
||||
return mode == SampleCacheMode::UCACHE;
|
||||
}
|
||||
|
||||
bool cachedit_enabled() const {
|
||||
return mode == SampleCacheMode::CACHEDIT;
|
||||
}
|
||||
};
|
||||
|
||||
static bool has_valid_cache_percent_range(const sd_cache_params_t& cache_params) {
|
||||
if (cache_params.mode != SD_CACHE_EASYCACHE && cache_params.mode != SD_CACHE_UCACHE) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return cache_params.start_percent >= 0.0f &&
|
||||
cache_params.start_percent < 1.0f &&
|
||||
cache_params.end_percent > 0.0f &&
|
||||
cache_params.end_percent <= 1.0f &&
|
||||
cache_params.start_percent < cache_params.end_percent;
|
||||
}
|
||||
|
||||
static void init_easycache_runtime(SampleCacheRuntime& runtime,
|
||||
SDVersion version,
|
||||
const sd_cache_params_t& cache_params,
|
||||
Denoiser* denoiser) {
|
||||
if (!sd_version_is_dit(version)) {
|
||||
LOG_WARN("EasyCache requested but not supported for this model type");
|
||||
return;
|
||||
}
|
||||
|
||||
EasyCacheConfig config;
|
||||
config.enabled = true;
|
||||
config.reuse_threshold = get_cache_reuse_threshold(cache_params);
|
||||
config.start_percent = cache_params.start_percent;
|
||||
config.end_percent = cache_params.end_percent;
|
||||
|
||||
runtime.easycache.init(config, denoiser);
|
||||
if (!runtime.easycache.enabled()) {
|
||||
LOG_WARN("EasyCache requested but could not be initialized for this run");
|
||||
return;
|
||||
}
|
||||
|
||||
runtime.mode = SampleCacheMode::EASYCACHE;
|
||||
LOG_INFO("EasyCache enabled - threshold: %.3f, start: %.2f, end: %.2f",
|
||||
config.reuse_threshold,
|
||||
config.start_percent,
|
||||
config.end_percent);
|
||||
}
|
||||
|
||||
static void init_ucache_runtime(SampleCacheRuntime& runtime,
|
||||
SDVersion version,
|
||||
const sd_cache_params_t& cache_params,
|
||||
Denoiser* denoiser,
|
||||
const std::vector<float>& sigmas) {
|
||||
if (!sd_version_is_unet(version)) {
|
||||
LOG_WARN("UCache requested but not supported for this model type (only UNET models)");
|
||||
return;
|
||||
}
|
||||
|
||||
UCacheConfig config;
|
||||
config.enabled = true;
|
||||
config.reuse_threshold = get_cache_reuse_threshold(cache_params);
|
||||
config.start_percent = cache_params.start_percent;
|
||||
config.end_percent = cache_params.end_percent;
|
||||
config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params.error_decay_rate));
|
||||
config.use_relative_threshold = cache_params.use_relative_threshold;
|
||||
config.reset_error_on_compute = cache_params.reset_error_on_compute;
|
||||
|
||||
runtime.ucache.init(config, denoiser);
|
||||
if (!runtime.ucache.enabled()) {
|
||||
LOG_WARN("UCache requested but could not be initialized for this run");
|
||||
return;
|
||||
}
|
||||
|
||||
runtime.ucache.set_sigmas(sigmas);
|
||||
runtime.mode = SampleCacheMode::UCACHE;
|
||||
LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s, reset: %s",
|
||||
config.reuse_threshold,
|
||||
config.start_percent,
|
||||
config.end_percent,
|
||||
config.error_decay_rate,
|
||||
config.use_relative_threshold ? "true" : "false",
|
||||
config.reset_error_on_compute ? "true" : "false");
|
||||
}
|
||||
|
||||
static void init_cachedit_runtime(SampleCacheRuntime& runtime,
|
||||
SDVersion version,
|
||||
const sd_cache_params_t& cache_params,
|
||||
const std::vector<float>& sigmas) {
|
||||
if (!sd_version_is_dit(version)) {
|
||||
LOG_WARN("CacheDIT requested but not supported for this model type (only DiT models)");
|
||||
return;
|
||||
}
|
||||
|
||||
DBCacheConfig dbcfg;
|
||||
dbcfg.enabled = (cache_params.mode == SD_CACHE_DBCACHE ||
|
||||
cache_params.mode == SD_CACHE_CACHE_DIT);
|
||||
dbcfg.Fn_compute_blocks = cache_params.Fn_compute_blocks;
|
||||
dbcfg.Bn_compute_blocks = cache_params.Bn_compute_blocks;
|
||||
dbcfg.residual_diff_threshold = cache_params.residual_diff_threshold;
|
||||
dbcfg.max_warmup_steps = cache_params.max_warmup_steps;
|
||||
dbcfg.max_cached_steps = cache_params.max_cached_steps;
|
||||
dbcfg.max_continuous_cached_steps = cache_params.max_continuous_cached_steps;
|
||||
if (cache_params.scm_mask != nullptr && strlen(cache_params.scm_mask) > 0) {
|
||||
dbcfg.steps_computation_mask = parse_scm_mask(cache_params.scm_mask);
|
||||
}
|
||||
dbcfg.scm_policy_dynamic = cache_params.scm_policy_dynamic;
|
||||
|
||||
TaylorSeerConfig tcfg;
|
||||
tcfg.enabled = (cache_params.mode == SD_CACHE_TAYLORSEER ||
|
||||
cache_params.mode == SD_CACHE_CACHE_DIT);
|
||||
tcfg.n_derivatives = cache_params.taylorseer_n_derivatives;
|
||||
tcfg.skip_interval_steps = cache_params.taylorseer_skip_interval;
|
||||
|
||||
runtime.cachedit.init(dbcfg, tcfg);
|
||||
if (!runtime.cachedit.enabled()) {
|
||||
LOG_WARN("CacheDIT requested but could not be initialized for this run");
|
||||
return;
|
||||
}
|
||||
|
||||
runtime.cachedit.set_sigmas(sigmas);
|
||||
runtime.mode = SampleCacheMode::CACHEDIT;
|
||||
LOG_INFO("CacheDIT enabled - mode: %s, Fn: %d, Bn: %d, threshold: %.3f, warmup: %d",
|
||||
cache_params.mode == SD_CACHE_CACHE_DIT ? "DBCache+TaylorSeer" : (cache_params.mode == SD_CACHE_DBCACHE ? "DBCache" : "TaylorSeer"),
|
||||
dbcfg.Fn_compute_blocks,
|
||||
dbcfg.Bn_compute_blocks,
|
||||
dbcfg.residual_diff_threshold,
|
||||
dbcfg.max_warmup_steps);
|
||||
}
|
||||
|
||||
static void init_spectrum_runtime(SampleCacheRuntime& runtime,
|
||||
SDVersion version,
|
||||
const sd_cache_params_t& cache_params,
|
||||
const std::vector<float>& sigmas) {
|
||||
if (!sd_version_is_unet(version) && !sd_version_is_dit(version)) {
|
||||
LOG_WARN("Spectrum requested but not supported for this model type (only UNET and DiT models)");
|
||||
return;
|
||||
}
|
||||
|
||||
SpectrumConfig config;
|
||||
config.w = cache_params.spectrum_w;
|
||||
config.m = cache_params.spectrum_m;
|
||||
config.lam = cache_params.spectrum_lam;
|
||||
config.window_size = cache_params.spectrum_window_size;
|
||||
config.flex_window = cache_params.spectrum_flex_window;
|
||||
config.warmup_steps = cache_params.spectrum_warmup_steps;
|
||||
config.stop_percent = cache_params.spectrum_stop_percent;
|
||||
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
runtime.spectrum.init(config, total_steps);
|
||||
runtime.spectrum_enabled = true;
|
||||
|
||||
LOG_INFO("Spectrum enabled - w: %.2f, m: %d, lam: %.2f, window: %d, flex: %.2f, warmup: %d, stop: %.0f%%",
|
||||
config.w, config.m, config.lam,
|
||||
config.window_size, config.flex_window,
|
||||
config.warmup_steps, config.stop_percent * 100.0f);
|
||||
}
|
||||
|
||||
static SampleCacheRuntime init_sample_cache_runtime(SDVersion version,
|
||||
const sd_cache_params_t* cache_params,
|
||||
Denoiser* denoiser,
|
||||
const std::vector<float>& sigmas) {
|
||||
SampleCacheRuntime runtime;
|
||||
if (cache_params == nullptr || cache_params->mode == SD_CACHE_DISABLED) {
|
||||
return runtime;
|
||||
}
|
||||
|
||||
if (!has_valid_cache_percent_range(*cache_params)) {
|
||||
LOG_WARN("Cache disabled due to invalid percent range (start=%.3f, end=%.3f)",
|
||||
cache_params->start_percent,
|
||||
cache_params->end_percent);
|
||||
return runtime;
|
||||
}
|
||||
|
||||
switch (cache_params->mode) {
|
||||
case SD_CACHE_EASYCACHE:
|
||||
init_easycache_runtime(runtime, version, *cache_params, denoiser);
|
||||
break;
|
||||
case SD_CACHE_UCACHE:
|
||||
init_ucache_runtime(runtime, version, *cache_params, denoiser, sigmas);
|
||||
break;
|
||||
case SD_CACHE_DBCACHE:
|
||||
case SD_CACHE_TAYLORSEER:
|
||||
case SD_CACHE_CACHE_DIT:
|
||||
init_cachedit_runtime(runtime, version, *cache_params, sigmas);
|
||||
break;
|
||||
case SD_CACHE_SPECTRUM:
|
||||
init_spectrum_runtime(runtime, version, *cache_params, sigmas);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return runtime;
|
||||
}
|
||||
|
||||
struct SampleStepCacheDispatcher {
|
||||
SampleCacheRuntime& runtime;
|
||||
int step;
|
||||
float sigma;
|
||||
int step_index;
|
||||
|
||||
SampleStepCacheDispatcher(SampleCacheRuntime& runtime, int step, float sigma)
|
||||
: runtime(runtime), step(step), sigma(sigma), step_index(step > 0 ? (step - 1) : -1) {
|
||||
if (step_index < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (runtime.mode) {
|
||||
case SampleCacheMode::EASYCACHE:
|
||||
runtime.easycache.begin_step(step_index, sigma);
|
||||
break;
|
||||
case SampleCacheMode::UCACHE:
|
||||
runtime.ucache.begin_step(step_index, sigma);
|
||||
break;
|
||||
case SampleCacheMode::CACHEDIT:
|
||||
runtime.cachedit.begin_step(step_index, sigma);
|
||||
break;
|
||||
case SampleCacheMode::NONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool before_condition(const SDCondition* condition, ggml_tensor* input, ggml_tensor* output) {
|
||||
if (step_index < 0 || condition == nullptr || input == nullptr || output == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (runtime.mode) {
|
||||
case SampleCacheMode::EASYCACHE:
|
||||
return runtime.easycache.before_condition(condition, input, output, sigma, step_index);
|
||||
case SampleCacheMode::UCACHE:
|
||||
return runtime.ucache.before_condition(condition, input, output, sigma, step_index);
|
||||
case SampleCacheMode::CACHEDIT:
|
||||
return runtime.cachedit.before_condition(condition, input, output, sigma, step_index);
|
||||
case SampleCacheMode::NONE:
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void after_condition(const SDCondition* condition, ggml_tensor* input, ggml_tensor* output) {
|
||||
if (step_index < 0 || condition == nullptr || input == nullptr || output == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
switch (runtime.mode) {
|
||||
case SampleCacheMode::EASYCACHE:
|
||||
runtime.easycache.after_condition(condition, input, output);
|
||||
break;
|
||||
case SampleCacheMode::UCACHE:
|
||||
runtime.ucache.after_condition(condition, input, output);
|
||||
break;
|
||||
case SampleCacheMode::CACHEDIT:
|
||||
runtime.cachedit.after_condition(condition, input, output);
|
||||
break;
|
||||
case SampleCacheMode::NONE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool is_step_skipped() const {
|
||||
switch (runtime.mode) {
|
||||
case SampleCacheMode::EASYCACHE:
|
||||
return runtime.easycache.is_step_skipped();
|
||||
case SampleCacheMode::UCACHE:
|
||||
return runtime.ucache.is_step_skipped();
|
||||
case SampleCacheMode::CACHEDIT:
|
||||
return runtime.cachedit.is_step_skipped();
|
||||
case SampleCacheMode::NONE:
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
static void log_sample_cache_summary(const SampleCacheRuntime& runtime, size_t total_steps) {
|
||||
if (runtime.easycache_enabled()) {
|
||||
if (runtime.easycache.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (runtime.easycache.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - runtime.easycache.total_steps_skipped);
|
||||
LOG_INFO("EasyCache skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
runtime.easycache.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup);
|
||||
} else {
|
||||
LOG_INFO("EasyCache skipped %d/%zu steps",
|
||||
runtime.easycache.total_steps_skipped,
|
||||
total_steps);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("EasyCache completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (runtime.ucache_enabled()) {
|
||||
if (runtime.ucache.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (runtime.ucache.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - runtime.ucache.total_steps_skipped);
|
||||
LOG_INFO("UCache skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
runtime.ucache.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup);
|
||||
} else {
|
||||
LOG_INFO("UCache skipped %d/%zu steps",
|
||||
runtime.ucache.total_steps_skipped,
|
||||
total_steps);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("UCache completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (runtime.cachedit_enabled()) {
|
||||
if (runtime.cachedit.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (runtime.cachedit.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - runtime.cachedit.total_steps_skipped);
|
||||
LOG_INFO("CacheDIT skipped %d/%zu steps (%.2fx estimated speedup), accum_diff: %.4f",
|
||||
runtime.cachedit.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup,
|
||||
runtime.cachedit.accumulated_residual_diff);
|
||||
} else {
|
||||
LOG_INFO("CacheDIT skipped %d/%zu steps, accum_diff: %.4f",
|
||||
runtime.cachedit.total_steps_skipped,
|
||||
total_steps,
|
||||
runtime.cachedit.accumulated_residual_diff);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("CacheDIT completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (runtime.spectrum_enabled && runtime.spectrum.total_steps_skipped > 0 && total_steps > 0) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - runtime.spectrum.total_steps_skipped);
|
||||
LOG_INFO("Spectrum skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
runtime.spectrum.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup);
|
||||
}
|
||||
}
|
||||
|
||||
/*=============================================== StableDiffusionGGML ================================================*/
|
||||
|
||||
class StableDiffusionGGML {
|
||||
@ -1662,148 +2035,7 @@ public:
|
||||
img_cfg_scale = cfg_scale;
|
||||
}
|
||||
|
||||
EasyCacheState easycache_state;
|
||||
UCacheState ucache_state;
|
||||
CacheDitConditionState cachedit_state;
|
||||
SpectrumState spectrum_state;
|
||||
bool easycache_enabled = false;
|
||||
bool ucache_enabled = false;
|
||||
bool cachedit_enabled = false;
|
||||
bool spectrum_enabled = false;
|
||||
|
||||
if (cache_params != nullptr && cache_params->mode != SD_CACHE_DISABLED) {
|
||||
bool percent_valid = true;
|
||||
if (cache_params->mode == SD_CACHE_EASYCACHE || cache_params->mode == SD_CACHE_UCACHE) {
|
||||
percent_valid = cache_params->start_percent >= 0.0f &&
|
||||
cache_params->start_percent < 1.0f &&
|
||||
cache_params->end_percent > 0.0f &&
|
||||
cache_params->end_percent <= 1.0f &&
|
||||
cache_params->start_percent < cache_params->end_percent;
|
||||
}
|
||||
|
||||
if (!percent_valid) {
|
||||
LOG_WARN("Cache disabled due to invalid percent range (start=%.3f, end=%.3f)",
|
||||
cache_params->start_percent,
|
||||
cache_params->end_percent);
|
||||
} else if (cache_params->mode == SD_CACHE_EASYCACHE) {
|
||||
bool easycache_supported = sd_version_is_dit(version);
|
||||
if (!easycache_supported) {
|
||||
LOG_WARN("EasyCache requested but not supported for this model type");
|
||||
} else {
|
||||
EasyCacheConfig easycache_config;
|
||||
easycache_config.enabled = true;
|
||||
easycache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
|
||||
easycache_config.start_percent = cache_params->start_percent;
|
||||
easycache_config.end_percent = cache_params->end_percent;
|
||||
easycache_state.init(easycache_config, denoiser.get());
|
||||
if (easycache_state.enabled()) {
|
||||
easycache_enabled = true;
|
||||
LOG_INFO("EasyCache enabled - threshold: %.3f, start: %.2f, end: %.2f",
|
||||
easycache_config.reuse_threshold,
|
||||
easycache_config.start_percent,
|
||||
easycache_config.end_percent);
|
||||
} else {
|
||||
LOG_WARN("EasyCache requested but could not be initialized for this run");
|
||||
}
|
||||
}
|
||||
} else if (cache_params->mode == SD_CACHE_UCACHE) {
|
||||
bool ucache_supported = sd_version_is_unet(version);
|
||||
if (!ucache_supported) {
|
||||
LOG_WARN("UCache requested but not supported for this model type (only UNET models)");
|
||||
} else {
|
||||
UCacheConfig ucache_config;
|
||||
ucache_config.enabled = true;
|
||||
ucache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
|
||||
ucache_config.start_percent = cache_params->start_percent;
|
||||
ucache_config.end_percent = cache_params->end_percent;
|
||||
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate));
|
||||
ucache_config.use_relative_threshold = cache_params->use_relative_threshold;
|
||||
ucache_config.reset_error_on_compute = cache_params->reset_error_on_compute;
|
||||
ucache_state.init(ucache_config, denoiser.get());
|
||||
if (ucache_state.enabled()) {
|
||||
ucache_enabled = true;
|
||||
LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s, reset: %s",
|
||||
ucache_config.reuse_threshold,
|
||||
ucache_config.start_percent,
|
||||
ucache_config.end_percent,
|
||||
ucache_config.error_decay_rate,
|
||||
ucache_config.use_relative_threshold ? "true" : "false",
|
||||
ucache_config.reset_error_on_compute ? "true" : "false");
|
||||
} else {
|
||||
LOG_WARN("UCache requested but could not be initialized for this run");
|
||||
}
|
||||
}
|
||||
} else if (cache_params->mode == SD_CACHE_DBCACHE ||
|
||||
cache_params->mode == SD_CACHE_TAYLORSEER ||
|
||||
cache_params->mode == SD_CACHE_CACHE_DIT) {
|
||||
bool cachedit_supported = sd_version_is_dit(version);
|
||||
if (!cachedit_supported) {
|
||||
LOG_WARN("CacheDIT requested but not supported for this model type (only DiT models)");
|
||||
} else {
|
||||
DBCacheConfig dbcfg;
|
||||
dbcfg.enabled = (cache_params->mode == SD_CACHE_DBCACHE ||
|
||||
cache_params->mode == SD_CACHE_CACHE_DIT);
|
||||
dbcfg.Fn_compute_blocks = cache_params->Fn_compute_blocks;
|
||||
dbcfg.Bn_compute_blocks = cache_params->Bn_compute_blocks;
|
||||
dbcfg.residual_diff_threshold = cache_params->residual_diff_threshold;
|
||||
dbcfg.max_warmup_steps = cache_params->max_warmup_steps;
|
||||
dbcfg.max_cached_steps = cache_params->max_cached_steps;
|
||||
dbcfg.max_continuous_cached_steps = cache_params->max_continuous_cached_steps;
|
||||
if (cache_params->scm_mask != nullptr && strlen(cache_params->scm_mask) > 0) {
|
||||
dbcfg.steps_computation_mask = parse_scm_mask(cache_params->scm_mask);
|
||||
}
|
||||
dbcfg.scm_policy_dynamic = cache_params->scm_policy_dynamic;
|
||||
|
||||
TaylorSeerConfig tcfg;
|
||||
tcfg.enabled = (cache_params->mode == SD_CACHE_TAYLORSEER ||
|
||||
cache_params->mode == SD_CACHE_CACHE_DIT);
|
||||
tcfg.n_derivatives = cache_params->taylorseer_n_derivatives;
|
||||
tcfg.skip_interval_steps = cache_params->taylorseer_skip_interval;
|
||||
|
||||
cachedit_state.init(dbcfg, tcfg);
|
||||
if (cachedit_state.enabled()) {
|
||||
cachedit_enabled = true;
|
||||
LOG_INFO("CacheDIT enabled - mode: %s, Fn: %d, Bn: %d, threshold: %.3f, warmup: %d",
|
||||
cache_params->mode == SD_CACHE_CACHE_DIT ? "DBCache+TaylorSeer" : (cache_params->mode == SD_CACHE_DBCACHE ? "DBCache" : "TaylorSeer"),
|
||||
dbcfg.Fn_compute_blocks,
|
||||
dbcfg.Bn_compute_blocks,
|
||||
dbcfg.residual_diff_threshold,
|
||||
dbcfg.max_warmup_steps);
|
||||
} else {
|
||||
LOG_WARN("CacheDIT requested but could not be initialized for this run");
|
||||
}
|
||||
}
|
||||
} else if (cache_params->mode == SD_CACHE_SPECTRUM) {
|
||||
bool spectrum_supported = sd_version_is_unet(version) || sd_version_is_dit(version);
|
||||
if (!spectrum_supported) {
|
||||
LOG_WARN("Spectrum requested but not supported for this model type (only UNET and DiT models)");
|
||||
} else {
|
||||
SpectrumConfig spectrum_config;
|
||||
spectrum_config.w = cache_params->spectrum_w;
|
||||
spectrum_config.m = cache_params->spectrum_m;
|
||||
spectrum_config.lam = cache_params->spectrum_lam;
|
||||
spectrum_config.window_size = cache_params->spectrum_window_size;
|
||||
spectrum_config.flex_window = cache_params->spectrum_flex_window;
|
||||
spectrum_config.warmup_steps = cache_params->spectrum_warmup_steps;
|
||||
spectrum_config.stop_percent = cache_params->spectrum_stop_percent;
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
spectrum_state.init(spectrum_config, total_steps);
|
||||
spectrum_enabled = true;
|
||||
LOG_INFO("Spectrum enabled - w: %.2f, m: %d, lam: %.2f, window: %d, flex: %.2f, warmup: %d, stop: %.0f%%",
|
||||
spectrum_config.w, spectrum_config.m, spectrum_config.lam,
|
||||
spectrum_config.window_size, spectrum_config.flex_window,
|
||||
spectrum_config.warmup_steps, spectrum_config.stop_percent * 100.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ucache_enabled) {
|
||||
ucache_state.set_sigmas(sigmas);
|
||||
}
|
||||
|
||||
if (cachedit_enabled) {
|
||||
cachedit_state.set_sigmas(sigmas);
|
||||
}
|
||||
SampleCacheRuntime cache_runtime = init_sample_cache_runtime(version, cache_params, denoiser.get(), sigmas);
|
||||
|
||||
size_t steps = sigmas.size() - 1;
|
||||
ggml_tensor* x = ggml_ext_dup_and_cpy_tensor(work_ctx, init_latent);
|
||||
@ -1876,121 +2108,7 @@ public:
|
||||
}
|
||||
|
||||
DiffusionParams diffusion_params;
|
||||
|
||||
const bool easycache_step_active = easycache_enabled && step > 0;
|
||||
int easycache_step_index = easycache_step_active ? (step - 1) : -1;
|
||||
if (easycache_step_active) {
|
||||
easycache_state.begin_step(easycache_step_index, sigma);
|
||||
}
|
||||
|
||||
auto easycache_before_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) -> bool {
|
||||
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return easycache_state.before_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor,
|
||||
sigma,
|
||||
easycache_step_index);
|
||||
};
|
||||
|
||||
auto easycache_after_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) {
|
||||
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
easycache_state.after_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor);
|
||||
};
|
||||
|
||||
auto easycache_step_is_skipped = [&]() {
|
||||
return easycache_step_active && easycache_state.is_step_skipped();
|
||||
};
|
||||
|
||||
const bool ucache_step_active = ucache_enabled && step > 0;
|
||||
int ucache_step_index = ucache_step_active ? (step - 1) : -1;
|
||||
if (ucache_step_active) {
|
||||
ucache_state.begin_step(ucache_step_index, sigma);
|
||||
}
|
||||
|
||||
auto ucache_before_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) -> bool {
|
||||
if (!ucache_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return ucache_state.before_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor,
|
||||
sigma,
|
||||
ucache_step_index);
|
||||
};
|
||||
|
||||
auto ucache_after_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) {
|
||||
if (!ucache_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
ucache_state.after_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor);
|
||||
};
|
||||
|
||||
auto ucache_step_is_skipped = [&]() {
|
||||
return ucache_step_active && ucache_state.is_step_skipped();
|
||||
};
|
||||
|
||||
const bool cachedit_step_active = cachedit_enabled && step > 0;
|
||||
int cachedit_step_index = cachedit_step_active ? (step - 1) : -1;
|
||||
if (cachedit_step_active) {
|
||||
cachedit_state.begin_step(cachedit_step_index, sigma);
|
||||
}
|
||||
|
||||
auto cachedit_before_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) -> bool {
|
||||
if (!cachedit_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return cachedit_state.before_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor,
|
||||
sigma,
|
||||
cachedit_step_index);
|
||||
};
|
||||
|
||||
auto cachedit_after_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) {
|
||||
if (!cachedit_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
cachedit_state.after_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor);
|
||||
};
|
||||
|
||||
auto cachedit_step_is_skipped = [&]() {
|
||||
return cachedit_step_active && cachedit_state.is_step_skipped();
|
||||
};
|
||||
|
||||
auto cache_before_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) -> bool {
|
||||
if (easycache_step_active) {
|
||||
return easycache_before_condition(condition, output_tensor);
|
||||
} else if (ucache_step_active) {
|
||||
return ucache_before_condition(condition, output_tensor);
|
||||
} else if (cachedit_step_active) {
|
||||
return cachedit_before_condition(condition, output_tensor);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto cache_after_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) {
|
||||
if (easycache_step_active) {
|
||||
easycache_after_condition(condition, output_tensor);
|
||||
} else if (ucache_step_active) {
|
||||
ucache_after_condition(condition, output_tensor);
|
||||
} else if (cachedit_step_active) {
|
||||
cachedit_after_condition(condition, output_tensor);
|
||||
}
|
||||
};
|
||||
|
||||
auto cache_step_is_skipped = [&]() {
|
||||
return easycache_step_is_skipped() || ucache_step_is_skipped() || cachedit_step_is_skipped();
|
||||
};
|
||||
SampleStepCacheDispatcher step_cache(cache_runtime, step, sigma);
|
||||
|
||||
std::vector<float> scaling = denoiser->get_scalings(sigma);
|
||||
GGML_ASSERT(scaling.size() == 3);
|
||||
@ -2017,8 +2135,8 @@ public:
|
||||
|
||||
timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask);
|
||||
|
||||
if (spectrum_enabled && spectrum_state.should_predict()) {
|
||||
spectrum_state.predict(denoised);
|
||||
if (cache_runtime.spectrum_enabled && cache_runtime.spectrum.should_predict()) {
|
||||
cache_runtime.spectrum.predict(denoised);
|
||||
|
||||
if (denoise_mask != nullptr) {
|
||||
apply_mask(denoised, init_latent, denoise_mask);
|
||||
@ -2077,6 +2195,22 @@ public:
|
||||
diffusion_params.vace_context = vace_context;
|
||||
diffusion_params.vace_strength = vace_strength;
|
||||
|
||||
auto run_diffusion_condition = [&](const SDCondition* condition, ggml_tensor** output_tensor) -> bool {
|
||||
if (step_cache.before_condition(condition, diffusion_params.x, *output_tensor)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
output_tensor)) {
|
||||
LOG_ERROR("diffusion model compute failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
step_cache.after_condition(condition, diffusion_params.x, *output_tensor);
|
||||
return true;
|
||||
};
|
||||
|
||||
const SDCondition* active_condition = nullptr;
|
||||
ggml_tensor** active_output = &out_cond;
|
||||
if (start_merge_step == -1 || step <= start_merge_step) {
|
||||
@ -2092,18 +2226,11 @@ public:
|
||||
active_condition = &id_cond;
|
||||
}
|
||||
|
||||
bool skip_model = cache_before_condition(active_condition, *active_output);
|
||||
if (!skip_model) {
|
||||
if (!work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
active_output)) {
|
||||
LOG_ERROR("diffusion model compute failed");
|
||||
return nullptr;
|
||||
}
|
||||
cache_after_condition(active_condition, *active_output);
|
||||
if (!run_diffusion_condition(active_condition, active_output)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool current_step_skipped = cache_step_is_skipped();
|
||||
bool current_step_skipped = step_cache.is_step_skipped();
|
||||
|
||||
float* negative_data = nullptr;
|
||||
if (has_unconditioned) {
|
||||
@ -2115,20 +2242,13 @@ public:
|
||||
LOG_ERROR("controlnet compute failed");
|
||||
}
|
||||
}
|
||||
current_step_skipped = cache_step_is_skipped();
|
||||
current_step_skipped = step_cache.is_step_skipped();
|
||||
diffusion_params.controls = controls;
|
||||
diffusion_params.context = uncond.c_crossattn;
|
||||
diffusion_params.c_concat = uncond.c_concat;
|
||||
diffusion_params.y = uncond.c_vector;
|
||||
bool skip_uncond = cache_before_condition(&uncond, out_uncond);
|
||||
if (!skip_uncond) {
|
||||
if (!work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_uncond)) {
|
||||
LOG_ERROR("diffusion model compute failed");
|
||||
return nullptr;
|
||||
}
|
||||
cache_after_condition(&uncond, out_uncond);
|
||||
if (!run_diffusion_condition(&uncond, &out_uncond)) {
|
||||
return nullptr;
|
||||
}
|
||||
negative_data = (float*)out_uncond->data;
|
||||
}
|
||||
@ -2138,15 +2258,8 @@ public:
|
||||
diffusion_params.context = img_cond.c_crossattn;
|
||||
diffusion_params.c_concat = img_cond.c_concat;
|
||||
diffusion_params.y = img_cond.c_vector;
|
||||
bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond);
|
||||
if (!skip_img_cond) {
|
||||
if (!work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
&out_img_cond)) {
|
||||
LOG_ERROR("diffusion model compute failed");
|
||||
return nullptr;
|
||||
}
|
||||
cache_after_condition(&img_cond, out_img_cond);
|
||||
if (!run_diffusion_condition(&img_cond, &out_img_cond)) {
|
||||
return nullptr;
|
||||
}
|
||||
img_cond_data = (float*)out_img_cond->data;
|
||||
}
|
||||
@ -2156,7 +2269,7 @@ public:
|
||||
float* skip_layer_data = has_skiplayer ? (float*)out_skip->data : nullptr;
|
||||
if (is_skiplayer_step) {
|
||||
LOG_DEBUG("Skipping layers at step %d\n", step);
|
||||
if (!cache_step_is_skipped()) {
|
||||
if (!step_cache.is_step_skipped()) {
|
||||
// skip layer (same as conditioned)
|
||||
diffusion_params.context = cond.c_crossattn;
|
||||
diffusion_params.c_concat = cond.c_concat;
|
||||
@ -2211,8 +2324,8 @@ public:
|
||||
vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip;
|
||||
}
|
||||
|
||||
if (spectrum_enabled) {
|
||||
spectrum_state.update(denoised);
|
||||
if (cache_runtime.spectrum_enabled) {
|
||||
cache_runtime.spectrum.update(denoised);
|
||||
}
|
||||
|
||||
if (denoise_mask != nullptr) {
|
||||
@ -2244,75 +2357,8 @@ public:
|
||||
return NULL;
|
||||
}
|
||||
|
||||
if (easycache_enabled) {
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
if (easycache_state.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (easycache_state.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - easycache_state.total_steps_skipped);
|
||||
LOG_INFO("EasyCache skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
easycache_state.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup);
|
||||
} else {
|
||||
LOG_INFO("EasyCache skipped %d/%zu steps",
|
||||
easycache_state.total_steps_skipped,
|
||||
total_steps);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("EasyCache completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (ucache_enabled) {
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
if (ucache_state.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (ucache_state.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - ucache_state.total_steps_skipped);
|
||||
LOG_INFO("UCache skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
ucache_state.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup);
|
||||
} else {
|
||||
LOG_INFO("UCache skipped %d/%zu steps",
|
||||
ucache_state.total_steps_skipped,
|
||||
total_steps);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("UCache completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (cachedit_enabled) {
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
if (cachedit_state.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (cachedit_state.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - cachedit_state.total_steps_skipped);
|
||||
LOG_INFO("CacheDIT skipped %d/%zu steps (%.2fx estimated speedup), accum_diff: %.4f",
|
||||
cachedit_state.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup,
|
||||
cachedit_state.accumulated_residual_diff);
|
||||
} else {
|
||||
LOG_INFO("CacheDIT skipped %d/%zu steps, accum_diff: %.4f",
|
||||
cachedit_state.total_steps_skipped,
|
||||
total_steps,
|
||||
cachedit_state.accumulated_residual_diff);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("CacheDIT completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (spectrum_enabled && spectrum_state.total_steps_skipped > 0) {
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - spectrum_state.total_steps_skipped);
|
||||
LOG_INFO("Spectrum skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
spectrum_state.total_steps_skipped, total_steps, speedup);
|
||||
}
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
log_sample_cache_summary(cache_runtime, total_steps);
|
||||
|
||||
if (inverse_noise_scaling) {
|
||||
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user