fix: optimize the handling of CLIP embedding weight (#840)

This commit is contained in:
leejet 2025-09-25 00:28:20 +08:00 committed by GitHub
parent f3140eadbb
commit 2abe9451c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 4 deletions

View File

@ -553,12 +553,13 @@ protected:
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
enum ggml_type token_wtype = GGML_TYPE_F32;
if (!force_clip_f32) {
auto tensor_type = tensor_types.find(prefix + "token_embedding.weight");
if (tensor_type != tensor_types.end())
auto tensor_type = tensor_types.find(prefix + "token_embedding.weight");
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
if (tensor_type != tensor_types.end() && allow_types.find(tensor_type->second) != allow_types.end()) {
token_wtype = tensor_type->second;
}
}
enum ggml_type position_wtype = GGML_TYPE_F32;
enum ggml_type position_wtype = GGML_TYPE_F32;
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
}

View File

@ -2422,6 +2422,8 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
// Pass, do not convert. For MMDiT
} else if (contains(name, "time_embed.") || contains(name, "label_emb.")) {
// Pass, do not convert. For Unet
} else if (contains(name, "embedding")) {
// Pass, do not convert embedding
} else {
return true;
}