mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-31 17:53:35 +00:00
Compare commits
3 Commits
50f921119e
...
00b0a0053d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
00b0a0053d | ||
|
|
e2a3a406b6 | ||
|
|
33ff442c1d |
@ -69,8 +69,9 @@ struct SDParams {
|
|||||||
std::string tensor_type_rules;
|
std::string tensor_type_rules;
|
||||||
std::string lora_model_dir;
|
std::string lora_model_dir;
|
||||||
std::string output_path = "output.png";
|
std::string output_path = "output.png";
|
||||||
std::string input_path;
|
std::string init_image_path;
|
||||||
std::string mask_path;
|
std::string end_image_path;
|
||||||
|
std::string mask_image_path;
|
||||||
std::string control_image_path;
|
std::string control_image_path;
|
||||||
std::vector<std::string> ref_image_paths;
|
std::vector<std::string> ref_image_paths;
|
||||||
|
|
||||||
@ -123,59 +124,60 @@ void print_params(SDParams params) {
|
|||||||
char* sample_params_str = sd_sample_params_to_str(¶ms.sample_params);
|
char* sample_params_str = sd_sample_params_to_str(¶ms.sample_params);
|
||||||
char* high_noise_sample_params_str = sd_sample_params_to_str(¶ms.high_noise_sample_params);
|
char* high_noise_sample_params_str = sd_sample_params_to_str(¶ms.high_noise_sample_params);
|
||||||
printf("Option: \n");
|
printf("Option: \n");
|
||||||
printf(" n_threads: %d\n", params.n_threads);
|
printf(" n_threads: %d\n", params.n_threads);
|
||||||
printf(" mode: %s\n", modes_str[params.mode]);
|
printf(" mode: %s\n", modes_str[params.mode]);
|
||||||
printf(" model_path: %s\n", params.model_path.c_str());
|
printf(" model_path: %s\n", params.model_path.c_str());
|
||||||
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
|
printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified");
|
||||||
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
|
printf(" clip_l_path: %s\n", params.clip_l_path.c_str());
|
||||||
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
|
printf(" clip_g_path: %s\n", params.clip_g_path.c_str());
|
||||||
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
|
printf(" clip_vision_path: %s\n", params.clip_vision_path.c_str());
|
||||||
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
|
printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str());
|
||||||
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
|
printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str());
|
||||||
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
|
printf(" high_noise_diffusion_model_path: %s\n", params.high_noise_diffusion_model_path.c_str());
|
||||||
printf(" vae_path: %s\n", params.vae_path.c_str());
|
printf(" vae_path: %s\n", params.vae_path.c_str());
|
||||||
printf(" taesd_path: %s\n", params.taesd_path.c_str());
|
printf(" taesd_path: %s\n", params.taesd_path.c_str());
|
||||||
printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
|
printf(" esrgan_path: %s\n", params.esrgan_path.c_str());
|
||||||
printf(" control_net_path: %s\n", params.control_net_path.c_str());
|
printf(" control_net_path: %s\n", params.control_net_path.c_str());
|
||||||
printf(" embedding_dir: %s\n", params.embedding_dir.c_str());
|
printf(" embedding_dir: %s\n", params.embedding_dir.c_str());
|
||||||
printf(" stacked_id_embed_dir: %s\n", params.stacked_id_embed_dir.c_str());
|
printf(" stacked_id_embed_dir: %s\n", params.stacked_id_embed_dir.c_str());
|
||||||
printf(" input_id_images_path: %s\n", params.input_id_images_path.c_str());
|
printf(" input_id_images_path: %s\n", params.input_id_images_path.c_str());
|
||||||
printf(" style ratio: %.2f\n", params.style_ratio);
|
printf(" style ratio: %.2f\n", params.style_ratio);
|
||||||
printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false");
|
printf(" normalize input image: %s\n", params.normalize_input ? "true" : "false");
|
||||||
printf(" output_path: %s\n", params.output_path.c_str());
|
printf(" output_path: %s\n", params.output_path.c_str());
|
||||||
printf(" init_img: %s\n", params.input_path.c_str());
|
printf(" init_image_path: %s\n", params.init_image_path.c_str());
|
||||||
printf(" mask_img: %s\n", params.mask_path.c_str());
|
printf(" end_image_path: %s\n", params.end_image_path.c_str());
|
||||||
printf(" control_image: %s\n", params.control_image_path.c_str());
|
printf(" mask_image_path: %s\n", params.mask_image_path.c_str());
|
||||||
|
printf(" control_image_path: %s\n", params.control_image_path.c_str());
|
||||||
printf(" ref_images_paths:\n");
|
printf(" ref_images_paths:\n");
|
||||||
for (auto& path : params.ref_image_paths) {
|
for (auto& path : params.ref_image_paths) {
|
||||||
printf(" %s\n", path.c_str());
|
printf(" %s\n", path.c_str());
|
||||||
};
|
};
|
||||||
printf(" offload_params_to_cpu: %s\n", params.offload_params_to_cpu ? "true" : "false");
|
printf(" offload_params_to_cpu: %s\n", params.offload_params_to_cpu ? "true" : "false");
|
||||||
printf(" clip_on_cpu: %s\n", params.clip_on_cpu ? "true" : "false");
|
printf(" clip_on_cpu: %s\n", params.clip_on_cpu ? "true" : "false");
|
||||||
printf(" control_net_cpu: %s\n", params.control_net_cpu ? "true" : "false");
|
printf(" control_net_cpu: %s\n", params.control_net_cpu ? "true" : "false");
|
||||||
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false");
|
printf(" vae_on_cpu: %s\n", params.vae_on_cpu ? "true" : "false");
|
||||||
printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false");
|
printf(" diffusion flash attention: %s\n", params.diffusion_flash_attn ? "true" : "false");
|
||||||
printf(" diffusion Conv2d direct:%s\n", params.diffusion_conv_direct ? "true" : "false");
|
printf(" diffusion Conv2d direct: %s\n", params.diffusion_conv_direct ? "true" : "false");
|
||||||
printf(" vae Conv2d direct:%s\n", params.vae_conv_direct ? "true" : "false");
|
printf(" vae_conv_direct: %s\n", params.vae_conv_direct ? "true" : "false");
|
||||||
printf(" strength(control): %.2f\n", params.control_strength);
|
printf(" control_strength: %.2f\n", params.control_strength);
|
||||||
printf(" prompt: %s\n", params.prompt.c_str());
|
printf(" prompt: %s\n", params.prompt.c_str());
|
||||||
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
|
printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
|
||||||
printf(" clip_skip: %d\n", params.clip_skip);
|
printf(" clip_skip: %d\n", params.clip_skip);
|
||||||
printf(" width: %d\n", params.width);
|
printf(" width: %d\n", params.width);
|
||||||
printf(" height: %d\n", params.height);
|
printf(" height: %d\n", params.height);
|
||||||
printf(" sample_params: %s\n", SAFE_STR(sample_params_str));
|
printf(" sample_params: %s\n", SAFE_STR(sample_params_str));
|
||||||
printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
|
printf(" high_noise_sample_params: %s\n", SAFE_STR(high_noise_sample_params_str));
|
||||||
printf(" strength(img2img): %.2f\n", params.strength);
|
printf(" strength(img2img): %.2f\n", params.strength);
|
||||||
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
|
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
|
||||||
printf(" seed: %ld\n", params.seed);
|
printf(" seed: %ld\n", params.seed);
|
||||||
printf(" batch_count: %d\n", params.batch_count);
|
printf(" batch_count: %d\n", params.batch_count);
|
||||||
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
|
printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false");
|
||||||
printf(" upscale_repeats: %d\n", params.upscale_repeats);
|
printf(" upscale_repeats: %d\n", params.upscale_repeats);
|
||||||
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
|
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
|
||||||
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
|
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
|
||||||
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
|
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
|
||||||
printf(" video_frames: %d\n", params.video_frames);
|
printf(" video_frames: %d\n", params.video_frames);
|
||||||
printf(" fps: %d\n", params.fps);
|
printf(" fps: %d\n", params.fps);
|
||||||
free(sample_params_str);
|
free(sample_params_str);
|
||||||
free(high_noise_sample_params_str);
|
free(high_noise_sample_params_str);
|
||||||
}
|
}
|
||||||
@ -208,8 +210,9 @@ void print_usage(int argc, const char* argv[]) {
|
|||||||
printf(" If not specified, the default is the type of the weight file\n");
|
printf(" If not specified, the default is the type of the weight file\n");
|
||||||
printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n");
|
printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n");
|
||||||
printf(" --lora-model-dir [DIR] lora model directory\n");
|
printf(" --lora-model-dir [DIR] lora model directory\n");
|
||||||
printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n");
|
printf(" -i, --init-img [IMAGE] path to the init image, required by img2img\n");
|
||||||
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
|
printf(" --mask [MASK] path to the mask image, required by img2img with mask\n");
|
||||||
|
printf(" -i, --end-img [IMAGE] path to the end image, required by flf2v\n");
|
||||||
printf(" --control-image [IMAGE] path to image condition, control net\n");
|
printf(" --control-image [IMAGE] path to image condition, control net\n");
|
||||||
printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n");
|
printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n");
|
||||||
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
|
printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n");
|
||||||
@ -449,10 +452,11 @@ void parse_args(int argc, const char** argv, SDParams& params) {
|
|||||||
{"", "--embd-dir", "", ¶ms.embedding_dir},
|
{"", "--embd-dir", "", ¶ms.embedding_dir},
|
||||||
{"", "--stacked-id-embd-dir", "", ¶ms.stacked_id_embed_dir},
|
{"", "--stacked-id-embd-dir", "", ¶ms.stacked_id_embed_dir},
|
||||||
{"", "--lora-model-dir", "", ¶ms.lora_model_dir},
|
{"", "--lora-model-dir", "", ¶ms.lora_model_dir},
|
||||||
{"-i", "--init-img", "", ¶ms.input_path},
|
{"-i", "--init-img", "", ¶ms.init_image_path},
|
||||||
|
{"", "--end-img", "", ¶ms.end_image_path},
|
||||||
{"", "--tensor-type-rules", "", ¶ms.tensor_type_rules},
|
{"", "--tensor-type-rules", "", ¶ms.tensor_type_rules},
|
||||||
{"", "--input-id-images-dir", "", ¶ms.input_id_images_path},
|
{"", "--input-id-images-dir", "", ¶ms.input_id_images_path},
|
||||||
{"", "--mask", "", ¶ms.mask_path},
|
{"", "--mask", "", ¶ms.mask_image_path},
|
||||||
{"", "--control-image", "", ¶ms.control_image_path},
|
{"", "--control-image", "", ¶ms.control_image_path},
|
||||||
{"-o", "--output", "", ¶ms.output_path},
|
{"-o", "--output", "", ¶ms.output_path},
|
||||||
{"-p", "--prompt", "", ¶ms.prompt},
|
{"-p", "--prompt", "", ¶ms.prompt},
|
||||||
@ -902,6 +906,94 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
|
|||||||
fflush(out_stream);
|
fflush(out_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint8_t* load_image(const char* image_path, int& width, int& height, int expected_width = 0, int expected_height = 0, int expected_channel = 3) {
|
||||||
|
int c = 0;
|
||||||
|
uint8_t* image_buffer = (uint8_t*)stbi_load(image_path, &width, &height, &c, expected_channel);
|
||||||
|
if (image_buffer == NULL) {
|
||||||
|
fprintf(stderr, "load image from '%s' failed\n", image_path);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
if (c < expected_channel) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"the number of channels for the input image must be >= %d,"
|
||||||
|
"but got %d channels, image_path = %s\n",
|
||||||
|
expected_channel,
|
||||||
|
c,
|
||||||
|
image_path);
|
||||||
|
free(image_buffer);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
if (width <= 0) {
|
||||||
|
fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path);
|
||||||
|
free(image_buffer);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
if (height <= 0) {
|
||||||
|
fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path);
|
||||||
|
free(image_buffer);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize input image ...
|
||||||
|
if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) {
|
||||||
|
float dst_aspect = (float)expected_width / (float)expected_height;
|
||||||
|
float src_aspect = (float)width / (float)height;
|
||||||
|
|
||||||
|
int crop_x = 0, crop_y = 0;
|
||||||
|
int crop_w = width, crop_h = height;
|
||||||
|
|
||||||
|
if (src_aspect > dst_aspect) {
|
||||||
|
crop_w = (int)(height * dst_aspect);
|
||||||
|
crop_x = (width - crop_w) / 2;
|
||||||
|
} else if (src_aspect < dst_aspect) {
|
||||||
|
crop_h = (int)(width / dst_aspect);
|
||||||
|
crop_y = (height - crop_h) / 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (crop_x != 0 || crop_y != 0) {
|
||||||
|
printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path);
|
||||||
|
uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel);
|
||||||
|
if (cropped_image_buffer == NULL) {
|
||||||
|
fprintf(stderr, "error: allocate memory for crop\n");
|
||||||
|
free(image_buffer);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
for (int row = 0; row < crop_h; row++) {
|
||||||
|
uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel;
|
||||||
|
uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel;
|
||||||
|
memcpy(dst, src, crop_w * expected_channel);
|
||||||
|
}
|
||||||
|
|
||||||
|
width = crop_w;
|
||||||
|
height = crop_h;
|
||||||
|
free(image_buffer);
|
||||||
|
image_buffer = cropped_image_buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height);
|
||||||
|
int resized_height = expected_height;
|
||||||
|
int resized_width = expected_width;
|
||||||
|
|
||||||
|
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel);
|
||||||
|
if (resized_image_buffer == NULL) {
|
||||||
|
fprintf(stderr, "error: allocate memory for resize input image\n");
|
||||||
|
free(image_buffer);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
stbir_resize(image_buffer, width, height, 0,
|
||||||
|
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
|
||||||
|
expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0,
|
||||||
|
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
||||||
|
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
||||||
|
STBIR_COLORSPACE_SRGB, nullptr);
|
||||||
|
|
||||||
|
// Save resized result
|
||||||
|
free(image_buffer);
|
||||||
|
image_buffer = resized_image_buffer;
|
||||||
|
}
|
||||||
|
return image_buffer;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, const char* argv[]) {
|
int main(int argc, const char* argv[]) {
|
||||||
SDParams params;
|
SDParams params;
|
||||||
parse_args(argc, argv, params);
|
parse_args(argc, argv, params);
|
||||||
@ -935,120 +1027,101 @@ int main(int argc, const char* argv[]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool vae_decode_only = true;
|
bool vae_decode_only = true;
|
||||||
uint8_t* input_image_buffer = NULL;
|
sd_image_t init_image = {(uint32_t)params.width, (uint32_t)params.height, 3, NULL};
|
||||||
uint8_t* control_image_buffer = NULL;
|
sd_image_t end_image = {(uint32_t)params.width, (uint32_t)params.height, 3, NULL};
|
||||||
uint8_t* mask_image_buffer = NULL;
|
sd_image_t control_image = {(uint32_t)params.width, (uint32_t)params.height, 3, NULL};
|
||||||
|
sd_image_t mask_image = {(uint32_t)params.width, (uint32_t)params.height, 1, NULL};
|
||||||
std::vector<sd_image_t> ref_images;
|
std::vector<sd_image_t> ref_images;
|
||||||
|
|
||||||
if (params.input_path.size() > 0) {
|
auto release_all_resources = [&]() {
|
||||||
|
free(init_image.data);
|
||||||
|
free(end_image.data);
|
||||||
|
free(control_image.data);
|
||||||
|
free(mask_image.data);
|
||||||
|
for (auto ref_image : ref_images) {
|
||||||
|
free(ref_image.data);
|
||||||
|
ref_image.data = NULL;
|
||||||
|
}
|
||||||
|
ref_images.clear();
|
||||||
|
};
|
||||||
|
|
||||||
|
if (params.init_image_path.size() > 0) {
|
||||||
vae_decode_only = false;
|
vae_decode_only = false;
|
||||||
|
|
||||||
int c = 0;
|
int width = 0;
|
||||||
|
int height = 0;
|
||||||
|
init_image.data = load_image(params.init_image_path.c_str(), width, height, params.width, params.height);
|
||||||
|
if (init_image.data == NULL) {
|
||||||
|
fprintf(stderr, "load image from '%s' failed\n", params.init_image_path.c_str());
|
||||||
|
release_all_resources();
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.end_image_path.size() > 0) {
|
||||||
|
vae_decode_only = false;
|
||||||
|
|
||||||
|
int width = 0;
|
||||||
|
int height = 0;
|
||||||
|
end_image.data = load_image(params.end_image_path.c_str(), width, height, params.width, params.height);
|
||||||
|
if (end_image.data == NULL) {
|
||||||
|
fprintf(stderr, "load image from '%s' failed\n", params.end_image_path.c_str());
|
||||||
|
release_all_resources();
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.mask_image_path.size() > 0) {
|
||||||
|
int c = 0;
|
||||||
|
int width = 0;
|
||||||
|
int height = 0;
|
||||||
|
mask_image.data = load_image(params.mask_image_path.c_str(), width, height, params.width, params.height, 1);
|
||||||
|
if (mask_image.data == NULL) {
|
||||||
|
fprintf(stderr, "load image from '%s' failed\n", params.mask_image_path.c_str());
|
||||||
|
release_all_resources();
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mask_image.data = (uint8_t*)malloc(params.width * params.height);
|
||||||
|
memset(mask_image.data, 255, params.width * params.height);
|
||||||
|
if (mask_image.data == NULL) {
|
||||||
|
fprintf(stderr, "malloc mask image failed\n");
|
||||||
|
release_all_resources();
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) {
|
||||||
int width = 0;
|
int width = 0;
|
||||||
int height = 0;
|
int height = 0;
|
||||||
input_image_buffer = stbi_load(params.input_path.c_str(), &width, &height, &c, 3);
|
control_image.data = load_image(params.control_image_path.c_str(), width, height, params.width, params.height);
|
||||||
if (input_image_buffer == NULL) {
|
if (control_image.data == NULL) {
|
||||||
fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str());
|
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
|
||||||
|
release_all_resources();
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
if (c < 3) {
|
if (params.canny_preprocess) { // apply preprocessor
|
||||||
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
|
control_image.data = preprocess_canny(control_image.data,
|
||||||
free(input_image_buffer);
|
control_image.width,
|
||||||
return 1;
|
control_image.height,
|
||||||
}
|
0.08f,
|
||||||
if (width <= 0) {
|
0.08f,
|
||||||
fprintf(stderr, "error: the width of image must be greater than 0\n");
|
0.8f,
|
||||||
free(input_image_buffer);
|
1.0f,
|
||||||
return 1;
|
false);
|
||||||
}
|
|
||||||
if (height <= 0) {
|
|
||||||
fprintf(stderr, "error: the height of image must be greater than 0\n");
|
|
||||||
free(input_image_buffer);
|
|
||||||
return 1;
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Resize input image ...
|
if (params.ref_image_paths.size() > 0) {
|
||||||
if (params.height != height || params.width != width) {
|
|
||||||
float dst_aspect = (float)params.width / (float)params.height;
|
|
||||||
float src_aspect = (float)width / (float)height;
|
|
||||||
|
|
||||||
int crop_x = 0, crop_y = 0;
|
|
||||||
int crop_w = width, crop_h = height;
|
|
||||||
|
|
||||||
if (src_aspect > dst_aspect) {
|
|
||||||
crop_w = (int)(height * dst_aspect);
|
|
||||||
crop_x = (width - crop_w) / 2;
|
|
||||||
} else if (src_aspect < dst_aspect) {
|
|
||||||
crop_h = (int)(width / dst_aspect);
|
|
||||||
crop_y = (height - crop_h) / 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (crop_x != 0 || crop_y != 0) {
|
|
||||||
printf("crop input image from %dx%d to %dx%d\n", width, height, crop_w, crop_h);
|
|
||||||
uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * 3);
|
|
||||||
if (cropped_image_buffer == NULL) {
|
|
||||||
fprintf(stderr, "error: allocate memory for crop\n");
|
|
||||||
free(input_image_buffer);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
for (int row = 0; row < crop_h; row++) {
|
|
||||||
uint8_t* src = input_image_buffer + ((crop_y + row) * width + crop_x) * 3;
|
|
||||||
uint8_t* dst = cropped_image_buffer + (row * crop_w) * 3;
|
|
||||||
memcpy(dst, src, crop_w * 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
width = crop_w;
|
|
||||||
height = crop_h;
|
|
||||||
free(input_image_buffer);
|
|
||||||
input_image_buffer = cropped_image_buffer;
|
|
||||||
}
|
|
||||||
|
|
||||||
printf("resize input image from %dx%d to %dx%d\n", width, height, params.width, params.height);
|
|
||||||
int resized_height = params.height;
|
|
||||||
int resized_width = params.width;
|
|
||||||
|
|
||||||
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3);
|
|
||||||
if (resized_image_buffer == NULL) {
|
|
||||||
fprintf(stderr, "error: allocate memory for resize input image\n");
|
|
||||||
free(input_image_buffer);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
stbir_resize(input_image_buffer, width, height, 0,
|
|
||||||
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
|
|
||||||
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
|
|
||||||
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
|
|
||||||
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
|
|
||||||
STBIR_COLORSPACE_SRGB, nullptr);
|
|
||||||
|
|
||||||
// Save resized result
|
|
||||||
free(input_image_buffer);
|
|
||||||
input_image_buffer = resized_image_buffer;
|
|
||||||
}
|
|
||||||
} else if (params.ref_image_paths.size() > 0) {
|
|
||||||
vae_decode_only = false;
|
vae_decode_only = false;
|
||||||
for (auto& path : params.ref_image_paths) {
|
for (auto& path : params.ref_image_paths) {
|
||||||
int c = 0;
|
|
||||||
int width = 0;
|
int width = 0;
|
||||||
int height = 0;
|
int height = 0;
|
||||||
uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3);
|
uint8_t* image_buffer = load_image(path.c_str(), width, height);
|
||||||
if (image_buffer == NULL) {
|
if (image_buffer == NULL) {
|
||||||
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
|
fprintf(stderr, "load image from '%s' failed\n", path.c_str());
|
||||||
return 1;
|
release_all_resources();
|
||||||
}
|
|
||||||
if (c < 3) {
|
|
||||||
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
|
|
||||||
free(image_buffer);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (width <= 0) {
|
|
||||||
fprintf(stderr, "error: the width of image must be greater than 0\n");
|
|
||||||
free(image_buffer);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
if (height <= 0) {
|
|
||||||
fprintf(stderr, "error: the height of image must be greater than 0\n");
|
|
||||||
free(image_buffer);
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
ref_images.push_back({(uint32_t)width,
|
ref_images.push_back({(uint32_t)width,
|
||||||
@ -1098,50 +1171,10 @@ int main(int argc, const char* argv[]) {
|
|||||||
|
|
||||||
if (sd_ctx == NULL) {
|
if (sd_ctx == NULL) {
|
||||||
printf("new_sd_ctx_t failed\n");
|
printf("new_sd_ctx_t failed\n");
|
||||||
|
release_all_resources();
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
sd_image_t input_image = {(uint32_t)params.width,
|
|
||||||
(uint32_t)params.height,
|
|
||||||
3,
|
|
||||||
input_image_buffer};
|
|
||||||
|
|
||||||
sd_image_t* control_image = NULL;
|
|
||||||
if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) {
|
|
||||||
int c = 0;
|
|
||||||
control_image_buffer = stbi_load(params.control_image_path.c_str(), ¶ms.width, ¶ms.height, &c, 3);
|
|
||||||
if (control_image_buffer == NULL) {
|
|
||||||
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
control_image = new sd_image_t{(uint32_t)params.width,
|
|
||||||
(uint32_t)params.height,
|
|
||||||
3,
|
|
||||||
control_image_buffer};
|
|
||||||
if (params.canny_preprocess) { // apply preprocessor
|
|
||||||
control_image->data = preprocess_canny(control_image->data,
|
|
||||||
control_image->width,
|
|
||||||
control_image->height,
|
|
||||||
0.08f,
|
|
||||||
0.08f,
|
|
||||||
0.8f,
|
|
||||||
1.0f,
|
|
||||||
false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint8_t> default_mask_image_vec(params.width * params.height, 255);
|
|
||||||
if (params.mask_path != "") {
|
|
||||||
int c = 0;
|
|
||||||
mask_image_buffer = stbi_load(params.mask_path.c_str(), ¶ms.width, ¶ms.height, &c, 1);
|
|
||||||
} else {
|
|
||||||
mask_image_buffer = default_mask_image_vec.data();
|
|
||||||
}
|
|
||||||
sd_image_t mask_image = {(uint32_t)params.width,
|
|
||||||
(uint32_t)params.height,
|
|
||||||
1,
|
|
||||||
mask_image_buffer};
|
|
||||||
|
|
||||||
sd_image_t* results;
|
sd_image_t* results;
|
||||||
int num_results = 1;
|
int num_results = 1;
|
||||||
if (params.mode == IMG_GEN) {
|
if (params.mode == IMG_GEN) {
|
||||||
@ -1149,7 +1182,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
params.prompt.c_str(),
|
params.prompt.c_str(),
|
||||||
params.negative_prompt.c_str(),
|
params.negative_prompt.c_str(),
|
||||||
params.clip_skip,
|
params.clip_skip,
|
||||||
input_image,
|
init_image,
|
||||||
ref_images.data(),
|
ref_images.data(),
|
||||||
(int)ref_images.size(),
|
(int)ref_images.size(),
|
||||||
mask_image,
|
mask_image,
|
||||||
@ -1173,7 +1206,8 @@ int main(int argc, const char* argv[]) {
|
|||||||
params.prompt.c_str(),
|
params.prompt.c_str(),
|
||||||
params.negative_prompt.c_str(),
|
params.negative_prompt.c_str(),
|
||||||
params.clip_skip,
|
params.clip_skip,
|
||||||
input_image,
|
init_image,
|
||||||
|
end_image,
|
||||||
params.width,
|
params.width,
|
||||||
params.height,
|
params.height,
|
||||||
params.sample_params,
|
params.sample_params,
|
||||||
@ -1275,8 +1309,8 @@ int main(int argc, const char* argv[]) {
|
|||||||
}
|
}
|
||||||
free(results);
|
free(results);
|
||||||
free_sd_ctx(sd_ctx);
|
free_sd_ctx(sd_ctx);
|
||||||
free(control_image_buffer);
|
|
||||||
free(input_image_buffer);
|
release_all_resources();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -401,7 +401,7 @@ public:
|
|||||||
version,
|
version,
|
||||||
sd_ctx_params->diffusion_flash_attn);
|
sd_ctx_params->diffusion_flash_attn);
|
||||||
}
|
}
|
||||||
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
|
if (diffusion_model->get_desc() == "Wan2.1-I2V-14B" || diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||||
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
|
clip_vision = std::make_shared<FrozenCLIPVisionEmbedder>(backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
model_loader.tensor_storages_types);
|
model_loader.tensor_storages_types);
|
||||||
@ -1780,7 +1780,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
const std::vector<float>& sigmas,
|
const std::vector<float>& sigmas,
|
||||||
int64_t seed,
|
int64_t seed,
|
||||||
int batch_count,
|
int batch_count,
|
||||||
const sd_image_t* control_cond,
|
sd_image_t control_image,
|
||||||
float control_strength,
|
float control_strength,
|
||||||
float style_ratio,
|
float style_ratio,
|
||||||
bool normalize_input,
|
bool normalize_input,
|
||||||
@ -1947,9 +1947,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
|
|
||||||
// Control net hint
|
// Control net hint
|
||||||
struct ggml_tensor* image_hint = NULL;
|
struct ggml_tensor* image_hint = NULL;
|
||||||
if (control_cond != NULL) {
|
if (control_image.data != NULL) {
|
||||||
image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
|
image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
|
||||||
sd_image_to_tensor(control_cond->data, image_hint);
|
sd_image_to_tensor(control_image.data, image_hint);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sample
|
// Sample
|
||||||
@ -2342,7 +2342,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
sigmas,
|
sigmas,
|
||||||
seed,
|
seed,
|
||||||
sd_img_gen_params->batch_count,
|
sd_img_gen_params->batch_count,
|
||||||
sd_img_gen_params->control_cond,
|
sd_img_gen_params->control_image,
|
||||||
sd_img_gen_params->control_strength,
|
sd_img_gen_params->control_strength,
|
||||||
sd_img_gen_params->style_strength,
|
sd_img_gen_params->style_strength,
|
||||||
sd_img_gen_params->normalize_input,
|
sd_img_gen_params->normalize_input,
|
||||||
@ -2413,38 +2413,53 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
ggml_tensor* concat_latent = NULL;
|
ggml_tensor* concat_latent = NULL;
|
||||||
ggml_tensor* denoise_mask = NULL;
|
ggml_tensor* denoise_mask = NULL;
|
||||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
|
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
|
||||||
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B") {
|
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-I2V-14B" ||
|
||||||
|
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||||
LOG_INFO("IMG2VID");
|
LOG_INFO("IMG2VID");
|
||||||
|
|
||||||
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B") {
|
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-I2V-14B" ||
|
||||||
|
sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||||
if (sd_vid_gen_params->init_image.data) {
|
if (sd_vid_gen_params->init_image.data) {
|
||||||
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2);
|
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2);
|
||||||
} else {
|
} else {
|
||||||
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2, true);
|
clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->init_image, false, -2, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.1-FLF2V-14B") {
|
||||||
|
ggml_tensor* end_image_clip_vision_output = NULL;
|
||||||
|
if (sd_vid_gen_params->end_image.data) {
|
||||||
|
end_image_clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->end_image, false, -2);
|
||||||
|
} else {
|
||||||
|
end_image_clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, sd_vid_gen_params->end_image, false, -2, true);
|
||||||
|
}
|
||||||
|
clip_vision_output = ggml_tensor_concat(work_ctx, clip_vision_output, end_image_clip_vision_output, 1);
|
||||||
|
}
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0);
|
LOG_INFO("get_clip_vision_output completed, taking %" PRId64 " ms", t1 - t0);
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t t1 = ggml_time_ms();
|
int64_t t1 = ggml_time_ms();
|
||||||
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
|
ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, frames, 3);
|
||||||
for (int i3 = 0; i3 < init_img->ne[3]; i3++) { // channels
|
for (int i3 = 0; i3 < image->ne[3]; i3++) { // channels
|
||||||
for (int i2 = 0; i2 < init_img->ne[2]; i2++) {
|
for (int i2 = 0; i2 < image->ne[2]; i2++) {
|
||||||
for (int i1 = 0; i1 < init_img->ne[1]; i1++) { // height
|
for (int i1 = 0; i1 < image->ne[1]; i1++) { // height
|
||||||
for (int i0 = 0; i0 < init_img->ne[0]; i0++) { // width
|
for (int i0 = 0; i0 < image->ne[0]; i0++) { // width
|
||||||
float value = 0.5f;
|
float value = 0.5f;
|
||||||
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
|
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
|
||||||
value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3);
|
value = *(sd_vid_gen_params->init_image.data + i1 * width * 3 + i0 * 3 + i3);
|
||||||
value /= 255.f;
|
value /= 255.f;
|
||||||
|
} else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data) {
|
||||||
|
value = *(sd_vid_gen_params->end_image.data + i1 * width * 3 + i0 * 3 + i3);
|
||||||
|
value /= 255.f;
|
||||||
}
|
}
|
||||||
ggml_tensor_set_f32(init_img, value, i0, i1, i2, i3);
|
ggml_tensor_set_f32(image, value, i0, i1, i2, i3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); // [b*c, t, h/8, w/8]
|
concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/8, w/8]
|
||||||
|
|
||||||
int64_t t2 = ggml_time_ms();
|
int64_t t2 = ggml_time_ms();
|
||||||
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||||
@ -2464,6 +2479,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
float value = 0.0f;
|
float value = 0.0f;
|
||||||
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
|
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
|
||||||
value = 1.0f;
|
value = 1.0f;
|
||||||
|
} else if (i2 == frames - 1 && sd_vid_gen_params->end_image.data && i3 == 3) {
|
||||||
|
value = 1.0f;
|
||||||
}
|
}
|
||||||
ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3);
|
ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -188,7 +188,7 @@ typedef struct {
|
|||||||
float strength;
|
float strength;
|
||||||
int64_t seed;
|
int64_t seed;
|
||||||
int batch_count;
|
int batch_count;
|
||||||
const sd_image_t* control_cond;
|
sd_image_t control_image;
|
||||||
float control_strength;
|
float control_strength;
|
||||||
float style_strength;
|
float style_strength;
|
||||||
bool normalize_input;
|
bool normalize_input;
|
||||||
@ -200,6 +200,7 @@ typedef struct {
|
|||||||
const char* negative_prompt;
|
const char* negative_prompt;
|
||||||
int clip_skip;
|
int clip_skip;
|
||||||
sd_image_t init_image;
|
sd_image_t init_image;
|
||||||
|
sd_image_t end_image;
|
||||||
int width;
|
int width;
|
||||||
int height;
|
int height;
|
||||||
sd_sample_params_t sample_params;
|
sd_sample_params_t sample_params;
|
||||||
|
|||||||
9
wan.hpp
9
wan.hpp
@ -1934,6 +1934,9 @@ namespace WAN {
|
|||||||
if (tensor_name.find("img_emb") != std::string::npos) {
|
if (tensor_name.find("img_emb") != std::string::npos) {
|
||||||
wan_params.model_type = "i2v";
|
wan_params.model_type = "i2v";
|
||||||
}
|
}
|
||||||
|
if (tensor_name.find("img_emb.emb_pos") != std::string::npos) {
|
||||||
|
wan_params.flf_pos_embed_token_number = 514;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (wan_params.num_layers == 30) {
|
if (wan_params.num_layers == 30) {
|
||||||
@ -1968,8 +1971,12 @@ namespace WAN {
|
|||||||
wan_params.in_dim = 16;
|
wan_params.in_dim = 16;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
desc = "Wan2.1-I2V-14B";
|
|
||||||
wan_params.in_dim = 36;
|
wan_params.in_dim = 36;
|
||||||
|
if (wan_params.flf_pos_embed_token_number > 0) {
|
||||||
|
desc = "Wan2.1-FLF2V-14B";
|
||||||
|
} else {
|
||||||
|
desc = "Wan2.1-I2V-14B";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
wan_params.dim = 5120;
|
wan_params.dim = 5120;
|
||||||
wan_params.eps = 1e-06;
|
wan_params.eps = 1e-06;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user