tokenizers/pre_tokenizers/unicode_scripts/
pre_tokenizer.rs

1use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script};
2use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result};
3use crate::utils::macro_rules_attribute;
4
5#[derive(Clone, Debug, PartialEq, Eq)]
6#[macro_rules_attribute(impl_serde_type!)]
7pub struct UnicodeScripts;
8
9impl UnicodeScripts {
10    pub fn new() -> Self {
11        Self {}
12    }
13}
14
15impl Default for UnicodeScripts {
16    fn default() -> Self {
17        Self::new()
18    }
19}
20
21// This code exists in the Unigram default IsValidSentencePiece.
22// It could be integrated directly within `get_script` but I
23// think it's kind of tricky to see those modifications later
24// I am guessing release mode will optimize this away anyway.
25fn fixed_script(c: char) -> Script {
26    let raw_script = get_script(c);
27    if c as u32 == 0x30FC {
28        Script::Han
29    } else if c == ' ' {
30        Script::Any
31    } else {
32        match raw_script {
33            Script::Hiragana => Script::Han,
34            Script::Katakana => Script::Han,
35            script => script,
36        }
37    }
38}
39
40impl PreTokenizer for UnicodeScripts {
41    fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
42        pretokenized.split(|_, normalized| {
43            let mut last_script = None;
44            let mut offset = 0;
45            let mut ranges: Vec<_> = normalized
46                .get()
47                .chars()
48                .filter_map(|c| {
49                    let script = Some(fixed_script(c));
50                    let result = if script != Some(Script::Any)
51                        && last_script != Some(Script::Any)
52                        && last_script != script
53                    {
54                        Some(offset)
55                    } else {
56                        None
57                    };
58                    offset += c.len_utf8();
59                    if script != Some(Script::Any) {
60                        last_script = script;
61                    }
62
63                    result
64                })
65                .collect();
66            ranges.push(normalized.get().len());
67            Ok(ranges
68                .windows(2)
69                .map(|item| {
70                    normalized
71                        .slice(Range::Normalized(item[0]..item[1]))
72                        .expect("NormalizedString bad split")
73                })
74                .collect::<Vec<_>>())
75        })
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use crate::OffsetReferential;
83    use crate::OffsetType;
84
85    #[test]
86    fn basic() {
87        let pretok = UnicodeScripts {};
88        let mut pretokenized = PreTokenizedString::from("どこで生れ。Yes");
89        pretok.pre_tokenize(&mut pretokenized).unwrap();
90        assert_eq!(
91            pretokenized
92                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
93                .into_iter()
94                .map(|(s, o, _)| (s, o))
95                .collect::<Vec<_>>(),
96            vec![("どこで生れ", (0, 15)), ("。", (15, 18)), ("Yes", (18, 21))]
97        );
98        assert_eq!(
99            pretokenized
100                .get_splits(OffsetReferential::Original, OffsetType::Byte)
101                .into_iter()
102                .map(|(s, o, _)| (s, o))
103                .collect::<Vec<_>>(),
104            vec![("どこで生れ", (0, 15)), ("。", (15, 18)), ("Yes", (18, 21))]
105        );
106    }
107
108    #[test]
109    fn spaces_are_included_in_every_script() {
110        let pretok = UnicodeScripts {};
111        let mut pretokenized = PreTokenizedString::from("Apples are りんご 林檎");
112        pretok.pre_tokenize(&mut pretokenized).unwrap();
113        assert_eq!(
114            pretokenized
115                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
116                .into_iter()
117                .map(|(s, o, _)| (s, o))
118                .collect::<Vec<_>>(),
119            vec![("Apples are ", (0, 11)), ("りんご 林檎", (11, 27))]
120        );
121        assert_eq!(
122            pretokenized
123                .get_splits(OffsetReferential::Original, OffsetType::Byte)
124                .into_iter()
125                .map(|(s, o, _)| (s, o))
126                .collect::<Vec<_>>(),
127            vec![("Apples are ", (0, 11)), ("りんご 林檎", (11, 27))]
128        );
129    }
130
131    #[test]
132    fn test_unicode_script() {
133        assert_eq!(Script::Han, fixed_script('京'));
134        assert_eq!(Script::Han, fixed_script('太'));
135        assert_eq!(Script::Han, fixed_script('い'));
136        assert_eq!(Script::Han, fixed_script('グ'));
137        assert_eq!(Script::Han, fixed_script('ー'));
138        assert_eq!(Script::Latin, fixed_script('a'));
139        assert_eq!(Script::Latin, fixed_script('A'));
140        assert_eq!(Script::Common, fixed_script('0'));
141        assert_eq!(Script::Common, fixed_script('$'));
142        assert_eq!(Script::Common, fixed_script('@'));
143        assert_eq!(Script::Common, fixed_script('-'));
144        assert_eq!(Script::Any, fixed_script(' '));
145    }
146}