From fe4e73156fac7a1058a30b66f6cf5b6894b9d3e3 Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 21 Sep 2025 00:31:48 +0800 Subject: [PATCH] add qwen2.5 vl support --- examples/cli/main.cpp | 12 +- ggml_extend.hpp | 50 ++- model.cpp | 31 +- qwen.hpp | 966 ++++++++++++++++++++++++++++++------------ tokenize_util.cpp | 142 +++++-- tokenize_util.h | 3 +- 6 files changed, 851 insertions(+), 353 deletions(-) diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 5b43670..423d3b9 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -1142,17 +1142,7 @@ int main(int argc, const char* argv[]) { SDParams params; params.verbose = true; sd_set_log_callback(sd_log_cb, (void*)¶ms); - auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { - return false; - }; - // auto tokenizer = CLIPTokenizer(); - auto tokenizer = Qwen::Qwen2Tokenizer(); - std::string text("a lovely cat"); - auto tokens = tokenizer.encode(text, on_new_token_cb); - for (auto token : tokens) { - std::cout << token << " "; - } - std::cout << std::endl; + Qwen::Qwen2_5_VLEmbedder::load_from_file_and_test(argv[1]); exit(1); parse_args(argc, argv, params); params.sample_params.guidance.slg.layers = params.skip_layers.data(); diff --git a/ggml_extend.hpp b/ggml_extend.hpp index a5f61ea..e26472a 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1119,9 +1119,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx return kqv; } -// q: [N, L_q, C] or [N*n_head, L_q, d_head] -// k: [N, L_k, C] or [N*n_head, L_k, d_head] -// v: [N, L_k, C] or [N, L_k, n_head, d_head] +// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head] +// k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head] +// v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head] // mask: [N, L_q, L_k] // return: [N, L_q, C] __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx, @@ -1139,27 +1139,31 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* int64_t C; int64_t N; int64_t d_head; + int64_t n_kv_head; if (!skip_reshape) { - L_q = q->ne[1]; - L_k = k->ne[1]; - C = q->ne[0]; - N = q->ne[2]; - d_head = C / n_head; - q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head] - q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head] - q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head] + L_q = q->ne[1]; + L_k = k->ne[1]; + C = q->ne[0]; + N = q->ne[2]; + d_head = C / n_head; + n_kv_head = k->ne[0] / d_head; - k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] - k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] - k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] + q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head] + q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head] + q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head] - v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head] + k = ggml_reshape_4d(ctx, k, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head] + k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_kv_head, L_k, d_head] + k = ggml_reshape_3d(ctx, k, d_head, L_k, n_kv_head * N); // [N * n_kv_head, L_k, d_head] + + v = ggml_reshape_4d(ctx, v, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head] } else { - L_q = q->ne[1]; - L_k = k->ne[1]; - d_head = v->ne[0]; - N = v->ne[3]; - C = d_head * n_head; + L_q = q->ne[1]; + L_k = k->ne[1]; + d_head = v->ne[0]; + N = v->ne[3]; + n_kv_head = k->ne[2] / N; + C = d_head * n_head; } float scale = (1.0f / sqrt((float)d_head)); @@ -1174,7 +1178,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16); v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3)); - v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_head * N); + v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N); if (kv_pad != 0) { v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0); } @@ -1232,8 +1236,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* // if (flash_attn) { // LOG_DEBUG("fallback to default attention, L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); // } - v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] - v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] + v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_kv_head, d_head, L_k] + v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k] auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] kq = ggml_scale_inplace(ctx, kq, scale); diff --git a/model.cpp b/model.cpp index 330abeb..0d9574a 100644 --- a/model.cpp +++ b/model.cpp @@ -110,6 +110,9 @@ const char* unused_tensors[] = { "embedding_manager", "denoiser.sigmas", "text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training + "qwen2vl.output.weight", + "qwen2vl.lm_head.", + "qwen2vl.visual.", }; bool is_unused_tensor(std::string name) { @@ -193,6 +196,21 @@ std::unordered_map pmid_v2_name_map = { "pmid.qformer_perceiver.token_proj.fc2.weight"}, }; +std::unordered_map qwenvl_name_map{ + {"token_embd.", "model.embed_tokens."}, + {"blk.", "model.layers."}, + {"attn_q.", "self_attn.q_proj."}, + {"attn_k.", "self_attn.k_proj."}, + {"attn_v.", "self_attn.v_proj."}, + {"attn_output.", "self_attn.o_proj."}, + {"attn_norm.", "input_layernorm."}, + {"ffn_down.", "mlp.down_proj."}, + {"ffn_gate.", "mlp.gate_proj."}, + {"ffn_up.", "mlp.up_proj."}, + {"ffn_norm.", "post_attention_layernorm."}, + {"output_norm.", "model.norm."}, +}; + std::string convert_cond_model_name(const std::string& name) { std::string new_name = name; std::string prefix; @@ -250,6 +268,13 @@ std::string convert_cond_model_name(const std::string& name) { if (pos != std::string::npos) { new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias."); } + } else if (contains(name, "qwen2vl")) { + for (auto kv : qwenvl_name_map) { + size_t pos = new_name.find(kv.first); + if (pos != std::string::npos) { + new_name.replace(pos, kv.first.size(), kv.second); + } + } } else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") { new_name = "text_encoders.t5xxl.transformer.shared.weight"; } @@ -580,7 +605,11 @@ std::string convert_tensor_name(std::string name) { // name.replace(pos, strlen("lora_B"), "lora_down"); // } std::string new_name = name; - if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) { + if (starts_with(name, "cond_stage_model.") || + starts_with(name, "conditioner.embedders.") || + starts_with(name, "text_encoders.") || + ends_with(name, ".vision_model.visual_projection.weight") || + starts_with(name, "qwen2vl")) { new_name = convert_cond_model_name(name); } else if (starts_with(name, "first_stage_model.decoder")) { new_name = convert_vae_decoder_name(name); diff --git a/qwen.hpp b/qwen.hpp index d73a882..45611c8 100644 --- a/qwen.hpp +++ b/qwen.hpp @@ -3,314 +3,730 @@ #include "ggml_extend.hpp" +#include +#include +#include +#include +#include +#include +#include +#include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include "json.hpp" #include "clip.hpp" +#include "json.hpp" #include "tokenize_util.h" namespace Qwen { -class Qwen2Tokenizer { -private: - std::map byte_encoder; - std::map byte_decoder; - std::map encoder; - std::map decoder; - std::map, int> bpe_ranks; - std::regex pat; - int encoder_len; - int bpe_len; + class Qwen2Tokenizer { + private: + std::map byte_encoder; + std::map byte_decoder; + std::map encoder; + std::map decoder; + std::map, int> bpe_ranks; + std::regex pat; + int encoder_len; + int bpe_len; -public: - const std::string UNK_TOKEN = "<|endoftext|>"; - const std::string EOS_TOKEN = "<|endoftext|>"; - const std::string PAD_TOKEN = "<|endoftext|>"; + public: + const std::string UNK_TOKEN = "<|endoftext|>"; + const std::string EOS_TOKEN = "<|endoftext|>"; + const std::string PAD_TOKEN = "<|endoftext|>"; - const int UNK_TOKEN_ID = 151643; - const int EOS_TOKEN_ID = 151643; - const int PAD_TOKEN_ID = 151643; + const int UNK_TOKEN_ID = 151643; + const int EOS_TOKEN_ID = 151643; + const int PAD_TOKEN_ID = 151643; -private: - static std::string strip(const std::string& str) { - std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); - std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); + std::vector special_tokens = { + "<|endoftext|>", + "<|im_start|>", + "<|im_end|>", + "<|object_ref_start|>", + "<|object_ref_end|>", + "<|box_start|>", + "<|box_end|>", + "<|quad_start|>", + "<|quad_end|>", + "<|vision_start|>", + "<|vision_end|>", + "<|vision_pad|>", + "<|image_pad|>", + "<|video_pad|>", + "", + "", + "<|fim_prefix|>", + "<|fim_middle|>", + "<|fim_suffix|>", + "<|fim_pad|>", + "<|repo_name|>", + "<|file_sep|>", + }; - if (start == std::string::npos) { - // String contains only whitespace characters - return ""; + private: + static std::string strip(const std::string& str) { + std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); + std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); + + if (start == std::string::npos) { + // String contains only whitespace characters + return ""; + } + + return str.substr(start, end - start + 1); } - return str.substr(start, end - start + 1); - } + static std::string whitespace_clean(std::string text) { + text = std::regex_replace(text, std::regex(R"(\s+)"), " "); + text = strip(text); + return text; + } - static std::string whitespace_clean(std::string text) { - text = std::regex_replace(text, std::regex(R"(\s+)"), " "); - text = strip(text); - return text; - } - - static std::set> get_pairs(const std::vector& subwords) { - std::set> pairs; - if (subwords.size() == 0) { + static std::set> get_pairs(const std::vector& subwords) { + std::set> pairs; + if (subwords.size() == 0) { + return pairs; + } + std::u32string prev_subword = subwords[0]; + for (int i = 1; i < subwords.size(); i++) { + std::u32string subword = subwords[i]; + std::pair pair(prev_subword, subword); + pairs.insert(pair); + prev_subword = subword; + } return pairs; } - std::u32string prev_subword = subwords[0]; - for (int i = 1; i < subwords.size(); i++) { - std::u32string subword = subwords[i]; - std::pair pair(prev_subword, subword); - pairs.insert(pair); - prev_subword = subword; - } - return pairs; - } -public: - explicit Qwen2Tokenizer(const std::string& merges_utf8_str = "") { - if (merges_utf8_str.size() > 0) { - load_from_merges(merges_utf8_str); - } else { - load_from_merges(ModelLoader::load_qwen2_merges()); - } - } - - void load_from_merges(const std::string& merges_utf8_str) { - auto byte_unicode_pairs = bytes_to_unicode(); - // printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size()); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); - for (auto& pair : byte_unicode_pairs) { - byte_decoder[pair.second] = pair.first; - } - // for (auto & pair: byte_unicode_pairs) { - // std::cout << pair.first << ": " << pair.second << std::endl; - // } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - LOG_DEBUG("merges size %llu", merges.size()); - // GGML_ASSERT(merges.size() == 48895); - merges = std::vector(merges.begin(), merges.end()); - std::vector> merge_pairs; - for (const auto& merge : merges) { - size_t space_pos = merge.find(' '); - merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), - // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - } - - std::vector vocab; - for (const auto& pair : byte_unicode_pairs) { - vocab.push_back(pair.second); - } - for (const auto& merge : merge_pairs) { - vocab.push_back(merge.first + merge.second); - } - vocab.push_back(utf8_to_utf32("<|endoftext|>")); - vocab.push_back(utf8_to_utf32("<|im_start|>")); - vocab.push_back(utf8_to_utf32("<|im_end|>")); - vocab.push_back(utf8_to_utf32("<|object_ref_start|>")); - vocab.push_back(utf8_to_utf32("<|object_ref_end|>")); - vocab.push_back(utf8_to_utf32("<|box_start|>")); - vocab.push_back(utf8_to_utf32("<|box_end|>")); - vocab.push_back(utf8_to_utf32("<|quad_start|>")); - vocab.push_back(utf8_to_utf32("<|quad_end|>")); - vocab.push_back(utf8_to_utf32("<|vision_start|>")); - vocab.push_back(utf8_to_utf32("<|vision_end|>")); - vocab.push_back(utf8_to_utf32("<|vision_pad|>")); - vocab.push_back(utf8_to_utf32("<|image_pad|>")); - vocab.push_back(utf8_to_utf32("<|video_pad|>")); - vocab.push_back(utf8_to_utf32("")); - vocab.push_back(utf8_to_utf32("")); - vocab.push_back(utf8_to_utf32("<|fim_prefix|>")); - vocab.push_back(utf8_to_utf32("<|fim_middle|>")); - vocab.push_back(utf8_to_utf32("<|fim_suffix|>")); - vocab.push_back(utf8_to_utf32("<|fim_pad|>")); - vocab.push_back(utf8_to_utf32("<|repo_name|>")); - vocab.push_back(utf8_to_utf32("<|file_sep|>")); - - LOG_DEBUG("vocab size: %llu", vocab.size()); - int i = 0; - for (const auto& token : vocab) { - encoder[token] = i; - decoder[i] = token; - i++; - } - encoder_len = i; - - int rank = 0; - for (const auto& merge : merge_pairs) { - bpe_ranks[merge] = rank++; - } - bpe_len = rank; - }; - - std::u32string bpe(const std::u32string& token) { - std::vector word; - - for (int i = 0; i < token.size(); i++) { - word.emplace_back(1, token[i]); - } - - std::set> pairs = get_pairs(word); - - if (pairs.empty()) { - return token; - } - - while (true) { - auto min_pair_iter = std::min_element(pairs.begin(), - pairs.end(), - [&](const std::pair& a, - const std::pair& b) { - if (bpe_ranks.find(a) == bpe_ranks.end()) { - return false; - } else if (bpe_ranks.find(b) == bpe_ranks.end()) { - return true; - } - return bpe_ranks.at(a) < bpe_ranks.at(b); - }); - - const std::pair& bigram = *min_pair_iter; - - if (bpe_ranks.find(bigram) == bpe_ranks.end()) { - break; - } - - std::u32string first = bigram.first; - std::u32string second = bigram.second; - std::vector new_word; - int32_t i = 0; - - while (i < word.size()) { - auto it = std::find(word.begin() + i, word.end(), first); - if (it == word.end()) { - new_word.insert(new_word.end(), word.begin() + i, word.end()); - break; - } - new_word.insert(new_word.end(), word.begin() + i, it); - i = static_cast(std::distance(word.begin(), it)); - - if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { - new_word.push_back(first + second); - i += 2; - } else { - new_word.push_back(word[i]); - i += 1; + bool is_special_token(const std::string& token) { + for (auto& special_token : special_tokens) { + if (special_token == token) { + return true; } } - - word = new_word; - - if (word.size() == 1) { - break; - } - pairs = get_pairs(word); + return false; } - std::u32string result; - for (int i = 0; i < word.size(); i++) { - result += word[i]; - if (i != word.size() - 1) { - result += utf8_to_utf32(" "); + public: + explicit Qwen2Tokenizer(const std::string& merges_utf8_str = "") { + if (merges_utf8_str.size() > 0) { + load_from_merges(merges_utf8_str); + } else { + load_from_merges(ModelLoader::load_qwen2_merges()); } } - return result; - } - - std::vector tokenize(std::string text, - on_new_token_cb_t on_new_token_cb, - size_t max_length = 0, - bool padding = false) { - std::vector tokens = encode(text, on_new_token_cb); - - if (max_length > 0) { - tokens.resize(max_length); - if (padding) { - tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID); + void load_from_merges(const std::string& merges_utf8_str) { + auto byte_unicode_pairs = bytes_to_unicode(); + // printf("byte_unicode_pairs have %lu pairs \n", byte_unicode_pairs.size()); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + for (auto& pair : byte_unicode_pairs) { + byte_decoder[pair.second] = pair.first; } - } - - return tokens; - } - - void pad_tokens(std::vector& tokens, - std::vector& weights, - size_t max_length = 0, - bool padding = false) { - if (max_length > 0 && padding) { - size_t n = std::ceil(tokens.size() * 1.0 / max_length); - if (n == 0) { - n = 1; - } - size_t length = max_length * n; - LOG_DEBUG("token length: %llu", length); - tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID); - weights.insert(weights.end(), length - weights.size(), 1.0); - } - } - - std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb) { - std::string original_text = text; - std::vector bpe_tokens; - - auto tokens = token_split(text); - std::vector token_strs; - for (auto& token : tokens) { - bool skip = on_new_token_cb(token, bpe_tokens); - if (skip) { - continue; - } - std::string token_str = token; - std::u32string utf32_token; - for (int i = 0; i < token_str.length(); i++) { - unsigned char b = token_str[i]; - utf32_token += byte_encoder[b]; - } - auto bpe_strs = bpe(utf32_token); - size_t start = 0; + // for (auto & pair: byte_unicode_pairs) { + // std::cout << pair.first << ": " << pair.second << std::endl; + // } + std::vector merges; + size_t start = 0; size_t pos; - while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { - auto bpe_str = bpe_strs.substr(start, pos - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - + std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); + while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { + merges.push_back(merges_utf32_str.substr(start, pos - start)); start = pos + 1; } - auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); + LOG_DEBUG("merges size %llu", merges.size()); + merges = std::vector(merges.begin(), merges.end()); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); + // printf("%s :: %s | %s \n", utf32_to_utf8(merge).c_str(), utf32_to_utf8(merge.substr(0, space_pos)).c_str(), + // utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); + } + + std::vector vocab; + for (const auto& pair : byte_unicode_pairs) { + vocab.push_back(pair.second); + } + for (const auto& merge : merge_pairs) { + vocab.push_back(merge.first + merge.second); + } + for (auto& special_token : special_tokens) { + vocab.push_back(utf8_to_utf32(special_token)); + } + + LOG_DEBUG("vocab size: %llu", vocab.size()); + int i = 0; + for (const auto& token : vocab) { + encoder[token] = i; + decoder[i] = token; + i++; + } + encoder_len = i; + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + bpe_len = rank; + }; + + std::u32string bpe(const std::u32string& token) { + std::vector word; + + for (int i = 0; i < token.size(); i++) { + word.emplace_back(1, token[i]); + } + + std::set> pairs = get_pairs(word); + + if (pairs.empty()) { + return token; + } + + while (true) { + auto min_pair_iter = std::min_element(pairs.begin(), + pairs.end(), + [&](const std::pair& a, + const std::pair& b) { + if (bpe_ranks.find(a) == bpe_ranks.end()) { + return false; + } else if (bpe_ranks.find(b) == bpe_ranks.end()) { + return true; + } + return bpe_ranks.at(a) < bpe_ranks.at(b); + }); + + const std::pair& bigram = *min_pair_iter; + + if (bpe_ranks.find(bigram) == bpe_ranks.end()) { + break; + } + + std::u32string first = bigram.first; + std::u32string second = bigram.second; + std::vector new_word; + int32_t i = 0; + + while (i < word.size()) { + auto it = std::find(word.begin() + i, word.end(), first); + if (it == word.end()) { + new_word.insert(new_word.end(), word.begin() + i, word.end()); + break; + } + new_word.insert(new_word.end(), word.begin() + i, it); + i = static_cast(std::distance(word.begin(), it)); + + if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { + new_word.push_back(first + second); + i += 2; + } else { + new_word.push_back(word[i]); + i += 1; + } + } + + word = new_word; + + if (word.size() == 1) { + break; + } + pairs = get_pairs(word); + } + + std::u32string result; + for (int i = 0; i < word.size(); i++) { + result += word[i]; + if (i != word.size() - 1) { + result += utf8_to_utf32(" "); + } + } + + return result; } - std::stringstream ss; - ss << "["; - for (auto token : token_strs) { - ss << "\"" << token << "\", "; + std::vector tokenize(std::string text, + on_new_token_cb_t on_new_token_cb = nullptr, + size_t max_length = 0, + bool padding = false) { + std::vector tokens = encode(text, on_new_token_cb); + + if (max_length > 0) { + if (tokens.size() < max_length) { + tokens.resize(max_length); + } else { + if (padding) { + tokens.insert(tokens.end(), max_length - tokens.size(), PAD_TOKEN_ID); + } + } + } + + return tokens; } - ss << "]"; - LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); - // printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str()); - return bpe_tokens; - } -}; -}; + void pad_tokens(std::vector& tokens, + std::vector& weights, + size_t max_length = 0, + bool padding = false) { + if (max_length > 0 && padding) { + size_t n = std::ceil(tokens.size() * 1.0 / max_length); + if (n == 0) { + n = 1; + } + size_t length = max_length * n; + LOG_DEBUG("token length: %llu", length); + tokens.insert(tokens.end(), length - tokens.size(), PAD_TOKEN_ID); + weights.insert(weights.end(), length - weights.size(), 1.0); + } + } + std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb = nullptr) { + std::string original_text = text; + std::vector bpe_tokens; + std::vector token_strs; + auto splited_texts = split_with_special_tokens(text, special_tokens); -#endif // __QWEN_HPP__ + for (auto& splited_text : splited_texts) { + if (is_special_token(splited_text)) { + bpe_tokens.push_back(encoder[utf8_to_utf32(splited_text)]); + token_strs.push_back(splited_text); + continue; + } + auto tokens = token_split(splited_text); + for (auto& token : tokens) { + if (on_new_token_cb != nullptr) { + bool skip = on_new_token_cb(token, bpe_tokens); + if (skip) { + continue; + } + } + + std::string token_str = token; + std::u32string utf32_token; + for (int i = 0; i < token_str.length(); i++) { + unsigned char b = token_str[i]; + utf32_token += byte_encoder[b]; + } + auto bpe_strs = bpe(utf32_token); + size_t start = 0; + size_t pos; + while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { + auto bpe_str = bpe_strs.substr(start, pos - start); + bpe_tokens.push_back(encoder[bpe_str]); + token_strs.push_back(utf32_to_utf8(bpe_str)); + + start = pos + 1; + } + auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); + bpe_tokens.push_back(encoder[bpe_str]); + token_strs.push_back(utf32_to_utf8(bpe_str)); + } + } + + std::stringstream ss; + ss << "["; + for (auto token : token_strs) { + ss << "\"" << token << "\", "; + } + ss << "]"; + LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); + // printf("split prompt \"%s\" to tokens %s \n", original_text.c_str(), ss.str().c_str()); + return bpe_tokens; + } + }; + + struct Qwen2_5_VLMLP : public GGMLBlock { + public: + Qwen2_5_VLMLP(int64_t hidden_size, int64_t intermediate_size, bool bias = false) { + blocks["gate_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, false)); + blocks["up_proj"] = std::shared_ptr(new Linear(hidden_size, intermediate_size, false)); + blocks["down_proj"] = std::shared_ptr(new Linear(intermediate_size, hidden_size, false)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + // x: [N, n_token, hidden_size] + auto gate_proj = std::dynamic_pointer_cast(blocks["gate_proj"]); + auto up_proj = std::dynamic_pointer_cast(blocks["up_proj"]); + auto down_proj = std::dynamic_pointer_cast(blocks["down_proj"]); + + auto h = gate_proj->forward(ctx, x); + h = ggml_silu_inplace(ctx, h); + h = ggml_mul_inplace(ctx, h, up_proj->forward(ctx, x)); + h = down_proj->forward(ctx, h); + return h; + } + }; + + class Qwen2_5_VLAttention : public GGMLBlock { + protected: + int64_t head_dim; + int64_t num_heads; + int64_t num_kv_heads; + + public: + Qwen2_5_VLAttention(int64_t hidden_size, + int64_t num_heads, + int64_t num_kv_heads) + : num_heads(num_heads), num_kv_heads(num_kv_heads) { + head_dim = hidden_size / num_heads; + GGML_ASSERT(num_heads * head_dim == hidden_size); + blocks["q_proj"] = std::shared_ptr(new Linear(hidden_size, num_heads * head_dim)); + blocks["k_proj"] = std::shared_ptr(new Linear(hidden_size, num_kv_heads * head_dim)); + blocks["v_proj"] = std::shared_ptr(new Linear(hidden_size, num_kv_heads * head_dim)); + blocks["o_proj"] = std::shared_ptr(new Linear(num_heads * head_dim, hidden_size, false)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* input_pos) { + // x: [N, n_token, hidden_size] + int64_t n_token = x->ne[1]; + int64_t N = x->ne[2]; + auto q_proj = std::dynamic_pointer_cast(blocks["q_proj"]); + auto k_proj = std::dynamic_pointer_cast(blocks["k_proj"]); + auto v_proj = std::dynamic_pointer_cast(blocks["v_proj"]); + auto out_proj = std::dynamic_pointer_cast(blocks["o_proj"]); + + auto q = q_proj->forward(ctx, x); // [N, n_token, num_heads*head_dim] + auto k = k_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim] + auto v = v_proj->forward(ctx, x); // [N, n_token, num_kv_heads*head_dim] + + q = ggml_reshape_4d(ctx, q, head_dim, num_heads, n_token, N); // [N, n_token, num_heads, head_dim] + k = ggml_reshape_4d(ctx, k, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] + v = ggml_reshape_4d(ctx, v, head_dim, num_kv_heads, n_token, N); // [N, n_token, num_kv_heads, head_dim] + + int sections[4] = {16, 24, 24, 0}; + q = ggml_rope_multi(ctx, q, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + k = ggml_rope_multi(ctx, k, input_pos, nullptr, head_dim, sections, GGML_ROPE_TYPE_MROPE, 128000, 1000000.f, 1.f, 0.f, 1.f, 32.f, 1.f); + + q = ggml_cont(ctx, ggml_torch_permute(ctx, q, 0, 2, 1, 3)); // [N, num_heads, n_token, head_dim] + q = ggml_reshape_3d(ctx, q, q->ne[0], q->ne[1], q->ne[2] * q->ne[3]); // [N*num_heads, n_token, head_dim] + + k = ggml_cont(ctx, ggml_torch_permute(ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim] + k = ggml_reshape_3d(ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim] + + x = ggml_nn_attention_ext(ctx, backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size] + + x = out_proj->forward(ctx, x); // [N, n_token, hidden_size] + return x; + } + }; + + struct Qwen2_5_VLBlock : public GGMLBlock { + public: + Qwen2_5_VLBlock(int64_t hidden_size, + int64_t intermediate_size, + int64_t num_heads, + int64_t num_kv_heads, + float eps = 1e-6f) { + blocks["self_attn"] = std::shared_ptr(new Qwen2_5_VLAttention(hidden_size, num_heads, num_kv_heads)); + blocks["mlp"] = std::shared_ptr(new Qwen2_5_VLMLP(hidden_size, intermediate_size)); + blocks["input_layernorm"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); + blocks["post_attention_layernorm"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* x, + struct ggml_tensor* input_pos) { + // x: [N, n_token, hidden_size] + auto self_attn = std::dynamic_pointer_cast(blocks["self_attn"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + auto input_layernorm = std::dynamic_pointer_cast(blocks["input_layernorm"]); + auto post_attention_layernorm = std::dynamic_pointer_cast(blocks["post_attention_layernorm"]); + + auto residual = x; + x = input_layernorm->forward(ctx, x); + x = self_attn->forward(ctx, backend, x, input_pos); + x = ggml_add_inplace(ctx, x, residual); + + residual = x; + x = post_attention_layernorm->forward(ctx, x); + x = mlp->forward(ctx, x); + x = ggml_add_inplace(ctx, x, residual); + + return x; + } + }; + + struct Qwen2_5_VLTextModel : public GGMLBlock { + protected: + int64_t num_layers; + + public: + Qwen2_5_VLTextModel(int64_t num_layers, + int64_t vocab_size, + int64_t hidden_size, + int64_t intermediate_size, + int64_t num_heads, + int64_t num_kv_heads, + float eps = 1e-6f) + : num_layers(num_layers) { + blocks["embed_tokens"] = std::shared_ptr(new Embedding(vocab_size, hidden_size)); + for (int i = 0; i < num_layers; i++) { + blocks["layers." + std::to_string(i)] = std::shared_ptr(new Qwen2_5_VLBlock(hidden_size, + intermediate_size, + num_heads, + num_kv_heads)); + } + blocks["norm"] = std::shared_ptr(new RMSNorm(hidden_size, eps)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* input_ids, + struct ggml_tensor* input_pos) { + // input_ids: [N, n_token] + // return: [N, n_token, hidden_size] + + auto embed_tokens = std::dynamic_pointer_cast(blocks["embed_tokens"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + + auto x = embed_tokens->forward(ctx, input_ids); + + for (int i = 0; i < num_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); + + x = block->forward(ctx, backend, x, input_pos); + } + + x = norm->forward(ctx, x); + return x; + } + }; + + struct Qwen2_5_VLParams { + int64_t num_layers = 28; + int64_t hidden_size = 3584; + int64_t intermediate_size = 18944; + int64_t num_heads = 28; + int64_t num_kv_heads = 4; + int64_t vocab_size = 152064; + float rms_norm_eps = 1e-06f; + }; + + struct Qwen2_5_VL : public GGMLBlock { + Qwen2_5_VLParams params; + + public: + Qwen2_5_VL() {} + Qwen2_5_VL(Qwen2_5_VLParams params) + : params(params) { + blocks["model"] = std::shared_ptr(new Qwen2_5_VLTextModel(params.num_layers, + params.vocab_size, + params.hidden_size, + params.intermediate_size, + params.num_heads, + params.num_kv_heads, + params.rms_norm_eps)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* input_ids, + struct ggml_tensor* input_pos) { + // input_ids: [N, n_token] + auto model = std::dynamic_pointer_cast(blocks["model"]); + + auto x = model->forward(ctx, backend, input_ids, input_pos); + return x; + } + }; + + struct Qwen2_5_VLRunner : public GGMLRunner { + Qwen2_5_VLParams params; + Qwen2_5_VL model; + + std::vector input_pos_vec; + + Qwen2_5_VLRunner(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2GGMLType& tensor_types, + const std::string prefix) + : GGMLRunner(backend, offload_params_to_cpu) { + model = Qwen2_5_VL(params); + model.init(params_ctx, tensor_types, prefix); + } + + std::string get_desc() { + return "qwenvl2.5"; + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + model.get_param_tensors(tensors, prefix); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + ggml_backend_t backend, + struct ggml_tensor* input_ids, + struct ggml_tensor* input_pos) { + auto hidden_states = model.forward(ctx, backend, input_ids, input_pos); // [N, n_token, hidden_size] + return hidden_states; + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) { + struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); + + input_ids = to_backend(input_ids); + + int64_t n_tokens = input_ids->ne[0]; + input_pos_vec.resize(n_tokens * 4); + for (int i = 0; i < n_tokens; ++i) { + input_pos_vec[i] = i; + input_pos_vec[n_tokens + i] = i; + input_pos_vec[2 * n_tokens + i] = i; + input_pos_vec[3 * n_tokens + i] = 0; + } + + auto input_pos = ggml_new_tensor_1d(compute_ctx, + GGML_TYPE_I32, + n_tokens * 4); + set_backend_tensor_data(input_pos, input_pos_vec.data()); + + struct ggml_tensor* hidden_states = forward(compute_ctx, runtime_backend, input_ids, input_pos); + + ggml_build_forward_expand(gf, hidden_states); + + return gf; + } + + void compute(const int n_threads, + struct ggml_tensor* input_ids, + ggml_tensor** output, + ggml_context* output_ctx = NULL) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(input_ids); + }; + GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); + } + }; + + struct Qwen2_5_VLEmbedder { + Qwen2Tokenizer tokenizer; + Qwen2_5_VLRunner model; + + Qwen2_5_VLEmbedder(ggml_backend_t backend, + bool offload_params_to_cpu, + const String2GGMLType& tensor_types = {}, + const std::string prefix = "") + : model(backend, offload_params_to_cpu, tensor_types, prefix) { + } + + void get_param_tensors(std::map& tensors, const std::string prefix) { + model.get_param_tensors(tensors, prefix); + } + + void alloc_params_buffer() { + model.alloc_params_buffer(); + } + + std::tuple, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { + auto parsed_attention = parse_prompt_attention(text); + + { + std::stringstream ss; + ss << "["; + for (const auto& item : parsed_attention) { + ss << "['" << item.first << "', " << item.second << "], "; + } + ss << "]"; + LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + } + + std::vector tokens; + std::vector weights; + for (const auto& item : parsed_attention) { + const std::string& curr_text = item.first; + float curr_weight = item.second; + std::vector curr_tokens = tokenizer.tokenize(curr_text, nullptr); + tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); + weights.insert(weights.end(), curr_tokens.size(), curr_weight); + } + + tokenizer.pad_tokens(tokens, weights, max_length, padding); + + // for (int i = 0; i < tokens.size(); i++) { + // std::cout << tokens[i] << ":" << weights[i] << ", "; + // } + // std::cout << std::endl; + + return {tokens, weights}; + } + + void test() { + struct ggml_init_params params; + params.mem_size = static_cast(1024 * 1024) * 1024; // 1GB + params.mem_buffer = NULL; + params.no_alloc = false; + + struct ggml_context* work_ctx = ggml_init(params); + GGML_ASSERT(work_ctx != NULL); + + { + std::string text("<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\na lovely cat<|im_end|>\n<|im_start|>assistant\n"); + auto tokens_and_weights = tokenize(text, 0, false); + std::vector& tokens = std::get<0>(tokens_and_weights); + std::vector& weights = std::get<1>(tokens_and_weights); + for (auto token : tokens) { + printf("%d ", token); + } + printf("\n"); + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); + struct ggml_tensor* out = NULL; + + int t0 = ggml_time_ms(); + model.compute(8, input_ids, &out, work_ctx); + int t1 = ggml_time_ms(); + + print_ggml_tensor(out); + LOG_DEBUG("qwen2vl test done in %dms", t1 - t0); + } + } + + static void load_from_file_and_test(const std::string& file_path) { + // cpu f16: pass + // ggml_backend_t backend = ggml_backend_cuda_init(0); + ggml_backend_t backend = ggml_backend_cpu_init(); + ggml_type model_data_type = GGML_TYPE_Q8_0; + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path, "qwen2vl.")) { + LOG_ERROR("init model loader from file failed: '%s'", file_path.c_str()); + return; + } + + auto tensor_types = model_loader.tensor_storages_types; + for (auto& item : tensor_types) { + // LOG_DEBUG("%s %u", item.first.c_str(), item.second); + if (ends_with(item.first, "weight")) { + item.second = model_data_type; + } + } + + std::shared_ptr qwenvl = std::shared_ptr(new Qwen2_5_VLEmbedder(backend, false, tensor_types, "qwen2vl")); + + qwenvl->alloc_params_buffer(); + std::map tensors; + qwenvl->get_param_tensors(tensors, "qwen2vl"); + + bool success = model_loader.load_tensors(tensors); + + if (!success) { + LOG_ERROR("load tensors from model loader failed"); + return; + } + + LOG_INFO("qwenvl model loaded"); + qwenvl->test(); + } + }; + +}; // Qwen + +#endif // __QWEN_HPP__ diff --git a/tokenize_util.cpp b/tokenize_util.cpp index 85e3821..d2e040e 100644 --- a/tokenize_util.cpp +++ b/tokenize_util.cpp @@ -1,7 +1,7 @@ +#include #include #include #include -#include #include "tokenize_util.h" @@ -697,36 +697,37 @@ bool is_letter(char32_t ch) { {0x31350, 0x33479}, }; - for (const auto &r : ranges) { - if (ch >= r.start && ch <= r.end) return true; + for (const auto& r : ranges) { + if (ch >= r.start && ch <= r.end) + return true; } return false; } bool is_space(char32_t cp) { switch (cp) { - case 0x0009: // TAB \t - case 0x000A: // LF \n - case 0x000B: // VT - case 0x000C: // FF - case 0x000D: // CR \r - case 0x0020: // Space - case 0x00A0: // No-Break Space - case 0x1680: // Ogham Space Mark - case 0x2000: // En Quad - case 0x2001: // Em Quad - case 0x2002: // En Space - case 0x2003: // Em Space - case 0x2004: // Three-Per-Em Space - case 0x2005: // Four-Per-Em Space - case 0x2006: // Six-Per-Em Space - case 0x2007: // Figure Space - case 0x2008: // Punctuation Space - case 0x2009: // Thin Space - case 0x200A: // Hair Space - case 0x202F: // Narrow No-Break Space - case 0x205F: // Medium Mathematical Space - case 0x3000: // Ideographic Space + case 0x0009: // TAB \t + case 0x000A: // LF \n + case 0x000B: // VT + case 0x000C: // FF + case 0x000D: // CR \r + case 0x0020: // Space + case 0x00A0: // No-Break Space + case 0x1680: // Ogham Space Mark + case 0x2000: // En Quad + case 0x2001: // Em Quad + case 0x2002: // En Space + case 0x2003: // Em Space + case 0x2004: // Three-Per-Em Space + case 0x2005: // Four-Per-Em Space + case 0x2006: // Six-Per-Em Space + case 0x2007: // Figure Space + case 0x2008: // Punctuation Space + case 0x2009: // Thin Space + case 0x200A: // Hair Space + case 0x202F: // Narrow No-Break Space + case 0x205F: // Medium Mathematical Space + case 0x3000: // Ideographic Space return true; default: return false; @@ -736,7 +737,7 @@ bool is_space(char32_t cp) { std::string str_to_lower(const std::string& input) { std::string result = input; std::transform(result.begin(), result.end(), result.begin(), - [](unsigned char c){ return std::tolower(c); }); + [](unsigned char c) { return std::tolower(c); }); return result; } @@ -745,17 +746,28 @@ std::vector utf8_to_codepoints(const std::string& str) { std::vector codepoints; size_t i = 0; while (i < str.size()) { - unsigned char c = str[i]; - char32_t cp = 0; + unsigned char c = str[i]; + char32_t cp = 0; size_t extra_bytes = 0; - if ((c & 0x80) == 0) cp = c; - else if ((c & 0xE0) == 0xC0) { cp = c & 0x1F; extra_bytes = 1; } - else if ((c & 0xF0) == 0xE0) { cp = c & 0x0F; extra_bytes = 2; } - else if ((c & 0xF8) == 0xF0) { cp = c & 0x07; extra_bytes = 3; } - else { ++i; continue; } // Invalid UTF-8 + if ((c & 0x80) == 0) + cp = c; + else if ((c & 0xE0) == 0xC0) { + cp = c & 0x1F; + extra_bytes = 1; + } else if ((c & 0xF0) == 0xE0) { + cp = c & 0x0F; + extra_bytes = 2; + } else if ((c & 0xF8) == 0xF0) { + cp = c & 0x07; + extra_bytes = 3; + } else { + ++i; + continue; + } // Invalid UTF-8 - if (i + extra_bytes >= str.size()) break; + if (i + extra_bytes >= str.size()) + break; for (size_t j = 1; j <= extra_bytes; ++j) cp = (cp << 6) | (str[i + j] & 0x3F); @@ -769,7 +781,8 @@ std::vector utf8_to_codepoints(const std::string& str) { // Unicode code point -> UTF-8 std::string codepoint_to_utf8(char32_t cp) { std::string out; - if (cp <= 0x7F) out.push_back(static_cast(cp)); + if (cp <= 0x7F) + out.push_back(static_cast(cp)); else if (cp <= 0x7FF) { out.push_back(static_cast(0xC0 | (cp >> 6))); out.push_back(static_cast(0x80 | (cp & 0x3F))); @@ -786,6 +799,17 @@ std::string codepoint_to_utf8(char32_t cp) { return out; } +bool starts_with(const std::vector& text, + const std::vector& prefix, + std::size_t index) { + if (index > text.size()) { + return false; + } + if (prefix.size() > text.size() - index) { + return false; + } + return std::equal(prefix.begin(), prefix.end(), text.begin() + index); +} std::vector token_split(const std::string& text) { std::vector tokens; @@ -797,14 +821,14 @@ std::vector token_split(const std::string& text) { // `(?i:'s|'t|'re|'ve|'m|'ll|'d)` if (cp == U'\'' && i + 1 < cps.size()) { - std::string next = str_to_lower(codepoint_to_utf8(cps[i+1])); + std::string next = str_to_lower(codepoint_to_utf8(cps[i + 1])); if (next == "s" || next == "t" || next == "m") { tokens.push_back("'" + next); i += 2; continue; } if (i + 2 < cps.size()) { - next += str_to_lower(codepoint_to_utf8(cps[i+2])); + next += str_to_lower(codepoint_to_utf8(cps[i + 2])); if (next == "re" || next == "ve" || next == "ll" || next == "d") { tokens.push_back("'" + next); i += 3; @@ -823,7 +847,7 @@ std::vector token_split(const std::string& text) { // `[^\r\n\p{L}\p{N}]?\p{L}+` { // `[^\r\n\p{L}\p{N}]\p{L}+` - if (!is_letter(cp) && cp != U'\r' && cp != U'\n' && i + 1 < cps.size() && is_letter(cps[i+1])) { + if (!is_letter(cp) && cp != U'\r' && cp != U'\n' && i + 1 < cps.size() && is_letter(cps[i + 1])) { std::string token = codepoint_to_utf8(cp); ++i; @@ -847,14 +871,14 @@ std::vector token_split(const std::string& text) { continue; } } - + // ` ?[^\s\p{L}\p{N}]+[\r\n]*` { // ` [^\s\p{L}\p{N}]+[\r\n]*` - if (cp == U' ' && i + 1 < cps.size() && !isspace(cps[i+1]) && !is_letter(cps[i+1]) && !is_number(cps[i+1])) { + if (cp == U' ' && i + 1 < cps.size() && !isspace(cps[i + 1]) && !is_letter(cps[i + 1]) && !is_number(cps[i + 1])) { std::string token = codepoint_to_utf8(cp); - token += codepoint_to_utf8(cps[i+1]); - i+=2; + token += codepoint_to_utf8(cps[i + 1]); + i += 2; while (i < cps.size() && !is_letter(cps[i]) && !is_number(cps[i]) && !isspace(cps[i])) { token += codepoint_to_utf8(cps[i]); @@ -915,6 +939,40 @@ std::vector token_split(const std::string& text) { return tokens; } +std::vector split_with_special_tokens( + const std::string& text, + const std::vector& special_tokens) { + std::vector result; + size_t pos = 0; + size_t text_len = text.size(); + + while (pos < text_len) { + size_t next_pos = text_len; + std::string matched_token; + + for (const auto& token : special_tokens) { + size_t token_pos = text.find(token, pos); + if (token_pos != std::string::npos && token_pos < next_pos) { + next_pos = token_pos; + matched_token = token; + } + } + + if (next_pos > pos) { + result.push_back(text.substr(pos, next_pos - pos)); + } + + if (!matched_token.empty()) { + result.push_back(matched_token); + pos = next_pos + matched_token.size(); + } else { + break; + } + } + + return result; +} + // int main() { // std::string text = "I'm testing C++ token_split function. 你好,世界! 123"; // auto tokens = token_split(text); diff --git a/tokenize_util.h b/tokenize_util.h index fca07a8..e744d75 100644 --- a/tokenize_util.h +++ b/tokenize_util.h @@ -5,5 +5,6 @@ #include std::vector token_split(const std::string& text); +std::vector split_with_special_tokens(const std::string& text, const std::vector& special_tokens); -#endif // __TOKENIZE_UTIL__ \ No newline at end of file +#endif // __TOKENIZE_UTIL__ \ No newline at end of file