refactor: simplify sample cache flow (#1350)

This commit is contained in:
leejet 2026-03-17 00:28:03 +08:00 committed by GitHub
parent 5265a5efa1
commit 545fac4f3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -103,6 +103,379 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) {
return std::max(0.0f, reuse_threshold);
}
enum class SampleCacheMode {
NONE,
EASYCACHE,
UCACHE,
CACHEDIT,
};
struct SampleCacheRuntime {
SampleCacheMode mode = SampleCacheMode::NONE;
EasyCacheState easycache;
UCacheState ucache;
CacheDitConditionState cachedit;
SpectrumState spectrum;
bool spectrum_enabled = false;
bool has_step_cache() const {
return mode != SampleCacheMode::NONE;
}
bool easycache_enabled() const {
return mode == SampleCacheMode::EASYCACHE;
}
bool ucache_enabled() const {
return mode == SampleCacheMode::UCACHE;
}
bool cachedit_enabled() const {
return mode == SampleCacheMode::CACHEDIT;
}
};
static bool has_valid_cache_percent_range(const sd_cache_params_t& cache_params) {
if (cache_params.mode != SD_CACHE_EASYCACHE && cache_params.mode != SD_CACHE_UCACHE) {
return true;
}
return cache_params.start_percent >= 0.0f &&
cache_params.start_percent < 1.0f &&
cache_params.end_percent > 0.0f &&
cache_params.end_percent <= 1.0f &&
cache_params.start_percent < cache_params.end_percent;
}
static void init_easycache_runtime(SampleCacheRuntime& runtime,
SDVersion version,
const sd_cache_params_t& cache_params,
Denoiser* denoiser) {
if (!sd_version_is_dit(version)) {
LOG_WARN("EasyCache requested but not supported for this model type");
return;
}
EasyCacheConfig config;
config.enabled = true;
config.reuse_threshold = get_cache_reuse_threshold(cache_params);
config.start_percent = cache_params.start_percent;
config.end_percent = cache_params.end_percent;
runtime.easycache.init(config, denoiser);
if (!runtime.easycache.enabled()) {
LOG_WARN("EasyCache requested but could not be initialized for this run");
return;
}
runtime.mode = SampleCacheMode::EASYCACHE;
LOG_INFO("EasyCache enabled - threshold: %.3f, start: %.2f, end: %.2f",
config.reuse_threshold,
config.start_percent,
config.end_percent);
}
static void init_ucache_runtime(SampleCacheRuntime& runtime,
SDVersion version,
const sd_cache_params_t& cache_params,
Denoiser* denoiser,
const std::vector<float>& sigmas) {
if (!sd_version_is_unet(version)) {
LOG_WARN("UCache requested but not supported for this model type (only UNET models)");
return;
}
UCacheConfig config;
config.enabled = true;
config.reuse_threshold = get_cache_reuse_threshold(cache_params);
config.start_percent = cache_params.start_percent;
config.end_percent = cache_params.end_percent;
config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params.error_decay_rate));
config.use_relative_threshold = cache_params.use_relative_threshold;
config.reset_error_on_compute = cache_params.reset_error_on_compute;
runtime.ucache.init(config, denoiser);
if (!runtime.ucache.enabled()) {
LOG_WARN("UCache requested but could not be initialized for this run");
return;
}
runtime.ucache.set_sigmas(sigmas);
runtime.mode = SampleCacheMode::UCACHE;
LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s, reset: %s",
config.reuse_threshold,
config.start_percent,
config.end_percent,
config.error_decay_rate,
config.use_relative_threshold ? "true" : "false",
config.reset_error_on_compute ? "true" : "false");
}
static void init_cachedit_runtime(SampleCacheRuntime& runtime,
SDVersion version,
const sd_cache_params_t& cache_params,
const std::vector<float>& sigmas) {
if (!sd_version_is_dit(version)) {
LOG_WARN("CacheDIT requested but not supported for this model type (only DiT models)");
return;
}
DBCacheConfig dbcfg;
dbcfg.enabled = (cache_params.mode == SD_CACHE_DBCACHE ||
cache_params.mode == SD_CACHE_CACHE_DIT);
dbcfg.Fn_compute_blocks = cache_params.Fn_compute_blocks;
dbcfg.Bn_compute_blocks = cache_params.Bn_compute_blocks;
dbcfg.residual_diff_threshold = cache_params.residual_diff_threshold;
dbcfg.max_warmup_steps = cache_params.max_warmup_steps;
dbcfg.max_cached_steps = cache_params.max_cached_steps;
dbcfg.max_continuous_cached_steps = cache_params.max_continuous_cached_steps;
if (cache_params.scm_mask != nullptr && strlen(cache_params.scm_mask) > 0) {
dbcfg.steps_computation_mask = parse_scm_mask(cache_params.scm_mask);
}
dbcfg.scm_policy_dynamic = cache_params.scm_policy_dynamic;
TaylorSeerConfig tcfg;
tcfg.enabled = (cache_params.mode == SD_CACHE_TAYLORSEER ||
cache_params.mode == SD_CACHE_CACHE_DIT);
tcfg.n_derivatives = cache_params.taylorseer_n_derivatives;
tcfg.skip_interval_steps = cache_params.taylorseer_skip_interval;
runtime.cachedit.init(dbcfg, tcfg);
if (!runtime.cachedit.enabled()) {
LOG_WARN("CacheDIT requested but could not be initialized for this run");
return;
}
runtime.cachedit.set_sigmas(sigmas);
runtime.mode = SampleCacheMode::CACHEDIT;
LOG_INFO("CacheDIT enabled - mode: %s, Fn: %d, Bn: %d, threshold: %.3f, warmup: %d",
cache_params.mode == SD_CACHE_CACHE_DIT ? "DBCache+TaylorSeer" : (cache_params.mode == SD_CACHE_DBCACHE ? "DBCache" : "TaylorSeer"),
dbcfg.Fn_compute_blocks,
dbcfg.Bn_compute_blocks,
dbcfg.residual_diff_threshold,
dbcfg.max_warmup_steps);
}
static void init_spectrum_runtime(SampleCacheRuntime& runtime,
SDVersion version,
const sd_cache_params_t& cache_params,
const std::vector<float>& sigmas) {
if (!sd_version_is_unet(version) && !sd_version_is_dit(version)) {
LOG_WARN("Spectrum requested but not supported for this model type (only UNET and DiT models)");
return;
}
SpectrumConfig config;
config.w = cache_params.spectrum_w;
config.m = cache_params.spectrum_m;
config.lam = cache_params.spectrum_lam;
config.window_size = cache_params.spectrum_window_size;
config.flex_window = cache_params.spectrum_flex_window;
config.warmup_steps = cache_params.spectrum_warmup_steps;
config.stop_percent = cache_params.spectrum_stop_percent;
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
runtime.spectrum.init(config, total_steps);
runtime.spectrum_enabled = true;
LOG_INFO("Spectrum enabled - w: %.2f, m: %d, lam: %.2f, window: %d, flex: %.2f, warmup: %d, stop: %.0f%%",
config.w, config.m, config.lam,
config.window_size, config.flex_window,
config.warmup_steps, config.stop_percent * 100.0f);
}
static SampleCacheRuntime init_sample_cache_runtime(SDVersion version,
const sd_cache_params_t* cache_params,
Denoiser* denoiser,
const std::vector<float>& sigmas) {
SampleCacheRuntime runtime;
if (cache_params == nullptr || cache_params->mode == SD_CACHE_DISABLED) {
return runtime;
}
if (!has_valid_cache_percent_range(*cache_params)) {
LOG_WARN("Cache disabled due to invalid percent range (start=%.3f, end=%.3f)",
cache_params->start_percent,
cache_params->end_percent);
return runtime;
}
switch (cache_params->mode) {
case SD_CACHE_EASYCACHE:
init_easycache_runtime(runtime, version, *cache_params, denoiser);
break;
case SD_CACHE_UCACHE:
init_ucache_runtime(runtime, version, *cache_params, denoiser, sigmas);
break;
case SD_CACHE_DBCACHE:
case SD_CACHE_TAYLORSEER:
case SD_CACHE_CACHE_DIT:
init_cachedit_runtime(runtime, version, *cache_params, sigmas);
break;
case SD_CACHE_SPECTRUM:
init_spectrum_runtime(runtime, version, *cache_params, sigmas);
break;
default:
break;
}
return runtime;
}
struct SampleStepCacheDispatcher {
SampleCacheRuntime& runtime;
int step;
float sigma;
int step_index;
SampleStepCacheDispatcher(SampleCacheRuntime& runtime, int step, float sigma)
: runtime(runtime), step(step), sigma(sigma), step_index(step > 0 ? (step - 1) : -1) {
if (step_index < 0) {
return;
}
switch (runtime.mode) {
case SampleCacheMode::EASYCACHE:
runtime.easycache.begin_step(step_index, sigma);
break;
case SampleCacheMode::UCACHE:
runtime.ucache.begin_step(step_index, sigma);
break;
case SampleCacheMode::CACHEDIT:
runtime.cachedit.begin_step(step_index, sigma);
break;
case SampleCacheMode::NONE:
break;
}
}
bool before_condition(const SDCondition* condition, ggml_tensor* input, ggml_tensor* output) {
if (step_index < 0 || condition == nullptr || input == nullptr || output == nullptr) {
return false;
}
switch (runtime.mode) {
case SampleCacheMode::EASYCACHE:
return runtime.easycache.before_condition(condition, input, output, sigma, step_index);
case SampleCacheMode::UCACHE:
return runtime.ucache.before_condition(condition, input, output, sigma, step_index);
case SampleCacheMode::CACHEDIT:
return runtime.cachedit.before_condition(condition, input, output, sigma, step_index);
case SampleCacheMode::NONE:
return false;
}
return false;
}
void after_condition(const SDCondition* condition, ggml_tensor* input, ggml_tensor* output) {
if (step_index < 0 || condition == nullptr || input == nullptr || output == nullptr) {
return;
}
switch (runtime.mode) {
case SampleCacheMode::EASYCACHE:
runtime.easycache.after_condition(condition, input, output);
break;
case SampleCacheMode::UCACHE:
runtime.ucache.after_condition(condition, input, output);
break;
case SampleCacheMode::CACHEDIT:
runtime.cachedit.after_condition(condition, input, output);
break;
case SampleCacheMode::NONE:
break;
}
}
bool is_step_skipped() const {
switch (runtime.mode) {
case SampleCacheMode::EASYCACHE:
return runtime.easycache.is_step_skipped();
case SampleCacheMode::UCACHE:
return runtime.ucache.is_step_skipped();
case SampleCacheMode::CACHEDIT:
return runtime.cachedit.is_step_skipped();
case SampleCacheMode::NONE:
return false;
}
return false;
}
};
static void log_sample_cache_summary(const SampleCacheRuntime& runtime, size_t total_steps) {
if (runtime.easycache_enabled()) {
if (runtime.easycache.total_steps_skipped > 0 && total_steps > 0) {
if (runtime.easycache.total_steps_skipped < static_cast<int>(total_steps)) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - runtime.easycache.total_steps_skipped);
LOG_INFO("EasyCache skipped %d/%zu steps (%.2fx estimated speedup)",
runtime.easycache.total_steps_skipped,
total_steps,
speedup);
} else {
LOG_INFO("EasyCache skipped %d/%zu steps",
runtime.easycache.total_steps_skipped,
total_steps);
}
} else if (total_steps > 0) {
LOG_INFO("EasyCache completed without skipping steps");
}
}
if (runtime.ucache_enabled()) {
if (runtime.ucache.total_steps_skipped > 0 && total_steps > 0) {
if (runtime.ucache.total_steps_skipped < static_cast<int>(total_steps)) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - runtime.ucache.total_steps_skipped);
LOG_INFO("UCache skipped %d/%zu steps (%.2fx estimated speedup)",
runtime.ucache.total_steps_skipped,
total_steps,
speedup);
} else {
LOG_INFO("UCache skipped %d/%zu steps",
runtime.ucache.total_steps_skipped,
total_steps);
}
} else if (total_steps > 0) {
LOG_INFO("UCache completed without skipping steps");
}
}
if (runtime.cachedit_enabled()) {
if (runtime.cachedit.total_steps_skipped > 0 && total_steps > 0) {
if (runtime.cachedit.total_steps_skipped < static_cast<int>(total_steps)) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - runtime.cachedit.total_steps_skipped);
LOG_INFO("CacheDIT skipped %d/%zu steps (%.2fx estimated speedup), accum_diff: %.4f",
runtime.cachedit.total_steps_skipped,
total_steps,
speedup,
runtime.cachedit.accumulated_residual_diff);
} else {
LOG_INFO("CacheDIT skipped %d/%zu steps, accum_diff: %.4f",
runtime.cachedit.total_steps_skipped,
total_steps,
runtime.cachedit.accumulated_residual_diff);
}
} else if (total_steps > 0) {
LOG_INFO("CacheDIT completed without skipping steps");
}
}
if (runtime.spectrum_enabled && runtime.spectrum.total_steps_skipped > 0 && total_steps > 0) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - runtime.spectrum.total_steps_skipped);
LOG_INFO("Spectrum skipped %d/%zu steps (%.2fx estimated speedup)",
runtime.spectrum.total_steps_skipped,
total_steps,
speedup);
}
}
/*=============================================== StableDiffusionGGML ================================================*/
class StableDiffusionGGML {
@ -1662,148 +2035,7 @@ public:
img_cfg_scale = cfg_scale;
}
EasyCacheState easycache_state;
UCacheState ucache_state;
CacheDitConditionState cachedit_state;
SpectrumState spectrum_state;
bool easycache_enabled = false;
bool ucache_enabled = false;
bool cachedit_enabled = false;
bool spectrum_enabled = false;
if (cache_params != nullptr && cache_params->mode != SD_CACHE_DISABLED) {
bool percent_valid = true;
if (cache_params->mode == SD_CACHE_EASYCACHE || cache_params->mode == SD_CACHE_UCACHE) {
percent_valid = cache_params->start_percent >= 0.0f &&
cache_params->start_percent < 1.0f &&
cache_params->end_percent > 0.0f &&
cache_params->end_percent <= 1.0f &&
cache_params->start_percent < cache_params->end_percent;
}
if (!percent_valid) {
LOG_WARN("Cache disabled due to invalid percent range (start=%.3f, end=%.3f)",
cache_params->start_percent,
cache_params->end_percent);
} else if (cache_params->mode == SD_CACHE_EASYCACHE) {
bool easycache_supported = sd_version_is_dit(version);
if (!easycache_supported) {
LOG_WARN("EasyCache requested but not supported for this model type");
} else {
EasyCacheConfig easycache_config;
easycache_config.enabled = true;
easycache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
easycache_config.start_percent = cache_params->start_percent;
easycache_config.end_percent = cache_params->end_percent;
easycache_state.init(easycache_config, denoiser.get());
if (easycache_state.enabled()) {
easycache_enabled = true;
LOG_INFO("EasyCache enabled - threshold: %.3f, start: %.2f, end: %.2f",
easycache_config.reuse_threshold,
easycache_config.start_percent,
easycache_config.end_percent);
} else {
LOG_WARN("EasyCache requested but could not be initialized for this run");
}
}
} else if (cache_params->mode == SD_CACHE_UCACHE) {
bool ucache_supported = sd_version_is_unet(version);
if (!ucache_supported) {
LOG_WARN("UCache requested but not supported for this model type (only UNET models)");
} else {
UCacheConfig ucache_config;
ucache_config.enabled = true;
ucache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
ucache_config.start_percent = cache_params->start_percent;
ucache_config.end_percent = cache_params->end_percent;
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate));
ucache_config.use_relative_threshold = cache_params->use_relative_threshold;
ucache_config.reset_error_on_compute = cache_params->reset_error_on_compute;
ucache_state.init(ucache_config, denoiser.get());
if (ucache_state.enabled()) {
ucache_enabled = true;
LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s, reset: %s",
ucache_config.reuse_threshold,
ucache_config.start_percent,
ucache_config.end_percent,
ucache_config.error_decay_rate,
ucache_config.use_relative_threshold ? "true" : "false",
ucache_config.reset_error_on_compute ? "true" : "false");
} else {
LOG_WARN("UCache requested but could not be initialized for this run");
}
}
} else if (cache_params->mode == SD_CACHE_DBCACHE ||
cache_params->mode == SD_CACHE_TAYLORSEER ||
cache_params->mode == SD_CACHE_CACHE_DIT) {
bool cachedit_supported = sd_version_is_dit(version);
if (!cachedit_supported) {
LOG_WARN("CacheDIT requested but not supported for this model type (only DiT models)");
} else {
DBCacheConfig dbcfg;
dbcfg.enabled = (cache_params->mode == SD_CACHE_DBCACHE ||
cache_params->mode == SD_CACHE_CACHE_DIT);
dbcfg.Fn_compute_blocks = cache_params->Fn_compute_blocks;
dbcfg.Bn_compute_blocks = cache_params->Bn_compute_blocks;
dbcfg.residual_diff_threshold = cache_params->residual_diff_threshold;
dbcfg.max_warmup_steps = cache_params->max_warmup_steps;
dbcfg.max_cached_steps = cache_params->max_cached_steps;
dbcfg.max_continuous_cached_steps = cache_params->max_continuous_cached_steps;
if (cache_params->scm_mask != nullptr && strlen(cache_params->scm_mask) > 0) {
dbcfg.steps_computation_mask = parse_scm_mask(cache_params->scm_mask);
}
dbcfg.scm_policy_dynamic = cache_params->scm_policy_dynamic;
TaylorSeerConfig tcfg;
tcfg.enabled = (cache_params->mode == SD_CACHE_TAYLORSEER ||
cache_params->mode == SD_CACHE_CACHE_DIT);
tcfg.n_derivatives = cache_params->taylorseer_n_derivatives;
tcfg.skip_interval_steps = cache_params->taylorseer_skip_interval;
cachedit_state.init(dbcfg, tcfg);
if (cachedit_state.enabled()) {
cachedit_enabled = true;
LOG_INFO("CacheDIT enabled - mode: %s, Fn: %d, Bn: %d, threshold: %.3f, warmup: %d",
cache_params->mode == SD_CACHE_CACHE_DIT ? "DBCache+TaylorSeer" : (cache_params->mode == SD_CACHE_DBCACHE ? "DBCache" : "TaylorSeer"),
dbcfg.Fn_compute_blocks,
dbcfg.Bn_compute_blocks,
dbcfg.residual_diff_threshold,
dbcfg.max_warmup_steps);
} else {
LOG_WARN("CacheDIT requested but could not be initialized for this run");
}
}
} else if (cache_params->mode == SD_CACHE_SPECTRUM) {
bool spectrum_supported = sd_version_is_unet(version) || sd_version_is_dit(version);
if (!spectrum_supported) {
LOG_WARN("Spectrum requested but not supported for this model type (only UNET and DiT models)");
} else {
SpectrumConfig spectrum_config;
spectrum_config.w = cache_params->spectrum_w;
spectrum_config.m = cache_params->spectrum_m;
spectrum_config.lam = cache_params->spectrum_lam;
spectrum_config.window_size = cache_params->spectrum_window_size;
spectrum_config.flex_window = cache_params->spectrum_flex_window;
spectrum_config.warmup_steps = cache_params->spectrum_warmup_steps;
spectrum_config.stop_percent = cache_params->spectrum_stop_percent;
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
spectrum_state.init(spectrum_config, total_steps);
spectrum_enabled = true;
LOG_INFO("Spectrum enabled - w: %.2f, m: %d, lam: %.2f, window: %d, flex: %.2f, warmup: %d, stop: %.0f%%",
spectrum_config.w, spectrum_config.m, spectrum_config.lam,
spectrum_config.window_size, spectrum_config.flex_window,
spectrum_config.warmup_steps, spectrum_config.stop_percent * 100.0f);
}
}
}
if (ucache_enabled) {
ucache_state.set_sigmas(sigmas);
}
if (cachedit_enabled) {
cachedit_state.set_sigmas(sigmas);
}
SampleCacheRuntime cache_runtime = init_sample_cache_runtime(version, cache_params, denoiser.get(), sigmas);
size_t steps = sigmas.size() - 1;
ggml_tensor* x = ggml_ext_dup_and_cpy_tensor(work_ctx, init_latent);
@ -1876,121 +2108,7 @@ public:
}
DiffusionParams diffusion_params;
const bool easycache_step_active = easycache_enabled && step > 0;
int easycache_step_index = easycache_step_active ? (step - 1) : -1;
if (easycache_step_active) {
easycache_state.begin_step(easycache_step_index, sigma);
}
auto easycache_before_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) -> bool {
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
return false;
}
return easycache_state.before_condition(condition,
diffusion_params.x,
output_tensor,
sigma,
easycache_step_index);
};
auto easycache_after_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) {
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
return;
}
easycache_state.after_condition(condition,
diffusion_params.x,
output_tensor);
};
auto easycache_step_is_skipped = [&]() {
return easycache_step_active && easycache_state.is_step_skipped();
};
const bool ucache_step_active = ucache_enabled && step > 0;
int ucache_step_index = ucache_step_active ? (step - 1) : -1;
if (ucache_step_active) {
ucache_state.begin_step(ucache_step_index, sigma);
}
auto ucache_before_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) -> bool {
if (!ucache_step_active || condition == nullptr || output_tensor == nullptr) {
return false;
}
return ucache_state.before_condition(condition,
diffusion_params.x,
output_tensor,
sigma,
ucache_step_index);
};
auto ucache_after_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) {
if (!ucache_step_active || condition == nullptr || output_tensor == nullptr) {
return;
}
ucache_state.after_condition(condition,
diffusion_params.x,
output_tensor);
};
auto ucache_step_is_skipped = [&]() {
return ucache_step_active && ucache_state.is_step_skipped();
};
const bool cachedit_step_active = cachedit_enabled && step > 0;
int cachedit_step_index = cachedit_step_active ? (step - 1) : -1;
if (cachedit_step_active) {
cachedit_state.begin_step(cachedit_step_index, sigma);
}
auto cachedit_before_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) -> bool {
if (!cachedit_step_active || condition == nullptr || output_tensor == nullptr) {
return false;
}
return cachedit_state.before_condition(condition,
diffusion_params.x,
output_tensor,
sigma,
cachedit_step_index);
};
auto cachedit_after_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) {
if (!cachedit_step_active || condition == nullptr || output_tensor == nullptr) {
return;
}
cachedit_state.after_condition(condition,
diffusion_params.x,
output_tensor);
};
auto cachedit_step_is_skipped = [&]() {
return cachedit_step_active && cachedit_state.is_step_skipped();
};
auto cache_before_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) -> bool {
if (easycache_step_active) {
return easycache_before_condition(condition, output_tensor);
} else if (ucache_step_active) {
return ucache_before_condition(condition, output_tensor);
} else if (cachedit_step_active) {
return cachedit_before_condition(condition, output_tensor);
}
return false;
};
auto cache_after_condition = [&](const SDCondition* condition, ggml_tensor* output_tensor) {
if (easycache_step_active) {
easycache_after_condition(condition, output_tensor);
} else if (ucache_step_active) {
ucache_after_condition(condition, output_tensor);
} else if (cachedit_step_active) {
cachedit_after_condition(condition, output_tensor);
}
};
auto cache_step_is_skipped = [&]() {
return easycache_step_is_skipped() || ucache_step_is_skipped() || cachedit_step_is_skipped();
};
SampleStepCacheDispatcher step_cache(cache_runtime, step, sigma);
std::vector<float> scaling = denoiser->get_scalings(sigma);
GGML_ASSERT(scaling.size() == 3);
@ -2017,8 +2135,8 @@ public:
timesteps_vec = process_timesteps(timesteps_vec, init_latent, denoise_mask);
if (spectrum_enabled && spectrum_state.should_predict()) {
spectrum_state.predict(denoised);
if (cache_runtime.spectrum_enabled && cache_runtime.spectrum.should_predict()) {
cache_runtime.spectrum.predict(denoised);
if (denoise_mask != nullptr) {
apply_mask(denoised, init_latent, denoise_mask);
@ -2077,6 +2195,22 @@ public:
diffusion_params.vace_context = vace_context;
diffusion_params.vace_strength = vace_strength;
auto run_diffusion_condition = [&](const SDCondition* condition, ggml_tensor** output_tensor) -> bool {
if (step_cache.before_condition(condition, diffusion_params.x, *output_tensor)) {
return true;
}
if (!work_diffusion_model->compute(n_threads,
diffusion_params,
output_tensor)) {
LOG_ERROR("diffusion model compute failed");
return false;
}
step_cache.after_condition(condition, diffusion_params.x, *output_tensor);
return true;
};
const SDCondition* active_condition = nullptr;
ggml_tensor** active_output = &out_cond;
if (start_merge_step == -1 || step <= start_merge_step) {
@ -2092,18 +2226,11 @@ public:
active_condition = &id_cond;
}
bool skip_model = cache_before_condition(active_condition, *active_output);
if (!skip_model) {
if (!work_diffusion_model->compute(n_threads,
diffusion_params,
active_output)) {
LOG_ERROR("diffusion model compute failed");
if (!run_diffusion_condition(active_condition, active_output)) {
return nullptr;
}
cache_after_condition(active_condition, *active_output);
}
bool current_step_skipped = cache_step_is_skipped();
bool current_step_skipped = step_cache.is_step_skipped();
float* negative_data = nullptr;
if (has_unconditioned) {
@ -2115,21 +2242,14 @@ public:
LOG_ERROR("controlnet compute failed");
}
}
current_step_skipped = cache_step_is_skipped();
current_step_skipped = step_cache.is_step_skipped();
diffusion_params.controls = controls;
diffusion_params.context = uncond.c_crossattn;
diffusion_params.c_concat = uncond.c_concat;
diffusion_params.y = uncond.c_vector;
bool skip_uncond = cache_before_condition(&uncond, out_uncond);
if (!skip_uncond) {
if (!work_diffusion_model->compute(n_threads,
diffusion_params,
&out_uncond)) {
LOG_ERROR("diffusion model compute failed");
if (!run_diffusion_condition(&uncond, &out_uncond)) {
return nullptr;
}
cache_after_condition(&uncond, out_uncond);
}
negative_data = (float*)out_uncond->data;
}
@ -2138,16 +2258,9 @@ public:
diffusion_params.context = img_cond.c_crossattn;
diffusion_params.c_concat = img_cond.c_concat;
diffusion_params.y = img_cond.c_vector;
bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond);
if (!skip_img_cond) {
if (!work_diffusion_model->compute(n_threads,
diffusion_params,
&out_img_cond)) {
LOG_ERROR("diffusion model compute failed");
if (!run_diffusion_condition(&img_cond, &out_img_cond)) {
return nullptr;
}
cache_after_condition(&img_cond, out_img_cond);
}
img_cond_data = (float*)out_img_cond->data;
}
@ -2156,7 +2269,7 @@ public:
float* skip_layer_data = has_skiplayer ? (float*)out_skip->data : nullptr;
if (is_skiplayer_step) {
LOG_DEBUG("Skipping layers at step %d\n", step);
if (!cache_step_is_skipped()) {
if (!step_cache.is_step_skipped()) {
// skip layer (same as conditioned)
diffusion_params.context = cond.c_crossattn;
diffusion_params.c_concat = cond.c_concat;
@ -2211,8 +2324,8 @@ public:
vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip;
}
if (spectrum_enabled) {
spectrum_state.update(denoised);
if (cache_runtime.spectrum_enabled) {
cache_runtime.spectrum.update(denoised);
}
if (denoise_mask != nullptr) {
@ -2244,75 +2357,8 @@ public:
return NULL;
}
if (easycache_enabled) {
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
if (easycache_state.total_steps_skipped > 0 && total_steps > 0) {
if (easycache_state.total_steps_skipped < static_cast<int>(total_steps)) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - easycache_state.total_steps_skipped);
LOG_INFO("EasyCache skipped %d/%zu steps (%.2fx estimated speedup)",
easycache_state.total_steps_skipped,
total_steps,
speedup);
} else {
LOG_INFO("EasyCache skipped %d/%zu steps",
easycache_state.total_steps_skipped,
total_steps);
}
} else if (total_steps > 0) {
LOG_INFO("EasyCache completed without skipping steps");
}
}
if (ucache_enabled) {
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
if (ucache_state.total_steps_skipped > 0 && total_steps > 0) {
if (ucache_state.total_steps_skipped < static_cast<int>(total_steps)) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - ucache_state.total_steps_skipped);
LOG_INFO("UCache skipped %d/%zu steps (%.2fx estimated speedup)",
ucache_state.total_steps_skipped,
total_steps,
speedup);
} else {
LOG_INFO("UCache skipped %d/%zu steps",
ucache_state.total_steps_skipped,
total_steps);
}
} else if (total_steps > 0) {
LOG_INFO("UCache completed without skipping steps");
}
}
if (cachedit_enabled) {
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
if (cachedit_state.total_steps_skipped > 0 && total_steps > 0) {
if (cachedit_state.total_steps_skipped < static_cast<int>(total_steps)) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - cachedit_state.total_steps_skipped);
LOG_INFO("CacheDIT skipped %d/%zu steps (%.2fx estimated speedup), accum_diff: %.4f",
cachedit_state.total_steps_skipped,
total_steps,
speedup,
cachedit_state.accumulated_residual_diff);
} else {
LOG_INFO("CacheDIT skipped %d/%zu steps, accum_diff: %.4f",
cachedit_state.total_steps_skipped,
total_steps,
cachedit_state.accumulated_residual_diff);
}
} else if (total_steps > 0) {
LOG_INFO("CacheDIT completed without skipping steps");
}
}
if (spectrum_enabled && spectrum_state.total_steps_skipped > 0) {
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - spectrum_state.total_steps_skipped);
LOG_INFO("Spectrum skipped %d/%zu steps (%.2fx estimated speedup)",
spectrum_state.total_steps_skipped, total_steps, speedup);
}
log_sample_cache_summary(cache_runtime, total_steps);
if (inverse_noise_scaling) {
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);