Conv2D direct support (#744)

* Conv2DDirect for VAE stage

* Enable only for Vulkan, reduced duplicated code

* Cmake option to use conv2d direct

* conv2d direct always on for opencl

* conv direct as a flag

* fix merge typo

* Align conv2d behavior to flash attention's

* fix readme

* add conv2d direct for controlnet

* add conv2d direct for esrgan

* clean code, use enable_conv2d_direct/get_all_blocks

* format code

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
Daniele 2025-08-02 17:25:17 +00:00 committed by GitHub
parent f7f05fb185
commit 5b8996f74a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 151 additions and 7 deletions

View File

@ -341,6 +341,10 @@ arguments:
--diffusion-fa use flash attention in the diffusion model (for low vram)
Might lower quality, since it implies converting k and v to f16.
This might crash if it is not supported by the backend.
--diffusion-conv-direct use Conv2d direct in the diffusion model
This might crash if it is not supported by the backend.
--vae-conv-direct use Conv2d direct in the vae model (should improve the performance)
This might crash if it is not supported by the backend.
--control-net-cpu keep controlnet in cpu (for low vram)
--canny apply canny preprocessor (edge detection)
--color colors the logging tags according to level

View File

@ -323,6 +323,17 @@ struct ControlNet : public GGMLRunner {
control_net.init(params_ctx, tensor_types, "");
}
void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
control_net.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}
~ControlNet() {
free_control_ctx();
}

View File

@ -147,6 +147,17 @@ struct ESRGAN : public GGMLRunner {
rrdb_net.init(params_ctx, tensor_types, "");
}
void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
rrdb_net.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}
std::string get_desc() {
return "esrgan";
}

View File

@ -97,6 +97,8 @@ struct SDParams {
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
bool canny_preprocess = false;
bool color = false;
int upscale_repeats = 1;
@ -142,6 +144,8 @@ void print_params(SDParams params) {
printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false");
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
printf(" diffusion Conv2d direct:%s\n", params.diffusion_conv_direct ? "true" : "false");
printf(" vae Conv2d direct:%s\n", params.vae_conv_direct ? "true" : "false");
printf(" strength(control): %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
@ -232,6 +236,10 @@ void print_usage(int argc, const char* argv[]) {
printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n");
printf(" Might lower quality, since it implies converting k and v to f16.\n");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --diffusion-conv-direct use Conv2d direct in the diffusion model");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --vae-conv-direct use Conv2d direct in the vae model (should improve the performance)");
printf(" This might crash if it is not supported by the backend.\n");
printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n");
printf(" --canny apply canny preprocessor (edge detection)\n");
printf(" --color colors the logging tags according to level\n");
@ -422,6 +430,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--clip-on-cpu", "", true, &params.clip_on_cpu},
{"", "--vae-on-cpu", "", true, &params.vae_on_cpu},
{"", "--diffusion-fa", "", true, &params.diffusion_flash_attn},
{"", "--diffusion-conv-direct", "", true, &params.diffusion_conv_direct},
{"", "--vae-conv-direct", "", true, &params.vae_conv_direct},
{"", "--canny", "", true, &params.canny_preprocess},
{"-v", "--verbos", "", true, &params.verbose},
{"", "--color", "", true, &params.color},
@ -901,6 +911,8 @@ int main(int argc, const char* argv[]) {
params.control_net_cpu,
params.vae_on_cpu,
params.diffusion_flash_attn,
params.diffusion_conv_direct,
params.vae_conv_direct,
params.chroma_use_dit_mask,
params.chroma_use_t5_mask,
params.chroma_t5_mask_pad,
@ -1012,7 +1024,8 @@ int main(int argc, const char* argv[]) {
int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth
if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) {
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
params.n_threads);
params.n_threads,
params.diffusion_conv_direct);
if (upscaler_ctx == NULL) {
printf("new_upscaler_ctx failed\n");

View File

@ -708,6 +708,25 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx,
return x;
}
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_conv_2d_direct(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b,
int s0 = 1,
int s1 = 1,
int p0 = 0,
int p1 = 0,
int d0 = 1,
int d1 = 1) {
x = ggml_conv_2d_direct(ctx, w, x, s0, s1, p0, p1, d0, d1);
if (b != NULL) {
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
// b = ggml_repeat(ctx, b, x);
x = ggml_add(ctx, x, b);
}
return x;
}
// w: [OCIC, KD, 1 * 1]
// x: [N, IC, IH, IW]
// b: [OC,]
@ -1377,6 +1396,19 @@ public:
tensors[prefix + pair.first] = pair.second;
}
}
virtual std::string get_desc() {
return "GGMLBlock";
}
void get_all_blocks(std::vector<GGMLBlock*>& result) {
result.push_back(this);
for (auto& block_iter : blocks) {
if (block_iter.second) {
block_iter.second->get_all_blocks(result);
}
}
}
};
class UnaryBlock : public GGMLBlock {
@ -1466,6 +1498,7 @@ protected:
std::pair<int, int> padding;
std::pair<int, int> dilation;
bool bias;
bool direct = false;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = GGML_TYPE_F16;
@ -1492,13 +1525,25 @@ public:
dilation(dilation),
bias(bias) {}
void enable_direct() {
direct = true;
}
std::string get_desc() {
return "Conv2d";
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
struct ggml_tensor* b = NULL;
if (bias) {
b = params["bias"];
}
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
if (direct) {
return ggml_nn_conv_2d_direct(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
} else {
return ggml_nn_conv_2d(ctx, x, w, b, stride.second, stride.first, padding.second, padding.first, dilation.second, dilation.first);
}
}
};

View File

@ -374,6 +374,10 @@ public:
model_loader.tensor_storages_types,
version,
sd_ctx_params->diffusion_flash_attn);
if (sd_ctx_params->diffusion_conv_direct) {
LOG_INFO("Using Conv2d direct in the diffusion model");
std::dynamic_pointer_cast<UNetModel>(diffusion_model)->unet.enable_conv2d_direct();
}
}
cond_stage_model->alloc_params_buffer();
@ -395,6 +399,10 @@ public:
vae_decode_only,
false,
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the vae model");
first_stage_model->enable_conv2d_direct();
}
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else {
@ -403,6 +411,10 @@ public:
"decoder.layers",
vae_decode_only,
version);
if (sd_ctx_params->vae_conv_direct) {
LOG_INFO("Using Conv2d direct in the tae model");
tae_first_stage->enable_conv2d_direct();
}
}
// first_stage_model->get_param_tensors(tensors, "first_stage_model.");
@ -415,6 +427,10 @@ public:
controlnet_backend = backend;
}
control_net = std::make_shared<ControlNet>(controlnet_backend, model_loader.tensor_storages_types, version);
if (sd_ctx_params->diffusion_conv_direct) {
LOG_INFO("Using Conv2d direct in the control net");
control_net->enable_conv2d_direct();
}
}
if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) {

View File

@ -134,6 +134,8 @@ typedef struct {
bool keep_control_net_on_cpu;
bool keep_vae_on_cpu;
bool diffusion_flash_attn;
bool diffusion_conv_direct;
bool vae_conv_direct;
bool chroma_use_dit_mask;
bool chroma_use_t5_mask;
int chroma_t5_mask_pad;
@ -236,7 +238,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
typedef struct upscaler_ctx_t upscaler_ctx_t;
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
int n_threads);
int n_threads,
bool direct);
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);

11
tae.hpp
View File

@ -206,6 +206,17 @@ struct TinyAutoEncoder : public GGMLRunner {
taesd.init(params_ctx, tensor_types, prefix);
}
void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
taesd.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}
std::string get_desc() {
return "taesd";
}

View File

@ -546,6 +546,18 @@ struct UNetModelRunner : public GGMLRunner {
unet.init(params_ctx, tensor_types, prefix);
}
void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
unet.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
LOG_DEBUG("block %s", block->get_desc().c_str());
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}
std::string get_desc() {
return "unet";
}

View File

@ -9,9 +9,12 @@ struct UpscalerGGML {
std::shared_ptr<ESRGAN> esrgan_upscaler;
std::string esrgan_path;
int n_threads;
bool direct = false;
UpscalerGGML(int n_threads)
: n_threads(n_threads) {
UpscalerGGML(int n_threads,
bool direct = false)
: n_threads(n_threads),
direct(direct) {
}
bool load_from_file(const std::string& esrgan_path) {
@ -47,6 +50,9 @@ struct UpscalerGGML {
}
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
esrgan_upscaler = std::make_shared<ESRGAN>(backend, model_loader.tensor_storages_types);
if (direct) {
esrgan_upscaler->enable_conv2d_direct();
}
if (!esrgan_upscaler->load_from_file(esrgan_path)) {
return false;
}
@ -104,14 +110,15 @@ struct upscaler_ctx_t {
};
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
int n_threads) {
int n_threads,
bool direct = false) {
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
if (upscaler_ctx == NULL) {
return NULL;
}
std::string esrgan_path(esrgan_path_c_str);
upscaler_ctx->upscaler = new UpscalerGGML(n_threads);
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct);
if (upscaler_ctx->upscaler == NULL) {
return NULL;
}

11
vae.hpp
View File

@ -534,6 +534,17 @@ struct AutoEncoderKL : public GGMLRunner {
ae.init(params_ctx, tensor_types, prefix);
}
void enable_conv2d_direct() {
std::vector<GGMLBlock*> blocks;
ae.get_all_blocks(blocks);
for (auto block : blocks) {
if (block->get_desc() == "Conv2d") {
auto conv_block = (Conv2d*)block;
conv_block->enable_direct();
}
}
}
std::string get_desc() {
return "vae";
}