#ifndef __VAE_HPP__ #define __VAE_HPP__ #include "common_block.hpp" struct VAE : public GGMLRunner { protected: SDVersion version; bool scale_input = true; virtual bool _compute(const int n_threads, ggml_tensor* z, bool decode_graph, ggml_tensor** output, ggml_context* output_ctx) = 0; public: VAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu) : version(version), GGMLRunner(backend, offload_params_to_cpu) {} int get_scale_factor() { int scale_factor = 8; if (version == VERSION_WAN2_2_TI2V) { scale_factor = 16; } else if (sd_version_is_flux2(version)) { scale_factor = 16; } else if (version == VERSION_CHROMA_RADIANCE) { scale_factor = 1; } return scale_factor; } virtual int get_encoder_output_channels(int input_channels) = 0; void get_tile_sizes(int& tile_size_x, int& tile_size_y, float& tile_overlap, const sd_tiling_params_t& params, int64_t latent_x, int64_t latent_y, float encoding_factor = 1.0f) { tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f); auto get_tile_size = [&](int requested_size, float factor, int64_t latent_size) { const int default_tile_size = 32; const int min_tile_dimension = 4; int tile_size = default_tile_size; // factor <= 1 means simple fraction of the latent dimension // factor > 1 means number of tiles across that dimension if (factor > 0.f) { if (factor > 1.0) factor = 1 / (factor - factor * tile_overlap + tile_overlap); tile_size = static_cast(std::round(latent_size * factor)); } else if (requested_size >= min_tile_dimension) { tile_size = requested_size; } tile_size = static_cast(tile_size * encoding_factor); return std::max(std::min(tile_size, static_cast(latent_size)), min_tile_dimension); }; tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x); tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y); } ggml_tensor* encode(int n_threads, ggml_context* work_ctx, ggml_tensor* x, sd_tiling_params_t tiling_params, bool circular_x = false, bool circular_y = false) { int64_t t0 = ggml_time_ms(); ggml_tensor* result = nullptr; const int scale_factor = get_scale_factor(); int64_t W = x->ne[0] / scale_factor; int64_t H = x->ne[1] / scale_factor; int channel_dim = sd_version_is_wan(version) ? 3 : 2; int64_t C = get_encoder_output_channels(static_cast(x->ne[channel_dim])); int64_t ne2; int64_t ne3; if (sd_version_is_wan(version)) { int64_t T = x->ne[2]; ne2 = (T - 1) / 4 + 1; ne3 = C; } else { ne2 = C; ne3 = x->ne[3]; } result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3); if (scale_input) { scale_to_minus1_1(x); } if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); } if (tiling_params.enabled) { float tile_overlap; int tile_size_x, tile_size_y; // multiply tile size for encode to keep the compute buffer size consistent get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, W, H, 1.30539f); LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y); auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { return _compute(n_threads, in, false, &out, work_ctx); }; sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling); } else { _compute(n_threads, x, false, &result, work_ctx); } free_compute_buffer(); int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); return result; } ggml_tensor* decode(int n_threads, ggml_context* work_ctx, ggml_tensor* x, sd_tiling_params_t tiling_params, bool decode_video = false, bool circular_x = false, bool circular_y = false, ggml_tensor* result = nullptr, bool silent = false) { const int scale_factor = get_scale_factor(); int64_t W = x->ne[0] * scale_factor; int64_t H = x->ne[1] * scale_factor; int64_t C = 3; if (result == nullptr) { if (decode_video) { int64_t T = x->ne[2]; if (sd_version_is_wan(version)) { T = ((T - 1) * 4) + 1; } result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, 3); } else { result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, x->ne[3]); } } int64_t t0 = ggml_time_ms(); if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) { x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]); } if (tiling_params.enabled) { float tile_overlap; int tile_size_x, tile_size_y; get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, x->ne[0], x->ne[1]); if (!silent) { LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y); } auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) { return _compute(n_threads, in, true, &out, nullptr); }; sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling, silent); } else { if (!_compute(n_threads, x, true, &result, work_ctx)) { LOG_ERROR("Failed to decode latetnts"); free_compute_buffer(); return nullptr; } } free_compute_buffer(); if (scale_input) { scale_to_0_1(result); } int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f); return result; } virtual ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr rng) = 0; virtual ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0; virtual ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0; virtual void get_param_tensors(std::map& tensors, const std::string prefix) = 0; virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); }; }; struct FakeVAE : public VAE { FakeVAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu) : VAE(version, backend, offload_params_to_cpu) {} int get_encoder_output_channels(int input_channels) { return input_channels; } bool _compute(const int n_threads, ggml_tensor* z, bool decode_graph, ggml_tensor** output, ggml_context* output_ctx) override { if (*output == nullptr && output_ctx != nullptr) { *output = ggml_dup_tensor(output_ctx, z); } ggml_ext_tensor_iter(z, [&](ggml_tensor* z, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { float value = ggml_ext_tensor_get_f32(z, i0, i1, i2, i3); ggml_ext_tensor_set_f32(*output, value, i0, i1, i2, i3); }); return true; } ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr rng) { return vae_output; } ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) { return ggml_ext_dup_and_cpy_tensor(work_ctx, latents); } ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) { return ggml_ext_dup_and_cpy_tensor(work_ctx, latents); } void get_param_tensors(std::map& tensors, const std::string prefix) override {} std::string get_desc() override { return "fake_vae"; } }; #endif // __VAE_HPP__