refactor: move all cache parameter defaults to the library (#1327)

This commit is contained in:
Wagner Bruna 2026-03-15 05:43:46 -03:00 committed by GitHub
parent f6968bc589
commit 630ee03f23
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 24 additions and 154 deletions

View File

@ -80,7 +80,7 @@ Uses Taylor series approximation to predict block outputs:
Combines DBCache and TaylorSeer:
```bash
--cache-mode cache-dit --cache-preset fast
--cache-mode cache-dit
```
#### Parameters
@ -92,14 +92,6 @@ Combines DBCache and TaylorSeer:
| `threshold` | L1 residual difference threshold | 0.08 |
| `warmup` | Steps before caching starts | 8 |
#### Presets
Available presets: `slow`, `medium`, `fast`, `ultra` (or `s`, `m`, `f`, `u`).
```bash
--cache-mode cache-dit --cache-preset fast
```
#### SCM Options
Steps Computation Mask controls which steps can be cached:

View File

@ -144,7 +144,6 @@ Generation Options:
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=;
spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples:
"threshold=0.25" or "threshold=1.5,reset=0" or "w=0.4,window=2"
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
--scm-policy SCM policy: 'dynamic' (default) or 'static'
```

View File

@ -1047,7 +1047,6 @@ struct SDGenerationParams {
std::string cache_mode;
std::string cache_option;
std::string cache_preset;
std::string scm_mask;
bool scm_policy_dynamic = true;
sd_cache_params_t cache_params{};
@ -1461,21 +1460,6 @@ struct SDGenerationParams {
return 1;
};
auto on_cache_preset_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
cache_preset = argv_to_utf8(index, argv);
if (cache_preset != "slow" && cache_preset != "s" && cache_preset != "S" &&
cache_preset != "medium" && cache_preset != "m" && cache_preset != "M" &&
cache_preset != "fast" && cache_preset != "f" && cache_preset != "F" &&
cache_preset != "ultra" && cache_preset != "u" && cache_preset != "U") {
fprintf(stderr, "error: invalid cache preset '%s', must be 'slow'/'s', 'medium'/'m', 'fast'/'f', or 'ultra'/'u'\n", cache_preset.c_str());
return -1;
}
return 1;
};
options.manual_options = {
{"-s",
"--seed",
@ -1519,10 +1503,6 @@ struct SDGenerationParams {
"--cache-option",
"named cache params (key=value format, comma-separated). easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=; spectrum: w=,m=,lam=,window=,flex=,warmup=,stop=. Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"",
on_cache_option_arg},
{"",
"--cache-preset",
"cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'",
on_cache_preset_arg},
{"",
"--scm-mask",
"SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache",
@ -1575,7 +1555,6 @@ struct SDGenerationParams {
load_if_exists("negative_prompt", negative_prompt);
load_if_exists("cache_mode", cache_mode);
load_if_exists("cache_option", cache_option);
load_if_exists("cache_preset", cache_preset);
load_if_exists("scm_mask", scm_mask);
load_if_exists("clip_skip", clip_skip);
@ -1811,47 +1790,16 @@ struct SDGenerationParams {
if (!cache_mode.empty()) {
if (cache_mode == "easycache") {
cache_params.mode = SD_CACHE_EASYCACHE;
cache_params.reuse_threshold = 0.2f;
cache_params.start_percent = 0.15f;
cache_params.end_percent = 0.95f;
cache_params.error_decay_rate = 1.0f;
cache_params.use_relative_threshold = true;
cache_params.reset_error_on_compute = true;
} else if (cache_mode == "ucache") {
cache_params.mode = SD_CACHE_UCACHE;
cache_params.reuse_threshold = 1.0f;
cache_params.start_percent = 0.15f;
cache_params.end_percent = 0.95f;
cache_params.error_decay_rate = 1.0f;
cache_params.use_relative_threshold = true;
cache_params.reset_error_on_compute = true;
} else if (cache_mode == "dbcache") {
cache_params.mode = SD_CACHE_DBCACHE;
cache_params.Fn_compute_blocks = 8;
cache_params.Bn_compute_blocks = 0;
cache_params.residual_diff_threshold = 0.08f;
cache_params.max_warmup_steps = 8;
} else if (cache_mode == "taylorseer") {
cache_params.mode = SD_CACHE_TAYLORSEER;
cache_params.Fn_compute_blocks = 8;
cache_params.Bn_compute_blocks = 0;
cache_params.residual_diff_threshold = 0.08f;
cache_params.max_warmup_steps = 8;
} else if (cache_mode == "cache-dit") {
cache_params.mode = SD_CACHE_CACHE_DIT;
cache_params.Fn_compute_blocks = 8;
cache_params.Bn_compute_blocks = 0;
cache_params.residual_diff_threshold = 0.08f;
cache_params.max_warmup_steps = 8;
} else if (cache_mode == "spectrum") {
cache_params.mode = SD_CACHE_SPECTRUM;
cache_params.spectrum_w = 0.40f;
cache_params.spectrum_m = 3;
cache_params.spectrum_lam = 1.0f;
cache_params.spectrum_window_size = 2;
cache_params.spectrum_flex_window = 0.50f;
cache_params.spectrum_warmup_steps = 4;
cache_params.spectrum_stop_percent = 0.9f;
}
if (!cache_option.empty()) {

View File

@ -133,7 +133,6 @@ Default Generation Options:
--cache-option named cache params (key=value format, comma-separated). easycache/ucache:
threshold=,start=,end=,decay=,relative=,reset=; dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=. Examples:
"threshold=0.25" or "threshold=1.5,reset=0"
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
--scm-mask SCM steps mask for cache-dit: comma-separated 0/1 (e.g., "1,1,1,0,0,1,0,0,1,0") - 1=compute, 0=can cache
--scm-policy SCM policy: 'dynamic' (default) or 'static'
```

View File

@ -603,87 +603,6 @@ inline std::vector<int> generate_scm_mask(
return mask;
}
inline std::vector<int> get_scm_preset(const std::string& preset, int total_steps) {
struct Preset {
std::vector<int> compute_bins;
std::vector<int> cache_bins;
};
Preset slow = {{8, 3, 3, 2, 1, 1}, {1, 2, 2, 2, 3}};
Preset medium = {{6, 2, 2, 2, 2, 1}, {1, 3, 3, 3, 3}};
Preset fast = {{6, 1, 1, 1, 1, 1}, {1, 3, 4, 5, 4}};
Preset ultra = {{4, 1, 1, 1, 1}, {2, 5, 6, 7}};
Preset* p = nullptr;
if (preset == "slow" || preset == "s" || preset == "S")
p = &slow;
else if (preset == "medium" || preset == "m" || preset == "M")
p = &medium;
else if (preset == "fast" || preset == "f" || preset == "F")
p = &fast;
else if (preset == "ultra" || preset == "u" || preset == "U")
p = &ultra;
else
return {};
if (total_steps != 28 && total_steps > 0) {
float scale = static_cast<float>(total_steps) / 28.0f;
std::vector<int> scaled_compute, scaled_cache;
for (int v : p->compute_bins) {
scaled_compute.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
}
for (int v : p->cache_bins) {
scaled_cache.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
}
return generate_scm_mask(scaled_compute, scaled_cache, total_steps);
}
return generate_scm_mask(p->compute_bins, p->cache_bins, total_steps);
}
inline float get_preset_threshold(const std::string& preset) {
if (preset == "slow" || preset == "s" || preset == "S")
return 0.20f;
if (preset == "medium" || preset == "m" || preset == "M")
return 0.25f;
if (preset == "fast" || preset == "f" || preset == "F")
return 0.30f;
if (preset == "ultra" || preset == "u" || preset == "U")
return 0.34f;
return 0.08f;
}
inline int get_preset_warmup(const std::string& preset) {
if (preset == "slow" || preset == "s" || preset == "S")
return 8;
if (preset == "medium" || preset == "m" || preset == "M")
return 6;
if (preset == "fast" || preset == "f" || preset == "F")
return 6;
if (preset == "ultra" || preset == "u" || preset == "U")
return 4;
return 8;
}
inline int get_preset_Fn(const std::string& preset) {
if (preset == "slow" || preset == "s" || preset == "S")
return 8;
if (preset == "medium" || preset == "m" || preset == "M")
return 8;
if (preset == "fast" || preset == "f" || preset == "F")
return 6;
if (preset == "ultra" || preset == "u" || preset == "U")
return 4;
return 8;
}
inline int get_preset_Bn(const std::string& preset) {
(void)preset;
return 0;
}
inline void parse_dbcache_options(const std::string& opts, DBCacheConfig& cfg) {
if (opts.empty())
return;

View File

@ -98,6 +98,19 @@ void suppress_pp(int step, int steps, float time, void* data) {
return;
}
static float get_cache_reuse_threshold(const sd_cache_params_t& params) {
float reuse_threshold = params.reuse_threshold;
if (reuse_threshold == INFINITY) {
if (params.mode == SD_CACHE_EASYCACHE) {
reuse_threshold = 0.2;
}
else if (params.mode == SD_CACHE_UCACHE) {
reuse_threshold = 1.0;
}
}
return std::max(0.0f, reuse_threshold);
}
/*=============================================== StableDiffusionGGML ================================================*/
class StableDiffusionGGML {
@ -1715,7 +1728,7 @@ public:
} else {
EasyCacheConfig easycache_config;
easycache_config.enabled = true;
easycache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
easycache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
easycache_config.start_percent = cache_params->start_percent;
easycache_config.end_percent = cache_params->end_percent;
easycache_state.init(easycache_config, denoiser.get());
@ -1736,7 +1749,7 @@ public:
} else {
UCacheConfig ucache_config;
ucache_config.enabled = true;
ucache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
ucache_config.reuse_threshold = get_cache_reuse_threshold(*cache_params);
ucache_config.start_percent = cache_params->start_percent;
ucache_config.end_percent = cache_params->end_percent;
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate));
@ -2983,7 +2996,7 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) {
void sd_cache_params_init(sd_cache_params_t* cache_params) {
*cache_params = {};
cache_params->mode = SD_CACHE_DISABLED;
cache_params->reuse_threshold = 1.0f;
cache_params->reuse_threshold = INFINITY;
cache_params->start_percent = 0.15f;
cache_params->end_percent = 0.95f;
cache_params->error_decay_rate = 1.0f;
@ -3229,7 +3242,7 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
snprintf(buf + strlen(buf), 4096 - strlen(buf),
"cache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
cache_mode_str,
sd_img_gen_params->cache.reuse_threshold,
get_cache_reuse_threshold(sd_img_gen_params->cache),
sd_img_gen_params->cache.start_percent,
sd_img_gen_params->cache.end_percent);
free(sample_params_str);