mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 02:08:51 +00:00
fix: ucache: normalize reuse error (#1313)
This commit is contained in:
parent
7c880f80c7
commit
d95062737e
@ -19,6 +19,7 @@ struct UCacheConfig {
|
|||||||
bool adaptive_threshold = true;
|
bool adaptive_threshold = true;
|
||||||
float early_step_multiplier = 0.5f;
|
float early_step_multiplier = 0.5f;
|
||||||
float late_step_multiplier = 1.5f;
|
float late_step_multiplier = 1.5f;
|
||||||
|
float relative_norm_gain = 1.6f;
|
||||||
bool reset_error_on_compute = true;
|
bool reset_error_on_compute = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -45,14 +46,16 @@ struct UCacheState {
|
|||||||
bool has_output_prev_norm = false;
|
bool has_output_prev_norm = false;
|
||||||
bool has_relative_transformation_rate = false;
|
bool has_relative_transformation_rate = false;
|
||||||
float relative_transformation_rate = 0.0f;
|
float relative_transformation_rate = 0.0f;
|
||||||
float cumulative_change_rate = 0.0f;
|
|
||||||
float last_input_change = 0.0f;
|
float last_input_change = 0.0f;
|
||||||
bool has_last_input_change = false;
|
bool has_last_input_change = false;
|
||||||
|
float output_change_ema = 0.0f;
|
||||||
|
bool has_output_change_ema = false;
|
||||||
int total_steps_skipped = 0;
|
int total_steps_skipped = 0;
|
||||||
int current_step_index = -1;
|
int current_step_index = -1;
|
||||||
int steps_computed_since_active = 0;
|
int steps_computed_since_active = 0;
|
||||||
|
int expected_total_steps = 0;
|
||||||
|
int consecutive_skipped_steps = 0;
|
||||||
float accumulated_error = 0.0f;
|
float accumulated_error = 0.0f;
|
||||||
float reference_output_norm = 0.0f;
|
|
||||||
|
|
||||||
struct BlockMetrics {
|
struct BlockMetrics {
|
||||||
float sum_transformation_rate = 0.0f;
|
float sum_transformation_rate = 0.0f;
|
||||||
@ -106,14 +109,16 @@ struct UCacheState {
|
|||||||
has_output_prev_norm = false;
|
has_output_prev_norm = false;
|
||||||
has_relative_transformation_rate = false;
|
has_relative_transformation_rate = false;
|
||||||
relative_transformation_rate = 0.0f;
|
relative_transformation_rate = 0.0f;
|
||||||
cumulative_change_rate = 0.0f;
|
|
||||||
last_input_change = 0.0f;
|
last_input_change = 0.0f;
|
||||||
has_last_input_change = false;
|
has_last_input_change = false;
|
||||||
|
output_change_ema = 0.0f;
|
||||||
|
has_output_change_ema = false;
|
||||||
total_steps_skipped = 0;
|
total_steps_skipped = 0;
|
||||||
current_step_index = -1;
|
current_step_index = -1;
|
||||||
steps_computed_since_active = 0;
|
steps_computed_since_active = 0;
|
||||||
|
expected_total_steps = 0;
|
||||||
|
consecutive_skipped_steps = 0;
|
||||||
accumulated_error = 0.0f;
|
accumulated_error = 0.0f;
|
||||||
reference_output_norm = 0.0f;
|
|
||||||
block_metrics.reset();
|
block_metrics.reset();
|
||||||
total_active_steps = 0;
|
total_active_steps = 0;
|
||||||
}
|
}
|
||||||
@ -134,6 +139,7 @@ struct UCacheState {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
size_t n_steps = sigmas.size() - 1;
|
size_t n_steps = sigmas.size() - 1;
|
||||||
|
expected_total_steps = static_cast<int>(n_steps);
|
||||||
|
|
||||||
size_t start_step = static_cast<size_t>(config.start_percent * n_steps);
|
size_t start_step = static_cast<size_t>(config.start_percent * n_steps);
|
||||||
size_t end_step = static_cast<size_t>(config.end_percent * n_steps);
|
size_t end_step = static_cast<size_t>(config.end_percent * n_steps);
|
||||||
@ -207,11 +213,15 @@ struct UCacheState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int effective_total = estimated_total_steps;
|
int effective_total = estimated_total_steps;
|
||||||
|
if (effective_total <= 0) {
|
||||||
|
effective_total = expected_total_steps;
|
||||||
|
}
|
||||||
if (effective_total <= 0) {
|
if (effective_total <= 0) {
|
||||||
effective_total = std::max(20, steps_computed_since_active * 2);
|
effective_total = std::max(20, steps_computed_since_active * 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
float progress = (effective_total > 0) ? (static_cast<float>(steps_computed_since_active) / effective_total) : 0.0f;
|
float progress = (effective_total > 0) ? (static_cast<float>(steps_computed_since_active) / effective_total) : 0.0f;
|
||||||
|
progress = std::max(0.0f, std::min(1.0f, progress));
|
||||||
|
|
||||||
float multiplier = 1.0f;
|
float multiplier = 1.0f;
|
||||||
if (progress < 0.2f) {
|
if (progress < 0.2f) {
|
||||||
@ -309,17 +319,31 @@ struct UCacheState {
|
|||||||
|
|
||||||
if (has_output_prev_norm && has_relative_transformation_rate &&
|
if (has_output_prev_norm && has_relative_transformation_rate &&
|
||||||
last_input_change > 0.0f && output_prev_norm > 0.0f) {
|
last_input_change > 0.0f && output_prev_norm > 0.0f) {
|
||||||
float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm;
|
float approx_output_change = relative_transformation_rate * last_input_change;
|
||||||
|
float approx_output_change_rate;
|
||||||
|
if (config.use_relative_threshold) {
|
||||||
|
float base_scale = std::max(output_prev_norm, 1e-6f);
|
||||||
|
float dyn_scale = has_output_change_ema
|
||||||
|
? std::max(output_change_ema * std::max(1.0f, config.relative_norm_gain), 1e-6f)
|
||||||
|
: base_scale;
|
||||||
|
float scale = std::sqrt(base_scale * dyn_scale);
|
||||||
|
approx_output_change_rate = approx_output_change / scale;
|
||||||
|
} else {
|
||||||
|
approx_output_change_rate = approx_output_change;
|
||||||
|
}
|
||||||
|
// Increase estimated error with skip horizon to avoid long extrapolation streaks
|
||||||
|
approx_output_change_rate *= (1.0f + 0.50f * consecutive_skipped_steps);
|
||||||
accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
|
accumulated_error = accumulated_error * config.error_decay_rate + approx_output_change_rate;
|
||||||
|
|
||||||
float effective_threshold = get_adaptive_threshold();
|
float effective_threshold = get_adaptive_threshold();
|
||||||
if (config.use_relative_threshold && reference_output_norm > 0.0f) {
|
if (!config.use_relative_threshold && output_prev_norm > 0.0f) {
|
||||||
effective_threshold = effective_threshold * reference_output_norm;
|
effective_threshold = effective_threshold * output_prev_norm;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (accumulated_error < effective_threshold) {
|
if (accumulated_error < effective_threshold) {
|
||||||
skip_current_step = true;
|
skip_current_step = true;
|
||||||
total_steps_skipped++;
|
total_steps_skipped++;
|
||||||
|
consecutive_skipped_steps++;
|
||||||
apply_cache(cond, input, output);
|
apply_cache(cond, input, output);
|
||||||
return true;
|
return true;
|
||||||
} else if (config.reset_error_on_compute) {
|
} else if (config.reset_error_on_compute) {
|
||||||
@ -340,6 +364,8 @@ struct UCacheState {
|
|||||||
if (cond != anchor_condition) {
|
if (cond != anchor_condition) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
steps_computed_since_active++;
|
||||||
|
consecutive_skipped_steps = 0;
|
||||||
|
|
||||||
size_t ne = static_cast<size_t>(ggml_nelements(input));
|
size_t ne = static_cast<size_t>(ggml_nelements(input));
|
||||||
float* in_data = (float*)input->data;
|
float* in_data = (float*)input->data;
|
||||||
@ -359,6 +385,14 @@ struct UCacheState {
|
|||||||
output_change /= static_cast<float>(ne);
|
output_change /= static_cast<float>(ne);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (std::isfinite(output_change) && output_change > 0.0f) {
|
||||||
|
if (!has_output_change_ema) {
|
||||||
|
output_change_ema = output_change;
|
||||||
|
has_output_change_ema = true;
|
||||||
|
} else {
|
||||||
|
output_change_ema = 0.8f * output_change_ema + 0.2f * output_change;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
prev_output.resize(ne);
|
prev_output.resize(ne);
|
||||||
for (size_t i = 0; i < ne; ++i) {
|
for (size_t i = 0; i < ne; ++i) {
|
||||||
@ -373,10 +407,6 @@ struct UCacheState {
|
|||||||
output_prev_norm = (ne > 0) ? (mean_abs / static_cast<float>(ne)) : 0.0f;
|
output_prev_norm = (ne > 0) ? (mean_abs / static_cast<float>(ne)) : 0.0f;
|
||||||
has_output_prev_norm = output_prev_norm > 0.0f;
|
has_output_prev_norm = output_prev_norm > 0.0f;
|
||||||
|
|
||||||
if (reference_output_norm == 0.0f) {
|
|
||||||
reference_output_norm = output_prev_norm;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) {
|
if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) {
|
||||||
float rate = output_change / last_input_change;
|
float rate = output_change / last_input_change;
|
||||||
if (std::isfinite(rate)) {
|
if (std::isfinite(rate)) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user