mirror of
https://github.com/leejet/stable-diffusion.cpp.git
synced 2025-12-13 05:48:56 +00:00
add qwen2.5 vl support
This commit is contained in:
parent
f88daa5114
commit
fe4e73156f
@ -1142,17 +1142,7 @@ int main(int argc, const char* argv[]) {
|
|||||||
SDParams params;
|
SDParams params;
|
||||||
params.verbose = true;
|
params.verbose = true;
|
||||||
sd_set_log_callback(sd_log_cb, (void*)¶ms);
|
sd_set_log_callback(sd_log_cb, (void*)¶ms);
|
||||||
auto on_new_token_cb = [&](std::string& str, std::vector<int32_t>& bpe_tokens) -> bool {
|
Qwen::Qwen2_5_VLEmbedder::load_from_file_and_test(argv[1]);
|
||||||
return false;
|
|
||||||
};
|
|
||||||
// auto tokenizer = CLIPTokenizer();
|
|
||||||
auto tokenizer = Qwen::Qwen2Tokenizer();
|
|
||||||
std::string text("a lovely cat");
|
|
||||||
auto tokens = tokenizer.encode(text, on_new_token_cb);
|
|
||||||
for (auto token : tokens) {
|
|
||||||
std::cout << token << " ";
|
|
||||||
}
|
|
||||||
std::cout << std::endl;
|
|
||||||
exit(1);
|
exit(1);
|
||||||
parse_args(argc, argv, params);
|
parse_args(argc, argv, params);
|
||||||
params.sample_params.guidance.slg.layers = params.skip_layers.data();
|
params.sample_params.guidance.slg.layers = params.skip_layers.data();
|
||||||
|
|||||||
@ -1119,9 +1119,9 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
|
|||||||
return kqv;
|
return kqv;
|
||||||
}
|
}
|
||||||
|
|
||||||
// q: [N, L_q, C] or [N*n_head, L_q, d_head]
|
// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head]
|
||||||
// k: [N, L_k, C] or [N*n_head, L_k, d_head]
|
// k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head]
|
||||||
// v: [N, L_k, C] or [N, L_k, n_head, d_head]
|
// v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head]
|
||||||
// mask: [N, L_q, L_k]
|
// mask: [N, L_q, L_k]
|
||||||
// return: [N, L_q, C]
|
// return: [N, L_q, C]
|
||||||
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx,
|
__STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* ctx,
|
||||||
@ -1139,27 +1139,31 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
int64_t C;
|
int64_t C;
|
||||||
int64_t N;
|
int64_t N;
|
||||||
int64_t d_head;
|
int64_t d_head;
|
||||||
|
int64_t n_kv_head;
|
||||||
if (!skip_reshape) {
|
if (!skip_reshape) {
|
||||||
L_q = q->ne[1];
|
L_q = q->ne[1];
|
||||||
L_k = k->ne[1];
|
L_k = k->ne[1];
|
||||||
C = q->ne[0];
|
C = q->ne[0];
|
||||||
N = q->ne[2];
|
N = q->ne[2];
|
||||||
d_head = C / n_head;
|
d_head = C / n_head;
|
||||||
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
|
n_kv_head = k->ne[0] / d_head;
|
||||||
q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
|
|
||||||
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
|
|
||||||
|
|
||||||
k = ggml_reshape_4d(ctx, k, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
q = ggml_reshape_4d(ctx, q, d_head, n_head, L_q, N); // [N, L_q, n_head, d_head]
|
||||||
k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
|
q = ggml_nn_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, L_q, d_head]
|
||||||
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
|
q = ggml_reshape_3d(ctx, q, d_head, L_q, n_head * N); // [N * n_head, L_q, d_head]
|
||||||
|
|
||||||
v = ggml_reshape_4d(ctx, v, d_head, n_head, L_k, N); // [N, L_k, n_head, d_head]
|
k = ggml_reshape_4d(ctx, k, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head]
|
||||||
|
k = ggml_nn_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_kv_head, L_k, d_head]
|
||||||
|
k = ggml_reshape_3d(ctx, k, d_head, L_k, n_kv_head * N); // [N * n_kv_head, L_k, d_head]
|
||||||
|
|
||||||
|
v = ggml_reshape_4d(ctx, v, d_head, n_kv_head, L_k, N); // [N, L_k, n_kv_head, d_head]
|
||||||
} else {
|
} else {
|
||||||
L_q = q->ne[1];
|
L_q = q->ne[1];
|
||||||
L_k = k->ne[1];
|
L_k = k->ne[1];
|
||||||
d_head = v->ne[0];
|
d_head = v->ne[0];
|
||||||
N = v->ne[3];
|
N = v->ne[3];
|
||||||
C = d_head * n_head;
|
n_kv_head = k->ne[2] / N;
|
||||||
|
C = d_head * n_head;
|
||||||
}
|
}
|
||||||
|
|
||||||
float scale = (1.0f / sqrt((float)d_head));
|
float scale = (1.0f / sqrt((float)d_head));
|
||||||
@ -1174,7 +1178,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
|
k_in = ggml_cast(ctx, k_in, GGML_TYPE_F16);
|
||||||
|
|
||||||
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
|
v_in = ggml_nn_cont(ctx, ggml_permute(ctx, v_in, 0, 2, 1, 3));
|
||||||
v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_head * N);
|
v_in = ggml_reshape_3d(ctx, v_in, d_head, L_k, n_kv_head * N);
|
||||||
if (kv_pad != 0) {
|
if (kv_pad != 0) {
|
||||||
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
v_in = ggml_pad(ctx, v_in, 0, kv_pad, 0, 0);
|
||||||
}
|
}
|
||||||
@ -1232,8 +1236,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
|
|||||||
// if (flash_attn) {
|
// if (flash_attn) {
|
||||||
// LOG_DEBUG("fallback to default attention, L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
// LOG_DEBUG("fallback to default attention, L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
|
||||||
// }
|
// }
|
||||||
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k]
|
v = ggml_nn_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_kv_head, d_head, L_k]
|
||||||
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k]
|
v = ggml_reshape_3d(ctx, v, L_k, d_head, n_kv_head * N); // [N * n_kv_head, d_head, L_k]
|
||||||
|
|
||||||
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k]
|
||||||
kq = ggml_scale_inplace(ctx, kq, scale);
|
kq = ggml_scale_inplace(ctx, kq, scale);
|
||||||
|
|||||||
31
model.cpp
31
model.cpp
@ -110,6 +110,9 @@ const char* unused_tensors[] = {
|
|||||||
"embedding_manager",
|
"embedding_manager",
|
||||||
"denoiser.sigmas",
|
"denoiser.sigmas",
|
||||||
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
|
"text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training
|
||||||
|
"qwen2vl.output.weight",
|
||||||
|
"qwen2vl.lm_head.",
|
||||||
|
"qwen2vl.visual.",
|
||||||
};
|
};
|
||||||
|
|
||||||
bool is_unused_tensor(std::string name) {
|
bool is_unused_tensor(std::string name) {
|
||||||
@ -193,6 +196,21 @@ std::unordered_map<std::string, std::string> pmid_v2_name_map = {
|
|||||||
"pmid.qformer_perceiver.token_proj.fc2.weight"},
|
"pmid.qformer_perceiver.token_proj.fc2.weight"},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
std::unordered_map<std::string, std::string> qwenvl_name_map{
|
||||||
|
{"token_embd.", "model.embed_tokens."},
|
||||||
|
{"blk.", "model.layers."},
|
||||||
|
{"attn_q.", "self_attn.q_proj."},
|
||||||
|
{"attn_k.", "self_attn.k_proj."},
|
||||||
|
{"attn_v.", "self_attn.v_proj."},
|
||||||
|
{"attn_output.", "self_attn.o_proj."},
|
||||||
|
{"attn_norm.", "input_layernorm."},
|
||||||
|
{"ffn_down.", "mlp.down_proj."},
|
||||||
|
{"ffn_gate.", "mlp.gate_proj."},
|
||||||
|
{"ffn_up.", "mlp.up_proj."},
|
||||||
|
{"ffn_norm.", "post_attention_layernorm."},
|
||||||
|
{"output_norm.", "model.norm."},
|
||||||
|
};
|
||||||
|
|
||||||
std::string convert_cond_model_name(const std::string& name) {
|
std::string convert_cond_model_name(const std::string& name) {
|
||||||
std::string new_name = name;
|
std::string new_name = name;
|
||||||
std::string prefix;
|
std::string prefix;
|
||||||
@ -250,6 +268,13 @@ std::string convert_cond_model_name(const std::string& name) {
|
|||||||
if (pos != std::string::npos) {
|
if (pos != std::string::npos) {
|
||||||
new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias.");
|
new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias.");
|
||||||
}
|
}
|
||||||
|
} else if (contains(name, "qwen2vl")) {
|
||||||
|
for (auto kv : qwenvl_name_map) {
|
||||||
|
size_t pos = new_name.find(kv.first);
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
new_name.replace(pos, kv.first.size(), kv.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
} else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") {
|
} else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") {
|
||||||
new_name = "text_encoders.t5xxl.transformer.shared.weight";
|
new_name = "text_encoders.t5xxl.transformer.shared.weight";
|
||||||
}
|
}
|
||||||
@ -580,7 +605,11 @@ std::string convert_tensor_name(std::string name) {
|
|||||||
// name.replace(pos, strlen("lora_B"), "lora_down");
|
// name.replace(pos, strlen("lora_B"), "lora_down");
|
||||||
// }
|
// }
|
||||||
std::string new_name = name;
|
std::string new_name = name;
|
||||||
if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) {
|
if (starts_with(name, "cond_stage_model.") ||
|
||||||
|
starts_with(name, "conditioner.embedders.") ||
|
||||||
|
starts_with(name, "text_encoders.") ||
|
||||||
|
ends_with(name, ".vision_model.visual_projection.weight") ||
|
||||||
|
starts_with(name, "qwen2vl")) {
|
||||||
new_name = convert_cond_model_name(name);
|
new_name = convert_cond_model_name(name);
|
||||||
} else if (starts_with(name, "first_stage_model.decoder")) {
|
} else if (starts_with(name, "first_stage_model.decoder")) {
|
||||||
new_name = convert_vae_decoder_name(name);
|
new_name = convert_vae_decoder_name(name);
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
|
#include <algorithm>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#include "tokenize_util.h"
|
#include "tokenize_util.h"
|
||||||
|
|
||||||
@ -697,36 +697,37 @@ bool is_letter(char32_t ch) {
|
|||||||
{0x31350, 0x33479},
|
{0x31350, 0x33479},
|
||||||
};
|
};
|
||||||
|
|
||||||
for (const auto &r : ranges) {
|
for (const auto& r : ranges) {
|
||||||
if (ch >= r.start && ch <= r.end) return true;
|
if (ch >= r.start && ch <= r.end)
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_space(char32_t cp) {
|
bool is_space(char32_t cp) {
|
||||||
switch (cp) {
|
switch (cp) {
|
||||||
case 0x0009: // TAB \t
|
case 0x0009: // TAB \t
|
||||||
case 0x000A: // LF \n
|
case 0x000A: // LF \n
|
||||||
case 0x000B: // VT
|
case 0x000B: // VT
|
||||||
case 0x000C: // FF
|
case 0x000C: // FF
|
||||||
case 0x000D: // CR \r
|
case 0x000D: // CR \r
|
||||||
case 0x0020: // Space
|
case 0x0020: // Space
|
||||||
case 0x00A0: // No-Break Space
|
case 0x00A0: // No-Break Space
|
||||||
case 0x1680: // Ogham Space Mark
|
case 0x1680: // Ogham Space Mark
|
||||||
case 0x2000: // En Quad
|
case 0x2000: // En Quad
|
||||||
case 0x2001: // Em Quad
|
case 0x2001: // Em Quad
|
||||||
case 0x2002: // En Space
|
case 0x2002: // En Space
|
||||||
case 0x2003: // Em Space
|
case 0x2003: // Em Space
|
||||||
case 0x2004: // Three-Per-Em Space
|
case 0x2004: // Three-Per-Em Space
|
||||||
case 0x2005: // Four-Per-Em Space
|
case 0x2005: // Four-Per-Em Space
|
||||||
case 0x2006: // Six-Per-Em Space
|
case 0x2006: // Six-Per-Em Space
|
||||||
case 0x2007: // Figure Space
|
case 0x2007: // Figure Space
|
||||||
case 0x2008: // Punctuation Space
|
case 0x2008: // Punctuation Space
|
||||||
case 0x2009: // Thin Space
|
case 0x2009: // Thin Space
|
||||||
case 0x200A: // Hair Space
|
case 0x200A: // Hair Space
|
||||||
case 0x202F: // Narrow No-Break Space
|
case 0x202F: // Narrow No-Break Space
|
||||||
case 0x205F: // Medium Mathematical Space
|
case 0x205F: // Medium Mathematical Space
|
||||||
case 0x3000: // Ideographic Space
|
case 0x3000: // Ideographic Space
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -736,7 +737,7 @@ bool is_space(char32_t cp) {
|
|||||||
std::string str_to_lower(const std::string& input) {
|
std::string str_to_lower(const std::string& input) {
|
||||||
std::string result = input;
|
std::string result = input;
|
||||||
std::transform(result.begin(), result.end(), result.begin(),
|
std::transform(result.begin(), result.end(), result.begin(),
|
||||||
[](unsigned char c){ return std::tolower(c); });
|
[](unsigned char c) { return std::tolower(c); });
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -745,17 +746,28 @@ std::vector<char32_t> utf8_to_codepoints(const std::string& str) {
|
|||||||
std::vector<char32_t> codepoints;
|
std::vector<char32_t> codepoints;
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
while (i < str.size()) {
|
while (i < str.size()) {
|
||||||
unsigned char c = str[i];
|
unsigned char c = str[i];
|
||||||
char32_t cp = 0;
|
char32_t cp = 0;
|
||||||
size_t extra_bytes = 0;
|
size_t extra_bytes = 0;
|
||||||
|
|
||||||
if ((c & 0x80) == 0) cp = c;
|
if ((c & 0x80) == 0)
|
||||||
else if ((c & 0xE0) == 0xC0) { cp = c & 0x1F; extra_bytes = 1; }
|
cp = c;
|
||||||
else if ((c & 0xF0) == 0xE0) { cp = c & 0x0F; extra_bytes = 2; }
|
else if ((c & 0xE0) == 0xC0) {
|
||||||
else if ((c & 0xF8) == 0xF0) { cp = c & 0x07; extra_bytes = 3; }
|
cp = c & 0x1F;
|
||||||
else { ++i; continue; } // Invalid UTF-8
|
extra_bytes = 1;
|
||||||
|
} else if ((c & 0xF0) == 0xE0) {
|
||||||
|
cp = c & 0x0F;
|
||||||
|
extra_bytes = 2;
|
||||||
|
} else if ((c & 0xF8) == 0xF0) {
|
||||||
|
cp = c & 0x07;
|
||||||
|
extra_bytes = 3;
|
||||||
|
} else {
|
||||||
|
++i;
|
||||||
|
continue;
|
||||||
|
} // Invalid UTF-8
|
||||||
|
|
||||||
if (i + extra_bytes >= str.size()) break;
|
if (i + extra_bytes >= str.size())
|
||||||
|
break;
|
||||||
|
|
||||||
for (size_t j = 1; j <= extra_bytes; ++j)
|
for (size_t j = 1; j <= extra_bytes; ++j)
|
||||||
cp = (cp << 6) | (str[i + j] & 0x3F);
|
cp = (cp << 6) | (str[i + j] & 0x3F);
|
||||||
@ -769,7 +781,8 @@ std::vector<char32_t> utf8_to_codepoints(const std::string& str) {
|
|||||||
// Unicode code point -> UTF-8
|
// Unicode code point -> UTF-8
|
||||||
std::string codepoint_to_utf8(char32_t cp) {
|
std::string codepoint_to_utf8(char32_t cp) {
|
||||||
std::string out;
|
std::string out;
|
||||||
if (cp <= 0x7F) out.push_back(static_cast<char>(cp));
|
if (cp <= 0x7F)
|
||||||
|
out.push_back(static_cast<char>(cp));
|
||||||
else if (cp <= 0x7FF) {
|
else if (cp <= 0x7FF) {
|
||||||
out.push_back(static_cast<char>(0xC0 | (cp >> 6)));
|
out.push_back(static_cast<char>(0xC0 | (cp >> 6)));
|
||||||
out.push_back(static_cast<char>(0x80 | (cp & 0x3F)));
|
out.push_back(static_cast<char>(0x80 | (cp & 0x3F)));
|
||||||
@ -786,6 +799,17 @@ std::string codepoint_to_utf8(char32_t cp) {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool starts_with(const std::vector<char32_t>& text,
|
||||||
|
const std::vector<char32_t>& prefix,
|
||||||
|
std::size_t index) {
|
||||||
|
if (index > text.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (prefix.size() > text.size() - index) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return std::equal(prefix.begin(), prefix.end(), text.begin() + index);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<std::string> token_split(const std::string& text) {
|
std::vector<std::string> token_split(const std::string& text) {
|
||||||
std::vector<std::string> tokens;
|
std::vector<std::string> tokens;
|
||||||
@ -797,14 +821,14 @@ std::vector<std::string> token_split(const std::string& text) {
|
|||||||
|
|
||||||
// `(?i:'s|'t|'re|'ve|'m|'ll|'d)`
|
// `(?i:'s|'t|'re|'ve|'m|'ll|'d)`
|
||||||
if (cp == U'\'' && i + 1 < cps.size()) {
|
if (cp == U'\'' && i + 1 < cps.size()) {
|
||||||
std::string next = str_to_lower(codepoint_to_utf8(cps[i+1]));
|
std::string next = str_to_lower(codepoint_to_utf8(cps[i + 1]));
|
||||||
if (next == "s" || next == "t" || next == "m") {
|
if (next == "s" || next == "t" || next == "m") {
|
||||||
tokens.push_back("'" + next);
|
tokens.push_back("'" + next);
|
||||||
i += 2;
|
i += 2;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (i + 2 < cps.size()) {
|
if (i + 2 < cps.size()) {
|
||||||
next += str_to_lower(codepoint_to_utf8(cps[i+2]));
|
next += str_to_lower(codepoint_to_utf8(cps[i + 2]));
|
||||||
if (next == "re" || next == "ve" || next == "ll" || next == "d") {
|
if (next == "re" || next == "ve" || next == "ll" || next == "d") {
|
||||||
tokens.push_back("'" + next);
|
tokens.push_back("'" + next);
|
||||||
i += 3;
|
i += 3;
|
||||||
@ -823,7 +847,7 @@ std::vector<std::string> token_split(const std::string& text) {
|
|||||||
// `[^\r\n\p{L}\p{N}]?\p{L}+`
|
// `[^\r\n\p{L}\p{N}]?\p{L}+`
|
||||||
{
|
{
|
||||||
// `[^\r\n\p{L}\p{N}]\p{L}+`
|
// `[^\r\n\p{L}\p{N}]\p{L}+`
|
||||||
if (!is_letter(cp) && cp != U'\r' && cp != U'\n' && i + 1 < cps.size() && is_letter(cps[i+1])) {
|
if (!is_letter(cp) && cp != U'\r' && cp != U'\n' && i + 1 < cps.size() && is_letter(cps[i + 1])) {
|
||||||
std::string token = codepoint_to_utf8(cp);
|
std::string token = codepoint_to_utf8(cp);
|
||||||
++i;
|
++i;
|
||||||
|
|
||||||
@ -847,14 +871,14 @@ std::vector<std::string> token_split(const std::string& text) {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ` ?[^\s\p{L}\p{N}]+[\r\n]*`
|
// ` ?[^\s\p{L}\p{N}]+[\r\n]*`
|
||||||
{
|
{
|
||||||
// ` [^\s\p{L}\p{N}]+[\r\n]*`
|
// ` [^\s\p{L}\p{N}]+[\r\n]*`
|
||||||
if (cp == U' ' && i + 1 < cps.size() && !isspace(cps[i+1]) && !is_letter(cps[i+1]) && !is_number(cps[i+1])) {
|
if (cp == U' ' && i + 1 < cps.size() && !isspace(cps[i + 1]) && !is_letter(cps[i + 1]) && !is_number(cps[i + 1])) {
|
||||||
std::string token = codepoint_to_utf8(cp);
|
std::string token = codepoint_to_utf8(cp);
|
||||||
token += codepoint_to_utf8(cps[i+1]);
|
token += codepoint_to_utf8(cps[i + 1]);
|
||||||
i+=2;
|
i += 2;
|
||||||
|
|
||||||
while (i < cps.size() && !is_letter(cps[i]) && !is_number(cps[i]) && !isspace(cps[i])) {
|
while (i < cps.size() && !is_letter(cps[i]) && !is_number(cps[i]) && !isspace(cps[i])) {
|
||||||
token += codepoint_to_utf8(cps[i]);
|
token += codepoint_to_utf8(cps[i]);
|
||||||
@ -915,6 +939,40 @@ std::vector<std::string> token_split(const std::string& text) {
|
|||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> split_with_special_tokens(
|
||||||
|
const std::string& text,
|
||||||
|
const std::vector<std::string>& special_tokens) {
|
||||||
|
std::vector<std::string> result;
|
||||||
|
size_t pos = 0;
|
||||||
|
size_t text_len = text.size();
|
||||||
|
|
||||||
|
while (pos < text_len) {
|
||||||
|
size_t next_pos = text_len;
|
||||||
|
std::string matched_token;
|
||||||
|
|
||||||
|
for (const auto& token : special_tokens) {
|
||||||
|
size_t token_pos = text.find(token, pos);
|
||||||
|
if (token_pos != std::string::npos && token_pos < next_pos) {
|
||||||
|
next_pos = token_pos;
|
||||||
|
matched_token = token;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (next_pos > pos) {
|
||||||
|
result.push_back(text.substr(pos, next_pos - pos));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!matched_token.empty()) {
|
||||||
|
result.push_back(matched_token);
|
||||||
|
pos = next_pos + matched_token.size();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// int main() {
|
// int main() {
|
||||||
// std::string text = "I'm testing C++ token_split function. 你好,世界! 123";
|
// std::string text = "I'm testing C++ token_split function. 你好,世界! 123";
|
||||||
// auto tokens = token_split(text);
|
// auto tokens = token_split(text);
|
||||||
|
|||||||
@ -5,5 +5,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
std::vector<std::string> token_split(const std::string& text);
|
std::vector<std::string> token_split(const std::string& text);
|
||||||
|
std::vector<std::string> split_with_special_tokens(const std::string& text, const std::vector<std::string>& special_tokens);
|
||||||
|
|
||||||
#endif // __TOKENIZE_UTIL__
|
#endif // __TOKENIZE_UTIL__
|
||||||
Loading…
x
Reference in New Issue
Block a user