scirs2_text/tokenization/
wordpiece.rs1use std::collections::HashMap;
9
10use crate::error::{Result, TextError};
11
12#[derive(Debug, Clone)]
17pub struct BasicTokenizer {
18 pub do_lower_case: bool,
20 pub strip_accents: bool,
22}
23
24impl BasicTokenizer {
25 pub fn new(do_lower_case: bool, strip_accents: bool) -> Self {
27 BasicTokenizer {
28 do_lower_case,
29 strip_accents,
30 }
31 }
32
33 pub fn tokenize(&self, text: &str) -> Vec<String> {
35 let text = if self.do_lower_case {
36 text.to_lowercase()
37 } else {
38 text.to_string()
39 };
40
41 let text = if self.strip_accents {
43 strip_accents_str(&text)
44 } else {
45 text
46 };
47
48 let mut spaced = String::with_capacity(text.len() + 16);
50 for ch in text.chars() {
51 if ch.is_whitespace() {
52 spaced.push(' ');
53 } else if is_punctuation(ch) || is_chinese_char(ch) {
54 spaced.push(' ');
55 spaced.push(ch);
56 spaced.push(' ');
57 } else {
58 spaced.push(ch);
59 }
60 }
61
62 spaced
63 .split_whitespace()
64 .filter(|s| !s.is_empty())
65 .map(|s| s.to_string())
66 .collect()
67 }
68}
69
70impl Default for BasicTokenizer {
71 fn default() -> Self {
72 BasicTokenizer::new(true, true)
73 }
74}
75
76fn is_combining_mark(ch: char) -> bool {
79 let cp = ch as u32;
80 (0x0300..=0x036F).contains(&cp)
84 || (0x1DC0..=0x1DFF).contains(&cp)
85 || (0x1AB0..=0x1AFF).contains(&cp)
86}
87
88fn strip_accents_str(s: &str) -> String {
90 use unicode_normalization::UnicodeNormalization;
93 s.nfd().filter(|&ch| !is_combining_mark(ch)).collect()
94}
95
96fn is_punctuation(ch: char) -> bool {
98 if (ch as u32) <= 47
99 || (58..=64).contains(&(ch as u32))
100 || (91..=96).contains(&(ch as u32))
101 || (123..=126).contains(&(ch as u32))
102 {
103 return true;
104 }
105 ch.is_ascii_punctuation() || ch == '。' || ch == ','
106}
107
108fn is_chinese_char(ch: char) -> bool {
110 let cp = ch as u32;
111 (0x4E00..=0x9FFF).contains(&cp)
112 || (0x3400..=0x4DBF).contains(&cp)
113 || (0x20000..=0x2A6DF).contains(&cp)
114 || (0x2A700..=0x2B73F).contains(&cp)
115 || (0x2B740..=0x2B81F).contains(&cp)
116 || (0x2B820..=0x2CEAF).contains(&cp)
117 || (0xF900..=0xFAFF).contains(&cp)
118 || (0x2F800..=0x2FA1F).contains(&cp)
119}
120
121#[derive(Debug, Clone)]
129pub struct WordPieceTokenizer {
130 vocab: HashMap<String, u32>,
131 id_to_token: Vec<String>,
132 unk_id: u32,
133 max_input_chars_per_word: usize,
134 basic: BasicTokenizer,
135}
136
137impl WordPieceTokenizer {
138 const UNK_TOKEN: &'static str = "[UNK]";
140 const CLS_TOKEN: &'static str = "[CLS]";
141 const SEP_TOKEN: &'static str = "[SEP]";
142 const PAD_TOKEN: &'static str = "[PAD]";
143 const MASK_TOKEN: &'static str = "[MASK]";
144
145 pub fn from_vocab(mut vocab: HashMap<String, u32>) -> Self {
150 if !vocab.contains_key(Self::UNK_TOKEN) {
152 let next_id = vocab.len() as u32;
153 vocab.insert(Self::UNK_TOKEN.to_string(), next_id);
154 }
155 let unk_id = vocab[Self::UNK_TOKEN];
156
157 let max_id = vocab.values().copied().max().unwrap_or(0) as usize;
159 let mut id_to_token = vec![String::new(); max_id + 1];
160 for (tok, &id) in &vocab {
161 if let Some(slot) = id_to_token.get_mut(id as usize) {
162 *slot = tok.clone();
163 }
164 }
165
166 WordPieceTokenizer {
167 vocab,
168 id_to_token,
169 unk_id,
170 max_input_chars_per_word: 200,
171 basic: BasicTokenizer::default(),
172 }
173 }
174
175 pub fn from_vocab_list(tokens: &[impl AsRef<str>]) -> Self {
178 let vocab: HashMap<String, u32> = tokens
179 .iter()
180 .enumerate()
181 .map(|(i, t)| (t.as_ref().to_string(), i as u32))
182 .collect();
183 Self::from_vocab(vocab)
184 }
185
186 pub fn with_max_input_chars(mut self, n: usize) -> Self {
189 self.max_input_chars_per_word = n;
190 self
191 }
192
193 fn wordpiece_word(&self, word: &str) -> Vec<String> {
197 let chars: Vec<char> = word.chars().collect();
198 if chars.len() > self.max_input_chars_per_word {
199 return vec![Self::UNK_TOKEN.to_string()];
200 }
201
202 let mut sub_tokens: Vec<String> = Vec::new();
203 let mut start = 0usize;
204 let n = chars.len();
205 let mut is_bad = false;
206
207 while start < n {
208 let mut end = n;
209 let mut found: Option<String> = None;
210
211 while start < end {
212 let substr: String = chars[start..end].iter().collect();
213 let candidate = if start == 0 {
214 substr.clone()
215 } else {
216 format!("##{}", substr)
217 };
218
219 if self.vocab.contains_key(&candidate) {
220 found = Some(candidate);
221 break;
222 }
223 if end == start + 1 {
224 is_bad = true;
226 break;
227 }
228 end -= 1;
229 }
230
231 if is_bad {
232 break;
233 }
234
235 match found {
236 Some(tok) => {
237 sub_tokens.push(tok);
238 start = end;
239 }
240 None => {
241 is_bad = true;
242 break;
243 }
244 }
245 }
246
247 if is_bad {
248 vec![Self::UNK_TOKEN.to_string()]
249 } else {
250 sub_tokens
251 }
252 }
253
254 pub fn tokenize(&self, text: &str) -> Vec<u32> {
258 self.tokenize_to_strings(text)
259 .iter()
260 .map(|tok| self.vocab.get(tok.as_str()).copied().unwrap_or(self.unk_id))
261 .collect()
262 }
263
264 pub fn tokenize_to_strings(&self, text: &str) -> Vec<String> {
266 let words = self.basic.tokenize(text);
267 words.iter().flat_map(|w| self.wordpiece_word(w)).collect()
268 }
269
270 pub fn decode(&self, ids: &[u32]) -> String {
272 let mut out = String::new();
273 for &id in ids {
274 let tok = self
275 .id_to_token
276 .get(id as usize)
277 .map(|s| s.as_str())
278 .unwrap_or("[UNK]");
279
280 if tok == Self::PAD_TOKEN {
282 continue;
283 }
284
285 if tok.starts_with("##") {
286 out.push_str(&tok[2..]);
287 } else if !out.is_empty() && tok != Self::CLS_TOKEN && tok != Self::SEP_TOKEN {
288 out.push(' ');
289 out.push_str(tok);
290 } else {
291 out.push_str(tok);
292 }
293 }
294 out
295 }
296
297 pub fn encode(
303 &self,
304 text: &str,
305 max_length: usize,
306 add_special_tokens: bool,
307 ) -> Result<(Vec<u32>, Vec<u8>)> {
308 if max_length == 0 {
309 return Err(TextError::InvalidInput(
310 "max_length must be > 0".to_string(),
311 ));
312 }
313
314 let cls_id = self
315 .vocab
316 .get(Self::CLS_TOKEN)
317 .copied()
318 .unwrap_or(self.unk_id);
319 let sep_id = self
320 .vocab
321 .get(Self::SEP_TOKEN)
322 .copied()
323 .unwrap_or(self.unk_id);
324 let pad_id = self
325 .vocab
326 .get(Self::PAD_TOKEN)
327 .copied()
328 .unwrap_or(self.unk_id);
329
330 let token_ids = self.tokenize(text);
331
332 let reserve = if add_special_tokens { 2 } else { 0 };
334 let content_budget = max_length.saturating_sub(reserve);
335 let truncated: Vec<u32> = token_ids.into_iter().take(content_budget).collect();
336
337 let mut ids: Vec<u32> = Vec::with_capacity(max_length);
338 if add_special_tokens {
339 ids.push(cls_id);
340 }
341 ids.extend_from_slice(&truncated);
342 if add_special_tokens {
343 ids.push(sep_id);
344 }
345
346 let real_len = ids.len();
347 while ids.len() < max_length {
349 ids.push(pad_id);
350 }
351
352 let mut mask: Vec<u8> = vec![0u8; max_length];
353 for m in mask.iter_mut().take(real_len) {
354 *m = 1;
355 }
356
357 Ok((ids, mask))
358 }
359
360 pub fn vocab_size(&self) -> usize {
362 self.vocab.len()
363 }
364
365 pub fn vocab_snapshot(&self) -> HashMap<String, u32> {
369 self.vocab.clone()
370 }
371}
372
373#[cfg(test)]
376mod tests {
377 use super::*;
378 use std::collections::HashMap;
379
380 fn mini_vocab() -> HashMap<String, u32> {
381 let mut v = HashMap::new();
382 for (i, tok) in [
383 "[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]", "he", "llo", "##llo", "world", "##world",
384 "want", "##ed", "to", "un", "##want", "##ed", "low", "##er", "##est", "new", "##er",
385 "##est", "h", "e", "l", "o", "w", "r", "d",
386 ]
387 .iter()
388 .enumerate()
389 {
390 v.entry(tok.to_string()).or_insert(i as u32);
391 }
392 v
393 }
394
395 #[test]
396 fn test_basic_tokenizer_lower() {
397 let tok = BasicTokenizer::new(true, false);
398 let tokens = tok.tokenize("Hello, World!");
399 assert!(tokens.iter().any(|t| t == "hello"));
400 assert!(tokens.iter().any(|t| t == "world"));
401 assert!(tokens.iter().any(|t| t == ","));
402 assert!(tokens.iter().any(|t| t == "!"));
403 }
404
405 #[test]
406 fn test_basic_tokenizer_no_lower() {
407 let tok = BasicTokenizer::new(false, false);
408 let tokens = tok.tokenize("Hello World");
409 assert!(tokens.iter().any(|t| t == "Hello"));
410 assert!(tokens.iter().any(|t| t == "World"));
411 }
412
413 #[test]
414 fn test_wordpiece_tokenize_to_strings_known() {
415 let vocab = mini_vocab();
416 let wp = WordPieceTokenizer::from_vocab(vocab);
417 let tokens = wp.tokenize_to_strings("low");
419 assert!(!tokens.iter().any(|t| t == "[UNK]"), "got {:?}", tokens);
420 }
421
422 #[test]
423 fn test_wordpiece_encode_length() {
424 let vocab = mini_vocab();
425 let wp = WordPieceTokenizer::from_vocab(vocab);
426 let (ids, mask) = wp.encode("low", 8, true).expect("encode failed");
427 assert_eq!(ids.len(), 8);
428 assert_eq!(mask.len(), 8);
429 assert_eq!(mask[0], 1);
431 }
432
433 #[test]
434 fn test_wordpiece_encode_truncation() {
435 let vocab = mini_vocab();
436 let wp = WordPieceTokenizer::from_vocab(vocab);
437 let (ids, mask) = wp
438 .encode("low low low low", 4, true)
439 .expect("encode failed");
440 assert_eq!(ids.len(), 4);
441 assert_eq!(mask.len(), 4);
442 }
443
444 #[test]
445 fn test_wordpiece_encode_no_special_tokens() {
446 let vocab = mini_vocab();
447 let wp = WordPieceTokenizer::from_vocab(vocab);
448 let (ids, mask) = wp.encode("low", 4, false).expect("encode failed");
449 assert_eq!(ids.len(), 4);
450 assert!(mask[0] == 1);
452 }
453
454 #[test]
455 fn test_wordpiece_decode_strips_double_hash() {
456 let vocab = mini_vocab();
457 let wp = WordPieceTokenizer::from_vocab(vocab);
458 let low_id = *wp.vocab.get("low").unwrap();
460 let er_id = *wp.vocab.get("##er").unwrap();
461 let decoded = wp.decode(&[low_id, er_id]);
462 assert_eq!(decoded, "lower");
463 }
464
465 #[test]
466 fn test_basic_tokenizer_punctuation_isolation() {
467 let tok = BasicTokenizer::new(false, false);
468 let tokens = tok.tokenize("It's fine.");
469 assert!(tokens.contains(&".".to_string()));
471 }
472}