refactor: simplify ControlNet output caching (#1655)

This commit is contained in:
leejet 2026-06-14 16:58:37 +08:00 committed by GitHub
parent 17d70b91e6
commit 9838264c49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 51 deletions

View File

@ -2007,6 +2007,10 @@ protected:
} }
bool copy_cache_tensors_to_cache_buffer(const std::unordered_set<std::string>* cache_keep_names = nullptr) { bool copy_cache_tensors_to_cache_buffer(const std::unordered_set<std::string>* cache_keep_names = nullptr) {
if (cache_tensor_map.empty() && cache_keep_names == nullptr) {
return true;
}
ggml_context* old_cache_ctx = cache_ctx; ggml_context* old_cache_ctx = cache_ctx;
ggml_backend_buffer_t old_cache_buffer = cache_buffer; ggml_backend_buffer_t old_cache_buffer = cache_buffer;
cache_ctx = nullptr; cache_ctx = nullptr;

View File

@ -312,16 +312,17 @@ struct ControlNet : public GGMLRunner {
ControlNetBlock control_net; ControlNetBlock control_net;
std::string weight_prefix; std::string weight_prefix;
ggml_backend_buffer_t control_buffer = nullptr;
ggml_context* control_ctx = nullptr;
std::vector<ggml_tensor*> control_outputs_ggml; std::vector<ggml_tensor*> control_outputs_ggml;
ggml_tensor* guided_hint_output_ggml = nullptr; ggml_tensor* guided_hint_output_ggml = nullptr;
std::vector<sd::Tensor<float>> controls; std::vector<sd::Tensor<float>> controls;
sd::Tensor<float> guided_hint;
bool guided_hint_cached = false; bool guided_hint_cached = false;
std::shared_ptr<ModelManager> owned_model_manager; std::shared_ptr<ModelManager> owned_model_manager;
ggml_backend_t params_backend = nullptr; ggml_backend_t params_backend = nullptr;
static const char* guided_hint_cache_name() {
return "controlnet.guided_hint";
}
ControlNet(ggml_backend_t backend, ControlNet(ggml_backend_t backend,
ggml_backend_t params_backend_, ggml_backend_t params_backend_,
const String2TensorStorage& tensor_storage_map = {}, const String2TensorStorage& tensor_storage_map = {},
@ -336,44 +337,12 @@ struct ControlNet : public GGMLRunner {
free_control_ctx(); free_control_ctx();
} }
void alloc_control_ctx(std::vector<ggml_tensor*> outs) {
ggml_init_params params;
params.mem_size = static_cast<size_t>(outs.size() * ggml_tensor_overhead()) + 1024 * 1024;
params.mem_buffer = nullptr;
params.no_alloc = true;
control_ctx = ggml_init(params);
control_outputs_ggml.resize(outs.size() - 1);
size_t control_buffer_size = 0;
guided_hint_output_ggml = ggml_dup_tensor(control_ctx, outs[0]);
control_buffer_size += ggml_nbytes(guided_hint_output_ggml);
for (int i = 0; i < outs.size() - 1; i++) {
control_outputs_ggml[i] = ggml_dup_tensor(control_ctx, outs[i + 1]);
control_buffer_size += ggml_nbytes(control_outputs_ggml[i]);
}
control_buffer = ggml_backend_alloc_ctx_tensors(control_ctx, runtime_backend);
LOG_DEBUG("control buffer size %.2fMB", control_buffer_size * 1.f / 1024.f / 1024.f);
}
void free_control_ctx() { void free_control_ctx() {
if (control_buffer != nullptr) {
ggml_backend_buffer_free(control_buffer);
control_buffer = nullptr;
}
if (control_ctx != nullptr) {
ggml_free(control_ctx);
control_ctx = nullptr;
}
guided_hint_output_ggml = nullptr; guided_hint_output_ggml = nullptr;
guided_hint_cached = false; guided_hint_cached = false;
guided_hint = {};
control_outputs_ggml.clear(); control_outputs_ggml.clear();
controls.clear(); controls.clear();
free_cache_ctx_and_buffer();
} }
std::string get_desc() override { std::string get_desc() override {
@ -397,11 +366,17 @@ struct ControlNet : public GGMLRunner {
ggml_tensor* context = make_optional_input(context_tensor); ggml_tensor* context = make_optional_input(context_tensor);
ggml_tensor* y = make_optional_input(y_tensor); ggml_tensor* y = make_optional_input(y_tensor);
guided_hint_output_ggml = nullptr;
control_outputs_ggml.clear();
ggml_tensor* guided_hint_input = nullptr; ggml_tensor* guided_hint_input = nullptr;
if (guided_hint_cached && !guided_hint.empty()) { if (guided_hint_cached) {
guided_hint_input = make_input(guided_hint); guided_hint_input = get_cache_tensor_by_name(guided_hint_cache_name());
hint = nullptr; if (guided_hint_input == nullptr) {
} else { guided_hint_cached = false;
}
}
if (guided_hint_input == nullptr) {
hint = make_input(hint_tensor); hint = make_input(hint_tensor);
} }
@ -415,13 +390,19 @@ struct ControlNet : public GGMLRunner {
context, context,
y); y);
if (control_ctx == nullptr) { if (guided_hint_input == nullptr && !outs.empty()) {
alloc_control_ctx(outs); guided_hint_output_ggml = outs[0];
ggml_set_output(guided_hint_output_ggml);
cache(guided_hint_cache_name(), guided_hint_output_ggml);
ggml_build_forward_expand(gf, guided_hint_output_ggml);
} }
ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[0], guided_hint_output_ggml)); control_outputs_ggml.reserve(outs.size() > 0 ? outs.size() - 1 : 0);
for (int i = 0; i < outs.size() - 1; i++) { for (size_t i = 1; i < outs.size(); i++) {
ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[i + 1], control_outputs_ggml[i])); ggml_tensor* control_output = outs[i];
ggml_set_output(control_output);
ggml_build_forward_expand(gf, control_output);
control_outputs_ggml.push_back(control_output);
} }
return gf; return gf;
@ -441,15 +422,12 @@ struct ControlNet : public GGMLRunner {
return build_graph(x, hint, timesteps, context, y); return build_graph(x, hint, timesteps, context, y);
}; };
auto compute_result = GGMLRunner::compute<float>(get_graph, n_threads, false, false, false); auto compute_result = GGMLRunner::compute<float>(get_graph, n_threads, false, false, false, true);
if (!compute_result.has_value()) { if (!compute_result.has_value()) {
return std::nullopt; return std::nullopt;
} }
if (guided_hint_output_ggml != nullptr) { guided_hint_cached = get_cache_tensor_by_name(guided_hint_cache_name()) != nullptr;
guided_hint = restore_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(guided_hint_output_ggml),
4);
}
controls.clear(); controls.clear();
controls.reserve(control_outputs_ggml.size()); controls.reserve(control_outputs_ggml.size());
for (ggml_tensor* control : control_outputs_ggml) { for (ggml_tensor* control : control_outputs_ggml) {
@ -457,7 +435,6 @@ struct ControlNet : public GGMLRunner {
GGML_ASSERT(!control_host.empty()); GGML_ASSERT(!control_host.empty());
controls.push_back(std::move(control_host)); controls.push_back(std::move(control_host));
} }
guided_hint_cached = true;
return controls; return controls;
} }