mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
Compare commits
No commits in common. "5900ef6605c6fbf7934239f795c13c97bc993853" and "6167e2927a0a60ec020bfd05307e2f050e76371c" have entirely different histories.
5900ef6605
...
6167e2927a
@ -119,10 +119,8 @@ endif()
|
|||||||
|
|
||||||
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
|
||||||
|
|
||||||
if (NOT SD_USE_SYSTEM_GGML)
|
# see https://github.com/ggerganov/ggml/pull/682
|
||||||
# see https://github.com/ggerganov/ggml/pull/682
|
add_definitions(-DGGML_MAX_NAME=128)
|
||||||
add_definitions(-DGGML_MAX_NAME=128)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# deps
|
# deps
|
||||||
# Only add ggml if it hasn't been added yet
|
# Only add ggml if it hasn't been added yet
|
||||||
|
|||||||
@ -341,10 +341,6 @@ arguments:
|
|||||||
--diffusion-fa use flash attention in the diffusion model (for low vram)
|
--diffusion-fa use flash attention in the diffusion model (for low vram)
|
||||||
Might lower quality, since it implies converting k and v to f16.
|
Might lower quality, since it implies converting k and v to f16.
|
||||||
This might crash if it is not supported by the backend.
|
This might crash if it is not supported by the backend.
|
||||||
--diffusion-conv-direct use Conv2d direct in the diffusion model
|
|
||||||
This might crash if it is not supported by the backend.
|
|
||||||
--vae-conv-direct use Conv2d direct in the vae model (should improve the performance)
|
|
||||||
This might crash if it is not supported by the backend.
|
|
||||||
--control-net-cpu keep controlnet in cpu (for low vram)
|
--control-net-cpu keep controlnet in cpu (for low vram)
|
||||||
--canny apply canny preprocessor (edge detection)
|
--canny apply canny preprocessor (edge detection)
|
||||||
--color colors the logging tags according to level
|
--color colors the logging tags according to level
|
||||||
|
|||||||
11
control.hpp
11
control.hpp
@ -323,17 +323,6 @@ struct ControlNet : public GGMLRunner {
|
|||||||
control_net.init(params_ctx, tensor_types, "");
|
control_net.init(params_ctx, tensor_types, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_conv2d_direct() {
|
|
||||||
std::vector<GGMLBlock*> blocks;
|
|
||||||
control_net.get_all_blocks(blocks);
|
|
||||||
for (auto block : blocks) {
|
|
||||||
if (block->get_desc() == "Conv2d") {
|
|
||||||
auto conv_block = (Conv2d*)block;
|
|
||||||
conv_block->enable_direct();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
~ControlNet() {
|
~ControlNet() {
|
||||||
free_control_ctx();
|
free_control_ctx();
|
||||||
}
|
}
|
||||||
|
|||||||
11
esrgan.hpp
11
esrgan.hpp
@ -147,17 +147,6 @@ struct ESRGAN : public GGMLRunner {
|
|||||||
rrdb_net.init(params_ctx, tensor_types, "");
|
rrdb_net.init(params_ctx, tensor_types, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_conv2d_direct() {
|
|
||||||
std::vector<GGMLBlock*> blocks;
|
|
||||||
rrdb_net.get_all_blocks(blocks);
|
|
||||||
for (auto block : blocks) {
|
|
||||||
if (block->get_desc() == "Conv2d") {
|
|
||||||
auto conv_block = (Conv2d*)block;
|
|
||||||
conv_block->enable_direct();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_desc() {
|
std::string get_desc() {
|
||||||
return "esrgan";
|
return "esrgan";
|
||||||
}
|
}
|
||||||
|
|||||||
@ -97,8 +97,6 @@ struct SDParams {
|
|||||||
bool clip_on_cpu = false;
|
bool clip_on_cpu = false;
|
||||||
bool vae_on_cpu = false;
|
bool vae_on_cpu = false;
|
||||||
bool diffusion_flash_attn = false;
|
bool diffusion_flash_attn = false;
|
||||||
bool diffusion_conv_direct = false;
|
|
||||||
bool vae_conv_direct = false;
|
|
||||||
bool canny_preprocess = false;
|
bool canny_preprocess = false;
|
||||||
bool color = false;
|
bool color = false;
|
||||||
int upscale_repeats = 1;
|
int upscale_repeats = 1;
|
||||||
@ -144,8 +142,6 @@ void print_params(SDParams params) {
|
|||||||
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
|
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
|
||||||
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
|
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
|
||||||
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
|
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
|
||||||
printf(" diffusion Conv2d direct:%s\n", params.diffusion_conv_direct ? "true" : "false");
|
|
||||||
printf(" vae Conv2d direct:%s\n", params.vae_conv_direct ? "true" : "false");
|
|
||||||
printf(" strength(control): %.2f\n", params.control_strength);
|
printf(" strength(control): %.2f\n", params.control_strength);
|
||||||
printf(" prompt: %s\n", params.prompt.c_str());
|
printf(" prompt: %s\n", params.prompt.c_str());
|
||||||
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
|
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
|
||||||
@ -236,10 +232,6 @@ void print_usage(int argc, const char* argv[]) {
|
|||||||
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
|
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
|
||||||
printf(" Might lower quality, since it implies converting k and v to f16.\n");
|
printf(" Might lower quality, since it implies converting k and v to f16.\n");
|
||||||
printf(" This might crash if it is not supported by the backend.\n");
|
printf(" This might crash if it is not supported by the backend.\n");
|
||||||
printf(" --diffusion-conv-direct use Conv2d direct in the diffusion model");
|
|
||||||
printf(" This might crash if it is not supported by the backend.\n");
|
|
||||||
printf(" --vae-conv-direct use Conv2d direct in the vae model (should improve the performance)");
|
|
||||||
printf(" This might crash if it is not supported by the backend.\n");
|
|
||||||
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
|
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
|
||||||
printf(" --canny apply canny preprocessor (edge detection)\n");
|
printf(" --canny apply canny preprocessor (edge detection)\n");
|
||||||
printf(" --color colors the logging tags according to level\n");
|
printf(" --color colors the logging tags according to level\n");
|
||||||
@ -430,8 +422,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
|||||||
{"", "--clip-on-cpu", "", true, ¶ms.clip_on_cpu},
|
{"", "--clip-on-cpu", "", true, ¶ms.clip_on_cpu},
|
||||||
{"", "--vae-on-cpu", "", true, ¶ms.vae_on_cpu},
|
{"", "--vae-on-cpu", "", true, ¶ms.vae_on_cpu},
|
||||||
{"", "--diffusion-fa", "", true, ¶ms.diffusion_flash_attn},
|
{"", "--diffusion-fa", "", true, ¶ms.diffusion_flash_attn},
|
||||||
{"", "--diffusion-conv-direct", "", true, ¶ms.diffusion_conv_direct},
|
|
||||||
{"", "--vae-conv-direct", "", true, ¶ms.vae_conv_direct},
|
|
||||||
{"", "--canny", "", true, ¶ms.canny_preprocess},
|
{"", "--canny", "", true, ¶ms.canny_preprocess},
|
||||||
{"-v", "--verbos", "", true, ¶ms.verbose},
|
{"-v", "--verbos", "", true, ¶ms.verbose},
|
||||||
{"", "--color", "", true, ¶ms.color},
|
{"", "--color", "", true, ¶ms.color},
|
||||||
@ -911,8 +901,6 @@ int main(int argc, const char* argv[]) {
|
|||||||
params.control_net_cpu,
|
params.control_net_cpu,
|
||||||
params.vae_on_cpu,
|
params.vae_on_cpu,
|
||||||
params.diffusion_flash_attn,
|
params.diffusion_flash_attn,
|
||||||
params.diffusion_conv_direct,
|
|
||||||
params.vae_conv_direct,
|
|
||||||
params.chroma_use_dit_mask,
|
params.chroma_use_dit_mask,
|
||||||
params.chroma_use_t5_mask,
|
params.chroma_use_t5_mask,
|
||||||
params.chroma_t5_mask_pad,
|
params.chroma_t5_mask_pad,
|
||||||
@ -1024,8 +1012,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
|
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
|
||||||
if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) {
|
if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) {
|
||||||
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
|
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
|
||||||
params.n_threads,
|
params.n_threads);
|
||||||
params.diffusion_conv_direct);
|
|
||||||
|
|
||||||
if (upscaler_ctx == NULL) {
|
if (upscaler_ctx == NULL) {
|
||||||
printf("new_upscaler_ctx failed\n");
|
printf("new_upscaler_ctx failed\n");
|
||||||
|
|||||||
2
ggml
2
ggml
@ -1 +1 @@
|
|||||||
Subproject commit 7dee1d6a1e7611f238d09be96738388da97c88ed
|
Subproject commit b96890f3ab5ffbdbe56bc126df5366c34bd08d39
|
||||||
@ -56,8 +56,6 @@
|
|||||||
#define __STATIC_INLINE__ static inline
|
#define __STATIC_INLINE__ static inline
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
static_assert(GGML_MAX_NAME >= 128, "GGML_MAX_NAME must be at least 128");
|
|
||||||
|
|
||||||
// n-mode trensor-matrix product
|
// n-mode trensor-matrix product
|
||||||
// example: 2-mode product
|
// example: 2-mode product
|
||||||
// A: [ne03, k, ne01, ne00]
|
// A: [ne03, k, ne01, ne00]
|
||||||
@ -708,25 +706,6 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d_direct(struct ggml_context* ctx,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
struct ggml_tensor* w,
|
|
||||||
struct ggml_tensor* b,
|
|
||||||
int s0 = 1,
|
|
||||||
int s1 = 1,
|
|
||||||
int p0 = 0,
|
|
||||||
int p1 = 0,
|
|
||||||
int d0 = 1,
|
|
||||||
int d1 = 1) {
|
|
||||||
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
|
|
||||||
if (b != NULL) {
|
|
||||||
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
|
|
||||||
// b = ggml_repeat(ctx, b, x);
|
|
||||||
x = ggml_add(ctx, x, b);
|
|
||||||
}
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
// w: [OC,IC, KD, 1 * 1]
|
// w: [OC,IC, KD, 1 * 1]
|
||||||
// x: [N, IC, IH, IW]
|
// x: [N, IC, IH, IW]
|
||||||
// b: [OC,]
|
// b: [OC,]
|
||||||
@ -1396,19 +1375,6 @@ public:
|
|||||||
tensors[prefix + pair.first] = pair.second;
|
tensors[prefix + pair.first] = pair.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual std::string get_desc() {
|
|
||||||
return "GGMLBlock";
|
|
||||||
}
|
|
||||||
|
|
||||||
void get_all_blocks(std::vector<GGMLBlock*>& result) {
|
|
||||||
result.push_back(this);
|
|
||||||
for (auto& block_iter : blocks) {
|
|
||||||
if (block_iter.second) {
|
|
||||||
block_iter.second->get_all_blocks(result);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class UnaryBlock : public GGMLBlock {
|
class UnaryBlock : public GGMLBlock {
|
||||||
@ -1498,7 +1464,6 @@ protected:
|
|||||||
std::pair<int, int> padding;
|
std::pair<int, int> padding;
|
||||||
std::pair<int, int> dilation;
|
std::pair<int, int> dilation;
|
||||||
bool bias;
|
bool bias;
|
||||||
bool direct = false;
|
|
||||||
|
|
||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
|
||||||
enum ggml_type wtype = GGML_TYPE_F16;
|
enum ggml_type wtype = GGML_TYPE_F16;
|
||||||
@ -1525,25 +1490,13 @@ public:
|
|||||||
dilation(dilation),
|
dilation(dilation),
|
||||||
bias(bias) {}
|
bias(bias) {}
|
||||||
|
|
||||||
void enable_direct() {
|
|
||||||
direct = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_desc() {
|
|
||||||
return "Conv2d";
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
|
||||||
struct ggml_tensor* w = params["weight"];
|
struct ggml_tensor* w = params["weight"];
|
||||||
struct ggml_tensor* b = NULL;
|
struct ggml_tensor* b = NULL;
|
||||||
if (bias) {
|
if (bias) {
|
||||||
b = params["bias"];
|
b = params["bias"];
|
||||||
}
|
}
|
||||||
if (direct) {
|
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
|
||||||
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
|
|
||||||
} else {
|
|
||||||
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -374,10 +374,6 @@ public:
|
|||||||
model_loader.tensor_storages_types,
|
model_loader.tensor_storages_types,
|
||||||
version,
|
version,
|
||||||
sd_ctx_params->diffusion_flash_attn);
|
sd_ctx_params->diffusion_flash_attn);
|
||||||
if (sd_ctx_params->diffusion_conv_direct) {
|
|
||||||
LOG_INFO("Using Conv2d direct in the diffusion model");
|
|
||||||
std::dynamic_pointer_cast<UNetModel>(diffusion_model)->unet.enable_conv2d_direct();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cond_stage_model->alloc_params_buffer();
|
cond_stage_model->alloc_params_buffer();
|
||||||
@ -399,10 +395,6 @@ public:
|
|||||||
vae_decode_only,
|
vae_decode_only,
|
||||||
false,
|
false,
|
||||||
version);
|
version);
|
||||||
if (sd_ctx_params->vae_conv_direct) {
|
|
||||||
LOG_INFO("Using Conv2d direct in the vae model");
|
|
||||||
first_stage_model->enable_conv2d_direct();
|
|
||||||
}
|
|
||||||
first_stage_model->alloc_params_buffer();
|
first_stage_model->alloc_params_buffer();
|
||||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||||
} else {
|
} else {
|
||||||
@ -411,10 +403,6 @@ public:
|
|||||||
"decoder.layers",
|
"decoder.layers",
|
||||||
vae_decode_only,
|
vae_decode_only,
|
||||||
version);
|
version);
|
||||||
if (sd_ctx_params->vae_conv_direct) {
|
|
||||||
LOG_INFO("Using Conv2d direct in the tae model");
|
|
||||||
tae_first_stage->enable_conv2d_direct();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
|
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
|
||||||
|
|
||||||
@ -427,10 +415,6 @@ public:
|
|||||||
controlnet_backend = backend;
|
controlnet_backend = backend;
|
||||||
}
|
}
|
||||||
control_net = std::make_shared<ControlNet>(controlnet_backend, model_loader.tensor_storages_types, version);
|
control_net = std::make_shared<ControlNet>(controlnet_backend, model_loader.tensor_storages_types, version);
|
||||||
if (sd_ctx_params->diffusion_conv_direct) {
|
|
||||||
LOG_INFO("Using Conv2d direct in the control net");
|
|
||||||
control_net->enable_conv2d_direct();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
|
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {
|
||||||
|
|||||||
@ -134,8 +134,6 @@ typedef struct {
|
|||||||
bool keep_control_net_on_cpu;
|
bool keep_control_net_on_cpu;
|
||||||
bool keep_vae_on_cpu;
|
bool keep_vae_on_cpu;
|
||||||
bool diffusion_flash_attn;
|
bool diffusion_flash_attn;
|
||||||
bool diffusion_conv_direct;
|
|
||||||
bool vae_conv_direct;
|
|
||||||
bool chroma_use_dit_mask;
|
bool chroma_use_dit_mask;
|
||||||
bool chroma_use_t5_mask;
|
bool chroma_use_t5_mask;
|
||||||
int chroma_t5_mask_pad;
|
int chroma_t5_mask_pad;
|
||||||
@ -238,8 +236,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
typedef struct upscaler_ctx_t upscaler_ctx_t;
|
typedef struct upscaler_ctx_t upscaler_ctx_t;
|
||||||
|
|
||||||
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
|
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
|
||||||
int n_threads,
|
int n_threads);
|
||||||
bool direct);
|
|
||||||
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
|
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
|
||||||
|
|
||||||
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
|
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);
|
||||||
|
|||||||
11
tae.hpp
11
tae.hpp
@ -206,17 +206,6 @@ struct TinyAutoEncoder : public GGMLRunner {
|
|||||||
taesd.init(params_ctx, tensor_types, prefix);
|
taesd.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_conv2d_direct() {
|
|
||||||
std::vector<GGMLBlock*> blocks;
|
|
||||||
taesd.get_all_blocks(blocks);
|
|
||||||
for (auto block : blocks) {
|
|
||||||
if (block->get_desc() == "Conv2d") {
|
|
||||||
auto conv_block = (Conv2d*)block;
|
|
||||||
conv_block->enable_direct();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_desc() {
|
std::string get_desc() {
|
||||||
return "taesd";
|
return "taesd";
|
||||||
}
|
}
|
||||||
|
|||||||
12
unet.hpp
12
unet.hpp
@ -546,18 +546,6 @@ struct UNetModelRunner : public GGMLRunner {
|
|||||||
unet.init(params_ctx, tensor_types, prefix);
|
unet.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_conv2d_direct() {
|
|
||||||
std::vector<GGMLBlock*> blocks;
|
|
||||||
unet.get_all_blocks(blocks);
|
|
||||||
for (auto block : blocks) {
|
|
||||||
if (block->get_desc() == "Conv2d") {
|
|
||||||
LOG_DEBUG("block %s", block->get_desc().c_str());
|
|
||||||
auto conv_block = (Conv2d*)block;
|
|
||||||
conv_block->enable_direct();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_desc() {
|
std::string get_desc() {
|
||||||
return "unet";
|
return "unet";
|
||||||
}
|
}
|
||||||
|
|||||||
15
upscaler.cpp
15
upscaler.cpp
@ -9,12 +9,9 @@ struct UpscalerGGML {
|
|||||||
std::shared_ptr<ESRGAN> esrgan_upscaler;
|
std::shared_ptr<ESRGAN> esrgan_upscaler;
|
||||||
std::string esrgan_path;
|
std::string esrgan_path;
|
||||||
int n_threads;
|
int n_threads;
|
||||||
bool direct = false;
|
|
||||||
|
|
||||||
UpscalerGGML(int n_threads,
|
UpscalerGGML(int n_threads)
|
||||||
bool direct = false)
|
: n_threads(n_threads) {
|
||||||
: n_threads(n_threads),
|
|
||||||
direct(direct) {
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool load_from_file(const std::string& esrgan_path) {
|
bool load_from_file(const std::string& esrgan_path) {
|
||||||
@ -50,9 +47,6 @@ struct UpscalerGGML {
|
|||||||
}
|
}
|
||||||
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
|
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
|
||||||
esrgan_upscaler = std::make_shared<ESRGAN>(backend, model_loader.tensor_storages_types);
|
esrgan_upscaler = std::make_shared<ESRGAN>(backend, model_loader.tensor_storages_types);
|
||||||
if (direct) {
|
|
||||||
esrgan_upscaler->enable_conv2d_direct();
|
|
||||||
}
|
|
||||||
if (!esrgan_upscaler->load_from_file(esrgan_path)) {
|
if (!esrgan_upscaler->load_from_file(esrgan_path)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -110,15 +104,14 @@ struct upscaler_ctx_t {
|
|||||||
};
|
};
|
||||||
|
|
||||||
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
|
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
|
||||||
int n_threads,
|
int n_threads) {
|
||||||
bool direct = false) {
|
|
||||||
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
|
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
|
||||||
if (upscaler_ctx == NULL) {
|
if (upscaler_ctx == NULL) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
std::string esrgan_path(esrgan_path_c_str);
|
std::string esrgan_path(esrgan_path_c_str);
|
||||||
|
|
||||||
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct);
|
upscaler_ctx->upscaler = new UpscalerGGML(n_threads);
|
||||||
if (upscaler_ctx->upscaler == NULL) {
|
if (upscaler_ctx->upscaler == NULL) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|||||||
11
vae.hpp
11
vae.hpp
@ -534,17 +534,6 @@ struct AutoEncoderKL : public GGMLRunner {
|
|||||||
ae.init(params_ctx, tensor_types, prefix);
|
ae.init(params_ctx, tensor_types, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
void enable_conv2d_direct() {
|
|
||||||
std::vector<GGMLBlock*> blocks;
|
|
||||||
ae.get_all_blocks(blocks);
|
|
||||||
for (auto block : blocks) {
|
|
||||||
if (block->get_desc() == "Conv2d") {
|
|
||||||
auto conv_block = (Conv2d*)block;
|
|
||||||
conv_block->enable_direct();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string get_desc() {
|
std::string get_desc() {
|
||||||
return "vae";
|
return "vae";
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user