mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2026-06-24 23:26:43 +00:00
add gpt oss tokenizer
This commit is contained in:
parent
a397e03488
commit
1fa06bac5c
91
src/tokenizers/gpt_oss_tokenizer.cpp
Normal file
91
src/tokenizers/gpt_oss_tokenizer.cpp
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
#include "gpt_oss_tokenizer.h"
|
||||||
|
|
||||||
|
#include "json.hpp"
|
||||||
|
#include "util.h"
|
||||||
|
#include "vocab/vocab.h"
|
||||||
|
|
||||||
|
void GPTOSSTokenizer::load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||||
|
auto byte_unicode_pairs = bytes_to_unicode();
|
||||||
|
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
|
||||||
|
for (auto& pair : byte_unicode_pairs) {
|
||||||
|
byte_decoder[pair.second] = pair.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
nlohmann::json vocab;
|
||||||
|
try {
|
||||||
|
vocab = nlohmann::json::parse(vocab_utf8_str);
|
||||||
|
} catch (const nlohmann::json::parse_error&) {
|
||||||
|
GGML_ABORT("invalid vocab json str");
|
||||||
|
}
|
||||||
|
for (const auto& [key, value] : vocab.items()) {
|
||||||
|
std::u32string token = utf8_to_utf32(key);
|
||||||
|
int i = value;
|
||||||
|
encoder[token] = i;
|
||||||
|
decoder[i] = token;
|
||||||
|
}
|
||||||
|
encoder_len = static_cast<int>(encoder.size());
|
||||||
|
for (auto& special_token : special_tokens) {
|
||||||
|
auto token = utf8_to_utf32(special_token);
|
||||||
|
encoder[token] = encoder_len;
|
||||||
|
decoder[encoder_len] = token;
|
||||||
|
encoder_len++;
|
||||||
|
}
|
||||||
|
encoder_len = static_cast<int>(encoder.size());
|
||||||
|
LOG_DEBUG("vocab size: %d", encoder_len);
|
||||||
|
|
||||||
|
std::vector<std::u32string> merges = split_utf32(merges_utf8_str);
|
||||||
|
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
|
||||||
|
for (const auto& merge : merges) {
|
||||||
|
size_t space_pos = merge.find(' ');
|
||||||
|
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
|
||||||
|
}
|
||||||
|
LOG_DEBUG("merges size %zu", merge_pairs.size());
|
||||||
|
|
||||||
|
int rank = 0;
|
||||||
|
for (const auto& merge : merge_pairs) {
|
||||||
|
bpe_ranks[merge] = rank++;
|
||||||
|
}
|
||||||
|
bpe_len = rank;
|
||||||
|
}
|
||||||
|
|
||||||
|
GPTOSSTokenizer::GPTOSSTokenizer(const std::string& merges_utf8_str, const std::string& vocab_utf8_str) {
|
||||||
|
BOS_TOKEN = "<|startoftext|>";
|
||||||
|
UNK_TOKEN = "<|endoftext|>";
|
||||||
|
EOS_TOKEN = "<|endoftext|>";
|
||||||
|
PAD_TOKEN = "<|endoftext|>";
|
||||||
|
|
||||||
|
BOS_TOKEN_ID = 199998;
|
||||||
|
EOS_TOKEN_ID = 199999;
|
||||||
|
UNK_TOKEN_ID = 199999;
|
||||||
|
PAD_TOKEN_ID = 199999;
|
||||||
|
|
||||||
|
special_tokens = {
|
||||||
|
"<|startoftext|>",
|
||||||
|
"<|endoftext|>",
|
||||||
|
"<|reserved_200000|>",
|
||||||
|
"<|reserved_200001|>",
|
||||||
|
"<|return|>",
|
||||||
|
"<|constrain|>",
|
||||||
|
"<|reserved_200004|>",
|
||||||
|
"<|channel|>",
|
||||||
|
"<|start|>",
|
||||||
|
"<|end|>",
|
||||||
|
"<|message|>",
|
||||||
|
"<|reserved_200009|>",
|
||||||
|
"<|reserved_200010|>",
|
||||||
|
"<|reserved_200011|>",
|
||||||
|
"<|call|>",
|
||||||
|
"<|reserved_200013|>",
|
||||||
|
"<|reserved_200014|>",
|
||||||
|
"<|reserved_200015|>",
|
||||||
|
"<|reserved_200016|>",
|
||||||
|
"<|reserved_200017|>",
|
||||||
|
"<|endofprompt|>",
|
||||||
|
};
|
||||||
|
|
||||||
|
if (merges_utf8_str.size() > 0) {
|
||||||
|
load_from_merges(merges_utf8_str, vocab_utf8_str);
|
||||||
|
} else {
|
||||||
|
load_from_merges(load_gpt_oss_merges(), load_gpt_oss_vocab_json());
|
||||||
|
}
|
||||||
|
}
|
||||||
16
src/tokenizers/gpt_oss_tokenizer.h
Normal file
16
src/tokenizers/gpt_oss_tokenizer.h
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
#ifndef __SD_TOKENIZERS_GPT_OSS_TOKENIZER_H__
|
||||||
|
#define __SD_TOKENIZERS_GPT_OSS_TOKENIZER_H__
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "bpe_tokenizer.h"
|
||||||
|
|
||||||
|
class GPTOSSTokenizer : public BPETokenizer {
|
||||||
|
protected:
|
||||||
|
void load_from_merges(const std::string& merges_utf8_str, const std::string& vocab_utf8_str);
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit GPTOSSTokenizer(const std::string& merges_utf8_str = "", const std::string& vocab_utf8_str = "");
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // __SD_TOKENIZERS_GPT_OSS_TOKENIZER_H__
|
||||||
3
src/tokenizers/vocab/gpt_oss_merges.hpp
Normal file
3
src/tokenizers/vocab/gpt_oss_merges.hpp
Normal file
File diff suppressed because one or more lines are too long
3
src/tokenizers/vocab/gpt_oss_vocab.hpp
Normal file
3
src/tokenizers/vocab/gpt_oss_vocab.hpp
Normal file
File diff suppressed because one or more lines are too long
@ -2,6 +2,8 @@
|
|||||||
#include "clip_merges.hpp"
|
#include "clip_merges.hpp"
|
||||||
#include "gemma_merges.hpp"
|
#include "gemma_merges.hpp"
|
||||||
#include "gemma_vocab.hpp"
|
#include "gemma_vocab.hpp"
|
||||||
|
#include "gpt_oss_merges.hpp"
|
||||||
|
#include "gpt_oss_vocab.hpp"
|
||||||
#include "mistral_merges.hpp"
|
#include "mistral_merges.hpp"
|
||||||
#include "mistral_vocab.hpp"
|
#include "mistral_vocab.hpp"
|
||||||
#include "qwen_merges.hpp"
|
#include "qwen_merges.hpp"
|
||||||
@ -46,4 +48,14 @@ std::string load_gemma_merges() {
|
|||||||
std::string load_gemma_vocab_json() {
|
std::string load_gemma_vocab_json() {
|
||||||
std::string json_str(reinterpret_cast<const char*>(gemma_vocab_json_utf8_c_str), sizeof(gemma_vocab_json_utf8_c_str));
|
std::string json_str(reinterpret_cast<const char*>(gemma_vocab_json_utf8_c_str), sizeof(gemma_vocab_json_utf8_c_str));
|
||||||
return json_str;
|
return json_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string load_gpt_oss_merges() {
|
||||||
|
std::string merges_utf8_str(reinterpret_cast<const char*>(gpt_oss_merges_utf8_c_str), sizeof(gpt_oss_merges_utf8_c_str));
|
||||||
|
return merges_utf8_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string load_gpt_oss_vocab_json() {
|
||||||
|
std::string json_str(reinterpret_cast<const char*>(gpt_oss_vocab_json_utf8_c_str), sizeof(gpt_oss_vocab_json_utf8_c_str));
|
||||||
|
return json_str;
|
||||||
}
|
}
|
||||||
@ -11,5 +11,7 @@ std::string load_t5_tokenizer_json();
|
|||||||
std::string load_umt5_tokenizer_json();
|
std::string load_umt5_tokenizer_json();
|
||||||
std::string load_gemma_merges();
|
std::string load_gemma_merges();
|
||||||
std::string load_gemma_vocab_json();
|
std::string load_gemma_vocab_json();
|
||||||
|
std::string load_gpt_oss_merges();
|
||||||
|
std::string load_gpt_oss_vocab_json();
|
||||||
|
|
||||||
#endif // __SD_TOKENIZERS_VOCAB_VOCAB_H__
|
#endif // __SD_TOKENIZERS_VOCAB_VOCAB_H__
|
||||||
Loading…
x
Reference in New Issue
Block a user