mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-17 03:37:20 +00:00
refactor: simplify ControlNet output caching (#1655)
This commit is contained in:
parent
17d70b91e6
commit
9838264c49
@ -2007,6 +2007,10 @@ protected:
|
||||
}
|
||||
|
||||
bool copy_cache_tensors_to_cache_buffer(const std::unordered_set<std::string>* cache_keep_names = nullptr) {
|
||||
if (cache_tensor_map.empty() && cache_keep_names == nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_context* old_cache_ctx = cache_ctx;
|
||||
ggml_backend_buffer_t old_cache_buffer = cache_buffer;
|
||||
cache_ctx = nullptr;
|
||||
|
||||
@ -312,16 +312,17 @@ struct ControlNet : public GGMLRunner {
|
||||
ControlNetBlock control_net;
|
||||
std::string weight_prefix;
|
||||
|
||||
ggml_backend_buffer_t control_buffer = nullptr;
|
||||
ggml_context* control_ctx = nullptr;
|
||||
std::vector<ggml_tensor*> control_outputs_ggml;
|
||||
ggml_tensor* guided_hint_output_ggml = nullptr;
|
||||
std::vector<sd::Tensor<float>> controls;
|
||||
sd::Tensor<float> guided_hint;
|
||||
bool guided_hint_cached = false;
|
||||
std::shared_ptr<ModelManager> owned_model_manager;
|
||||
ggml_backend_t params_backend = nullptr;
|
||||
|
||||
static const char* guided_hint_cache_name() {
|
||||
return "controlnet.guided_hint";
|
||||
}
|
||||
|
||||
ControlNet(ggml_backend_t backend,
|
||||
ggml_backend_t params_backend_,
|
||||
const String2TensorStorage& tensor_storage_map = {},
|
||||
@ -336,44 +337,12 @@ struct ControlNet : public GGMLRunner {
|
||||
free_control_ctx();
|
||||
}
|
||||
|
||||
void alloc_control_ctx(std::vector<ggml_tensor*> outs) {
|
||||
ggml_init_params params;
|
||||
params.mem_size = static_cast<size_t>(outs.size() * ggml_tensor_overhead()) + 1024 * 1024;
|
||||
params.mem_buffer = nullptr;
|
||||
params.no_alloc = true;
|
||||
control_ctx = ggml_init(params);
|
||||
|
||||
control_outputs_ggml.resize(outs.size() - 1);
|
||||
|
||||
size_t control_buffer_size = 0;
|
||||
|
||||
guided_hint_output_ggml = ggml_dup_tensor(control_ctx, outs[0]);
|
||||
control_buffer_size += ggml_nbytes(guided_hint_output_ggml);
|
||||
|
||||
for (int i = 0; i < outs.size() - 1; i++) {
|
||||
control_outputs_ggml[i] = ggml_dup_tensor(control_ctx, outs[i + 1]);
|
||||
control_buffer_size += ggml_nbytes(control_outputs_ggml[i]);
|
||||
}
|
||||
|
||||
control_buffer = ggml_backend_alloc_ctx_tensors(control_ctx, runtime_backend);
|
||||
|
||||
LOG_DEBUG("control buffer size %.2fMB", control_buffer_size * 1.f / 1024.f / 1024.f);
|
||||
}
|
||||
|
||||
void free_control_ctx() {
|
||||
if (control_buffer != nullptr) {
|
||||
ggml_backend_buffer_free(control_buffer);
|
||||
control_buffer = nullptr;
|
||||
}
|
||||
if (control_ctx != nullptr) {
|
||||
ggml_free(control_ctx);
|
||||
control_ctx = nullptr;
|
||||
}
|
||||
guided_hint_output_ggml = nullptr;
|
||||
guided_hint_cached = false;
|
||||
guided_hint = {};
|
||||
control_outputs_ggml.clear();
|
||||
controls.clear();
|
||||
free_cache_ctx_and_buffer();
|
||||
}
|
||||
|
||||
std::string get_desc() override {
|
||||
@ -397,11 +366,17 @@ struct ControlNet : public GGMLRunner {
|
||||
ggml_tensor* context = make_optional_input(context_tensor);
|
||||
ggml_tensor* y = make_optional_input(y_tensor);
|
||||
|
||||
guided_hint_output_ggml = nullptr;
|
||||
control_outputs_ggml.clear();
|
||||
|
||||
ggml_tensor* guided_hint_input = nullptr;
|
||||
if (guided_hint_cached && !guided_hint.empty()) {
|
||||
guided_hint_input = make_input(guided_hint);
|
||||
hint = nullptr;
|
||||
} else {
|
||||
if (guided_hint_cached) {
|
||||
guided_hint_input = get_cache_tensor_by_name(guided_hint_cache_name());
|
||||
if (guided_hint_input == nullptr) {
|
||||
guided_hint_cached = false;
|
||||
}
|
||||
}
|
||||
if (guided_hint_input == nullptr) {
|
||||
hint = make_input(hint_tensor);
|
||||
}
|
||||
|
||||
@ -415,13 +390,19 @@ struct ControlNet : public GGMLRunner {
|
||||
context,
|
||||
y);
|
||||
|
||||
if (control_ctx == nullptr) {
|
||||
alloc_control_ctx(outs);
|
||||
if (guided_hint_input == nullptr && !outs.empty()) {
|
||||
guided_hint_output_ggml = outs[0];
|
||||
ggml_set_output(guided_hint_output_ggml);
|
||||
cache(guided_hint_cache_name(), guided_hint_output_ggml);
|
||||
ggml_build_forward_expand(gf, guided_hint_output_ggml);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[0], guided_hint_output_ggml));
|
||||
for (int i = 0; i < outs.size() - 1; i++) {
|
||||
ggml_build_forward_expand(gf, ggml_cpy(compute_ctx, outs[i + 1], control_outputs_ggml[i]));
|
||||
control_outputs_ggml.reserve(outs.size() > 0 ? outs.size() - 1 : 0);
|
||||
for (size_t i = 1; i < outs.size(); i++) {
|
||||
ggml_tensor* control_output = outs[i];
|
||||
ggml_set_output(control_output);
|
||||
ggml_build_forward_expand(gf, control_output);
|
||||
control_outputs_ggml.push_back(control_output);
|
||||
}
|
||||
|
||||
return gf;
|
||||
@ -441,15 +422,12 @@ struct ControlNet : public GGMLRunner {
|
||||
return build_graph(x, hint, timesteps, context, y);
|
||||
};
|
||||
|
||||
auto compute_result = GGMLRunner::compute<float>(get_graph, n_threads, false, false, false);
|
||||
auto compute_result = GGMLRunner::compute<float>(get_graph, n_threads, false, false, false, true);
|
||||
if (!compute_result.has_value()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (guided_hint_output_ggml != nullptr) {
|
||||
guided_hint = restore_trailing_singleton_dims(sd::make_sd_tensor_from_ggml<float>(guided_hint_output_ggml),
|
||||
4);
|
||||
}
|
||||
guided_hint_cached = get_cache_tensor_by_name(guided_hint_cache_name()) != nullptr;
|
||||
controls.clear();
|
||||
controls.reserve(control_outputs_ggml.size());
|
||||
for (ggml_tensor* control : control_outputs_ggml) {
|
||||
@ -457,7 +435,6 @@ struct ControlNet : public GGMLRunner {
|
||||
GGML_ASSERT(!control_host.empty());
|
||||
controls.push_back(std::move(control_host));
|
||||
}
|
||||
guided_hint_cached = true;
|
||||
return controls;
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user