feat: inpaint improvements (#1357)

* inpaint: get max pixel max instead of single sample

* inpaint: masked diffusion for inpainting models with inflated mask

* refactor tensor interpolate nearest-like reduction paths and generalize max_pool_2d

---------

Co-authored-by: leejet <leejet714@gmail.com>
This commit is contained in:
stduhpf 2026-04-05 18:44:26 +02:00 committed by GitHub
parent 687a81f251
commit 9369ab759f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 193 additions and 15 deletions

View File

@ -2846,7 +2846,8 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
{request->width / request->vae_scale_factor, {request->width / request->vae_scale_factor,
request->height / request->vae_scale_factor, request->height / request->vae_scale_factor,
1, 1,
1}); 1},
sd::ops::InterpolateMode::NearestMax);
sd::Tensor<float> init_latent; sd::Tensor<float> init_latent;
sd::Tensor<float> control_latent; sd::Tensor<float> control_latent;
@ -2991,8 +2992,12 @@ static std::optional<ImageGenerationLatents> prepare_image_generation_latents(sd
latents.ref_latents = std::move(ref_latents); latents.ref_latents = std::move(ref_latents);
if (sd_version_is_inpaint(sd_ctx->sd->version)) { if (sd_version_is_inpaint(sd_ctx->sd->version)) {
latents.denoise_mask = std::move(latent_mask); latent_mask = sd::ops::max_pool_2d(latent_mask,
{3, 3},
{1, 1},
{1, 1});
} }
latents.denoise_mask = std::move(latent_mask);
return latents; return latents;
} }

View File

@ -815,6 +815,9 @@ namespace sd {
namespace ops { namespace ops {
enum class InterpolateMode { enum class InterpolateMode {
Nearest, Nearest,
NearestMax,
NearestMin,
NearestAvg,
}; };
inline int64_t normalize_slice_bound(int64_t index, int64_t dim_size) { inline int64_t normalize_slice_bound(int64_t index, int64_t dim_size) {
@ -1012,12 +1015,16 @@ namespace sd {
std::vector<int64_t> output_shape, std::vector<int64_t> output_shape,
InterpolateMode mode = InterpolateMode::Nearest, InterpolateMode mode = InterpolateMode::Nearest,
bool align_corners = false) { bool align_corners = false) {
if (mode != InterpolateMode::Nearest) { const bool is_nearest_like_mode = (mode == InterpolateMode::Nearest ||
tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" + mode == InterpolateMode::NearestMax ||
mode == InterpolateMode::NearestMin ||
mode == InterpolateMode::NearestAvg);
if (!is_nearest_like_mode) {
tensor_throw_invalid_argument("Only nearest-like interpolate modes are implemented, got mode=" +
std::to_string(static_cast<int>(mode))); std::to_string(static_cast<int>(mode)));
} }
if (align_corners) { if (align_corners) {
tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" + tensor_throw_invalid_argument("align_corners is not supported for nearest-like interpolate: input_shape=" +
tensor_shape_to_string(input.shape()) + ", output_shape=" + tensor_shape_to_string(input.shape()) + ", output_shape=" +
tensor_shape_to_string(output_shape)); tensor_shape_to_string(output_shape));
} }
@ -1044,14 +1051,102 @@ namespace sd {
} }
} }
Tensor<T> output(std::move(output_shape)); bool has_downsampling = false;
for (int64_t flat = 0; flat < output.numel(); ++flat) { for (int64_t i = 0; i < input.dim(); ++i) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat, output.shape()); if (input.shape()[i] > output_shape[i]) {
std::vector<int64_t> input_coord(static_cast<size_t>(input.dim()), 0); has_downsampling = true;
for (size_t i = 0; i < static_cast<size_t>(input.dim()); ++i) { break;
input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i];
} }
output[flat] = input.index(input_coord); }
Tensor<T> output(std::move(output_shape));
if (mode == InterpolateMode::Nearest || !has_downsampling) {
for (int64_t flat = 0; flat < output.numel(); ++flat) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat, output.shape());
std::vector<int64_t> input_coord(static_cast<size_t>(input.dim()), 0);
for (size_t i = 0; i < static_cast<size_t>(input.dim()); ++i) {
input_coord[i] = output_coord[i] * input.shape()[i] / output.shape()[i];
}
output[flat] = input.index(input_coord);
}
return output;
}
auto init_reduction = [&]() -> T {
switch (mode) {
case InterpolateMode::NearestMax:
return std::numeric_limits<T>::lowest();
case InterpolateMode::NearestMin:
return std::numeric_limits<T>::max();
case InterpolateMode::NearestAvg:
return T(0);
case InterpolateMode::Nearest:
return T(0);
}
tensor_throw_invalid_argument("Unsupported interpolate mode: mode=" +
std::to_string(static_cast<int>(mode)));
};
auto reduce_value = [&](T& acc, const T& sample) {
switch (mode) {
case InterpolateMode::NearestMax:
acc = std::max(acc, sample);
break;
case InterpolateMode::NearestMin:
acc = std::min(acc, sample);
break;
case InterpolateMode::NearestAvg:
acc += sample;
break;
case InterpolateMode::Nearest:
break;
}
};
// Reduction modes only differ from nearest mode when downsampling.
for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat_out, output.shape());
std::vector<int64_t> input_start(output.dim(), 0);
std::vector<int64_t> input_end(output.dim(), 0);
for (size_t i = 0; i < static_cast<size_t>(output.dim()); ++i) {
const int64_t input_dim = input.shape()[i];
const int64_t output_dim = output.shape()[i];
input_start[i] = std::max(int64_t(0), static_cast<int64_t>(output_coord[i] * input_dim / output_dim));
input_end[i] = std::min(input_dim, ((output_coord[i] + 1) * input_dim + output_dim - 1) / output_dim);
}
T value = init_reduction();
bool done_window = false;
std::vector<int64_t> current_in_coord = input_start;
while (!done_window) {
reduce_value(value, input.index(current_in_coord));
for (int d = static_cast<int>(output.dim()) - 1; d >= 0; --d) {
if (++current_in_coord[d] < input_end[d]) {
break;
}
current_in_coord[d] = input_start[d];
if (d == 0) {
done_window = true;
}
}
}
if (mode == InterpolateMode::NearestAvg) {
int64_t window_size = 1;
for (size_t i = 0; i < static_cast<size_t>(output.dim()); ++i) {
window_size *= (input_end[i] - input_start[i]);
}
value /= static_cast<T>(window_size);
}
output[flat_out] = value;
} }
return output; return output;
@ -1063,12 +1158,16 @@ namespace sd {
const std::optional<std::vector<double>>& scale_factor, const std::optional<std::vector<double>>& scale_factor,
InterpolateMode mode = InterpolateMode::Nearest, InterpolateMode mode = InterpolateMode::Nearest,
bool align_corners = false) { bool align_corners = false) {
if (mode != InterpolateMode::Nearest) { const bool is_nearest_like_mode = (mode == InterpolateMode::Nearest ||
tensor_throw_invalid_argument("Only nearest interpolate mode is implemented, got mode=" + mode == InterpolateMode::NearestMax ||
mode == InterpolateMode::NearestMin ||
mode == InterpolateMode::NearestAvg);
if (!is_nearest_like_mode) {
tensor_throw_invalid_argument("Only nearest-like interpolate modes are implemented, got mode=" +
std::to_string(static_cast<int>(mode))); std::to_string(static_cast<int>(mode)));
} }
if (align_corners) { if (align_corners) {
tensor_throw_invalid_argument("align_corners is not supported for nearest interpolate: input_shape=" + tensor_throw_invalid_argument("align_corners is not supported for nearest-like interpolate: input_shape=" +
tensor_shape_to_string(input.shape())); tensor_shape_to_string(input.shape()));
} }
if (size.has_value() == scale_factor.has_value()) { if (size.has_value() == scale_factor.has_value()) {
@ -1128,6 +1227,80 @@ namespace sd {
align_corners); align_corners);
} }
template <typename T>
inline Tensor<T> max_pool_2d(const Tensor<T>& input,
std::vector<int64_t> kernel_size,
std::vector<int64_t> stride,
std::vector<int64_t> padding) {
if (input.dim() < 2) {
tensor_throw_invalid_argument("Tensor max_pool_2d requires input_dim >= 2: input_dim=" +
std::to_string(input.dim()) + ", input_shape=" +
tensor_shape_to_string(input.shape()));
}
if (kernel_size.size() != 2 || stride.size() != 2 || padding.size() != 2) {
tensor_throw_invalid_argument("Tensor max_pool_2d requires kernel_size, stride, and padding to have length 2");
}
for (size_t i = 0; i < 2; ++i) {
if (kernel_size[i] <= 0) {
tensor_throw_invalid_argument("Tensor max_pool_2d kernel_size must be positive: kernel_size=" +
tensor_shape_to_string(kernel_size));
}
if (stride[i] <= 0) {
tensor_throw_invalid_argument("Tensor max_pool_2d stride must be positive: stride=" +
tensor_shape_to_string(stride));
}
if (padding[i] < 0) {
tensor_throw_invalid_argument("Tensor max_pool_2d padding must be non-negative: padding=" +
tensor_shape_to_string(padding));
}
}
const int64_t in_height = input.shape()[0];
const int64_t in_width = input.shape()[1];
const int64_t out_height = (in_height + 2 * padding[0] - kernel_size[0]) / stride[0] + 1;
const int64_t out_width = (in_width + 2 * padding[1] - kernel_size[1]) / stride[1] + 1;
if (out_height <= 0 || out_width <= 0) {
tensor_throw_invalid_argument("max_pool_2d results in invalid output dimensions: " +
std::to_string(out_height) + "x" + std::to_string(out_width));
}
std::vector<int64_t> output_shape = input.shape();
output_shape[0] = out_height;
output_shape[1] = out_width;
Tensor<T> output(std::move(output_shape));
for (int64_t flat_out = 0; flat_out < output.numel(); ++flat_out) {
std::vector<int64_t> output_coord = tensor_unravel_index(flat_out, output.shape());
std::vector<int64_t> input_coord = output_coord;
const int64_t oh = output_coord[0];
const int64_t ow = output_coord[1];
T max_val = std::numeric_limits<T>::lowest();
bool has_valid_input = false;
for (int64_t kh = 0; kh < kernel_size[0]; ++kh) {
for (int64_t kw = 0; kw < kernel_size[1]; ++kw) {
const int64_t ih = oh * stride[0] + kh - padding[0];
const int64_t iw = ow * stride[1] + kw - padding[1];
if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) {
input_coord[0] = ih;
input_coord[1] = iw;
max_val = std::max(max_val, input.index(input_coord));
has_valid_input = true;
}
}
}
output[flat_out] = has_valid_input ? max_val : T(0);
}
return output;
}
template <typename T> template <typename T>
inline Tensor<T> concat(const Tensor<T>& lhs, const Tensor<T>& rhs, size_t dim) { inline Tensor<T> concat(const Tensor<T>& lhs, const Tensor<T>& rhs, size_t dim) {
if (lhs.dim() != rhs.dim()) { if (lhs.dim() != rhs.dim()) {