tokenizers/pre_tokenizers/unicode_scripts/
pre_tokenizer.rs1use crate::pre_tokenizers::unicode_scripts::scripts::{get_script, Script};
2use crate::tokenizer::{normalizer::Range, PreTokenizedString, PreTokenizer, Result};
3use crate::utils::macro_rules_attribute;
4
5#[derive(Clone, Debug, PartialEq, Eq)]
6#[macro_rules_attribute(impl_serde_type!)]
7pub struct UnicodeScripts;
8
9impl UnicodeScripts {
10 pub fn new() -> Self {
11 Self {}
12 }
13}
14
15impl Default for UnicodeScripts {
16 fn default() -> Self {
17 Self::new()
18 }
19}
20
21fn fixed_script(c: char) -> Script {
26 let raw_script = get_script(c);
27 if c as u32 == 0x30FC {
28 Script::Han
29 } else if c == ' ' {
30 Script::Any
31 } else {
32 match raw_script {
33 Script::Hiragana => Script::Han,
34 Script::Katakana => Script::Han,
35 script => script,
36 }
37 }
38}
39
40impl PreTokenizer for UnicodeScripts {
41 fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
42 pretokenized.split(|_, normalized| {
43 let mut last_script = None;
44 let mut offset = 0;
45 let mut ranges: Vec<_> = normalized
46 .get()
47 .chars()
48 .filter_map(|c| {
49 let script = Some(fixed_script(c));
50 let result = if script != Some(Script::Any)
51 && last_script != Some(Script::Any)
52 && last_script != script
53 {
54 Some(offset)
55 } else {
56 None
57 };
58 offset += c.len_utf8();
59 if script != Some(Script::Any) {
60 last_script = script;
61 }
62
63 result
64 })
65 .collect();
66 ranges.push(normalized.get().len());
67 Ok(ranges
68 .windows(2)
69 .map(|item| {
70 normalized
71 .slice(Range::Normalized(item[0]..item[1]))
72 .expect("NormalizedString bad split")
73 })
74 .collect::<Vec<_>>())
75 })
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use crate::OffsetReferential;
83 use crate::OffsetType;
84
85 #[test]
86 fn basic() {
87 let pretok = UnicodeScripts {};
88 let mut pretokenized = PreTokenizedString::from("どこで生れ。Yes");
89 pretok.pre_tokenize(&mut pretokenized).unwrap();
90 assert_eq!(
91 pretokenized
92 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
93 .into_iter()
94 .map(|(s, o, _)| (s, o))
95 .collect::<Vec<_>>(),
96 vec![("どこで生れ", (0, 15)), ("。", (15, 18)), ("Yes", (18, 21))]
97 );
98 assert_eq!(
99 pretokenized
100 .get_splits(OffsetReferential::Original, OffsetType::Byte)
101 .into_iter()
102 .map(|(s, o, _)| (s, o))
103 .collect::<Vec<_>>(),
104 vec![("どこで生れ", (0, 15)), ("。", (15, 18)), ("Yes", (18, 21))]
105 );
106 }
107
108 #[test]
109 fn spaces_are_included_in_every_script() {
110 let pretok = UnicodeScripts {};
111 let mut pretokenized = PreTokenizedString::from("Apples are りんご 林檎");
112 pretok.pre_tokenize(&mut pretokenized).unwrap();
113 assert_eq!(
114 pretokenized
115 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
116 .into_iter()
117 .map(|(s, o, _)| (s, o))
118 .collect::<Vec<_>>(),
119 vec![("Apples are ", (0, 11)), ("りんご 林檎", (11, 27))]
120 );
121 assert_eq!(
122 pretokenized
123 .get_splits(OffsetReferential::Original, OffsetType::Byte)
124 .into_iter()
125 .map(|(s, o, _)| (s, o))
126 .collect::<Vec<_>>(),
127 vec![("Apples are ", (0, 11)), ("りんご 林檎", (11, 27))]
128 );
129 }
130
131 #[test]
132 fn test_unicode_script() {
133 assert_eq!(Script::Han, fixed_script('京'));
134 assert_eq!(Script::Han, fixed_script('太'));
135 assert_eq!(Script::Han, fixed_script('い'));
136 assert_eq!(Script::Han, fixed_script('グ'));
137 assert_eq!(Script::Han, fixed_script('ー'));
138 assert_eq!(Script::Latin, fixed_script('a'));
139 assert_eq!(Script::Latin, fixed_script('A'));
140 assert_eq!(Script::Common, fixed_script('0'));
141 assert_eq!(Script::Common, fixed_script('$'));
142 assert_eq!(Script::Common, fixed_script('@'));
143 assert_eq!(Script::Common, fixed_script('-'));
144 assert_eq!(Script::Any, fixed_script(' '));
145 }
146}