mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
allow more quant types
This commit is contained in:
parent
b6c2244d9a
commit
4023083f70
7
clip.hpp
7
clip.hpp
@ -553,12 +553,13 @@ protected:
|
|||||||
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
|
||||||
enum ggml_type token_wtype = GGML_TYPE_F32;
|
enum ggml_type token_wtype = GGML_TYPE_F32;
|
||||||
if (!force_clip_f32) {
|
if (!force_clip_f32) {
|
||||||
auto tensor_type = tensor_types.find(prefix + "token_embedding.weight");
|
auto tensor_type = tensor_types.find(prefix + "token_embedding.weight");
|
||||||
if (tensor_type != tensor_types.end() && tensor_type->second == GGML_TYPE_F16) {
|
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
|
||||||
|
if (tensor_type != tensor_types.end() && allow_types.find(tensor_type->second) != allow_types.end()) {
|
||||||
token_wtype = tensor_type->second;
|
token_wtype = tensor_type->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
enum ggml_type position_wtype = GGML_TYPE_F32;
|
enum ggml_type position_wtype = GGML_TYPE_F32;
|
||||||
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
|
params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
|
||||||
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
|
params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user