fix: optimize the handling of embedding weight (#859)

This commit is contained in:
leejet 2025-09-25 23:09:59 +08:00 committed by GitHub
parent 6ad46bb700
commit 35843c77ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 5 deletions

View File

@ -553,10 +553,9 @@ protected:
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") { void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type token_wtype = GGML_TYPE_F32; enum ggml_type token_wtype = GGML_TYPE_F32;
if (!force_clip_f32) { if (!force_clip_f32) {
auto tensor_type = tensor_types.find(prefix + "token_embedding.weight"); token_wtype = get_type(prefix + "token_embedding.weight", tensor_types, GGML_TYPE_F32);
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0}; if (!support_get_rows(token_wtype)) {
if (tensor_type != tensor_types.end() && allow_types.find(tensor_type->second) != allow_types.end()) { token_wtype = GGML_TYPE_F32;
token_wtype = tensor_type->second;
} }
} }
enum ggml_type position_wtype = GGML_TYPE_F32; enum ggml_type position_wtype = GGML_TYPE_F32;

View File

@ -1967,13 +1967,24 @@ public:
} }
}; };
__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
if (allow_types.find(wtype) != allow_types.end()) {
return true;
}
return false;
}
class Embedding : public UnaryBlock { class Embedding : public UnaryBlock {
protected: protected:
int64_t embedding_dim; int64_t embedding_dim;
int64_t num_embeddings; int64_t num_embeddings;
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") { void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32); enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings); if (!support_get_rows(wtype)) {
wtype = GGML_TYPE_F32;
}
params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings);
} }
public: public: