mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-01-02 10:43:35 +00:00
refactor: simplify logic for saving results (#1149)
This commit is contained in:
parent
51bd9c8004
commit
4ff2c8c74b
@ -370,6 +370,95 @@ std::string format_frame_idx(std::string pattern, int frame_idx) {
|
||||
return result;
|
||||
}
|
||||
|
||||
bool save_results(const SDCliParams& cli_params,
|
||||
const SDContextParams& ctx_params,
|
||||
const SDGenerationParams& gen_params,
|
||||
sd_image_t* results,
|
||||
int num_results) {
|
||||
if (results == nullptr || num_results <= 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
fs::path out_path = cli_params.output_path;
|
||||
|
||||
if (!out_path.parent_path().empty()) {
|
||||
std::error_code ec;
|
||||
fs::create_directories(out_path.parent_path(), ec);
|
||||
if (ec) {
|
||||
LOG_ERROR("failed to create directory '%s': %s",
|
||||
out_path.parent_path().string().c_str(), ec.message().c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
fs::path base_path = out_path;
|
||||
fs::path ext = out_path.has_extension() ? out_path.extension() : fs::path{};
|
||||
if (!ext.empty())
|
||||
base_path.replace_extension();
|
||||
|
||||
std::string ext_lower = ext.string();
|
||||
std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower);
|
||||
bool is_jpg = (ext_lower == ".jpg" || ext_lower == ".jpeg" || ext_lower == ".jpe");
|
||||
|
||||
int output_begin_idx = cli_params.output_begin_idx;
|
||||
if (output_begin_idx < 0) {
|
||||
output_begin_idx = 0;
|
||||
}
|
||||
|
||||
auto write_image = [&](const fs::path& path, int idx) {
|
||||
const sd_image_t& img = results[idx];
|
||||
if (!img.data)
|
||||
return;
|
||||
|
||||
std::string params = get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + idx);
|
||||
int ok = 0;
|
||||
if (is_jpg) {
|
||||
ok = stbi_write_jpg(path.string().c_str(), img.width, img.height, img.channel, img.data, 90, params.c_str());
|
||||
} else {
|
||||
ok = stbi_write_png(path.string().c_str(), img.width, img.height, img.channel, img.data, 0, params.c_str());
|
||||
}
|
||||
LOG_INFO("save result image %d to '%s' (%s)", idx, path.string().c_str(), ok ? "success" : "failure");
|
||||
};
|
||||
|
||||
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
|
||||
if (!is_jpg && ext_lower != ".png")
|
||||
ext = ".png";
|
||||
fs::path pattern = base_path;
|
||||
pattern += ext;
|
||||
|
||||
for (int i = 0; i < num_results; ++i) {
|
||||
fs::path img_path = format_frame_idx(pattern.string(), output_begin_idx + i);
|
||||
write_image(img_path, i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (cli_params.mode == VID_GEN && num_results > 1) {
|
||||
if (ext_lower != ".avi")
|
||||
ext = ".avi";
|
||||
fs::path video_path = base_path;
|
||||
video_path += ext;
|
||||
create_mjpg_avi_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps);
|
||||
LOG_INFO("save result MJPG AVI video to '%s'", video_path.string().c_str());
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!is_jpg && ext_lower != ".png")
|
||||
ext = ".png";
|
||||
|
||||
for (int i = 0; i < num_results; ++i) {
|
||||
fs::path img_path = base_path;
|
||||
if (num_results > 1) {
|
||||
img_path += "_" + std::to_string(output_begin_idx + i);
|
||||
}
|
||||
img_path += ext;
|
||||
write_image(img_path, i);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, const char* argv[]) {
|
||||
if (argc > 1 && std::string(argv[1]) == "--version") {
|
||||
std::cout << version_string() << "\n";
|
||||
@ -713,102 +802,9 @@ int main(int argc, const char* argv[]) {
|
||||
}
|
||||
}
|
||||
|
||||
// create directory if not exists
|
||||
{
|
||||
const fs::path out_path = cli_params.output_path;
|
||||
if (const fs::path out_dir = out_path.parent_path(); !out_dir.empty()) {
|
||||
std::error_code ec;
|
||||
fs::create_directories(out_dir, ec); // OK if already exists
|
||||
if (ec) {
|
||||
LOG_ERROR("failed to create directory '%s': %s",
|
||||
out_dir.string().c_str(), ec.message().c_str());
|
||||
if (!save_results(cli_params, ctx_params, gen_params, results, num_results)) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string base_path;
|
||||
std::string file_ext;
|
||||
std::string file_ext_lower;
|
||||
bool is_jpg;
|
||||
size_t last_dot_pos = cli_params.output_path.find_last_of(".");
|
||||
size_t last_slash_pos = std::min(cli_params.output_path.find_last_of("/"),
|
||||
cli_params.output_path.find_last_of("\\"));
|
||||
if (last_dot_pos != std::string::npos && (last_slash_pos == std::string::npos || last_dot_pos > last_slash_pos)) { // filename has extension
|
||||
base_path = cli_params.output_path.substr(0, last_dot_pos);
|
||||
file_ext = file_ext_lower = cli_params.output_path.substr(last_dot_pos);
|
||||
std::transform(file_ext.begin(), file_ext.end(), file_ext_lower.begin(), ::tolower);
|
||||
is_jpg = (file_ext_lower == ".jpg" || file_ext_lower == ".jpeg" || file_ext_lower == ".jpe");
|
||||
} else {
|
||||
base_path = cli_params.output_path;
|
||||
file_ext = file_ext_lower = "";
|
||||
is_jpg = false;
|
||||
}
|
||||
|
||||
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
|
||||
std::string final_output_path = cli_params.output_path;
|
||||
if (cli_params.output_begin_idx == -1) {
|
||||
cli_params.output_begin_idx = 0;
|
||||
}
|
||||
// writing image sequence, default to PNG
|
||||
if (!is_jpg && file_ext_lower != ".png") {
|
||||
base_path += file_ext;
|
||||
file_ext = ".png";
|
||||
}
|
||||
final_output_path = base_path + file_ext;
|
||||
for (int i = 0; i < num_results; i++) {
|
||||
if (results[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
std::string final_image_path = format_frame_idx(final_output_path, cli_params.output_begin_idx + i);
|
||||
if (is_jpg) {
|
||||
int write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
||||
results[i].data, 90, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
|
||||
LOG_INFO("save result JPEG image %d to '%s' (%s)", i, final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
||||
} else {
|
||||
int write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
||||
results[i].data, 0, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
|
||||
LOG_INFO("save result PNG image %d to '%s' (%s)", i, final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
||||
}
|
||||
}
|
||||
} else if (cli_params.mode == VID_GEN && num_results > 1) {
|
||||
std::string final_output_path = cli_params.output_path;
|
||||
if (file_ext_lower != ".avi") {
|
||||
if (!is_jpg && file_ext_lower != ".png") {
|
||||
base_path += file_ext;
|
||||
}
|
||||
file_ext = ".avi";
|
||||
final_output_path = base_path + file_ext;
|
||||
}
|
||||
create_mjpg_avi_from_sd_images(final_output_path.c_str(), results, num_results, gen_params.fps);
|
||||
LOG_INFO("save result MJPG AVI video to '%s'\n", final_output_path.c_str());
|
||||
} else {
|
||||
// appending ".png" to absent or unknown extension
|
||||
if (!is_jpg && file_ext_lower != ".png") {
|
||||
base_path += file_ext;
|
||||
file_ext = ".png";
|
||||
}
|
||||
if (cli_params.output_begin_idx == -1) {
|
||||
cli_params.output_begin_idx = 1;
|
||||
}
|
||||
for (int i = 0; i < num_results; i++) {
|
||||
if (results[i].data == nullptr) {
|
||||
continue;
|
||||
}
|
||||
int write_ok;
|
||||
std::string final_image_path;
|
||||
final_image_path = i > 0 ? base_path + "_" + std::to_string(cli_params.output_begin_idx + i) + file_ext : base_path + file_ext;
|
||||
if (is_jpg) {
|
||||
write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
||||
results[i].data, 90, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
|
||||
LOG_INFO("save result JPEG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
||||
} else {
|
||||
write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
||||
results[i].data, 0, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
|
||||
LOG_INFO("save result PNG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_results; i++) {
|
||||
free(results[i].data);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user