mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-19 04:37:18 +00:00
refactor: manage upscaler params through model manager (#1645)
This commit is contained in:
parent
563137a592
commit
9b0fceb41b
@ -48,6 +48,7 @@ enum SDVersion {
|
|||||||
VERSION_LONGCAT,
|
VERSION_LONGCAT,
|
||||||
VERSION_PID,
|
VERSION_PID,
|
||||||
VERSION_IDEOGRAM4,
|
VERSION_IDEOGRAM4,
|
||||||
|
VERSION_ESRGAN,
|
||||||
VERSION_COUNT,
|
VERSION_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,14 @@
|
|||||||
#ifndef __SD_MODEL_UPSCALER_ESRGAN_HPP__
|
#ifndef __SD_MODEL_UPSCALER_ESRGAN_HPP__
|
||||||
#define __SD_MODEL_UPSCALER_ESRGAN_HPP__
|
#define __SD_MODEL_UPSCALER_ESRGAN_HPP__
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "core/ggml_extend.hpp"
|
#include "core/ggml_extend.hpp"
|
||||||
#include "model_loader.h"
|
#include "core/util.h"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
=================================== ESRGAN ===================================
|
=================================== ESRGAN ===================================
|
||||||
@ -12,6 +18,74 @@
|
|||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
struct ESRGANConfig {
|
||||||
|
int scale = 4;
|
||||||
|
int num_block = 23;
|
||||||
|
int num_in_ch = 3;
|
||||||
|
int num_out_ch = 3;
|
||||||
|
int num_feat = 64;
|
||||||
|
int num_grow_ch = 32;
|
||||||
|
|
||||||
|
static ESRGANConfig detect_from_weights(const String2TensorStorage& tensor_storage_map,
|
||||||
|
const std::string& prefix = "") {
|
||||||
|
ESRGANConfig config;
|
||||||
|
auto find_weight = [&](const std::string& suffix) -> const TensorStorage* {
|
||||||
|
std::string name = prefix.empty() ? suffix : prefix + "." + suffix;
|
||||||
|
auto iter = tensor_storage_map.find(name);
|
||||||
|
if (iter == tensor_storage_map.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return &iter->second;
|
||||||
|
};
|
||||||
|
|
||||||
|
int detected_num_block = 0;
|
||||||
|
const std::string body_prefix = prefix.empty() ? "body." : prefix + ".body.";
|
||||||
|
for (const auto& [name, _] : tensor_storage_map) {
|
||||||
|
if (!starts_with(name, body_prefix)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
size_t pos = name.find('.', body_prefix.size());
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
int idx = std::stoi(name.substr(body_prefix.size(), pos - body_prefix.size()));
|
||||||
|
detected_num_block = std::max(detected_num_block, idx + 1);
|
||||||
|
} catch (...) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (detected_num_block > 0) {
|
||||||
|
config.num_block = detected_num_block;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_conv_up2 = find_weight("conv_up2.weight") != nullptr;
|
||||||
|
bool has_conv_up1 = find_weight("conv_up1.weight") != nullptr;
|
||||||
|
bool has_model_tensor =
|
||||||
|
detected_num_block > 0 ||
|
||||||
|
find_weight("conv_first.weight") != nullptr ||
|
||||||
|
find_weight("conv_hr.weight") != nullptr ||
|
||||||
|
find_weight("conv_last.weight") != nullptr;
|
||||||
|
if (has_conv_up2) {
|
||||||
|
config.scale = 4;
|
||||||
|
} else if (has_conv_up1) {
|
||||||
|
config.scale = 2;
|
||||||
|
} else if (has_model_tensor) {
|
||||||
|
config.scale = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (has_model_tensor || has_conv_up1 || has_conv_up2) {
|
||||||
|
LOG_DEBUG("esrgan: scale = %d, num_block = %d, num_in_ch = %d, num_out_ch = %d, num_feat = %d, num_grow_ch = %d",
|
||||||
|
config.scale,
|
||||||
|
config.num_block,
|
||||||
|
config.num_in_ch,
|
||||||
|
config.num_out_ch,
|
||||||
|
config.num_feat,
|
||||||
|
config.num_grow_ch);
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class ResidualDenseBlock : public GGMLBlock {
|
class ResidualDenseBlock : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
int num_feat;
|
int num_feat;
|
||||||
@ -83,34 +157,29 @@ public:
|
|||||||
|
|
||||||
class RRDBNet : public GGMLBlock {
|
class RRDBNet : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
int scale = 4;
|
ESRGANConfig config;
|
||||||
int num_block = 23;
|
|
||||||
int num_in_ch = 3;
|
|
||||||
int num_out_ch = 3;
|
|
||||||
int num_feat = 64;
|
|
||||||
int num_grow_ch = 32;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
RRDBNet(int scale, int num_block, int num_in_ch, int num_out_ch, int num_feat, int num_grow_ch)
|
explicit RRDBNet(ESRGANConfig config)
|
||||||
: scale(scale), num_block(num_block), num_in_ch(num_in_ch), num_out_ch(num_out_ch), num_feat(num_feat), num_grow_ch(num_grow_ch) {
|
: config(std::move(config)) {
|
||||||
blocks["conv_first"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_in_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv_first"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.num_in_ch, this->config.num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||||
for (int i = 0; i < num_block; i++) {
|
for (int i = 0; i < this->config.num_block; i++) {
|
||||||
std::string name = "body." + std::to_string(i);
|
std::string name = "body." + std::to_string(i);
|
||||||
blocks[name] = std::shared_ptr<GGMLBlock>(new RRDB(num_feat, num_grow_ch));
|
blocks[name] = std::shared_ptr<GGMLBlock>(new RRDB(this->config.num_feat, this->config.num_grow_ch));
|
||||||
}
|
}
|
||||||
blocks["conv_body"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv_body"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.num_feat, this->config.num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||||
if (scale >= 2) {
|
if (this->config.scale >= 2) {
|
||||||
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.num_feat, this->config.num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||||
}
|
}
|
||||||
if (scale == 4) {
|
if (this->config.scale == 4) {
|
||||||
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.num_feat, this->config.num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||||
}
|
}
|
||||||
blocks["conv_hr"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv_hr"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.num_feat, this->config.num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||||
blocks["conv_last"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_out_ch, {3, 3}, {1, 1}, {1, 1}));
|
blocks["conv_last"] = std::shared_ptr<GGMLBlock>(new Conv2d(this->config.num_feat, this->config.num_out_ch, {3, 3}, {1, 1}, {1, 1}));
|
||||||
}
|
}
|
||||||
|
|
||||||
int get_scale() { return scale; }
|
int get_scale() { return config.scale; }
|
||||||
int get_num_block() { return num_block; }
|
int get_num_block() { return config.num_block; }
|
||||||
|
|
||||||
ggml_tensor* lrelu(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
ggml_tensor* lrelu(GGMLRunnerContext* ctx, ggml_tensor* x) {
|
||||||
return ggml_leaky_relu(ctx->ggml_ctx, x, 0.2f, true);
|
return ggml_leaky_relu(ctx->ggml_ctx, x, 0.2f, true);
|
||||||
@ -127,7 +196,7 @@ public:
|
|||||||
auto feat = conv_first->forward(ctx, x);
|
auto feat = conv_first->forward(ctx, x);
|
||||||
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.prelude", "feat");
|
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.prelude", "feat");
|
||||||
auto body_feat = feat;
|
auto body_feat = feat;
|
||||||
for (int i = 0; i < num_block; i++) {
|
for (int i = 0; i < config.num_block; i++) {
|
||||||
std::string name = "body." + std::to_string(i);
|
std::string name = "body." + std::to_string(i);
|
||||||
auto block = std::dynamic_pointer_cast<RRDB>(blocks[name]);
|
auto block = std::dynamic_pointer_cast<RRDB>(blocks[name]);
|
||||||
|
|
||||||
@ -138,11 +207,11 @@ public:
|
|||||||
feat = ggml_add(ctx->ggml_ctx, feat, body_feat);
|
feat = ggml_add(ctx->ggml_ctx, feat, body_feat);
|
||||||
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.body.out", "feat");
|
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.body.out", "feat");
|
||||||
// upsample
|
// upsample
|
||||||
if (scale >= 2) {
|
if (config.scale >= 2) {
|
||||||
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
|
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
|
||||||
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
||||||
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.up1", "feat");
|
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.up1", "feat");
|
||||||
if (scale == 4) {
|
if (config.scale == 4) {
|
||||||
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
|
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
|
||||||
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx->ggml_ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
||||||
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.up2", "feat");
|
sd::ggml_graph_cut::mark_graph_cut(feat, "esrgan.up2", "feat");
|
||||||
@ -156,201 +225,28 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct ESRGAN : public GGMLRunner {
|
struct ESRGAN : public GGMLRunner {
|
||||||
|
ESRGANConfig config;
|
||||||
std::unique_ptr<RRDBNet> rrdb_net;
|
std::unique_ptr<RRDBNet> rrdb_net;
|
||||||
int scale = 4;
|
|
||||||
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
|
|
||||||
|
|
||||||
ESRGAN(ggml_backend_t backend,
|
ESRGAN(ggml_backend_t backend,
|
||||||
ggml_backend_t params_backend,
|
ggml_backend_t params_backend,
|
||||||
int tile_size = 128,
|
|
||||||
const String2TensorStorage& tensor_storage_map = {})
|
const String2TensorStorage& tensor_storage_map = {})
|
||||||
: GGMLRunner(backend, params_backend) {
|
: GGMLRunner(backend, params_backend),
|
||||||
this->tile_size = tile_size;
|
config(ESRGANConfig::detect_from_weights(tensor_storage_map)),
|
||||||
|
rrdb_net(std::make_unique<RRDBNet>(config)) {
|
||||||
|
rrdb_net->init(params_ctx, tensor_storage_map, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string get_desc() override {
|
std::string get_desc() override {
|
||||||
return "esrgan";
|
return "esrgan";
|
||||||
}
|
}
|
||||||
|
|
||||||
bool load_from_file(const std::string& file_path, int n_threads) {
|
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) {
|
||||||
LOG_INFO("loading esrgan from '%s'", file_path.c_str());
|
if (!rrdb_net) {
|
||||||
|
return;
|
||||||
ModelLoader model_loader;
|
|
||||||
if (!model_loader.init_from_file_and_convert_name(file_path)) {
|
|
||||||
LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get tensor names
|
rrdb_net->get_param_tensors(tensors);
|
||||||
auto tensor_names = model_loader.get_tensor_names();
|
|
||||||
|
|
||||||
// Detect if it's ESRGAN format
|
|
||||||
bool is_ESRGAN = std::find(tensor_names.begin(), tensor_names.end(), "model.0.weight") != tensor_names.end();
|
|
||||||
|
|
||||||
// Detect parameters from tensor names
|
|
||||||
int detected_num_block = 0;
|
|
||||||
if (is_ESRGAN) {
|
|
||||||
for (const auto& name : tensor_names) {
|
|
||||||
if (name.find("model.1.sub.") == 0) {
|
|
||||||
size_t first_dot = name.find('.', 12);
|
|
||||||
if (first_dot != std::string::npos) {
|
|
||||||
size_t second_dot = name.find('.', first_dot + 1);
|
|
||||||
if (second_dot != std::string::npos && name.substr(first_dot + 1, 3) == "RDB") {
|
|
||||||
try {
|
|
||||||
int idx = std::stoi(name.substr(12, first_dot - 12));
|
|
||||||
detected_num_block = std::max(detected_num_block, idx + 1);
|
|
||||||
} catch (...) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Original format
|
|
||||||
for (const auto& name : tensor_names) {
|
|
||||||
if (name.find("body.") == 0) {
|
|
||||||
size_t pos = name.find('.', 5);
|
|
||||||
if (pos != std::string::npos) {
|
|
||||||
try {
|
|
||||||
int idx = std::stoi(name.substr(5, pos - 5));
|
|
||||||
detected_num_block = std::max(detected_num_block, idx + 1);
|
|
||||||
} catch (...) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int detected_scale = 4; // default
|
|
||||||
if (is_ESRGAN) {
|
|
||||||
// For ESRGAN format, detect scale by highest model number
|
|
||||||
int max_model_num = 0;
|
|
||||||
for (const auto& name : tensor_names) {
|
|
||||||
if (name.find("model.") == 0) {
|
|
||||||
size_t dot_pos = name.find('.', 6);
|
|
||||||
if (dot_pos != std::string::npos) {
|
|
||||||
try {
|
|
||||||
int num = std::stoi(name.substr(6, dot_pos - 6));
|
|
||||||
max_model_num = std::max(max_model_num, num);
|
|
||||||
} catch (...) {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (max_model_num <= 4) {
|
|
||||||
detected_scale = 1;
|
|
||||||
} else if (max_model_num <= 7) {
|
|
||||||
detected_scale = 2;
|
|
||||||
} else {
|
|
||||||
detected_scale = 4;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Original format
|
|
||||||
bool has_conv_up2 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) {
|
|
||||||
return name == "conv_up2.weight";
|
|
||||||
});
|
|
||||||
bool has_conv_up1 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) {
|
|
||||||
return name == "conv_up1.weight";
|
|
||||||
});
|
|
||||||
if (has_conv_up2) {
|
|
||||||
detected_scale = 4;
|
|
||||||
} else if (has_conv_up1) {
|
|
||||||
detected_scale = 2;
|
|
||||||
} else {
|
|
||||||
detected_scale = 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int detected_num_in_ch = 3;
|
|
||||||
int detected_num_out_ch = 3;
|
|
||||||
int detected_num_feat = 64;
|
|
||||||
int detected_num_grow_ch = 32;
|
|
||||||
|
|
||||||
// Create RRDBNet with detected parameters
|
|
||||||
rrdb_net = std::make_unique<RRDBNet>(detected_scale, detected_num_block, detected_num_in_ch, detected_num_out_ch, detected_num_feat, detected_num_grow_ch);
|
|
||||||
rrdb_net->init(params_ctx, {}, "");
|
|
||||||
|
|
||||||
if (!alloc_params_buffer()) {
|
|
||||||
LOG_ERROR("esrgan model buffer allocation failed");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<std::string, ggml_tensor*> esrgan_tensors;
|
|
||||||
rrdb_net->get_param_tensors(esrgan_tensors);
|
|
||||||
|
|
||||||
bool success;
|
|
||||||
if (is_ESRGAN) {
|
|
||||||
// Build name mapping for ESRGAN format
|
|
||||||
std::map<std::string, std::string> expected_to_model;
|
|
||||||
expected_to_model["conv_first.weight"] = "model.0.weight";
|
|
||||||
expected_to_model["conv_first.bias"] = "model.0.bias";
|
|
||||||
|
|
||||||
for (int i = 0; i < detected_num_block; i++) {
|
|
||||||
for (int j = 1; j <= 3; j++) {
|
|
||||||
for (int k = 1; k <= 5; k++) {
|
|
||||||
std::string expected_weight = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".weight";
|
|
||||||
std::string model_weight = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.weight";
|
|
||||||
expected_to_model[expected_weight] = model_weight;
|
|
||||||
|
|
||||||
std::string expected_bias = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".bias";
|
|
||||||
std::string model_bias = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.bias";
|
|
||||||
expected_to_model[expected_bias] = model_bias;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (detected_scale == 1) {
|
|
||||||
expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight";
|
|
||||||
expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias";
|
|
||||||
expected_to_model["conv_hr.weight"] = "model.2.weight";
|
|
||||||
expected_to_model["conv_hr.bias"] = "model.2.bias";
|
|
||||||
expected_to_model["conv_last.weight"] = "model.4.weight";
|
|
||||||
expected_to_model["conv_last.bias"] = "model.4.bias";
|
|
||||||
} else {
|
|
||||||
expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight";
|
|
||||||
expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias";
|
|
||||||
if (detected_scale >= 2) {
|
|
||||||
expected_to_model["conv_up1.weight"] = "model.3.weight";
|
|
||||||
expected_to_model["conv_up1.bias"] = "model.3.bias";
|
|
||||||
}
|
|
||||||
if (detected_scale == 4) {
|
|
||||||
expected_to_model["conv_up2.weight"] = "model.6.weight";
|
|
||||||
expected_to_model["conv_up2.bias"] = "model.6.bias";
|
|
||||||
expected_to_model["conv_hr.weight"] = "model.8.weight";
|
|
||||||
expected_to_model["conv_hr.bias"] = "model.8.bias";
|
|
||||||
expected_to_model["conv_last.weight"] = "model.10.weight";
|
|
||||||
expected_to_model["conv_last.bias"] = "model.10.bias";
|
|
||||||
} else if (detected_scale == 2) {
|
|
||||||
expected_to_model["conv_hr.weight"] = "model.5.weight";
|
|
||||||
expected_to_model["conv_hr.bias"] = "model.5.bias";
|
|
||||||
expected_to_model["conv_last.weight"] = "model.7.weight";
|
|
||||||
expected_to_model["conv_last.bias"] = "model.7.bias";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<std::string, ggml_tensor*> model_tensors;
|
|
||||||
for (auto& p : esrgan_tensors) {
|
|
||||||
auto it = expected_to_model.find(p.first);
|
|
||||||
if (it != expected_to_model.end()) {
|
|
||||||
model_tensors[it->second] = p.second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
model_loader.set_n_threads(n_threads);
|
|
||||||
success = model_loader.load_tensors(model_tensors);
|
|
||||||
} else {
|
|
||||||
model_loader.set_n_threads(n_threads);
|
|
||||||
success = model_loader.load_tensors(esrgan_tensors);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!success) {
|
|
||||||
LOG_ERROR("load esrgan tensors from model loader failed");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
scale = rrdb_net->get_scale();
|
|
||||||
LOG_INFO("esrgan model loaded with scale=%d, num_block=%d", scale, detected_num_block);
|
|
||||||
return success;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor) {
|
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor) {
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
#ifndef __SD_MODEL_UPSCALER_LTX_LATENT_UPSCALER_HPP__
|
#ifndef __SD_MODEL_UPSCALER_LTX_LATENT_UPSCALER_HPP__
|
||||||
#define __SD_MODEL_UPSCALER_LTX_LATENT_UPSCALER_HPP__
|
#define __SD_MODEL_UPSCALER_LTX_LATENT_UPSCALER_HPP__
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdlib>
|
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
@ -32,70 +32,69 @@ namespace LTXVUpsampler {
|
|||||||
int spatial_up_num = 2;
|
int spatial_up_num = 2;
|
||||||
int spatial_down_den = 1;
|
int spatial_down_den = 1;
|
||||||
int temporal_up_factor = 1;
|
int temporal_up_factor = 1;
|
||||||
|
|
||||||
|
static LatentUpsamplerConfig detect_from_weights(const String2TensorStorage& tensor_storage_map,
|
||||||
|
const std::string& prefix = "") {
|
||||||
|
LatentUpsamplerConfig config;
|
||||||
|
auto find_weight = [&](const std::string& suffix) -> const TensorStorage* {
|
||||||
|
std::string name = prefix.empty() ? suffix : prefix + "." + suffix;
|
||||||
|
auto iter = tensor_storage_map.find(name);
|
||||||
|
if (iter == tensor_storage_map.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return &iter->second;
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline bool has_tensor(const String2TensorStorage& tensor_storage_map,
|
bool inferred = false;
|
||||||
const std::string& name) {
|
|
||||||
return tensor_storage_map.find(name) != tensor_storage_map.end();
|
const TensorStorage* initial_norm = find_weight("initial_norm.weight");
|
||||||
|
if (initial_norm != nullptr) {
|
||||||
|
config.mid_channels = initial_norm->ne[0];
|
||||||
|
inferred = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline int64_t get_tensor_ne(const String2TensorStorage& tensor_storage_map,
|
const TensorStorage* final_conv = find_weight("final_conv.bias");
|
||||||
const std::string& name,
|
if (final_conv != nullptr) {
|
||||||
int axis,
|
config.in_channels = final_conv->ne[0];
|
||||||
int64_t fallback) {
|
inferred = true;
|
||||||
auto it = tensor_storage_map.find(name);
|
|
||||||
if (it == tensor_storage_map.end() || axis < 0 || axis >= GGML_MAX_DIMS) {
|
|
||||||
return fallback;
|
|
||||||
}
|
|
||||||
return it->second.ne[axis];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline int64_t get_tensor_ne0(const String2TensorStorage& tensor_storage_map,
|
int detected_blocks = 0;
|
||||||
const std::string& name,
|
const std::string res_blocks_prefix = prefix.empty() ? "res_blocks." : prefix + ".res_blocks.";
|
||||||
int64_t fallback) {
|
for (const auto& [name, _] : tensor_storage_map) {
|
||||||
return get_tensor_ne(tensor_storage_map, name, 0, fallback);
|
if (!starts_with(name, res_blocks_prefix)) {
|
||||||
}
|
|
||||||
|
|
||||||
static inline int count_module_blocks(const String2TensorStorage& tensor_storage_map,
|
|
||||||
const std::string& module_name) {
|
|
||||||
int max_block = -1;
|
|
||||||
const std::string prefix = module_name + ".";
|
|
||||||
for (const auto& pair : tensor_storage_map) {
|
|
||||||
const std::string& name = pair.first;
|
|
||||||
if (name.find(prefix) != 0) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
size_t begin = prefix.size();
|
size_t begin = res_blocks_prefix.size();
|
||||||
size_t end = name.find('.', begin);
|
size_t end = name.find('.', begin);
|
||||||
if (end == std::string::npos) {
|
if (end == std::string::npos) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
int index = atoi(name.substr(begin, end - begin).c_str());
|
try {
|
||||||
max_block = std::max(max_block, index);
|
int idx = std::stoi(name.substr(begin, end - begin));
|
||||||
|
detected_blocks = std::max(detected_blocks, idx + 1);
|
||||||
|
} catch (...) {
|
||||||
}
|
}
|
||||||
return max_block + 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline LatentUpsamplerConfig detect_config_from_weights(const String2TensorStorage& tensor_storage_map) {
|
|
||||||
LatentUpsamplerConfig config;
|
|
||||||
config.mid_channels = get_tensor_ne0(tensor_storage_map, "initial_norm.weight", config.mid_channels);
|
|
||||||
config.in_channels = get_tensor_ne0(tensor_storage_map, "final_conv.bias", config.in_channels);
|
|
||||||
int detected_blocks = count_module_blocks(tensor_storage_map, "res_blocks");
|
|
||||||
if (detected_blocks > 0) {
|
if (detected_blocks > 0) {
|
||||||
config.num_blocks_per_stage = detected_blocks;
|
config.num_blocks_per_stage = detected_blocks;
|
||||||
|
inferred = true;
|
||||||
}
|
}
|
||||||
config.rational_resampler = has_tensor(tensor_storage_map, "upsampler.conv.weight");
|
|
||||||
int64_t upsampler_out_channels = get_tensor_ne0(tensor_storage_map, "upsampler.0.bias", 0);
|
const TensorStorage* rational_upsampler_weight = find_weight("upsampler.conv.weight");
|
||||||
|
const TensorStorage* upsampler_bias = find_weight("upsampler.0.bias");
|
||||||
|
config.rational_resampler = rational_upsampler_weight != nullptr;
|
||||||
|
int64_t upsampler_out_channels = upsampler_bias == nullptr ? 0 : upsampler_bias->ne[0];
|
||||||
config.spatial_upsample = config.rational_resampler || upsampler_out_channels == 4 * config.mid_channels;
|
config.spatial_upsample = config.rational_resampler || upsampler_out_channels == 4 * config.mid_channels;
|
||||||
config.temporal_upsample = upsampler_out_channels == 2 * config.mid_channels;
|
config.temporal_upsample = upsampler_out_channels == 2 * config.mid_channels;
|
||||||
|
if (config.rational_resampler || upsampler_out_channels > 0) {
|
||||||
|
inferred = true;
|
||||||
|
}
|
||||||
if (config.temporal_upsample) {
|
if (config.temporal_upsample) {
|
||||||
config.temporal_up_factor = 2;
|
config.temporal_up_factor = 2;
|
||||||
}
|
}
|
||||||
if (config.rational_resampler) {
|
if (rational_upsampler_weight != nullptr) {
|
||||||
int64_t out_channels = get_tensor_ne(tensor_storage_map,
|
int64_t out_channels = rational_upsampler_weight->ne[3];
|
||||||
"upsampler.conv.weight",
|
|
||||||
3,
|
|
||||||
config.mid_channels * 9);
|
|
||||||
if (config.mid_channels > 0 && out_channels % config.mid_channels == 0) {
|
if (config.mid_channels > 0 && out_channels % config.mid_channels == 0) {
|
||||||
int64_t ratio = out_channels / config.mid_channels;
|
int64_t ratio = out_channels / config.mid_channels;
|
||||||
int num = static_cast<int>(std::round(std::sqrt(static_cast<double>(ratio))));
|
int num = static_cast<int>(std::round(std::sqrt(static_cast<double>(ratio))));
|
||||||
@ -114,8 +113,19 @@ namespace LTXVUpsampler {
|
|||||||
config.spatial_scale = static_cast<float>(config.spatial_up_num);
|
config.spatial_scale = static_cast<float>(config.spatial_up_num);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (inferred) {
|
||||||
|
LOG_DEBUG("ltx latent upsampler: in_channels = %" PRId64 ", mid_channels = %" PRId64 ", num_blocks_per_stage = %d, spatial_scale = %.3f, temporal_up_factor = %d, rational_resampler = %d",
|
||||||
|
config.in_channels,
|
||||||
|
config.mid_channels,
|
||||||
|
config.num_blocks_per_stage,
|
||||||
|
config.spatial_scale,
|
||||||
|
config.temporal_up_factor,
|
||||||
|
config.rational_resampler);
|
||||||
|
}
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class VideoGroupNorm : public GGMLBlock {
|
class VideoGroupNorm : public GGMLBlock {
|
||||||
protected:
|
protected:
|
||||||
@ -419,34 +429,14 @@ namespace LTXVUpsampler {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct LatentUpsamplerRunner : public GGMLRunner {
|
struct LatentUpsamplerRunner : public GGMLRunner {
|
||||||
|
LatentUpsamplerConfig config;
|
||||||
std::unique_ptr<LatentUpsampler> model;
|
std::unique_ptr<LatentUpsampler> model;
|
||||||
|
|
||||||
LatentUpsamplerRunner(ggml_backend_t backend,
|
LatentUpsamplerRunner(ggml_backend_t backend,
|
||||||
ggml_backend_t params_backend)
|
ggml_backend_t params_backend,
|
||||||
: GGMLRunner(backend, params_backend) {}
|
const String2TensorStorage& tensor_storage_map)
|
||||||
|
: GGMLRunner(backend, params_backend),
|
||||||
std::string get_desc() override {
|
config(LatentUpsamplerConfig::detect_from_weights(tensor_storage_map)) {
|
||||||
return "ltx_latent_upsampler";
|
|
||||||
}
|
|
||||||
|
|
||||||
bool load_from_file(const std::string& file_path, int n_threads) {
|
|
||||||
LOG_INFO("loading LTX latent upsampler from '%s'", file_path.c_str());
|
|
||||||
ModelLoader model_loader;
|
|
||||||
if (!model_loader.init_from_file(file_path)) {
|
|
||||||
LOG_ERROR("init LTX latent upsampler model loader from file failed: '%s'", file_path.c_str());
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto& tensor_storage_map = model_loader.get_tensor_storage_map();
|
|
||||||
bool has_regular_upsampler = has_tensor(tensor_storage_map, "upsampler.0.weight");
|
|
||||||
bool has_rational_spatial = has_tensor(tensor_storage_map, "upsampler.conv.weight");
|
|
||||||
if (!has_tensor(tensor_storage_map, "post_upsample_res_blocks.0.conv2.bias") ||
|
|
||||||
(!has_regular_upsampler && !has_rational_spatial)) {
|
|
||||||
LOG_ERROR("unsupported LTX latent upsampler weights: expected upsampler tensors");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
LatentUpsamplerConfig config = detect_config_from_weights(tensor_storage_map);
|
|
||||||
if (config.dims != 3 || (!config.spatial_upsample && !config.temporal_upsample) ||
|
if (config.dims != 3 || (!config.spatial_upsample && !config.temporal_upsample) ||
|
||||||
config.spatial_up_num < 1 || config.spatial_down_den < 1 || config.temporal_up_factor < 1) {
|
config.spatial_up_num < 1 || config.spatial_down_den < 1 || config.temporal_up_factor < 1) {
|
||||||
LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f temporal_factor=%d",
|
LOG_ERROR("unsupported LTX latent upsampler config: dims=%d spatial=%d temporal=%d rational=%d scale=%.3f temporal_factor=%d",
|
||||||
@ -456,35 +446,21 @@ namespace LTXVUpsampler {
|
|||||||
config.rational_resampler,
|
config.rational_resampler,
|
||||||
config.spatial_scale,
|
config.spatial_scale,
|
||||||
config.temporal_up_factor);
|
config.temporal_up_factor);
|
||||||
return false;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
model = std::make_unique<LatentUpsampler>(config);
|
model = std::make_unique<LatentUpsampler>(config);
|
||||||
model->init(params_ctx, tensor_storage_map, "");
|
model->init(params_ctx, tensor_storage_map, "");
|
||||||
if (!alloc_params_buffer()) {
|
|
||||||
LOG_ERROR("LTX latent upsampler params buffer allocation failed");
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::map<std::string, ggml_tensor*> tensors;
|
std::string get_desc() override {
|
||||||
|
return "ltx_latent_upsampler";
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) {
|
||||||
|
if (model) {
|
||||||
model->get_param_tensors(tensors);
|
model->get_param_tensors(tensors);
|
||||||
std::set<std::string> ignore_tensors;
|
|
||||||
if (config.rational_resampler) {
|
|
||||||
ignore_tensors.insert("upsampler.blur_down.kernel");
|
|
||||||
}
|
}
|
||||||
model_loader.set_n_threads(n_threads);
|
|
||||||
if (!model_loader.load_tensors(tensors, ignore_tensors)) {
|
|
||||||
LOG_ERROR("load LTX latent upsampler tensors failed");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
LOG_INFO("LTX latent upsampler loaded: in_channels=%" PRId64 ", mid_channels=%" PRId64 ", blocks=%d, scale=%.3f, temporal_factor=%d, rational=%d",
|
|
||||||
config.in_channels,
|
|
||||||
config.mid_channels,
|
|
||||||
config.num_blocks_per_stage,
|
|
||||||
config.spatial_scale,
|
|
||||||
config.temporal_up_factor,
|
|
||||||
config.rational_resampler);
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor) {
|
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor) {
|
||||||
@ -515,9 +491,9 @@ namespace LTXVUpsampler {
|
|||||||
(long long)x.shape()[4]);
|
(long long)x.shape()[4]);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
if (x.shape()[3] != model->config.in_channels) {
|
if (x.shape()[3] != config.in_channels) {
|
||||||
LOG_ERROR("LTX latent upsampler expected %" PRId64 " channels, got %lld",
|
LOG_ERROR("LTX latent upsampler expected %" PRId64 " channels, got %lld",
|
||||||
model->config.in_channels,
|
config.in_channels,
|
||||||
(long long)x.shape()[3]);
|
(long long)x.shape()[3]);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -990,7 +990,46 @@ bool is_first_stage_model_name(const std::string& name) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static std::string convert_esrgan_tensor_name(std::string name) {
|
||||||
|
static std::unordered_map<std::string, std::string> esrgan_name_map;
|
||||||
|
|
||||||
|
if (esrgan_name_map.empty()) {
|
||||||
|
esrgan_name_map["model.0."] = "conv_first.";
|
||||||
|
|
||||||
|
constexpr int max_num_blocks = 64;
|
||||||
|
for (int i = 0; i < max_num_blocks; i++) {
|
||||||
|
std::string block_prefix = "model.1.sub." + std::to_string(i) + ".";
|
||||||
|
for (int rdb = 1; rdb <= 3; rdb++) {
|
||||||
|
for (int conv = 1; conv <= 5; conv++) {
|
||||||
|
esrgan_name_map[block_prefix + "RDB" + std::to_string(rdb) + ".conv" + std::to_string(conv) + ".0."] =
|
||||||
|
"body." + std::to_string(i) + ".rdb" + std::to_string(rdb) + ".conv" + std::to_string(conv) + ".";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
esrgan_name_map[block_prefix + "weight"] = "conv_body.weight";
|
||||||
|
esrgan_name_map[block_prefix + "bias"] = "conv_body.bias";
|
||||||
|
}
|
||||||
|
|
||||||
|
// RealESRGAN stores only the learned layers in a Sequential. These indices
|
||||||
|
// cover the common x1, x2 and x4 layouts.
|
||||||
|
esrgan_name_map["model.2."] = "conv_hr.";
|
||||||
|
esrgan_name_map["model.3."] = "conv_up1.";
|
||||||
|
esrgan_name_map["model.4."] = "conv_last.";
|
||||||
|
esrgan_name_map["model.5."] = "conv_hr.";
|
||||||
|
esrgan_name_map["model.6."] = "conv_up2.";
|
||||||
|
esrgan_name_map["model.7."] = "conv_last.";
|
||||||
|
esrgan_name_map["model.8."] = "conv_hr.";
|
||||||
|
esrgan_name_map["model.10."] = "conv_last.";
|
||||||
|
}
|
||||||
|
|
||||||
|
replace_with_prefix_map(name, esrgan_name_map);
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
std::string convert_tensor_name(std::string name, SDVersion version) {
|
std::string convert_tensor_name(std::string name, SDVersion version) {
|
||||||
|
if (version == VERSION_ESRGAN) {
|
||||||
|
return convert_esrgan_tensor_name(std::move(name));
|
||||||
|
}
|
||||||
|
|
||||||
bool is_lora = false;
|
bool is_lora = false;
|
||||||
bool is_lycoris_underline = false;
|
bool is_lycoris_underline = false;
|
||||||
bool is_underline = false;
|
bool is_underline = false;
|
||||||
|
|||||||
@ -90,6 +90,7 @@ const char* model_version_to_str[] = {
|
|||||||
"Longcat-Image",
|
"Longcat-Image",
|
||||||
"PiD",
|
"PiD",
|
||||||
"Ideogram 4",
|
"Ideogram 4",
|
||||||
|
"ESRGAN",
|
||||||
};
|
};
|
||||||
|
|
||||||
const char* sampling_methods_str[] = {
|
const char* sampling_methods_str[] = {
|
||||||
@ -4996,17 +4997,41 @@ static sd::Tensor<float> upscale_ltx_spatial_video_latent(sd_ctx_t* sd_ctx,
|
|||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto upsampler_manager = std::make_shared<ModelManager>();
|
||||||
|
upsampler_manager->set_n_threads(sd_ctx->sd->n_threads);
|
||||||
|
upsampler_manager->set_enable_mmap(sd_ctx->sd->enable_mmap);
|
||||||
|
ModelLoader& model_loader = upsampler_manager->loader();
|
||||||
|
if (!model_loader.init_from_file(model_path)) {
|
||||||
|
LOG_ERROR("init LTX latent upsampler model loader from file failed: '%s'", model_path);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<LTXVUpsampler::LatentUpsamplerRunner> upsampler =
|
std::unique_ptr<LTXVUpsampler::LatentUpsamplerRunner> upsampler =
|
||||||
std::make_unique<LTXVUpsampler::LatentUpsamplerRunner>(sd_ctx->sd->backend_for(SDBackendModule::UPSCALER),
|
std::make_unique<LTXVUpsampler::LatentUpsamplerRunner>(sd_ctx->sd->backend_for(SDBackendModule::UPSCALER),
|
||||||
sd_ctx->sd->backend_for(SDBackendModule::UPSCALER));
|
sd_ctx->sd->params_backend_for(SDBackendModule::UPSCALER),
|
||||||
|
model_loader.get_tensor_storage_map());
|
||||||
const size_t max_graph_vram_bytes = sd::ggml_graph_cut::max_vram_gib_to_bytes(sd_ctx->sd->max_vram);
|
const size_t max_graph_vram_bytes = sd::ggml_graph_cut::max_vram_gib_to_bytes(sd_ctx->sd->max_vram);
|
||||||
upsampler->set_max_graph_vram_bytes(max_graph_vram_bytes);
|
upsampler->set_max_graph_vram_bytes(max_graph_vram_bytes);
|
||||||
if (!upsampler->load_from_file(model_path, sd_ctx->sd->n_threads)) {
|
if (upsampler->model == nullptr) {
|
||||||
LOG_ERROR("load LTX latent upsampler failed");
|
LOG_ERROR("init LTX latent upsampler from metadata failed");
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<std::string, ggml_tensor*> tensors;
|
||||||
|
upsampler->get_param_tensors(tensors);
|
||||||
|
upsampler->set_weight_manager(upsampler_manager);
|
||||||
|
if (!upsampler_manager->register_param_tensors("LTX latent upsampler",
|
||||||
|
std::move(tensors),
|
||||||
|
ModelManager::ResidencyMode::Resident,
|
||||||
|
sd_ctx->sd->backend_for(SDBackendModule::UPSCALER),
|
||||||
|
sd_ctx->sd->params_backend_for(SDBackendModule::UPSCALER)) ||
|
||||||
|
!upsampler_manager->validate_registered_tensors()) {
|
||||||
|
LOG_ERROR("register LTX latent upsampler tensors with model manager failed");
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
sd::Tensor<float> upscaled = upsampler->compute(sd_ctx->sd->n_threads, unnormalized);
|
sd::Tensor<float> upscaled = upsampler->compute(sd_ctx->sd->n_threads, unnormalized);
|
||||||
|
upsampler_manager.reset();
|
||||||
upsampler.reset();
|
upsampler.reset();
|
||||||
if (upscaled.empty()) {
|
if (upscaled.empty()) {
|
||||||
LOG_ERROR("LTX latent spatial upscale failed");
|
LOG_ERROR("LTX latent spatial upscale failed");
|
||||||
|
|||||||
@ -18,6 +18,12 @@ UpscalerGGML::UpscalerGGML(int n_threads,
|
|||||||
params_backend_spec(std::move(params_backend_spec)) {
|
params_backend_spec(std::move(params_backend_spec)) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
UpscalerGGML::~UpscalerGGML() {
|
||||||
|
// ModelManager holds raw ggml tensor pointers owned by the runner context.
|
||||||
|
model_manager.reset();
|
||||||
|
esrgan_upscaler.reset();
|
||||||
|
}
|
||||||
|
|
||||||
void UpscalerGGML::set_max_graph_vram_bytes(size_t max_vram_bytes) {
|
void UpscalerGGML::set_max_graph_vram_bytes(size_t max_vram_bytes) {
|
||||||
max_graph_vram_bytes = max_vram_bytes;
|
max_graph_vram_bytes = max_vram_bytes;
|
||||||
if (esrgan_upscaler) {
|
if (esrgan_upscaler) {
|
||||||
@ -72,22 +78,40 @@ bool UpscalerGGML::load_from_file(const std::string& esrgan_path,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ModelLoader model_loader;
|
model_manager = std::make_shared<ModelManager>();
|
||||||
if (!model_loader.init_from_file_and_convert_name(esrgan_path)) {
|
model_manager->set_n_threads(n_threads);
|
||||||
|
model_manager->set_enable_mmap(false);
|
||||||
|
|
||||||
|
ModelLoader& model_loader = model_manager->loader();
|
||||||
|
if (!model_loader.init_from_file_and_convert_name(esrgan_path, "", VERSION_ESRGAN)) {
|
||||||
LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str());
|
LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str());
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
model_loader.set_wtype_override(model_data_type);
|
model_loader.set_wtype_override(model_data_type);
|
||||||
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
|
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
|
||||||
esrgan_upscaler = std::make_shared<ESRGAN>(backend_for(SDBackendModule::UPSCALER),
|
esrgan_upscaler = std::make_shared<ESRGAN>(backend_for(SDBackendModule::UPSCALER),
|
||||||
params_backend_for(SDBackendModule::UPSCALER),
|
params_backend_for(SDBackendModule::UPSCALER),
|
||||||
tile_size,
|
|
||||||
model_loader.get_tensor_storage_map());
|
model_loader.get_tensor_storage_map());
|
||||||
|
if (esrgan_upscaler == nullptr || esrgan_upscaler->rrdb_net == nullptr) {
|
||||||
|
LOG_ERROR("init esrgan model from metadata failed: '%s'", esrgan_path.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
esrgan_upscaler->set_max_graph_vram_bytes(max_graph_vram_bytes);
|
esrgan_upscaler->set_max_graph_vram_bytes(max_graph_vram_bytes);
|
||||||
esrgan_upscaler->set_stream_layers_enabled(stream_layers_enabled);
|
esrgan_upscaler->set_stream_layers_enabled(stream_layers_enabled);
|
||||||
if (direct) {
|
if (direct) {
|
||||||
esrgan_upscaler->set_conv2d_direct_enabled(true);
|
esrgan_upscaler->set_conv2d_direct_enabled(true);
|
||||||
}
|
}
|
||||||
if (!esrgan_upscaler->load_from_file(esrgan_path, n_threads)) {
|
|
||||||
|
std::map<std::string, ggml_tensor*> tensors;
|
||||||
|
esrgan_upscaler->get_param_tensors(tensors);
|
||||||
|
esrgan_upscaler->set_weight_manager(model_manager);
|
||||||
|
if (!model_manager->register_param_tensors("ESRGAN",
|
||||||
|
std::move(tensors),
|
||||||
|
ModelManager::ResidencyMode::Resident,
|
||||||
|
backend_for(SDBackendModule::UPSCALER),
|
||||||
|
params_backend_for(SDBackendModule::UPSCALER)) ||
|
||||||
|
!model_manager->validate_registered_tensors()) {
|
||||||
|
LOG_ERROR("register esrgan tensors with model manager failed");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -95,6 +119,7 @@ bool UpscalerGGML::load_from_file(const std::string& esrgan_path,
|
|||||||
|
|
||||||
sd::Tensor<float> UpscalerGGML::upscale_tensor(const sd::Tensor<float>& input_tensor) {
|
sd::Tensor<float> UpscalerGGML::upscale_tensor(const sd::Tensor<float>& input_tensor) {
|
||||||
sd::Tensor<float> upscaled;
|
sd::Tensor<float> upscaled;
|
||||||
|
const int scale = esrgan_upscaler->config.scale;
|
||||||
if (tile_size <= 0 || (input_tensor.shape()[0] <= tile_size && input_tensor.shape()[1] <= tile_size)) {
|
if (tile_size <= 0 || (input_tensor.shape()[0] <= tile_size && input_tensor.shape()[1] <= tile_size)) {
|
||||||
upscaled = esrgan_upscaler->compute(n_threads, input_tensor);
|
upscaled = esrgan_upscaler->compute(n_threads, input_tensor);
|
||||||
} else {
|
} else {
|
||||||
@ -108,9 +133,9 @@ sd::Tensor<float> UpscalerGGML::upscale_tensor(const sd::Tensor<float>& input_te
|
|||||||
};
|
};
|
||||||
|
|
||||||
upscaled = process_tiles_2d(input_tensor,
|
upscaled = process_tiles_2d(input_tensor,
|
||||||
static_cast<int>(input_tensor.shape()[0] * esrgan_upscaler->scale),
|
static_cast<int>(input_tensor.shape()[0] * scale),
|
||||||
static_cast<int>(input_tensor.shape()[1] * esrgan_upscaler->scale),
|
static_cast<int>(input_tensor.shape()[1] * scale),
|
||||||
esrgan_upscaler->scale,
|
scale,
|
||||||
tile_size,
|
tile_size,
|
||||||
tile_size,
|
tile_size,
|
||||||
0.25f,
|
0.25f,
|
||||||
@ -129,8 +154,9 @@ sd::Tensor<float> UpscalerGGML::upscale_tensor(const sd::Tensor<float>& input_te
|
|||||||
sd_image_t UpscalerGGML::upscale(sd_image_t input_image, uint32_t upscale_factor) {
|
sd_image_t UpscalerGGML::upscale(sd_image_t input_image, uint32_t upscale_factor) {
|
||||||
// upscale_factor, unused for RealESRGAN_x4plus_anime_6B.pth
|
// upscale_factor, unused for RealESRGAN_x4plus_anime_6B.pth
|
||||||
sd_image_t upscaled_image = {0, 0, 0, nullptr};
|
sd_image_t upscaled_image = {0, 0, 0, nullptr};
|
||||||
int output_width = (int)input_image.width * esrgan_upscaler->scale;
|
const int scale = esrgan_upscaler->config.scale;
|
||||||
int output_height = (int)input_image.height * esrgan_upscaler->scale;
|
int output_width = (int)input_image.width * scale;
|
||||||
|
int output_height = (int)input_image.height * scale;
|
||||||
LOG_INFO("upscaling from (%i x %i) to (%i x %i)",
|
LOG_INFO("upscaling from (%i x %i) to (%i x %i)",
|
||||||
input_image.width, input_image.height, output_width, output_height);
|
input_image.width, input_image.height, output_width, output_height);
|
||||||
|
|
||||||
@ -187,7 +213,7 @@ int get_upscale_factor(upscaler_ctx_t* upscaler_ctx) {
|
|||||||
if (upscaler_ctx == nullptr || upscaler_ctx->upscaler == nullptr || upscaler_ctx->upscaler->esrgan_upscaler == nullptr) {
|
if (upscaler_ctx == nullptr || upscaler_ctx->upscaler == nullptr || upscaler_ctx->upscaler->esrgan_upscaler == nullptr) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
return upscaler_ctx->upscaler->esrgan_upscaler->scale;
|
return upscaler_ctx->upscaler->esrgan_upscaler->config.scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx) {
|
void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx) {
|
||||||
|
|||||||
@ -4,6 +4,7 @@
|
|||||||
#include "core/ggml_extend_backend.h"
|
#include "core/ggml_extend_backend.h"
|
||||||
#include "core/tensor.hpp"
|
#include "core/tensor.hpp"
|
||||||
#include "model/upscaler/esrgan.hpp"
|
#include "model/upscaler/esrgan.hpp"
|
||||||
|
#include "model_manager.h"
|
||||||
#include "stable-diffusion.h"
|
#include "stable-diffusion.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -11,6 +12,7 @@
|
|||||||
|
|
||||||
struct UpscalerGGML {
|
struct UpscalerGGML {
|
||||||
SDBackendManager backend_manager;
|
SDBackendManager backend_manager;
|
||||||
|
std::shared_ptr<ModelManager> model_manager;
|
||||||
ggml_type model_data_type = GGML_TYPE_F16;
|
ggml_type model_data_type = GGML_TYPE_F16;
|
||||||
std::shared_ptr<ESRGAN> esrgan_upscaler;
|
std::shared_ptr<ESRGAN> esrgan_upscaler;
|
||||||
std::string esrgan_path;
|
std::string esrgan_path;
|
||||||
@ -27,6 +29,7 @@ struct UpscalerGGML {
|
|||||||
int tile_size = 128,
|
int tile_size = 128,
|
||||||
std::string backend_spec = "",
|
std::string backend_spec = "",
|
||||||
std::string params_backend_spec = "");
|
std::string params_backend_spec = "");
|
||||||
|
~UpscalerGGML();
|
||||||
|
|
||||||
bool load_from_file(const std::string& esrgan_path,
|
bool load_from_file(const std::string& esrgan_path,
|
||||||
bool offload_params_to_cpu,
|
bool offload_params_to_cpu,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user