fix: zero Wan2.2 TI2V timesteps for fixed frames (#1604)

This commit is contained in:
leejet 2026-06-03 23:32:31 +08:00 committed by GitHub
parent a7f2e03da4
commit 1f9ee88e09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1680,12 +1680,15 @@ public:
const sd::Tensor<float>& init_latent,
const sd::Tensor<float>& denoise_mask) {
if (diffusion_model->get_desc() == "Wan2.2-TI2V-5B") {
auto new_timesteps = std::vector<float>(static_cast<size_t>(init_latent.shape()[2]), timesteps[0]);
int64_t frame_count = init_latent.shape()[2];
auto new_timesteps = std::vector<float>(static_cast<size_t>(frame_count), timesteps[0]);
if (!denoise_mask.empty()) {
float value = denoise_mask.dim() == 5 ? denoise_mask.index(0, 0, 0, 0, 0) : denoise_mask.index(0, 0, 0, 0);
if (!denoise_mask.empty() && denoise_mask.dim() >= 4 && denoise_mask.shape()[2] == frame_count) {
for (int64_t frame = 0; frame < frame_count; ++frame) {
float value = denoise_mask.dim() == 5 ? denoise_mask.index(0, 0, frame, 0, 0) : denoise_mask.index(0, 0, frame, 0);
if (value == 0.f) {
new_timesteps[0] = 0.f;
new_timesteps[static_cast<size_t>(frame)] = 0.f;
}
}
}
return new_timesteps;