feat(server): use image and command-line dimensions by default on server (#1262)

This commit is contained in:
Wagner Bruna 2026-02-10 13:42:50 -03:00 committed by GitHub
parent 45ce78a3ae
commit adea272225
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -404,8 +404,8 @@ int main(int argc, const char** argv) {
std::string size = j.value("size", "");
std::string output_format = j.value("output_format", "png");
int output_compression = j.value("output_compression", 100);
int width = 512;
int height = 512;
int width = default_gen_params.width > 0 ? default_gen_params.width : 512;
int height = default_gen_params.width > 0 ? default_gen_params.height : 512;
if (!size.empty()) {
auto pos = size.find('x');
if (pos != std::string::npos) {
@ -593,7 +593,7 @@ int main(int argc, const char** argv) {
n = std::clamp(n, 1, 8);
std::string size = req.form.get_field("size");
int width = 512, height = 512;
int width = -1, height = -1;
if (!size.empty()) {
auto pos = size.find('x');
if (pos != std::string::npos) {
@ -650,15 +650,31 @@ int main(int argc, const char** argv) {
LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t control_image = {0, 0, 3, nullptr};
std::vector<sd_image_t> pmid_images;
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
if (gen_params.width > 0)
return gen_params.width;
if (default_gen_params.width > 0)
return default_gen_params.width;
return 512;
};
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
if (gen_params.height > 0)
return gen_params.height;
if (default_gen_params.height > 0)
return default_gen_params.height;
return 512;
};
std::vector<sd_image_t> ref_images;
ref_images.reserve(images_bytes.size());
for (auto& bytes : images_bytes) {
int img_w = width;
int img_h = height;
int img_w;
int img_h;
uint8_t* raw_pixels = load_image_from_memory(
reinterpret_cast<const char*>(bytes.data()),
static_cast<int>(bytes.size()),
@ -670,22 +686,31 @@ int main(int argc, const char** argv) {
}
sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels};
gen_params.set_width_and_height_if_unset(img.width, img.height);
ref_images.push_back(img);
}
sd_image_t mask_image = {0};
if (!mask_bytes.empty()) {
int mask_w = width;
int mask_h = height;
int expected_width = 0;
int expected_height = 0;
if (gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
int mask_w;
int mask_h;
uint8_t* mask_raw = load_image_from_memory(
reinterpret_cast<const char*>(mask_bytes.data()),
static_cast<int>(mask_bytes.size()),
mask_w, mask_h,
width, height, 1);
expected_width, expected_height, 1);
mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw};
gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height);
} else {
mask_image.width = width;
mask_image.height = height;
mask_image.width = get_resolved_width();
mask_image.height = get_resolved_height();
mask_image.channel = 1;
mask_image.data = nullptr;
}
@ -702,8 +727,8 @@ int main(int argc, const char** argv) {
gen_params.auto_resize_ref_image,
gen_params.increase_ref_index,
mask_image,
gen_params.width,
gen_params.height,
get_resolved_width(),
get_resolved_height(),
gen_params.sample_params,
gen_params.strength,
gen_params.seed,
@ -886,8 +911,6 @@ int main(int argc, const char** argv) {
SDGenerationParams gen_params = default_gen_params;
gen_params.prompt = prompt;
gen_params.negative_prompt = negative_prompt;
gen_params.width = width;
gen_params.height = height;
gen_params.seed = seed;
gen_params.sample_params.sample_steps = steps;
gen_params.batch_count = batch_size;
@ -905,17 +928,36 @@ int main(int argc, const char** argv) {
gen_params.sample_params.scheduler = scheduler;
}
// re-read to avoid applying 512 as default before the provided
// images and/or server command-line
gen_params.width = j.value("width", -1);
gen_params.height = j.value("height", -1);
LOG_DEBUG("%s\n", gen_params.to_string().c_str());
sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t control_image = {0, 0, 3, nullptr};
sd_image_t mask_image = {0, 0, 1, nullptr};
std::vector<uint8_t> mask_data;
std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> ref_images;
if (img2img) {
auto decode_image = [](sd_image_t& image, std::string encoded) -> bool {
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
if (gen_params.width > 0)
return gen_params.width;
if (default_gen_params.width > 0)
return default_gen_params.width;
return 512;
};
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
if (gen_params.height > 0)
return gen_params.height;
if (default_gen_params.height > 0)
return default_gen_params.height;
return 512;
};
auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
// remove data URI prefix if present ("data:image/png;base64,")
auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) {
@ -923,20 +965,29 @@ int main(int argc, const char** argv) {
}
std::vector<uint8_t> img_data = base64_decode(encoded);
if (!img_data.empty()) {
int img_w = image.width;
int img_h = image.height;
int expected_width = 0;
int expected_height = 0;
if (gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
int img_w;
int img_h;
uint8_t* raw_data = load_image_from_memory(
(const char*)img_data.data(), (int)img_data.size(),
img_w, img_h,
image.width, image.height, image.channel);
expected_width, expected_height, image.channel);
if (raw_data) {
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
gen_params.set_width_and_height_if_unset(image.width, image.height);
return true;
}
}
return false;
};
if (img2img) {
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
std::string encoded = j["init_images"][0].get<std::string>();
decode_image(init_image, encoded);
@ -952,13 +1003,22 @@ int main(int argc, const char** argv) {
}
}
} else {
mask_data = std::vector<uint8_t>(width * height, 255);
mask_image.width = width;
mask_image.height = height;
int m_width = get_resolved_width();
int m_height = get_resolved_height();
mask_data = std::vector<uint8_t>(m_width * m_height, 255);
mask_image.width = m_width;
mask_image.height = m_height;
mask_image.channel = 1;
mask_image.data = mask_data.data();
}
float denoising_strength = j.value("denoising_strength", -1.f);
if (denoising_strength >= 0.f) {
denoising_strength = std::min(denoising_strength, 1.0f);
gen_params.strength = denoising_strength;
}
}
if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>();
@ -969,13 +1029,6 @@ int main(int argc, const char** argv) {
}
}
float denoising_strength = j.value("denoising_strength", -1.f);
if (denoising_strength >= 0.f) {
denoising_strength = std::min(denoising_strength, 1.0f);
gen_params.strength = denoising_strength;
}
}
sd_img_gen_params_t img_gen_params = {
sd_loras.data(),
static_cast<uint32_t>(sd_loras.size()),
@ -988,8 +1041,8 @@ int main(int argc, const char** argv) {
gen_params.auto_resize_ref_image,
gen_params.increase_ref_index,
mask_image,
gen_params.width,
gen_params.height,
get_resolved_width(),
get_resolved_height(),
gen_params.sample_params,
gen_params.strength,
gen_params.seed,