tokenizers/pre_tokenizers/
bert.rs

1use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
2use crate::utils::macro_rules_attribute;
3use unicode_categories::UnicodeCategories;
4
5fn is_bert_punc(x: char) -> bool {
6    char::is_ascii_punctuation(&x) || x.is_punctuation()
7}
8
9#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10#[macro_rules_attribute(impl_serde_type!)]
11pub struct BertPreTokenizer;
12
13impl PreTokenizer for BertPreTokenizer {
14    fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
15        pretokenized.split(|_, s| s.split(char::is_whitespace, SplitDelimiterBehavior::Removed))?;
16        pretokenized.split(|_, s| s.split(is_bert_punc, SplitDelimiterBehavior::Isolated))
17    }
18}
19
20#[cfg(test)]
21mod tests {
22    use super::*;
23    use crate::{NormalizedString, OffsetReferential, OffsetType};
24
25    #[test]
26    fn basic() {
27        let pretok = BertPreTokenizer;
28        let mut pretokenized: PreTokenizedString = "Hey friend!     How are you?!?".into();
29        pretok.pre_tokenize(&mut pretokenized).unwrap();
30        assert_eq!(
31            pretokenized
32                .get_splits(OffsetReferential::Original, OffsetType::Byte)
33                .into_iter()
34                .map(|(s, o, _)| (s, o))
35                .collect::<Vec<_>>(),
36            vec![
37                ("Hey", (0, 3)),
38                ("friend", (4, 10)),
39                ("!", (10, 11)),
40                ("How", (16, 19)),
41                ("are", (20, 23)),
42                ("you", (24, 27)),
43                ("?", (27, 28)),
44                ("!", (28, 29)),
45                ("?", (29, 30)),
46            ]
47        );
48    }
49
50    #[test]
51    fn chinese_chars() {
52        let mut n = NormalizedString::from("野口里佳 Noguchi Rika");
53        n.transform(
54            n.get().to_owned().chars().flat_map(|c| {
55                if (c as usize) > 0x4E00 {
56                    vec![(' ', 0), (c, 1), (' ', 1)]
57                } else {
58                    vec![(c, 0)]
59                }
60            }),
61            0,
62        );
63        let mut pretokenized = n.into();
64        let pretok = BertPreTokenizer;
65        pretok.pre_tokenize(&mut pretokenized).unwrap();
66        assert_eq!(
67            pretokenized
68                .get_splits(OffsetReferential::Original, OffsetType::Byte)
69                .into_iter()
70                .map(|(s, o, _)| (s, o))
71                .collect::<Vec<_>>(),
72            vec![
73                ("野", (0, 3)),
74                ("口", (3, 6)),
75                ("里", (6, 9)),
76                ("佳", (9, 12)),
77                ("Noguchi", (13, 20)),
78                ("Rika", (21, 25))
79            ]
80        );
81    }
82}