Compare commits

..

4 Commits

Author SHA1 Message Date
leejet
94f4f295c1 Merge branch 'master' into qwen_image 2025-09-25 23:13:00 +08:00
leejet
35843c77ea
fix: optimize the handling of embedding weight (#859) 2025-09-25 23:09:59 +08:00
leejet
178a415d89 Merge branch 'master' into qwen_image 2025-09-25 22:01:08 +08:00
leejet
6ad46bb700 sync: update ggml 2025-09-25 21:57:43 +08:00
3 changed files with 16 additions and 6 deletions

View File

@ -553,10 +553,9 @@ protected:
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type token_wtype = GGML_TYPE_F32;
if (!force_clip_f32) {
auto tensor_type = tensor_types.find(prefix + "token_embedding.weight");
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
if (tensor_type != tensor_types.end() && allow_types.find(tensor_type->second) != allow_types.end()) {
token_wtype = tensor_type->second;
token_wtype = get_type(prefix + "token_embedding.weight", tensor_types, GGML_TYPE_F32);
if (!support_get_rows(token_wtype)) {
token_wtype = GGML_TYPE_F32;
}
}
enum ggml_type position_wtype = GGML_TYPE_F32;

2
ggml

@ -1 +1 @@
Subproject commit 553c44706c3cc6e4077f4ab214923fc4c20a013c
Subproject commit 7bffd79a4bec72e9a3bfbedb582a218b84401c13

View File

@ -1980,13 +1980,24 @@ public:
}
};
__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
if (allow_types.find(wtype) != allow_types.end()) {
return true;
}
return false;
}
class Embedding : public UnaryBlock {
protected:
int64_t embedding_dim;
int64_t num_embeddings;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings);
if (!support_get_rows(wtype)) {
wtype = GGML_TYPE_F32;
}
params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings);
}
public: