diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index dba707e..3182950 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -69,7 +69,8 @@ struct SDParams { std::string tensor_type_rules; std::string lora_model_dir; std::string output_path = "output.png"; - std::string input_path; + std::string init_image_path; + std::string end_image_path; std::string mask_path; std::string control_image_path; std::vector ref_image_paths; @@ -123,59 +124,60 @@ void print_params(SDParams params) { char* sample_params_str = sd_sample_params_to_str(¶ms.sample_params); char* high_noise_sample_params_str = sd_sample_params_to_str(¶ms.high_noise_sample_params); printf("Option: \n"); - printf(" n_threads: %d\n", params.n_threads); - printf(" mode: %s\n", modes_str[params.mode]); - printf(" model_path: %s\n", params.model_path.c_str()); - printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified"); - printf(" clip_l_path: %s\n", params.clip_l_path.c_str()); - printf(" clip_g_path: %s\n", params.clip_g_path.c_str()); - printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str()); - printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); - printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); + printf(" n_threads: %d\n", params.n_threads); + printf(" mode: %s\n", modes_str[params.mode]); + printf(" model_path: %s\n", params.model_path.c_str()); + printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified"); + printf(" clip_l_path: %s\n", params.clip_l_path.c_str()); + printf(" clip_g_path: %s\n", params.clip_g_path.c_str()); + printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str()); + printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); + printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str()); - printf(" vae_path: %s\n", params.vae_path.c_str()); - printf(" taesd_path: %s\n", params.taesd_path.c_str()); - printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); - printf(" control_net_path: %s\n", params.control_net_path.c_str()); - printf(" embedding_dir: %s\n", params.embedding_dir.c_str()); - printf(" stacked_id_embed_dir: %s\n", params.stacked_id_embed_dir.c_str()); - printf(" input_id_images_path: %s\n", params.input_id_images_path.c_str()); - printf(" style ratio: %.2f\n", params.style_ratio); - printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false"); - printf(" output_path: %s\n", params.output_path.c_str()); - printf(" init_img: %s\n", params.input_path.c_str()); - printf(" mask_img: %s\n", params.mask_path.c_str()); - printf(" control_image: %s\n", params.control_image_path.c_str()); + printf(" vae_path: %s\n", params.vae_path.c_str()); + printf(" taesd_path: %s\n", params.taesd_path.c_str()); + printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); + printf(" control_net_path: %s\n", params.control_net_path.c_str()); + printf(" embedding_dir: %s\n", params.embedding_dir.c_str()); + printf(" stacked_id_embed_dir: %s\n", params.stacked_id_embed_dir.c_str()); + printf(" input_id_images_path: %s\n", params.input_id_images_path.c_str()); + printf(" style ratio: %.2f\n", params.style_ratio); + printf(" normalize input image: %s\n", params.normalize_input ? "true" : "false"); + printf(" output_path: %s\n", params.output_path.c_str()); + printf(" init_image_path: %s\n", params.init_image_path.c_str()); + printf(" end_image_path: %s\n", params.end_image_path.c_str()); + printf(" mask_image_path: %s\n", params.mask_path.c_str()); + printf(" control_image_path: %s\n", params.control_image_path.c_str()); printf(" ref_images_paths:\n"); for (auto& path : params.ref_image_paths) { printf(" %s\n", path.c_str()); }; - printf(" offload_params_to_cpu: %s\n", params.offload_params_to_cpu ? "true" : "false"); - printf(" clip_on_cpu: %s\n", params.clip_on_cpu ? "true" : "false"); - printf(" control_net_cpu: %s\n", params.control_net_cpu ? "true" : "false"); - printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false"); - printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false"); - printf(" diffusion Conv2d direct:%s\n", params.diffusion_conv_direct ? "true" : "false"); - printf(" vae Conv2d direct:%s\n", params.vae_conv_direct ? "true" : "false"); - printf(" strength(control): %.2f\n", params.control_strength); - printf(" prompt: %s\n", params.prompt.c_str()); - printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); - printf(" clip_skip: %d\n", params.clip_skip); - printf(" width: %d\n", params.width); - printf(" height: %d\n", params.height); - printf(" sample_params: %s\n", SAFE_STR(sample_params_str)); - printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str)); - printf(" strength(img2img): %.2f\n", params.strength); - printf(" rng: %s\n", sd_rng_type_name(params.rng_type)); - printf(" seed: %ld\n", params.seed); - printf(" batch_count: %d\n", params.batch_count); - printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false"); - printf(" upscale_repeats: %d\n", params.upscale_repeats); - printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false"); - printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false"); - printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad); - printf(" video_frames: %d\n", params.video_frames); - printf(" fps: %d\n", params.fps); + printf(" offload_params_to_cpu: %s\n", params.offload_params_to_cpu ? "true" : "false"); + printf(" clip_on_cpu: %s\n", params.clip_on_cpu ? "true" : "false"); + printf(" control_net_cpu: %s\n", params.control_net_cpu ? "true" : "false"); + printf(" vae_on_cpu: %s\n", params.vae_on_cpu ? "true" : "false"); + printf(" diffusion flash attention: %s\n", params.diffusion_flash_attn ? "true" : "false"); + printf(" diffusion Conv2d direct: %s\n", params.diffusion_conv_direct ? "true" : "false"); + printf(" vae_conv_direct: %s\n", params.vae_conv_direct ? "true" : "false"); + printf(" control_strength: %.2f\n", params.control_strength); + printf(" prompt: %s\n", params.prompt.c_str()); + printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); + printf(" clip_skip: %d\n", params.clip_skip); + printf(" width: %d\n", params.width); + printf(" height: %d\n", params.height); + printf(" sample_params: %s\n", SAFE_STR(sample_params_str)); + printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str)); + printf(" strength(img2img): %.2f\n", params.strength); + printf(" rng: %s\n", sd_rng_type_name(params.rng_type)); + printf(" seed: %ld\n", params.seed); + printf(" batch_count: %d\n", params.batch_count); + printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false"); + printf(" upscale_repeats: %d\n", params.upscale_repeats); + printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false"); + printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false"); + printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad); + printf(" video_frames: %d\n", params.video_frames); + printf(" fps: %d\n", params.fps); free(sample_params_str); free(high_noise_sample_params_str); } @@ -449,7 +451,8 @@ void parse_args(int argc, const char** argv, SDParams& params) { {"", "--embd-dir", "", ¶ms.embedding_dir}, {"", "--stacked-id-embd-dir", "", ¶ms.stacked_id_embed_dir}, {"", "--lora-model-dir", "", ¶ms.lora_model_dir}, - {"-i", "--init-img", "", ¶ms.input_path}, + {"-i", "--init-img", "", ¶ms.init_image_path}, + {"", "--end-img", "", ¶ms.end_image_path}, {"", "--tensor-type-rules", "", ¶ms.tensor_type_rules}, {"", "--input-id-images-dir", "", ¶ms.input_id_images_path}, {"", "--mask", "", ¶ms.mask_path}, @@ -902,6 +905,94 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { fflush(out_stream); } +uint8_t* load_image(const char* image_path, int& width, int& height, int expected_width = 0, int expected_height = 0, int expected_channel = 3) { + int c = 0; + uint8_t* image_buffer = (uint8_t*)stbi_load(image_path, &width, &height, &c, expected_channel); + if (image_buffer == NULL) { + fprintf(stderr, "load image from '%s' failed\n", image_path); + return NULL; + } + if (c < expected_channel) { + fprintf(stderr, + "the number of channels for the input image must be >= %d," + "but got %d channels, image_path = %s\n", + expected_channel, + c, + image_path); + free(image_buffer); + return NULL; + } + if (width <= 0) { + fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path); + free(image_buffer); + return NULL; + } + if (height <= 0) { + fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path); + free(image_buffer); + return NULL; + } + + // Resize input image ... + if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) { + float dst_aspect = (float)expected_width / (float)expected_height; + float src_aspect = (float)width / (float)height; + + int crop_x = 0, crop_y = 0; + int crop_w = width, crop_h = height; + + if (src_aspect > dst_aspect) { + crop_w = (int)(height * dst_aspect); + crop_x = (width - crop_w) / 2; + } else if (src_aspect < dst_aspect) { + crop_h = (int)(width / dst_aspect); + crop_y = (height - crop_h) / 2; + } + + if (crop_x != 0 || crop_y != 0) { + printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path); + uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel); + if (cropped_image_buffer == NULL) { + fprintf(stderr, "error: allocate memory for crop\n"); + free(image_buffer); + return NULL; + } + for (int row = 0; row < crop_h; row++) { + uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel; + uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel; + memcpy(dst, src, crop_w * expected_channel); + } + + width = crop_w; + height = crop_h; + free(image_buffer); + image_buffer = cropped_image_buffer; + } + + printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height); + int resized_height = expected_height; + int resized_width = expected_width; + + uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel); + if (resized_image_buffer == NULL) { + fprintf(stderr, "error: allocate memory for resize input image\n"); + free(image_buffer); + return NULL; + } + stbir_resize(image_buffer, width, height, 0, + resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8, + expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0, + STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, + STBIR_FILTER_BOX, STBIR_FILTER_BOX, + STBIR_COLORSPACE_SRGB, nullptr); + + // Save resized result + free(image_buffer); + image_buffer = resized_image_buffer; + } + return image_buffer; +} + int main(int argc, const char* argv[]) { SDParams params; parse_args(argc, argv, params); @@ -935,120 +1026,101 @@ int main(int argc, const char* argv[]) { } } - bool vae_decode_only = true; - uint8_t* input_image_buffer = NULL; - uint8_t* control_image_buffer = NULL; - uint8_t* mask_image_buffer = NULL; + bool vae_decode_only = true; + sd_image_t init_image = {(uint32_t)params.width, (uint32_t)params.height, 3, NULL}; + sd_image_t end_image = {(uint32_t)params.width, (uint32_t)params.height, 3, NULL}; + sd_image_t control_image = {(uint32_t)params.width, (uint32_t)params.height, 3, NULL}; + sd_image_t mask_image = {(uint32_t)params.width, (uint32_t)params.height, 1, NULL}; std::vector ref_images; - if (params.input_path.size() > 0) { + auto release_all_resources = [&]() { + free(init_image.data); + free(end_image.data); + free(control_image.data); + free(mask_image.data); + for (auto ref_image : ref_images) { + free(ref_image.data); + ref_image.data = NULL; + } + ref_images.clear(); + }; + + if (params.init_image_path.size() > 0) { vae_decode_only = false; - int c = 0; + int width = 0; + int height = 0; + init_image.data = load_image(params.init_image_path.c_str(), width, height, params.width, params.height); + if (init_image.data == NULL) { + fprintf(stderr, "load image from '%s' failed\n", params.init_image_path.c_str()); + release_all_resources(); + return 1; + } + } + + if (params.end_image_path.size() > 0) { + vae_decode_only = false; + + int width = 0; + int height = 0; + end_image.data = load_image(params.end_image_path.c_str(), width, height, params.width, params.height); + if (end_image.data == NULL) { + fprintf(stderr, "load image from '%s' failed\n", params.end_image_path.c_str()); + release_all_resources(); + return 1; + } + } + + if (params.mask_path.size() > 0) { + int c = 0; + int width = 0; + int height = 0; + mask_image.data = load_image(params.mask_path.c_str(), width, height, params.width, params.height, 1); + if (mask_image.data == NULL) { + fprintf(stderr, "load image from '%s' failed\n", params.mask_path.c_str()); + release_all_resources(); + return 1; + } + } else { + mask_image.data = (uint8_t*)malloc(params.width * params.height); + memset(mask_image.data, 255, params.width * params.height); + if (mask_image.data == NULL) { + fprintf(stderr, "malloc mask image failed\n"); + release_all_resources(); + return 1; + } + } + + if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) { int width = 0; int height = 0; - input_image_buffer = stbi_load(params.input_path.c_str(), &width, &height, &c, 3); - if (input_image_buffer == NULL) { - fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str()); + control_image.data = load_image(params.control_image_path.c_str(), width, height, params.width, params.height); + if (control_image.data == NULL) { + fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str()); + release_all_resources(); return 1; } - if (c < 3) { - fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); - free(input_image_buffer); - return 1; - } - if (width <= 0) { - fprintf(stderr, "error: the width of image must be greater than 0\n"); - free(input_image_buffer); - return 1; - } - if (height <= 0) { - fprintf(stderr, "error: the height of image must be greater than 0\n"); - free(input_image_buffer); - return 1; + if (params.canny_preprocess) { // apply preprocessor + control_image.data = preprocess_canny(control_image.data, + control_image.width, + control_image.height, + 0.08f, + 0.08f, + 0.8f, + 1.0f, + false); } + } - // Resize input image ... - if (params.height != height || params.width != width) { - float dst_aspect = (float)params.width / (float)params.height; - float src_aspect = (float)width / (float)height; - - int crop_x = 0, crop_y = 0; - int crop_w = width, crop_h = height; - - if (src_aspect > dst_aspect) { - crop_w = (int)(height * dst_aspect); - crop_x = (width - crop_w) / 2; - } else if (src_aspect < dst_aspect) { - crop_h = (int)(width / dst_aspect); - crop_y = (height - crop_h) / 2; - } - - if (crop_x != 0 || crop_y != 0) { - printf("crop input image from %dx%d to %dx%d\n", width, height, crop_w, crop_h); - uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * 3); - if (cropped_image_buffer == NULL) { - fprintf(stderr, "error: allocate memory for crop\n"); - free(input_image_buffer); - return 1; - } - for (int row = 0; row < crop_h; row++) { - uint8_t* src = input_image_buffer + ((crop_y + row) * width + crop_x) * 3; - uint8_t* dst = cropped_image_buffer + (row * crop_w) * 3; - memcpy(dst, src, crop_w * 3); - } - - width = crop_w; - height = crop_h; - free(input_image_buffer); - input_image_buffer = cropped_image_buffer; - } - - printf("resize input image from %dx%d to %dx%d\n", width, height, params.width, params.height); - int resized_height = params.height; - int resized_width = params.width; - - uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3); - if (resized_image_buffer == NULL) { - fprintf(stderr, "error: allocate memory for resize input image\n"); - free(input_image_buffer); - return 1; - } - stbir_resize(input_image_buffer, width, height, 0, - resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8, - 3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0, - STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP, - STBIR_FILTER_BOX, STBIR_FILTER_BOX, - STBIR_COLORSPACE_SRGB, nullptr); - - // Save resized result - free(input_image_buffer); - input_image_buffer = resized_image_buffer; - } - } else if (params.ref_image_paths.size() > 0) { + if (params.ref_image_paths.size() > 0) { vae_decode_only = false; for (auto& path : params.ref_image_paths) { - int c = 0; int width = 0; int height = 0; - uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3); + uint8_t* image_buffer = load_image(path.c_str(), width, height); if (image_buffer == NULL) { fprintf(stderr, "load image from '%s' failed\n", path.c_str()); - return 1; - } - if (c < 3) { - fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); - free(image_buffer); - return 1; - } - if (width <= 0) { - fprintf(stderr, "error: the width of image must be greater than 0\n"); - free(image_buffer); - return 1; - } - if (height <= 0) { - fprintf(stderr, "error: the height of image must be greater than 0\n"); - free(image_buffer); + release_all_resources(); return 1; } ref_images.push_back({(uint32_t)width, @@ -1098,50 +1170,10 @@ int main(int argc, const char* argv[]) { if (sd_ctx == NULL) { printf("new_sd_ctx_t failed\n"); + release_all_resources(); return 1; } - sd_image_t input_image = {(uint32_t)params.width, - (uint32_t)params.height, - 3, - input_image_buffer}; - - sd_image_t* control_image = NULL; - if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) { - int c = 0; - control_image_buffer = stbi_load(params.control_image_path.c_str(), ¶ms.width, ¶ms.height, &c, 3); - if (control_image_buffer == NULL) { - fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str()); - return 1; - } - control_image = new sd_image_t{(uint32_t)params.width, - (uint32_t)params.height, - 3, - control_image_buffer}; - if (params.canny_preprocess) { // apply preprocessor - control_image->data = preprocess_canny(control_image->data, - control_image->width, - control_image->height, - 0.08f, - 0.08f, - 0.8f, - 1.0f, - false); - } - } - - std::vector default_mask_image_vec(params.width * params.height, 255); - if (params.mask_path != "") { - int c = 0; - mask_image_buffer = stbi_load(params.mask_path.c_str(), ¶ms.width, ¶ms.height, &c, 1); - } else { - mask_image_buffer = default_mask_image_vec.data(); - } - sd_image_t mask_image = {(uint32_t)params.width, - (uint32_t)params.height, - 1, - mask_image_buffer}; - sd_image_t* results; int num_results = 1; if (params.mode == IMG_GEN) { @@ -1149,7 +1181,7 @@ int main(int argc, const char* argv[]) { params.prompt.c_str(), params.negative_prompt.c_str(), params.clip_skip, - input_image, + init_image, ref_images.data(), (int)ref_images.size(), mask_image, @@ -1173,7 +1205,8 @@ int main(int argc, const char* argv[]) { params.prompt.c_str(), params.negative_prompt.c_str(), params.clip_skip, - input_image, + init_image, + end_image, params.width, params.height, params.sample_params, @@ -1275,8 +1308,8 @@ int main(int argc, const char* argv[]) { } free(results); free_sd_ctx(sd_ctx); - free(control_image_buffer); - free(input_image_buffer); + + release_all_resources(); return 0; } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index fdf7a65..e081670 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1780,7 +1780,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, const std::vector& sigmas, int64_t seed, int batch_count, - const sd_image_t* control_cond, + sd_image_t control_image, float control_strength, float style_ratio, bool normalize_input, @@ -1947,9 +1947,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, // Control net hint struct ggml_tensor* image_hint = NULL; - if (control_cond != NULL) { + if (control_image.data != NULL) { image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); - sd_image_to_tensor(control_cond->data, image_hint); + sd_image_to_tensor(control_image.data, image_hint); } // Sample @@ -2342,7 +2342,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sigmas, seed, sd_img_gen_params->batch_count, - sd_img_gen_params->control_cond, + sd_img_gen_params->control_image, sd_img_gen_params->control_strength, sd_img_gen_params->style_strength, sd_img_gen_params->normalize_input, diff --git a/stable-diffusion.h b/stable-diffusion.h index e4d2aa1..52d4aa6 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -188,7 +188,7 @@ typedef struct { float strength; int64_t seed; int batch_count; - const sd_image_t* control_cond; + sd_image_t control_image; float control_strength; float style_strength; bool normalize_input; @@ -200,6 +200,7 @@ typedef struct { const char* negative_prompt; int clip_skip; sd_image_t init_image; + sd_image_t end_image; int width; int height; sd_sample_params_t sample_params;