mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-23 17:58:58 +00:00
236 lines
9.7 KiB
C++
236 lines
9.7 KiB
C++
#ifndef __VAE_HPP__
|
|
#define __VAE_HPP__
|
|
|
|
#include "common_block.hpp"
|
|
|
|
struct VAE : public GGMLRunner {
|
|
protected:
|
|
SDVersion version;
|
|
bool scale_input = true;
|
|
virtual bool _compute(const int n_threads,
|
|
ggml_tensor* z,
|
|
bool decode_graph,
|
|
ggml_tensor** output,
|
|
ggml_context* output_ctx) = 0;
|
|
|
|
public:
|
|
VAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu)
|
|
: version(version), GGMLRunner(backend, offload_params_to_cpu) {}
|
|
|
|
int get_scale_factor() {
|
|
int scale_factor = 8;
|
|
if (version == VERSION_WAN2_2_TI2V) {
|
|
scale_factor = 16;
|
|
} else if (sd_version_is_flux2(version)) {
|
|
scale_factor = 16;
|
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
|
scale_factor = 1;
|
|
}
|
|
return scale_factor;
|
|
}
|
|
|
|
virtual int get_encoder_output_channels(int input_channels) = 0;
|
|
|
|
void get_tile_sizes(int& tile_size_x,
|
|
int& tile_size_y,
|
|
float& tile_overlap,
|
|
const sd_tiling_params_t& params,
|
|
int64_t latent_x,
|
|
int64_t latent_y,
|
|
float encoding_factor = 1.0f) {
|
|
tile_overlap = std::max(std::min(params.target_overlap, 0.5f), 0.0f);
|
|
auto get_tile_size = [&](int requested_size, float factor, int64_t latent_size) {
|
|
const int default_tile_size = 32;
|
|
const int min_tile_dimension = 4;
|
|
int tile_size = default_tile_size;
|
|
// factor <= 1 means simple fraction of the latent dimension
|
|
// factor > 1 means number of tiles across that dimension
|
|
if (factor > 0.f) {
|
|
if (factor > 1.0)
|
|
factor = 1 / (factor - factor * tile_overlap + tile_overlap);
|
|
tile_size = static_cast<int>(std::round(latent_size * factor));
|
|
} else if (requested_size >= min_tile_dimension) {
|
|
tile_size = requested_size;
|
|
}
|
|
tile_size = static_cast<int>(tile_size * encoding_factor);
|
|
return std::max(std::min(tile_size, static_cast<int>(latent_size)), min_tile_dimension);
|
|
};
|
|
|
|
tile_size_x = get_tile_size(params.tile_size_x, params.rel_size_x, latent_x);
|
|
tile_size_y = get_tile_size(params.tile_size_y, params.rel_size_y, latent_y);
|
|
}
|
|
|
|
ggml_tensor* encode(int n_threads,
|
|
ggml_context* work_ctx,
|
|
ggml_tensor* x,
|
|
sd_tiling_params_t tiling_params,
|
|
bool circular_x = false,
|
|
bool circular_y = false) {
|
|
int64_t t0 = ggml_time_ms();
|
|
ggml_tensor* result = nullptr;
|
|
const int scale_factor = get_scale_factor();
|
|
int64_t W = x->ne[0] / scale_factor;
|
|
int64_t H = x->ne[1] / scale_factor;
|
|
int channel_dim = sd_version_is_wan(version) ? 3 : 2;
|
|
int64_t C = get_encoder_output_channels(static_cast<int>(x->ne[channel_dim]));
|
|
int64_t ne2;
|
|
int64_t ne3;
|
|
if (sd_version_is_wan(version)) {
|
|
int64_t T = x->ne[2];
|
|
ne2 = (T - 1) / 4 + 1;
|
|
ne3 = C;
|
|
} else {
|
|
ne2 = C;
|
|
ne3 = x->ne[3];
|
|
}
|
|
result = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, ne2, ne3);
|
|
|
|
if (scale_input) {
|
|
scale_to_minus1_1(x);
|
|
}
|
|
|
|
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
|
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
|
}
|
|
|
|
if (tiling_params.enabled) {
|
|
float tile_overlap;
|
|
int tile_size_x, tile_size_y;
|
|
// multiply tile size for encode to keep the compute buffer size consistent
|
|
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, W, H, 1.30539f);
|
|
|
|
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
|
|
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
|
return _compute(n_threads, in, false, &out, work_ctx);
|
|
};
|
|
sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling);
|
|
} else {
|
|
_compute(n_threads, x, false, &result, work_ctx);
|
|
}
|
|
free_compute_buffer();
|
|
|
|
int64_t t1 = ggml_time_ms();
|
|
LOG_DEBUG("computing vae encode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
|
return result;
|
|
}
|
|
|
|
ggml_tensor* decode(int n_threads,
|
|
ggml_context* work_ctx,
|
|
ggml_tensor* x,
|
|
sd_tiling_params_t tiling_params,
|
|
bool decode_video = false,
|
|
bool circular_x = false,
|
|
bool circular_y = false,
|
|
ggml_tensor* result = nullptr,
|
|
bool silent = false) {
|
|
const int scale_factor = get_scale_factor();
|
|
int64_t W = x->ne[0] * scale_factor;
|
|
int64_t H = x->ne[1] * scale_factor;
|
|
int64_t C = 3;
|
|
if (result == nullptr) {
|
|
if (decode_video) {
|
|
int64_t T = x->ne[2];
|
|
if (sd_version_is_wan(version)) {
|
|
T = ((T - 1) * 4) + 1;
|
|
}
|
|
result = ggml_new_tensor_4d(work_ctx,
|
|
GGML_TYPE_F32,
|
|
W,
|
|
H,
|
|
T,
|
|
3);
|
|
} else {
|
|
result = ggml_new_tensor_4d(work_ctx,
|
|
GGML_TYPE_F32,
|
|
W,
|
|
H,
|
|
C,
|
|
x->ne[3]);
|
|
}
|
|
}
|
|
int64_t t0 = ggml_time_ms();
|
|
if (sd_version_is_qwen_image(version) || sd_version_is_anima(version)) {
|
|
x = ggml_reshape_4d(work_ctx, x, x->ne[0], x->ne[1], 1, x->ne[2] * x->ne[3]);
|
|
}
|
|
if (tiling_params.enabled) {
|
|
float tile_overlap;
|
|
int tile_size_x, tile_size_y;
|
|
get_tile_sizes(tile_size_x, tile_size_y, tile_overlap, tiling_params, x->ne[0], x->ne[1]);
|
|
|
|
if (!silent) {
|
|
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
|
|
}
|
|
|
|
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
|
|
return _compute(n_threads, in, true, &out, nullptr);
|
|
};
|
|
sd_tiling_non_square(x, result, scale_factor, tile_size_x, tile_size_y, tile_overlap, circular_x, circular_y, on_tiling, silent);
|
|
} else {
|
|
if (!_compute(n_threads, x, true, &result, work_ctx)) {
|
|
LOG_ERROR("Failed to decode latetnts");
|
|
free_compute_buffer();
|
|
return nullptr;
|
|
}
|
|
}
|
|
free_compute_buffer();
|
|
if (scale_input) {
|
|
scale_to_0_1(result);
|
|
}
|
|
int64_t t1 = ggml_time_ms();
|
|
LOG_DEBUG("computing vae decode graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
|
|
ggml_ext_tensor_clamp_inplace(result, 0.0f, 1.0f);
|
|
return result;
|
|
}
|
|
|
|
virtual ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) = 0;
|
|
virtual ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0;
|
|
virtual ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) = 0;
|
|
virtual void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) = 0;
|
|
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
|
};
|
|
|
|
struct FakeVAE : public VAE {
|
|
FakeVAE(SDVersion version, ggml_backend_t backend, bool offload_params_to_cpu)
|
|
: VAE(version, backend, offload_params_to_cpu) {}
|
|
|
|
int get_encoder_output_channels(int input_channels) {
|
|
return input_channels;
|
|
}
|
|
|
|
bool _compute(const int n_threads,
|
|
ggml_tensor* z,
|
|
bool decode_graph,
|
|
ggml_tensor** output,
|
|
ggml_context* output_ctx) override {
|
|
if (*output == nullptr && output_ctx != nullptr) {
|
|
*output = ggml_dup_tensor(output_ctx, z);
|
|
}
|
|
ggml_ext_tensor_iter(z, [&](ggml_tensor* z, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
|
float value = ggml_ext_tensor_get_f32(z, i0, i1, i2, i3);
|
|
ggml_ext_tensor_set_f32(*output, value, i0, i1, i2, i3);
|
|
});
|
|
return true;
|
|
}
|
|
|
|
ggml_tensor* vae_output_to_latents(ggml_context* work_ctx, ggml_tensor* vae_output, std::shared_ptr<RNG> rng) {
|
|
return vae_output;
|
|
}
|
|
|
|
ggml_tensor* diffusion_to_vae_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
|
}
|
|
|
|
ggml_tensor* vae_to_diffuison_latents(ggml_context* work_ctx, ggml_tensor* latents) {
|
|
return ggml_ext_dup_and_cpy_tensor(work_ctx, latents);
|
|
}
|
|
|
|
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors, const std::string prefix) override {}
|
|
|
|
std::string get_desc() override {
|
|
return "fake_vae";
|
|
}
|
|
};
|
|
|
|
#endif // __VAE_HPP__
|