mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-24 23:26:43 +00:00
add gpt oss tokenizer
This commit is contained in:
parent
a397e03488
commit
1fa06bac5c
91
src/tokenizers/gpt_oss_tokenizer.cpp
Normal file
91
src/tokenizers/gpt_oss_tokenizer.cpp
Normal file
@ -0,0 +1,91 @@
|
||||
#include "gpt_oss_tokenizer.h"
|
||||
|
||||
#include "json.hpp"
|
||||
#include "util.h"
|
||||
#include "vocab/vocab.h"
|
||||
|
||||
void GPTOSSTokenizer::load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||
auto byte_unicode_pairs = bytes_to_unicode();
|
||||
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||
for (auto& pair : byte_unicode_pairs) {
|
||||
byte_decoder[pair.second] = pair.first;
|
||||
}
|
||||
|
||||
nlohmann::json vocab;
|
||||
try {
|
||||
vocab = nlohmann::json::parse(vocab_utf8_str);
|
||||
} catch (const nlohmann::json::parse_error&) {
|
||||
GGML_ABORT("invalid vocab json str");
|
||||
}
|
||||
for (const auto& [key, value] : vocab.items()) {
|
||||
std::u32string token = utf8_to_utf32(key);
|
||||
int i = value;
|
||||
encoder[token] = i;
|
||||
decoder[i] = token;
|
||||
}
|
||||
encoder_len = static_cast<int>(encoder.size());
|
||||
for (auto& special_token : special_tokens) {
|
||||
auto token = utf8_to_utf32(special_token);
|
||||
encoder[token] = encoder_len;
|
||||
decoder[encoder_len] = token;
|
||||
encoder_len++;
|
||||
}
|
||||
encoder_len = static_cast<int>(encoder.size());
|
||||
LOG_DEBUG("vocab size: %d", encoder_len);
|
||||
|
||||
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
|
||||
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||
for (const auto& merge : merges) {
|
||||
size_t space_pos = merge.find(' ');
|
||||
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||
}
|
||||
LOG_DEBUG("merges size %zu", merge_pairs.size());
|
||||
|
||||
int rank = 0;
|
||||
for (const auto& merge : merge_pairs) {
|
||||
bpe_ranks[merge] = rank++;
|
||||
}
|
||||
bpe_len = rank;
|
||||
}
|
||||
|
||||
GPTOSSTokenizer::GPTOSSTokenizer(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||
BOS_TOKEN = "<|startoftext|>";
|
||||
UNK_TOKEN = "<|endoftext|>";
|
||||
EOS_TOKEN = "<|endoftext|>";
|
||||
PAD_TOKEN = "<|endoftext|>";
|
||||
|
||||
BOS_TOKEN_ID = 199998;
|
||||
EOS_TOKEN_ID = 199999;
|
||||
UNK_TOKEN_ID = 199999;
|
||||
PAD_TOKEN_ID = 199999;
|
||||
|
||||
special_tokens = {
|
||||
"<|startoftext|>",
|
||||
"<|endoftext|>",
|
||||
"<|reserved_200000|>",
|
||||
"<|reserved_200001|>",
|
||||
"<|return|>",
|
||||
"<|constrain|>",
|
||||
"<|reserved_200004|>",
|
||||
"<|channel|>",
|
||||
"<|start|>",
|
||||
"<|end|>",
|
||||
"<|message|>",
|
||||
"<|reserved_200009|>",
|
||||
"<|reserved_200010|>",
|
||||
"<|reserved_200011|>",
|
||||
"<|call|>",
|
||||
"<|reserved_200013|>",
|
||||
"<|reserved_200014|>",
|
||||
"<|reserved_200015|>",
|
||||
"<|reserved_200016|>",
|
||||
"<|reserved_200017|>",
|
||||
"<|endofprompt|>",
|
||||
};
|
||||
|
||||
if (merges_utf8_str.size() > 0) {
|
||||
load_from_merges(merges_utf8_str, vocab_utf8_str);
|
||||
} else {
|
||||
load_from_merges(load_gpt_oss_merges(), load_gpt_oss_vocab_json());
|
||||
}
|
||||
}
|
||||
16
src/tokenizers/gpt_oss_tokenizer.h
Normal file
16
src/tokenizers/gpt_oss_tokenizer.h
Normal file
@ -0,0 +1,16 @@
|
||||
#ifndef __SD_TOKENIZERS_GPT_OSS_TOKENIZER_H__
|
||||
#define __SD_TOKENIZERS_GPT_OSS_TOKENIZER_H__
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "bpe_tokenizer.h"
|
||||
|
||||
class GPTOSSTokenizer : public BPETokenizer {
|
||||
protected:
|
||||
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str);
|
||||
|
||||
public:
|
||||
explicit GPTOSSTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "");
|
||||
};
|
||||
|
||||
#endif // __SD_TOKENIZERS_GPT_OSS_TOKENIZER_H__
|
||||
3
src/tokenizers/vocab/gpt_oss_merges.hpp
Normal file
3
src/tokenizers/vocab/gpt_oss_merges.hpp
Normal file
File diff suppressed because one or more lines are too long
3
src/tokenizers/vocab/gpt_oss_vocab.hpp
Normal file
3
src/tokenizers/vocab/gpt_oss_vocab.hpp
Normal file
File diff suppressed because one or more lines are too long
@ -2,6 +2,8 @@
|
||||
#include "clip_merges.hpp"
|
||||
#include "gemma_merges.hpp"
|
||||
#include "gemma_vocab.hpp"
|
||||
#include "gpt_oss_merges.hpp"
|
||||
#include "gpt_oss_vocab.hpp"
|
||||
#include "mistral_merges.hpp"
|
||||
#include "mistral_vocab.hpp"
|
||||
#include "qwen_merges.hpp"
|
||||
@ -47,3 +49,13 @@ std::string load_gemma_vocab_json() {
|
||||
std::string json_str(reinterpret_cast<const char*>(gemma_vocab_json_utf8_c_str), sizeof(gemma_vocab_json_utf8_c_str));
|
||||
return json_str;
|
||||
}
|
||||
|
||||
std::string load_gpt_oss_merges() {
|
||||
std::string merges_utf8_str(reinterpret_cast<const char*>(gpt_oss_merges_utf8_c_str), sizeof(gpt_oss_merges_utf8_c_str));
|
||||
return merges_utf8_str;
|
||||
}
|
||||
|
||||
std::string load_gpt_oss_vocab_json() {
|
||||
std::string json_str(reinterpret_cast<const char*>(gpt_oss_vocab_json_utf8_c_str), sizeof(gpt_oss_vocab_json_utf8_c_str));
|
||||
return json_str;
|
||||
}
|
||||
@ -11,5 +11,7 @@ std::string load_t5_tokenizer_json();
|
||||
std::string load_umt5_tokenizer_json();
|
||||
std::string load_gemma_merges();
|
||||
std::string load_gemma_vocab_json();
|
||||
std::string load_gpt_oss_merges();
|
||||
std::string load_gpt_oss_vocab_json();
|
||||
|
||||
#endif // __SD_TOKENIZERS_VOCAB_VOCAB_H__
|
||||
Loading…
x
Reference in New Issue
Block a user