mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-29 09:36:40 +00:00
feat: support Qwen-Image/Wan VAE with diffusers naming (#1713)
This commit is contained in:
parent
03e9a22f4d
commit
d77b8f5ee8
@ -850,7 +850,78 @@ std::string convert_diffusers_vae_to_original_sd1(std::string name) {
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string convert_first_stage_model_name(std::string name, std::string prefix) {
|
||||
std::string convert_diffusers_to_original_wan_vae(std::string name) {
|
||||
static const std::vector<std::pair<std::string, std::string>> prefix_map = {
|
||||
{"quant_conv.", "conv1."},
|
||||
{"post_quant_conv.", "conv2."},
|
||||
|
||||
{"decoder.up_blocks.0.resnets.0.", "decoder.upsamples.0.residual."},
|
||||
{"decoder.up_blocks.0.resnets.1.", "decoder.upsamples.1.residual."},
|
||||
{"decoder.up_blocks.0.resnets.2.", "decoder.upsamples.2.residual."},
|
||||
{"decoder.up_blocks.0.upsamplers.0.", "decoder.upsamples.3."},
|
||||
|
||||
{"decoder.up_blocks.1.resnets.0.conv_shortcut.", "decoder.upsamples.4.shortcut."},
|
||||
{"decoder.up_blocks.1.resnets.0.", "decoder.upsamples.4.residual."},
|
||||
{"decoder.up_blocks.1.resnets.1.", "decoder.upsamples.5.residual."},
|
||||
{"decoder.up_blocks.1.resnets.2.", "decoder.upsamples.6.residual."},
|
||||
{"decoder.up_blocks.1.upsamplers.0.", "decoder.upsamples.7."},
|
||||
{"decoder.up_blocks.2.resnets.0.", "decoder.upsamples.8.residual."},
|
||||
{"decoder.up_blocks.2.resnets.1.", "decoder.upsamples.9.residual."},
|
||||
{"decoder.up_blocks.2.resnets.2.", "decoder.upsamples.10.residual."},
|
||||
{"decoder.up_blocks.2.upsamplers.0.", "decoder.upsamples.11."},
|
||||
{"decoder.up_blocks.3.resnets.0.", "decoder.upsamples.12.residual."},
|
||||
{"decoder.up_blocks.3.resnets.1.", "decoder.upsamples.13.residual."},
|
||||
{"decoder.up_blocks.3.resnets.2.", "decoder.upsamples.14.residual."},
|
||||
|
||||
{"encoder.down_blocks.0.", "encoder.downsamples.0.residual."},
|
||||
{"encoder.down_blocks.1.", "encoder.downsamples.1.residual."},
|
||||
{"encoder.down_blocks.2.", "encoder.downsamples.2."},
|
||||
{"encoder.down_blocks.3.conv_shortcut.", "encoder.downsamples.3.shortcut."},
|
||||
{"encoder.down_blocks.3.", "encoder.downsamples.3.residual."},
|
||||
{"encoder.down_blocks.4.", "encoder.downsamples.4.residual."},
|
||||
{"encoder.down_blocks.5.", "encoder.downsamples.5."},
|
||||
{"encoder.down_blocks.6.conv_shortcut.", "encoder.downsamples.6.shortcut."},
|
||||
{"encoder.down_blocks.6.", "encoder.downsamples.6.residual."},
|
||||
{"encoder.down_blocks.7.", "encoder.downsamples.7.residual."},
|
||||
{"encoder.down_blocks.8.", "encoder.downsamples.8."},
|
||||
{"encoder.down_blocks.9.", "encoder.downsamples.9.residual."},
|
||||
{"encoder.down_blocks.10.", "encoder.downsamples.10.residual."},
|
||||
};
|
||||
|
||||
static const std::vector<std::pair<std::string, std::string>> shared_name_map = {
|
||||
{".conv_in.", ".conv1."},
|
||||
{".norm_out.", ".head.0."},
|
||||
{".conv_out.", ".head.2."},
|
||||
|
||||
{".mid_block.attentions.0.", ".middle.1."},
|
||||
{".mid_block.resnets.0.", ".middle.0.residual."},
|
||||
{".mid_block.resnets.1.", ".middle.2.residual."},
|
||||
};
|
||||
|
||||
static const std::vector<std::pair<std::string, std::string>> resnet_name_map = {
|
||||
{".norm1.", ".0."},
|
||||
{".conv1.", ".2."},
|
||||
{".norm2.", ".3."},
|
||||
{".conv2.", ".6."},
|
||||
};
|
||||
|
||||
replace_with_name_map(name, shared_name_map);
|
||||
replace_with_prefix_map(name, prefix_map);
|
||||
|
||||
// Only apply the ResNet-specific renaming if the tensor belongs to a ResNet block.
|
||||
// This prevents generic ".conv1." or ".conv2." matching on top-level encoder/decoder convolutions.
|
||||
if (name.find(".residual.") != std::string::npos) {
|
||||
replace_with_name_map(name, resnet_name_map);
|
||||
}
|
||||
|
||||
|
||||
return name;
|
||||
}
|
||||
|
||||
std::string convert_first_stage_model_name(std::string name, std::string prefix, SDVersion version) {
|
||||
if (sd_version_uses_wan_vae(version)) {
|
||||
return convert_diffusers_to_original_wan_vae(name);
|
||||
}
|
||||
static std::unordered_map<std::string, std::string> vae_name_map = {
|
||||
{"decoder.post_quant_conv.", "post_quant_conv."},
|
||||
{"encoder.quant_conv.", "quant_conv."},
|
||||
@ -1239,7 +1310,7 @@ std::string convert_tensor_name(std::string name, SDVersion version) {
|
||||
{
|
||||
for (const auto& prefix : first_stage_model_prefix_vec) {
|
||||
if (starts_with(name, prefix)) {
|
||||
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix);
|
||||
name = convert_first_stage_model_name(name.substr(prefix.size()), prefix, version);
|
||||
if (version == VERSION_SDXS_512_DS || version == VERSION_SDXS_09) {
|
||||
name = "tae." + name;
|
||||
} else {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user