refactor: introduce ggml_ext_zeros_like/ggml_ext_ones_like (#1312)

This commit is contained in:
leejet 2026-03-04 00:36:52 +08:00 committed by GitHub
parent d41f5fff69
commit ba35dd734e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 3 deletions

View File

@ -345,7 +345,7 @@ int main(int argc, const char** argv) {
auto get_lora_full_path = [&](const std::string& path) -> std::string {
std::lock_guard<std::mutex> lock(lora_mutex);
auto it = std::find_if(lora_cache.begin(), lora_cache.end(),
[&](const LoraEntry& e) { return e.path == path; });
[&](const LoraEntry& e) { return e.path == path; });
return (it != lora_cache.end()) ? it->fullpath : "";
};
@ -567,7 +567,7 @@ int main(int argc, const char** argv) {
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(prompt);
size_t image_count = req.form.get_file_count("image[]");
size_t image_count = req.form.get_file_count("image[]");
bool has_legacy_image = req.form.has_file("image");
if (image_count == 0 && !has_legacy_image) {
res.status = 400;

View File

@ -1219,6 +1219,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_zeros(struct ggml_context* ctx,
return ggml_ext_full(ctx, 0.f, ne0, ne1, ne2, ne3);
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_zeros_like(struct ggml_context* ctx,
struct ggml_tensor* x) {
return ggml_ext_zeros(ctx, x->ne[0], x->ne[1], x->ne[2], x->ne[3]);
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones(struct ggml_context* ctx,
int64_t ne0,
int64_t ne1,
@ -1227,6 +1232,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones(struct ggml_context* ctx,
return ggml_ext_full(ctx, 1.f, ne0, ne1, ne2, ne3);
}
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones_like(struct ggml_context* ctx,
struct ggml_tensor* x) {
return ggml_ext_ones(ctx, x->ne[0], x->ne[1], x->ne[2], x->ne[3]);
}
__STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor* a) {
#ifdef SD_USE_VULKAN
auto zero_index = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:zero_int");

View File

@ -404,7 +404,7 @@ namespace Qwen {
auto t_emb = time_text_embed->forward(ctx, timestep);
if (params.zero_cond_t) {
auto t_emb_0 = time_text_embed->forward(ctx, ggml_ext_zeros(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]));
auto t_emb_0 = time_text_embed->forward(ctx, ggml_ext_zeros_like(ctx->ggml_ctx, timestep));
t_emb = ggml_concat(ctx->ggml_ctx, t_emb, t_emb_0, 1);
}
auto img = img_in->forward(ctx, x);