mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-12 13:28:37 +00:00
feat: add chroma radiance support (#910)
* add chroma radiance support * fix ci * simply generate_init_latent * workaround: avoid ggml cuda error * format code * add chroma radiance doc
This commit is contained in:
parent
062490aa7c
commit
9e28be6479
@ -35,10 +35,11 @@ API and command-line option may change frequently.***
|
|||||||
- Image Models
|
- Image Models
|
||||||
- SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo)
|
- SD1.x, SD2.x, [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo)
|
||||||
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
|
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
|
||||||
- [some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
|
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
|
||||||
- [SD3/SD3.5](./docs/sd3.md)
|
- [SD3/SD3.5](./docs/sd3.md)
|
||||||
- [Flux-dev/Flux-schnell](./docs/flux.md)
|
- [Flux-dev/Flux-schnell](./docs/flux.md)
|
||||||
- [Chroma](./docs/chroma.md)
|
- [Chroma](./docs/chroma.md)
|
||||||
|
- [Chroma1-Radiance](./docs/chroma_radiance.md)
|
||||||
- [Qwen Image](./docs/qwen_image.md)
|
- [Qwen Image](./docs/qwen_image.md)
|
||||||
- Image Edit Models
|
- Image Edit Models
|
||||||
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
- [FLUX.1-Kontext-dev](./docs/kontext.md)
|
||||||
|
|||||||
BIN
assets/flux/chroma1-radiance.png
Normal file
BIN
assets/flux/chroma1-radiance.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 477 KiB |
21
docs/chroma_radiance.md
Normal file
21
docs/chroma_radiance.md
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# How to Use
|
||||||
|
|
||||||
|
## Download weights
|
||||||
|
|
||||||
|
- Download Chroma1-Radiance
|
||||||
|
- safetensors: https://huggingface.co/lodestones/Chroma1-Radiance/tree/main
|
||||||
|
- gguf: https://huggingface.co/silveroxides/Chroma1-Radiance-GGUF/tree/main
|
||||||
|
|
||||||
|
- Download t5xxl
|
||||||
|
- safetensors: https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
```
|
||||||
|
.\bin\Release\sd.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\Chroma1-Radiance-v0.4-Q8_0.gguf --t5xxl ..\..\ComfyUI\models\clip\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'chroma radiance cpp'" --cfg-scale 4.0 --sampling-method euler -v
|
||||||
|
```
|
||||||
|
|
||||||
|
<img alt="Chroma1-Radiance" src="../assets/flux/chroma1-radiance.png" />
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
594
flux.hpp
594
flux.hpp
@ -399,7 +399,7 @@ namespace Flux {
|
|||||||
|
|
||||||
ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
|
ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) {
|
||||||
int64_t offset = 3 * idx;
|
int64_t offset = 3 * idx;
|
||||||
return {ctx, vec, offset};
|
return ModulationOut(ctx, vec, offset);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
@ -549,7 +549,135 @@ namespace Flux {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct NerfEmbedder : public GGMLBlock {
|
||||||
|
NerfEmbedder(int64_t in_channels,
|
||||||
|
int64_t hidden_size_input,
|
||||||
|
int64_t max_freqs) {
|
||||||
|
blocks["embedder.0"] = std::make_shared<Linear>(in_channels + max_freqs * max_freqs, hidden_size_input);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* dct) {
|
||||||
|
// x: (B, P^2, C)
|
||||||
|
// dct: (1, P^2, max_freqs^2)
|
||||||
|
// return: (B, P^2, hidden_size_input)
|
||||||
|
auto embedder = std::dynamic_pointer_cast<Linear>(blocks["embedder.0"]);
|
||||||
|
|
||||||
|
dct = ggml_repeat_4d(ctx, dct, dct->ne[0], dct->ne[1], x->ne[2], x->ne[3]);
|
||||||
|
x = ggml_concat(ctx, x, dct, 0);
|
||||||
|
x = embedder->forward(ctx, x);
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NerfGLUBlock : public GGMLBlock {
|
||||||
|
int64_t mlp_ratio;
|
||||||
|
NerfGLUBlock(int64_t hidden_size_s,
|
||||||
|
int64_t hidden_size_x,
|
||||||
|
int64_t mlp_ratio)
|
||||||
|
: mlp_ratio(mlp_ratio) {
|
||||||
|
int64_t total_params = 3 * hidden_size_x * hidden_size_x * mlp_ratio;
|
||||||
|
blocks["param_generator"] = std::make_shared<Linear>(hidden_size_s, total_params);
|
||||||
|
blocks["norm"] = std::make_shared<RMSNorm>(hidden_size_x);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* s) {
|
||||||
|
// x: (batch_size, n_token, hidden_size_x)
|
||||||
|
// s: (batch_size, hidden_size_s)
|
||||||
|
// return: (batch_size, n_token, hidden_size_x)
|
||||||
|
auto param_generator = std::dynamic_pointer_cast<Linear>(blocks["param_generator"]);
|
||||||
|
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
||||||
|
|
||||||
|
int64_t batch_size = x->ne[2];
|
||||||
|
int64_t hidden_size_x = x->ne[0];
|
||||||
|
|
||||||
|
auto mlp_params = param_generator->forward(ctx, s);
|
||||||
|
auto fc_params = ggml_chunk(ctx, mlp_params, 3, 0);
|
||||||
|
auto fc1_gate = ggml_reshape_3d(ctx, fc_params[0], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
|
||||||
|
auto fc1_value = ggml_reshape_3d(ctx, fc_params[1], hidden_size_x * mlp_ratio, hidden_size_x, batch_size);
|
||||||
|
auto fc2 = ggml_reshape_3d(ctx, fc_params[2], hidden_size_x, mlp_ratio * hidden_size_x, batch_size);
|
||||||
|
|
||||||
|
fc1_gate = ggml_cont(ctx, ggml_torch_permute(ctx, fc1_gate, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
|
||||||
|
fc1_gate = ggml_l2_norm(ctx, fc1_gate, 1e-12f);
|
||||||
|
fc1_value = ggml_cont(ctx, ggml_torch_permute(ctx, fc1_value, 1, 0, 2, 3)); // [batch_size, hidden_size_x*mlp_ratio, hidden_size_x]
|
||||||
|
fc1_value = ggml_l2_norm(ctx, fc1_value, 1e-12f);
|
||||||
|
fc2 = ggml_cont(ctx, ggml_torch_permute(ctx, fc2, 1, 0, 2, 3)); // [batch_size, hidden_size_x, hidden_size_x*mlp_ratio]
|
||||||
|
fc2 = ggml_l2_norm(ctx, fc2, 1e-12f);
|
||||||
|
|
||||||
|
auto res_x = x;
|
||||||
|
x = norm->forward(ctx, x); // [batch_size, n_token, hidden_size_x]
|
||||||
|
|
||||||
|
auto x1 = ggml_mul_mat(ctx, fc1_gate, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
|
||||||
|
x1 = ggml_silu_inplace(ctx, x1);
|
||||||
|
|
||||||
|
auto x2 = ggml_mul_mat(ctx, fc1_value, x); // [batch_size, n_token, hidden_size_x*mlp_ratio]
|
||||||
|
|
||||||
|
x = ggml_mul_inplace(ctx, x1, x2); // [batch_size, n_token, hidden_size_x*mlp_ratio]
|
||||||
|
|
||||||
|
x = ggml_mul_mat(ctx, fc2, x); // [batch_size, n_token, hidden_size_x]
|
||||||
|
|
||||||
|
x = ggml_add_inplace(ctx, x, res_x);
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NerfFinalLayer : public GGMLBlock {
|
||||||
|
NerfFinalLayer(int64_t hidden_size,
|
||||||
|
int64_t out_channels) {
|
||||||
|
blocks["norm"] = std::make_shared<RMSNorm>(hidden_size);
|
||||||
|
blocks["linear"] = std::make_shared<Linear>(hidden_size, out_channels);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x) {
|
||||||
|
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
||||||
|
auto linear = std::dynamic_pointer_cast<Linear>(blocks["linear"]);
|
||||||
|
|
||||||
|
x = norm->forward(ctx, x);
|
||||||
|
x = linear->forward(ctx, x);
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NerfFinalLayerConv : public GGMLBlock {
|
||||||
|
NerfFinalLayerConv(int64_t hidden_size,
|
||||||
|
int64_t out_channels) {
|
||||||
|
blocks["norm"] = std::make_shared<RMSNorm>(hidden_size);
|
||||||
|
blocks["conv"] = std::make_shared<Conv2d>(hidden_size, out_channels, std::pair{3, 3}, std::pair{1, 1}, std::pair{1, 1});
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x) {
|
||||||
|
// x: [N, C, H, W]
|
||||||
|
auto norm = std::dynamic_pointer_cast<RMSNorm>(blocks["norm"]);
|
||||||
|
auto conv = std::dynamic_pointer_cast<Conv2d>(blocks["conv"]);
|
||||||
|
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 2, 0, 1, 3)); // [N, H, W, C]
|
||||||
|
x = norm->forward(ctx, x);
|
||||||
|
x = ggml_cont(ctx, ggml_torch_permute(ctx, x, 1, 2, 0, 3)); // [N, C, H, W]
|
||||||
|
x = conv->forward(ctx, x);
|
||||||
|
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ChromaRadianceParams {
|
||||||
|
int64_t nerf_hidden_size = 64;
|
||||||
|
int64_t nerf_mlp_ratio = 4;
|
||||||
|
int64_t nerf_depth = 4;
|
||||||
|
int64_t nerf_max_freqs = 8;
|
||||||
|
};
|
||||||
|
|
||||||
struct FluxParams {
|
struct FluxParams {
|
||||||
|
SDVersion version = VERSION_FLUX;
|
||||||
|
bool is_chroma = false;
|
||||||
|
int64_t patch_size = 2;
|
||||||
int64_t in_channels = 64;
|
int64_t in_channels = 64;
|
||||||
int64_t out_channels = 64;
|
int64_t out_channels = 64;
|
||||||
int64_t vec_in_dim = 768;
|
int64_t vec_in_dim = 768;
|
||||||
@ -565,8 +693,8 @@ namespace Flux {
|
|||||||
bool qkv_bias = true;
|
bool qkv_bias = true;
|
||||||
bool guidance_embed = true;
|
bool guidance_embed = true;
|
||||||
bool flash_attn = true;
|
bool flash_attn = true;
|
||||||
bool is_chroma = false;
|
int64_t in_dim = 64;
|
||||||
SDVersion version = VERSION_FLUX;
|
ChromaRadianceParams chroma_radiance_params;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Flux : public GGMLBlock {
|
struct Flux : public GGMLBlock {
|
||||||
@ -575,53 +703,89 @@ namespace Flux {
|
|||||||
Flux() {}
|
Flux() {}
|
||||||
Flux(FluxParams params)
|
Flux(FluxParams params)
|
||||||
: params(params) {
|
: params(params) {
|
||||||
blocks["img_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.in_channels, params.hidden_size, true));
|
if (params.version == VERSION_CHROMA_RADIANCE) {
|
||||||
if (params.is_chroma) {
|
std::pair<int, int> kernel_size = {(int)params.patch_size, (int)params.patch_size};
|
||||||
blocks["distilled_guidance_layer"] = std::shared_ptr<GGMLBlock>(new ChromaApproximator(params.in_channels, params.hidden_size));
|
std::pair<int, int> stride = kernel_size;
|
||||||
|
|
||||||
|
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
|
||||||
|
params.hidden_size,
|
||||||
|
kernel_size,
|
||||||
|
stride);
|
||||||
} else {
|
} else {
|
||||||
blocks["time_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
|
blocks["img_in"] = std::make_shared<Linear>(params.in_channels, params.hidden_size, true);
|
||||||
blocks["vector_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(params.vec_in_dim, params.hidden_size));
|
}
|
||||||
|
if (params.is_chroma) {
|
||||||
|
blocks["distilled_guidance_layer"] = std::make_shared<ChromaApproximator>(params.in_dim, params.hidden_size);
|
||||||
|
} else {
|
||||||
|
blocks["time_in"] = std::make_shared<MLPEmbedder>(256, params.hidden_size);
|
||||||
|
blocks["vector_in"] = std::make_shared<MLPEmbedder>(params.vec_in_dim, params.hidden_size);
|
||||||
if (params.guidance_embed) {
|
if (params.guidance_embed) {
|
||||||
blocks["guidance_in"] = std::shared_ptr<GGMLBlock>(new MLPEmbedder(256, params.hidden_size));
|
blocks["guidance_in"] = std::make_shared<MLPEmbedder>(256, params.hidden_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
blocks["txt_in"] = std::shared_ptr<GGMLBlock>(new Linear(params.context_in_dim, params.hidden_size, true));
|
blocks["txt_in"] = std::make_shared<Linear>(params.context_in_dim, params.hidden_size, true);
|
||||||
|
|
||||||
for (int i = 0; i < params.depth; i++) {
|
for (int i = 0; i < params.depth; i++) {
|
||||||
blocks["double_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new DoubleStreamBlock(params.hidden_size,
|
blocks["double_blocks." + std::to_string(i)] = std::make_shared<DoubleStreamBlock>(params.hidden_size,
|
||||||
params.num_heads,
|
params.num_heads,
|
||||||
params.mlp_ratio,
|
params.mlp_ratio,
|
||||||
i,
|
i,
|
||||||
params.qkv_bias,
|
params.qkv_bias,
|
||||||
params.flash_attn,
|
params.flash_attn,
|
||||||
params.is_chroma));
|
params.is_chroma);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < params.depth_single_blocks; i++) {
|
for (int i = 0; i < params.depth_single_blocks; i++) {
|
||||||
blocks["single_blocks." + std::to_string(i)] = std::shared_ptr<GGMLBlock>(new SingleStreamBlock(params.hidden_size,
|
blocks["single_blocks." + std::to_string(i)] = std::make_shared<SingleStreamBlock>(params.hidden_size,
|
||||||
params.num_heads,
|
params.num_heads,
|
||||||
params.mlp_ratio,
|
params.mlp_ratio,
|
||||||
i,
|
i,
|
||||||
0.f,
|
0.f,
|
||||||
params.flash_attn,
|
params.flash_attn,
|
||||||
params.is_chroma));
|
params.is_chroma);
|
||||||
}
|
}
|
||||||
|
|
||||||
blocks["final_layer"] = std::shared_ptr<GGMLBlock>(new LastLayer(params.hidden_size, 1, params.out_channels, params.is_chroma));
|
if (params.version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
blocks["nerf_image_embedder"] = std::make_shared<NerfEmbedder>(params.in_channels,
|
||||||
|
params.chroma_radiance_params.nerf_hidden_size,
|
||||||
|
params.chroma_radiance_params.nerf_max_freqs);
|
||||||
|
|
||||||
|
for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) {
|
||||||
|
blocks["nerf_blocks." + std::to_string(i)] = std::make_shared<NerfGLUBlock>(params.hidden_size,
|
||||||
|
params.chroma_radiance_params.nerf_hidden_size,
|
||||||
|
params.chroma_radiance_params.nerf_mlp_ratio);
|
||||||
|
}
|
||||||
|
|
||||||
|
blocks["nerf_final_layer_conv"] = std::make_shared<NerfFinalLayerConv>(params.chroma_radiance_params.nerf_hidden_size,
|
||||||
|
params.in_channels);
|
||||||
|
|
||||||
|
} else {
|
||||||
|
blocks["final_layer"] = std::make_shared<LastLayer>(params.hidden_size, 1, params.out_channels, params.is_chroma);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* pad_to_patch_size(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x) {
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
|
||||||
|
int pad_h = (params.patch_size - H % params.patch_size) % params.patch_size;
|
||||||
|
int pad_w = (params.patch_size - W % params.patch_size) % params.patch_size;
|
||||||
|
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
|
||||||
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
struct ggml_tensor* patchify(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x) {
|
||||||
int64_t patch_size) {
|
|
||||||
// x: [N, C, H, W]
|
// x: [N, C, H, W]
|
||||||
// return: [N, h*w, C * patch_size * patch_size]
|
// return: [N, h*w, C * patch_size * patch_size]
|
||||||
int64_t N = x->ne[3];
|
int64_t N = x->ne[3];
|
||||||
int64_t C = x->ne[2];
|
int64_t C = x->ne[2];
|
||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
int64_t p = patch_size;
|
int64_t p = params.patch_size;
|
||||||
int64_t h = H / patch_size;
|
int64_t h = H / params.patch_size;
|
||||||
int64_t w = W / patch_size;
|
int64_t w = W / params.patch_size;
|
||||||
|
|
||||||
GGML_ASSERT(h * p == H && w * p == W);
|
GGML_ASSERT(h * p == H && w * p == W);
|
||||||
|
|
||||||
@ -633,18 +797,25 @@ namespace Flux {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* process_img(struct ggml_context* ctx,
|
||||||
|
struct ggml_tensor* x) {
|
||||||
|
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
x = pad_to_patch_size(ctx, x);
|
||||||
|
x = patchify(ctx, x);
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
struct ggml_tensor* unpatchify(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* x,
|
struct ggml_tensor* x,
|
||||||
int64_t h,
|
int64_t h,
|
||||||
int64_t w,
|
int64_t w) {
|
||||||
int64_t patch_size) {
|
|
||||||
// x: [N, h*w, C*patch_size*patch_size]
|
// x: [N, h*w, C*patch_size*patch_size]
|
||||||
// return: [N, C, H, W]
|
// return: [N, C, H, W]
|
||||||
int64_t N = x->ne[2];
|
int64_t N = x->ne[2];
|
||||||
int64_t C = x->ne[0] / patch_size / patch_size;
|
int64_t C = x->ne[0] / params.patch_size / params.patch_size;
|
||||||
int64_t H = h * patch_size;
|
int64_t H = h * params.patch_size;
|
||||||
int64_t W = w * patch_size;
|
int64_t W = w * params.patch_size;
|
||||||
int64_t p = patch_size;
|
int64_t p = params.patch_size;
|
||||||
|
|
||||||
GGML_ASSERT(C * p * p == x->ne[0]);
|
GGML_ASSERT(C * p * p == x->ne[0]);
|
||||||
|
|
||||||
@ -671,7 +842,10 @@ namespace Flux {
|
|||||||
auto txt_in = std::dynamic_pointer_cast<Linear>(blocks["txt_in"]);
|
auto txt_in = std::dynamic_pointer_cast<Linear>(blocks["txt_in"]);
|
||||||
auto final_layer = std::dynamic_pointer_cast<LastLayer>(blocks["final_layer"]);
|
auto final_layer = std::dynamic_pointer_cast<LastLayer>(blocks["final_layer"]);
|
||||||
|
|
||||||
img = img_in->forward(ctx, img);
|
if (img_in) {
|
||||||
|
img = img_in->forward(ctx, img);
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor* vec;
|
struct ggml_tensor* vec;
|
||||||
struct ggml_tensor* txt_img_mask = nullptr;
|
struct ggml_tensor* txt_img_mask = nullptr;
|
||||||
if (params.is_chroma) {
|
if (params.is_chroma) {
|
||||||
@ -682,7 +856,7 @@ namespace Flux {
|
|||||||
|
|
||||||
// auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1);
|
// auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1);
|
||||||
// ggml_arange tot working on a lot of backends, precomputing it on CPU instead
|
// ggml_arange tot working on a lot of backends, precomputing it on CPU instead
|
||||||
GGML_ASSERT(arange != nullptr);
|
GGML_ASSERT(mod_index_arange != nullptr);
|
||||||
auto modulation_index = ggml_nn_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32]
|
auto modulation_index = ggml_nn_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32]
|
||||||
|
|
||||||
// Batch broadcast (will it ever be useful)
|
// Batch broadcast (will it ever be useful)
|
||||||
@ -749,52 +923,96 @@ namespace Flux {
|
|||||||
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
txt_img->nb[2] * txt->ne[1]); // [n_img_token, N, hidden_size]
|
||||||
img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size]
|
||||||
|
|
||||||
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
if (final_layer) {
|
||||||
|
img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels)
|
||||||
|
}
|
||||||
|
|
||||||
return img;
|
return img;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor* process_img(struct ggml_context* ctx,
|
struct ggml_tensor* forward_chroma_radiance(struct ggml_context* ctx,
|
||||||
struct ggml_tensor* x) {
|
ggml_backend_t backend,
|
||||||
int64_t W = x->ne[0];
|
struct ggml_tensor* x,
|
||||||
int64_t H = x->ne[1];
|
struct ggml_tensor* timestep,
|
||||||
int64_t patch_size = 2;
|
struct ggml_tensor* context,
|
||||||
int pad_h = (patch_size - H % patch_size) % patch_size;
|
struct ggml_tensor* c_concat,
|
||||||
int pad_w = (patch_size - W % patch_size) % patch_size;
|
struct ggml_tensor* y,
|
||||||
x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w]
|
struct ggml_tensor* guidance,
|
||||||
|
struct ggml_tensor* pe,
|
||||||
// img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
struct ggml_tensor* mod_index_arange = nullptr,
|
||||||
auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size]
|
struct ggml_tensor* dct = nullptr,
|
||||||
return img;
|
std::vector<ggml_tensor*> ref_latents = {},
|
||||||
}
|
std::vector<int> skip_layers = {}) {
|
||||||
|
|
||||||
struct ggml_tensor* forward(struct ggml_context* ctx,
|
|
||||||
ggml_backend_t backend,
|
|
||||||
struct ggml_tensor* x,
|
|
||||||
struct ggml_tensor* timestep,
|
|
||||||
struct ggml_tensor* context,
|
|
||||||
struct ggml_tensor* c_concat,
|
|
||||||
struct ggml_tensor* y,
|
|
||||||
struct ggml_tensor* guidance,
|
|
||||||
struct ggml_tensor* pe,
|
|
||||||
struct ggml_tensor* mod_index_arange = nullptr,
|
|
||||||
std::vector<ggml_tensor*> ref_latents = {},
|
|
||||||
std::vector<int> skip_layers = {}) {
|
|
||||||
// Forward pass of DiT.
|
|
||||||
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
|
||||||
// timestep: (N,) tensor of diffusion timesteps
|
|
||||||
// context: (N, L, D)
|
|
||||||
// c_concat: nullptr, or for (N,C+M, H, W) for Fill
|
|
||||||
// y: (N, adm_in_channels) tensor of class labels
|
|
||||||
// guidance: (N,)
|
|
||||||
// pe: (L, d_head/2, 2, 2)
|
|
||||||
// return: (N, C, H, W)
|
|
||||||
|
|
||||||
GGML_ASSERT(x->ne[3] == 1);
|
GGML_ASSERT(x->ne[3] == 1);
|
||||||
|
|
||||||
int64_t W = x->ne[0];
|
int64_t W = x->ne[0];
|
||||||
int64_t H = x->ne[1];
|
int64_t H = x->ne[1];
|
||||||
int64_t C = x->ne[2];
|
int64_t C = x->ne[2];
|
||||||
int64_t patch_size = 2;
|
int64_t patch_size = params.patch_size;
|
||||||
|
int pad_h = (patch_size - H % patch_size) % patch_size;
|
||||||
|
int pad_w = (patch_size - W % patch_size) % patch_size;
|
||||||
|
|
||||||
|
auto img = pad_to_patch_size(ctx, x);
|
||||||
|
auto orig_img = img;
|
||||||
|
|
||||||
|
auto img_in_patch = std::dynamic_pointer_cast<Conv2d>(blocks["img_in_patch"]);
|
||||||
|
|
||||||
|
img = img_in_patch->forward(ctx, img); // [N, hidden_size, H/patch_size, W/patch_size]
|
||||||
|
img = ggml_reshape_3d(ctx, img, img->ne[0] * img->ne[1], img->ne[2], img->ne[3]); // [N, hidden_size, H/patch_size*W/patch_size]
|
||||||
|
img = ggml_cont(ctx, ggml_torch_permute(ctx, img, 1, 0, 2, 3)); // [N, H/patch_size*W/patch_size, hidden_size]
|
||||||
|
|
||||||
|
auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, n_img_token, hidden_size]
|
||||||
|
|
||||||
|
// nerf decode
|
||||||
|
auto nerf_image_embedder = std::dynamic_pointer_cast<NerfEmbedder>(blocks["nerf_image_embedder"]);
|
||||||
|
auto nerf_final_layer_conv = std::dynamic_pointer_cast<NerfFinalLayerConv>(blocks["nerf_final_layer_conv"]);
|
||||||
|
|
||||||
|
auto nerf_pixels = patchify(ctx, orig_img); // [N, num_patches, C * patch_size * patch_size]
|
||||||
|
int64_t num_patches = nerf_pixels->ne[1];
|
||||||
|
nerf_pixels = ggml_reshape_3d(ctx,
|
||||||
|
nerf_pixels,
|
||||||
|
nerf_pixels->ne[0] / C,
|
||||||
|
C,
|
||||||
|
nerf_pixels->ne[1] * nerf_pixels->ne[2]); // [N*num_patches, C, patch_size*patch_size]
|
||||||
|
nerf_pixels = ggml_cont(ctx, ggml_torch_permute(ctx, nerf_pixels, 1, 0, 2, 3)); // [N*num_patches, patch_size*patch_size, C]
|
||||||
|
|
||||||
|
auto nerf_hidden = ggml_reshape_2d(ctx, out, out->ne[0], out->ne[1] * out->ne[2]); // [N*num_patches, hidden_size]
|
||||||
|
auto img_dct = nerf_image_embedder->forward(ctx, nerf_pixels, dct); // [N*num_patches, patch_size*patch_size, nerf_hidden_size]
|
||||||
|
|
||||||
|
for (int i = 0; i < params.chroma_radiance_params.nerf_depth; i++) {
|
||||||
|
auto block = std::dynamic_pointer_cast<NerfGLUBlock>(blocks["nerf_blocks." + std::to_string(i)]);
|
||||||
|
|
||||||
|
img_dct = block->forward(ctx, img_dct, nerf_hidden);
|
||||||
|
}
|
||||||
|
|
||||||
|
img_dct = ggml_cont(ctx, ggml_torch_permute(ctx, img_dct, 1, 0, 2, 3)); // [N*num_patches, nerf_hidden_size, patch_size*patch_size]
|
||||||
|
img_dct = ggml_reshape_3d(ctx, img_dct, img_dct->ne[0] * img_dct->ne[1], num_patches, img_dct->ne[2] / num_patches); // [N, num_patches, nerf_hidden_size*patch_size*patch_size]
|
||||||
|
img_dct = unpatchify(ctx, img_dct, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, nerf_hidden_size, H, W]
|
||||||
|
|
||||||
|
out = nerf_final_layer_conv->forward(ctx, img_dct); // [N, C, H, W]
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward_flux_chroma(struct ggml_context* ctx,
|
||||||
|
ggml_backend_t backend,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* timestep,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* c_concat,
|
||||||
|
struct ggml_tensor* y,
|
||||||
|
struct ggml_tensor* guidance,
|
||||||
|
struct ggml_tensor* pe,
|
||||||
|
struct ggml_tensor* mod_index_arange = nullptr,
|
||||||
|
struct ggml_tensor* dct = nullptr,
|
||||||
|
std::vector<ggml_tensor*> ref_latents = {},
|
||||||
|
std::vector<int> skip_layers = {}) {
|
||||||
|
GGML_ASSERT(x->ne[3] == 1);
|
||||||
|
|
||||||
|
int64_t W = x->ne[0];
|
||||||
|
int64_t H = x->ne[1];
|
||||||
|
int64_t C = x->ne[2];
|
||||||
|
int64_t patch_size = params.patch_size;
|
||||||
int pad_h = (patch_size - H % patch_size) % patch_size;
|
int pad_h = (patch_size - H % patch_size) % patch_size;
|
||||||
int pad_w = (patch_size - W % patch_size) % patch_size;
|
int pad_w = (patch_size - W % patch_size) % patch_size;
|
||||||
|
|
||||||
@ -816,21 +1034,16 @@ namespace Flux {
|
|||||||
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
|
||||||
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
|
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
|
||||||
|
|
||||||
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
|
masked = process_img(ctx, masked);
|
||||||
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
|
mask = process_img(ctx, mask);
|
||||||
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);
|
control = process_img(ctx, control);
|
||||||
|
|
||||||
masked = patchify(ctx, masked, patch_size);
|
|
||||||
mask = patchify(ctx, mask, patch_size);
|
|
||||||
control = patchify(ctx, control, patch_size);
|
|
||||||
|
|
||||||
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
|
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
|
||||||
} else if (params.version == VERSION_FLUX_CONTROLS) {
|
} else if (params.version == VERSION_FLUX_CONTROLS) {
|
||||||
GGML_ASSERT(c_concat != nullptr);
|
GGML_ASSERT(c_concat != nullptr);
|
||||||
|
|
||||||
ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
|
auto control = process_img(ctx, c_concat);
|
||||||
control = patchify(ctx, control, patch_size);
|
img = ggml_concat(ctx, img, control, 0);
|
||||||
img = ggml_concat(ctx, img, control, 0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ref_latents.size() > 0) {
|
if (ref_latents.size() > 0) {
|
||||||
@ -849,10 +1062,63 @@ namespace Flux {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
// rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)
|
||||||
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w]
|
out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size); // [N, C, H + pad_h, W + pad_w]
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor* forward(struct ggml_context* ctx,
|
||||||
|
ggml_backend_t backend,
|
||||||
|
struct ggml_tensor* x,
|
||||||
|
struct ggml_tensor* timestep,
|
||||||
|
struct ggml_tensor* context,
|
||||||
|
struct ggml_tensor* c_concat,
|
||||||
|
struct ggml_tensor* y,
|
||||||
|
struct ggml_tensor* guidance,
|
||||||
|
struct ggml_tensor* pe,
|
||||||
|
struct ggml_tensor* mod_index_arange = nullptr,
|
||||||
|
struct ggml_tensor* dct = nullptr,
|
||||||
|
std::vector<ggml_tensor*> ref_latents = {},
|
||||||
|
std::vector<int> skip_layers = {}) {
|
||||||
|
// Forward pass of DiT.
|
||||||
|
// x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||||
|
// timestep: (N,) tensor of diffusion timesteps
|
||||||
|
// context: (N, L, D)
|
||||||
|
// c_concat: nullptr, or for (N,C+M, H, W) for Fill
|
||||||
|
// y: (N, adm_in_channels) tensor of class labels
|
||||||
|
// guidance: (N,)
|
||||||
|
// pe: (L, d_head/2, 2, 2)
|
||||||
|
// return: (N, C, H, W)
|
||||||
|
|
||||||
|
if (params.version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
return forward_chroma_radiance(ctx,
|
||||||
|
backend,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
c_concat,
|
||||||
|
y,
|
||||||
|
guidance,
|
||||||
|
pe,
|
||||||
|
mod_index_arange,
|
||||||
|
dct,
|
||||||
|
ref_latents,
|
||||||
|
skip_layers);
|
||||||
|
} else {
|
||||||
|
return forward_flux_chroma(ctx,
|
||||||
|
backend,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
context,
|
||||||
|
c_concat,
|
||||||
|
y,
|
||||||
|
guidance,
|
||||||
|
pe,
|
||||||
|
mod_index_arange,
|
||||||
|
dct,
|
||||||
|
ref_latents,
|
||||||
|
skip_layers);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FluxRunner : public GGMLRunner {
|
struct FluxRunner : public GGMLRunner {
|
||||||
@ -860,7 +1126,8 @@ namespace Flux {
|
|||||||
FluxParams flux_params;
|
FluxParams flux_params;
|
||||||
Flux flux;
|
Flux flux;
|
||||||
std::vector<float> pe_vec;
|
std::vector<float> pe_vec;
|
||||||
std::vector<float> mod_index_arange_vec; // for cache
|
std::vector<float> mod_index_arange_vec;
|
||||||
|
std::vector<float> dct_vec;
|
||||||
SDVersion version;
|
SDVersion version;
|
||||||
bool use_mask = false;
|
bool use_mask = false;
|
||||||
|
|
||||||
@ -883,6 +1150,9 @@ namespace Flux {
|
|||||||
flux_params.in_channels = 128;
|
flux_params.in_channels = 128;
|
||||||
} else if (version == VERSION_FLEX_2) {
|
} else if (version == VERSION_FLEX_2) {
|
||||||
flux_params.in_channels = 196;
|
flux_params.in_channels = 196;
|
||||||
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
flux_params.in_channels = 3;
|
||||||
|
flux_params.patch_size = 16;
|
||||||
}
|
}
|
||||||
for (auto pair : tensor_types) {
|
for (auto pair : tensor_types) {
|
||||||
std::string tensor_name = pair.first;
|
std::string tensor_name = pair.first;
|
||||||
@ -933,6 +1203,56 @@ namespace Flux {
|
|||||||
flux.get_param_tensors(tensors, prefix);
|
flux.get_param_tensors(tensors, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<float> fetch_dct_pos(int patch_size, int max_freqs) {
|
||||||
|
const float PI = 3.14159265358979323846f;
|
||||||
|
|
||||||
|
std::vector<float> pos(patch_size);
|
||||||
|
for (int i = 0; i < patch_size; ++i) {
|
||||||
|
pos[i] = static_cast<float>(i) / static_cast<float>(patch_size - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> pos_x(patch_size * patch_size);
|
||||||
|
std::vector<float> pos_y(patch_size * patch_size);
|
||||||
|
for (int i = 0; i < patch_size; ++i) {
|
||||||
|
for (int j = 0; j < patch_size; ++j) {
|
||||||
|
pos_x[i * patch_size + j] = pos[j];
|
||||||
|
pos_y[i * patch_size + j] = pos[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> freqs(max_freqs);
|
||||||
|
for (int i = 0; i < max_freqs; ++i) {
|
||||||
|
freqs[i] = static_cast<float>(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> coeffs(max_freqs * max_freqs);
|
||||||
|
for (int fx = 0; fx < max_freqs; ++fx) {
|
||||||
|
for (int fy = 0; fy < max_freqs; ++fy) {
|
||||||
|
coeffs[fx * max_freqs + fy] = 1.0f / (1.0f + freqs[fx] * freqs[fy]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_positions = patch_size * patch_size;
|
||||||
|
int num_features = max_freqs * max_freqs;
|
||||||
|
std::vector<float> dct(num_positions * num_features);
|
||||||
|
|
||||||
|
for (int p = 0; p < num_positions; ++p) {
|
||||||
|
float px = pos_x[p];
|
||||||
|
float py = pos_y[p];
|
||||||
|
|
||||||
|
for (int fx = 0; fx < max_freqs; ++fx) {
|
||||||
|
float cx = std::cos(px * freqs[fx] * PI);
|
||||||
|
for (int fy = 0; fy < max_freqs; ++fy) {
|
||||||
|
float cy = std::cos(py * freqs[fy] * PI);
|
||||||
|
float val = cx * cy * coeffs[fx * max_freqs + fy];
|
||||||
|
dct[p * num_features + (fx * max_freqs + fy)] = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return dct;
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
|
struct ggml_cgraph* build_graph(struct ggml_tensor* x,
|
||||||
struct ggml_tensor* timesteps,
|
struct ggml_tensor* timesteps,
|
||||||
struct ggml_tensor* context,
|
struct ggml_tensor* context,
|
||||||
@ -946,6 +1266,7 @@ namespace Flux {
|
|||||||
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
|
struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false);
|
||||||
|
|
||||||
struct ggml_tensor* mod_index_arange = nullptr;
|
struct ggml_tensor* mod_index_arange = nullptr;
|
||||||
|
struct ggml_tensor* dct = nullptr; // for chroma radiance
|
||||||
|
|
||||||
x = to_backend(x);
|
x = to_backend(x);
|
||||||
context = to_backend(context);
|
context = to_backend(context);
|
||||||
@ -976,7 +1297,7 @@ namespace Flux {
|
|||||||
|
|
||||||
pe_vec = Rope::gen_flux_pe(x->ne[1],
|
pe_vec = Rope::gen_flux_pe(x->ne[1],
|
||||||
x->ne[0],
|
x->ne[0],
|
||||||
2,
|
flux_params.patch_size,
|
||||||
x->ne[3],
|
x->ne[3],
|
||||||
context->ne[1],
|
context->ne[1],
|
||||||
ref_latents,
|
ref_latents,
|
||||||
@ -991,6 +1312,17 @@ namespace Flux {
|
|||||||
// pe->data = nullptr;
|
// pe->data = nullptr;
|
||||||
set_backend_tensor_data(pe, pe_vec.data());
|
set_backend_tensor_data(pe, pe_vec.data());
|
||||||
|
|
||||||
|
if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
int64_t patch_size = flux_params.patch_size;
|
||||||
|
int64_t nerf_max_freqs = flux_params.chroma_radiance_params.nerf_max_freqs;
|
||||||
|
dct_vec = fetch_dct_pos(patch_size, nerf_max_freqs);
|
||||||
|
dct = ggml_new_tensor_2d(compute_ctx, GGML_TYPE_F32, nerf_max_freqs * nerf_max_freqs, patch_size * patch_size);
|
||||||
|
// dct->data = dct_vec.data();
|
||||||
|
// print_ggml_tensor(dct);
|
||||||
|
// dct->data = nullptr;
|
||||||
|
set_backend_tensor_data(dct, dct_vec.data());
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor* out = flux.forward(compute_ctx,
|
struct ggml_tensor* out = flux.forward(compute_ctx,
|
||||||
runtime_backend,
|
runtime_backend,
|
||||||
x,
|
x,
|
||||||
@ -1001,6 +1333,7 @@ namespace Flux {
|
|||||||
guidance,
|
guidance,
|
||||||
pe,
|
pe,
|
||||||
mod_index_arange,
|
mod_index_arange,
|
||||||
|
dct,
|
||||||
ref_latents,
|
ref_latents,
|
||||||
skip_layers);
|
skip_layers);
|
||||||
|
|
||||||
@ -1035,7 +1368,7 @@ namespace Flux {
|
|||||||
|
|
||||||
void test() {
|
void test() {
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = static_cast<size_t>(20 * 1024 * 1024); // 20 MB
|
params.mem_size = static_cast<size_t>(1024 * 1024) * 1024; // 1GB
|
||||||
params.mem_buffer = nullptr;
|
params.mem_buffer = nullptr;
|
||||||
params.no_alloc = false;
|
params.no_alloc = false;
|
||||||
|
|
||||||
@ -1046,22 +1379,25 @@ namespace Flux {
|
|||||||
// cpu f16:
|
// cpu f16:
|
||||||
// cuda f16: nan
|
// cuda f16: nan
|
||||||
// cuda q8_0: pass
|
// cuda q8_0: pass
|
||||||
auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1);
|
// auto x = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 16, 16, 16, 1);
|
||||||
ggml_set_f32(x, 0.01f);
|
// ggml_set_f32(x, 0.01f);
|
||||||
|
auto x = load_tensor_from_file(work_ctx, "chroma_x.bin");
|
||||||
// print_ggml_tensor(x);
|
// print_ggml_tensor(x);
|
||||||
|
|
||||||
std::vector<float> timesteps_vec(1, 999.f);
|
std::vector<float> timesteps_vec(1, 1.f);
|
||||||
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec);
|
||||||
|
|
||||||
std::vector<float> guidance_vec(1, 3.5f);
|
std::vector<float> guidance_vec(1, 0.f);
|
||||||
auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec);
|
auto guidance = vector_to_ggml_tensor(work_ctx, guidance_vec);
|
||||||
|
|
||||||
auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1);
|
// auto context = ggml_new_tensor_3d(work_ctx, GGML_TYPE_F32, 4096, 256, 1);
|
||||||
ggml_set_f32(context, 0.01f);
|
// ggml_set_f32(context, 0.01f);
|
||||||
|
auto context = load_tensor_from_file(work_ctx, "chroma_context.bin");
|
||||||
// print_ggml_tensor(context);
|
// print_ggml_tensor(context);
|
||||||
|
|
||||||
auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1);
|
// auto y = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 768, 1);
|
||||||
ggml_set_f32(y, 0.01f);
|
// ggml_set_f32(y, 0.01f);
|
||||||
|
auto y = nullptr;
|
||||||
// print_ggml_tensor(y);
|
// print_ggml_tensor(y);
|
||||||
|
|
||||||
struct ggml_tensor* out = nullptr;
|
struct ggml_tensor* out = nullptr;
|
||||||
@ -1076,32 +1412,44 @@ namespace Flux {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void load_from_file_and_test(const std::string& file_path) {
|
static void load_from_file_and_test(const std::string& file_path) {
|
||||||
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||||
ggml_type model_data_type = GGML_TYPE_Q8_0;
|
ggml_type model_data_type = GGML_TYPE_Q8_0;
|
||||||
std::shared_ptr<FluxRunner> flux = std::make_shared<FluxRunner>(backend, false);
|
|
||||||
{
|
|
||||||
LOG_INFO("loading from '%s'", file_path.c_str());
|
|
||||||
|
|
||||||
flux->alloc_params_buffer();
|
ModelLoader model_loader;
|
||||||
std::map<std::string, ggml_tensor*> tensors;
|
if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) {
|
||||||
flux->get_param_tensors(tensors, "model.diffusion_model");
|
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
|
||||||
|
return;
|
||||||
ModelLoader model_loader;
|
|
||||||
if (!model_loader.init_from_file(file_path, "model.diffusion_model.")) {
|
|
||||||
LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool success = model_loader.load_tensors(tensors);
|
|
||||||
|
|
||||||
if (!success) {
|
|
||||||
LOG_ERROR("load tensors from model loader failed");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_INFO("flux model loaded");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto tensor_types = model_loader.tensor_storages_types;
|
||||||
|
for (auto& item : tensor_types) {
|
||||||
|
// LOG_DEBUG("%s %u", item.first.c_str(), item.second);
|
||||||
|
if (ends_with(item.first, "weight")) {
|
||||||
|
// item.second = model_data_type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<FluxRunner> flux = std::make_shared<FluxRunner>(backend,
|
||||||
|
false,
|
||||||
|
tensor_types,
|
||||||
|
"model.diffusion_model",
|
||||||
|
VERSION_CHROMA_RADIANCE,
|
||||||
|
false,
|
||||||
|
true);
|
||||||
|
|
||||||
|
flux->alloc_params_buffer();
|
||||||
|
std::map<std::string, ggml_tensor*> tensors;
|
||||||
|
flux->get_param_tensors(tensors, "model.diffusion_model");
|
||||||
|
|
||||||
|
bool success = model_loader.load_tensors(tensors);
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
LOG_ERROR("load tensors from model loader failed");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INFO("flux model loaded");
|
||||||
flux->test();
|
flux->test();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -954,7 +954,16 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
|
|||||||
if (scale != 1.f) {
|
if (scale != 1.f) {
|
||||||
x = ggml_scale(ctx, x, scale);
|
x = ggml_scale(ctx, x, scale);
|
||||||
}
|
}
|
||||||
x = ggml_mul_mat(ctx, w, x);
|
if (x->ne[2] * x->ne[3] > 1024) {
|
||||||
|
// workaround: avoid ggml cuda error
|
||||||
|
int64_t ne2 = x->ne[2];
|
||||||
|
int64_t ne3 = x->ne[3];
|
||||||
|
x = ggml_reshape_2d(ctx, x, x->ne[0], x->ne[1] * x->ne[2] * x->ne[3]);
|
||||||
|
x = ggml_mul_mat(ctx, w, x);
|
||||||
|
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1] / ne2 / ne3, ne2, ne3);
|
||||||
|
} else {
|
||||||
|
x = ggml_mul_mat(ctx, w, x);
|
||||||
|
}
|
||||||
if (force_prec_f32) {
|
if (force_prec_f32) {
|
||||||
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
|
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
|
||||||
}
|
}
|
||||||
|
|||||||
32
model.cpp
32
model.cpp
@ -1778,7 +1778,6 @@ bool ModelLoader::model_is_unet() {
|
|||||||
|
|
||||||
SDVersion ModelLoader::get_sd_version() {
|
SDVersion ModelLoader::get_sd_version() {
|
||||||
TensorStorage token_embedding_weight, input_block_weight;
|
TensorStorage token_embedding_weight, input_block_weight;
|
||||||
bool input_block_checked = false;
|
|
||||||
|
|
||||||
bool has_multiple_encoders = false;
|
bool has_multiple_encoders = false;
|
||||||
bool is_unet = false;
|
bool is_unet = false;
|
||||||
@ -1791,12 +1790,12 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
bool has_middle_block_1 = false;
|
bool has_middle_block_1 = false;
|
||||||
|
|
||||||
for (auto& tensor_storage : tensor_storages) {
|
for (auto& tensor_storage : tensor_storages) {
|
||||||
if (!(is_xl || is_flux)) {
|
if (!(is_xl)) {
|
||||||
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
|
||||||
is_flux = true;
|
is_flux = true;
|
||||||
if (input_block_checked) {
|
}
|
||||||
break;
|
if (tensor_storage.name.find("model.diffusion_model.nerf_final_layer_conv.") != std::string::npos) {
|
||||||
}
|
return VERSION_CHROMA_RADIANCE;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
|
||||||
return VERSION_SD3;
|
return VERSION_SD3;
|
||||||
@ -1813,22 +1812,19 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.img_emb") != std::string::npos) {
|
||||||
has_img_emb = true;
|
has_img_emb = true;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos ||
|
||||||
|
tensor_storage.name.find("unet.down_blocks.") != std::string::npos) {
|
||||||
is_unet = true;
|
is_unet = true;
|
||||||
if (has_multiple_encoders) {
|
if (has_multiple_encoders) {
|
||||||
is_xl = true;
|
is_xl = true;
|
||||||
if (input_block_checked) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) {
|
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos ||
|
||||||
|
tensor_storage.name.find("cond_stage_model.1") != std::string::npos ||
|
||||||
|
tensor_storage.name.find("te.1") != std::string::npos) {
|
||||||
has_multiple_encoders = true;
|
has_multiple_encoders = true;
|
||||||
if (is_unet) {
|
if (is_unet) {
|
||||||
is_xl = true;
|
is_xl = true;
|
||||||
if (input_block_checked) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
|
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
|
||||||
@ -1848,12 +1844,10 @@ SDVersion ModelLoader::get_sd_version() {
|
|||||||
token_embedding_weight = tensor_storage;
|
token_embedding_weight = tensor_storage;
|
||||||
// break;
|
// break;
|
||||||
}
|
}
|
||||||
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") {
|
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" ||
|
||||||
input_block_weight = tensor_storage;
|
tensor_storage.name == "model.diffusion_model.img_in.weight" ||
|
||||||
input_block_checked = true;
|
tensor_storage.name == "unet.conv_in.weight") {
|
||||||
if (is_flux) {
|
input_block_weight = tensor_storage;
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (is_wan) {
|
if (is_wan) {
|
||||||
|
|||||||
7
model.h
7
model.h
@ -36,6 +36,7 @@ enum SDVersion {
|
|||||||
VERSION_FLUX_FILL,
|
VERSION_FLUX_FILL,
|
||||||
VERSION_FLUX_CONTROLS,
|
VERSION_FLUX_CONTROLS,
|
||||||
VERSION_FLEX_2,
|
VERSION_FLEX_2,
|
||||||
|
VERSION_CHROMA_RADIANCE,
|
||||||
VERSION_WAN2,
|
VERSION_WAN2,
|
||||||
VERSION_WAN2_2_I2V,
|
VERSION_WAN2_2_I2V,
|
||||||
VERSION_WAN2_2_TI2V,
|
VERSION_WAN2_2_TI2V,
|
||||||
@ -72,7 +73,11 @@ static inline bool sd_version_is_sd3(SDVersion version) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static inline bool sd_version_is_flux(SDVersion version) {
|
static inline bool sd_version_is_flux(SDVersion version) {
|
||||||
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2) {
|
if (version == VERSION_FLUX ||
|
||||||
|
version == VERSION_FLUX_FILL ||
|
||||||
|
version == VERSION_FLUX_CONTROLS ||
|
||||||
|
version == VERSION_FLEX_2 ||
|
||||||
|
version == VERSION_CHROMA_RADIANCE) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -649,7 +649,7 @@ namespace Qwen {
|
|||||||
|
|
||||||
static void load_from_file_and_test(const std::string& file_path) {
|
static void load_from_file_and_test(const std::string& file_path) {
|
||||||
// cuda q8: pass
|
// cuda q8: pass
|
||||||
// cuda q8 fa: nan
|
// cuda q8 fa: pass
|
||||||
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
// ggml_backend_t backend = ggml_backend_cuda_init(0);
|
||||||
ggml_backend_t backend = ggml_backend_cpu_init();
|
ggml_backend_t backend = ggml_backend_cpu_init();
|
||||||
ggml_type model_data_type = GGML_TYPE_Q8_0;
|
ggml_type model_data_type = GGML_TYPE_Q8_0;
|
||||||
|
|||||||
@ -41,6 +41,7 @@ const char* model_version_to_str[] = {
|
|||||||
"Flux Fill",
|
"Flux Fill",
|
||||||
"Flux Control",
|
"Flux Control",
|
||||||
"Flex.2",
|
"Flex.2",
|
||||||
|
"Chroma Radiance",
|
||||||
"Wan 2.x",
|
"Wan 2.x",
|
||||||
"Wan 2.2 I2V",
|
"Wan 2.2 I2V",
|
||||||
"Wan 2.2 TI2V",
|
"Wan 2.2 TI2V",
|
||||||
@ -494,6 +495,9 @@ public:
|
|||||||
version);
|
version);
|
||||||
first_stage_model->alloc_params_buffer();
|
first_stage_model->alloc_params_buffer();
|
||||||
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
first_stage_model->get_param_tensors(tensors, "first_stage_model");
|
||||||
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
|
||||||
|
offload_params_to_cpu);
|
||||||
} else if (!use_tiny_autoencoder) {
|
} else if (!use_tiny_autoencoder) {
|
||||||
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
|
first_stage_model = std::make_shared<AutoEncoderKL>(vae_backend,
|
||||||
offload_params_to_cpu,
|
offload_params_to_cpu,
|
||||||
@ -1041,7 +1045,7 @@ public:
|
|||||||
struct ggml_tensor* c_concat = nullptr;
|
struct ggml_tensor* c_concat = nullptr;
|
||||||
{
|
{
|
||||||
if (zero_out_masked) {
|
if (zero_out_masked) {
|
||||||
c_concat = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 4, 1);
|
c_concat = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / get_vae_scale_factor(), height / get_vae_scale_factor(), 4, 1);
|
||||||
ggml_set_f32(c_concat, 0.f);
|
ggml_set_f32(c_concat, 0.f);
|
||||||
} else {
|
} else {
|
||||||
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
|
ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
|
||||||
@ -1375,6 +1379,53 @@ public:
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int get_vae_scale_factor() {
|
||||||
|
int vae_scale_factor = 8;
|
||||||
|
if (version == VERSION_WAN2_2_TI2V) {
|
||||||
|
vae_scale_factor = 16;
|
||||||
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
vae_scale_factor = 1;
|
||||||
|
}
|
||||||
|
return vae_scale_factor;
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_latent_channel() {
|
||||||
|
int latent_channel = 4;
|
||||||
|
if (sd_version_is_dit(version)) {
|
||||||
|
if (version == VERSION_WAN2_2_TI2V) {
|
||||||
|
latent_channel = 48;
|
||||||
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
latent_channel = 3;
|
||||||
|
} else {
|
||||||
|
latent_channel = 16;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return latent_channel;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor* generate_init_latent(ggml_context* work_ctx,
|
||||||
|
int width,
|
||||||
|
int height,
|
||||||
|
int frames = 1,
|
||||||
|
bool video = false) {
|
||||||
|
int vae_scale_factor = get_vae_scale_factor();
|
||||||
|
int W = width / vae_scale_factor;
|
||||||
|
int H = height / vae_scale_factor;
|
||||||
|
int T = frames;
|
||||||
|
if (sd_version_is_wan(version)) {
|
||||||
|
T = ((T - 1) / 4) + 1;
|
||||||
|
}
|
||||||
|
int C = get_latent_channel();
|
||||||
|
ggml_tensor* init_latent;
|
||||||
|
if (video) {
|
||||||
|
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
|
||||||
|
} else {
|
||||||
|
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
||||||
|
}
|
||||||
|
ggml_set_f32(init_latent, shift_factor);
|
||||||
|
return init_latent;
|
||||||
|
}
|
||||||
|
|
||||||
void process_latent_in(ggml_tensor* latent) {
|
void process_latent_in(ggml_tensor* latent) {
|
||||||
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
|
||||||
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
|
GGML_ASSERT(latent->ne[3] == 16 || latent->ne[3] == 48);
|
||||||
@ -1410,6 +1461,8 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
// pass
|
||||||
} else {
|
} else {
|
||||||
ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3);
|
float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3);
|
||||||
@ -1454,6 +1507,8 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (version == VERSION_CHROMA_RADIANCE) {
|
||||||
|
// pass
|
||||||
} else {
|
} else {
|
||||||
ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
ggml_tensor_iter(latent, [&](ggml_tensor* latent, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3);
|
float value = ggml_tensor_get_f32(latent, i0, i1, i2, i3);
|
||||||
@ -1495,11 +1550,11 @@ public:
|
|||||||
ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) {
|
ggml_tensor* vae_encode(ggml_context* work_ctx, ggml_tensor* x, bool encode_video = false) {
|
||||||
int64_t t0 = ggml_time_ms();
|
int64_t t0 = ggml_time_ms();
|
||||||
ggml_tensor* result = nullptr;
|
ggml_tensor* result = nullptr;
|
||||||
int W = x->ne[0] / 8;
|
int W = x->ne[0] / get_vae_scale_factor();
|
||||||
int H = x->ne[1] / 8;
|
int H = x->ne[1] / get_vae_scale_factor();
|
||||||
|
int C = get_latent_channel();
|
||||||
if (vae_tiling_params.enabled && !encode_video) {
|
if (vae_tiling_params.enabled && !encode_video) {
|
||||||
// TODO wan2.2 vae support?
|
// TODO wan2.2 vae support?
|
||||||
int C = sd_version_is_dit(version) ? 16 : 4;
|
|
||||||
int ne2;
|
int ne2;
|
||||||
int ne3;
|
int ne3;
|
||||||
if (sd_version_is_qwen_image(version)) {
|
if (sd_version_is_qwen_image(version)) {
|
||||||
@ -1586,7 +1641,10 @@ public:
|
|||||||
|
|
||||||
ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* vae_output) {
|
ggml_tensor* get_first_stage_encoding(ggml_context* work_ctx, ggml_tensor* vae_output) {
|
||||||
ggml_tensor* latent;
|
ggml_tensor* latent;
|
||||||
if (use_tiny_autoencoder || sd_version_is_qwen_image(version) || sd_version_is_wan(version)) {
|
if (use_tiny_autoencoder ||
|
||||||
|
sd_version_is_qwen_image(version) ||
|
||||||
|
sd_version_is_wan(version) ||
|
||||||
|
version == VERSION_CHROMA_RADIANCE) {
|
||||||
latent = vae_output;
|
latent = vae_output;
|
||||||
} else if (version == VERSION_SD1_PIX2PIX) {
|
} else if (version == VERSION_SD1_PIX2PIX) {
|
||||||
latent = ggml_view_3d(work_ctx,
|
latent = ggml_view_3d(work_ctx,
|
||||||
@ -1613,18 +1671,14 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
|
ggml_tensor* decode_first_stage(ggml_context* work_ctx, ggml_tensor* x, bool decode_video = false) {
|
||||||
int64_t W = x->ne[0] * 8;
|
int64_t W = x->ne[0] * get_vae_scale_factor();
|
||||||
int64_t H = x->ne[1] * 8;
|
int64_t H = x->ne[1] * get_vae_scale_factor();
|
||||||
int64_t C = 3;
|
int64_t C = 3;
|
||||||
ggml_tensor* result = nullptr;
|
ggml_tensor* result = nullptr;
|
||||||
if (decode_video) {
|
if (decode_video) {
|
||||||
int T = x->ne[2];
|
int T = x->ne[2];
|
||||||
if (sd_version_is_wan(version)) {
|
if (sd_version_is_wan(version)) {
|
||||||
T = ((T - 1) * 4) + 1;
|
T = ((T - 1) * 4) + 1;
|
||||||
if (version == VERSION_WAN2_2_TI2V) {
|
|
||||||
W = x->ne[0] * 16;
|
|
||||||
H = x->ne[1] * 16;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
result = ggml_new_tensor_4d(work_ctx,
|
result = ggml_new_tensor_4d(work_ctx,
|
||||||
GGML_TYPE_F32,
|
GGML_TYPE_F32,
|
||||||
@ -2235,16 +2289,9 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
|
|
||||||
// Sample
|
// Sample
|
||||||
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
|
std::vector<struct ggml_tensor*> final_latents; // collect latents to decode
|
||||||
int C = 4;
|
int C = sd_ctx->sd->get_latent_channel();
|
||||||
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
int W = width / sd_ctx->sd->get_vae_scale_factor();
|
||||||
C = 16;
|
int H = height / sd_ctx->sd->get_vae_scale_factor();
|
||||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
|
||||||
C = 16;
|
|
||||||
} else if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
|
|
||||||
C = 16;
|
|
||||||
}
|
|
||||||
int W = width / 8;
|
|
||||||
int H = height / 8;
|
|
||||||
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
|
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
|
||||||
|
|
||||||
struct ggml_tensor* control_latent = nullptr;
|
struct ggml_tensor* control_latent = nullptr;
|
||||||
@ -2422,51 +2469,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
|
|||||||
return result_images;
|
return result_images;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx,
|
|
||||||
ggml_context* work_ctx,
|
|
||||||
int width,
|
|
||||||
int height,
|
|
||||||
int frames = 1,
|
|
||||||
bool video = false) {
|
|
||||||
int C = 4;
|
|
||||||
int T = frames;
|
|
||||||
int W = width / 8;
|
|
||||||
int H = height / 8;
|
|
||||||
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
|
||||||
C = 16;
|
|
||||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
|
||||||
C = 16;
|
|
||||||
} else if (sd_version_is_qwen_image(sd_ctx->sd->version)) {
|
|
||||||
C = 16;
|
|
||||||
} else if (sd_version_is_wan(sd_ctx->sd->version)) {
|
|
||||||
C = 16;
|
|
||||||
T = ((T - 1) / 4) + 1;
|
|
||||||
if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) {
|
|
||||||
C = 48;
|
|
||||||
W = width / 16;
|
|
||||||
H = height / 16;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ggml_tensor* init_latent;
|
|
||||||
if (video) {
|
|
||||||
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, T, C);
|
|
||||||
} else {
|
|
||||||
init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
|
|
||||||
}
|
|
||||||
if (sd_version_is_sd3(sd_ctx->sd->version)) {
|
|
||||||
ggml_set_f32(init_latent, 0.0609f);
|
|
||||||
} else if (sd_version_is_flux(sd_ctx->sd->version)) {
|
|
||||||
ggml_set_f32(init_latent, 0.1159f);
|
|
||||||
} else {
|
|
||||||
ggml_set_f32(init_latent, 0.f);
|
|
||||||
}
|
|
||||||
return init_latent;
|
|
||||||
}
|
|
||||||
|
|
||||||
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
|
sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
|
||||||
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
|
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
|
||||||
int width = sd_img_gen_params->width;
|
int width = sd_img_gen_params->width;
|
||||||
int height = sd_img_gen_params->height;
|
int height = sd_img_gen_params->height;
|
||||||
|
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
|
||||||
if (sd_version_is_dit(sd_ctx->sd->version)) {
|
if (sd_version_is_dit(sd_ctx->sd->version)) {
|
||||||
if (width % 16 || height % 16) {
|
if (width % 16 || height % 16) {
|
||||||
LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)",
|
LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)",
|
||||||
@ -2562,20 +2569,20 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
1);
|
1);
|
||||||
for (int ix = 0; ix < masked_latent->ne[0]; ix++) {
|
for (int ix = 0; ix < masked_latent->ne[0]; ix++) {
|
||||||
for (int iy = 0; iy < masked_latent->ne[1]; iy++) {
|
for (int iy = 0; iy < masked_latent->ne[1]; iy++) {
|
||||||
int mx = ix * 8;
|
int mx = ix * vae_scale_factor;
|
||||||
int my = iy * 8;
|
int my = iy * vae_scale_factor;
|
||||||
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
|
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
|
||||||
for (int k = 0; k < masked_latent->ne[2]; k++) {
|
for (int k = 0; k < masked_latent->ne[2]; k++) {
|
||||||
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
|
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
|
||||||
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
|
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
|
||||||
}
|
}
|
||||||
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
|
// "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image
|
||||||
for (int x = 0; x < 8; x++) {
|
for (int x = 0; x < vae_scale_factor; x++) {
|
||||||
for (int y = 0; y < 8; y++) {
|
for (int y = 0; y < vae_scale_factor; y++) {
|
||||||
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
|
float m = ggml_tensor_get_f32(mask_img, mx + x, my + y);
|
||||||
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?)
|
// TODO: check if the way the mask is flattened is correct (is it supposed to be x*vae_scale_factor+y or x+vae_scale_factor*y?)
|
||||||
// python code was using "b (h 8) (w 8) -> b (8 8) h w"
|
// python code was using "b (h vae_scale_factor) (w vae_scale_factor) -> b (vae_scale_factor vae_scale_factor) h w"
|
||||||
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y);
|
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * vae_scale_factor + y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
|
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
|
||||||
@ -2598,11 +2605,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
|
|
||||||
{
|
{
|
||||||
// LOG_WARN("Inpainting with a base model is not great");
|
// LOG_WARN("Inpainting with a base model is not great");
|
||||||
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1);
|
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / vae_scale_factor, height / vae_scale_factor, 1, 1);
|
||||||
for (int ix = 0; ix < denoise_mask->ne[0]; ix++) {
|
for (int ix = 0; ix < denoise_mask->ne[0]; ix++) {
|
||||||
for (int iy = 0; iy < denoise_mask->ne[1]; iy++) {
|
for (int iy = 0; iy < denoise_mask->ne[1]; iy++) {
|
||||||
int mx = ix * 8;
|
int mx = ix * vae_scale_factor;
|
||||||
int my = iy * 8;
|
int my = iy * vae_scale_factor;
|
||||||
float m = ggml_tensor_get_f32(mask_img, mx, my);
|
float m = ggml_tensor_get_f32(mask_img, mx, my);
|
||||||
ggml_tensor_set_f32(denoise_mask, m, ix, iy);
|
ggml_tensor_set_f32(denoise_mask, m, ix, iy);
|
||||||
}
|
}
|
||||||
@ -2613,7 +2620,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
|
|||||||
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
|
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
|
||||||
LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask");
|
LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask");
|
||||||
}
|
}
|
||||||
init_latent = generate_init_latent(sd_ctx, work_ctx, width, height);
|
init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height);
|
||||||
}
|
}
|
||||||
|
|
||||||
sd_guidance_params_t guidance = sd_img_gen_params->sample_params.guidance;
|
sd_guidance_params_t guidance = sd_img_gen_params->sample_params.guidance;
|
||||||
@ -2741,6 +2748,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
int sample_steps = sd_vid_gen_params->sample_params.sample_steps;
|
int sample_steps = sd_vid_gen_params->sample_params.sample_steps;
|
||||||
LOG_INFO("generate_video %dx%dx%d", width, height, frames);
|
LOG_INFO("generate_video %dx%dx%d", width, height, frames);
|
||||||
|
|
||||||
|
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
|
||||||
|
|
||||||
sd_ctx->sd->init_scheduler(sd_vid_gen_params->sample_params.scheduler);
|
sd_ctx->sd->init_scheduler(sd_vid_gen_params->sample_params.scheduler);
|
||||||
|
|
||||||
int high_noise_sample_steps = 0;
|
int high_noise_sample_steps = 0;
|
||||||
@ -2838,7 +2847,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
ggml_tensor_set_f32(image, value, i0, i1, i2, i3);
|
ggml_tensor_set_f32(image, value, i0, i1, i2, i3);
|
||||||
});
|
});
|
||||||
|
|
||||||
concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/8, w/8]
|
concat_latent = sd_ctx->sd->encode_first_stage(work_ctx, image); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor]
|
||||||
|
|
||||||
int64_t t2 = ggml_time_ms();
|
int64_t t2 = ggml_time_ms();
|
||||||
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
LOG_INFO("encode_first_stage completed, taking %" PRId64 " ms", t2 - t1);
|
||||||
@ -2848,7 +2857,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
concat_latent->ne[0],
|
concat_latent->ne[0],
|
||||||
concat_latent->ne[1],
|
concat_latent->ne[1],
|
||||||
concat_latent->ne[2],
|
concat_latent->ne[2],
|
||||||
4); // [b*4, t, w/8, h/8]
|
4); // [b*4, t, w/vae_scale_factor, h/vae_scale_factor]
|
||||||
ggml_tensor_iter(concat_mask, [&](ggml_tensor* concat_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
ggml_tensor_iter(concat_mask, [&](ggml_tensor* concat_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
float value = 0.0f;
|
float value = 0.0f;
|
||||||
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
|
if (i2 == 0 && sd_vid_gen_params->init_image.data) { // start image
|
||||||
@ -2859,7 +2868,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3);
|
ggml_tensor_set_f32(concat_mask, value, i0, i1, i2, i3);
|
||||||
});
|
});
|
||||||
|
|
||||||
concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/8, w/8]
|
concat_latent = ggml_tensor_concat(work_ctx, concat_mask, concat_latent, 3); // [b*(c+4), t, h/vae_scale_factor, w/vae_scale_factor]
|
||||||
} else if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-TI2V-5B" && sd_vid_gen_params->init_image.data) {
|
} else if (sd_ctx->sd->diffusion_model->get_desc() == "Wan2.2-TI2V-5B" && sd_vid_gen_params->init_image.data) {
|
||||||
LOG_INFO("IMG2VID");
|
LOG_INFO("IMG2VID");
|
||||||
|
|
||||||
@ -2870,7 +2879,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
|
|
||||||
auto init_image_latent = sd_ctx->sd->vae_encode(work_ctx, init_img); // [b*c, 1, h/16, w/16]
|
auto init_image_latent = sd_ctx->sd->vae_encode(work_ctx, init_img); // [b*c, 1, h/16, w/16]
|
||||||
|
|
||||||
init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
|
init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, frames, true);
|
||||||
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
|
||||||
ggml_set_f32(denoise_mask, 1.f);
|
ggml_set_f32(denoise_mask, 1.f);
|
||||||
|
|
||||||
@ -2927,8 +2936,8 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
ggml_tensor_set_f32(reactive, reactive_value, i0, i1, i2, i3);
|
ggml_tensor_set_f32(reactive, reactive_value, i0, i1, i2, i3);
|
||||||
});
|
});
|
||||||
|
|
||||||
inactive = sd_ctx->sd->encode_first_stage(work_ctx, inactive); // [b*c, t, h/8, w/8]
|
inactive = sd_ctx->sd->encode_first_stage(work_ctx, inactive); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor]
|
||||||
reactive = sd_ctx->sd->encode_first_stage(work_ctx, reactive); // [b*c, t, h/8, w/8]
|
reactive = sd_ctx->sd->encode_first_stage(work_ctx, reactive); // [b*c, t, h/vae_scale_factor, w/vae_scale_factor]
|
||||||
|
|
||||||
int64_t length = inactive->ne[2];
|
int64_t length = inactive->ne[2];
|
||||||
if (ref_image_latent) {
|
if (ref_image_latent) {
|
||||||
@ -2936,7 +2945,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
frames = (length - 1) * 4 + 1;
|
frames = (length - 1) * 4 + 1;
|
||||||
ref_image_num = 1;
|
ref_image_num = 1;
|
||||||
}
|
}
|
||||||
vace_context = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, inactive->ne[0], inactive->ne[1], length, 96); // [b*96, t, h/8, w/8]
|
vace_context = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, inactive->ne[0], inactive->ne[1], length, 96); // [b*96, t, h/vae_scale_factor, w/vae_scale_factor]
|
||||||
ggml_tensor_iter(vace_context, [&](ggml_tensor* vace_context, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
ggml_tensor_iter(vace_context, [&](ggml_tensor* vace_context, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
float value;
|
float value;
|
||||||
if (i3 < 32) {
|
if (i3 < 32) {
|
||||||
@ -2953,7 +2962,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
if (ref_image_latent && i2 == 0) {
|
if (ref_image_latent && i2 == 0) {
|
||||||
value = 0.f;
|
value = 0.f;
|
||||||
} else {
|
} else {
|
||||||
int64_t vae_stride = 8;
|
int64_t vae_stride = vae_scale_factor;
|
||||||
int64_t mask_height_index = i1 * vae_stride + (i3 - 32) / vae_stride;
|
int64_t mask_height_index = i1 * vae_stride + (i3 - 32) / vae_stride;
|
||||||
int64_t mask_width_index = i0 * vae_stride + (i3 - 32) % vae_stride;
|
int64_t mask_width_index = i0 * vae_stride + (i3 - 32) % vae_stride;
|
||||||
value = ggml_tensor_get_f32(mask, mask_width_index, mask_height_index, i2 - ref_image_num, 0);
|
value = ggml_tensor_get_f32(mask, mask_width_index, mask_height_index, i2 - ref_image_num, 0);
|
||||||
@ -2966,7 +2975,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (init_latent == nullptr) {
|
if (init_latent == nullptr) {
|
||||||
init_latent = generate_init_latent(sd_ctx, work_ctx, width, height, frames, true);
|
init_latent = sd_ctx->sd->generate_init_latent(work_ctx, width, height, frames, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get learned condition
|
// Get learned condition
|
||||||
@ -2997,16 +3006,10 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
|
|||||||
sd_ctx->sd->cond_stage_model->free_params_buffer();
|
sd_ctx->sd->cond_stage_model->free_params_buffer();
|
||||||
}
|
}
|
||||||
|
|
||||||
int W = width / 8;
|
int W = width / vae_scale_factor;
|
||||||
int H = height / 8;
|
int H = height / vae_scale_factor;
|
||||||
int T = init_latent->ne[2];
|
int T = init_latent->ne[2];
|
||||||
int C = 16;
|
int C = sd_ctx->sd->get_latent_channel();
|
||||||
|
|
||||||
if (sd_ctx->sd->version == VERSION_WAN2_2_TI2V) {
|
|
||||||
W = width / 16;
|
|
||||||
H = height / 16;
|
|
||||||
C = 48;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ggml_tensor* final_latent;
|
struct ggml_tensor* final_latent;
|
||||||
struct ggml_tensor* x_t = init_latent;
|
struct ggml_tensor* x_t = init_latent;
|
||||||
|
|||||||
24
vae.hpp
24
vae.hpp
@ -533,6 +533,30 @@ struct VAE : public GGMLRunner {
|
|||||||
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
virtual void set_conv2d_scale(float scale) { SD_UNUSED(scale); };
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct FakeVAE : public VAE {
|
||||||
|
FakeVAE(ggml_backend_t backend, bool offload_params_to_cpu)
|
||||||
|
: VAE(backend, offload_params_to_cpu) {}
|
||||||
|
void compute(const int n_threads,
|
||||||
|
struct ggml_tensor* z,
|
||||||
|
bool decode_graph,
|
||||||
|
struct ggml_tensor** output,
|
||||||
|
struct ggml_context* output_ctx) override {
|
||||||
|
if (*output == nullptr && output_ctx != nullptr) {
|
||||||
|
*output = ggml_dup_tensor(output_ctx, z);
|
||||||
|
}
|
||||||
|
ggml_tensor_iter(z, [&](ggml_tensor* z, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
|
||||||
|
float value = ggml_tensor_get_f32(z, i0, i1, i2, i3);
|
||||||
|
ggml_tensor_set_f32(*output, value, i0, i1, i2, i3);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors, const std::string prefix) override {}
|
||||||
|
|
||||||
|
std::string get_desc() override {
|
||||||
|
return "fake_vae";
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct AutoEncoderKL : public VAE {
|
struct AutoEncoderKL : public VAE {
|
||||||
bool decode_only = true;
|
bool decode_only = true;
|
||||||
AutoencodingEngine ae;
|
AutoencodingEngine ae;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user