refactor: simplify logic for saving results (#1149)

This commit is contained in:
leejet 2025-12-28 23:27:27 +08:00 committed by GitHub
parent 51bd9c8004
commit 4ff2c8c74b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -370,6 +370,95 @@ std::string format_frame_idx(std::string pattern, int frame_idx) {
return result;
}
bool save_results(const SDCliParams& cli_params,
const SDContextParams& ctx_params,
const SDGenerationParams& gen_params,
sd_image_t* results,
int num_results) {
if (results == nullptr || num_results <= 0) {
return false;
}
namespace fs = std::filesystem;
fs::path out_path = cli_params.output_path;
if (!out_path.parent_path().empty()) {
std::error_code ec;
fs::create_directories(out_path.parent_path(), ec);
if (ec) {
LOG_ERROR("failed to create directory '%s': %s",
out_path.parent_path().string().c_str(), ec.message().c_str());
return false;
}
}
fs::path base_path = out_path;
fs::path ext = out_path.has_extension() ? out_path.extension() : fs::path{};
if (!ext.empty())
base_path.replace_extension();
std::string ext_lower = ext.string();
std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower);
bool is_jpg = (ext_lower == ".jpg" || ext_lower == ".jpeg" || ext_lower == ".jpe");
int output_begin_idx = cli_params.output_begin_idx;
if (output_begin_idx < 0) {
output_begin_idx = 0;
}
auto write_image = [&](const fs::path& path, int idx) {
const sd_image_t& img = results[idx];
if (!img.data)
return;
std::string params = get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + idx);
int ok = 0;
if (is_jpg) {
ok = stbi_write_jpg(path.string().c_str(), img.width, img.height, img.channel, img.data, 90, params.c_str());
} else {
ok = stbi_write_png(path.string().c_str(), img.width, img.height, img.channel, img.data, 0, params.c_str());
}
LOG_INFO("save result image %d to '%s' (%s)", idx, path.string().c_str(), ok ? "success" : "failure");
};
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
if (!is_jpg && ext_lower != ".png")
ext = ".png";
fs::path pattern = base_path;
pattern += ext;
for (int i = 0; i < num_results; ++i) {
fs::path img_path = format_frame_idx(pattern.string(), output_begin_idx + i);
write_image(img_path, i);
}
return true;
}
if (cli_params.mode == VID_GEN && num_results > 1) {
if (ext_lower != ".avi")
ext = ".avi";
fs::path video_path = base_path;
video_path += ext;
create_mjpg_avi_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps);
LOG_INFO("save result MJPG AVI video to '%s'", video_path.string().c_str());
return true;
}
if (!is_jpg && ext_lower != ".png")
ext = ".png";
for (int i = 0; i < num_results; ++i) {
fs::path img_path = base_path;
if (num_results > 1) {
img_path += "_" + std::to_string(output_begin_idx + i);
}
img_path += ext;
write_image(img_path, i);
}
return true;
}
int main(int argc, const char* argv[]) {
if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n";
@ -713,101 +802,8 @@ int main(int argc, const char* argv[]) {
}
}
// create directory if not exists
{
const fs::path out_path = cli_params.output_path;
if (const fs::path out_dir = out_path.parent_path(); !out_dir.empty()) {
std::error_code ec;
fs::create_directories(out_dir, ec); // OK if already exists
if (ec) {
LOG_ERROR("failed to create directory '%s': %s",
out_dir.string().c_str(), ec.message().c_str());
return 1;
}
}
}
std::string base_path;
std::string file_ext;
std::string file_ext_lower;
bool is_jpg;
size_t last_dot_pos = cli_params.output_path.find_last_of(".");
size_t last_slash_pos = std::min(cli_params.output_path.find_last_of("/"),
cli_params.output_path.find_last_of("\\"));
if (last_dot_pos != std::string::npos && (last_slash_pos == std::string::npos || last_dot_pos > last_slash_pos)) { // filename has extension
base_path = cli_params.output_path.substr(0, last_dot_pos);
file_ext = file_ext_lower = cli_params.output_path.substr(last_dot_pos);
std::transform(file_ext.begin(), file_ext.end(), file_ext_lower.begin(), ::tolower);
is_jpg = (file_ext_lower == ".jpg" || file_ext_lower == ".jpeg" || file_ext_lower == ".jpe");
} else {
base_path = cli_params.output_path;
file_ext = file_ext_lower = "";
is_jpg = false;
}
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
std::string final_output_path = cli_params.output_path;
if (cli_params.output_begin_idx == -1) {
cli_params.output_begin_idx = 0;
}
// writing image sequence, default to PNG
if (!is_jpg && file_ext_lower != ".png") {
base_path += file_ext;
file_ext = ".png";
}
final_output_path = base_path + file_ext;
for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr) {
continue;
}
std::string final_image_path = format_frame_idx(final_output_path, cli_params.output_begin_idx + i);
if (is_jpg) {
int write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 90, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
LOG_INFO("save result JPEG image %d to '%s' (%s)", i, final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
} else {
int write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
LOG_INFO("save result PNG image %d to '%s' (%s)", i, final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
}
}
} else if (cli_params.mode == VID_GEN && num_results > 1) {
std::string final_output_path = cli_params.output_path;
if (file_ext_lower != ".avi") {
if (!is_jpg && file_ext_lower != ".png") {
base_path += file_ext;
}
file_ext = ".avi";
final_output_path = base_path + file_ext;
}
create_mjpg_avi_from_sd_images(final_output_path.c_str(), results, num_results, gen_params.fps);
LOG_INFO("save result MJPG AVI video to '%s'\n", final_output_path.c_str());
} else {
// appending ".png" to absent or unknown extension
if (!is_jpg && file_ext_lower != ".png") {
base_path += file_ext;
file_ext = ".png";
}
if (cli_params.output_begin_idx == -1) {
cli_params.output_begin_idx = 1;
}
for (int i = 0; i < num_results; i++) {
if (results[i].data == nullptr) {
continue;
}
int write_ok;
std::string final_image_path;
final_image_path = i > 0 ? base_path + "_" + std::to_string(cli_params.output_begin_idx + i) + file_ext : base_path + file_ext;
if (is_jpg) {
write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 90, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
LOG_INFO("save result JPEG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
} else {
write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
results[i].data, 0, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
LOG_INFO("save result PNG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
}
}
if (!save_results(cli_params, ctx_params, gen_params, results, num_results)) {
return 1;
}
for (int i = 0; i < num_results; i++) {