mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-03-24 02:08:51 +00:00
refactor: introduce ggml_ext_zeros_like/ggml_ext_ones_like (#1312)
This commit is contained in:
parent
d41f5fff69
commit
ba35dd734e
@ -345,7 +345,7 @@ int main(int argc, const char** argv) {
|
||||
auto get_lora_full_path = [&](const std::string& path) -> std::string {
|
||||
std::lock_guard<std::mutex> lock(lora_mutex);
|
||||
auto it = std::find_if(lora_cache.begin(), lora_cache.end(),
|
||||
[&](const LoraEntry& e) { return e.path == path; });
|
||||
[&](const LoraEntry& e) { return e.path == path; });
|
||||
return (it != lora_cache.end()) ? it->fullpath : "";
|
||||
};
|
||||
|
||||
@ -567,7 +567,7 @@ int main(int argc, const char** argv) {
|
||||
|
||||
std::string sd_cpp_extra_args_str = extract_and_remove_sd_cpp_extra_args(prompt);
|
||||
|
||||
size_t image_count = req.form.get_file_count("image[]");
|
||||
size_t image_count = req.form.get_file_count("image[]");
|
||||
bool has_legacy_image = req.form.has_file("image");
|
||||
if (image_count == 0 && !has_legacy_image) {
|
||||
res.status = 400;
|
||||
|
||||
@ -1219,6 +1219,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_zeros(struct ggml_context* ctx,
|
||||
return ggml_ext_full(ctx, 0.f, ne0, ne1, ne2, ne3);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_zeros_like(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x) {
|
||||
return ggml_ext_zeros(ctx, x->ne[0], x->ne[1], x->ne[2], x->ne[3]);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones(struct ggml_context* ctx,
|
||||
int64_t ne0,
|
||||
int64_t ne1,
|
||||
@ -1227,6 +1232,11 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones(struct ggml_context* ctx,
|
||||
return ggml_ext_full(ctx, 1.f, ne0, ne1, ne2, ne3);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_ones_like(struct ggml_context* ctx,
|
||||
struct ggml_tensor* x) {
|
||||
return ggml_ext_ones(ctx, x->ne[0], x->ne[1], x->ne[2], x->ne[3]);
|
||||
}
|
||||
|
||||
__STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor* a) {
|
||||
#ifdef SD_USE_VULKAN
|
||||
auto zero_index = ggml_get_tensor(ctx, "ggml_runner_build_in_tensor:zero_int");
|
||||
|
||||
@ -404,7 +404,7 @@ namespace Qwen {
|
||||
|
||||
auto t_emb = time_text_embed->forward(ctx, timestep);
|
||||
if (params.zero_cond_t) {
|
||||
auto t_emb_0 = time_text_embed->forward(ctx, ggml_ext_zeros(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]));
|
||||
auto t_emb_0 = time_text_embed->forward(ctx, ggml_ext_zeros_like(ctx->ggml_ctx, timestep));
|
||||
t_emb = ggml_concat(ctx->ggml_ctx, t_emb, t_emb_0, 1);
|
||||
}
|
||||
auto img = img_in->forward(ctx, x);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user