mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-02-04 10:53:34 +00:00
feat: support new chroma radiance "x0_x32_proto" (#1209)
This commit is contained in:
parent
e50e1f253d
commit
b87fe13afd
28
flux.hpp
28
flux.hpp
@ -748,7 +748,7 @@ namespace Flux {
|
|||||||
int nerf_depth = 4;
|
int nerf_depth = 4;
|
||||||
int nerf_max_freqs = 8;
|
int nerf_max_freqs = 8;
|
||||||
bool use_x0 = false;
|
bool use_x0 = false;
|
||||||
bool use_patch_size_32 = false;
|
bool fake_patch_size_x2 = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FluxParams {
|
struct FluxParams {
|
||||||
@ -786,8 +786,11 @@ namespace Flux {
|
|||||||
Flux(FluxParams params)
|
Flux(FluxParams params)
|
||||||
: params(params) {
|
: params(params) {
|
||||||
if (params.version == VERSION_CHROMA_RADIANCE) {
|
if (params.version == VERSION_CHROMA_RADIANCE) {
|
||||||
std::pair<int, int> kernel_size = {16, 16};
|
std::pair<int, int> kernel_size = {params.patch_size, params.patch_size};
|
||||||
std::pair<int, int> stride = kernel_size;
|
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
||||||
|
kernel_size = {params.patch_size / 2, params.patch_size / 2};
|
||||||
|
}
|
||||||
|
std::pair<int, int> stride = kernel_size;
|
||||||
|
|
||||||
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
|
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
|
||||||
params.hidden_size,
|
params.hidden_size,
|
||||||
@ -1082,7 +1085,7 @@ namespace Flux {
|
|||||||
auto img = pad_to_patch_size(ctx, x);
|
auto img = pad_to_patch_size(ctx, x);
|
||||||
auto orig_img = img;
|
auto orig_img = img;
|
||||||
|
|
||||||
if (params.chroma_radiance_params.use_patch_size_32) {
|
if (params.chroma_radiance_params.fake_patch_size_x2) {
|
||||||
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
|
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
|
||||||
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
|
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
|
||||||
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
|
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
|
||||||
@ -1303,7 +1306,8 @@ namespace Flux {
|
|||||||
flux_params.ref_index_scale = 10.f;
|
flux_params.ref_index_scale = 10.f;
|
||||||
flux_params.use_mlp_silu_act = true;
|
flux_params.use_mlp_silu_act = true;
|
||||||
}
|
}
|
||||||
int64_t head_dim = 0;
|
int64_t head_dim = 0;
|
||||||
|
int64_t actual_radiance_patch_size = -1;
|
||||||
for (auto pair : tensor_storage_map) {
|
for (auto pair : tensor_storage_map) {
|
||||||
std::string tensor_name = pair.first;
|
std::string tensor_name = pair.first;
|
||||||
if (!starts_with(tensor_name, prefix))
|
if (!starts_with(tensor_name, prefix))
|
||||||
@ -1316,9 +1320,12 @@ namespace Flux {
|
|||||||
flux_params.chroma_radiance_params.use_x0 = true;
|
flux_params.chroma_radiance_params.use_x0 = true;
|
||||||
}
|
}
|
||||||
if (tensor_name.find("__32x32__") != std::string::npos) {
|
if (tensor_name.find("__32x32__") != std::string::npos) {
|
||||||
LOG_DEBUG("using patch size 32 prediction");
|
LOG_DEBUG("using patch size 32");
|
||||||
flux_params.chroma_radiance_params.use_patch_size_32 = true;
|
flux_params.patch_size = 32;
|
||||||
flux_params.patch_size = 32;
|
}
|
||||||
|
if (tensor_name.find("img_in_patch.weight") != std::string::npos) {
|
||||||
|
actual_radiance_patch_size = pair.second.ne[0];
|
||||||
|
LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size);
|
||||||
}
|
}
|
||||||
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
|
||||||
// Chroma
|
// Chroma
|
||||||
@ -1351,6 +1358,11 @@ namespace Flux {
|
|||||||
head_dim = pair.second.ne[0];
|
head_dim = pair.second.ne[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (actual_radiance_patch_size > 0 && actual_radiance_patch_size != flux_params.patch_size) {
|
||||||
|
GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size);
|
||||||
|
LOG_DEBUG("using fake x2 patch size");
|
||||||
|
flux_params.chroma_radiance_params.fake_patch_size_x2 = true;
|
||||||
|
}
|
||||||
|
|
||||||
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
|
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user