tokenizers/pre_tokenizers/
bert.rs1use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
2use crate::utils::macro_rules_attribute;
3use unicode_categories::UnicodeCategories;
4
5fn is_bert_punc(x: char) -> bool {
6 char::is_ascii_punctuation(&x) || x.is_punctuation()
7}
8
9#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10#[macro_rules_attribute(impl_serde_type!)]
11pub struct BertPreTokenizer;
12
13impl PreTokenizer for BertPreTokenizer {
14 fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
15 pretokenized.split(|_, s| s.split(char::is_whitespace, SplitDelimiterBehavior::Removed))?;
16 pretokenized.split(|_, s| s.split(is_bert_punc, SplitDelimiterBehavior::Isolated))
17 }
18}
19
20#[cfg(test)]
21mod tests {
22 use super::*;
23 use crate::{NormalizedString, OffsetReferential, OffsetType};
24
25 #[test]
26 fn basic() {
27 let pretok = BertPreTokenizer;
28 let mut pretokenized: PreTokenizedString = "Hey friend! How are you?!?".into();
29 pretok.pre_tokenize(&mut pretokenized).unwrap();
30 assert_eq!(
31 pretokenized
32 .get_splits(OffsetReferential::Original, OffsetType::Byte)
33 .into_iter()
34 .map(|(s, o, _)| (s, o))
35 .collect::<Vec<_>>(),
36 vec![
37 ("Hey", (0, 3)),
38 ("friend", (4, 10)),
39 ("!", (10, 11)),
40 ("How", (16, 19)),
41 ("are", (20, 23)),
42 ("you", (24, 27)),
43 ("?", (27, 28)),
44 ("!", (28, 29)),
45 ("?", (29, 30)),
46 ]
47 );
48 }
49
50 #[test]
51 fn chinese_chars() {
52 let mut n = NormalizedString::from("野口里佳 Noguchi Rika");
53 n.transform(
54 n.get().to_owned().chars().flat_map(|c| {
55 if (c as usize) > 0x4E00 {
56 vec![(' ', 0), (c, 1), (' ', 1)]
57 } else {
58 vec![(c, 0)]
59 }
60 }),
61 0,
62 );
63 let mut pretokenized = n.into();
64 let pretok = BertPreTokenizer;
65 pretok.pre_tokenize(&mut pretokenized).unwrap();
66 assert_eq!(
67 pretokenized
68 .get_splits(OffsetReferential::Original, OffsetType::Byte)
69 .into_iter()
70 .map(|(s, o, _)| (s, o))
71 .collect::<Vec<_>>(),
72 vec![
73 ("野", (0, 3)),
74 ("口", (3, 6)),
75 ("里", (6, 9)),
76 ("佳", (9, 12)),
77 ("Noguchi", (13, 20)),
78 ("Rika", (21, 25))
79 ]
80 );
81 }
82}