Compare commits

...

8 Commits

12 changed files with 182 additions and 46 deletions

View File

@ -149,7 +149,7 @@ jobs:
runs-on: windows-2025
env:
VULKAN_VERSION: 1.3.261.1
VULKAN_VERSION: 1.4.328.1
strategy:
matrix:
@ -199,9 +199,9 @@ jobs:
version: 1.11.1
- name: Install Vulkan SDK
id: get_vulkan
if: ${{ matrix.build == 'vulkan' }}
if: ${{ matrix.build == 'vulkan' }} https://sdk.lunarg.com/sdk/download/1.4.328.1/windows/vulkansdk-windows-X64-1.4.328.1.exe
run: |
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe"
curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe"
& "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install
Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}"
Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin"

View File

@ -1,16 +1,21 @@
ARG UBUNTU_VERSION=22.04
FROM ubuntu:$UBUNTU_VERSION as build
FROM ubuntu:$UBUNTU_VERSION AS build
RUN apt-get update && apt-get install -y build-essential git cmake
RUN apt-get update && apt-get install -y --no-install-recommends build-essential git cmake
WORKDIR /sd.cpp
COPY . .
RUN mkdir build && cd build && cmake .. && cmake --build . --config Release
RUN cmake . -B ./build
RUN cmake --build ./build --config Release --parallel
FROM ubuntu:$UBUNTU_VERSION as runtime
FROM ubuntu:$UBUNTU_VERSION AS runtime
RUN apt-get update && \
apt-get install --yes --no-install-recommends libgomp1 && \
apt-get clean
COPY --from=build /sd.cpp/build/bin/sd /sd

View File

@ -449,6 +449,7 @@ These projects use `stable-diffusion.cpp` as a backend for their image generatio
- [Local Diffusion](https://github.com/rmatif/Local-Diffusion)
- [sd.cpp-webui](https://github.com/daniandtheweb/sd.cpp-webui)
- [LocalAI](https://github.com/mudler/LocalAI)
- [Neural-Pixel](https://github.com/Luiz-Alcantara/Neural-Pixel)
## Contributors
@ -473,4 +474,4 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp
- [generative-models](https://github.com/Stability-AI/generative-models/)
- [PhotoMaker](https://github.com/TencentARC/PhotoMaker)
- [Wan2.1](https://github.com/Wan-Video/Wan2.1)
- [Wan2.2](https://github.com/Wan-Video/Wan2.2)
- [Wan2.2](https://github.com/Wan-Video/Wan2.2)

View File

@ -242,7 +242,8 @@ public:
FeedForward(int64_t dim,
int64_t dim_out,
int64_t mult = 4,
Activation activation = Activation::GEGLU) {
Activation activation = Activation::GEGLU,
bool force_prec_f32 = false) {
int64_t inner_dim = dim * mult;
if (activation == Activation::GELU) {
@ -252,7 +253,7 @@ public:
}
// net_1 is nn.Dropout(), skip for inference
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out));
blocks["net.2"] = std::shared_ptr<GGMLBlock>(new Linear(inner_dim, dim_out, true, false, force_prec_f32));
}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {

View File

@ -1111,10 +1111,19 @@ bool load_images_from_dir(const std::string dir,
return false;
}
std::vector<fs::directory_entry> entries;
for (const auto& entry : fs::directory_iterator(dir)) {
if (!entry.is_regular_file())
continue;
if (entry.is_regular_file()) {
entries.push_back(entry);
}
}
std::sort(entries.begin(), entries.end(),
[](const fs::directory_entry& a, const fs::directory_entry& b) {
return a.path().filename().string() < b.path().filename().string();
});
for (const auto& entry : entries) {
std::string path = entry.path().string();
std::string ext = entry.path().extension().string();
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
@ -1254,7 +1263,7 @@ int main(int argc, const char* argv[]) {
}
}
if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) {
if (params.control_image_path.size() > 0) {
int width = 0;
int height = 0;
control_image.data = load_image(params.control_image_path.c_str(), width, height, params.width, params.height);

View File

@ -565,6 +565,7 @@ namespace Flux {
bool guidance_embed = true;
bool flash_attn = true;
bool is_chroma = false;
SDVersion version = VERSION_FLUX;
};
struct Flux : public GGMLBlock {
@ -799,7 +800,8 @@ namespace Flux {
auto img = process_img(ctx, x);
uint64_t img_tokens = img->ne[1];
if (c_concat != NULL) {
if (params.version == VERSION_FLUX_FILL) {
GGML_ASSERT(c_concat != NULL);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
@ -807,6 +809,27 @@ namespace Flux {
mask = process_img(ctx, mask);
img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0);
} else if (params.version == VERSION_FLEX_2) {
GGML_ASSERT(c_concat != NULL);
ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0);
ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C);
ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1));
masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0);
mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0);
control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0);
masked = patchify(ctx, masked, patch_size);
mask = patchify(ctx, mask, patch_size);
control = patchify(ctx, control, patch_size);
img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0);
} else if (params.version == VERSION_FLUX_CONTROLS) {
GGML_ASSERT(c_concat != NULL);
ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0);
control = patchify(ctx, control, patch_size);
img = ggml_concat(ctx, img, control, 0);
}
if (ref_latents.size() > 0) {
@ -817,6 +840,7 @@ namespace Flux {
}
auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size]
if (out->ne[1] > img_tokens) {
out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size]
out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0);
@ -846,13 +870,18 @@ namespace Flux {
SDVersion version = VERSION_FLUX,
bool flash_attn = false,
bool use_mask = false)
: GGMLRunner(backend, offload_params_to_cpu), use_mask(use_mask) {
: GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) {
flux_params.version = version;
flux_params.flash_attn = flash_attn;
flux_params.guidance_embed = false;
flux_params.depth = 0;
flux_params.depth_single_blocks = 0;
if (version == VERSION_FLUX_FILL) {
flux_params.in_channels = 384;
} else if (version == VERSION_FLUX_CONTROLS) {
flux_params.in_channels = 128;
} else if (version == VERSION_FLEX_2) {
flux_params.in_channels = 196;
}
for (auto pair : tensor_types) {
std::string tensor_name = pair.first;

View File

@ -431,18 +431,24 @@ __STATIC_INLINE__ void sd_image_to_tensor(sd_image_t image,
__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data,
struct ggml_tensor* mask,
struct ggml_tensor* output) {
struct ggml_tensor* output,
float masked_value = 0.5f) {
int64_t width = output->ne[0];
int64_t height = output->ne[1];
int64_t channels = output->ne[2];
float rescale_mx = mask->ne[0] / output->ne[0];
float rescale_my = mask->ne[1] / output->ne[1];
GGML_ASSERT(output->type == GGML_TYPE_F32);
for (int ix = 0; ix < width; ix++) {
for (int iy = 0; iy < height; iy++) {
float m = ggml_tensor_get_f32(mask, ix, iy);
int mx = (int)(ix * rescale_mx);
int my = (int)(iy * rescale_my);
float m = ggml_tensor_get_f32(mask, mx, my);
m = round(m); // inpaint models need binary masks
ggml_tensor_set_f32(mask, m, ix, iy);
ggml_tensor_set_f32(mask, m, mx, my);
for (int k = 0; k < channels; k++) {
float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5;
float value = ggml_tensor_get_f32(image_data, ix, iy, k);
value = (1 - m) * (value - masked_value) + masked_value;
ggml_tensor_set_f32(output, value, ix, iy, k);
}
}
@ -930,8 +936,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ct
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
struct ggml_tensor* x,
struct ggml_tensor* w,
struct ggml_tensor* b) {
struct ggml_tensor* b,
bool force_prec_f32 = false) {
x = ggml_mul_mat(ctx, w, x);
if (force_prec_f32) {
ggml_mul_mat_set_prec(x, GGML_PREC_F32);
}
if (b != NULL) {
x = ggml_add_inplace(ctx, x, b);
}
@ -1944,6 +1954,7 @@ protected:
int64_t out_features;
bool bias;
bool force_f32;
bool force_prec_f32;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
@ -1960,12 +1971,14 @@ protected:
public:
Linear(int64_t in_features,
int64_t out_features,
bool bias = true,
bool force_f32 = false)
bool bias = true,
bool force_f32 = false,
bool force_prec_f32 = false)
: in_features(in_features),
out_features(out_features),
bias(bias),
force_f32(force_f32) {}
force_f32(force_f32),
force_prec_f32(force_prec_f32) {}
struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) {
struct ggml_tensor* w = params["weight"];
@ -1973,7 +1986,7 @@ public:
if (bias) {
b = params["bias"];
}
return ggml_nn_linear(ctx, x, w, b);
return ggml_nn_linear(ctx, x, w, b, force_prec_f32);
}
};

View File

@ -1863,10 +1863,15 @@ SDVersion ModelLoader::get_sd_version() {
}
if (is_flux) {
is_inpaint = input_block_weight.ne[0] == 384;
if (is_inpaint) {
if (input_block_weight.ne[0] == 384) {
return VERSION_FLUX_FILL;
}
if (input_block_weight.ne[0] == 128) {
return VERSION_FLUX_CONTROLS;
}
if (input_block_weight.ne[0] == 196) {
return VERSION_FLEX_2;
}
return VERSION_FLUX;
}

12
model.h
View File

@ -31,6 +31,8 @@ enum SDVersion {
VERSION_SD3,
VERSION_FLUX,
VERSION_FLUX_FILL,
VERSION_FLUX_CONTROLS,
VERSION_FLEX_2,
VERSION_WAN2,
VERSION_WAN2_2_I2V,
VERSION_WAN2_2_TI2V,
@ -67,7 +69,7 @@ static inline bool sd_version_is_sd3(SDVersion version) {
}
static inline bool sd_version_is_flux(SDVersion version) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) {
if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2) {
return true;
}
return false;
@ -88,7 +90,7 @@ static inline bool sd_version_is_qwen_image(SDVersion version) {
}
static inline bool sd_version_is_inpaint(SDVersion version) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) {
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) {
return true;
}
return false;
@ -108,8 +110,12 @@ static inline bool sd_version_is_unet_edit(SDVersion version) {
return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX;
}
static inline bool sd_version_is_control(SDVersion version) {
return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2;
}
static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) {
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version);
return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version) || sd_version_is_control(version);
}
enum PMVersion {

View File

@ -196,7 +196,7 @@ namespace Qwen {
blocks["img_norm1"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["img_norm2"] = std::shared_ptr<GGMLBlock>(new LayerNorm(dim, eps, false));
blocks["img_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU));
blocks["img_mlp"] = std::shared_ptr<GGMLBlock>(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU, true));
// txt_mod.0 is nn.SiLU()
blocks["txt_mod.1"] = std::shared_ptr<GGMLBlock>(new Linear(dim, 6 * dim, true));

View File

@ -37,6 +37,8 @@ const char* model_version_to_str[] = {
"SD3.x",
"Flux",
"Flux Fill",
"Flux Control",
"Flex.2",
"Wan 2.x",
"Wan 2.2 I2V",
"Wan 2.2 TI2V",
@ -103,7 +105,7 @@ public:
std::shared_ptr<DiffusionModel> high_noise_diffusion_model;
std::shared_ptr<VAE> first_stage_model;
std::shared_ptr<TinyAutoEncoder> tae_first_stage;
std::shared_ptr<ControlNet> control_net;
std::shared_ptr<ControlNet> control_net = NULL;
std::shared_ptr<PhotoMakerIDEncoder> pmid_model;
std::shared_ptr<LoraModel> pmid_lora;
std::shared_ptr<PhotoMakerIDEmbed> pmid_id_embeds;
@ -344,6 +346,11 @@ public:
scale_factor = 1.0f;
}
if (sd_version_is_control(version)) {
// Might need vae encode for control cond
vae_decode_only = false;
}
bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu;
{
@ -1196,7 +1203,7 @@ public:
std::vector<struct ggml_tensor*> controls;
if (control_hint != NULL) {
if (control_hint != NULL && control_net != NULL) {
control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector);
controls = control_net->controls;
// print_ggml_tensor(controls[12]);
@ -1234,7 +1241,7 @@ public:
float* negative_data = NULL;
if (has_unconditioned) {
// uncond
if (control_hint != NULL) {
if (control_hint != NULL && control_net != NULL) {
control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector);
controls = control_net->controls;
}
@ -2160,10 +2167,19 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
int W = width / 8;
int H = height / 8;
LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]);
struct ggml_tensor* control_latent = NULL;
if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) {
control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint);
ggml_tensor_scale(control_latent, control_strength);
}
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
mask_channels = 1 + init_latent->ne[2];
}
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
// no mask, set the whole image as masked
@ -2177,6 +2193,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) {
ggml_tensor_set_f32(empty_latent, 1, x, y, c);
}
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
for (int64_t c = 0; c < empty_latent->ne[2]; c++) {
// 0x16,1x1,0x16
ggml_tensor_set_f32(empty_latent, c == init_latent->ne[2], x, y, c);
}
} else {
ggml_tensor_set_f32(empty_latent, 1, x, y, 0);
for (int64_t c = 1; c < empty_latent->ne[2]; c++) {
@ -2185,7 +2206,28 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
}
}
}
if (concat_latent == NULL) {
if (sd_ctx->sd->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd->control_net == NULL) {
bool no_inpaint = concat_latent == NULL;
if (no_inpaint) {
concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
}
// fill in the control image here
for (int64_t x = 0; x < control_latent->ne[0]; x++) {
for (int64_t y = 0; y < control_latent->ne[1]; y++) {
if (no_inpaint) {
for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) {
// 0x16,1x1,0x16
ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c);
}
}
for (int64_t c = 0; c < control_latent->ne[2]; c++) {
float v = ggml_tensor_get_f32(control_latent, x, y, c);
ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c);
}
}
}
} else if (concat_latent == NULL) {
concat_latent = empty_latent;
}
cond.c_concat = concat_latent;
@ -2195,10 +2237,20 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
ggml_set_f32(empty_latent, 0);
uncond.c_concat = empty_latent;
if (concat_latent == NULL) {
concat_latent = empty_latent;
cond.c_concat = ref_latents[0];
if (cond.c_concat == NULL) {
cond.c_concat = empty_latent;
}
} else if (sd_version_is_control(sd_ctx->sd->version)) {
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
ggml_set_f32(empty_latent, 0);
uncond.c_concat = empty_latent;
if (sd_ctx->sd->control_net == NULL) {
cond.c_concat = control_latent;
}
if (cond.c_concat == NULL) {
cond.c_concat = empty_latent;
}
cond.c_concat = ref_latents[0];
}
SDCondition img_cond;
if (uncond.c_crossattn != NULL &&
@ -2402,17 +2454,27 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
sd_image_to_tensor(sd_img_gen_params->mask_image, mask_img);
sd_image_to_tensor(sd_img_gen_params->init_image, init_img);
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
int64_t mask_channels = 1;
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
mask_channels = 8 * 8; // flatten the whole mask
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
mask_channels = 1 + init_latent->ne[2];
}
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_apply_mask(init_img, mask_img, masked_img);
ggml_tensor* masked_latent = NULL;
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
if (sd_ctx->sd->version != VERSION_FLEX_2) {
// most inpaint models mask before vae
ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1);
sd_apply_mask(init_img, mask_img, masked_img);
masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img);
} else {
// mask after vae
masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1);
sd_apply_mask(init_latent, mask_img, masked_latent, 0.);
}
concat_latent = ggml_new_tensor_4d(work_ctx,
GGML_TYPE_F32,
masked_latent->ne[0],
@ -2437,12 +2499,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y);
}
}
} else {
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
float m = ggml_tensor_get_f32(mask_img, mx, my);
ggml_tensor_set_f32(concat_latent, m, ix, iy, 0);
// masked image
for (int k = 0; k < masked_latent->ne[2]; k++) {
float v = ggml_tensor_get_f32(masked_latent, ix, iy, k);
ggml_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels);
ggml_tensor_set_f32(concat_latent, v, ix, iy, k);
}
// downsampled mask
ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]);
// control (todo: support this)
for (int k = 0; k < masked_latent->ne[2]; k++) {
ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k);
}
}
}
@ -2461,8 +2529,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
}
}
}
init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img);
} else {
LOG_INFO("TXT2IMG");
if (sd_version_is_inpaint(sd_ctx->sd->version)) {

View File

@ -583,6 +583,7 @@ struct AutoEncoderKL : public VAE {
bool decode_graph,
struct ggml_tensor** output,
struct ggml_context* output_ctx = NULL) {
GGML_ASSERT(!decode_only || decode_graph);
auto get_graph = [&]() -> struct ggml_cgraph* {
return build_graph(z, decode_graph);
};