feat: support Qwen-Image/Wan VAE with diffusers naming (#1713)

This commit is contained in:
stduhpf 2026-06-28 16:50:12 +02:00 committed by GitHub
parent 03e9a22f4d
commit d77b8f5ee8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -850,7 +850,78 @@ std::string convert_diffusers_vae_to_original_sd1(std::string name) {
return result;
}
std::string convert_first_stage_model_name(std::string name, std::string prefix) {
std::string convert_diffusers_to_original_wan_vae(std::string name) {
static const std::vector<std::pair<std::string, std::string>> prefix_map = {
{"quant_conv.", "conv1."},
{"post_quant_conv.", "conv2."},
{"decoder.up_blocks.0.resnets.0.", "decoder.upsamples.0.residual."},
{"decoder.up_blocks.0.resnets.1.", "decoder.upsamples.1.residual."},
{"decoder.up_blocks.0.resnets.2.", "decoder.upsamples.2.residual."},
{"decoder.up_blocks.0.upsamplers.0.", "decoder.upsamples.3."},
{"decoder.up_blocks.1.resnets.0.conv_shortcut.", "decoder.upsamples.4.shortcut."},
{"decoder.up_blocks.1.resnets.0.", "decoder.upsamples.4.residual."},
{"decoder.up_blocks.1.resnets.1.", "decoder.upsamples.5.residual."},
{"decoder.up_blocks.1.resnets.2.", "decoder.upsamples.6.residual."},
{"decoder.up_blocks.1.upsamplers.0.", "decoder.upsamples.7."},
{"decoder.up_blocks.2.resnets.0.", "decoder.upsamples.8.residual."},
{"decoder.up_blocks.2.resnets.1.", "decoder.upsamples.9.residual."},
{"decoder.up_blocks.2.resnets.2.", "decoder.upsamples.10.residual."},
{"decoder.up_blocks.2.upsamplers.0.", "decoder.upsamples.11."},
{"decoder.up_blocks.3.resnets.0.", "decoder.upsamples.12.residual."},
{"decoder.up_blocks.3.resnets.1.", "decoder.upsamples.13.residual."},
{"decoder.up_blocks.3.resnets.2.", "decoder.upsamples.14.residual."},
{"encoder.down_blocks.0.", "encoder.downsamples.0.residual."},
{"encoder.down_blocks.1.", "encoder.downsamples.1.residual."},
{"encoder.down_blocks.2.", "encoder.downsamples.2."},
{"encoder.down_blocks.3.conv_shortcut.", "encoder.downsamples.3.shortcut."},
{"encoder.down_blocks.3.", "encoder.downsamples.3.residual."},
{"encoder.down_blocks.4.", "encoder.downsamples.4.residual."},
{"encoder.down_blocks.5.", "encoder.downsamples.5."},
{"encoder.down_blocks.6.conv_shortcut.", "encoder.downsamples.6.shortcut."},
{"encoder.down_blocks.6.", "encoder.downsamples.6.residual."},
{"encoder.down_blocks.7.", "encoder.downsamples.7.residual."},
{"encoder.down_blocks.8.", "encoder.downsamples.8."},
{"encoder.down_blocks.9.", "encoder.downsamples.9.residual."},
{"encoder.down_blocks.10.", "encoder.downsamples.10.residual."},
};
static const std::vector<std::pair<std::string, std::string>> shared_name_map = {
{".conv_in.", ".conv1."},
{".norm_out.", ".head.0."},
{".conv_out.", ".head.2."},
{".mid_block.attentions.0.", ".middle.1."},
{".mid_block.resnets.0.", ".middle.0.residual."},
{".mid_block.resnets.1.", ".middle.2.residual."},
};
static const std::vector<std::pair<std::string, std::string>> resnet_name_map = {
{".norm1.", ".0."},
{".conv1.", ".2."},
{".norm2.", ".3."},
{".conv2.", ".6."},
};
replace_with_name_map(name, shared_name_map);
replace_with_prefix_map(name, prefix_map);
// Only apply the ResNet-specific renaming if the tensor belongs to a ResNet block.
// This prevents generic ".conv1." or ".conv2." matching on top-level encoder/decoder convolutions.
if (name.find(".residual.") != std::string::npos) {
replace_with_name_map(name, resnet_name_map);
}
return name;
}
std::string convert_first_stage_model_name(std::string name, std::string prefix, SDVersion version) {
if (sd_version_uses_wan_vae(version)) {
return convert_diffusers_to_original_wan_vae(name);
}
static std::unordered_map<std::string, std::string> vae_name_map = {
{"decoder.post_quant_conv.", "post_quant_conv."},
{"encoder.quant_conv.", "quant_conv."},
@ -1239,7 +1310,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
{
for (const auto& prefix : first_stage_model_prefix_vec) {
if (starts_with(name, prefix)) {
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix);
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix, version);
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
name = "tae." + name;
} else {