mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
feat: add support for more esrgan models & x2 & x1 models (#855)
This commit is contained in:
parent
02af48a97f
commit
e70d0205ca
218
esrgan.hpp
218
esrgan.hpp
@ -83,39 +83,44 @@ public:
|
||||
|
||||
class RRDBNet : public GGMLBlock {
|
||||
protected:
|
||||
int scale = 4; // default RealESRGAN_x4plus_anime_6B
|
||||
int num_block = 6; // default RealESRGAN_x4plus_anime_6B
|
||||
int scale = 4;
|
||||
int num_block = 23;
|
||||
int num_in_ch = 3;
|
||||
int num_out_ch = 3;
|
||||
int num_feat = 64; // default RealESRGAN_x4plus_anime_6B
|
||||
int num_grow_ch = 32; // default RealESRGAN_x4plus_anime_6B
|
||||
int num_feat = 64;
|
||||
int num_grow_ch = 32;
|
||||
|
||||
public:
|
||||
RRDBNet() {
|
||||
RRDBNet(int scale, int num_block, int num_in_ch, int num_out_ch, int num_feat, int num_grow_ch)
|
||||
: scale(scale), num_block(num_block), num_in_ch(num_in_ch), num_out_ch(num_out_ch), num_feat(num_feat), num_grow_ch(num_grow_ch) {
|
||||
blocks["conv_first"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_in_ch, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||
for (int i = 0; i < num_block; i++) {
|
||||
std::string name = "body." + std::to_string(i);
|
||||
blocks[name] = std::shared_ptr<GGMLBlock>(new RRDB(num_feat, num_grow_ch));
|
||||
}
|
||||
blocks["conv_body"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||
// upsample
|
||||
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||
if (scale >= 2) {
|
||||
blocks["conv_up1"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||
}
|
||||
if (scale == 4) {
|
||||
blocks["conv_up2"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||
}
|
||||
blocks["conv_hr"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_feat, {3, 3}, {1, 1}, {1, 1}));
|
||||
blocks["conv_last"] = std::shared_ptr<GGMLBlock>(new Conv2d(num_feat, num_out_ch, {3, 3}, {1, 1}, {1, 1}));
|
||||
}
|
||||
|
||||
int get_scale() { return scale; }
|
||||
int get_num_block() { return num_block; }
|
||||
|
||||
struct ggml_tensor* lrelu(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
return ggml_leaky_relu(ctx, x, 0.2f, true);
|
||||
}
|
||||
|
||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||
// x: [n, num_in_ch, h, w]
|
||||
// return: [n, num_out_ch, h*4, w*4]
|
||||
// return: [n, num_out_ch, h*scale, w*scale]
|
||||
auto conv_first = std::dynamic_pointer_cast<Conv2d>(blocks["conv_first"]);
|
||||
auto conv_body = std::dynamic_pointer_cast<Conv2d>(blocks["conv_body"]);
|
||||
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
|
||||
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
|
||||
auto conv_hr = std::dynamic_pointer_cast<Conv2d>(blocks["conv_hr"]);
|
||||
auto conv_last = std::dynamic_pointer_cast<Conv2d>(blocks["conv_last"]);
|
||||
|
||||
@ -130,15 +135,22 @@ public:
|
||||
body_feat = conv_body->forward(ctx, body_feat);
|
||||
feat = ggml_add(ctx, feat, body_feat);
|
||||
// upsample
|
||||
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
||||
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
||||
if (scale >= 2) {
|
||||
auto conv_up1 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up1"]);
|
||||
feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
||||
if (scale == 4) {
|
||||
auto conv_up2 = std::dynamic_pointer_cast<Conv2d>(blocks["conv_up2"]);
|
||||
feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST)));
|
||||
}
|
||||
}
|
||||
// for all scales
|
||||
auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat)));
|
||||
return out;
|
||||
}
|
||||
};
|
||||
|
||||
struct ESRGAN : public GGMLRunner {
|
||||
RRDBNet rrdb_net;
|
||||
std::unique_ptr<RRDBNet> rrdb_net;
|
||||
int scale = 4;
|
||||
int tile_size = 128; // avoid cuda OOM for 4gb VRAM
|
||||
|
||||
@ -146,12 +158,14 @@ struct ESRGAN : public GGMLRunner {
|
||||
bool offload_params_to_cpu,
|
||||
const String2GGMLType& tensor_types = {})
|
||||
: GGMLRunner(backend, offload_params_to_cpu) {
|
||||
rrdb_net.init(params_ctx, tensor_types, "");
|
||||
// rrdb_net will be created in load_from_file
|
||||
}
|
||||
|
||||
void enable_conv2d_direct() {
|
||||
if (!rrdb_net)
|
||||
return;
|
||||
std::vector<GGMLBlock*> blocks;
|
||||
rrdb_net.get_all_blocks(blocks);
|
||||
rrdb_net->get_all_blocks(blocks);
|
||||
for (auto block : blocks) {
|
||||
if (block->get_desc() == "Conv2d") {
|
||||
auto conv_block = (Conv2d*)block;
|
||||
@ -167,31 +181,185 @@ struct ESRGAN : public GGMLRunner {
|
||||
bool load_from_file(const std::string& file_path, int n_threads) {
|
||||
LOG_INFO("loading esrgan from '%s'", file_path.c_str());
|
||||
|
||||
alloc_params_buffer();
|
||||
std::map<std::string, ggml_tensor*> esrgan_tensors;
|
||||
rrdb_net.get_param_tensors(esrgan_tensors);
|
||||
|
||||
ModelLoader model_loader;
|
||||
if (!model_loader.init_from_file(file_path)) {
|
||||
LOG_ERROR("init esrgan model loader from file failed: '%s'", file_path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool success = model_loader.load_tensors(esrgan_tensors, {}, n_threads);
|
||||
// Get tensor names
|
||||
auto tensor_names = model_loader.get_tensor_names();
|
||||
|
||||
// Detect if it's ESRGAN format
|
||||
bool is_ESRGAN = std::find(tensor_names.begin(), tensor_names.end(), "model.0.weight") != tensor_names.end();
|
||||
|
||||
// Detect parameters from tensor names
|
||||
int detected_num_block = 0;
|
||||
if (is_ESRGAN) {
|
||||
for (const auto& name : tensor_names) {
|
||||
if (name.find("model.1.sub.") == 0) {
|
||||
size_t first_dot = name.find('.', 12);
|
||||
if (first_dot != std::string::npos) {
|
||||
size_t second_dot = name.find('.', first_dot + 1);
|
||||
if (second_dot != std::string::npos && name.substr(first_dot + 1, 3) == "RDB") {
|
||||
try {
|
||||
int idx = std::stoi(name.substr(12, first_dot - 12));
|
||||
detected_num_block = std::max(detected_num_block, idx + 1);
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Original format
|
||||
for (const auto& name : tensor_names) {
|
||||
if (name.find("body.") == 0) {
|
||||
size_t pos = name.find('.', 5);
|
||||
if (pos != std::string::npos) {
|
||||
try {
|
||||
int idx = std::stoi(name.substr(5, pos - 5));
|
||||
detected_num_block = std::max(detected_num_block, idx + 1);
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int detected_scale = 4; // default
|
||||
if (is_ESRGAN) {
|
||||
// For ESRGAN format, detect scale by highest model number
|
||||
int max_model_num = 0;
|
||||
for (const auto& name : tensor_names) {
|
||||
if (name.find("model.") == 0) {
|
||||
size_t dot_pos = name.find('.', 6);
|
||||
if (dot_pos != std::string::npos) {
|
||||
try {
|
||||
int num = std::stoi(name.substr(6, dot_pos - 6));
|
||||
max_model_num = std::max(max_model_num, num);
|
||||
} catch (...) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (max_model_num <= 4) {
|
||||
detected_scale = 1;
|
||||
} else if (max_model_num <= 7) {
|
||||
detected_scale = 2;
|
||||
} else {
|
||||
detected_scale = 4;
|
||||
}
|
||||
} else {
|
||||
// Original format
|
||||
bool has_conv_up2 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) {
|
||||
return name == "conv_up2.weight";
|
||||
});
|
||||
bool has_conv_up1 = std::any_of(tensor_names.begin(), tensor_names.end(), [](const std::string& name) {
|
||||
return name == "conv_up1.weight";
|
||||
});
|
||||
if (has_conv_up2) {
|
||||
detected_scale = 4;
|
||||
} else if (has_conv_up1) {
|
||||
detected_scale = 2;
|
||||
} else {
|
||||
detected_scale = 1;
|
||||
}
|
||||
}
|
||||
|
||||
int detected_num_in_ch = 3;
|
||||
int detected_num_out_ch = 3;
|
||||
int detected_num_feat = 64;
|
||||
int detected_num_grow_ch = 32;
|
||||
|
||||
// Create RRDBNet with detected parameters
|
||||
rrdb_net = std::make_unique<RRDBNet>(detected_scale, detected_num_block, detected_num_in_ch, detected_num_out_ch, detected_num_feat, detected_num_grow_ch);
|
||||
rrdb_net->init(params_ctx, {}, "");
|
||||
|
||||
alloc_params_buffer();
|
||||
std::map<std::string, ggml_tensor*> esrgan_tensors;
|
||||
rrdb_net->get_param_tensors(esrgan_tensors);
|
||||
|
||||
bool success;
|
||||
if (is_ESRGAN) {
|
||||
// Build name mapping for ESRGAN format
|
||||
std::map<std::string, std::string> expected_to_model;
|
||||
expected_to_model["conv_first.weight"] = "model.0.weight";
|
||||
expected_to_model["conv_first.bias"] = "model.0.bias";
|
||||
|
||||
for (int i = 0; i < detected_num_block; i++) {
|
||||
for (int j = 1; j <= 3; j++) {
|
||||
for (int k = 1; k <= 5; k++) {
|
||||
std::string expected_weight = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".weight";
|
||||
std::string model_weight = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.weight";
|
||||
expected_to_model[expected_weight] = model_weight;
|
||||
|
||||
std::string expected_bias = "body." + std::to_string(i) + ".rdb" + std::to_string(j) + ".conv" + std::to_string(k) + ".bias";
|
||||
std::string model_bias = "model.1.sub." + std::to_string(i) + ".RDB" + std::to_string(j) + ".conv" + std::to_string(k) + ".0.bias";
|
||||
expected_to_model[expected_bias] = model_bias;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (detected_scale == 1) {
|
||||
expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight";
|
||||
expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias";
|
||||
expected_to_model["conv_hr.weight"] = "model.2.weight";
|
||||
expected_to_model["conv_hr.bias"] = "model.2.bias";
|
||||
expected_to_model["conv_last.weight"] = "model.4.weight";
|
||||
expected_to_model["conv_last.bias"] = "model.4.bias";
|
||||
} else {
|
||||
expected_to_model["conv_body.weight"] = "model.1.sub." + std::to_string(detected_num_block) + ".weight";
|
||||
expected_to_model["conv_body.bias"] = "model.1.sub." + std::to_string(detected_num_block) + ".bias";
|
||||
if (detected_scale >= 2) {
|
||||
expected_to_model["conv_up1.weight"] = "model.3.weight";
|
||||
expected_to_model["conv_up1.bias"] = "model.3.bias";
|
||||
}
|
||||
if (detected_scale == 4) {
|
||||
expected_to_model["conv_up2.weight"] = "model.6.weight";
|
||||
expected_to_model["conv_up2.bias"] = "model.6.bias";
|
||||
expected_to_model["conv_hr.weight"] = "model.8.weight";
|
||||
expected_to_model["conv_hr.bias"] = "model.8.bias";
|
||||
expected_to_model["conv_last.weight"] = "model.10.weight";
|
||||
expected_to_model["conv_last.bias"] = "model.10.bias";
|
||||
} else if (detected_scale == 2) {
|
||||
expected_to_model["conv_hr.weight"] = "model.5.weight";
|
||||
expected_to_model["conv_hr.bias"] = "model.5.bias";
|
||||
expected_to_model["conv_last.weight"] = "model.7.weight";
|
||||
expected_to_model["conv_last.bias"] = "model.7.bias";
|
||||
}
|
||||
}
|
||||
|
||||
std::map<std::string, ggml_tensor*> model_tensors;
|
||||
for (auto& p : esrgan_tensors) {
|
||||
auto it = expected_to_model.find(p.first);
|
||||
if (it != expected_to_model.end()) {
|
||||
model_tensors[it->second] = p.second;
|
||||
}
|
||||
}
|
||||
|
||||
success = model_loader.load_tensors(model_tensors, {}, n_threads);
|
||||
} else {
|
||||
success = model_loader.load_tensors(esrgan_tensors, {}, n_threads);
|
||||
}
|
||||
|
||||
if (!success) {
|
||||
LOG_ERROR("load esrgan tensors from model loader failed");
|
||||
return false;
|
||||
}
|
||||
|
||||
LOG_INFO("esrgan model loaded");
|
||||
scale = rrdb_net->get_scale();
|
||||
LOG_INFO("esrgan model loaded with scale=%d, num_block=%d", scale, detected_num_block);
|
||||
return success;
|
||||
}
|
||||
|
||||
struct ggml_cgraph* build_graph(struct ggml_tensor* x) {
|
||||
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
|
||||
x = to_backend(x);
|
||||
struct ggml_tensor* out = rrdb_net.forward(compute_ctx, x);
|
||||
if (!rrdb_net)
|
||||
return nullptr;
|
||||
constexpr int kGraphNodes = 1 << 16; // 65k
|
||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, kGraphNodes, /*grads*/ false);
|
||||
x = to_backend(x);
|
||||
struct ggml_tensor* out = rrdb_net->forward(compute_ctx, x);
|
||||
ggml_build_forward_expand(gf, out);
|
||||
return gf;
|
||||
}
|
||||
|
||||
8
model.h
8
model.h
@ -258,6 +258,14 @@ public:
|
||||
std::set<std::string> ignore_tensors = {},
|
||||
int n_threads = 0);
|
||||
|
||||
std::vector<std::string> get_tensor_names() const {
|
||||
std::vector<std::string> names;
|
||||
for (const auto& ts : tensor_storages) {
|
||||
names.push_back(ts.name);
|
||||
}
|
||||
return names;
|
||||
}
|
||||
|
||||
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules);
|
||||
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
|
||||
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
|
||||
|
||||
@ -283,6 +283,8 @@ SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx,
|
||||
sd_image_t input_image,
|
||||
uint32_t upscale_factor);
|
||||
|
||||
SD_API int get_upscale_factor(upscaler_ctx_t* upscaler_ctx);
|
||||
|
||||
SD_API bool convert(const char* input_path,
|
||||
const char* vae_path,
|
||||
const char* output_path,
|
||||
|
||||
@ -138,6 +138,13 @@ sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_
|
||||
return upscaler_ctx->upscaler->upscale(input_image, upscale_factor);
|
||||
}
|
||||
|
||||
int get_upscale_factor(upscaler_ctx_t* upscaler_ctx) {
|
||||
if (upscaler_ctx == NULL || upscaler_ctx->upscaler == NULL || upscaler_ctx->upscaler->esrgan_upscaler == NULL) {
|
||||
return 1;
|
||||
}
|
||||
return upscaler_ctx->upscaler->esrgan_upscaler->scale;
|
||||
}
|
||||
|
||||
void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx) {
|
||||
if (upscaler_ctx->upscaler != NULL) {
|
||||
delete upscaler_ctx->upscaler;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user