scirs2_text/tokenizers/
unicode.rs1use std::collections::HashMap;
24
25#[derive(Debug, Clone)]
29pub struct UnicodeTokenizerConfig {
30 pub lowercase: bool,
32 pub strip_accents: bool,
35 pub split_on_whitespace: bool,
37 pub split_on_punctuation: bool,
39 pub handle_cjk: bool,
42 pub min_token_length: usize,
44 pub max_tokens: Option<usize>,
46}
47
48impl Default for UnicodeTokenizerConfig {
49 fn default() -> Self {
50 UnicodeTokenizerConfig {
51 lowercase: true,
52 strip_accents: false,
53 split_on_whitespace: true,
54 split_on_punctuation: true,
55 handle_cjk: true,
56 min_token_length: 1,
57 max_tokens: None,
58 }
59 }
60}
61
62pub struct UnicodeTokenizer {
73 config: UnicodeTokenizerConfig,
74}
75
76impl UnicodeTokenizer {
77 pub fn new(config: UnicodeTokenizerConfig) -> Self {
79 UnicodeTokenizer { config }
80 }
81
82 pub fn tokenize(&self, text: &str) -> Vec<String> {
84 if text.is_empty() {
85 return vec![];
86 }
87
88 let working: String = if self.config.lowercase {
90 text.to_lowercase()
91 } else {
92 text.to_string()
93 };
94
95 let working: String = if self.config.handle_cjk {
97 self.add_cjk_spaces(&working)
98 } else {
99 working
100 };
101
102 let working: String = if self.config.strip_accents {
104 Self::strip_accents_approx(&working)
105 } else {
106 working
107 };
108
109 let raw_tokens: Vec<String> = if self.config.split_on_whitespace {
111 working.split_whitespace().map(|s| s.to_string()).collect()
112 } else {
113 vec![working]
114 };
115
116 let tokens: Vec<String> = if self.config.split_on_punctuation {
118 raw_tokens
119 .into_iter()
120 .flat_map(|tok| self.split_on_punct(tok))
121 .collect()
122 } else {
123 raw_tokens
124 };
125
126 let mut tokens: Vec<String> = tokens
128 .into_iter()
129 .filter(|t| t.len() >= self.config.min_token_length)
130 .collect();
131
132 if let Some(max) = self.config.max_tokens {
134 tokens.truncate(max);
135 }
136
137 tokens
138 }
139
140 pub fn encode(&self, text: &str, vocab: &HashMap<String, usize>) -> Vec<usize> {
144 self.tokenize(text)
145 .iter()
146 .filter_map(|tok| vocab.get(tok).copied())
147 .collect()
148 }
149
150 fn add_cjk_spaces(&self, s: &str) -> String {
154 let mut out = String::with_capacity(s.len() + s.chars().count());
155 for c in s.chars() {
156 if Self::is_cjk(c) {
157 out.push(' ');
158 out.push(c);
159 out.push(' ');
160 } else {
161 out.push(c);
162 }
163 }
164 out
165 }
166
167 fn split_on_punct(&self, tok: String) -> Vec<String> {
169 let mut parts: Vec<String> = Vec::new();
170 let mut current = String::new();
171 for c in tok.chars() {
172 if Self::is_punctuation(c) {
173 if !current.is_empty() {
174 parts.push(current.clone());
175 current.clear();
176 }
177 parts.push(c.to_string());
178 } else {
179 current.push(c);
180 }
181 }
182 if !current.is_empty() {
183 parts.push(current);
184 }
185 parts
186 }
187
188 #[inline]
195 pub fn is_cjk(c: char) -> bool {
196 matches!(c as u32,
197 0x4E00..=0x9FFF | 0x3400..=0x4DBF | 0x20000..=0x2A6DF )
201 }
202
203 #[inline]
205 pub fn is_punctuation(c: char) -> bool {
206 let cp = c as u32;
207 if matches!(cp, 33..=47 | 58..=64 | 91..=96 | 123..=126) {
209 return true;
210 }
211 if (0x2000..=0x206F).contains(&cp) {
213 return true;
214 }
215 if (0x3000..=0x303F).contains(&cp) {
217 return true;
218 }
219 if (0xFF00..=0xFFEF).contains(&cp) {
221 return true;
222 }
223 false
224 }
225
226 pub fn strip_accents_approx(s: &str) -> String {
233 s.chars()
234 .filter(|&c| {
235 let cp = c as u32;
236 !(0x0300..=0x036F).contains(&cp)
237 })
238 .collect()
239 }
240}
241
242impl std::fmt::Debug for UnicodeTokenizer {
243 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
244 f.debug_struct("UnicodeTokenizer")
245 .field("lowercase", &self.config.lowercase)
246 .field("handle_cjk", &self.config.handle_cjk)
247 .field("split_on_punctuation", &self.config.split_on_punctuation)
248 .finish()
249 }
250}
251
252#[cfg(test)]
255mod tests {
256 use super::*;
257
258 fn default_tok() -> UnicodeTokenizer {
259 UnicodeTokenizer::new(UnicodeTokenizerConfig::default())
260 }
261
262 #[test]
265 fn test_unicode_tokenizer_empty() {
266 let tok = default_tok();
267 let tokens = tok.tokenize("");
268 assert!(
269 tokens.is_empty(),
270 "empty string must produce empty token list"
271 );
272 }
273
274 #[test]
277 fn test_unicode_tokenizer_cjk() {
278 let tok = UnicodeTokenizer::new(UnicodeTokenizerConfig {
279 lowercase: false,
280 handle_cjk: true,
281 split_on_punctuation: false,
282 ..Default::default()
283 });
284 let tokens = tok.tokenize("你好world");
285 assert!(
287 tokens.contains(&"你".to_string()),
288 "CJK char '你' should be its own token"
289 );
290 assert!(
291 tokens.contains(&"好".to_string()),
292 "CJK char '好' should be its own token"
293 );
294 }
295
296 #[test]
299 fn test_unicode_tokenizer_punctuation() {
300 let tok = UnicodeTokenizer::new(UnicodeTokenizerConfig {
301 lowercase: false,
302 handle_cjk: false,
303 split_on_punctuation: true,
304 ..Default::default()
305 });
306 let tokens = tok.tokenize("Hello,world!");
307 assert!(
309 tokens.contains(&",".to_string()),
310 "comma must be a separate token"
311 );
312 assert!(
313 tokens.contains(&"!".to_string()),
314 "exclamation must be a separate token"
315 );
316 assert!(
317 tokens.contains(&"Hello".to_string()),
318 "Hello must remain a token"
319 );
320 }
321
322 #[test]
323 fn test_unicode_tokenizer_lowercase() {
324 let tok = UnicodeTokenizer::new(UnicodeTokenizerConfig {
325 lowercase: true,
326 split_on_punctuation: false,
327 handle_cjk: false,
328 ..Default::default()
329 });
330 let tokens = tok.tokenize("Hello World");
331 assert!(tokens.contains(&"hello".to_string()));
332 assert!(tokens.contains(&"world".to_string()));
333 }
334
335 #[test]
336 fn test_unicode_tokenizer_max_tokens() {
337 let tok = UnicodeTokenizer::new(UnicodeTokenizerConfig {
338 max_tokens: Some(2),
339 split_on_punctuation: false,
340 handle_cjk: false,
341 ..Default::default()
342 });
343 let tokens = tok.tokenize("one two three four five");
344 assert_eq!(tokens.len(), 2);
345 }
346
347 #[test]
348 fn test_unicode_tokenizer_min_length() {
349 let tok = UnicodeTokenizer::new(UnicodeTokenizerConfig {
350 min_token_length: 3,
351 split_on_punctuation: false,
352 handle_cjk: false,
353 ..Default::default()
354 });
355 let tokens = tok.tokenize("a bb ccc dddd");
356 for t in &tokens {
358 assert!(t.len() >= 3, "token '{t}' is too short");
359 }
360 }
361
362 #[test]
363 fn test_is_cjk_basic() {
364 assert!(UnicodeTokenizer::is_cjk('中')); assert!(UnicodeTokenizer::is_cjk('日')); assert!(!UnicodeTokenizer::is_cjk('A'));
367 assert!(!UnicodeTokenizer::is_cjk('é'));
368 }
369
370 #[test]
371 fn test_is_punctuation_ascii() {
372 assert!(UnicodeTokenizer::is_punctuation(','));
373 assert!(UnicodeTokenizer::is_punctuation('!'));
374 assert!(UnicodeTokenizer::is_punctuation(';'));
375 assert!(!UnicodeTokenizer::is_punctuation('a'));
376 assert!(!UnicodeTokenizer::is_punctuation('5'));
377 }
378
379 #[test]
380 fn test_strip_accents_approx() {
381 let decomposed = "e\u{0301}"; let stripped = UnicodeTokenizer::strip_accents_approx(decomposed);
385 assert_eq!(stripped, "e", "combining accent should be stripped");
386 }
387
388 #[test]
389 fn test_encode_returns_vocab_indices() {
390 let tok = UnicodeTokenizer::new(UnicodeTokenizerConfig {
391 split_on_punctuation: false,
392 handle_cjk: false,
393 ..Default::default()
394 });
395 let mut vocab = HashMap::new();
396 vocab.insert("hello".to_string(), 0usize);
397 vocab.insert("world".to_string(), 1usize);
398 let indices = tok.encode("Hello World", &vocab);
399 assert_eq!(indices, vec![0, 1]);
400 }
401
402 #[test]
403 fn test_tokenize_mixed_script() {
404 let tok = UnicodeTokenizer::new(UnicodeTokenizerConfig {
405 handle_cjk: true,
406 split_on_punctuation: true,
407 lowercase: true,
408 ..Default::default()
409 });
410 let tokens = tok.tokenize("Hello 世界 world!");
411 assert!(!tokens.is_empty());
412 assert!(
414 tokens.iter().any(|t| t == "world"),
415 "world should be a token"
416 );
417 }
418}