mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 10:53:34 +00:00
feat: use image width and height when not explicitly set (#1206)
This commit is contained in:
parent
329571131d
commit
5e4579c11d
@ -245,7 +245,7 @@ std::string get_image_params(const SDCliParams& cli_params, const SDContextParam
|
||||
parameter_string += "Guidance: " + std::to_string(gen_params.sample_params.guidance.distilled_guidance) + ", ";
|
||||
parameter_string += "Eta: " + std::to_string(gen_params.sample_params.eta) + ", ";
|
||||
parameter_string += "Seed: " + std::to_string(seed) + ", ";
|
||||
parameter_string += "Size: " + std::to_string(gen_params.width) + "x" + std::to_string(gen_params.height) + ", ";
|
||||
parameter_string += "Size: " + std::to_string(gen_params.get_resolved_width()) + "x" + std::to_string(gen_params.get_resolved_height()) + ", ";
|
||||
parameter_string += "Model: " + sd_basename(ctx_params.model_path) + ", ";
|
||||
parameter_string += "RNG: " + std::string(sd_rng_type_name(ctx_params.rng_type)) + ", ";
|
||||
if (ctx_params.sampler_rng_type != RNG_TYPE_COUNT) {
|
||||
@ -526,10 +526,10 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
|
||||
bool vae_decode_only = true;
|
||||
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||
sd_image_t end_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
|
||||
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
|
||||
sd_image_t init_image = {0, 0, 3, nullptr};
|
||||
sd_image_t end_image = {0, 0, 3, nullptr};
|
||||
sd_image_t control_image = {0, 0, 3, nullptr};
|
||||
sd_image_t mask_image = {0, 0, 1, nullptr};
|
||||
std::vector<sd_image_t> ref_images;
|
||||
std::vector<sd_image_t> pmid_images;
|
||||
std::vector<sd_image_t> control_frames;
|
||||
@ -556,57 +556,79 @@ int main(int argc, const char* argv[]) {
|
||||
control_frames.clear();
|
||||
};
|
||||
|
||||
auto load_image_and_update_size = [&](const std::string& path,
|
||||
sd_image_t& image,
|
||||
bool resize_image = true,
|
||||
int expected_channel = 3) -> bool {
|
||||
int expected_width = 0;
|
||||
int expected_height = 0;
|
||||
if (resize_image && gen_params.width_and_height_are_set()) {
|
||||
expected_width = gen_params.width;
|
||||
expected_height = gen_params.height;
|
||||
}
|
||||
|
||||
if (!load_sd_image_from_file(&image, path.c_str(), expected_width, expected_height, expected_channel)) {
|
||||
LOG_ERROR("load image from '%s' failed", path.c_str());
|
||||
release_all_resources();
|
||||
return false;
|
||||
}
|
||||
|
||||
gen_params.set_width_and_height_if_unset(image.width, image.height);
|
||||
return true;
|
||||
};
|
||||
|
||||
if (gen_params.init_image_path.size() > 0) {
|
||||
vae_decode_only = false;
|
||||
|
||||
int width = 0;
|
||||
int height = 0;
|
||||
init_image.data = load_image_from_file(gen_params.init_image_path.c_str(), width, height, gen_params.width, gen_params.height);
|
||||
if (init_image.data == nullptr) {
|
||||
LOG_ERROR("load image from '%s' failed", gen_params.init_image_path.c_str());
|
||||
release_all_resources();
|
||||
if (!load_image_and_update_size(gen_params.init_image_path, init_image)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (gen_params.end_image_path.size() > 0) {
|
||||
vae_decode_only = false;
|
||||
|
||||
int width = 0;
|
||||
int height = 0;
|
||||
end_image.data = load_image_from_file(gen_params.end_image_path.c_str(), width, height, gen_params.width, gen_params.height);
|
||||
if (end_image.data == nullptr) {
|
||||
LOG_ERROR("load image from '%s' failed", gen_params.end_image_path.c_str());
|
||||
release_all_resources();
|
||||
if (!load_image_and_update_size(gen_params.init_image_path, end_image)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (gen_params.ref_image_paths.size() > 0) {
|
||||
vae_decode_only = false;
|
||||
for (auto& path : gen_params.ref_image_paths) {
|
||||
sd_image_t ref_image = {0, 0, 3, nullptr};
|
||||
if (!load_image_and_update_size(path, ref_image, false)) {
|
||||
return 1;
|
||||
}
|
||||
ref_images.push_back(ref_image);
|
||||
}
|
||||
}
|
||||
|
||||
if (gen_params.mask_image_path.size() > 0) {
|
||||
int c = 0;
|
||||
int width = 0;
|
||||
int height = 0;
|
||||
mask_image.data = load_image_from_file(gen_params.mask_image_path.c_str(), width, height, gen_params.width, gen_params.height, 1);
|
||||
if (mask_image.data == nullptr) {
|
||||
if (load_sd_image_from_file(&mask_image,
|
||||
gen_params.mask_image_path.c_str(),
|
||||
gen_params.get_resolved_width(),
|
||||
gen_params.get_resolved_height(),
|
||||
1)) {
|
||||
LOG_ERROR("load image from '%s' failed", gen_params.mask_image_path.c_str());
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
mask_image.data = (uint8_t*)malloc(gen_params.width * gen_params.height);
|
||||
mask_image.data = (uint8_t*)malloc(gen_params.get_resolved_width() * gen_params.get_resolved_height());
|
||||
if (mask_image.data == nullptr) {
|
||||
LOG_ERROR("malloc mask image failed");
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
memset(mask_image.data, 255, gen_params.width * gen_params.height);
|
||||
mask_image.width = gen_params.get_resolved_width();
|
||||
mask_image.height = gen_params.get_resolved_height();
|
||||
memset(mask_image.data, 255, gen_params.get_resolved_width() * gen_params.get_resolved_height());
|
||||
}
|
||||
|
||||
if (gen_params.control_image_path.size() > 0) {
|
||||
int width = 0;
|
||||
int height = 0;
|
||||
control_image.data = load_image_from_file(gen_params.control_image_path.c_str(), width, height, gen_params.width, gen_params.height);
|
||||
if (control_image.data == nullptr) {
|
||||
if (load_sd_image_from_file(&control_image,
|
||||
gen_params.control_image_path.c_str(),
|
||||
gen_params.get_resolved_width(),
|
||||
gen_params.get_resolved_height())) {
|
||||
LOG_ERROR("load image from '%s' failed", gen_params.control_image_path.c_str());
|
||||
release_all_resources();
|
||||
return 1;
|
||||
@ -621,29 +643,11 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
if (gen_params.ref_image_paths.size() > 0) {
|
||||
vae_decode_only = false;
|
||||
for (auto& path : gen_params.ref_image_paths) {
|
||||
int width = 0;
|
||||
int height = 0;
|
||||
uint8_t* image_buffer = load_image_from_file(path.c_str(), width, height);
|
||||
if (image_buffer == nullptr) {
|
||||
LOG_ERROR("load image from '%s' failed", path.c_str());
|
||||
release_all_resources();
|
||||
return 1;
|
||||
}
|
||||
ref_images.push_back({(uint32_t)width,
|
||||
(uint32_t)height,
|
||||
3,
|
||||
image_buffer});
|
||||
}
|
||||
}
|
||||
|
||||
if (!gen_params.control_video_path.empty()) {
|
||||
if (!load_images_from_dir(gen_params.control_video_path,
|
||||
control_frames,
|
||||
gen_params.width,
|
||||
gen_params.height,
|
||||
gen_params.get_resolved_width(),
|
||||
gen_params.get_resolved_height(),
|
||||
gen_params.video_frames,
|
||||
cli_params.verbose)) {
|
||||
release_all_resources();
|
||||
@ -717,8 +721,8 @@ int main(int argc, const char* argv[]) {
|
||||
gen_params.auto_resize_ref_image,
|
||||
gen_params.increase_ref_index,
|
||||
mask_image,
|
||||
gen_params.width,
|
||||
gen_params.height,
|
||||
gen_params.get_resolved_width(),
|
||||
gen_params.get_resolved_height(),
|
||||
gen_params.sample_params,
|
||||
gen_params.strength,
|
||||
gen_params.seed,
|
||||
@ -748,8 +752,8 @@ int main(int argc, const char* argv[]) {
|
||||
end_image,
|
||||
control_frames.data(),
|
||||
(int)control_frames.size(),
|
||||
gen_params.width,
|
||||
gen_params.height,
|
||||
gen_params.get_resolved_width(),
|
||||
gen_params.get_resolved_height(),
|
||||
gen_params.sample_params,
|
||||
gen_params.high_noise_sample_params,
|
||||
gen_params.moe_boundary,
|
||||
|
||||
@ -1024,8 +1024,8 @@ struct SDGenerationParams {
|
||||
std::string prompt_with_lora; // for metadata record only
|
||||
std::string negative_prompt;
|
||||
int clip_skip = -1; // <= 0 represents unspecified
|
||||
int width = 512;
|
||||
int height = 512;
|
||||
int width = -1;
|
||||
int height = -1;
|
||||
int batch_count = 1;
|
||||
std::string init_image_path;
|
||||
std::string end_image_path;
|
||||
@ -1705,17 +1705,24 @@ struct SDGenerationParams {
|
||||
}
|
||||
}
|
||||
|
||||
bool width_and_height_are_set() const {
|
||||
return width > 0 && height > 0;
|
||||
}
|
||||
|
||||
void set_width_and_height_if_unset(int w, int h) {
|
||||
if (!width_and_height_are_set()) {
|
||||
LOG_INFO("set width x height to %d x %d", w, h);
|
||||
width = w;
|
||||
height = h;
|
||||
}
|
||||
}
|
||||
|
||||
int get_resolved_width() const { return (width > 0) ? width : 512; }
|
||||
|
||||
int get_resolved_height() const { return (height > 0) ? height : 512; }
|
||||
|
||||
bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
|
||||
prompt_with_lora = prompt;
|
||||
if (width <= 0) {
|
||||
LOG_ERROR("error: the width must be greater than 0\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (height <= 0) {
|
||||
LOG_ERROR("error: the height must be greater than 0\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (sample_params.sample_steps <= 0) {
|
||||
LOG_ERROR("error: the sample_steps must be greater than 0\n");
|
||||
@ -2083,6 +2090,22 @@ uint8_t* load_image_from_file(const char* image_path,
|
||||
return load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
|
||||
}
|
||||
|
||||
bool load_sd_image_from_file(sd_image_t* image,
|
||||
const char* image_path,
|
||||
int expected_width = 0,
|
||||
int expected_height = 0,
|
||||
int expected_channel = 3) {
|
||||
int width;
|
||||
int height;
|
||||
image->data = load_image_common(false, image_path, 0, width, height, expected_width, expected_height, expected_channel);
|
||||
if (image->data == nullptr) {
|
||||
return false;
|
||||
}
|
||||
image->width = width;
|
||||
image->height = height;
|
||||
return true;
|
||||
}
|
||||
|
||||
uint8_t* load_image_from_memory(const char* image_bytes,
|
||||
int len,
|
||||
int& width,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user