mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-01-02 10:43:35 +00:00
feat: add more caching methods (#1066)
This commit is contained in:
parent
30a91138f8
commit
298b11069f
975
cache_dit.hpp
Normal file
975
cache_dit.hpp
Normal file
@ -0,0 +1,975 @@
|
||||
#ifndef __CACHE_DIT_HPP__
|
||||
#define __CACHE_DIT_HPP__
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "ggml_extend.hpp"
|
||||
|
||||
struct DBCacheConfig {
|
||||
bool enabled = false;
|
||||
int Fn_compute_blocks = 8;
|
||||
int Bn_compute_blocks = 0;
|
||||
float residual_diff_threshold = 0.08f;
|
||||
int max_warmup_steps = 8;
|
||||
int max_cached_steps = -1;
|
||||
int max_continuous_cached_steps = -1;
|
||||
float max_accumulated_residual_diff = -1.0f;
|
||||
std::vector<int> steps_computation_mask;
|
||||
bool scm_policy_dynamic = true;
|
||||
};
|
||||
|
||||
struct TaylorSeerConfig {
|
||||
bool enabled = false;
|
||||
int n_derivatives = 1;
|
||||
int max_warmup_steps = 2;
|
||||
int skip_interval_steps = 1;
|
||||
};
|
||||
|
||||
struct CacheDitConfig {
|
||||
DBCacheConfig dbcache;
|
||||
TaylorSeerConfig taylorseer;
|
||||
int double_Fn_blocks = -1;
|
||||
int double_Bn_blocks = -1;
|
||||
int single_Fn_blocks = -1;
|
||||
int single_Bn_blocks = -1;
|
||||
};
|
||||
|
||||
struct TaylorSeerState {
|
||||
int n_derivatives = 1;
|
||||
int current_step = -1;
|
||||
int last_computed_step = -1;
|
||||
std::vector<std::vector<float>> dY_prev;
|
||||
std::vector<std::vector<float>> dY_current;
|
||||
|
||||
void init(int n_deriv, size_t hidden_size) {
|
||||
n_derivatives = n_deriv;
|
||||
int order = n_derivatives + 1;
|
||||
dY_prev.resize(order);
|
||||
dY_current.resize(order);
|
||||
for (int i = 0; i < order; i++) {
|
||||
dY_prev[i].clear();
|
||||
dY_current[i].clear();
|
||||
}
|
||||
current_step = -1;
|
||||
last_computed_step = -1;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
for (auto& v : dY_prev)
|
||||
v.clear();
|
||||
for (auto& v : dY_current)
|
||||
v.clear();
|
||||
current_step = -1;
|
||||
last_computed_step = -1;
|
||||
}
|
||||
|
||||
bool can_approximate() const {
|
||||
return last_computed_step >= n_derivatives && !dY_prev.empty() && !dY_prev[0].empty();
|
||||
}
|
||||
|
||||
void update_derivatives(const float* Y, size_t size, int step) {
|
||||
int order = n_derivatives + 1;
|
||||
dY_prev = dY_current;
|
||||
dY_current[0].resize(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
dY_current[0][i] = Y[i];
|
||||
}
|
||||
|
||||
int window = step - last_computed_step;
|
||||
if (window <= 0)
|
||||
window = 1;
|
||||
|
||||
for (int d = 0; d < n_derivatives; d++) {
|
||||
if (!dY_prev[d].empty() && dY_prev[d].size() == size) {
|
||||
dY_current[d + 1].resize(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
dY_current[d + 1][i] = (dY_current[d][i] - dY_prev[d][i]) / static_cast<float>(window);
|
||||
}
|
||||
} else {
|
||||
dY_current[d + 1].clear();
|
||||
}
|
||||
}
|
||||
|
||||
current_step = step;
|
||||
last_computed_step = step;
|
||||
}
|
||||
|
||||
void approximate(float* output, size_t size, int target_step) const {
|
||||
if (!can_approximate() || dY_prev[0].size() != size) {
|
||||
return;
|
||||
}
|
||||
|
||||
int elapsed = target_step - last_computed_step;
|
||||
if (elapsed <= 0)
|
||||
elapsed = 1;
|
||||
|
||||
std::fill(output, output + size, 0.0f);
|
||||
float factorial = 1.0f;
|
||||
int order = static_cast<int>(dY_prev.size());
|
||||
|
||||
for (int o = 0; o < order; o++) {
|
||||
if (dY_prev[o].empty() || dY_prev[o].size() != size)
|
||||
continue;
|
||||
if (o > 0)
|
||||
factorial *= static_cast<float>(o);
|
||||
float coeff = std::pow(static_cast<float>(elapsed), o) / factorial;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
output[i] += coeff * dY_prev[o][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct BlockCacheEntry {
|
||||
std::vector<float> residual_img;
|
||||
std::vector<float> residual_txt;
|
||||
std::vector<float> residual;
|
||||
std::vector<float> prev_img;
|
||||
std::vector<float> prev_txt;
|
||||
std::vector<float> prev_output;
|
||||
bool has_prev = false;
|
||||
};
|
||||
|
||||
struct CacheDitState {
|
||||
CacheDitConfig config;
|
||||
bool initialized = false;
|
||||
|
||||
int total_double_blocks = 0;
|
||||
int total_single_blocks = 0;
|
||||
size_t hidden_size = 0;
|
||||
|
||||
int current_step = -1;
|
||||
int total_steps = 0;
|
||||
int warmup_remaining = 0;
|
||||
std::vector<int> cached_steps;
|
||||
int continuous_cached_steps = 0;
|
||||
float accumulated_residual_diff = 0.0f;
|
||||
|
||||
std::vector<BlockCacheEntry> double_block_cache;
|
||||
std::vector<BlockCacheEntry> single_block_cache;
|
||||
|
||||
std::vector<float> Fn_residual_img;
|
||||
std::vector<float> Fn_residual_txt;
|
||||
std::vector<float> prev_Fn_residual_img;
|
||||
std::vector<float> prev_Fn_residual_txt;
|
||||
bool has_prev_Fn_residual = false;
|
||||
|
||||
std::vector<float> Bn_buffer_img;
|
||||
std::vector<float> Bn_buffer_txt;
|
||||
std::vector<float> Bn_buffer;
|
||||
bool has_Bn_buffer = false;
|
||||
|
||||
TaylorSeerState taylor_state;
|
||||
|
||||
bool can_cache_this_step = false;
|
||||
bool is_caching_this_step = false;
|
||||
|
||||
int total_blocks_computed = 0;
|
||||
int total_blocks_cached = 0;
|
||||
|
||||
void init(const CacheDitConfig& cfg, int num_double_blocks, int num_single_blocks, size_t h_size) {
|
||||
config = cfg;
|
||||
total_double_blocks = num_double_blocks;
|
||||
total_single_blocks = num_single_blocks;
|
||||
hidden_size = h_size;
|
||||
|
||||
initialized = cfg.dbcache.enabled || cfg.taylorseer.enabled;
|
||||
|
||||
if (!initialized)
|
||||
return;
|
||||
|
||||
warmup_remaining = cfg.dbcache.max_warmup_steps;
|
||||
double_block_cache.resize(total_double_blocks);
|
||||
single_block_cache.resize(total_single_blocks);
|
||||
|
||||
if (cfg.taylorseer.enabled) {
|
||||
taylor_state.init(cfg.taylorseer.n_derivatives, h_size);
|
||||
}
|
||||
|
||||
reset_runtime();
|
||||
}
|
||||
|
||||
void reset_runtime() {
|
||||
current_step = -1;
|
||||
total_steps = 0;
|
||||
warmup_remaining = config.dbcache.max_warmup_steps;
|
||||
cached_steps.clear();
|
||||
continuous_cached_steps = 0;
|
||||
accumulated_residual_diff = 0.0f;
|
||||
|
||||
for (auto& entry : double_block_cache) {
|
||||
entry.residual_img.clear();
|
||||
entry.residual_txt.clear();
|
||||
entry.prev_img.clear();
|
||||
entry.prev_txt.clear();
|
||||
entry.has_prev = false;
|
||||
}
|
||||
|
||||
for (auto& entry : single_block_cache) {
|
||||
entry.residual.clear();
|
||||
entry.prev_output.clear();
|
||||
entry.has_prev = false;
|
||||
}
|
||||
|
||||
Fn_residual_img.clear();
|
||||
Fn_residual_txt.clear();
|
||||
prev_Fn_residual_img.clear();
|
||||
prev_Fn_residual_txt.clear();
|
||||
has_prev_Fn_residual = false;
|
||||
|
||||
Bn_buffer_img.clear();
|
||||
Bn_buffer_txt.clear();
|
||||
Bn_buffer.clear();
|
||||
has_Bn_buffer = false;
|
||||
|
||||
taylor_state.reset();
|
||||
|
||||
can_cache_this_step = false;
|
||||
is_caching_this_step = false;
|
||||
|
||||
total_blocks_computed = 0;
|
||||
total_blocks_cached = 0;
|
||||
}
|
||||
|
||||
bool enabled() const {
|
||||
return initialized && (config.dbcache.enabled || config.taylorseer.enabled);
|
||||
}
|
||||
|
||||
void begin_step(int step_index, float sigma = 0.0f) {
|
||||
if (!enabled())
|
||||
return;
|
||||
if (step_index == current_step)
|
||||
return;
|
||||
|
||||
current_step = step_index;
|
||||
total_steps++;
|
||||
|
||||
bool in_warmup = warmup_remaining > 0;
|
||||
if (in_warmup) {
|
||||
warmup_remaining--;
|
||||
}
|
||||
|
||||
bool scm_allows_cache = true;
|
||||
if (!config.dbcache.steps_computation_mask.empty()) {
|
||||
if (step_index < static_cast<int>(config.dbcache.steps_computation_mask.size())) {
|
||||
scm_allows_cache = (config.dbcache.steps_computation_mask[step_index] == 0);
|
||||
if (!config.dbcache.scm_policy_dynamic && scm_allows_cache) {
|
||||
can_cache_this_step = true;
|
||||
is_caching_this_step = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool max_cached_ok = (config.dbcache.max_cached_steps < 0) ||
|
||||
(static_cast<int>(cached_steps.size()) < config.dbcache.max_cached_steps);
|
||||
|
||||
bool max_cont_ok = (config.dbcache.max_continuous_cached_steps < 0) ||
|
||||
(continuous_cached_steps < config.dbcache.max_continuous_cached_steps);
|
||||
|
||||
bool accum_ok = (config.dbcache.max_accumulated_residual_diff < 0.0f) ||
|
||||
(accumulated_residual_diff < config.dbcache.max_accumulated_residual_diff);
|
||||
|
||||
can_cache_this_step = !in_warmup && scm_allows_cache && max_cached_ok && max_cont_ok && accum_ok && has_prev_Fn_residual;
|
||||
is_caching_this_step = false;
|
||||
}
|
||||
|
||||
void end_step(bool was_cached) {
|
||||
if (was_cached) {
|
||||
cached_steps.push_back(current_step);
|
||||
continuous_cached_steps++;
|
||||
} else {
|
||||
continuous_cached_steps = 0;
|
||||
}
|
||||
}
|
||||
|
||||
static float calculate_residual_diff(const float* prev, const float* curr, size_t size) {
|
||||
if (size == 0)
|
||||
return 0.0f;
|
||||
|
||||
float sum_diff = 0.0f;
|
||||
float sum_abs = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
sum_diff += std::fabs(prev[i] - curr[i]);
|
||||
sum_abs += std::fabs(prev[i]);
|
||||
}
|
||||
|
||||
return sum_diff / (sum_abs + 1e-6f);
|
||||
}
|
||||
|
||||
static float calculate_residual_diff(const std::vector<float>& prev, const std::vector<float>& curr) {
|
||||
if (prev.size() != curr.size() || prev.empty())
|
||||
return 1.0f;
|
||||
return calculate_residual_diff(prev.data(), curr.data(), prev.size());
|
||||
}
|
||||
|
||||
int get_double_Fn_blocks() const {
|
||||
return (config.double_Fn_blocks >= 0) ? config.double_Fn_blocks : config.dbcache.Fn_compute_blocks;
|
||||
}
|
||||
|
||||
int get_double_Bn_blocks() const {
|
||||
return (config.double_Bn_blocks >= 0) ? config.double_Bn_blocks : config.dbcache.Bn_compute_blocks;
|
||||
}
|
||||
|
||||
int get_single_Fn_blocks() const {
|
||||
return (config.single_Fn_blocks >= 0) ? config.single_Fn_blocks : config.dbcache.Fn_compute_blocks;
|
||||
}
|
||||
|
||||
int get_single_Bn_blocks() const {
|
||||
return (config.single_Bn_blocks >= 0) ? config.single_Bn_blocks : config.dbcache.Bn_compute_blocks;
|
||||
}
|
||||
|
||||
bool is_Fn_double_block(int block_idx) const {
|
||||
return block_idx < get_double_Fn_blocks();
|
||||
}
|
||||
|
||||
bool is_Bn_double_block(int block_idx) const {
|
||||
int Bn = get_double_Bn_blocks();
|
||||
return Bn > 0 && block_idx >= (total_double_blocks - Bn);
|
||||
}
|
||||
|
||||
bool is_Mn_double_block(int block_idx) const {
|
||||
return !is_Fn_double_block(block_idx) && !is_Bn_double_block(block_idx);
|
||||
}
|
||||
|
||||
bool is_Fn_single_block(int block_idx) const {
|
||||
return block_idx < get_single_Fn_blocks();
|
||||
}
|
||||
|
||||
bool is_Bn_single_block(int block_idx) const {
|
||||
int Bn = get_single_Bn_blocks();
|
||||
return Bn > 0 && block_idx >= (total_single_blocks - Bn);
|
||||
}
|
||||
|
||||
bool is_Mn_single_block(int block_idx) const {
|
||||
return !is_Fn_single_block(block_idx) && !is_Bn_single_block(block_idx);
|
||||
}
|
||||
|
||||
void store_Fn_residual(const float* img, const float* txt, size_t img_size, size_t txt_size, const float* input_img, const float* input_txt) {
|
||||
Fn_residual_img.resize(img_size);
|
||||
Fn_residual_txt.resize(txt_size);
|
||||
|
||||
for (size_t i = 0; i < img_size; i++) {
|
||||
Fn_residual_img[i] = img[i] - input_img[i];
|
||||
}
|
||||
for (size_t i = 0; i < txt_size; i++) {
|
||||
Fn_residual_txt[i] = txt[i] - input_txt[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool check_cache_decision() {
|
||||
if (!can_cache_this_step) {
|
||||
is_caching_this_step = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!has_prev_Fn_residual || prev_Fn_residual_img.empty()) {
|
||||
is_caching_this_step = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
float diff_img = calculate_residual_diff(prev_Fn_residual_img, Fn_residual_img);
|
||||
float diff_txt = calculate_residual_diff(prev_Fn_residual_txt, Fn_residual_txt);
|
||||
float diff = (diff_img + diff_txt) / 2.0f;
|
||||
|
||||
if (diff < config.dbcache.residual_diff_threshold) {
|
||||
is_caching_this_step = true;
|
||||
accumulated_residual_diff += diff;
|
||||
return true;
|
||||
}
|
||||
|
||||
is_caching_this_step = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
void update_prev_Fn_residual() {
|
||||
prev_Fn_residual_img = Fn_residual_img;
|
||||
prev_Fn_residual_txt = Fn_residual_txt;
|
||||
has_prev_Fn_residual = !prev_Fn_residual_img.empty();
|
||||
}
|
||||
|
||||
void store_double_block_residual(int block_idx, const float* img, const float* txt, size_t img_size, size_t txt_size, const float* prev_img, const float* prev_txt) {
|
||||
if (block_idx < 0 || block_idx >= static_cast<int>(double_block_cache.size()))
|
||||
return;
|
||||
|
||||
BlockCacheEntry& entry = double_block_cache[block_idx];
|
||||
|
||||
entry.residual_img.resize(img_size);
|
||||
entry.residual_txt.resize(txt_size);
|
||||
for (size_t i = 0; i < img_size; i++) {
|
||||
entry.residual_img[i] = img[i] - prev_img[i];
|
||||
}
|
||||
for (size_t i = 0; i < txt_size; i++) {
|
||||
entry.residual_txt[i] = txt[i] - prev_txt[i];
|
||||
}
|
||||
|
||||
entry.prev_img.resize(img_size);
|
||||
entry.prev_txt.resize(txt_size);
|
||||
for (size_t i = 0; i < img_size; i++) {
|
||||
entry.prev_img[i] = img[i];
|
||||
}
|
||||
for (size_t i = 0; i < txt_size; i++) {
|
||||
entry.prev_txt[i] = txt[i];
|
||||
}
|
||||
entry.has_prev = true;
|
||||
}
|
||||
|
||||
void apply_double_block_cache(int block_idx, float* img, float* txt, size_t img_size, size_t txt_size) {
|
||||
if (block_idx < 0 || block_idx >= static_cast<int>(double_block_cache.size()))
|
||||
return;
|
||||
|
||||
const BlockCacheEntry& entry = double_block_cache[block_idx];
|
||||
if (entry.residual_img.size() != img_size || entry.residual_txt.size() != txt_size)
|
||||
return;
|
||||
|
||||
for (size_t i = 0; i < img_size; i++) {
|
||||
img[i] += entry.residual_img[i];
|
||||
}
|
||||
for (size_t i = 0; i < txt_size; i++) {
|
||||
txt[i] += entry.residual_txt[i];
|
||||
}
|
||||
|
||||
total_blocks_cached++;
|
||||
}
|
||||
|
||||
void store_single_block_residual(int block_idx, const float* output, size_t size, const float* input) {
|
||||
if (block_idx < 0 || block_idx >= static_cast<int>(single_block_cache.size()))
|
||||
return;
|
||||
|
||||
BlockCacheEntry& entry = single_block_cache[block_idx];
|
||||
|
||||
entry.residual.resize(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
entry.residual[i] = output[i] - input[i];
|
||||
}
|
||||
|
||||
entry.prev_output.resize(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
entry.prev_output[i] = output[i];
|
||||
}
|
||||
entry.has_prev = true;
|
||||
}
|
||||
|
||||
void apply_single_block_cache(int block_idx, float* output, size_t size) {
|
||||
if (block_idx < 0 || block_idx >= static_cast<int>(single_block_cache.size()))
|
||||
return;
|
||||
|
||||
const BlockCacheEntry& entry = single_block_cache[block_idx];
|
||||
if (entry.residual.size() != size)
|
||||
return;
|
||||
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
output[i] += entry.residual[i];
|
||||
}
|
||||
|
||||
total_blocks_cached++;
|
||||
}
|
||||
|
||||
void store_Bn_buffer(const float* img, const float* txt, size_t img_size, size_t txt_size, const float* Bn_start_img, const float* Bn_start_txt) {
|
||||
Bn_buffer_img.resize(img_size);
|
||||
Bn_buffer_txt.resize(txt_size);
|
||||
|
||||
for (size_t i = 0; i < img_size; i++) {
|
||||
Bn_buffer_img[i] = img[i] - Bn_start_img[i];
|
||||
}
|
||||
for (size_t i = 0; i < txt_size; i++) {
|
||||
Bn_buffer_txt[i] = txt[i] - Bn_start_txt[i];
|
||||
}
|
||||
has_Bn_buffer = true;
|
||||
}
|
||||
|
||||
void apply_Bn_buffer(float* img, float* txt, size_t img_size, size_t txt_size) {
|
||||
if (!has_Bn_buffer)
|
||||
return;
|
||||
if (Bn_buffer_img.size() != img_size || Bn_buffer_txt.size() != txt_size)
|
||||
return;
|
||||
|
||||
for (size_t i = 0; i < img_size; i++) {
|
||||
img[i] += Bn_buffer_img[i];
|
||||
}
|
||||
for (size_t i = 0; i < txt_size; i++) {
|
||||
txt[i] += Bn_buffer_txt[i];
|
||||
}
|
||||
}
|
||||
|
||||
void taylor_update(const float* hidden_state, size_t size) {
|
||||
if (!config.taylorseer.enabled)
|
||||
return;
|
||||
taylor_state.update_derivatives(hidden_state, size, current_step);
|
||||
}
|
||||
|
||||
bool taylor_can_approximate() const {
|
||||
return config.taylorseer.enabled && taylor_state.can_approximate();
|
||||
}
|
||||
|
||||
void taylor_approximate(float* output, size_t size) {
|
||||
if (!config.taylorseer.enabled)
|
||||
return;
|
||||
taylor_state.approximate(output, size, current_step);
|
||||
}
|
||||
|
||||
bool should_use_taylor_this_step() const {
|
||||
if (!config.taylorseer.enabled)
|
||||
return false;
|
||||
if (current_step < config.taylorseer.max_warmup_steps)
|
||||
return false;
|
||||
|
||||
int interval = config.taylorseer.skip_interval_steps;
|
||||
if (interval <= 0)
|
||||
interval = 1;
|
||||
|
||||
return (current_step % (interval + 1)) != 0;
|
||||
}
|
||||
|
||||
void log_metrics() const {
|
||||
if (!enabled())
|
||||
return;
|
||||
|
||||
int total_blocks = total_blocks_computed + total_blocks_cached;
|
||||
float cache_ratio = (total_blocks > 0) ? (static_cast<float>(total_blocks_cached) / total_blocks * 100.0f) : 0.0f;
|
||||
|
||||
float step_cache_ratio = (total_steps > 0) ? (static_cast<float>(cached_steps.size()) / total_steps * 100.0f) : 0.0f;
|
||||
|
||||
LOG_INFO("CacheDIT: steps_cached=%zu/%d (%.1f%%), blocks_cached=%d/%d (%.1f%%), accum_diff=%.4f",
|
||||
cached_steps.size(), total_steps, step_cache_ratio,
|
||||
total_blocks_cached, total_blocks, cache_ratio,
|
||||
accumulated_residual_diff);
|
||||
}
|
||||
|
||||
std::string get_summary() const {
|
||||
char buf[256];
|
||||
snprintf(buf, sizeof(buf),
|
||||
"CacheDIT[thresh=%.2f]: cached %zu/%d steps, %d/%d blocks",
|
||||
config.dbcache.residual_diff_threshold,
|
||||
cached_steps.size(), total_steps,
|
||||
total_blocks_cached, total_blocks_computed + total_blocks_cached);
|
||||
return std::string(buf);
|
||||
}
|
||||
};
|
||||
|
||||
inline std::vector<int> parse_scm_mask(const std::string& mask_str) {
|
||||
std::vector<int> mask;
|
||||
if (mask_str.empty())
|
||||
return mask;
|
||||
|
||||
size_t pos = 0;
|
||||
size_t start = 0;
|
||||
while ((pos = mask_str.find(',', start)) != std::string::npos) {
|
||||
std::string token = mask_str.substr(start, pos - start);
|
||||
mask.push_back(std::stoi(token));
|
||||
start = pos + 1;
|
||||
}
|
||||
if (start < mask_str.length()) {
|
||||
mask.push_back(std::stoi(mask_str.substr(start)));
|
||||
}
|
||||
|
||||
return mask;
|
||||
}
|
||||
|
||||
inline std::vector<int> generate_scm_mask(
|
||||
const std::vector<int>& compute_bins,
|
||||
const std::vector<int>& cache_bins,
|
||||
int total_steps) {
|
||||
std::vector<int> mask;
|
||||
size_t c_idx = 0, cache_idx = 0;
|
||||
|
||||
while (static_cast<int>(mask.size()) < total_steps) {
|
||||
if (c_idx < compute_bins.size()) {
|
||||
for (int i = 0; i < compute_bins[c_idx] && static_cast<int>(mask.size()) < total_steps; i++) {
|
||||
mask.push_back(1);
|
||||
}
|
||||
c_idx++;
|
||||
}
|
||||
if (cache_idx < cache_bins.size()) {
|
||||
for (int i = 0; i < cache_bins[cache_idx] && static_cast<int>(mask.size()) < total_steps; i++) {
|
||||
mask.push_back(0);
|
||||
}
|
||||
cache_idx++;
|
||||
}
|
||||
if (c_idx >= compute_bins.size() && cache_idx >= cache_bins.size())
|
||||
break;
|
||||
}
|
||||
|
||||
if (!mask.empty()) {
|
||||
mask.back() = 1;
|
||||
}
|
||||
|
||||
return mask;
|
||||
}
|
||||
|
||||
inline std::vector<int> get_scm_preset(const std::string& preset, int total_steps) {
|
||||
struct Preset {
|
||||
std::vector<int> compute_bins;
|
||||
std::vector<int> cache_bins;
|
||||
};
|
||||
|
||||
Preset slow = {{8, 3, 3, 2, 1, 1}, {1, 2, 2, 2, 3}};
|
||||
Preset medium = {{6, 2, 2, 2, 2, 1}, {1, 3, 3, 3, 3}};
|
||||
Preset fast = {{6, 1, 1, 1, 1, 1}, {1, 3, 4, 5, 4}};
|
||||
Preset ultra = {{4, 1, 1, 1, 1}, {2, 5, 6, 7}};
|
||||
|
||||
Preset* p = nullptr;
|
||||
if (preset == "slow" || preset == "s" || preset == "S")
|
||||
p = &slow;
|
||||
else if (preset == "medium" || preset == "m" || preset == "M")
|
||||
p = &medium;
|
||||
else if (preset == "fast" || preset == "f" || preset == "F")
|
||||
p = &fast;
|
||||
else if (preset == "ultra" || preset == "u" || preset == "U")
|
||||
p = &ultra;
|
||||
else
|
||||
return {};
|
||||
|
||||
if (total_steps != 28 && total_steps > 0) {
|
||||
float scale = static_cast<float>(total_steps) / 28.0f;
|
||||
std::vector<int> scaled_compute, scaled_cache;
|
||||
|
||||
for (int v : p->compute_bins) {
|
||||
scaled_compute.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
|
||||
}
|
||||
for (int v : p->cache_bins) {
|
||||
scaled_cache.push_back(std::max(1, static_cast<int>(v * scale + 0.5f)));
|
||||
}
|
||||
|
||||
return generate_scm_mask(scaled_compute, scaled_cache, total_steps);
|
||||
}
|
||||
|
||||
return generate_scm_mask(p->compute_bins, p->cache_bins, total_steps);
|
||||
}
|
||||
|
||||
inline float get_preset_threshold(const std::string& preset) {
|
||||
if (preset == "slow" || preset == "s" || preset == "S")
|
||||
return 0.20f;
|
||||
if (preset == "medium" || preset == "m" || preset == "M")
|
||||
return 0.25f;
|
||||
if (preset == "fast" || preset == "f" || preset == "F")
|
||||
return 0.30f;
|
||||
if (preset == "ultra" || preset == "u" || preset == "U")
|
||||
return 0.34f;
|
||||
return 0.08f;
|
||||
}
|
||||
|
||||
inline int get_preset_warmup(const std::string& preset) {
|
||||
if (preset == "slow" || preset == "s" || preset == "S")
|
||||
return 8;
|
||||
if (preset == "medium" || preset == "m" || preset == "M")
|
||||
return 6;
|
||||
if (preset == "fast" || preset == "f" || preset == "F")
|
||||
return 6;
|
||||
if (preset == "ultra" || preset == "u" || preset == "U")
|
||||
return 4;
|
||||
return 8;
|
||||
}
|
||||
|
||||
inline int get_preset_Fn(const std::string& preset) {
|
||||
if (preset == "slow" || preset == "s" || preset == "S")
|
||||
return 8;
|
||||
if (preset == "medium" || preset == "m" || preset == "M")
|
||||
return 8;
|
||||
if (preset == "fast" || preset == "f" || preset == "F")
|
||||
return 6;
|
||||
if (preset == "ultra" || preset == "u" || preset == "U")
|
||||
return 4;
|
||||
return 8;
|
||||
}
|
||||
|
||||
inline int get_preset_Bn(const std::string& preset) {
|
||||
(void)preset;
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline void parse_dbcache_options(const std::string& opts, DBCacheConfig& cfg) {
|
||||
if (opts.empty())
|
||||
return;
|
||||
|
||||
int Fn = 8, Bn = 0, warmup = 8, max_cached = -1, max_cont = -1;
|
||||
float thresh = 0.08f;
|
||||
|
||||
sscanf(opts.c_str(), "%d,%d,%f,%d,%d,%d",
|
||||
&Fn, &Bn, &thresh, &warmup, &max_cached, &max_cont);
|
||||
|
||||
cfg.Fn_compute_blocks = Fn;
|
||||
cfg.Bn_compute_blocks = Bn;
|
||||
cfg.residual_diff_threshold = thresh;
|
||||
cfg.max_warmup_steps = warmup;
|
||||
cfg.max_cached_steps = max_cached;
|
||||
cfg.max_continuous_cached_steps = max_cont;
|
||||
}
|
||||
|
||||
inline void parse_taylorseer_options(const std::string& opts, TaylorSeerConfig& cfg) {
|
||||
if (opts.empty())
|
||||
return;
|
||||
|
||||
int n_deriv = 1, warmup = 2, interval = 1;
|
||||
sscanf(opts.c_str(), "%d,%d,%d", &n_deriv, &warmup, &interval);
|
||||
|
||||
cfg.n_derivatives = n_deriv;
|
||||
cfg.max_warmup_steps = warmup;
|
||||
cfg.skip_interval_steps = interval;
|
||||
}
|
||||
|
||||
struct CacheDitConditionState {
|
||||
DBCacheConfig config;
|
||||
TaylorSeerConfig taylor_config;
|
||||
bool initialized = false;
|
||||
|
||||
int current_step_index = -1;
|
||||
bool step_active = false;
|
||||
bool skip_current_step = false;
|
||||
bool initial_step = true;
|
||||
int warmup_remaining = 0;
|
||||
std::vector<int> cached_steps;
|
||||
int continuous_cached_steps = 0;
|
||||
float accumulated_residual_diff = 0.0f;
|
||||
int total_steps_skipped = 0;
|
||||
|
||||
const void* anchor_condition = nullptr;
|
||||
|
||||
struct CacheEntry {
|
||||
std::vector<float> diff;
|
||||
std::vector<float> prev_input;
|
||||
std::vector<float> prev_output;
|
||||
bool has_prev = false;
|
||||
};
|
||||
std::unordered_map<const void*, CacheEntry> cache_diffs;
|
||||
|
||||
TaylorSeerState taylor_state;
|
||||
|
||||
float start_sigma = std::numeric_limits<float>::max();
|
||||
float end_sigma = 0.0f;
|
||||
|
||||
void reset_runtime() {
|
||||
current_step_index = -1;
|
||||
step_active = false;
|
||||
skip_current_step = false;
|
||||
initial_step = true;
|
||||
warmup_remaining = config.max_warmup_steps;
|
||||
cached_steps.clear();
|
||||
continuous_cached_steps = 0;
|
||||
accumulated_residual_diff = 0.0f;
|
||||
total_steps_skipped = 0;
|
||||
anchor_condition = nullptr;
|
||||
cache_diffs.clear();
|
||||
taylor_state.reset();
|
||||
}
|
||||
|
||||
void init(const DBCacheConfig& dbcfg, const TaylorSeerConfig& tcfg) {
|
||||
config = dbcfg;
|
||||
taylor_config = tcfg;
|
||||
initialized = dbcfg.enabled || tcfg.enabled;
|
||||
reset_runtime();
|
||||
|
||||
if (taylor_config.enabled) {
|
||||
taylor_state.init(taylor_config.n_derivatives, 0);
|
||||
}
|
||||
}
|
||||
|
||||
void set_sigmas(const std::vector<float>& sigmas) {
|
||||
if (!initialized || sigmas.size() < 2)
|
||||
return;
|
||||
|
||||
float start_percent = 0.15f;
|
||||
float end_percent = 0.95f;
|
||||
|
||||
size_t n_steps = sigmas.size() - 1;
|
||||
size_t start_step = static_cast<size_t>(start_percent * n_steps);
|
||||
size_t end_step = static_cast<size_t>(end_percent * n_steps);
|
||||
|
||||
if (start_step >= n_steps)
|
||||
start_step = n_steps - 1;
|
||||
if (end_step >= n_steps)
|
||||
end_step = n_steps - 1;
|
||||
|
||||
start_sigma = sigmas[start_step];
|
||||
end_sigma = sigmas[end_step];
|
||||
|
||||
if (start_sigma < end_sigma) {
|
||||
std::swap(start_sigma, end_sigma);
|
||||
}
|
||||
}
|
||||
|
||||
bool enabled() const {
|
||||
return initialized && (config.enabled || taylor_config.enabled);
|
||||
}
|
||||
|
||||
void begin_step(int step_index, float sigma) {
|
||||
if (!enabled())
|
||||
return;
|
||||
if (step_index == current_step_index)
|
||||
return;
|
||||
|
||||
current_step_index = step_index;
|
||||
skip_current_step = false;
|
||||
step_active = false;
|
||||
|
||||
if (sigma > start_sigma)
|
||||
return;
|
||||
if (!(sigma > end_sigma))
|
||||
return;
|
||||
|
||||
step_active = true;
|
||||
|
||||
if (warmup_remaining > 0) {
|
||||
warmup_remaining--;
|
||||
return;
|
||||
}
|
||||
|
||||
if (!config.steps_computation_mask.empty()) {
|
||||
if (step_index < static_cast<int>(config.steps_computation_mask.size())) {
|
||||
if (config.steps_computation_mask[step_index] == 1) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (config.max_cached_steps >= 0 &&
|
||||
static_cast<int>(cached_steps.size()) >= config.max_cached_steps) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (config.max_continuous_cached_steps >= 0 &&
|
||||
continuous_cached_steps >= config.max_continuous_cached_steps) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
bool step_is_active() const {
|
||||
return enabled() && step_active;
|
||||
}
|
||||
|
||||
bool is_step_skipped() const {
|
||||
return enabled() && step_active && skip_current_step;
|
||||
}
|
||||
|
||||
bool has_cache(const void* cond) const {
|
||||
auto it = cache_diffs.find(cond);
|
||||
return it != cache_diffs.end() && !it->second.diff.empty();
|
||||
}
|
||||
|
||||
void update_cache(const void* cond, const float* input, const float* output, size_t size) {
|
||||
CacheEntry& entry = cache_diffs[cond];
|
||||
entry.diff.resize(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
entry.diff[i] = output[i] - input[i];
|
||||
}
|
||||
|
||||
entry.prev_input.resize(size);
|
||||
entry.prev_output.resize(size);
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
entry.prev_input[i] = input[i];
|
||||
entry.prev_output[i] = output[i];
|
||||
}
|
||||
entry.has_prev = true;
|
||||
}
|
||||
|
||||
void apply_cache(const void* cond, const float* input, float* output, size_t size) {
|
||||
auto it = cache_diffs.find(cond);
|
||||
if (it == cache_diffs.end() || it->second.diff.empty())
|
||||
return;
|
||||
if (it->second.diff.size() != size)
|
||||
return;
|
||||
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
output[i] = input[i] + it->second.diff[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool before_condition(const void* cond, struct ggml_tensor* input, struct ggml_tensor* output, float sigma, int step_index) {
|
||||
if (!enabled() || step_index < 0)
|
||||
return false;
|
||||
|
||||
if (step_index != current_step_index) {
|
||||
begin_step(step_index, sigma);
|
||||
}
|
||||
|
||||
if (!step_active)
|
||||
return false;
|
||||
|
||||
if (initial_step) {
|
||||
anchor_condition = cond;
|
||||
initial_step = false;
|
||||
}
|
||||
|
||||
bool is_anchor = (cond == anchor_condition);
|
||||
|
||||
if (skip_current_step) {
|
||||
if (has_cache(cond)) {
|
||||
apply_cache(cond, (float*)input->data, (float*)output->data,
|
||||
static_cast<size_t>(ggml_nelements(output)));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!is_anchor)
|
||||
return false;
|
||||
|
||||
auto it = cache_diffs.find(cond);
|
||||
if (it == cache_diffs.end() || !it->second.has_prev)
|
||||
return false;
|
||||
|
||||
size_t ne = static_cast<size_t>(ggml_nelements(input));
|
||||
if (it->second.prev_input.size() != ne)
|
||||
return false;
|
||||
|
||||
float* input_data = (float*)input->data;
|
||||
float diff = CacheDitState::calculate_residual_diff(
|
||||
it->second.prev_input.data(), input_data, ne);
|
||||
|
||||
float effective_threshold = config.residual_diff_threshold;
|
||||
if (config.Fn_compute_blocks > 0) {
|
||||
float fn_confidence = 1.0f + 0.02f * (config.Fn_compute_blocks - 8);
|
||||
fn_confidence = std::max(0.5f, std::min(2.0f, fn_confidence));
|
||||
effective_threshold *= fn_confidence;
|
||||
}
|
||||
if (config.Bn_compute_blocks > 0) {
|
||||
float bn_quality = 1.0f - 0.03f * config.Bn_compute_blocks;
|
||||
bn_quality = std::max(0.5f, std::min(1.0f, bn_quality));
|
||||
effective_threshold *= bn_quality;
|
||||
}
|
||||
|
||||
if (diff < effective_threshold) {
|
||||
skip_current_step = true;
|
||||
total_steps_skipped++;
|
||||
cached_steps.push_back(current_step_index);
|
||||
continuous_cached_steps++;
|
||||
accumulated_residual_diff += diff;
|
||||
apply_cache(cond, input_data, (float*)output->data, ne);
|
||||
return true;
|
||||
}
|
||||
|
||||
continuous_cached_steps = 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
void after_condition(const void* cond, struct ggml_tensor* input, struct ggml_tensor* output) {
|
||||
if (!step_is_active())
|
||||
return;
|
||||
|
||||
size_t ne = static_cast<size_t>(ggml_nelements(output));
|
||||
update_cache(cond, (float*)input->data, (float*)output->data, ne);
|
||||
|
||||
if (cond == anchor_condition && taylor_config.enabled) {
|
||||
taylor_state.update_derivatives((float*)output->data, ne, current_step_index);
|
||||
}
|
||||
}
|
||||
|
||||
void log_metrics() const {
|
||||
if (!enabled())
|
||||
return;
|
||||
|
||||
LOG_INFO("CacheDIT: steps_skipped=%d/%d (%.1f%%), accum_residual_diff=%.4f",
|
||||
total_steps_skipped,
|
||||
current_step_index + 1,
|
||||
(current_step_index > 0) ? (100.0f * total_steps_skipped / (current_step_index + 1)) : 0.0f,
|
||||
accumulated_residual_diff);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
126
docs/caching.md
Normal file
126
docs/caching.md
Normal file
@ -0,0 +1,126 @@
|
||||
## Caching
|
||||
|
||||
Caching methods accelerate diffusion inference by reusing intermediate computations when changes between steps are small.
|
||||
|
||||
### Cache Modes
|
||||
|
||||
| Mode | Target | Description |
|
||||
|------|--------|-------------|
|
||||
| `ucache` | UNET models | Condition-level caching with error tracking |
|
||||
| `easycache` | DiT models | Condition-level cache |
|
||||
| `dbcache` | DiT models | Block-level L1 residual threshold |
|
||||
| `taylorseer` | DiT models | Taylor series approximation |
|
||||
| `cache-dit` | DiT models | Combined DBCache + TaylorSeer |
|
||||
|
||||
### UCache (UNET Models)
|
||||
|
||||
UCache caches the residual difference (output - input) and reuses it when input changes are below threshold.
|
||||
|
||||
```bash
|
||||
sd-cli -m model.safetensors -p "a cat" --cache-mode ucache --cache-option "threshold=1.5"
|
||||
```
|
||||
|
||||
#### Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `threshold` | Error threshold for reuse decision | 1.0 |
|
||||
| `start` | Start caching at this percent of steps | 0.15 |
|
||||
| `end` | Stop caching at this percent of steps | 0.95 |
|
||||
| `decay` | Error decay rate (0-1) | 1.0 |
|
||||
| `relative` | Scale threshold by output norm (0/1) | 1 |
|
||||
| `reset` | Reset error after computing (0/1) | 1 |
|
||||
|
||||
#### Reset Parameter
|
||||
|
||||
The `reset` parameter controls error accumulation behavior:
|
||||
|
||||
- `reset=1` (default): Resets accumulated error after each computed step. More aggressive caching, works well with most samplers.
|
||||
- `reset=0`: Keeps error accumulated. More conservative, recommended for `euler_a` sampler.
|
||||
|
||||
### EasyCache (DiT Models)
|
||||
|
||||
Condition-level caching for DiT models. Caches and reuses outputs when input changes are below threshold.
|
||||
|
||||
```bash
|
||||
--cache-mode easycache --cache-option "threshold=0.3"
|
||||
```
|
||||
|
||||
#### Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `threshold` | Input change threshold for reuse | 0.2 |
|
||||
| `start` | Start caching at this percent of steps | 0.15 |
|
||||
| `end` | Stop caching at this percent of steps | 0.95 |
|
||||
|
||||
### Cache-DIT (DiT Models)
|
||||
|
||||
For DiT models like FLUX and QWEN, use block-level caching modes.
|
||||
|
||||
#### DBCache
|
||||
|
||||
Caches blocks based on L1 residual difference threshold:
|
||||
|
||||
```bash
|
||||
--cache-mode dbcache --cache-option "threshold=0.25,warmup=4"
|
||||
```
|
||||
|
||||
#### TaylorSeer
|
||||
|
||||
Uses Taylor series approximation to predict block outputs:
|
||||
|
||||
```bash
|
||||
--cache-mode taylorseer
|
||||
```
|
||||
|
||||
#### Cache-DIT (Combined)
|
||||
|
||||
Combines DBCache and TaylorSeer:
|
||||
|
||||
```bash
|
||||
--cache-mode cache-dit --cache-preset fast
|
||||
```
|
||||
|
||||
#### Parameters
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `Fn` | Front blocks to always compute | 8 |
|
||||
| `Bn` | Back blocks to always compute | 0 |
|
||||
| `threshold` | L1 residual difference threshold | 0.08 |
|
||||
| `warmup` | Steps before caching starts | 8 |
|
||||
|
||||
#### Presets
|
||||
|
||||
Available presets: `slow`, `medium`, `fast`, `ultra` (or `s`, `m`, `f`, `u`).
|
||||
|
||||
```bash
|
||||
--cache-mode cache-dit --cache-preset fast
|
||||
```
|
||||
|
||||
#### SCM Options
|
||||
|
||||
Steps Computation Mask controls which steps can be cached:
|
||||
|
||||
```bash
|
||||
--scm-mask "1,1,1,1,0,0,1,0,0,0,1,0,0,0,1,0,0,0,1,1"
|
||||
```
|
||||
|
||||
Mask values: `1` = compute, `0` = can cache.
|
||||
|
||||
| Policy | Description |
|
||||
|--------|-------------|
|
||||
| `dynamic` | Check threshold before caching |
|
||||
| `static` | Always cache on cacheable steps |
|
||||
|
||||
```bash
|
||||
--scm-policy dynamic
|
||||
```
|
||||
|
||||
### Performance Tips
|
||||
|
||||
- Start with default thresholds and adjust based on output quality
|
||||
- Lower threshold = better quality, less speedup
|
||||
- Higher threshold = more speedup, potential quality loss
|
||||
- More steps generally means more caching opportunities
|
||||
@ -127,5 +127,12 @@ Generation Options:
|
||||
--skip-layers layers to skip for SLG steps (default: [7,8,9])
|
||||
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
|
||||
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
|
||||
--easycache enable EasyCache for DiT models with optional "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95)
|
||||
--cache-mode caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)
|
||||
--cache-option named cache params (key=value format, comma-separated):
|
||||
- easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=
|
||||
- dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=
|
||||
Examples: "threshold=0.25" or "threshold=1.5,reset=0"
|
||||
--cache-preset cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'
|
||||
--scm-mask SCM steps mask: comma-separated 0/1 (1=compute, 0=can cache)
|
||||
--scm-policy SCM policy: 'dynamic' (default) or 'static'
|
||||
```
|
||||
|
||||
@ -617,7 +617,7 @@ int main(int argc, const char* argv[]) {
|
||||
gen_params.pm_style_strength,
|
||||
}, // pm_params
|
||||
ctx_params.vae_tiling_params,
|
||||
gen_params.easycache_params,
|
||||
gen_params.cache_params,
|
||||
};
|
||||
|
||||
results = generate_image(sd_ctx, &img_gen_params);
|
||||
@ -642,7 +642,7 @@ int main(int argc, const char* argv[]) {
|
||||
gen_params.seed,
|
||||
gen_params.video_frames,
|
||||
gen_params.vace_strength,
|
||||
gen_params.easycache_params,
|
||||
gen_params.cache_params,
|
||||
};
|
||||
|
||||
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
|
||||
|
||||
@ -1018,8 +1018,12 @@ struct SDGenerationParams {
|
||||
|
||||
std::vector<float> custom_sigmas;
|
||||
|
||||
std::string easycache_option;
|
||||
sd_easycache_params_t easycache_params;
|
||||
std::string cache_mode;
|
||||
std::string cache_option;
|
||||
std::string cache_preset;
|
||||
std::string scm_mask;
|
||||
bool scm_policy_dynamic = true;
|
||||
sd_cache_params_t cache_params{};
|
||||
|
||||
float moe_boundary = 0.875f;
|
||||
int video_frames = 1;
|
||||
@ -1381,36 +1385,64 @@ struct SDGenerationParams {
|
||||
return 1;
|
||||
};
|
||||
|
||||
auto on_easycache_arg = [&](int argc, const char** argv, int index) {
|
||||
const std::string default_values = "0.2,0.15,0.95";
|
||||
auto looks_like_value = [](const std::string& token) {
|
||||
if (token.empty()) {
|
||||
return false;
|
||||
}
|
||||
if (token[0] != '-') {
|
||||
return true;
|
||||
}
|
||||
if (token.size() == 1) {
|
||||
return false;
|
||||
}
|
||||
unsigned char next = static_cast<unsigned char>(token[1]);
|
||||
return std::isdigit(next) || token[1] == '.';
|
||||
};
|
||||
auto on_cache_mode_arg = [&](int argc, const char** argv, int index) {
|
||||
if (++index >= argc) {
|
||||
return -1;
|
||||
}
|
||||
cache_mode = argv_to_utf8(index, argv);
|
||||
if (cache_mode != "easycache" && cache_mode != "ucache" &&
|
||||
cache_mode != "dbcache" && cache_mode != "taylorseer" && cache_mode != "cache-dit") {
|
||||
fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache', 'ucache', 'dbcache', 'taylorseer', or 'cache-dit'\n", cache_mode.c_str());
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
std::string option_value;
|
||||
int consumed = 0;
|
||||
if (index + 1 < argc) {
|
||||
std::string next_arg = argv[index + 1];
|
||||
if (looks_like_value(next_arg)) {
|
||||
option_value = argv_to_utf8(index + 1, argv);
|
||||
consumed = 1;
|
||||
}
|
||||
auto on_cache_option_arg = [&](int argc, const char** argv, int index) {
|
||||
if (++index >= argc) {
|
||||
return -1;
|
||||
}
|
||||
if (option_value.empty()) {
|
||||
option_value = default_values;
|
||||
cache_option = argv_to_utf8(index, argv);
|
||||
return 1;
|
||||
};
|
||||
|
||||
auto on_scm_mask_arg = [&](int argc, const char** argv, int index) {
|
||||
if (++index >= argc) {
|
||||
return -1;
|
||||
}
|
||||
easycache_option = option_value;
|
||||
return consumed;
|
||||
scm_mask = argv_to_utf8(index, argv);
|
||||
return 1;
|
||||
};
|
||||
|
||||
auto on_scm_policy_arg = [&](int argc, const char** argv, int index) {
|
||||
if (++index >= argc) {
|
||||
return -1;
|
||||
}
|
||||
std::string policy = argv_to_utf8(index, argv);
|
||||
if (policy == "dynamic") {
|
||||
scm_policy_dynamic = true;
|
||||
} else if (policy == "static") {
|
||||
scm_policy_dynamic = false;
|
||||
} else {
|
||||
fprintf(stderr, "error: invalid scm policy '%s', must be 'dynamic' or 'static'\n", policy.c_str());
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
auto on_cache_preset_arg = [&](int argc, const char** argv, int index) {
|
||||
if (++index >= argc) {
|
||||
return -1;
|
||||
}
|
||||
cache_preset = argv_to_utf8(index, argv);
|
||||
if (cache_preset != "slow" && cache_preset != "s" && cache_preset != "S" &&
|
||||
cache_preset != "medium" && cache_preset != "m" && cache_preset != "M" &&
|
||||
cache_preset != "fast" && cache_preset != "f" && cache_preset != "F" &&
|
||||
cache_preset != "ultra" && cache_preset != "u" && cache_preset != "U") {
|
||||
fprintf(stderr, "error: invalid cache preset '%s', must be 'slow'/'s', 'medium'/'m', 'fast'/'f', or 'ultra'/'u'\n", cache_preset.c_str());
|
||||
return -1;
|
||||
}
|
||||
return 1;
|
||||
};
|
||||
|
||||
options.manual_options = {
|
||||
@ -1449,9 +1481,25 @@ struct SDGenerationParams {
|
||||
"reference image for Flux Kontext models (can be used multiple times)",
|
||||
on_ref_image_arg},
|
||||
{"",
|
||||
"--easycache",
|
||||
"enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)",
|
||||
on_easycache_arg},
|
||||
"--cache-mode",
|
||||
"caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)",
|
||||
on_cache_mode_arg},
|
||||
{"",
|
||||
"--cache-option",
|
||||
"named cache params (key=value format, comma-separated):\n - easycache/ucache: threshold=,start=,end=,decay=,relative=,reset=\n - dbcache/taylorseer/cache-dit: Fn=,Bn=,threshold=,warmup=\n Examples: \"threshold=0.25\" or \"threshold=1.5,reset=0\"",
|
||||
on_cache_option_arg},
|
||||
{"",
|
||||
"--cache-preset",
|
||||
"cache-dit preset: 'slow'/'s', 'medium'/'m', 'fast'/'f', 'ultra'/'u'",
|
||||
on_cache_preset_arg},
|
||||
{"",
|
||||
"--scm-mask",
|
||||
"SCM steps mask for cache-dit: comma-separated 0/1 (e.g., \"1,1,1,0,0,1,0,0,1,0\") - 1=compute, 0=can cache",
|
||||
on_scm_mask_arg},
|
||||
{"",
|
||||
"--scm-policy",
|
||||
"SCM policy: 'dynamic' (default) or 'static'",
|
||||
on_scm_policy_arg},
|
||||
|
||||
};
|
||||
|
||||
@ -1494,7 +1542,10 @@ struct SDGenerationParams {
|
||||
|
||||
load_if_exists("prompt", prompt);
|
||||
load_if_exists("negative_prompt", negative_prompt);
|
||||
load_if_exists("easycache_option", easycache_option);
|
||||
load_if_exists("cache_mode", cache_mode);
|
||||
load_if_exists("cache_option", cache_option);
|
||||
load_if_exists("cache_preset", cache_preset);
|
||||
load_if_exists("scm_mask", scm_mask);
|
||||
|
||||
load_if_exists("clip_skip", clip_skip);
|
||||
load_if_exists("width", width);
|
||||
@ -1634,57 +1685,118 @@ struct SDGenerationParams {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!easycache_option.empty()) {
|
||||
float values[3] = {0.0f, 0.0f, 0.0f};
|
||||
std::stringstream ss(easycache_option);
|
||||
sd_cache_params_init(&cache_params);
|
||||
|
||||
auto parse_named_params = [&](const std::string& opt_str) -> bool {
|
||||
std::stringstream ss(opt_str);
|
||||
std::string token;
|
||||
int idx = 0;
|
||||
while (std::getline(ss, token, ',')) {
|
||||
auto trim = [](std::string& s) {
|
||||
const char* whitespace = " \t\r\n";
|
||||
auto start = s.find_first_not_of(whitespace);
|
||||
if (start == std::string::npos) {
|
||||
s.clear();
|
||||
return;
|
||||
}
|
||||
auto end = s.find_last_not_of(whitespace);
|
||||
s = s.substr(start, end - start + 1);
|
||||
};
|
||||
trim(token);
|
||||
if (token.empty()) {
|
||||
LOG_ERROR("error: invalid easycache option '%s'", easycache_option.c_str());
|
||||
return false;
|
||||
}
|
||||
if (idx >= 3) {
|
||||
LOG_ERROR("error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
|
||||
size_t eq_pos = token.find('=');
|
||||
if (eq_pos == std::string::npos) {
|
||||
LOG_ERROR("error: cache option '%s' missing '=' separator", token.c_str());
|
||||
return false;
|
||||
}
|
||||
std::string key = token.substr(0, eq_pos);
|
||||
std::string val = token.substr(eq_pos + 1);
|
||||
try {
|
||||
values[idx] = std::stof(token);
|
||||
if (key == "threshold") {
|
||||
if (cache_mode == "easycache" || cache_mode == "ucache") {
|
||||
cache_params.reuse_threshold = std::stof(val);
|
||||
} else {
|
||||
cache_params.residual_diff_threshold = std::stof(val);
|
||||
}
|
||||
} else if (key == "start") {
|
||||
cache_params.start_percent = std::stof(val);
|
||||
} else if (key == "end") {
|
||||
cache_params.end_percent = std::stof(val);
|
||||
} else if (key == "decay") {
|
||||
cache_params.error_decay_rate = std::stof(val);
|
||||
} else if (key == "relative") {
|
||||
cache_params.use_relative_threshold = (std::stof(val) != 0.0f);
|
||||
} else if (key == "reset") {
|
||||
cache_params.reset_error_on_compute = (std::stof(val) != 0.0f);
|
||||
} else if (key == "Fn" || key == "fn") {
|
||||
cache_params.Fn_compute_blocks = std::stoi(val);
|
||||
} else if (key == "Bn" || key == "bn") {
|
||||
cache_params.Bn_compute_blocks = std::stoi(val);
|
||||
} else if (key == "warmup") {
|
||||
cache_params.max_warmup_steps = std::stoi(val);
|
||||
} else {
|
||||
LOG_ERROR("error: unknown cache parameter '%s'", key.c_str());
|
||||
return false;
|
||||
}
|
||||
} catch (const std::exception&) {
|
||||
LOG_ERROR("error: invalid easycache value '%s'", token.c_str());
|
||||
LOG_ERROR("error: invalid value '%s' for parameter '%s'", val.c_str(), key.c_str());
|
||||
return false;
|
||||
}
|
||||
idx++;
|
||||
}
|
||||
if (idx != 3) {
|
||||
LOG_ERROR("error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
|
||||
return false;
|
||||
return true;
|
||||
};
|
||||
|
||||
if (!cache_mode.empty()) {
|
||||
if (cache_mode == "easycache") {
|
||||
cache_params.mode = SD_CACHE_EASYCACHE;
|
||||
cache_params.reuse_threshold = 0.2f;
|
||||
cache_params.start_percent = 0.15f;
|
||||
cache_params.end_percent = 0.95f;
|
||||
cache_params.error_decay_rate = 1.0f;
|
||||
cache_params.use_relative_threshold = true;
|
||||
cache_params.reset_error_on_compute = true;
|
||||
} else if (cache_mode == "ucache") {
|
||||
cache_params.mode = SD_CACHE_UCACHE;
|
||||
cache_params.reuse_threshold = 1.0f;
|
||||
cache_params.start_percent = 0.15f;
|
||||
cache_params.end_percent = 0.95f;
|
||||
cache_params.error_decay_rate = 1.0f;
|
||||
cache_params.use_relative_threshold = true;
|
||||
cache_params.reset_error_on_compute = true;
|
||||
} else if (cache_mode == "dbcache") {
|
||||
cache_params.mode = SD_CACHE_DBCACHE;
|
||||
cache_params.Fn_compute_blocks = 8;
|
||||
cache_params.Bn_compute_blocks = 0;
|
||||
cache_params.residual_diff_threshold = 0.08f;
|
||||
cache_params.max_warmup_steps = 8;
|
||||
} else if (cache_mode == "taylorseer") {
|
||||
cache_params.mode = SD_CACHE_TAYLORSEER;
|
||||
cache_params.Fn_compute_blocks = 8;
|
||||
cache_params.Bn_compute_blocks = 0;
|
||||
cache_params.residual_diff_threshold = 0.08f;
|
||||
cache_params.max_warmup_steps = 8;
|
||||
} else if (cache_mode == "cache-dit") {
|
||||
cache_params.mode = SD_CACHE_CACHE_DIT;
|
||||
cache_params.Fn_compute_blocks = 8;
|
||||
cache_params.Bn_compute_blocks = 0;
|
||||
cache_params.residual_diff_threshold = 0.08f;
|
||||
cache_params.max_warmup_steps = 8;
|
||||
}
|
||||
if (values[0] < 0.0f) {
|
||||
LOG_ERROR("error: easycache threshold must be non-negative\n");
|
||||
return false;
|
||||
|
||||
if (!cache_option.empty()) {
|
||||
if (!parse_named_params(cache_option)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
|
||||
LOG_ERROR("error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
|
||||
return false;
|
||||
|
||||
if (cache_mode == "easycache" || cache_mode == "ucache") {
|
||||
if (cache_params.reuse_threshold < 0.0f) {
|
||||
LOG_ERROR("error: cache threshold must be non-negative");
|
||||
return false;
|
||||
}
|
||||
if (cache_params.start_percent < 0.0f || cache_params.start_percent >= 1.0f ||
|
||||
cache_params.end_percent <= 0.0f || cache_params.end_percent > 1.0f ||
|
||||
cache_params.start_percent >= cache_params.end_percent) {
|
||||
LOG_ERROR("error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
easycache_params.enabled = true;
|
||||
easycache_params.reuse_threshold = values[0];
|
||||
easycache_params.start_percent = values[1];
|
||||
easycache_params.end_percent = values[2];
|
||||
} else {
|
||||
easycache_params.enabled = false;
|
||||
}
|
||||
|
||||
if (cache_params.mode == SD_CACHE_DBCACHE ||
|
||||
cache_params.mode == SD_CACHE_TAYLORSEER ||
|
||||
cache_params.mode == SD_CACHE_CACHE_DIT) {
|
||||
if (!scm_mask.empty()) {
|
||||
cache_params.scm_mask = scm_mask.c_str();
|
||||
}
|
||||
cache_params.scm_policy_dynamic = scm_policy_dynamic;
|
||||
}
|
||||
|
||||
sample_params.guidance.slg.layers = skip_layers.data();
|
||||
@ -1786,12 +1898,13 @@ struct SDGenerationParams {
|
||||
<< " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n"
|
||||
<< " high_noise_sample_params: " << high_noise_sample_params_str << ",\n"
|
||||
<< " custom_sigmas: " << vec_to_string(custom_sigmas) << ",\n"
|
||||
<< " easycache_option: \"" << easycache_option << "\",\n"
|
||||
<< " easycache: "
|
||||
<< (easycache_params.enabled ? "enabled" : "disabled")
|
||||
<< " (threshold=" << easycache_params.reuse_threshold
|
||||
<< ", start=" << easycache_params.start_percent
|
||||
<< ", end=" << easycache_params.end_percent << "),\n"
|
||||
<< " cache_mode: \"" << cache_mode << "\",\n"
|
||||
<< " cache_option: \"" << cache_option << "\",\n"
|
||||
<< " cache: "
|
||||
<< (cache_params.mode != SD_CACHE_DISABLED ? "enabled" : "disabled")
|
||||
<< " (threshold=" << cache_params.reuse_threshold
|
||||
<< ", start=" << cache_params.start_percent
|
||||
<< ", end=" << cache_params.end_percent << "),\n"
|
||||
<< " moe_boundary: " << moe_boundary << ",\n"
|
||||
<< " video_frames: " << video_frames << ",\n"
|
||||
<< " fps: " << fps << ",\n"
|
||||
|
||||
@ -432,7 +432,7 @@ int main(int argc, const char** argv) {
|
||||
gen_params.pm_style_strength,
|
||||
}, // pm_params
|
||||
ctx_params.vae_tiling_params,
|
||||
gen_params.easycache_params,
|
||||
gen_params.cache_params,
|
||||
};
|
||||
|
||||
sd_image_t* results = nullptr;
|
||||
@ -645,7 +645,7 @@ int main(int argc, const char** argv) {
|
||||
gen_params.pm_style_strength,
|
||||
}, // pm_params
|
||||
ctx_params.vae_tiling_params,
|
||||
gen_params.easycache_params,
|
||||
gen_params.cache_params,
|
||||
};
|
||||
|
||||
sd_image_t* results = nullptr;
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
#include "stable-diffusion.h"
|
||||
#include "util.h"
|
||||
|
||||
#include "cache_dit.hpp"
|
||||
#include "conditioner.hpp"
|
||||
#include "control.hpp"
|
||||
#include "denoiser.hpp"
|
||||
@ -16,6 +17,7 @@
|
||||
#include "lora.hpp"
|
||||
#include "pmid.hpp"
|
||||
#include "tae.hpp"
|
||||
#include "ucache.hpp"
|
||||
#include "vae.hpp"
|
||||
|
||||
#include "latent-preview.h"
|
||||
@ -1525,12 +1527,12 @@ public:
|
||||
const std::vector<float>& sigmas,
|
||||
int start_merge_step,
|
||||
SDCondition id_cond,
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false,
|
||||
ggml_tensor* denoise_mask = nullptr,
|
||||
ggml_tensor* vace_context = nullptr,
|
||||
float vace_strength = 1.f,
|
||||
const sd_easycache_params_t* easycache_params = nullptr) {
|
||||
std::vector<ggml_tensor*> ref_latents = {},
|
||||
bool increase_ref_index = false,
|
||||
ggml_tensor* denoise_mask = nullptr,
|
||||
ggml_tensor* vace_context = nullptr,
|
||||
float vace_strength = 1.f,
|
||||
const sd_cache_params_t* cache_params = nullptr) {
|
||||
if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) {
|
||||
LOG_WARN("timestep shifting is only supported for SDXL models!");
|
||||
shifted_timestep = 0;
|
||||
@ -1558,31 +1560,40 @@ public:
|
||||
}
|
||||
|
||||
EasyCacheState easycache_state;
|
||||
UCacheState ucache_state;
|
||||
CacheDitConditionState cachedit_state;
|
||||
bool easycache_enabled = false;
|
||||
if (easycache_params != nullptr && easycache_params->enabled) {
|
||||
bool easycache_supported = sd_version_is_dit(version);
|
||||
if (!easycache_supported) {
|
||||
LOG_WARN("EasyCache requested but not supported for this model type");
|
||||
} else {
|
||||
EasyCacheConfig easycache_config;
|
||||
easycache_config.enabled = true;
|
||||
easycache_config.reuse_threshold = std::max(0.0f, easycache_params->reuse_threshold);
|
||||
easycache_config.start_percent = easycache_params->start_percent;
|
||||
easycache_config.end_percent = easycache_params->end_percent;
|
||||
bool percent_valid = easycache_config.start_percent >= 0.0f &&
|
||||
easycache_config.start_percent < 1.0f &&
|
||||
easycache_config.end_percent > 0.0f &&
|
||||
easycache_config.end_percent <= 1.0f &&
|
||||
easycache_config.start_percent < easycache_config.end_percent;
|
||||
if (!percent_valid) {
|
||||
LOG_WARN("EasyCache disabled due to invalid percent range (start=%.3f, end=%.3f)",
|
||||
easycache_config.start_percent,
|
||||
easycache_config.end_percent);
|
||||
bool ucache_enabled = false;
|
||||
bool cachedit_enabled = false;
|
||||
|
||||
if (cache_params != nullptr && cache_params->mode != SD_CACHE_DISABLED) {
|
||||
bool percent_valid = true;
|
||||
if (cache_params->mode == SD_CACHE_EASYCACHE || cache_params->mode == SD_CACHE_UCACHE) {
|
||||
percent_valid = cache_params->start_percent >= 0.0f &&
|
||||
cache_params->start_percent < 1.0f &&
|
||||
cache_params->end_percent > 0.0f &&
|
||||
cache_params->end_percent <= 1.0f &&
|
||||
cache_params->start_percent < cache_params->end_percent;
|
||||
}
|
||||
|
||||
if (!percent_valid) {
|
||||
LOG_WARN("Cache disabled due to invalid percent range (start=%.3f, end=%.3f)",
|
||||
cache_params->start_percent,
|
||||
cache_params->end_percent);
|
||||
} else if (cache_params->mode == SD_CACHE_EASYCACHE) {
|
||||
bool easycache_supported = sd_version_is_dit(version);
|
||||
if (!easycache_supported) {
|
||||
LOG_WARN("EasyCache requested but not supported for this model type");
|
||||
} else {
|
||||
EasyCacheConfig easycache_config;
|
||||
easycache_config.enabled = true;
|
||||
easycache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
|
||||
easycache_config.start_percent = cache_params->start_percent;
|
||||
easycache_config.end_percent = cache_params->end_percent;
|
||||
easycache_state.init(easycache_config, denoiser.get());
|
||||
if (easycache_state.enabled()) {
|
||||
easycache_enabled = true;
|
||||
LOG_INFO("EasyCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f",
|
||||
LOG_INFO("EasyCache enabled - threshold: %.3f, start: %.2f, end: %.2f",
|
||||
easycache_config.reuse_threshold,
|
||||
easycache_config.start_percent,
|
||||
easycache_config.end_percent);
|
||||
@ -1590,9 +1601,84 @@ public:
|
||||
LOG_WARN("EasyCache requested but could not be initialized for this run");
|
||||
}
|
||||
}
|
||||
} else if (cache_params->mode == SD_CACHE_UCACHE) {
|
||||
bool ucache_supported = sd_version_is_unet(version);
|
||||
if (!ucache_supported) {
|
||||
LOG_WARN("UCache requested but not supported for this model type (only UNET models)");
|
||||
} else {
|
||||
UCacheConfig ucache_config;
|
||||
ucache_config.enabled = true;
|
||||
ucache_config.reuse_threshold = std::max(0.0f, cache_params->reuse_threshold);
|
||||
ucache_config.start_percent = cache_params->start_percent;
|
||||
ucache_config.end_percent = cache_params->end_percent;
|
||||
ucache_config.error_decay_rate = std::max(0.0f, std::min(1.0f, cache_params->error_decay_rate));
|
||||
ucache_config.use_relative_threshold = cache_params->use_relative_threshold;
|
||||
ucache_config.reset_error_on_compute = cache_params->reset_error_on_compute;
|
||||
ucache_state.init(ucache_config, denoiser.get());
|
||||
if (ucache_state.enabled()) {
|
||||
ucache_enabled = true;
|
||||
LOG_INFO("UCache enabled - threshold: %.3f, start: %.2f, end: %.2f, decay: %.2f, relative: %s, reset: %s",
|
||||
ucache_config.reuse_threshold,
|
||||
ucache_config.start_percent,
|
||||
ucache_config.end_percent,
|
||||
ucache_config.error_decay_rate,
|
||||
ucache_config.use_relative_threshold ? "true" : "false",
|
||||
ucache_config.reset_error_on_compute ? "true" : "false");
|
||||
} else {
|
||||
LOG_WARN("UCache requested but could not be initialized for this run");
|
||||
}
|
||||
}
|
||||
} else if (cache_params->mode == SD_CACHE_DBCACHE ||
|
||||
cache_params->mode == SD_CACHE_TAYLORSEER ||
|
||||
cache_params->mode == SD_CACHE_CACHE_DIT) {
|
||||
bool cachedit_supported = sd_version_is_dit(version);
|
||||
if (!cachedit_supported) {
|
||||
LOG_WARN("CacheDIT requested but not supported for this model type (only DiT models)");
|
||||
} else {
|
||||
DBCacheConfig dbcfg;
|
||||
dbcfg.enabled = (cache_params->mode == SD_CACHE_DBCACHE ||
|
||||
cache_params->mode == SD_CACHE_CACHE_DIT);
|
||||
dbcfg.Fn_compute_blocks = cache_params->Fn_compute_blocks;
|
||||
dbcfg.Bn_compute_blocks = cache_params->Bn_compute_blocks;
|
||||
dbcfg.residual_diff_threshold = cache_params->residual_diff_threshold;
|
||||
dbcfg.max_warmup_steps = cache_params->max_warmup_steps;
|
||||
dbcfg.max_cached_steps = cache_params->max_cached_steps;
|
||||
dbcfg.max_continuous_cached_steps = cache_params->max_continuous_cached_steps;
|
||||
if (cache_params->scm_mask != nullptr && strlen(cache_params->scm_mask) > 0) {
|
||||
dbcfg.steps_computation_mask = parse_scm_mask(cache_params->scm_mask);
|
||||
}
|
||||
dbcfg.scm_policy_dynamic = cache_params->scm_policy_dynamic;
|
||||
|
||||
TaylorSeerConfig tcfg;
|
||||
tcfg.enabled = (cache_params->mode == SD_CACHE_TAYLORSEER ||
|
||||
cache_params->mode == SD_CACHE_CACHE_DIT);
|
||||
tcfg.n_derivatives = cache_params->taylorseer_n_derivatives;
|
||||
tcfg.skip_interval_steps = cache_params->taylorseer_skip_interval;
|
||||
|
||||
cachedit_state.init(dbcfg, tcfg);
|
||||
if (cachedit_state.enabled()) {
|
||||
cachedit_enabled = true;
|
||||
LOG_INFO("CacheDIT enabled - mode: %s, Fn: %d, Bn: %d, threshold: %.3f, warmup: %d",
|
||||
cache_params->mode == SD_CACHE_CACHE_DIT ? "DBCache+TaylorSeer" : (cache_params->mode == SD_CACHE_DBCACHE ? "DBCache" : "TaylorSeer"),
|
||||
dbcfg.Fn_compute_blocks,
|
||||
dbcfg.Bn_compute_blocks,
|
||||
dbcfg.residual_diff_threshold,
|
||||
dbcfg.max_warmup_steps);
|
||||
} else {
|
||||
LOG_WARN("CacheDIT requested but could not be initialized for this run");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (ucache_enabled) {
|
||||
ucache_state.set_sigmas(sigmas);
|
||||
}
|
||||
|
||||
if (cachedit_enabled) {
|
||||
cachedit_state.set_sigmas(sigmas);
|
||||
}
|
||||
|
||||
size_t steps = sigmas.size() - 1;
|
||||
struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent);
|
||||
copy_ggml_tensor(x, init_latent);
|
||||
@ -1696,6 +1782,91 @@ public:
|
||||
return easycache_step_active && easycache_state.is_step_skipped();
|
||||
};
|
||||
|
||||
const bool ucache_step_active = ucache_enabled && step > 0;
|
||||
int ucache_step_index = ucache_step_active ? (step - 1) : -1;
|
||||
if (ucache_step_active) {
|
||||
ucache_state.begin_step(ucache_step_index, sigma);
|
||||
}
|
||||
|
||||
auto ucache_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool {
|
||||
if (!ucache_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return ucache_state.before_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor,
|
||||
sigma,
|
||||
ucache_step_index);
|
||||
};
|
||||
|
||||
auto ucache_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) {
|
||||
if (!ucache_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
ucache_state.after_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor);
|
||||
};
|
||||
|
||||
auto ucache_step_is_skipped = [&]() {
|
||||
return ucache_step_active && ucache_state.is_step_skipped();
|
||||
};
|
||||
|
||||
const bool cachedit_step_active = cachedit_enabled && step > 0;
|
||||
int cachedit_step_index = cachedit_step_active ? (step - 1) : -1;
|
||||
if (cachedit_step_active) {
|
||||
cachedit_state.begin_step(cachedit_step_index, sigma);
|
||||
}
|
||||
|
||||
auto cachedit_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool {
|
||||
if (!cachedit_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return cachedit_state.before_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor,
|
||||
sigma,
|
||||
cachedit_step_index);
|
||||
};
|
||||
|
||||
auto cachedit_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) {
|
||||
if (!cachedit_step_active || condition == nullptr || output_tensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
cachedit_state.after_condition(condition,
|
||||
diffusion_params.x,
|
||||
output_tensor);
|
||||
};
|
||||
|
||||
auto cachedit_step_is_skipped = [&]() {
|
||||
return cachedit_step_active && cachedit_state.is_step_skipped();
|
||||
};
|
||||
|
||||
auto cache_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool {
|
||||
if (easycache_step_active) {
|
||||
return easycache_before_condition(condition, output_tensor);
|
||||
} else if (ucache_step_active) {
|
||||
return ucache_before_condition(condition, output_tensor);
|
||||
} else if (cachedit_step_active) {
|
||||
return cachedit_before_condition(condition, output_tensor);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto cache_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) {
|
||||
if (easycache_step_active) {
|
||||
easycache_after_condition(condition, output_tensor);
|
||||
} else if (ucache_step_active) {
|
||||
ucache_after_condition(condition, output_tensor);
|
||||
} else if (cachedit_step_active) {
|
||||
cachedit_after_condition(condition, output_tensor);
|
||||
}
|
||||
};
|
||||
|
||||
auto cache_step_is_skipped = [&]() {
|
||||
return easycache_step_is_skipped() || ucache_step_is_skipped() || cachedit_step_is_skipped();
|
||||
};
|
||||
|
||||
std::vector<float> scaling = denoiser->get_scalings(sigma);
|
||||
GGML_ASSERT(scaling.size() == 3);
|
||||
float c_skip = scaling[0];
|
||||
@ -1771,7 +1942,7 @@ public:
|
||||
active_condition = &id_cond;
|
||||
}
|
||||
|
||||
bool skip_model = easycache_before_condition(active_condition, *active_output);
|
||||
bool skip_model = cache_before_condition(active_condition, *active_output);
|
||||
if (!skip_model) {
|
||||
if (!work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
@ -1779,10 +1950,10 @@ public:
|
||||
LOG_ERROR("diffusion model compute failed");
|
||||
return nullptr;
|
||||
}
|
||||
easycache_after_condition(active_condition, *active_output);
|
||||
cache_after_condition(active_condition, *active_output);
|
||||
}
|
||||
|
||||
bool current_step_skipped = easycache_step_is_skipped();
|
||||
bool current_step_skipped = cache_step_is_skipped();
|
||||
|
||||
float* negative_data = nullptr;
|
||||
if (has_unconditioned) {
|
||||
@ -1794,12 +1965,12 @@ public:
|
||||
LOG_ERROR("controlnet compute failed");
|
||||
}
|
||||
}
|
||||
current_step_skipped = easycache_step_is_skipped();
|
||||
current_step_skipped = cache_step_is_skipped();
|
||||
diffusion_params.controls = controls;
|
||||
diffusion_params.context = uncond.c_crossattn;
|
||||
diffusion_params.c_concat = uncond.c_concat;
|
||||
diffusion_params.y = uncond.c_vector;
|
||||
bool skip_uncond = easycache_before_condition(&uncond, out_uncond);
|
||||
bool skip_uncond = cache_before_condition(&uncond, out_uncond);
|
||||
if (!skip_uncond) {
|
||||
if (!work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
@ -1807,7 +1978,7 @@ public:
|
||||
LOG_ERROR("diffusion model compute failed");
|
||||
return nullptr;
|
||||
}
|
||||
easycache_after_condition(&uncond, out_uncond);
|
||||
cache_after_condition(&uncond, out_uncond);
|
||||
}
|
||||
negative_data = (float*)out_uncond->data;
|
||||
}
|
||||
@ -1817,7 +1988,7 @@ public:
|
||||
diffusion_params.context = img_cond.c_crossattn;
|
||||
diffusion_params.c_concat = img_cond.c_concat;
|
||||
diffusion_params.y = img_cond.c_vector;
|
||||
bool skip_img_cond = easycache_before_condition(&img_cond, out_img_cond);
|
||||
bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond);
|
||||
if (!skip_img_cond) {
|
||||
if (!work_diffusion_model->compute(n_threads,
|
||||
diffusion_params,
|
||||
@ -1825,7 +1996,7 @@ public:
|
||||
LOG_ERROR("diffusion model compute failed");
|
||||
return nullptr;
|
||||
}
|
||||
easycache_after_condition(&img_cond, out_img_cond);
|
||||
cache_after_condition(&img_cond, out_img_cond);
|
||||
}
|
||||
img_cond_data = (float*)out_img_cond->data;
|
||||
}
|
||||
@ -1835,7 +2006,7 @@ public:
|
||||
float* skip_layer_data = has_skiplayer ? (float*)out_skip->data : nullptr;
|
||||
if (is_skiplayer_step) {
|
||||
LOG_DEBUG("Skipping layers at step %d\n", step);
|
||||
if (!easycache_step_is_skipped()) {
|
||||
if (!cache_step_is_skipped()) {
|
||||
// skip layer (same as conditioned)
|
||||
diffusion_params.context = cond.c_crossattn;
|
||||
diffusion_params.c_concat = cond.c_concat;
|
||||
@ -1939,6 +2110,48 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
if (ucache_enabled) {
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
if (ucache_state.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (ucache_state.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - ucache_state.total_steps_skipped);
|
||||
LOG_INFO("UCache skipped %d/%zu steps (%.2fx estimated speedup)",
|
||||
ucache_state.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup);
|
||||
} else {
|
||||
LOG_INFO("UCache skipped %d/%zu steps",
|
||||
ucache_state.total_steps_skipped,
|
||||
total_steps);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("UCache completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (cachedit_enabled) {
|
||||
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
|
||||
if (cachedit_state.total_steps_skipped > 0 && total_steps > 0) {
|
||||
if (cachedit_state.total_steps_skipped < static_cast<int>(total_steps)) {
|
||||
double speedup = static_cast<double>(total_steps) /
|
||||
static_cast<double>(total_steps - cachedit_state.total_steps_skipped);
|
||||
LOG_INFO("CacheDIT skipped %d/%zu steps (%.2fx estimated speedup), accum_diff: %.4f",
|
||||
cachedit_state.total_steps_skipped,
|
||||
total_steps,
|
||||
speedup,
|
||||
cachedit_state.accumulated_residual_diff);
|
||||
} else {
|
||||
LOG_INFO("CacheDIT skipped %d/%zu steps, accum_diff: %.4f",
|
||||
cachedit_state.total_steps_skipped,
|
||||
total_steps,
|
||||
cachedit_state.accumulated_residual_diff);
|
||||
}
|
||||
} else if (total_steps > 0) {
|
||||
LOG_INFO("CacheDIT completed without skipping steps");
|
||||
}
|
||||
}
|
||||
|
||||
if (inverse_noise_scaling) {
|
||||
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
|
||||
}
|
||||
@ -2554,12 +2767,25 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) {
|
||||
return LORA_APPLY_MODE_COUNT;
|
||||
}
|
||||
|
||||
void sd_easycache_params_init(sd_easycache_params_t* easycache_params) {
|
||||
*easycache_params = {};
|
||||
easycache_params->enabled = false;
|
||||
easycache_params->reuse_threshold = 0.2f;
|
||||
easycache_params->start_percent = 0.15f;
|
||||
easycache_params->end_percent = 0.95f;
|
||||
void sd_cache_params_init(sd_cache_params_t* cache_params) {
|
||||
*cache_params = {};
|
||||
cache_params->mode = SD_CACHE_DISABLED;
|
||||
cache_params->reuse_threshold = 1.0f;
|
||||
cache_params->start_percent = 0.15f;
|
||||
cache_params->end_percent = 0.95f;
|
||||
cache_params->error_decay_rate = 1.0f;
|
||||
cache_params->use_relative_threshold = true;
|
||||
cache_params->reset_error_on_compute = true;
|
||||
cache_params->Fn_compute_blocks = 8;
|
||||
cache_params->Bn_compute_blocks = 0;
|
||||
cache_params->residual_diff_threshold = 0.08f;
|
||||
cache_params->max_warmup_steps = 8;
|
||||
cache_params->max_cached_steps = -1;
|
||||
cache_params->max_continuous_cached_steps = -1;
|
||||
cache_params->taylorseer_n_derivatives = 1;
|
||||
cache_params->taylorseer_skip_interval = 1;
|
||||
cache_params->scm_mask = nullptr;
|
||||
cache_params->scm_policy_dynamic = true;
|
||||
}
|
||||
|
||||
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
|
||||
@ -2724,7 +2950,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
|
||||
sd_img_gen_params->control_strength = 0.9f;
|
||||
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
|
||||
sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
|
||||
sd_easycache_params_init(&sd_img_gen_params->easycache);
|
||||
sd_cache_params_init(&sd_img_gen_params->cache);
|
||||
}
|
||||
|
||||
char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
@ -2768,12 +2994,18 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
|
||||
sd_img_gen_params->pm_params.id_images_count,
|
||||
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
|
||||
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled));
|
||||
const char* cache_mode_str = "disabled";
|
||||
if (sd_img_gen_params->cache.mode == SD_CACHE_EASYCACHE) {
|
||||
cache_mode_str = "easycache";
|
||||
} else if (sd_img_gen_params->cache.mode == SD_CACHE_UCACHE) {
|
||||
cache_mode_str = "ucache";
|
||||
}
|
||||
snprintf(buf + strlen(buf), 4096 - strlen(buf),
|
||||
"easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
|
||||
sd_img_gen_params->easycache.enabled ? "enabled" : "disabled",
|
||||
sd_img_gen_params->easycache.reuse_threshold,
|
||||
sd_img_gen_params->easycache.start_percent,
|
||||
sd_img_gen_params->easycache.end_percent);
|
||||
"cache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
|
||||
cache_mode_str,
|
||||
sd_img_gen_params->cache.reuse_threshold,
|
||||
sd_img_gen_params->cache.start_percent,
|
||||
sd_img_gen_params->cache.end_percent);
|
||||
free(sample_params_str);
|
||||
return buf;
|
||||
}
|
||||
@ -2790,7 +3022,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
|
||||
sd_vid_gen_params->video_frames = 6;
|
||||
sd_vid_gen_params->moe_boundary = 0.875f;
|
||||
sd_vid_gen_params->vace_strength = 1.f;
|
||||
sd_easycache_params_init(&sd_vid_gen_params->easycache);
|
||||
sd_cache_params_init(&sd_vid_gen_params->cache);
|
||||
}
|
||||
|
||||
struct sd_ctx_t {
|
||||
@ -2869,9 +3101,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
std::vector<sd_image_t*> ref_images,
|
||||
std::vector<ggml_tensor*> ref_latents,
|
||||
bool increase_ref_index,
|
||||
ggml_tensor* concat_latent = nullptr,
|
||||
ggml_tensor* denoise_mask = nullptr,
|
||||
const sd_easycache_params_t* easycache_params = nullptr) {
|
||||
ggml_tensor* concat_latent = nullptr,
|
||||
ggml_tensor* denoise_mask = nullptr,
|
||||
const sd_cache_params_t* cache_params = nullptr) {
|
||||
if (seed < 0) {
|
||||
// Generally, when using the provided command line, the seed is always >0.
|
||||
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
|
||||
@ -3160,7 +3392,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
||||
denoise_mask,
|
||||
nullptr,
|
||||
1.0f,
|
||||
easycache_params);
|
||||
cache_params);
|
||||
int64_t sampling_end = ggml_time_ms();
|
||||
if (x_0 != nullptr) {
|
||||
// print_ggml_tensor(x_0);
|
||||
@ -3498,7 +3730,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
||||
sd_img_gen_params->increase_ref_index,
|
||||
concat_latent,
|
||||
denoise_mask,
|
||||
&sd_img_gen_params->easycache);
|
||||
&sd_img_gen_params->cache);
|
||||
|
||||
size_t t2 = ggml_time_ms();
|
||||
|
||||
@ -3869,7 +4101,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
denoise_mask,
|
||||
vace_context,
|
||||
sd_vid_gen_params->vace_strength,
|
||||
&sd_vid_gen_params->easycache);
|
||||
&sd_vid_gen_params->cache);
|
||||
|
||||
int64_t sampling_end = ggml_time_ms();
|
||||
LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
@ -3906,7 +4138,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
||||
denoise_mask,
|
||||
vace_context,
|
||||
sd_vid_gen_params->vace_strength,
|
||||
&sd_vid_gen_params->easycache);
|
||||
&sd_vid_gen_params->cache);
|
||||
|
||||
int64_t sampling_end = ggml_time_ms();
|
||||
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
|
||||
|
||||
@ -238,12 +238,34 @@ typedef struct {
|
||||
float style_strength;
|
||||
} sd_pm_params_t; // photo maker
|
||||
|
||||
enum sd_cache_mode_t {
|
||||
SD_CACHE_DISABLED = 0,
|
||||
SD_CACHE_EASYCACHE,
|
||||
SD_CACHE_UCACHE,
|
||||
SD_CACHE_DBCACHE,
|
||||
SD_CACHE_TAYLORSEER,
|
||||
SD_CACHE_CACHE_DIT,
|
||||
};
|
||||
|
||||
typedef struct {
|
||||
bool enabled;
|
||||
enum sd_cache_mode_t mode;
|
||||
float reuse_threshold;
|
||||
float start_percent;
|
||||
float end_percent;
|
||||
} sd_easycache_params_t;
|
||||
float error_decay_rate;
|
||||
bool use_relative_threshold;
|
||||
bool reset_error_on_compute;
|
||||
int Fn_compute_blocks;
|
||||
int Bn_compute_blocks;
|
||||
float residual_diff_threshold;
|
||||
int max_warmup_steps;
|
||||
int max_cached_steps;
|
||||
int max_continuous_cached_steps;
|
||||
int taylorseer_n_derivatives;
|
||||
int taylorseer_skip_interval;
|
||||
const char* scm_mask;
|
||||
bool scm_policy_dynamic;
|
||||
} sd_cache_params_t;
|
||||
|
||||
typedef struct {
|
||||
bool is_high_noise;
|
||||
@ -273,7 +295,7 @@ typedef struct {
|
||||
float control_strength;
|
||||
sd_pm_params_t pm_params;
|
||||
sd_tiling_params_t vae_tiling_params;
|
||||
sd_easycache_params_t easycache;
|
||||
sd_cache_params_t cache;
|
||||
} sd_img_gen_params_t;
|
||||
|
||||
typedef struct {
|
||||
@ -295,7 +317,7 @@ typedef struct {
|
||||
int64_t seed;
|
||||
int video_frames;
|
||||
float vace_strength;
|
||||
sd_easycache_params_t easycache;
|
||||
sd_cache_params_t cache;
|
||||
} sd_vid_gen_params_t;
|
||||
|
||||
typedef struct sd_ctx_t sd_ctx_t;
|
||||
@ -325,7 +347,7 @@ SD_API enum preview_t str_to_preview(const char* str);
|
||||
SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode);
|
||||
SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str);
|
||||
|
||||
SD_API void sd_easycache_params_init(sd_easycache_params_t* easycache_params);
|
||||
SD_API void sd_cache_params_init(sd_cache_params_t* cache_params);
|
||||
|
||||
SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
|
||||
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);
|
||||
|
||||
404
ucache.hpp
Normal file
404
ucache.hpp
Normal file
@ -0,0 +1,404 @@
|
||||
#ifndef __UCACHE_HPP__
|
||||
#define __UCACHE_HPP__
|
||||
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "denoiser.hpp"
|
||||
#include "ggml_extend.hpp"
|
||||
|
||||
struct UCacheConfig {
|
||||
bool enabled = false;
|
||||
float reuse_threshold = 1.0f;
|
||||
float start_percent = 0.15f;
|
||||
float end_percent = 0.95f;
|
||||
float error_decay_rate = 1.0f;
|
||||
bool use_relative_threshold = true;
|
||||
bool adaptive_threshold = true;
|
||||
float early_step_multiplier = 0.5f;
|
||||
float late_step_multiplier = 1.5f;
|
||||
bool reset_error_on_compute = true;
|
||||
};
|
||||
|
||||
struct UCacheCacheEntry {
|
||||
std::vector<float> diff;
|
||||
};
|
||||
|
||||
struct UCacheState {
|
||||
UCacheConfig config;
|
||||
Denoiser* denoiser = nullptr;
|
||||
float start_sigma = std::numeric_limits<float>::max();
|
||||
float end_sigma = 0.0f;
|
||||
bool initialized = false;
|
||||
bool initial_step = true;
|
||||
bool skip_current_step = false;
|
||||
bool step_active = false;
|
||||
const SDCondition* anchor_condition = nullptr;
|
||||
std::unordered_map<const SDCondition*, UCacheCacheEntry> cache_diffs;
|
||||
std::vector<float> prev_input;
|
||||
std::vector<float> prev_output;
|
||||
float output_prev_norm = 0.0f;
|
||||
bool has_prev_input = false;
|
||||
bool has_prev_output = false;
|
||||
bool has_output_prev_norm = false;
|
||||
bool has_relative_transformation_rate = false;
|
||||
float relative_transformation_rate = 0.0f;
|
||||
float cumulative_change_rate = 0.0f;
|
||||
float last_input_change = 0.0f;
|
||||
bool has_last_input_change = false;
|
||||
int total_steps_skipped = 0;
|
||||
int current_step_index = -1;
|
||||
int steps_computed_since_active = 0;
|
||||
float accumulated_error = 0.0f;
|
||||
float reference_output_norm = 0.0f;
|
||||
|
||||
struct BlockMetrics {
|
||||
float sum_transformation_rate = 0.0f;
|
||||
float sum_output_norm = 0.0f;
|
||||
int sample_count = 0;
|
||||
float min_change_rate = std::numeric_limits<float>::max();
|
||||
float max_change_rate = 0.0f;
|
||||
|
||||
void reset() {
|
||||
sum_transformation_rate = 0.0f;
|
||||
sum_output_norm = 0.0f;
|
||||
sample_count = 0;
|
||||
min_change_rate = std::numeric_limits<float>::max();
|
||||
max_change_rate = 0.0f;
|
||||
}
|
||||
|
||||
void record(float change_rate, float output_norm) {
|
||||
if (std::isfinite(change_rate) && change_rate > 0.0f) {
|
||||
sum_transformation_rate += change_rate;
|
||||
sum_output_norm += output_norm;
|
||||
sample_count++;
|
||||
if (change_rate < min_change_rate)
|
||||
min_change_rate = change_rate;
|
||||
if (change_rate > max_change_rate)
|
||||
max_change_rate = change_rate;
|
||||
}
|
||||
}
|
||||
|
||||
float avg_transformation_rate() const {
|
||||
return (sample_count > 0) ? (sum_transformation_rate / sample_count) : 0.0f;
|
||||
}
|
||||
|
||||
float avg_output_norm() const {
|
||||
return (sample_count > 0) ? (sum_output_norm / sample_count) : 0.0f;
|
||||
}
|
||||
};
|
||||
BlockMetrics block_metrics;
|
||||
int total_active_steps = 0;
|
||||
|
||||
void reset_runtime() {
|
||||
initial_step = true;
|
||||
skip_current_step = false;
|
||||
step_active = false;
|
||||
anchor_condition = nullptr;
|
||||
cache_diffs.clear();
|
||||
prev_input.clear();
|
||||
prev_output.clear();
|
||||
output_prev_norm = 0.0f;
|
||||
has_prev_input = false;
|
||||
has_prev_output = false;
|
||||
has_output_prev_norm = false;
|
||||
has_relative_transformation_rate = false;
|
||||
relative_transformation_rate = 0.0f;
|
||||
cumulative_change_rate = 0.0f;
|
||||
last_input_change = 0.0f;
|
||||
has_last_input_change = false;
|
||||
total_steps_skipped = 0;
|
||||
current_step_index = -1;
|
||||
steps_computed_since_active = 0;
|
||||
accumulated_error = 0.0f;
|
||||
reference_output_norm = 0.0f;
|
||||
block_metrics.reset();
|
||||
total_active_steps = 0;
|
||||
}
|
||||
|
||||
void init(const UCacheConfig& cfg, Denoiser* d) {
|
||||
config = cfg;
|
||||
denoiser = d;
|
||||
initialized = cfg.enabled && d != nullptr;
|
||||
reset_runtime();
|
||||
if (initialized) {
|
||||
start_sigma = percent_to_sigma(config.start_percent);
|
||||
end_sigma = percent_to_sigma(config.end_percent);
|
||||
}
|
||||
}
|
||||
|
||||
void set_sigmas(const std::vector<float>& sigmas) {
|
||||
if (!initialized || sigmas.size() < 2) {
|
||||
return;
|
||||
}
|
||||
size_t n_steps = sigmas.size() - 1;
|
||||
|
||||
size_t start_step = static_cast<size_t>(config.start_percent * n_steps);
|
||||
size_t end_step = static_cast<size_t>(config.end_percent * n_steps);
|
||||
|
||||
if (start_step >= n_steps)
|
||||
start_step = n_steps - 1;
|
||||
if (end_step >= n_steps)
|
||||
end_step = n_steps - 1;
|
||||
|
||||
start_sigma = sigmas[start_step];
|
||||
end_sigma = sigmas[end_step];
|
||||
|
||||
if (start_sigma < end_sigma) {
|
||||
std::swap(start_sigma, end_sigma);
|
||||
}
|
||||
}
|
||||
|
||||
bool enabled() const {
|
||||
return initialized && config.enabled;
|
||||
}
|
||||
|
||||
float percent_to_sigma(float percent) const {
|
||||
if (!denoiser) {
|
||||
return 0.0f;
|
||||
}
|
||||
if (percent <= 0.0f) {
|
||||
return std::numeric_limits<float>::max();
|
||||
}
|
||||
if (percent >= 1.0f) {
|
||||
return 0.0f;
|
||||
}
|
||||
float t = (1.0f - percent) * (TIMESTEPS - 1);
|
||||
return denoiser->t_to_sigma(t);
|
||||
}
|
||||
|
||||
void begin_step(int step_index, float sigma) {
|
||||
if (!enabled()) {
|
||||
return;
|
||||
}
|
||||
if (step_index == current_step_index) {
|
||||
return;
|
||||
}
|
||||
current_step_index = step_index;
|
||||
skip_current_step = false;
|
||||
has_last_input_change = false;
|
||||
step_active = false;
|
||||
|
||||
if (sigma > start_sigma) {
|
||||
return;
|
||||
}
|
||||
if (!(sigma > end_sigma)) {
|
||||
return;
|
||||
}
|
||||
step_active = true;
|
||||
total_active_steps++;
|
||||
}
|
||||
|
||||
bool step_is_active() const {
|
||||
return enabled() && step_active;
|
||||
}
|
||||
|
||||
bool is_step_skipped() const {
|
||||
return enabled() && step_active && skip_current_step;
|
||||
}
|
||||
|
||||
float get_adaptive_threshold(int estimated_total_steps = 0) const {
|
||||
float base_threshold = config.reuse_threshold;
|
||||
|
||||
if (!config.adaptive_threshold) {
|
||||
return base_threshold;
|
||||
}
|
||||
|
||||
int effective_total = estimated_total_steps;
|
||||
if (effective_total <= 0) {
|
||||
effective_total = std::max(20, steps_computed_since_active * 2);
|
||||
}
|
||||
|
||||
float progress = (effective_total > 0) ? (static_cast<float>(steps_computed_since_active) / effective_total) : 0.0f;
|
||||
|
||||
float multiplier = 1.0f;
|
||||
if (progress < 0.2f) {
|
||||
multiplier = config.early_step_multiplier;
|
||||
} else if (progress > 0.8f) {
|
||||
multiplier = config.late_step_multiplier;
|
||||
}
|
||||
|
||||
return base_threshold * multiplier;
|
||||
}
|
||||
|
||||
bool has_cache(const SDCondition* cond) const {
|
||||
auto it = cache_diffs.find(cond);
|
||||
return it != cache_diffs.end() && !it->second.diff.empty();
|
||||
}
|
||||
|
||||
void update_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) {
|
||||
UCacheCacheEntry& entry = cache_diffs[cond];
|
||||
size_t ne = static_cast<size_t>(ggml_nelements(output));
|
||||
entry.diff.resize(ne);
|
||||
float* out_data = (float*)output->data;
|
||||
float* in_data = (float*)input->data;
|
||||
|
||||
for (size_t i = 0; i < ne; ++i) {
|
||||
entry.diff[i] = out_data[i] - in_data[i];
|
||||
}
|
||||
}
|
||||
|
||||
void apply_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) {
|
||||
auto it = cache_diffs.find(cond);
|
||||
if (it == cache_diffs.end() || it->second.diff.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
copy_ggml_tensor(output, input);
|
||||
float* out_data = (float*)output->data;
|
||||
const std::vector<float>& diff = it->second.diff;
|
||||
for (size_t i = 0; i < diff.size(); ++i) {
|
||||
out_data[i] += diff[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool before_condition(const SDCondition* cond,
|
||||
ggml_tensor* input,
|
||||
ggml_tensor* output,
|
||||
float sigma,
|
||||
int step_index) {
|
||||
if (!enabled() || step_index < 0) {
|
||||
return false;
|
||||
}
|
||||
if (step_index != current_step_index) {
|
||||
begin_step(step_index, sigma);
|
||||
}
|
||||
if (!step_active) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (initial_step) {
|
||||
anchor_condition = cond;
|
||||
initial_step = false;
|
||||
}
|
||||
|
||||
bool is_anchor = (cond == anchor_condition);
|
||||
|
||||
if (skip_current_step) {
|
||||
if (has_cache(cond)) {
|
||||
apply_cache(cond, input, output);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!is_anchor) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!has_prev_input || !has_prev_output || !has_cache(cond)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t ne = static_cast<size_t>(ggml_nelements(input));
|
||||
if (prev_input.size() != ne) {
|
||||
return false;
|
||||
}
|
||||
|
||||
float* input_data = (float*)input->data;
|
||||
last_input_change = 0.0f;
|
||||
for (size_t i = 0; i < ne; ++i) {
|
||||
last_input_change += std::fabs(input_data[i] - prev_input[i]);
|
||||
}
|
||||
if (ne > 0) {
|
||||
last_input_change /= static_cast<float>(ne);
|
||||
}
|
||||
has_last_input_change = true;
|
||||
|
||||
if (has_output_prev_norm && has_relative_transformation_rate &&
|
||||
last_input_change > 0.0f && output_prev_norm > 0.0f) {
|
||||
float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm;
|
||||
accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
|
||||
|
||||
float effective_threshold = get_adaptive_threshold();
|
||||
if (config.use_relative_threshold && reference_output_norm > 0.0f) {
|
||||
effective_threshold = effective_threshold * reference_output_norm;
|
||||
}
|
||||
|
||||
if (accumulated_error < effective_threshold) {
|
||||
skip_current_step = true;
|
||||
total_steps_skipped++;
|
||||
apply_cache(cond, input, output);
|
||||
return true;
|
||||
} else if (config.reset_error_on_compute) {
|
||||
accumulated_error = 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void after_condition(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) {
|
||||
if (!step_is_active()) {
|
||||
return;
|
||||
}
|
||||
|
||||
update_cache(cond, input, output);
|
||||
|
||||
if (cond != anchor_condition) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t ne = static_cast<size_t>(ggml_nelements(input));
|
||||
float* in_data = (float*)input->data;
|
||||
prev_input.resize(ne);
|
||||
for (size_t i = 0; i < ne; ++i) {
|
||||
prev_input[i] = in_data[i];
|
||||
}
|
||||
has_prev_input = true;
|
||||
|
||||
float* out_data = (float*)output->data;
|
||||
float output_change = 0.0f;
|
||||
if (has_prev_output && prev_output.size() == ne) {
|
||||
for (size_t i = 0; i < ne; ++i) {
|
||||
output_change += std::fabs(out_data[i] - prev_output[i]);
|
||||
}
|
||||
if (ne > 0) {
|
||||
output_change /= static_cast<float>(ne);
|
||||
}
|
||||
}
|
||||
|
||||
prev_output.resize(ne);
|
||||
for (size_t i = 0; i < ne; ++i) {
|
||||
prev_output[i] = out_data[i];
|
||||
}
|
||||
has_prev_output = true;
|
||||
|
||||
float mean_abs = 0.0f;
|
||||
for (size_t i = 0; i < ne; ++i) {
|
||||
mean_abs += std::fabs(out_data[i]);
|
||||
}
|
||||
output_prev_norm = (ne > 0) ? (mean_abs / static_cast<float>(ne)) : 0.0f;
|
||||
has_output_prev_norm = output_prev_norm > 0.0f;
|
||||
|
||||
if (reference_output_norm == 0.0f) {
|
||||
reference_output_norm = output_prev_norm;
|
||||
}
|
||||
|
||||
if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) {
|
||||
float rate = output_change / last_input_change;
|
||||
if (std::isfinite(rate)) {
|
||||
relative_transformation_rate = rate;
|
||||
has_relative_transformation_rate = true;
|
||||
block_metrics.record(rate, output_prev_norm);
|
||||
}
|
||||
}
|
||||
|
||||
has_last_input_change = false;
|
||||
}
|
||||
|
||||
void log_block_metrics() const {
|
||||
if (block_metrics.sample_count > 0) {
|
||||
LOG_INFO("UCacheBlockMetrics: samples=%d, avg_rate=%.4f, min=%.4f, max=%.4f, avg_norm=%.4f",
|
||||
block_metrics.sample_count,
|
||||
block_metrics.avg_transformation_rate(),
|
||||
block_metrics.min_change_rate,
|
||||
block_metrics.max_change_rate,
|
||||
block_metrics.avg_output_norm());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif // __UCACHE_HPP__
|
||||
Loading…
x
Reference in New Issue
Block a user