Compare commits

...

2 Commits

Author SHA1 Message Date
leejet
c2d8ffc22c
fix: compatibility for models with modified tensor shapes (#951) 2025-11-07 23:04:41 +08:00
stduhpf
fb748bb8a4
fix: TAE encoding (#935) 2025-11-07 22:58:59 +08:00
4 changed files with 39 additions and 2 deletions

View File

@ -410,6 +410,22 @@ protected:
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2
bool use_linear = false;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
if (iter != tensor_storage_map.end()) {
int64_t inner_dim = n_head * d_head;
if (iter->second.n_dims == 4 && use_linear) {
use_linear = false;
blocks["proj_in"] = std::make_shared<Conv2d>(in_channels, inner_dim, std::pair{1, 1});
blocks["proj_out"] = std::make_shared<Conv2d>(inner_dim, in_channels, std::pair{1, 1});
} else if (iter->second.n_dims == 2 && !use_linear) {
use_linear = true;
blocks["proj_in"] = std::make_shared<Linear>(in_channels, inner_dim);
blocks["proj_out"] = std::make_shared<Linear>(inner_dim, in_channels);
}
}
}
public:
SpatialTransformer(int64_t in_channels,
int64_t n_head,

View File

@ -1926,8 +1926,8 @@ public:
if (prefix.size() > 0) {
prefix = prefix + ".";
}
init_blocks(ctx, tensor_storage_map, prefix);
init_params(ctx, tensor_storage_map, prefix);
init_blocks(ctx, tensor_storage_map, prefix);
}
size_t get_params_num() {

View File

@ -1645,7 +1645,9 @@ public:
} else {
latent = gaussian_latent_sample(work_ctx, vae_output);
}
process_latent_in(latent);
if (!use_tiny_autoencoder) {
process_latent_in(latent);
}
if (sd_version_is_qwen_image(version)) {
latent = ggml_reshape_4d(work_ctx, latent, latent->ne[0], latent->ne[1], latent->ne[3], 1);
}

19
vae.hpp
View File

@ -66,6 +66,25 @@ protected:
int64_t in_channels;
bool use_linear;
void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
if (iter != tensor_storage_map.end()) {
if (iter->second.n_dims == 4 && use_linear) {
use_linear = false;
blocks["q"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["k"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["v"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["proj_out"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
} else if (iter->second.n_dims == 2 && !use_linear) {
use_linear = true;
blocks["q"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["k"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["v"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["proj_out"] = std::make_shared<Linear>(in_channels, in_channels);
}
}
}
public:
AttnBlock(int64_t in_channels, bool use_linear)
: in_channels(in_channels), use_linear(use_linear) {