mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
feat: add shift factor support (#903)
This commit is contained in:
parent
d05e46ca5e
commit
48e0a28ddf
@ -94,6 +94,7 @@ public:
|
||||
std::shared_ptr<RNG> rng = std::make_shared<STDDefaultRNG>();
|
||||
int n_threads = -1;
|
||||
float scale_factor = 0.18215f;
|
||||
float shift_factor = 0.f;
|
||||
|
||||
std::shared_ptr<Conditioner> cond_stage_model;
|
||||
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
|
||||
@ -324,9 +325,10 @@ public:
|
||||
scale_factor = 0.13025f;
|
||||
} else if (sd_version_is_sd3(version)) {
|
||||
scale_factor = 1.5305f;
|
||||
shift_factor = 0.0609f;
|
||||
} else if (sd_version_is_flux(version)) {
|
||||
scale_factor = 0.3611f;
|
||||
// TODO: shift_factor
|
||||
shift_factor = 0.1159f;
|
||||
} else if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||
scale_factor = 1.0f;
|
||||
}
|
||||
@ -1404,7 +1406,11 @@ public:
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ggml_tensor_scale(latent, scale_factor);
|
||||
ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3);
|
||||
value = (value - shift_factor) * scale_factor;
|
||||
ggml_tensor_set_f32(latent, value, i0, i1, i2, i3);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1444,7 +1450,11 @@ public:
|
||||
}
|
||||
}
|
||||
} else {
|
||||
ggml_tensor_scale(latent, 1.0f / scale_factor);
|
||||
ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||
float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3);
|
||||
value = (value / scale_factor) + shift_factor;
|
||||
ggml_tensor_set_f32(latent, value, i0, i1, i2, i3);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user