feat: add image preview support (#522)

This commit is contained in:
stduhpf 2025-11-09 17:12:02 +01:00 committed by GitHub
parent ee89afc878
commit 8ecdf053ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 563 additions and 10 deletions

1
.gitignore vendored
View File

@ -12,3 +12,4 @@ test/
output*.png
models*
*.log
preview.png

View File

@ -32,6 +32,7 @@ Options:
-o, --output <string> path to write result image to (default: ./output.png)
-p, --prompt <string> the prompt to render
-n, --negative-prompt <string> the negative prompt (default: "")
--preview-path <string> path to write preview image to (default: ./preview.png)
--upscale-model <string> path to esrgan model.
-t, --threads <int> number of threads to use during computation (default: -1). If threads <= 0, then threads will be set to the number of
CPU physical cores
@ -48,6 +49,8 @@ Options:
--fps <int> fps (default: 24)
--timestep-shift <int> shift timestep for NitroFusion models (default: 0). recommended N for NitroSD-Realism around 250 and 500 for
NitroSD-Vibrant
--preview-interval <int> interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at
every step)
--cfg-scale <float> unconditional guidance scale: (default: 7.0)
--img-cfg-scale <float> image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)
--guidance <float> distilled guidance scale for models with guidance input (default: 3.5)
@ -86,6 +89,8 @@ Options:
--chroma-enable-t5-mask enable t5 mask for chroma
--increase-ref-index automatically increase the indices of references images based on the order they are listed (starting with 1).
--disable-auto-resize-ref-image disable auto resize of ref images
--taesd-preview-only prevents usage of taesd for decoding the final image. (for use with --preview tae)
--preview-noisy enables previewing noisy inputs of the models rather than the denoised outputs
-M, --mode run mode, one of [img_gen, vid_gen, upscale, convert], default: img_gen
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the
type of the weight file
@ -107,4 +112,5 @@ Options:
--vae-tile-size tile size for vae tiling, format [X]x[Y] (default: 32x32)
--vae-relative-tile-size relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1
(overrides --vae-tile-size)
--preview preview method. must be one of the following [none, proj, tae, vae] (default is none)
```

View File

@ -46,6 +46,13 @@ const char* modes_str[] = {
};
#define SD_ALL_MODES_STR "img_gen, vid_gen, convert, upscale"
const char* previews_str[] = {
"none",
"proj",
"tae",
"vae",
};
enum SDMode {
IMG_GEN,
VID_GEN,
@ -135,6 +142,12 @@ struct SDParams {
sd_tiling_params_t vae_tiling_params = {false, 0, 0, 0.5f, 0.0f, 0.0f};
bool force_sdxl_vae_conv_scale = false;
preview_t preview_method = PREVIEW_NONE;
int preview_interval = 1;
std::string preview_path = "preview.png";
bool taesd_preview = false;
bool preview_noisy = false;
SDParams() {
sd_sample_params_init(&sample_params);
sd_sample_params_init(&high_noise_sample_params);
@ -210,6 +223,8 @@ void print_params(SDParams params) {
printf(" video_frames: %d\n", params.video_frames);
printf(" vace_strength: %.2f\n", params.vace_strength);
printf(" fps: %d\n", params.fps);
printf(" preview_mode: %s (%s)\n", previews_str[params.preview_method], params.preview_noisy ? "noisy" : "denoised");
printf(" preview_interval: %d\n", params.preview_interval);
free(sample_params_str);
free(high_noise_sample_params_str);
}
@ -589,6 +604,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"--negative-prompt",
"the negative prompt (default: \"\")",
&params.negative_prompt},
{"",
"--preview-path",
"path to write preview image to (default: ./preview.png)",
&params.preview_path},
{"",
"--upscale-model",
"path to esrgan model.",
@ -647,6 +666,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"shift timestep for NitroFusion models (default: 0). "
"recommended N for NitroSD-Realism around 250 and 500 for NitroSD-Vibrant",
&params.sample_params.shifted_timestep},
{"",
"--preview-interval",
"interval in denoising steps between consecutive updates of the image preview file (default is 1, meaning updating at every step)",
&params.preview_interval},
};
options.float_options = {
@ -801,7 +824,14 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"--disable-auto-resize-ref-image",
"disable auto resize of ref images",
false, &params.auto_resize_ref_image},
};
{"",
"--taesd-preview-only",
std::string("prevents usage of taesd for decoding the final image. (for use with --preview ") + previews_str[PREVIEW_TAE] + ")",
true, &params.taesd_preview},
{"",
"--preview-noisy",
"enables previewing noisy inputs of the models rather than the denoised outputs",
true, &params.preview_noisy}};
auto on_mode_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
@ -1046,6 +1076,26 @@ void parse_args(int argc, const char** argv, SDParams& params) {
return 1;
};
auto on_preview_arg = [&](int argc, const char** argv, int index) {
if (++index >= argc) {
return -1;
}
const char* preview = argv[index];
int preview_method = -1;
for (int m = 0; m < PREVIEW_COUNT; m++) {
if (!strcmp(preview, previews_str[m])) {
preview_method = m;
}
}
if (preview_method == -1) {
fprintf(stderr, "error: preview method %s\n",
preview);
return -1;
}
params.preview_method = (preview_t)preview_method;
return 1;
};
options.manual_options = {
{"-M",
"--mode",
@ -1110,6 +1160,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
"--vae-relative-tile-size",
"relative tile size for vae tiling, format [X]x[Y], in fraction of image size if < 1, in number of tiles per dim if >=1 (overrides --vae-tile-size)",
on_relative_tile_size_arg},
{"",
"--preview",
std::string("preview method. must be one of the following [") + previews_str[0] + ", " + previews_str[1] + ", " + previews_str[2] + ", " + previews_str[3] + "] (default is " + previews_str[PREVIEW_NONE] + ")\n",
on_preview_arg},
};
if (!parse_options(argc, argv, options)) {
@ -1452,15 +1506,50 @@ bool load_images_from_dir(const std::string dir,
return true;
}
const char* preview_path;
float preview_fps;
void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy) {
(void)step;
(void)is_noisy;
// is_noisy is set to true if the preview corresponds to noisy latents, false if it's denoised latents
// unused in this app, it will either be always noisy or always denoised here
if (frame_count == 1) {
stbi_write_png(preview_path, image->width, image->height, image->channel, image->data, 0);
} else {
create_mjpg_avi_from_sd_images(preview_path, image, frame_count, preview_fps);
}
}
int main(int argc, const char* argv[]) {
SDParams params;
parse_args(argc, argv, params);
preview_path = params.preview_path.c_str();
if (params.video_frames > 4) {
size_t last_dot_pos = params.preview_path.find_last_of(".");
std::string base_path = params.preview_path;
std::string file_ext = "";
if (last_dot_pos != std::string::npos) { // filename has extension
base_path = params.preview_path.substr(0, last_dot_pos);
file_ext = params.preview_path.substr(last_dot_pos);
std::transform(file_ext.begin(), file_ext.end(), file_ext.begin(), ::tolower);
}
if (file_ext == ".png") {
base_path = base_path + ".avi";
preview_path = base_path.c_str();
}
}
preview_fps = params.fps;
if (params.preview_method == PREVIEW_PROJ)
preview_fps /= 4.0f;
params.sample_params.guidance.slg.layers = params.skip_layers.data();
params.sample_params.guidance.slg.layer_count = params.skip_layers.size();
params.high_noise_sample_params.guidance.slg.layers = params.high_noise_skip_layers.data();
params.high_noise_sample_params.guidance.slg.layer_count = params.high_noise_skip_layers.size();
sd_set_log_callback(sd_log_cb, (void*)&params);
sd_set_preview_callback((sd_preview_cb_t)step_callback, params.preview_method, params.preview_interval, !params.preview_noisy, params.preview_noisy);
if (params.verbose) {
print_params(params);
@ -1654,6 +1743,7 @@ int main(int argc, const char* argv[]) {
params.control_net_cpu,
params.vae_on_cpu,
params.diffusion_flash_attn,
params.taesd_preview,
params.diffusion_conv_direct,
params.vae_conv_direct,
params.force_sdxl_vae_conv_scale,

View File

@ -875,7 +875,7 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
ggml_tensor* input_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, input_tile_size_x, input_tile_size_y, input->ne[2], input->ne[3]);
ggml_tensor* output_tile = ggml_new_tensor_4d(tiles_ctx, GGML_TYPE_F32, output_tile_size_x, output_tile_size_y, output->ne[2], output->ne[3]);
int num_tiles = num_tiles_x * num_tiles_y;
LOG_INFO("processing %i tiles", num_tiles);
LOG_DEBUG("processing %i tiles", num_tiles);
pretty_progress(0, num_tiles, 0.0f);
int tile_count = 1;
bool last_y = false, last_x = false;

173
latent-preview.h Normal file
View File

@ -0,0 +1,173 @@
#include <cstddef>
#include <cstdint>
#include "ggml.h"
const float wan_21_latent_rgb_proj[16][3] = {
{0.015123f, -0.148418f, 0.479828f},
{0.003652f, -0.010680f, -0.037142f},
{0.212264f, 0.063033f, 0.016779f},
{0.232999f, 0.406476f, 0.220125f},
{-0.051864f, -0.082384f, -0.069396f},
{0.085005f, -0.161492f, 0.010689f},
{-0.245369f, -0.506846f, -0.117010f},
{-0.151145f, 0.017721f, 0.007207f},
{-0.293239f, -0.207936f, -0.421135f},
{-0.187721f, 0.050783f, 0.177649f},
{-0.013067f, 0.265964f, 0.166578f},
{0.028327f, 0.109329f, 0.108642f},
{-0.205343f, 0.043991f, 0.148914f},
{0.014307f, -0.048647f, -0.007219f},
{0.217150f, 0.053074f, 0.319923f},
{0.155357f, 0.083156f, 0.064780f}};
float wan_21_latent_rgb_bias[3] = {-0.270270f, -0.234976f, -0.456853f};
const float wan_22_latent_rgb_proj[48][3] = {
{0.017126f, -0.027230f, -0.019257f},
{-0.113739f, -0.028715f, -0.022885f},
{-0.000106f, 0.021494f, 0.004629f},
{-0.013273f, -0.107137f, -0.033638f},
{-0.000381f, 0.000279f, 0.025877f},
{-0.014216f, -0.003975f, 0.040528f},
{0.001638f, -0.000748f, 0.011022f},
{0.029238f, -0.006697f, 0.035933f},
{0.021641f, -0.015874f, 0.040531f},
{-0.101984f, -0.070160f, -0.028855f},
{0.033207f, -0.021068f, 0.002663f},
{-0.104711f, 0.121673f, 0.102981f},
{0.082647f, -0.004991f, 0.057237f},
{-0.027375f, 0.031581f, 0.006868f},
{-0.045434f, 0.029444f, 0.019287f},
{-0.046572f, -0.012537f, 0.006675f},
{0.074709f, 0.033690f, 0.025289f},
{-0.008251f, -0.002745f, -0.006999f},
{0.012685f, -0.061856f, -0.048658f},
{0.042304f, -0.007039f, 0.000295f},
{-0.007644f, -0.060843f, -0.033142f},
{0.159909f, 0.045628f, 0.367541f},
{0.095171f, 0.086438f, 0.010271f},
{0.006812f, 0.019643f, 0.029637f},
{0.003467f, -0.010705f, 0.014252f},
{-0.099681f, -0.066272f, -0.006243f},
{0.047357f, 0.037040f, 0.000185f},
{-0.041797f, -0.089225f, -0.032257f},
{0.008928f, 0.017028f, 0.018684f},
{-0.042255f, 0.016045f, 0.006849f},
{0.011268f, 0.036462f, 0.037387f},
{0.011553f, -0.016375f, -0.048589f},
{0.046266f, -0.027189f, 0.056979f},
{0.009640f, -0.017576f, 0.030324f},
{-0.045794f, -0.036083f, -0.010616f},
{0.022418f, 0.039783f, -0.032939f},
{-0.052714f, -0.015525f, 0.007438f},
{0.193004f, 0.223541f, 0.264175f},
{-0.059406f, -0.008188f, 0.022867f},
{-0.156742f, -0.263791f, -0.007385f},
{-0.015717f, 0.016570f, 0.033969f},
{0.037969f, 0.109835f, 0.200449f},
{-0.000782f, -0.009566f, -0.008058f},
{0.010709f, 0.052960f, -0.044195f},
{0.017271f, 0.045839f, 0.034569f},
{0.009424f, 0.013088f, -0.001714f},
{-0.024805f, -0.059378f, -0.033756f},
{-0.078293f, 0.029070f, 0.026129f}};
float wan_22_latent_rgb_bias[3] = {0.013160f, -0.096492f, -0.071323f};
const float flux_latent_rgb_proj[16][3] = {
{-0.041168f, 0.019917f, 0.097253f},
{0.028096f, 0.026730f, 0.129576f},
{0.065618f, -0.067950f, -0.014651f},
{-0.012998f, -0.014762f, 0.081251f},
{0.078567f, 0.059296f, -0.024687f},
{-0.015987f, -0.003697f, 0.005012f},
{0.033605f, 0.138999f, 0.068517f},
{-0.024450f, -0.063567f, -0.030101f},
{-0.040194f, -0.016710f, 0.127185f},
{0.112681f, 0.088764f, -0.041940f},
{-0.023498f, 0.093664f, 0.025543f},
{0.082899f, 0.048320f, 0.007491f},
{0.075712f, 0.074139f, 0.081965f},
{-0.143501f, 0.018263f, -0.136138f},
{-0.025767f, -0.082035f, -0.040023f},
{-0.111849f, -0.055589f, -0.032361f}};
float flux_latent_rgb_bias[3] = {0.024600f, -0.006937f, -0.008089f};
// This one was taken straight from
// https://github.com/Stability-AI/sd3.5/blob/8565799a3b41eb0c7ba976d18375f0f753f56402/sd3_impls.py#L288-L303
// (MiT Licence)
const float sd3_latent_rgb_proj[16][3] = {
{-0.0645f, 0.0177f, 0.1052f},
{0.0028f, 0.0312f, 0.0650f},
{0.1848f, 0.0762f, 0.0360f},
{0.0944f, 0.0360f, 0.0889f},
{0.0897f, 0.0506f, -0.0364f},
{-0.0020f, 0.1203f, 0.0284f},
{0.0855f, 0.0118f, 0.0283f},
{-0.0539f, 0.0658f, 0.1047f},
{-0.0057f, 0.0116f, 0.0700f},
{-0.0412f, 0.0281f, -0.0039f},
{0.1106f, 0.1171f, 0.1220f},
{-0.0248f, 0.0682f, -0.0481f},
{0.0815f, 0.0846f, 0.1207f},
{-0.0120f, -0.0055f, -0.0867f},
{-0.0749f, -0.0634f, -0.0456f},
{-0.1418f, -0.1457f, -0.1259f},
};
float sd3_latent_rgb_bias[3] = {0, 0, 0};
const float sdxl_latent_rgb_proj[4][3] = {
{0.258303f, 0.277640f, 0.329699f},
{-0.299701f, 0.105446f, 0.014194f},
{0.050522f, 0.186163f, -0.143257f},
{-0.211938f, -0.149892f, -0.080036f}};
float sdxl_latent_rgb_bias[3] = {0.144381f, -0.033313f, 0.007061f};
const float sd_latent_rgb_proj[4][3] = {
{0.337366f, 0.216344f, 0.257386f},
{0.165636f, 0.386828f, 0.046994f},
{-0.267803f, 0.237036f, 0.223517f},
{-0.178022f, -0.200862f, -0.678514f}};
float sd_latent_rgb_bias[3] = {-0.017478f, -0.055834f, -0.105825f};
void preview_latent_video(uint8_t* buffer, struct ggml_tensor* latents, const float (*latent_rgb_proj)[3], const float latent_rgb_bias[3], int width, int height, int frames, int dim) {
size_t buffer_head = 0;
for (int k = 0; k < frames; k++) {
for (int j = 0; j < height; j++) {
for (int i = 0; i < width; i++) {
size_t latent_id = (i * latents->nb[0] + j * latents->nb[1] + k * latents->nb[2]);
float r = 0, g = 0, b = 0;
if (latent_rgb_proj != nullptr) {
for (int d = 0; d < dim; d++) {
float value = *(float*)((char*)latents->data + latent_id + d * latents->nb[ggml_n_dims(latents) - 1]);
r += value * latent_rgb_proj[d][0];
g += value * latent_rgb_proj[d][1];
b += value * latent_rgb_proj[d][2];
}
} else {
// interpret first 3 channels as RGB
r = *(float*)((char*)latents->data + latent_id + 0 * latents->nb[ggml_n_dims(latents) - 1]);
g = *(float*)((char*)latents->data + latent_id + 1 * latents->nb[ggml_n_dims(latents) - 1]);
b = *(float*)((char*)latents->data + latent_id + 2 * latents->nb[ggml_n_dims(latents) - 1]);
}
if (latent_rgb_bias != nullptr) {
// bias
r += latent_rgb_bias[0];
g += latent_rgb_bias[1];
b += latent_rgb_bias[2];
}
// change range
r = r * .5f + .5f;
g = g * .5f + .5f;
b = b * .5f + .5f;
// clamp rgb values to [0,1] range
r = r >= 0 ? r <= 1 ? r : 1 : 0;
g = g >= 0 ? g <= 1 ? g : 1 : 0;
b = b >= 0 ? b <= 1 ? b : 1 : 0;
buffer[buffer_head++] = (uint8_t)(r * 255);
buffer[buffer_head++] = (uint8_t)(g * 255);
buffer[buffer_head++] = (uint8_t)(b * 255);
}
}
}
}

View File

@ -16,6 +16,8 @@
#include "tae.hpp"
#include "vae.hpp"
#include "latent-preview.h"
const char* model_version_to_str[] = {
"SD 1.x",
"SD 1.x Inpaint",
@ -74,6 +76,14 @@ void calculate_alphas_cumprod(float* alphas_cumprod,
}
}
void suppress_pp(int step, int steps, float time, void* data) {
(void)step;
(void)steps;
(void)time;
(void)data;
return;
}
/*=============================================== StableDiffusionGGML ================================================*/
class StableDiffusionGGML {
@ -487,7 +497,7 @@ public:
} else if (version == VERSION_CHROMA_RADIANCE) {
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
offload_params_to_cpu);
} else if (!use_tiny_autoencoder) {
} else if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) {
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
@ -510,7 +520,8 @@ public:
}
first_stage_model->alloc_params_buffer();
first_stage_model->get_param_tensors(tensors, "first_stage_model");
} else {
}
if (use_tiny_autoencoder) {
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
offload_params_to_cpu,
tensor_storage_map,
@ -626,9 +637,10 @@ public:
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
}
size_t vae_params_mem_size = 0;
if (!use_tiny_autoencoder) {
if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) {
vae_params_mem_size = first_stage_model->get_params_buffer_size();
} else {
}
if (use_tiny_autoencoder) {
if (!tae_first_stage->load_from_file(taesd_path, n_threads)) {
return false;
}
@ -801,6 +813,7 @@ public:
LOG_DEBUG("finished loaded file");
ggml_free(ctx);
use_tiny_autoencoder = use_tiny_autoencoder && !sd_ctx_params->tae_preview_only;
return true;
}
@ -1109,6 +1122,156 @@ public:
}
}
void silent_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) {
sd_progress_cb_t cb = sd_get_progress_callback();
void* cbd = sd_get_progress_callback_data();
sd_set_progress_callback((sd_progress_cb_t)suppress_pp, nullptr);
sd_tiling(input, output, scale, tile_size, tile_overlap_factor, on_processing);
sd_set_progress_callback(cb, cbd);
}
void preview_image(ggml_context* work_ctx,
int step,
struct ggml_tensor* latents,
enum SDVersion version,
preview_t preview_mode,
ggml_tensor* result,
std::function<void(int, int, sd_image_t*, bool)> step_callback,
bool is_noisy) {
const uint32_t channel = 3;
uint32_t width = latents->ne[0];
uint32_t height = latents->ne[1];
uint32_t dim = latents->ne[ggml_n_dims(latents) - 1];
if (preview_mode == PREVIEW_PROJ) {
const float(*latent_rgb_proj)[channel] = nullptr;
float* latent_rgb_bias = nullptr;
if (dim == 48) {
if (sd_version_is_wan(version)) {
latent_rgb_proj = wan_22_latent_rgb_proj;
latent_rgb_bias = wan_22_latent_rgb_bias;
} else {
LOG_WARN("No latent to RGB projection known for this model");
// unknown model
return;
}
} else if (dim == 16) {
// 16 channels VAE -> Flux or SD3
if (sd_version_is_sd3(version)) {
latent_rgb_proj = sd3_latent_rgb_proj;
latent_rgb_bias = sd3_latent_rgb_bias;
} else if (sd_version_is_flux(version)) {
latent_rgb_proj = flux_latent_rgb_proj;
latent_rgb_bias = flux_latent_rgb_bias;
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
latent_rgb_proj = wan_21_latent_rgb_proj;
latent_rgb_bias = wan_21_latent_rgb_bias;
} else {
LOG_WARN("No latent to RGB projection known for this model");
// unknown model
return;
}
} else if (dim == 4) {
// 4 channels VAE
if (sd_version_is_sdxl(version)) {
latent_rgb_proj = sdxl_latent_rgb_proj;
latent_rgb_bias = sdxl_latent_rgb_bias;
} else if (sd_version_is_sd1(version) || sd_version_is_sd2(version)) {
latent_rgb_proj = sd_latent_rgb_proj;
latent_rgb_bias = sd_latent_rgb_bias;
} else {
// unknown model
LOG_WARN("No latent to RGB projection known for this model");
return;
}
} else if (dim == 3) {
// Do nothing, assuming already RGB latents
} else {
LOG_WARN("No latent to RGB projection known for this model");
// unknown latent space
return;
}
uint32_t frames = 1;
if (ggml_n_dims(latents) == 4) {
frames = latents->ne[2];
}
uint8_t* data = (uint8_t*)malloc(frames * width * height * channel * sizeof(uint8_t));
preview_latent_video(data, latents, latent_rgb_proj, latent_rgb_bias, width, height, frames, dim);
sd_image_t* images = (sd_image_t*)malloc(frames * sizeof(sd_image_t));
for (int i = 0; i < frames; i++) {
images[i] = {width, height, channel, data + i * width * height * channel};
}
step_callback(step, frames, images, is_noisy);
free(data);
free(images);
} else {
if (preview_mode == PREVIEW_VAE) {
process_latent_out(latents);
if (vae_tiling_params.enabled) {
// split latent in 32x32 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
first_stage_model->compute(n_threads, in, true, &out, nullptr);
};
silent_tiling(latents, result, get_vae_scale_factor(), 32, 0.5f, on_tiling);
} else {
first_stage_model->compute(n_threads, latents, true, &result, work_ctx);
}
first_stage_model->free_compute_buffer();
process_vae_output_tensor(result);
process_latent_in(latents);
} else if (preview_mode == PREVIEW_TAE) {
if (tae_first_stage == nullptr) {
LOG_WARN("TAE not found for preview");
return;
}
if (vae_tiling_params.enabled) {
// split latent in 64x64 tiles and compute in several steps
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
tae_first_stage->compute(n_threads, in, true, &out, nullptr);
};
silent_tiling(latents, result, get_vae_scale_factor(), 64, 0.5f, on_tiling);
} else {
tae_first_stage->compute(n_threads, latents, true, &result, work_ctx);
}
tae_first_stage->free_compute_buffer();
} else {
return;
}
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
uint32_t frames = 1;
if (ggml_n_dims(latents) == 4) {
frames = result->ne[2];
}
sd_image_t* images = (sd_image_t*)malloc(frames * sizeof(sd_image_t));
// print_ggml_tensor(result,true);
for (size_t i = 0; i < frames; i++) {
images[i].width = result->ne[0];
images[i].height = result->ne[1];
images[i].channel = 3;
images[i].data = ggml_tensor_to_sd_image(result, i, ggml_n_dims(latents) == 4);
}
step_callback(step, frames, images, is_noisy);
ggml_ext_tensor_scale_inplace(result, 0);
for (int i = 0; i < frames; i++) {
free(images[i].data);
}
free(images);
}
}
ggml_tensor* sample(ggml_context* work_ctx,
std::shared_ptr<DiffusionModel> work_diffusion_model,
bool inverse_noise_scaling,
@ -1184,7 +1347,34 @@ public:
int64_t t0 = ggml_time_us();
struct ggml_tensor* preview_tensor = nullptr;
auto sd_preview_mode = sd_get_preview_mode();
if (sd_preview_mode != PREVIEW_NONE && sd_preview_mode != PREVIEW_PROJ) {
int64_t W = x->ne[0] * get_vae_scale_factor();
int64_t H = x->ne[1] * get_vae_scale_factor();
if (ggml_n_dims(x) == 4) {
// assuming video mode (if batch processing gets implemented this will break)
int T = x->ne[2];
if (sd_version_is_wan(version)) {
T = ((T - 1) * 4) + 1;
}
preview_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32,
W,
H,
T,
3);
} else {
preview_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32,
W,
H,
3,
x->ne[3]);
}
}
auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
auto sd_preview_cb = sd_get_preview_callback();
auto sd_preview_mode = sd_get_preview_mode();
if (step == 1 || step == -1) {
pretty_progress(0, (int)steps, 0);
}
@ -1219,6 +1409,11 @@ public:
if (denoise_mask != nullptr && version == VERSION_WAN2_2_TI2V) {
apply_mask(noised_input, init_latent, denoise_mask);
}
if (sd_preview_cb != nullptr && sd_should_preview_noisy()) {
if (step % sd_get_preview_interval() == 0) {
preview_image(work_ctx, step, noised_input, version, sd_preview_mode, preview_tensor, sd_preview_cb, true);
}
}
std::vector<struct ggml_tensor*> controls;
@ -1340,16 +1535,22 @@ public:
vec_denoised[i] = latent_result * c_out + vec_input[i] * c_skip;
}
if (denoise_mask != nullptr) {
apply_mask(denoised, init_latent, denoise_mask);
}
if (sd_preview_cb != nullptr && sd_should_preview_denoised()) {
if (step % sd_get_preview_interval() == 0) {
preview_image(work_ctx, step, denoised, version, sd_preview_mode, preview_tensor, sd_preview_cb, false);
}
}
int64_t t1 = ggml_time_us();
if (step > 0 || step == -(int)steps) {
int showstep = std::abs(step);
pretty_progress(showstep, (int)steps, (t1 - t0) / 1000000.f / showstep);
// LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000);
}
if (denoise_mask != nullptr) {
apply_mask(denoised, init_latent, denoise_mask);
}
return denoised;
};
@ -1855,6 +2056,29 @@ enum prediction_t str_to_prediction(const char* str) {
return PREDICTION_COUNT;
}
const char* preview_to_str[] = {
"none",
"proj",
"tae",
"vae",
};
const char* sd_preview_name(enum preview_t preview) {
if (preview < PREVIEW_COUNT) {
return preview_to_str[preview];
}
return NONE_STR;
}
enum preview_t str_to_preview(const char* str) {
for (int i = 0; i < PREVIEW_COUNT; i++) {
if (!strcmp(str, preview_to_str[i])) {
return (enum preview_t)i;
}
}
return PREVIEW_COUNT;
}
void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
*sd_ctx_params = {};
sd_ctx_params->vae_decode_only = true;

View File

@ -126,6 +126,14 @@ enum sd_log_level_t {
SD_LOG_ERROR
};
enum preview_t {
PREVIEW_NONE,
PREVIEW_PROJ,
PREVIEW_TAE,
PREVIEW_VAE,
PREVIEW_COUNT
};
typedef struct {
bool enabled;
int tile_size_x;
@ -162,6 +170,7 @@ typedef struct {
bool keep_control_net_on_cpu;
bool keep_vae_on_cpu;
bool diffusion_flash_attn;
bool tae_preview_only;
bool diffusion_conv_direct;
bool vae_conv_direct;
bool force_sdxl_vae_conv_scale;
@ -254,9 +263,11 @@ typedef struct sd_ctx_t sd_ctx_t;
typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data);
typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data);
typedef void (*sd_preview_cb_t)(int step, int frame_count, sd_image_t* frames, bool is_noisy);
SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data);
SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data);
SD_API void sd_set_preview_callback(sd_preview_cb_t cb, preview_t mode, int interval, bool denoised, bool noisy);
SD_API int32_t get_num_physical_cores();
SD_API const char* sd_get_system_info();
@ -270,6 +281,8 @@ SD_API const char* sd_schedule_name(enum scheduler_t scheduler);
SD_API enum scheduler_t str_to_schedule(const char* str);
SD_API const char* sd_prediction_name(enum prediction_t prediction);
SD_API enum prediction_t str_to_prediction(const char* str);
SD_API const char* sd_preview_name(enum preview_t preview);
SD_API enum preview_t str_to_preview(const char* str);
SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params);
SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params);

View File

@ -185,6 +185,12 @@ int32_t get_num_physical_cores() {
static sd_progress_cb_t sd_progress_cb = nullptr;
void* sd_progress_cb_data = nullptr;
static sd_preview_cb_t sd_preview_cb = nullptr;
preview_t sd_preview_mode = PREVIEW_NONE;
int sd_preview_interval = 1;
bool sd_preview_denoised = true;
bool sd_preview_noisy = false;
std::u32string utf8_to_utf32(const std::string& utf8_str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
return converter.from_bytes(utf8_str);
@ -328,6 +334,37 @@ void sd_set_progress_callback(sd_progress_cb_t cb, void* data) {
sd_progress_cb = cb;
sd_progress_cb_data = data;
}
void sd_set_preview_callback(sd_preview_cb_t cb, preview_t mode = PREVIEW_PROJ, int interval = 1, bool denoised = true, bool noisy = false) {
sd_preview_cb = cb;
sd_preview_mode = mode;
sd_preview_interval = interval;
sd_preview_denoised = denoised;
sd_preview_noisy = noisy;
}
sd_preview_cb_t sd_get_preview_callback() {
return sd_preview_cb;
}
preview_t sd_get_preview_mode() {
return sd_preview_mode;
}
int sd_get_preview_interval() {
return sd_preview_interval;
}
bool sd_should_preview_denoised() {
return sd_preview_denoised;
}
bool sd_should_preview_noisy() {
return sd_preview_noisy;
}
sd_progress_cb_t sd_get_progress_callback() {
return sd_progress_cb;
}
void* sd_get_progress_callback_data() {
return sd_progress_cb_data;
}
const char* sd_get_system_info() {
static char buffer[1024];
std::stringstream ss;

9
util.h
View File

@ -54,6 +54,15 @@ std::string trim(const std::string& s);
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text);
sd_progress_cb_t sd_get_progress_callback();
void* sd_get_progress_callback_data();
sd_preview_cb_t sd_get_preview_callback();
preview_t sd_get_preview_mode();
int sd_get_preview_interval();
bool sd_should_preview_denoised();
bool sd_should_preview_noisy();
#define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__)