feat: add easycache support (#940)

This commit is contained in:
rmatif 2025-11-19 16:19:32 +01:00 committed by GitHub
parent 28ffb6c13d
commit a14e2b321d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 541 additions and 32 deletions

265
easycache.hpp Normal file
View File

@ -0,0 +1,265 @@
#include <cmath>
#include <limits>
#include <unordered_map>
#include <vector>
#include "denoiser.hpp"
#include "ggml_extend.hpp"
struct EasyCacheConfig {
bool enabled = false;
float reuse_threshold = 0.2f;
float start_percent = 0.15f;
float end_percent = 0.95f;
};
struct EasyCacheCacheEntry {
std::vector<float> diff;
};
struct EasyCacheState {
EasyCacheConfig config;
Denoiser* denoiser = nullptr;
float start_sigma = std::numeric_limits<float>::max();
float end_sigma = 0.0f;
bool initialized = false;
bool initial_step = true;
bool skip_current_step = false;
bool step_active = false;
const SDCondition* anchor_condition = nullptr;
std::unordered_map<const SDCondition*, EasyCacheCacheEntry> cache_diffs;
std::vector<float> prev_input;
std::vector<float> prev_output;
float output_prev_norm = 0.0f;
bool has_prev_input = false;
bool has_prev_output = false;
bool has_output_prev_norm = false;
bool has_relative_transformation_rate = false;
float relative_transformation_rate = 0.0f;
float cumulative_change_rate = 0.0f;
float last_input_change = 0.0f;
bool has_last_input_change = false;
int total_steps_skipped = 0;
int current_step_index = -1;
void reset_runtime() {
initial_step = true;
skip_current_step = false;
step_active = false;
anchor_condition = nullptr;
cache_diffs.clear();
prev_input.clear();
prev_output.clear();
output_prev_norm = 0.0f;
has_prev_input = false;
has_prev_output = false;
has_output_prev_norm = false;
has_relative_transformation_rate = false;
relative_transformation_rate = 0.0f;
cumulative_change_rate = 0.0f;
last_input_change = 0.0f;
has_last_input_change = false;
total_steps_skipped = 0;
current_step_index = -1;
}
void init(const EasyCacheConfig& cfg, Denoiser* d) {
config = cfg;
denoiser = d;
initialized = cfg.enabled && d != nullptr;
reset_runtime();
if (initialized) {
start_sigma = percent_to_sigma(config.start_percent);
end_sigma = percent_to_sigma(config.end_percent);
}
}
bool enabled() const {
return initialized && config.enabled;
}
float percent_to_sigma(float percent) const {
if (!denoiser) {
return 0.0f;
}
if (percent <= 0.0f) {
return std::numeric_limits<float>::max();
}
if (percent >= 1.0f) {
return 0.0f;
}
float t = (1.0f - percent) * (TIMESTEPS - 1);
return denoiser->t_to_sigma(t);
}
void begin_step(int step_index, float sigma) {
if (!enabled()) {
return;
}
if (step_index == current_step_index) {
return;
}
current_step_index = step_index;
skip_current_step = false;
has_last_input_change = false;
step_active = false;
if (sigma > start_sigma) {
return;
}
if (!(sigma > end_sigma)) {
return;
}
step_active = true;
}
bool step_is_active() const {
return enabled() && step_active;
}
bool is_step_skipped() const {
return enabled() && step_active && skip_current_step;
}
bool has_cache(const SDCondition* cond) const {
auto it = cache_diffs.find(cond);
return it != cache_diffs.end() && !it->second.diff.empty();
}
void update_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) {
EasyCacheCacheEntry& entry = cache_diffs[cond];
size_t ne = static_cast<size_t>(ggml_nelements(output));
entry.diff.resize(ne);
float* out_data = (float*)output->data;
float* in_data = (float*)input->data;
for (size_t i = 0; i < ne; ++i) {
entry.diff[i] = out_data[i] - in_data[i];
}
}
void apply_cache(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) {
auto it = cache_diffs.find(cond);
if (it == cache_diffs.end() || it->second.diff.empty()) {
return;
}
copy_ggml_tensor(output, input);
float* out_data = (float*)output->data;
const std::vector<float>& diff = it->second.diff;
for (size_t i = 0; i < diff.size(); ++i) {
out_data[i] += diff[i];
}
}
bool before_condition(const SDCondition* cond,
ggml_tensor* input,
ggml_tensor* output,
float sigma,
int step_index) {
if (!enabled() || step_index < 0) {
return false;
}
if (step_index != current_step_index) {
begin_step(step_index, sigma);
}
if (!step_active) {
return false;
}
if (initial_step) {
anchor_condition = cond;
initial_step = false;
}
bool is_anchor = (cond == anchor_condition);
if (skip_current_step) {
if (has_cache(cond)) {
apply_cache(cond, input, output);
return true;
}
return false;
}
if (!is_anchor) {
return false;
}
if (!has_prev_input || !has_prev_output || !has_cache(cond)) {
return false;
}
size_t ne = static_cast<size_t>(ggml_nelements(input));
if (prev_input.size() != ne) {
return false;
}
float* input_data = (float*)input->data;
last_input_change = 0.0f;
for (size_t i = 0; i < ne; ++i) {
last_input_change += std::fabs(input_data[i] - prev_input[i]);
}
if (ne > 0) {
last_input_change /= static_cast<float>(ne);
}
has_last_input_change = true;
if (has_output_prev_norm && has_relative_transformation_rate && last_input_change > 0.0f && output_prev_norm > 0.0f) {
float approx_output_change_rate = (relative_transformation_rate * last_input_change) / output_prev_norm;
cumulative_change_rate += approx_output_change_rate;
if (cumulative_change_rate < config.reuse_threshold) {
skip_current_step = true;
total_steps_skipped++;
apply_cache(cond, input, output);
return true;
} else {
cumulative_change_rate = 0.0f;
}
}
return false;
}
void after_condition(const SDCondition* cond, ggml_tensor* input, ggml_tensor* output) {
if (!step_is_active()) {
return;
}
update_cache(cond, input, output);
if (cond != anchor_condition) {
return;
}
size_t ne = static_cast<size_t>(ggml_nelements(input));
float* in_data = (float*)input->data;
prev_input.resize(ne);
for (size_t i = 0; i < ne; ++i) {
prev_input[i] = in_data[i];
}
has_prev_input = true;
float* out_data = (float*)output->data;
float output_change = 0.0f;
if (has_prev_output && prev_output.size() == ne) {
for (size_t i = 0; i < ne; ++i) {
output_change += std::fabs(out_data[i] - prev_output[i]);
}
if (ne > 0) {
output_change /= static_cast<float>(ne);
}
}
prev_output.resize(ne);
for (size_t i = 0; i < ne; ++i) {
prev_output[i] = out_data[i];
}
has_prev_output = true;
float mean_abs = 0.0f;
for (size_t i = 0; i < ne; ++i) {
mean_abs += std::fabs(out_data[i]);
}
output_prev_norm = (ne > 0) ? (mean_abs / static_cast<float>(ne)) : 0.0f;
has_output_prev_norm = output_prev_norm > 0.0f;
if (has_last_input_change && last_input_change > 0.0f && output_change > 0.0f) {
float rate = output_change / last_input_change;
if (std::isfinite(rate)) {
relative_transformation_rate = rate;
has_relative_transformation_rate = true;
}
}
cumulative_change_rate = 0.0f;
has_last_input_change = false;
}
};

View File

@ -33,6 +33,7 @@ Options:
-p, --prompt <string> the prompt to render -p, --prompt <string> the prompt to render
-n, --negative-prompt <string> the negative prompt (default: "") -n, --negative-prompt <string> the negative prompt (default: "")
--preview-path <string> path to write preview image to (default: ./preview.png) --preview-path <string> path to write preview image to (default: ./preview.png)
--easycache <string> enable EasyCache for DiT models, accepts optional "threshold,start_percent,end_percent" values (defaults to 0.2,0.15,0.95)
--upscale-model <string> path to esrgan model. --upscale-model <string> path to esrgan model.
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of -t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of
CPU physical cores CPU physical cores

View File

@ -1,6 +1,7 @@
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <time.h> #include <time.h>
#include <cctype>
#include <filesystem> #include <filesystem>
#include <functional> #include <functional>
#include <iostream> #include <iostream>
@ -105,6 +106,9 @@ struct SDParams {
std::vector<int> high_noise_skip_layers = {7, 8, 9}; std::vector<int> high_noise_skip_layers = {7, 8, 9};
sd_sample_params_t high_noise_sample_params; sd_sample_params_t high_noise_sample_params;
std::string easycache_option;
sd_easycache_params_t easycache_params;
float moe_boundary = 0.875f; float moe_boundary = 0.875f;
int video_frames = 1; int video_frames = 1;
int fps = 16; int fps = 16;
@ -154,6 +158,7 @@ struct SDParams {
sd_sample_params_init(&sample_params); sd_sample_params_init(&sample_params);
sd_sample_params_init(&high_noise_sample_params); sd_sample_params_init(&high_noise_sample_params);
high_noise_sample_params.sample_steps = -1; high_noise_sample_params.sample_steps = -1;
sd_easycache_params_init(&easycache_params);
} }
}; };
@ -225,6 +230,11 @@ void print_params(SDParams params) {
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false"); printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad); printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
printf(" video_frames: %d\n", params.video_frames); printf(" video_frames: %d\n", params.video_frames);
printf(" easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
params.easycache_params.enabled ? "enabled" : "disabled",
params.easycache_params.reuse_threshold,
params.easycache_params.start_percent,
params.easycache_params.end_percent);
printf(" vace_strength: %.2f\n", params.vace_strength); printf(" vace_strength: %.2f\n", params.vace_strength);
printf(" fps: %d\n", params.fps); printf(" fps: %d\n", params.fps);
printf(" preview_mode: %s (%s)\n", previews_str[params.preview_method], params.preview_noisy ? "noisy" : "denoised"); printf(" preview_mode: %s (%s)\n", previews_str[params.preview_method], params.preview_noisy ? "noisy" : "denoised");
@ -1128,6 +1138,38 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1; return 1;
}; };
auto on_easycache_arg = [&](int argc, const char** argv, int index) {
const std::string default_values = "0.2,0.15,0.95";
auto looks_like_value = [](const std::string& token) {
if (token.empty()) {
return false;
}
if (token[0] != '-') {
return true;
}
if (token.size() == 1) {
return false;
}
unsigned char next = static_cast<unsigned char>(token[1]);
return std::isdigit(next) || token[1] == '.';
};
std::string option_value;
int consumed = 0;
if (index + 1 < argc) {
std::string next_arg = argv[index + 1];
if (looks_like_value(next_arg)) {
option_value = argv_to_utf8(index + 1, argv);
consumed = 1;
}
}
if (option_value.empty()) {
option_value = default_values;
}
params.easycache_option = option_value;
return consumed;
};
options.manual_options = { options.manual_options = {
{"-M", {"-M",
"--mode", "--mode",
@ -1208,6 +1250,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"--preview", "--preview",
std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")\n", std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")\n",
on_preview_arg}, on_preview_arg},
{"",
"--easycache",
"enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)",
on_easycache_arg},
}; };
if (!parse_options(argc, argv, options)) { if (!parse_options(argc, argv, options)) {
@ -1215,6 +1261,59 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1); exit(1);
} }
if (!params.easycache_option.empty()) {
float values[3] = {0.0f, 0.0f, 0.0f};
std::stringstream ss(params.easycache_option);
std::string token;
int idx = 0;
while (std::getline(ss, token, ',')) {
auto trim = [](std::string& s) {
const char* whitespace = " \t\r\n";
auto start = s.find_first_not_of(whitespace);
if (start == std::string::npos) {
s.clear();
return;
}
auto end = s.find_last_not_of(whitespace);
s = s.substr(start, end - start + 1);
};
trim(token);
if (token.empty()) {
fprintf(stderr, "error: invalid easycache option '%s'\n", params.easycache_option.c_str());
exit(1);
}
if (idx >= 3) {
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
exit(1);
}
try {
values[idx] = std::stof(token);
} catch (const std::exception&) {
fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str());
exit(1);
}
idx++;
}
if (idx != 3) {
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
exit(1);
}
if (values[0] < 0.0f) {
fprintf(stderr, "error: easycache threshold must be non-negative\n");
exit(1);
}
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
exit(1);
}
params.easycache_params.enabled = true;
params.easycache_params.reuse_threshold = values[0];
params.easycache_params.start_percent = values[1];
params.easycache_params.end_percent = values[2];
} else {
params.easycache_params.enabled = false;
}
if (params.n_threads <= 0) { if (params.n_threads <= 0) {
params.n_threads = get_num_physical_cores(); params.n_threads = get_num_physical_cores();
} }
@ -1852,6 +1951,7 @@ int main(int argc, const char* argv[]) {
params.pm_style_strength, params.pm_style_strength,
}, // pm_params }, // pm_params
params.vae_tiling_params, params.vae_tiling_params,
params.easycache_params,
}; };
results = generate_image(sd_ctx, &img_gen_params); results = generate_image(sd_ctx, &img_gen_params);
@ -1874,6 +1974,7 @@ int main(int argc, const char* argv[]) {
params.seed, params.seed,
params.video_frames, params.video_frames,
params.vace_strength, params.vace_strength,
params.easycache_params,
}; };
results = generate_video(sd_ctx, &vid_gen_params, &num_results); results = generate_video(sd_ctx, &vid_gen_params, &num_results);

View File

@ -11,6 +11,7 @@
#include "control.hpp" #include "control.hpp"
#include "denoiser.hpp" #include "denoiser.hpp"
#include "diffusion_model.hpp" #include "diffusion_model.hpp"
#include "easycache.hpp"
#include "esrgan.hpp" #include "esrgan.hpp"
#include "lora.hpp" #include "lora.hpp"
#include "pmid.hpp" #include "pmid.hpp"
@ -1481,11 +1482,12 @@ public:
const std::vector<float>& sigmas, const std::vector<float>& sigmas,
int start_merge_step, int start_merge_step,
SDCondition id_cond, SDCondition id_cond,
std::vector<ggml_tensor*> ref_latents = {}, std::vector<ggml_tensor*> ref_latents = {},
bool increase_ref_index = false, bool increase_ref_index = false,
ggml_tensor* denoise_mask = nullptr, ggml_tensor* denoise_mask = nullptr,
ggml_tensor* vace_context = nullptr, ggml_tensor* vace_context = nullptr,
float vace_strength = 1.f) { float vace_strength = 1.f,
const sd_easycache_params_t* easycache_params = nullptr) {
if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) { if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) {
LOG_WARN("timestep shifting is only supported for SDXL models!"); LOG_WARN("timestep shifting is only supported for SDXL models!");
shifted_timestep = 0; shifted_timestep = 0;
@ -1501,6 +1503,42 @@ public:
img_cfg_scale = cfg_scale; img_cfg_scale = cfg_scale;
} }
EasyCacheState easycache_state;
bool easycache_enabled = false;
if (easycache_params != nullptr && easycache_params->enabled) {
bool easycache_supported = sd_version_is_dit(version);
if (!easycache_supported) {
LOG_WARN("EasyCache requested but not supported for this model type");
} else {
EasyCacheConfig easycache_config;
easycache_config.enabled = true;
easycache_config.reuse_threshold = std::max(0.0f, easycache_params->reuse_threshold);
easycache_config.start_percent = easycache_params->start_percent;
easycache_config.end_percent = easycache_params->end_percent;
bool percent_valid = easycache_config.start_percent >= 0.0f &&
easycache_config.start_percent < 1.0f &&
easycache_config.end_percent > 0.0f &&
easycache_config.end_percent <= 1.0f &&
easycache_config.start_percent < easycache_config.end_percent;
if (!percent_valid) {
LOG_WARN("EasyCache disabled due to invalid percent range (start=%.3f, end=%.3f)",
easycache_config.start_percent,
easycache_config.end_percent);
} else {
easycache_state.init(easycache_config, denoiser.get());
if (easycache_state.enabled()) {
easycache_enabled = true;
LOG_INFO("EasyCache enabled - threshold: %.3f, start_percent: %.2f, end_percent: %.2f",
easycache_config.reuse_threshold,
easycache_config.start_percent,
easycache_config.end_percent);
} else {
LOG_WARN("EasyCache requested but could not be initialized for this run");
}
}
}
}
size_t steps = sigmas.size() - 1; size_t steps = sigmas.size() - 1;
struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent); struct ggml_tensor* x = ggml_dup_tensor(work_ctx, init_latent);
copy_ggml_tensor(x, init_latent); copy_ggml_tensor(x, init_latent);
@ -1571,6 +1609,38 @@ public:
pretty_progress(0, (int)steps, 0); pretty_progress(0, (int)steps, 0);
} }
DiffusionParams diffusion_params;
const bool easycache_step_active = easycache_enabled && step > 0;
int easycache_step_index = easycache_step_active ? (step - 1) : -1;
if (easycache_step_active) {
easycache_state.begin_step(easycache_step_index, sigma);
}
auto easycache_before_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) -> bool {
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
return false;
}
return easycache_state.before_condition(condition,
diffusion_params.x,
output_tensor,
sigma,
easycache_step_index);
};
auto easycache_after_condition = [&](const SDCondition* condition, struct ggml_tensor* output_tensor) {
if (!easycache_step_active || condition == nullptr || output_tensor == nullptr) {
return;
}
easycache_state.after_condition(condition,
diffusion_params.x,
output_tensor);
};
auto easycache_step_is_skipped = [&]() {
return easycache_step_active && easycache_state.is_step_skipped();
};
std::vector<float> scaling = denoiser->get_scalings(sigma); std::vector<float> scaling = denoiser->get_scalings(sigma);
GGML_ASSERT(scaling.size() == 3); GGML_ASSERT(scaling.size() == 3);
float c_skip = scaling[0]; float c_skip = scaling[0];
@ -1616,7 +1686,6 @@ public:
// GGML_ASSERT(0); // GGML_ASSERT(0);
} }
DiffusionParams diffusion_params;
diffusion_params.x = noised_input; diffusion_params.x = noised_input;
diffusion_params.timesteps = timesteps; diffusion_params.timesteps = timesteps;
diffusion_params.guidance = guidance_tensor; diffusion_params.guidance = guidance_tensor;
@ -1627,37 +1696,50 @@ public:
diffusion_params.vace_context = vace_context; diffusion_params.vace_context = vace_context;
diffusion_params.vace_strength = vace_strength; diffusion_params.vace_strength = vace_strength;
const SDCondition* active_condition = nullptr;
struct ggml_tensor** active_output = &out_cond;
if (start_merge_step == -1 || step <= start_merge_step) { if (start_merge_step == -1 || step <= start_merge_step) {
// cond // cond
diffusion_params.context = cond.c_crossattn; diffusion_params.context = cond.c_crossattn;
diffusion_params.c_concat = cond.c_concat; diffusion_params.c_concat = cond.c_concat;
diffusion_params.y = cond.c_vector; diffusion_params.y = cond.c_vector;
work_diffusion_model->compute(n_threads, active_condition = &cond;
diffusion_params,
&out_cond);
} else { } else {
diffusion_params.context = id_cond.c_crossattn; diffusion_params.context = id_cond.c_crossattn;
diffusion_params.c_concat = cond.c_concat; diffusion_params.c_concat = cond.c_concat;
diffusion_params.y = id_cond.c_vector; diffusion_params.y = id_cond.c_vector;
active_condition = &id_cond;
}
bool skip_model = easycache_before_condition(active_condition, *active_output);
if (!skip_model) {
work_diffusion_model->compute(n_threads, work_diffusion_model->compute(n_threads,
diffusion_params, diffusion_params,
&out_cond); active_output);
easycache_after_condition(active_condition, *active_output);
} }
bool current_step_skipped = easycache_step_is_skipped();
float* negative_data = nullptr; float* negative_data = nullptr;
if (has_unconditioned) { if (has_unconditioned) {
// uncond // uncond
if (control_hint != nullptr && control_net != nullptr) { if (!current_step_skipped && control_hint != nullptr && control_net != nullptr) {
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector); control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
controls = control_net->controls; controls = control_net->controls;
} }
current_step_skipped = easycache_step_is_skipped();
diffusion_params.controls = controls; diffusion_params.controls = controls;
diffusion_params.context = uncond.c_crossattn; diffusion_params.context = uncond.c_crossattn;
diffusion_params.c_concat = uncond.c_concat; diffusion_params.c_concat = uncond.c_concat;
diffusion_params.y = uncond.c_vector; diffusion_params.y = uncond.c_vector;
work_diffusion_model->compute(n_threads, bool skip_uncond = easycache_before_condition(&uncond, out_uncond);
diffusion_params, if (!skip_uncond) {
&out_uncond); work_diffusion_model->compute(n_threads,
diffusion_params,
&out_uncond);
easycache_after_condition(&uncond, out_uncond);
}
negative_data = (float*)out_uncond->data; negative_data = (float*)out_uncond->data;
} }
@ -1666,25 +1748,31 @@ public:
diffusion_params.context = img_cond.c_crossattn; diffusion_params.context = img_cond.c_crossattn;
diffusion_params.c_concat = img_cond.c_concat; diffusion_params.c_concat = img_cond.c_concat;
diffusion_params.y = img_cond.c_vector; diffusion_params.y = img_cond.c_vector;
work_diffusion_model->compute(n_threads, bool skip_img_cond = easycache_before_condition(&img_cond, out_img_cond);
diffusion_params, if (!skip_img_cond) {
&out_img_cond); work_diffusion_model->compute(n_threads,
diffusion_params,
&out_img_cond);
easycache_after_condition(&img_cond, out_img_cond);
}
img_cond_data = (float*)out_img_cond->data; img_cond_data = (float*)out_img_cond->data;
} }
int step_count = sigmas.size(); int step_count = sigmas.size();
bool is_skiplayer_step = has_skiplayer && step > (int)(guidance.slg.layer_start * step_count) && step < (int)(guidance.slg.layer_end * step_count); bool is_skiplayer_step = has_skiplayer && step > (int)(guidance.slg.layer_start * step_count) && step < (int)(guidance.slg.layer_end * step_count);
float* skip_layer_data = nullptr; float* skip_layer_data = has_skiplayer ? (float*)out_skip->data : nullptr;
if (is_skiplayer_step) { if (is_skiplayer_step) {
LOG_DEBUG("Skipping layers at step %d\n", step); LOG_DEBUG("Skipping layers at step %d\n", step);
// skip layer (same as conditionned) if (!easycache_step_is_skipped()) {
diffusion_params.context = cond.c_crossattn; // skip layer (same as conditioned)
diffusion_params.c_concat = cond.c_concat; diffusion_params.context = cond.c_crossattn;
diffusion_params.y = cond.c_vector; diffusion_params.c_concat = cond.c_concat;
diffusion_params.skip_layers = skip_layers; diffusion_params.y = cond.c_vector;
work_diffusion_model->compute(n_threads, diffusion_params.skip_layers = skip_layers;
diffusion_params, work_diffusion_model->compute(n_threads,
&out_skip); diffusion_params,
&out_skip);
}
skip_layer_data = (float*)out_skip->data; skip_layer_data = (float*)out_skip->data;
} }
float* vec_denoised = (float*)denoised->data; float* vec_denoised = (float*)denoised->data;
@ -1748,6 +1836,26 @@ public:
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta); sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
if (easycache_enabled) {
size_t total_steps = sigmas.size() > 0 ? sigmas.size() - 1 : 0;
if (easycache_state.total_steps_skipped > 0 && total_steps > 0) {
if (easycache_state.total_steps_skipped < static_cast<int>(total_steps)) {
double speedup = static_cast<double>(total_steps) /
static_cast<double>(total_steps - easycache_state.total_steps_skipped);
LOG_INFO("EasyCache skipped %d/%zu steps (%.2fx estimated speedup)",
easycache_state.total_steps_skipped,
total_steps,
speedup);
} else {
LOG_INFO("EasyCache skipped %d/%zu steps",
easycache_state.total_steps_skipped,
total_steps);
}
} else if (total_steps > 0) {
LOG_INFO("EasyCache completed without skipping steps");
}
}
if (inverse_noise_scaling) { if (inverse_noise_scaling) {
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
} }
@ -2294,6 +2402,14 @@ enum lora_apply_mode_t str_to_lora_apply_mode(const char* str) {
return LORA_APPLY_MODE_COUNT; return LORA_APPLY_MODE_COUNT;
} }
void sd_easycache_params_init(sd_easycache_params_t* easycache_params) {
*easycache_params = {};
easycache_params->enabled = false;
easycache_params->reuse_threshold = 0.2f;
easycache_params->start_percent = 0.15f;
easycache_params->end_percent = 0.95f;
}
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
*sd_ctx_params = {}; *sd_ctx_params = {};
sd_ctx_params->vae_decode_only = true; sd_ctx_params->vae_decode_only = true;
@ -2452,6 +2568,7 @@ void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->control_strength = 0.9f; sd_img_gen_params->control_strength = 0.9f;
sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f}; sd_img_gen_params->pm_params = {nullptr, 0, nullptr, 20.f};
sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f}; sd_img_gen_params->vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
sd_easycache_params_init(&sd_img_gen_params->easycache);
} }
char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
@ -2495,6 +2612,12 @@ char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) {
sd_img_gen_params->pm_params.id_images_count, sd_img_gen_params->pm_params.id_images_count,
SAFE_STR(sd_img_gen_params->pm_params.id_embed_path), SAFE_STR(sd_img_gen_params->pm_params.id_embed_path),
BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled)); BOOL_STR(sd_img_gen_params->vae_tiling_params.enabled));
snprintf(buf + strlen(buf), 4096 - strlen(buf),
"easycache: %s (threshold=%.3f, start=%.2f, end=%.2f)\n",
sd_img_gen_params->easycache.enabled ? "enabled" : "disabled",
sd_img_gen_params->easycache.reuse_threshold,
sd_img_gen_params->easycache.start_percent,
sd_img_gen_params->easycache.end_percent);
free(sample_params_str); free(sample_params_str);
return buf; return buf;
} }
@ -2511,6 +2634,7 @@ void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) {
sd_vid_gen_params->video_frames = 6; sd_vid_gen_params->video_frames = 6;
sd_vid_gen_params->moe_boundary = 0.875f; sd_vid_gen_params->moe_boundary = 0.875f;
sd_vid_gen_params->vace_strength = 1.f; sd_vid_gen_params->vace_strength = 1.f;
sd_easycache_params_init(&sd_vid_gen_params->easycache);
} }
struct sd_ctx_t { struct sd_ctx_t {
@ -2578,8 +2702,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
std::vector<sd_image_t*> ref_images, std::vector<sd_image_t*> ref_images,
std::vector<ggml_tensor*> ref_latents, std::vector<ggml_tensor*> ref_latents,
bool increase_ref_index, bool increase_ref_index,
ggml_tensor* concat_latent = nullptr, ggml_tensor* concat_latent = nullptr,
ggml_tensor* denoise_mask = nullptr) { ggml_tensor* denoise_mask = nullptr,
const sd_easycache_params_t* easycache_params = nullptr) {
if (seed < 0) { if (seed < 0) {
// Generally, when using the provided command line, the seed is always >0. // Generally, when using the provided command line, the seed is always >0.
// However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library
@ -2868,7 +2993,10 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
id_cond, id_cond,
ref_latents, ref_latents,
increase_ref_index, increase_ref_index,
denoise_mask); denoise_mask,
nullptr,
1.0f,
easycache_params);
// print_ggml_tensor(x_0); // print_ggml_tensor(x_0);
int64_t sampling_end = ggml_time_ms(); int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
@ -3185,7 +3313,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
ref_latents, ref_latents,
sd_img_gen_params->increase_ref_index, sd_img_gen_params->increase_ref_index,
concat_latent, concat_latent,
denoise_mask); denoise_mask,
&sd_img_gen_params->easycache);
size_t t2 = ggml_time_ms(); size_t t2 = ggml_time_ms();
@ -3506,7 +3635,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
false, false,
denoise_mask, denoise_mask,
vace_context, vace_context,
sd_vid_gen_params->vace_strength); sd_vid_gen_params->vace_strength,
&sd_vid_gen_params->easycache);
int64_t sampling_end = ggml_time_ms(); int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); LOG_INFO("sampling(high noise) completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);
@ -3542,7 +3672,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
false, false,
denoise_mask, denoise_mask,
vace_context, vace_context,
sd_vid_gen_params->vace_strength); sd_vid_gen_params->vace_strength,
&sd_vid_gen_params->easycache);
int64_t sampling_end = ggml_time_ms(); int64_t sampling_end = ggml_time_ms();
LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000); LOG_INFO("sampling completed, taking %.2fs", (sampling_end - sampling_start) * 1.0f / 1000);

View File

@ -229,6 +229,13 @@ typedef struct {
float style_strength; float style_strength;
} sd_pm_params_t; // photo maker } sd_pm_params_t; // photo maker
typedef struct {
bool enabled;
float reuse_threshold;
float start_percent;
float end_percent;
} sd_easycache_params_t;
typedef struct { typedef struct {
const char* prompt; const char* prompt;
const char* negative_prompt; const char* negative_prompt;
@ -249,6 +256,7 @@ typedef struct {
float control_strength; float control_strength;
sd_pm_params_t pm_params; sd_pm_params_t pm_params;
sd_tiling_params_t vae_tiling_params; sd_tiling_params_t vae_tiling_params;
sd_easycache_params_t easycache;
} sd_img_gen_params_t; } sd_img_gen_params_t;
typedef struct { typedef struct {
@ -268,6 +276,7 @@ typedef struct {
int64_t seed; int64_t seed;
int video_frames; int video_frames;
float vace_strength; float vace_strength;
sd_easycache_params_t easycache;
} sd_vid_gen_params_t; } sd_vid_gen_params_t;
typedef struct sd_ctx_t sd_ctx_t; typedef struct sd_ctx_t sd_ctx_t;
@ -297,6 +306,8 @@ SD_API enum preview_t str_to_preview(const char* str);
SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode); SD_API const char* sd_lora_apply_mode_name(enum lora_apply_mode_t mode);
SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str); SD_API enum lora_apply_mode_t str_to_lora_apply_mode(const char* str);
SD_API void sd_easycache_params_init(sd_easycache_params_t* easycache_params);
SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params); SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);