diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f6cf11e..a4d6ca6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -149,7 +149,7 @@ jobs: runs-on: windows-2025 env: - VULKAN_VERSION: 1.3.261.1 + VULKAN_VERSION: 1.4.328.1 strategy: matrix: @@ -199,9 +199,9 @@ jobs: version: 1.11.1 - name: Install Vulkan SDK id: get_vulkan - if: ${{ matrix.build == 'vulkan' }} + if: ${{ matrix.build == 'vulkan' }} https://sdk.lunarg.com/sdk/download/1.4.328.1/windows/vulkansdk-windows-X64-1.4.328.1.exe run: | - curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe" + curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/vulkansdk-windows-X64-${env:VULKAN_VERSION}.exe" & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}" Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin" diff --git a/Dockerfile b/Dockerfile index bd9a378..4173357 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,16 +1,21 @@ ARG UBUNTU_VERSION=22.04 -FROM ubuntu:$UBUNTU_VERSION as build +FROM ubuntu:$UBUNTU_VERSION AS build -RUN apt-get update && apt-get install -y build-essential git cmake +RUN apt-get update && apt-get install -y --no-install-recommends build-essential git cmake WORKDIR /sd.cpp COPY . . -RUN mkdir build && cd build && cmake .. && cmake --build . --config Release +RUN cmake . -B ./build +RUN cmake --build ./build --config Release --parallel -FROM ubuntu:$UBUNTU_VERSION as runtime +FROM ubuntu:$UBUNTU_VERSION AS runtime + +RUN apt-get update && \ + apt-get install --yes --no-install-recommends libgomp1 && \ + apt-get clean COPY --from=build /sd.cpp/build/bin/sd /sd diff --git a/README.md b/README.md index c5c3eb1..cef0bac 100644 --- a/README.md +++ b/README.md @@ -449,6 +449,7 @@ These projects use `stable-diffusion.cpp` as a backend for their image generatio - [Local Diffusion](https://github.com/rmatif/Local-Diffusion) - [sd.cpp-webui](https://github.com/daniandtheweb/sd.cpp-webui) - [LocalAI](https://github.com/mudler/LocalAI) +- [Neural-Pixel](https://github.com/Luiz-Alcantara/Neural-Pixel) ## Contributors @@ -473,4 +474,4 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp - [generative-models](https://github.com/Stability-AI/generative-models/) - [PhotoMaker](https://github.com/TencentARC/PhotoMaker) - [Wan2.1](https://github.com/Wan-Video/Wan2.1) -- [Wan2.2](https://github.com/Wan-Video/Wan2.2) \ No newline at end of file +- [Wan2.2](https://github.com/Wan-Video/Wan2.2) diff --git a/common.hpp b/common.hpp index 9c8aba1..a197e8f 100644 --- a/common.hpp +++ b/common.hpp @@ -242,7 +242,8 @@ public: FeedForward(int64_t dim, int64_t dim_out, int64_t mult = 4, - Activation activation = Activation::GEGLU) { + Activation activation = Activation::GEGLU, + bool force_prec_f32 = false) { int64_t inner_dim = dim * mult; if (activation == Activation::GELU) { @@ -252,7 +253,7 @@ public: } // net_1 is nn.Dropout(), skip for inference - blocks["net.2"] = std::shared_ptr(new Linear(inner_dim, dim_out)); + blocks["net.2"] = std::shared_ptr(new Linear(inner_dim, dim_out, true, false, force_prec_f32)); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index b3883f5..c55eb71 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -1111,10 +1111,19 @@ bool load_images_from_dir(const std::string dir, return false; } + std::vector entries; for (const auto& entry : fs::directory_iterator(dir)) { - if (!entry.is_regular_file()) - continue; + if (entry.is_regular_file()) { + entries.push_back(entry); + } + } + std::sort(entries.begin(), entries.end(), + [](const fs::directory_entry& a, const fs::directory_entry& b) { + return a.path().filename().string() < b.path().filename().string(); + }); + + for (const auto& entry : entries) { std::string path = entry.path().string(); std::string ext = entry.path().extension().string(); std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower); @@ -1254,7 +1263,7 @@ int main(int argc, const char* argv[]) { } } - if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) { + if (params.control_image_path.size() > 0) { int width = 0; int height = 0; control_image.data = load_image(params.control_image_path.c_str(), width, height, params.width, params.height); diff --git a/flux.hpp b/flux.hpp index 4153c6f..2ed4104 100644 --- a/flux.hpp +++ b/flux.hpp @@ -565,6 +565,7 @@ namespace Flux { bool guidance_embed = true; bool flash_attn = true; bool is_chroma = false; + SDVersion version = VERSION_FLUX; }; struct Flux : public GGMLBlock { @@ -799,7 +800,8 @@ namespace Flux { auto img = process_img(ctx, x); uint64_t img_tokens = img->ne[1]; - if (c_concat != NULL) { + if (params.version == VERSION_FLUX_FILL) { + GGML_ASSERT(c_concat != NULL); ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); @@ -807,6 +809,27 @@ namespace Flux { mask = process_img(ctx, mask); img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); + } else if (params.version == VERSION_FLEX_2) { + GGML_ASSERT(c_concat != NULL); + ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); + ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 1, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); + ggml_tensor* control = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * (C + 1)); + + masked = ggml_pad(ctx, masked, pad_w, pad_h, 0, 0); + mask = ggml_pad(ctx, mask, pad_w, pad_h, 0, 0); + control = ggml_pad(ctx, control, pad_w, pad_h, 0, 0); + + masked = patchify(ctx, masked, patch_size); + mask = patchify(ctx, mask, patch_size); + control = patchify(ctx, control, patch_size); + + img = ggml_concat(ctx, img, ggml_concat(ctx, ggml_concat(ctx, masked, mask, 0), control, 0), 0); + } else if (params.version == VERSION_FLUX_CONTROLS) { + GGML_ASSERT(c_concat != NULL); + + ggml_tensor* control = ggml_pad(ctx, c_concat, pad_w, pad_h, 0, 0); + control = patchify(ctx, control, patch_size); + img = ggml_concat(ctx, img, control, 0); } if (ref_latents.size() > 0) { @@ -817,6 +840,7 @@ namespace Flux { } auto out = forward_orig(ctx, backend, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] + if (out->ne[1] > img_tokens) { out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); @@ -846,13 +870,18 @@ namespace Flux { SDVersion version = VERSION_FLUX, bool flash_attn = false, bool use_mask = false) - : GGMLRunner(backend, offload_params_to_cpu), use_mask(use_mask) { + : GGMLRunner(backend, offload_params_to_cpu), version(version), use_mask(use_mask) { + flux_params.version = version; flux_params.flash_attn = flash_attn; flux_params.guidance_embed = false; flux_params.depth = 0; flux_params.depth_single_blocks = 0; if (version == VERSION_FLUX_FILL) { flux_params.in_channels = 384; + } else if (version == VERSION_FLUX_CONTROLS) { + flux_params.in_channels = 128; + } else if (version == VERSION_FLEX_2) { + flux_params.in_channels = 196; } for (auto pair : tensor_types) { std::string tensor_name = pair.first; diff --git a/ggml_extend.hpp b/ggml_extend.hpp index e94950a..8d48341 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -431,18 +431,24 @@ __STATIC_INLINE__ void sd_image_to_tensor(sd_image_t image, __STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, struct ggml_tensor* mask, - struct ggml_tensor* output) { + struct ggml_tensor* output, + float masked_value = 0.5f) { int64_t width = output->ne[0]; int64_t height = output->ne[1]; int64_t channels = output->ne[2]; + float rescale_mx = mask->ne[0] / output->ne[0]; + float rescale_my = mask->ne[1] / output->ne[1]; GGML_ASSERT(output->type == GGML_TYPE_F32); for (int ix = 0; ix < width; ix++) { for (int iy = 0; iy < height; iy++) { - float m = ggml_tensor_get_f32(mask, ix, iy); + int mx = (int)(ix * rescale_mx); + int my = (int)(iy * rescale_my); + float m = ggml_tensor_get_f32(mask, mx, my); m = round(m); // inpaint models need binary masks - ggml_tensor_set_f32(mask, m, ix, iy); + ggml_tensor_set_f32(mask, m, mx, my); for (int k = 0; k < channels; k++) { - float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5; + float value = ggml_tensor_get_f32(image_data, ix, iy, k); + value = (1 - m) * (value - masked_value) + masked_value; ggml_tensor_set_f32(output, value, ix, iy, k); } } @@ -930,8 +936,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ct __STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* w, - struct ggml_tensor* b) { + struct ggml_tensor* b, + bool force_prec_f32 = false) { x = ggml_mul_mat(ctx, w, x); + if (force_prec_f32) { + ggml_mul_mat_set_prec(x, GGML_PREC_F32); + } if (b != NULL) { x = ggml_add_inplace(ctx, x, b); } @@ -1944,6 +1954,7 @@ protected: int64_t out_features; bool bias; bool force_f32; + bool force_prec_f32; void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); @@ -1960,12 +1971,14 @@ protected: public: Linear(int64_t in_features, int64_t out_features, - bool bias = true, - bool force_f32 = false) + bool bias = true, + bool force_f32 = false, + bool force_prec_f32 = false) : in_features(in_features), out_features(out_features), bias(bias), - force_f32(force_f32) {} + force_f32(force_f32), + force_prec_f32(force_prec_f32) {} struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { struct ggml_tensor* w = params["weight"]; @@ -1973,7 +1986,7 @@ public: if (bias) { b = params["bias"]; } - return ggml_nn_linear(ctx, x, w, b); + return ggml_nn_linear(ctx, x, w, b, force_prec_f32); } }; diff --git a/model.cpp b/model.cpp index a68a99d..b45493c 100644 --- a/model.cpp +++ b/model.cpp @@ -1863,10 +1863,15 @@ SDVersion ModelLoader::get_sd_version() { } if (is_flux) { - is_inpaint = input_block_weight.ne[0] == 384; - if (is_inpaint) { + if (input_block_weight.ne[0] == 384) { return VERSION_FLUX_FILL; } + if (input_block_weight.ne[0] == 128) { + return VERSION_FLUX_CONTROLS; + } + if (input_block_weight.ne[0] == 196) { + return VERSION_FLEX_2; + } return VERSION_FLUX; } diff --git a/model.h b/model.h index 045d582..628639c 100644 --- a/model.h +++ b/model.h @@ -31,6 +31,8 @@ enum SDVersion { VERSION_SD3, VERSION_FLUX, VERSION_FLUX_FILL, + VERSION_FLUX_CONTROLS, + VERSION_FLEX_2, VERSION_WAN2, VERSION_WAN2_2_I2V, VERSION_WAN2_2_TI2V, @@ -67,7 +69,7 @@ static inline bool sd_version_is_sd3(SDVersion version) { } static inline bool sd_version_is_flux(SDVersion version) { - if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) { + if (version == VERSION_FLUX || version == VERSION_FLUX_FILL || version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2) { return true; } return false; @@ -88,7 +90,7 @@ static inline bool sd_version_is_qwen_image(SDVersion version) { } static inline bool sd_version_is_inpaint(SDVersion version) { - if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) { + if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL || version == VERSION_FLEX_2) { return true; } return false; @@ -108,8 +110,12 @@ static inline bool sd_version_is_unet_edit(SDVersion version) { return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX; } +static inline bool sd_version_is_control(SDVersion version) { + return version == VERSION_FLUX_CONTROLS || version == VERSION_FLEX_2; +} + static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) { - return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version); + return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version) || sd_version_is_control(version); } enum PMVersion { diff --git a/qwen_image.hpp b/qwen_image.hpp index 0cd7f9e..726d24d 100644 --- a/qwen_image.hpp +++ b/qwen_image.hpp @@ -196,7 +196,7 @@ namespace Qwen { blocks["img_norm1"] = std::shared_ptr(new LayerNorm(dim, eps, false)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(dim, eps, false)); - blocks["img_mlp"] = std::shared_ptr(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU)); + blocks["img_mlp"] = std::shared_ptr(new FeedForward(dim, dim, 4, FeedForward::Activation::GELU, true)); // txt_mod.0 is nn.SiLU() blocks["txt_mod.1"] = std::shared_ptr(new Linear(dim, 6 * dim, true)); diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index d75301b..ff7bc94 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -37,6 +37,8 @@ const char* model_version_to_str[] = { "SD3.x", "Flux", "Flux Fill", + "Flux Control", + "Flex.2", "Wan 2.x", "Wan 2.2 I2V", "Wan 2.2 TI2V", @@ -103,7 +105,7 @@ public: std::shared_ptr high_noise_diffusion_model; std::shared_ptr first_stage_model; std::shared_ptr tae_first_stage; - std::shared_ptr control_net; + std::shared_ptr control_net = NULL; std::shared_ptr pmid_model; std::shared_ptr pmid_lora; std::shared_ptr pmid_id_embeds; @@ -344,6 +346,11 @@ public: scale_factor = 1.0f; } + if (sd_version_is_control(version)) { + // Might need vae encode for control cond + vae_decode_only = false; + } + bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; { @@ -1196,7 +1203,7 @@ public: std::vector controls; - if (control_hint != NULL) { + if (control_hint != NULL && control_net != NULL) { control_net->compute(n_threads, noised_input, control_hint, timesteps, cond.c_crossattn, cond.c_vector); controls = control_net->controls; // print_ggml_tensor(controls[12]); @@ -1234,7 +1241,7 @@ public: float* negative_data = NULL; if (has_unconditioned) { // uncond - if (control_hint != NULL) { + if (control_hint != NULL && control_net != NULL) { control_net->compute(n_threads, noised_input, control_hint, timesteps, uncond.c_crossattn, uncond.c_vector); controls = control_net->controls; } @@ -2160,10 +2167,19 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, int W = width / 8; int H = height / 8; LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); + + struct ggml_tensor* control_latent = NULL; + if (sd_version_is_control(sd_ctx->sd->version) && image_hint != NULL) { + control_latent = sd_ctx->sd->encode_first_stage(work_ctx, image_hint); + ggml_tensor_scale(control_latent, control_strength); + } + if (sd_version_is_inpaint(sd_ctx->sd->version)) { int64_t mask_channels = 1; if (sd_ctx->sd->version == VERSION_FLUX_FILL) { mask_channels = 8 * 8; // flatten the whole mask + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + mask_channels = 1 + init_latent->ne[2]; } auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); // no mask, set the whole image as masked @@ -2177,6 +2193,11 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) { ggml_tensor_set_f32(empty_latent, 1, x, y, c); } + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + for (int64_t c = 0; c < empty_latent->ne[2]; c++) { + // 0x16,1x1,0x16 + ggml_tensor_set_f32(empty_latent, c == init_latent->ne[2], x, y, c); + } } else { ggml_tensor_set_f32(empty_latent, 1, x, y, 0); for (int64_t c = 1; c < empty_latent->ne[2]; c++) { @@ -2185,7 +2206,28 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, } } } - if (concat_latent == NULL) { + + if (sd_ctx->sd->version == VERSION_FLEX_2 && control_latent != NULL && sd_ctx->sd->control_net == NULL) { + bool no_inpaint = concat_latent == NULL; + if (no_inpaint) { + concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); + } + // fill in the control image here + for (int64_t x = 0; x < control_latent->ne[0]; x++) { + for (int64_t y = 0; y < control_latent->ne[1]; y++) { + if (no_inpaint) { + for (int64_t c = 0; c < concat_latent->ne[2] - control_latent->ne[2]; c++) { + // 0x16,1x1,0x16 + ggml_tensor_set_f32(concat_latent, c == init_latent->ne[2], x, y, c); + } + } + for (int64_t c = 0; c < control_latent->ne[2]; c++) { + float v = ggml_tensor_get_f32(control_latent, x, y, c); + ggml_tensor_set_f32(concat_latent, v, x, y, concat_latent->ne[2] - control_latent->ne[2] + c); + } + } + } + } else if (concat_latent == NULL) { concat_latent = empty_latent; } cond.c_concat = concat_latent; @@ -2195,10 +2237,20 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, auto empty_latent = ggml_dup_tensor(work_ctx, init_latent); ggml_set_f32(empty_latent, 0); uncond.c_concat = empty_latent; - if (concat_latent == NULL) { - concat_latent = empty_latent; + cond.c_concat = ref_latents[0]; + if (cond.c_concat == NULL) { + cond.c_concat = empty_latent; + } + } else if (sd_version_is_control(sd_ctx->sd->version)) { + auto empty_latent = ggml_dup_tensor(work_ctx, init_latent); + ggml_set_f32(empty_latent, 0); + uncond.c_concat = empty_latent; + if (sd_ctx->sd->control_net == NULL) { + cond.c_concat = control_latent; + } + if (cond.c_concat == NULL) { + cond.c_concat = empty_latent; } - cond.c_concat = ref_latents[0]; } SDCondition img_cond; if (uncond.c_crossattn != NULL && @@ -2402,17 +2454,27 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_image_to_tensor(sd_img_gen_params->mask_image, mask_img); sd_image_to_tensor(sd_img_gen_params->init_image, init_img); + init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); + if (sd_version_is_inpaint(sd_ctx->sd->version)) { int64_t mask_channels = 1; if (sd_ctx->sd->version == VERSION_FLUX_FILL) { mask_channels = 8 * 8; // flatten the whole mask + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { + mask_channels = 1 + init_latent->ne[2]; } - ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); - sd_apply_mask(init_img, mask_img, masked_img); ggml_tensor* masked_latent = NULL; - masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); - + if (sd_ctx->sd->version != VERSION_FLEX_2) { + // most inpaint models mask before vae + ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + sd_apply_mask(init_img, mask_img, masked_img); + masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + } else { + // mask after vae + masked_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], init_latent->ne[2], 1); + sd_apply_mask(init_latent, mask_img, masked_latent, 0.); + } concat_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, masked_latent->ne[0], @@ -2437,12 +2499,18 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y); } } - } else { + } else if (sd_ctx->sd->version == VERSION_FLEX_2) { float m = ggml_tensor_get_f32(mask_img, mx, my); - ggml_tensor_set_f32(concat_latent, m, ix, iy, 0); + // masked image for (int k = 0; k < masked_latent->ne[2]; k++) { float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); - ggml_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels); + ggml_tensor_set_f32(concat_latent, v, ix, iy, k); + } + // downsampled mask + ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2]); + // control (todo: support this) + for (int k = 0; k < masked_latent->ne[2]; k++) { + ggml_tensor_set_f32(concat_latent, 0, ix, iy, masked_latent->ne[2] + 1 + k); } } } @@ -2461,8 +2529,6 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g } } } - - init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); } else { LOG_INFO("TXT2IMG"); if (sd_version_is_inpaint(sd_ctx->sd->version)) { diff --git a/vae.hpp b/vae.hpp index dd982ab..622b8bb 100644 --- a/vae.hpp +++ b/vae.hpp @@ -583,6 +583,7 @@ struct AutoEncoderKL : public VAE { bool decode_graph, struct ggml_tensor** output, struct ggml_context* output_ctx = NULL) { + GGML_ASSERT(!decode_only || decode_graph); auto get_graph = [&]() -> struct ggml_cgraph* { return build_graph(z, decode_graph); };