mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-01-02 10:43:35 +00:00
refactor: simplify logic for saving results (#1149)
This commit is contained in:
parent
51bd9c8004
commit
4ff2c8c74b
@ -370,6 +370,95 @@ std::string format_frame_idx(std::string pattern, int frame_idx) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool save_results(const SDCliParams& cli_params,
|
||||||
|
const SDContextParams& ctx_params,
|
||||||
|
const SDGenerationParams& gen_params,
|
||||||
|
sd_image_t* results,
|
||||||
|
int num_results) {
|
||||||
|
if (results == nullptr || num_results <= 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace fs = std::filesystem;
|
||||||
|
fs::path out_path = cli_params.output_path;
|
||||||
|
|
||||||
|
if (!out_path.parent_path().empty()) {
|
||||||
|
std::error_code ec;
|
||||||
|
fs::create_directories(out_path.parent_path(), ec);
|
||||||
|
if (ec) {
|
||||||
|
LOG_ERROR("failed to create directory '%s': %s",
|
||||||
|
out_path.parent_path().string().c_str(), ec.message().c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fs::path base_path = out_path;
|
||||||
|
fs::path ext = out_path.has_extension() ? out_path.extension() : fs::path{};
|
||||||
|
if (!ext.empty())
|
||||||
|
base_path.replace_extension();
|
||||||
|
|
||||||
|
std::string ext_lower = ext.string();
|
||||||
|
std::transform(ext_lower.begin(), ext_lower.end(), ext_lower.begin(), ::tolower);
|
||||||
|
bool is_jpg = (ext_lower == ".jpg" || ext_lower == ".jpeg" || ext_lower == ".jpe");
|
||||||
|
|
||||||
|
int output_begin_idx = cli_params.output_begin_idx;
|
||||||
|
if (output_begin_idx < 0) {
|
||||||
|
output_begin_idx = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto write_image = [&](const fs::path& path, int idx) {
|
||||||
|
const sd_image_t& img = results[idx];
|
||||||
|
if (!img.data)
|
||||||
|
return;
|
||||||
|
|
||||||
|
std::string params = get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + idx);
|
||||||
|
int ok = 0;
|
||||||
|
if (is_jpg) {
|
||||||
|
ok = stbi_write_jpg(path.string().c_str(), img.width, img.height, img.channel, img.data, 90, params.c_str());
|
||||||
|
} else {
|
||||||
|
ok = stbi_write_png(path.string().c_str(), img.width, img.height, img.channel, img.data, 0, params.c_str());
|
||||||
|
}
|
||||||
|
LOG_INFO("save result image %d to '%s' (%s)", idx, path.string().c_str(), ok ? "success" : "failure");
|
||||||
|
};
|
||||||
|
|
||||||
|
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
|
||||||
|
if (!is_jpg && ext_lower != ".png")
|
||||||
|
ext = ".png";
|
||||||
|
fs::path pattern = base_path;
|
||||||
|
pattern += ext;
|
||||||
|
|
||||||
|
for (int i = 0; i < num_results; ++i) {
|
||||||
|
fs::path img_path = format_frame_idx(pattern.string(), output_begin_idx + i);
|
||||||
|
write_image(img_path, i);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cli_params.mode == VID_GEN && num_results > 1) {
|
||||||
|
if (ext_lower != ".avi")
|
||||||
|
ext = ".avi";
|
||||||
|
fs::path video_path = base_path;
|
||||||
|
video_path += ext;
|
||||||
|
create_mjpg_avi_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps);
|
||||||
|
LOG_INFO("save result MJPG AVI video to '%s'", video_path.string().c_str());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_jpg && ext_lower != ".png")
|
||||||
|
ext = ".png";
|
||||||
|
|
||||||
|
for (int i = 0; i < num_results; ++i) {
|
||||||
|
fs::path img_path = base_path;
|
||||||
|
if (num_results > 1) {
|
||||||
|
img_path += "_" + std::to_string(output_begin_idx + i);
|
||||||
|
}
|
||||||
|
img_path += ext;
|
||||||
|
write_image(img_path, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, const char* argv[]) {
|
int main(int argc, const char* argv[]) {
|
||||||
if (argc > 1 && std::string(argv[1]) == "--version") {
|
if (argc > 1 && std::string(argv[1]) == "--version") {
|
||||||
std::cout << version_string() << "\n";
|
std::cout << version_string() << "\n";
|
||||||
@ -713,101 +802,8 @@ int main(int argc, const char* argv[]) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// create directory if not exists
|
if (!save_results(cli_params, ctx_params, gen_params, results, num_results)) {
|
||||||
{
|
return 1;
|
||||||
const fs::path out_path = cli_params.output_path;
|
|
||||||
if (const fs::path out_dir = out_path.parent_path(); !out_dir.empty()) {
|
|
||||||
std::error_code ec;
|
|
||||||
fs::create_directories(out_dir, ec); // OK if already exists
|
|
||||||
if (ec) {
|
|
||||||
LOG_ERROR("failed to create directory '%s': %s",
|
|
||||||
out_dir.string().c_str(), ec.message().c_str());
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string base_path;
|
|
||||||
std::string file_ext;
|
|
||||||
std::string file_ext_lower;
|
|
||||||
bool is_jpg;
|
|
||||||
size_t last_dot_pos = cli_params.output_path.find_last_of(".");
|
|
||||||
size_t last_slash_pos = std::min(cli_params.output_path.find_last_of("/"),
|
|
||||||
cli_params.output_path.find_last_of("\\"));
|
|
||||||
if (last_dot_pos != std::string::npos && (last_slash_pos == std::string::npos || last_dot_pos > last_slash_pos)) { // filename has extension
|
|
||||||
base_path = cli_params.output_path.substr(0, last_dot_pos);
|
|
||||||
file_ext = file_ext_lower = cli_params.output_path.substr(last_dot_pos);
|
|
||||||
std::transform(file_ext.begin(), file_ext.end(), file_ext_lower.begin(), ::tolower);
|
|
||||||
is_jpg = (file_ext_lower == ".jpg" || file_ext_lower == ".jpeg" || file_ext_lower == ".jpe");
|
|
||||||
} else {
|
|
||||||
base_path = cli_params.output_path;
|
|
||||||
file_ext = file_ext_lower = "";
|
|
||||||
is_jpg = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
|
|
||||||
std::string final_output_path = cli_params.output_path;
|
|
||||||
if (cli_params.output_begin_idx == -1) {
|
|
||||||
cli_params.output_begin_idx = 0;
|
|
||||||
}
|
|
||||||
// writing image sequence, default to PNG
|
|
||||||
if (!is_jpg && file_ext_lower != ".png") {
|
|
||||||
base_path += file_ext;
|
|
||||||
file_ext = ".png";
|
|
||||||
}
|
|
||||||
final_output_path = base_path + file_ext;
|
|
||||||
for (int i = 0; i < num_results; i++) {
|
|
||||||
if (results[i].data == nullptr) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
std::string final_image_path = format_frame_idx(final_output_path, cli_params.output_begin_idx + i);
|
|
||||||
if (is_jpg) {
|
|
||||||
int write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
|
||||||
results[i].data, 90, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
|
|
||||||
LOG_INFO("save result JPEG image %d to '%s' (%s)", i, final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
|
||||||
} else {
|
|
||||||
int write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
|
||||||
results[i].data, 0, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
|
|
||||||
LOG_INFO("save result PNG image %d to '%s' (%s)", i, final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if (cli_params.mode == VID_GEN && num_results > 1) {
|
|
||||||
std::string final_output_path = cli_params.output_path;
|
|
||||||
if (file_ext_lower != ".avi") {
|
|
||||||
if (!is_jpg && file_ext_lower != ".png") {
|
|
||||||
base_path += file_ext;
|
|
||||||
}
|
|
||||||
file_ext = ".avi";
|
|
||||||
final_output_path = base_path + file_ext;
|
|
||||||
}
|
|
||||||
create_mjpg_avi_from_sd_images(final_output_path.c_str(), results, num_results, gen_params.fps);
|
|
||||||
LOG_INFO("save result MJPG AVI video to '%s'\n", final_output_path.c_str());
|
|
||||||
} else {
|
|
||||||
// appending ".png" to absent or unknown extension
|
|
||||||
if (!is_jpg && file_ext_lower != ".png") {
|
|
||||||
base_path += file_ext;
|
|
||||||
file_ext = ".png";
|
|
||||||
}
|
|
||||||
if (cli_params.output_begin_idx == -1) {
|
|
||||||
cli_params.output_begin_idx = 1;
|
|
||||||
}
|
|
||||||
for (int i = 0; i < num_results; i++) {
|
|
||||||
if (results[i].data == nullptr) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
int write_ok;
|
|
||||||
std::string final_image_path;
|
|
||||||
final_image_path = i > 0 ? base_path + "_" + std::to_string(cli_params.output_begin_idx + i) + file_ext : base_path + file_ext;
|
|
||||||
if (is_jpg) {
|
|
||||||
write_ok = stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
|
||||||
results[i].data, 90, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
|
|
||||||
LOG_INFO("save result JPEG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
|
||||||
} else {
|
|
||||||
write_ok = stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel,
|
|
||||||
results[i].data, 0, get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + i).c_str());
|
|
||||||
LOG_INFO("save result PNG image to '%s' (%s)", final_image_path.c_str(), write_ok == 0 ? "failure" : "success");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < num_results; i++) {
|
for (int i = 0; i < num_results; i++) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user