refactor: introduce ggml_ext_zeros_like/ggml_ext_ones_like (#1312)

This commit is contained in:
leejet 2026-03-04 00:36:52 +08:00 committed by GitHub
parent d41f5fff69
commit ba35dd734e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 3 deletions

View File

@ -1219,6 +1219,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_zeros(struct ggml_context* ctx,
return ggml_ext_full(ctx, 0.f, ne0, ne1, ne2, ne3);
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_zeros_like(struct ggml_context* ctx,
struct ggml_tensor* x) {
return ggml_ext_zeros(ctx, x->ne[0], x->ne[1], x->ne[2], x->ne[3]);
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones(struct ggml_context* ctx,
int64_t ne0,
int64_t ne1,
@ -1227,6 +1232,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones(struct ggml_context* ctx,
return ggml_ext_full(ctx, 1.f, ne0, ne1, ne2, ne3);
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones_like(struct ggml_context* ctx,
struct ggml_tensor* x) {
return ggml_ext_ones(ctx, x->ne[0], x->ne[1], x->ne[2], x->ne[3]);
}
__STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor* a) {
#ifdef SD_USE_VULKAN
auto zero_index = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:zero_int");

View File

@ -404,7 +404,7 @@ namespace Qwen {
auto t_emb = time_text_embed->forward(ctx, timestep);
if (params.zero_cond_t) {
auto t_emb_0 = time_text_embed->forward(ctx, ggml_ext_zeros(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]));
auto t_emb_0 = time_text_embed->forward(ctx, ggml_ext_zeros_like(ctx->ggml_ctx, timestep));
t_emb = ggml_concat(ctx->ggml_ctx, t_emb, t_emb_0, 1);
}
auto img = img_in->forward(ctx, x);