unify image loading processing

This commit is contained in:
leejet 2025-08-31 18:40:04 +08:00
parent 50f921119e
commit 33ff442c1d
3 changed files with 232 additions and 198 deletions

View File

@ -69,7 +69,8 @@ struct SDParams {
std::string tensor_type_rules; std::string tensor_type_rules;
std::string lora_model_dir; std::string lora_model_dir;
std::string output_path = "output.png"; std::string output_path = "output.png";
std::string input_path; std::string init_image_path;
std::string end_image_path;
std::string mask_path; std::string mask_path;
std::string control_image_path; std::string control_image_path;
std::vector<std::string> ref_image_paths; std::vector<std::string> ref_image_paths;
@ -143,9 +144,10 @@ void print_params(SDParams params) {
printf(" style ratio: %.2f\n", params.style_ratio); printf(" style ratio: %.2f\n", params.style_ratio);
printf(" normalize input image: %s\n", params.normalize_input ? "true" : "false"); printf(" normalize input image: %s\n", params.normalize_input ? "true" : "false");
printf(" output_path: %s\n", params.output_path.c_str()); printf(" output_path: %s\n", params.output_path.c_str());
printf(" init_img: %s\n", params.input_path.c_str()); printf(" init_image_path: %s\n", params.init_image_path.c_str());
printf(" mask_img: %s\n", params.mask_path.c_str()); printf(" end_image_path: %s\n", params.end_image_path.c_str());
printf(" control_image: %s\n", params.control_image_path.c_str()); printf(" mask_image_path: %s\n", params.mask_path.c_str());
printf(" control_image_path: %s\n", params.control_image_path.c_str());
printf(" ref_images_paths:\n"); printf(" ref_images_paths:\n");
for (auto& path : params.ref_image_paths) { for (auto& path : params.ref_image_paths) {
printf(" %s\n", path.c_str()); printf(" %s\n", path.c_str());
@ -153,11 +155,11 @@ void print_params(SDParams params) {
printf(" offload_params_to_cpu: %s\n", params.offload_params_to_cpu ? "true" : "false"); printf(" offload_params_to_cpu: %s\n", params.offload_params_to_cpu ? "true" : "false");
printf(" clip_on_cpu: %s\n", params.clip_on_cpu ? "true" : "false"); printf(" clip_on_cpu: %s\n", params.clip_on_cpu ? "true" : "false");
printf(" control_net_cpu: %s\n", params.control_net_cpu ? "true" : "false"); printf(" control_net_cpu: %s\n", params.control_net_cpu ? "true" : "false");
printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false"); printf(" vae_on_cpu: %s\n", params.vae_on_cpu ? "true" : "false");
printf(" diffusion flash attention: %s\n", params.diffusion_flash_attn ? "true" : "false"); printf(" diffusion flash attention: %s\n", params.diffusion_flash_attn ? "true" : "false");
printf(" diffusion Conv2d direct: %s\n", params.diffusion_conv_direct ? "true" : "false"); printf(" diffusion Conv2d direct: %s\n", params.diffusion_conv_direct ? "true" : "false");
printf(" vae Conv2d direct:%s\n", params.vae_conv_direct ? "true" : "false"); printf(" vae_conv_direct: %s\n", params.vae_conv_direct ? "true" : "false");
printf(" strength(control): %.2f\n", params.control_strength); printf(" control_strength: %.2f\n", params.control_strength);
printf(" prompt: %s\n", params.prompt.c_str()); printf(" prompt: %s\n", params.prompt.c_str());
printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str());
printf(" clip_skip: %d\n", params.clip_skip); printf(" clip_skip: %d\n", params.clip_skip);
@ -449,7 +451,8 @@ void parse_args(int argc, const char** argv, SDParams& params) {
{"", "--embd-dir", "", &params.embedding_dir}, {"", "--embd-dir", "", &params.embedding_dir},
{"", "--stacked-id-embd-dir", "", &params.stacked_id_embed_dir}, {"", "--stacked-id-embd-dir", "", &params.stacked_id_embed_dir},
{"", "--lora-model-dir", "", &params.lora_model_dir}, {"", "--lora-model-dir", "", &params.lora_model_dir},
{"-i", "--init-img", "", &params.input_path}, {"-i", "--init-img", "", &params.init_image_path},
{"", "--end-img", "", &params.end_image_path},
{"", "--tensor-type-rules", "", &params.tensor_type_rules}, {"", "--tensor-type-rules", "", &params.tensor_type_rules},
{"", "--input-id-images-dir", "", &params.input_id_images_path}, {"", "--input-id-images-dir", "", &params.input_id_images_path},
{"", "--mask", "", &params.mask_path}, {"", "--mask", "", &params.mask_path},
@ -902,6 +905,94 @@ void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) {
fflush(out_stream); fflush(out_stream);
} }
uint8_t* load_image(const char* image_path, int& width, int& height, int expected_width = 0, int expected_height = 0, int expected_channel = 3) {
int c = 0;
uint8_t* image_buffer = (uint8_t*)stbi_load(image_path, &width, &height, &c, expected_channel);
if (image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", image_path);
return NULL;
}
if (c < expected_channel) {
fprintf(stderr,
"the number of channels for the input image must be >= %d,"
"but got %d channels, image_path = %s\n",
expected_channel,
c,
image_path);
free(image_buffer);
return NULL;
}
if (width <= 0) {
fprintf(stderr, "error: the width of image must be greater than 0, image_path = %s\n", image_path);
free(image_buffer);
return NULL;
}
if (height <= 0) {
fprintf(stderr, "error: the height of image must be greater than 0, image_path = %s\n", image_path);
free(image_buffer);
return NULL;
}
// Resize input image ...
if ((expected_width > 0 && expected_height > 0) && (height != expected_height || width != expected_width)) {
float dst_aspect = (float)expected_width / (float)expected_height;
float src_aspect = (float)width / (float)height;
int crop_x = 0, crop_y = 0;
int crop_w = width, crop_h = height;
if (src_aspect > dst_aspect) {
crop_w = (int)(height * dst_aspect);
crop_x = (width - crop_w) / 2;
} else if (src_aspect < dst_aspect) {
crop_h = (int)(width / dst_aspect);
crop_y = (height - crop_h) / 2;
}
if (crop_x != 0 || crop_y != 0) {
printf("crop input image from %dx%d to %dx%d, image_path = %s\n", width, height, crop_w, crop_h, image_path);
uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * expected_channel);
if (cropped_image_buffer == NULL) {
fprintf(stderr, "error: allocate memory for crop\n");
free(image_buffer);
return NULL;
}
for (int row = 0; row < crop_h; row++) {
uint8_t* src = image_buffer + ((crop_y + row) * width + crop_x) * expected_channel;
uint8_t* dst = cropped_image_buffer + (row * crop_w) * expected_channel;
memcpy(dst, src, crop_w * expected_channel);
}
width = crop_w;
height = crop_h;
free(image_buffer);
image_buffer = cropped_image_buffer;
}
printf("resize input image from %dx%d to %dx%d\n", width, height, expected_width, expected_height);
int resized_height = expected_height;
int resized_width = expected_width;
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * expected_channel);
if (resized_image_buffer == NULL) {
fprintf(stderr, "error: allocate memory for resize input image\n");
free(image_buffer);
return NULL;
}
stbir_resize(image_buffer, width, height, 0,
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
expected_channel, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
STBIR_COLORSPACE_SRGB, nullptr);
// Save resized result
free(image_buffer);
image_buffer = resized_image_buffer;
}
return image_buffer;
}
int main(int argc, const char* argv[]) { int main(int argc, const char* argv[]) {
SDParams params; SDParams params;
parse_args(argc, argv, params); parse_args(argc, argv, params);
@ -936,119 +1027,100 @@ int main(int argc, const char* argv[]) {
} }
bool vae_decode_only = true; bool vae_decode_only = true;
uint8_t* input_image_buffer = NULL; sd_image_t init_image = {(uint32_t)params.width, (uint32_t)params.height, 3, NULL};
uint8_t* control_image_buffer = NULL; sd_image_t end_image = {(uint32_t)params.width, (uint32_t)params.height, 3, NULL};
uint8_t* mask_image_buffer = NULL; sd_image_t control_image = {(uint32_t)params.width, (uint32_t)params.height, 3, NULL};
sd_image_t mask_image = {(uint32_t)params.width, (uint32_t)params.height, 1, NULL};
std::vector<sd_image_t> ref_images; std::vector<sd_image_t> ref_images;
if (params.input_path.size() > 0) { auto release_all_resources = [&]() {
free(init_image.data);
free(end_image.data);
free(control_image.data);
free(mask_image.data);
for (auto ref_image : ref_images) {
free(ref_image.data);
ref_image.data = NULL;
}
ref_images.clear();
};
if (params.init_image_path.size() > 0) {
vae_decode_only = false; vae_decode_only = false;
int width = 0;
int height = 0;
init_image.data = load_image(params.init_image_path.c_str(), width, height, params.width, params.height);
if (init_image.data == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.init_image_path.c_str());
release_all_resources();
return 1;
}
}
if (params.end_image_path.size() > 0) {
vae_decode_only = false;
int width = 0;
int height = 0;
end_image.data = load_image(params.end_image_path.c_str(), width, height, params.width, params.height);
if (end_image.data == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.end_image_path.c_str());
release_all_resources();
return 1;
}
}
if (params.mask_path.size() > 0) {
int c = 0; int c = 0;
int width = 0; int width = 0;
int height = 0; int height = 0;
input_image_buffer = stbi_load(params.input_path.c_str(), &width, &height, &c, 3); mask_image.data = load_image(params.mask_path.c_str(), width, height, params.width, params.height, 1);
if (input_image_buffer == NULL) { if (mask_image.data == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.input_path.c_str()); fprintf(stderr, "load image from '%s' failed\n", params.mask_path.c_str());
release_all_resources();
return 1; return 1;
} }
if (c < 3) { } else {
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); mask_image.data = (uint8_t*)malloc(params.width * params.height);
free(input_image_buffer); memset(mask_image.data, 255, params.width * params.height);
if (mask_image.data == NULL) {
fprintf(stderr, "malloc mask image failed\n");
release_all_resources();
return 1; return 1;
} }
if (width <= 0) { }
fprintf(stderr, "error: the width of image must be greater than 0\n");
free(input_image_buffer); if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) {
int width = 0;
int height = 0;
control_image.data = load_image(params.control_image_path.c_str(), width, height, params.width, params.height);
if (control_image.data == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
release_all_resources();
return 1; return 1;
} }
if (height <= 0) { if (params.canny_preprocess) { // apply preprocessor
fprintf(stderr, "error: the height of image must be greater than 0\n"); control_image.data = preprocess_canny(control_image.data,
free(input_image_buffer); control_image.width,
return 1; control_image.height,
0.08f,
0.08f,
0.8f,
1.0f,
false);
}
} }
// Resize input image ... if (params.ref_image_paths.size() > 0) {
if (params.height != height || params.width != width) {
float dst_aspect = (float)params.width / (float)params.height;
float src_aspect = (float)width / (float)height;
int crop_x = 0, crop_y = 0;
int crop_w = width, crop_h = height;
if (src_aspect > dst_aspect) {
crop_w = (int)(height * dst_aspect);
crop_x = (width - crop_w) / 2;
} else if (src_aspect < dst_aspect) {
crop_h = (int)(width / dst_aspect);
crop_y = (height - crop_h) / 2;
}
if (crop_x != 0 || crop_y != 0) {
printf("crop input image from %dx%d to %dx%d\n", width, height, crop_w, crop_h);
uint8_t* cropped_image_buffer = (uint8_t*)malloc(crop_w * crop_h * 3);
if (cropped_image_buffer == NULL) {
fprintf(stderr, "error: allocate memory for crop\n");
free(input_image_buffer);
return 1;
}
for (int row = 0; row < crop_h; row++) {
uint8_t* src = input_image_buffer + ((crop_y + row) * width + crop_x) * 3;
uint8_t* dst = cropped_image_buffer + (row * crop_w) * 3;
memcpy(dst, src, crop_w * 3);
}
width = crop_w;
height = crop_h;
free(input_image_buffer);
input_image_buffer = cropped_image_buffer;
}
printf("resize input image from %dx%d to %dx%d\n", width, height, params.width, params.height);
int resized_height = params.height;
int resized_width = params.width;
uint8_t* resized_image_buffer = (uint8_t*)malloc(resized_height * resized_width * 3);
if (resized_image_buffer == NULL) {
fprintf(stderr, "error: allocate memory for resize input image\n");
free(input_image_buffer);
return 1;
}
stbir_resize(input_image_buffer, width, height, 0,
resized_image_buffer, resized_width, resized_height, 0, STBIR_TYPE_UINT8,
3 /*RGB channel*/, STBIR_ALPHA_CHANNEL_NONE, 0,
STBIR_EDGE_CLAMP, STBIR_EDGE_CLAMP,
STBIR_FILTER_BOX, STBIR_FILTER_BOX,
STBIR_COLORSPACE_SRGB, nullptr);
// Save resized result
free(input_image_buffer);
input_image_buffer = resized_image_buffer;
}
} else if (params.ref_image_paths.size() > 0) {
vae_decode_only = false; vae_decode_only = false;
for (auto& path : params.ref_image_paths) { for (auto& path : params.ref_image_paths) {
int c = 0;
int width = 0; int width = 0;
int height = 0; int height = 0;
uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3); uint8_t* image_buffer = load_image(path.c_str(), width, height);
if (image_buffer == NULL) { if (image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", path.c_str()); fprintf(stderr, "load image from '%s' failed\n", path.c_str());
return 1; release_all_resources();
}
if (c < 3) {
fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c);
free(image_buffer);
return 1;
}
if (width <= 0) {
fprintf(stderr, "error: the width of image must be greater than 0\n");
free(image_buffer);
return 1;
}
if (height <= 0) {
fprintf(stderr, "error: the height of image must be greater than 0\n");
free(image_buffer);
return 1; return 1;
} }
ref_images.push_back({(uint32_t)width, ref_images.push_back({(uint32_t)width,
@ -1098,50 +1170,10 @@ int main(int argc, const char* argv[]) {
if (sd_ctx == NULL) { if (sd_ctx == NULL) {
printf("new_sd_ctx_t failed\n"); printf("new_sd_ctx_t failed\n");
release_all_resources();
return 1; return 1;
} }
sd_image_t input_image = {(uint32_t)params.width,
(uint32_t)params.height,
3,
input_image_buffer};
sd_image_t* control_image = NULL;
if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) {
int c = 0;
control_image_buffer = stbi_load(params.control_image_path.c_str(), &params.width, &params.height, &c, 3);
if (control_image_buffer == NULL) {
fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str());
return 1;
}
control_image = new sd_image_t{(uint32_t)params.width,
(uint32_t)params.height,
3,
control_image_buffer};
if (params.canny_preprocess) { // apply preprocessor
control_image->data = preprocess_canny(control_image->data,
control_image->width,
control_image->height,
0.08f,
0.08f,
0.8f,
1.0f,
false);
}
}
std::vector<uint8_t> default_mask_image_vec(params.width * params.height, 255);
if (params.mask_path != "") {
int c = 0;
mask_image_buffer = stbi_load(params.mask_path.c_str(), &params.width, &params.height, &c, 1);
} else {
mask_image_buffer = default_mask_image_vec.data();
}
sd_image_t mask_image = {(uint32_t)params.width,
(uint32_t)params.height,
1,
mask_image_buffer};
sd_image_t* results; sd_image_t* results;
int num_results = 1; int num_results = 1;
if (params.mode == IMG_GEN) { if (params.mode == IMG_GEN) {
@ -1149,7 +1181,7 @@ int main(int argc, const char* argv[]) {
params.prompt.c_str(), params.prompt.c_str(),
params.negative_prompt.c_str(), params.negative_prompt.c_str(),
params.clip_skip, params.clip_skip,
input_image, init_image,
ref_images.data(), ref_images.data(),
(int)ref_images.size(), (int)ref_images.size(),
mask_image, mask_image,
@ -1173,7 +1205,8 @@ int main(int argc, const char* argv[]) {
params.prompt.c_str(), params.prompt.c_str(),
params.negative_prompt.c_str(), params.negative_prompt.c_str(),
params.clip_skip, params.clip_skip,
input_image, init_image,
end_image,
params.width, params.width,
params.height, params.height,
params.sample_params, params.sample_params,
@ -1275,8 +1308,8 @@ int main(int argc, const char* argv[]) {
} }
free(results); free(results);
free_sd_ctx(sd_ctx); free_sd_ctx(sd_ctx);
free(control_image_buffer);
free(input_image_buffer); release_all_resources();
return 0; return 0;
} }

View File

@ -1780,7 +1780,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
const std::vector<float>& sigmas, const std::vector<float>& sigmas,
int64_t seed, int64_t seed,
int batch_count, int batch_count,
const sd_image_t* control_cond, sd_image_t control_image,
float control_strength, float control_strength,
float style_ratio, float style_ratio,
bool normalize_input, bool normalize_input,
@ -1947,9 +1947,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
// Control net hint // Control net hint
struct ggml_tensor* image_hint = NULL; struct ggml_tensor* image_hint = NULL;
if (control_cond != NULL) { if (control_image.data != NULL) {
image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_image_to_tensor(control_cond->data, image_hint); sd_image_to_tensor(control_image.data, image_hint);
} }
// Sample // Sample
@ -2342,7 +2342,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sigmas, sigmas,
seed, seed,
sd_img_gen_params->batch_count, sd_img_gen_params->batch_count,
sd_img_gen_params->control_cond, sd_img_gen_params->control_image,
sd_img_gen_params->control_strength, sd_img_gen_params->control_strength,
sd_img_gen_params->style_strength, sd_img_gen_params->style_strength,
sd_img_gen_params->normalize_input, sd_img_gen_params->normalize_input,

View File

@ -188,7 +188,7 @@ typedef struct {
float strength; float strength;
int64_t seed; int64_t seed;
int batch_count; int batch_count;
const sd_image_t* control_cond; sd_image_t control_image;
float control_strength; float control_strength;
float style_strength; float style_strength;
bool normalize_input; bool normalize_input;
@ -200,6 +200,7 @@ typedef struct {
const char* negative_prompt; const char* negative_prompt;
int clip_skip; int clip_skip;
sd_image_t init_image; sd_image_t init_image;
sd_image_t end_image;
int width; int width;
int height; int height;
sd_sample_params_t sample_params; sd_sample_params_t sample_params;