feat: use image width and height when not explicitly set (#1206)

This commit is contained in:
leejet 2026-01-22 23:54:41 +08:00 committed by GitHub
parent 329571131d
commit 5e4579c11d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 92 additions and 65 deletions

View File

@ -245,7 +245,7 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", "; parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", "; parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
parameter_string += "Seed: " + std::to_string(seed) + ", "; parameter_string += "Seed: " + std::to_string(seed) + ", ";
parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", "; parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", ";
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", "; parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", "; parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) { if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
@ -526,10 +526,10 @@ int main(int argc, const char* argv[]) {
} }
bool vae_decode_only = true; bool vae_decode_only = true;
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t end_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t end_image = {0, 0, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr}; sd_image_t control_image = {0, 0, 3, nullptr};
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr}; sd_image_t mask_image = {0, 0, 1, nullptr};
std::vector<sd_image_t> ref_images; std::vector<sd_image_t> ref_images;
std::vector<sd_image_t> pmid_images; std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> control_frames; std::vector<sd_image_t> control_frames;
@ -556,57 +556,79 @@ int main(int argc, const char* argv[]) {
control_frames.clear(); control_frames.clear();
}; };
auto load_image_and_update_size = [&](const std::string& path,
sd_image_t& image,
bool resize_image = true,
int expected_channel = 3) -> bool {
int expected_width = 0;
int expected_height = 0;
if (resize_image && gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return false;
}
gen_params.set_width_and_height_if_unset(image.width, image.height);
return true;
};
if (gen_params.init_image_path.size() > 0) { if (gen_params.init_image_path.size() > 0) {
vae_decode_only = false; vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, init_image)) {
int width = 0;
int height = 0;
init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (init_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str());
release_all_resources();
return 1; return 1;
} }
} }
if (gen_params.end_image_path.size() > 0) { if (gen_params.end_image_path.size() > 0) {
vae_decode_only = false; vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, end_image)) {
int width = 0;
int height = 0;
end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
if (end_image.data == nullptr) {
LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str());
release_all_resources();
return 1; return 1;
} }
} }
if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) {
sd_image_t ref_image = {0, 0, 3, nullptr};
if (!load_image_and_update_size(path, ref_image, false)) {
return 1;
}
ref_images.push_back(ref_image);
}
}
if (gen_params.mask_image_path.size() > 0) { if (gen_params.mask_image_path.size() > 0) {
int c = 0; if (load_sd_image_from_file(&mask_image,
int width = 0; gen_params.mask_image_path.c_str(),
int height = 0; gen_params.get_resolved_width(),
mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1); gen_params.get_resolved_height(),
if (mask_image.data == nullptr) { 1)) {
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str()); LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
release_all_resources(); release_all_resources();
return 1; return 1;
} }
} else { } else {
mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height); mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
if (mask_image.data == nullptr) { if (mask_image.data == nullptr) {
LOG_ERROR("malloc mask image failed"); LOG_ERROR("malloc mask image failed");
release_all_resources(); release_all_resources();
return 1; return 1;
} }
memset(mask_image.data, 255, gen_params.width * gen_params.height); mask_image.width = gen_params.get_resolved_width();
mask_image.height = gen_params.get_resolved_height();
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
} }
if (gen_params.control_image_path.size() > 0) { if (gen_params.control_image_path.size() > 0) {
int width = 0; if (load_sd_image_from_file(&control_image,
int height = 0; gen_params.control_image_path.c_str(),
control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height); gen_params.get_resolved_width(),
if (control_image.data == nullptr) { gen_params.get_resolved_height())) {
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str()); LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
release_all_resources(); release_all_resources();
return 1; return 1;
@ -621,29 +643,11 @@ int main(int argc, const char* argv[]) {
} }
} }
if (gen_params.ref_image_paths.size() > 0) {
vae_decode_only = false;
for (auto& path : gen_params.ref_image_paths) {
int width = 0;
int height = 0;
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height);
if (image_buffer == nullptr) {
LOG_ERROR("load image from '%s' failed", path.c_str());
release_all_resources();
return 1;
}
ref_images.push_back({(uint32_t)width,
(uint32_t)height,
3,
image_buffer});
}
}
if (!gen_params.control_video_path.empty()) { if (!gen_params.control_video_path.empty()) {
if (!load_images_from_dir(gen_params.control_video_path, if (!load_images_from_dir(gen_params.control_video_path,
control_frames, control_frames,
gen_params.width, gen_params.get_resolved_width(),
gen_params.height, gen_params.get_resolved_height(),
gen_params.video_frames, gen_params.video_frames,
cli_params.verbose)) { cli_params.verbose)) {
release_all_resources(); release_all_resources();
@ -717,8 +721,8 @@ int main(int argc, const char* argv[]) {
gen_params.auto_resize_ref_image, gen_params.auto_resize_ref_image,
gen_params.increase_ref_index, gen_params.increase_ref_index,
mask_image, mask_image,
gen_params.width, gen_params.get_resolved_width(),
gen_params.height, gen_params.get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.strength, gen_params.strength,
gen_params.seed, gen_params.seed,
@ -748,8 +752,8 @@ int main(int argc, const char* argv[]) {
end_image, end_image,
control_frames.data(), control_frames.data(),
(int)control_frames.size(), (int)control_frames.size(),
gen_params.width, gen_params.get_resolved_width(),
gen_params.height, gen_params.get_resolved_height(),
gen_params.sample_params, gen_params.sample_params,
gen_params.high_noise_sample_params, gen_params.high_noise_sample_params,
gen_params.moe_boundary, gen_params.moe_boundary,

View File

@ -1024,8 +1024,8 @@ struct SDGenerationParams {
std::string prompt_with_lora; // for metadata record only std::string prompt_with_lora; // for metadata record only
std::string negative_prompt; std::string negative_prompt;
int clip_skip = -1; // <= 0 represents unspecified int clip_skip = -1; // <= 0 represents unspecified
int width = 512; int width = -1;
int height = 512; int height = -1;
int batch_count = 1; int batch_count = 1;
std::string init_image_path; std::string init_image_path;
std::string end_image_path; std::string end_image_path;
@ -1705,17 +1705,24 @@ struct SDGenerationParams {
} }
} }
bool width_and_height_are_set() const {
return width > 0 && height > 0;
}
void set_width_and_height_if_unset(int w, int h) {
if (!width_and_height_are_set()) {
LOG_INFO("set width x height to %d x %d", w, h);
width = w;
height = h;
}
}
int get_resolved_width() const { return (width > 0) ? width : 512; }
int get_resolved_height() const { return (height > 0) ? height : 512; }
bool process_and_check(SDMode mode, const std::string& lora_model_dir) { bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
prompt_with_lora = prompt; prompt_with_lora = prompt;
if (width <= 0) {
LOG_ERROR("error: the width must be greater than 0\n");
return false;
}
if (height <= 0) {
LOG_ERROR("error: the height must be greater than 0\n");
return false;
}
if (sample_params.sample_steps <= 0) { if (sample_params.sample_steps <= 0) {
LOG_ERROR("error: the sample_steps must be greater than 0\n"); LOG_ERROR("error: the sample_steps must be greater than 0\n");
@ -2083,6 +2090,22 @@ uint8_t* load_image_from_file(const char* image_path,
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel); return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
} }
bool load_sd_image_from_file(sd_image_t* image,
const char* image_path,
int expected_width = 0,
int expected_height = 0,
int expected_channel = 3) {
int width;
int height;
image->data = load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
if (image->data == nullptr) {
return false;
}
image->width = width;
image->height = height;
return true;
}
uint8_t* load_image_from_memory(const char* image_bytes, uint8_t* load_image_from_memory(const char* image_bytes,
int len, int len,
int& width, int& width,