#ifndef __SD_TOKENIZERS_TOKENIZER_H__ #define __SD_TOKENIZERS_TOKENIZER_H__ #include #include #include #include #include using on_new_token_cb_t = std::function&)>; class Tokenizer { protected: std::vector special_tokens; bool add_bos_token = false; bool add_eos_token = false; bool pad_left = false; std::string end_of_word_suffix; virtual std::string decode_token(int token_id) const = 0; virtual std::string normalize(const std::string& text) const; public: std::string UNK_TOKEN; std::string BOS_TOKEN; std::string EOS_TOKEN; std::string PAD_TOKEN; int UNK_TOKEN_ID = 0; int BOS_TOKEN_ID = 0; int EOS_TOKEN_ID = 0; int PAD_TOKEN_ID = 0; virtual ~Tokenizer() = default; void add_special_token(const std::string& token); bool is_special_token(const std::string& token) const; virtual std::vector encode(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr) = 0; std::vector tokenize(const std::string& text, on_new_token_cb_t on_new_token_cb = nullptr, bool padding = false, size_t min_length = 0, size_t max_length = 100000000, bool allow_overflow_expand = false); void pad_tokens(std::vector& tokens, std::vector* weights, std::vector* mask, size_t min_length = 0, size_t max_length = 100000000, bool allow_overflow_expand = false); std::string decode(const std::vector& tokens) const; }; #endif // __SD_TOKENIZERS_TOKENIZER_H__