tokenizers/pre_tokenizers/
mod.rs

1pub mod bert;
2pub mod byte_level;
3pub mod delimiter;
4pub mod digits;
5pub mod metaspace;
6pub mod punctuation;
7pub mod sequence;
8pub mod split;
9pub mod unicode_scripts;
10pub mod whitespace;
11
12use serde::{Deserialize, Deserializer, Serialize};
13
14use crate::pre_tokenizers::bert::BertPreTokenizer;
15use crate::pre_tokenizers::byte_level::ByteLevel;
16use crate::pre_tokenizers::delimiter::CharDelimiterSplit;
17use crate::pre_tokenizers::digits::Digits;
18use crate::pre_tokenizers::metaspace::Metaspace;
19use crate::pre_tokenizers::punctuation::Punctuation;
20use crate::pre_tokenizers::sequence::Sequence;
21use crate::pre_tokenizers::split::Split;
22use crate::pre_tokenizers::unicode_scripts::UnicodeScripts;
23use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
24use crate::{PreTokenizedString, PreTokenizer};
25
26#[derive(Serialize, Clone, Debug, PartialEq)]
27#[serde(untagged)]
28pub enum PreTokenizerWrapper {
29    BertPreTokenizer(BertPreTokenizer),
30    ByteLevel(ByteLevel),
31    Delimiter(CharDelimiterSplit),
32    Metaspace(Metaspace),
33    Whitespace(Whitespace),
34    Sequence(Sequence),
35    Split(Split),
36    Punctuation(Punctuation),
37    WhitespaceSplit(WhitespaceSplit),
38    Digits(Digits),
39    UnicodeScripts(UnicodeScripts),
40}
41
42impl PreTokenizer for PreTokenizerWrapper {
43    fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> crate::Result<()> {
44        match self {
45            Self::BertPreTokenizer(bpt) => bpt.pre_tokenize(normalized),
46            Self::ByteLevel(bpt) => bpt.pre_tokenize(normalized),
47            Self::Delimiter(dpt) => dpt.pre_tokenize(normalized),
48            Self::Metaspace(mspt) => mspt.pre_tokenize(normalized),
49            Self::Whitespace(wspt) => wspt.pre_tokenize(normalized),
50            Self::Punctuation(tok) => tok.pre_tokenize(normalized),
51            Self::Sequence(tok) => tok.pre_tokenize(normalized),
52            Self::Split(tok) => tok.pre_tokenize(normalized),
53            Self::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized),
54            Self::Digits(wspt) => wspt.pre_tokenize(normalized),
55            Self::UnicodeScripts(us) => us.pre_tokenize(normalized),
56        }
57    }
58}
59
60impl<'de> Deserialize<'de> for PreTokenizerWrapper {
61    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
62    where
63        D: Deserializer<'de>,
64    {
65        #[derive(Deserialize)]
66        pub struct Tagged {
67            #[serde(rename = "type")]
68            variant: EnumType,
69            #[serde(flatten)]
70            rest: serde_json::Value,
71        }
72        #[derive(Deserialize, Serialize)]
73        pub enum EnumType {
74            BertPreTokenizer,
75            ByteLevel,
76            Delimiter,
77            Metaspace,
78            Whitespace,
79            Sequence,
80            Split,
81            Punctuation,
82            WhitespaceSplit,
83            Digits,
84            UnicodeScripts,
85        }
86
87        #[derive(Deserialize)]
88        #[serde(untagged)]
89        pub enum PreTokenizerHelper {
90            Tagged(Tagged),
91            Legacy(serde_json::Value),
92        }
93
94        #[derive(Deserialize)]
95        #[serde(untagged)]
96        pub enum PreTokenizerUntagged {
97            BertPreTokenizer(BertPreTokenizer),
98            ByteLevel(ByteLevel),
99            Delimiter(CharDelimiterSplit),
100            Metaspace(Metaspace),
101            Whitespace(Whitespace),
102            Sequence(Sequence),
103            Split(Split),
104            Punctuation(Punctuation),
105            WhitespaceSplit(WhitespaceSplit),
106            Digits(Digits),
107            UnicodeScripts(UnicodeScripts),
108        }
109
110        let helper = PreTokenizerHelper::deserialize(deserializer)?;
111
112        Ok(match helper {
113            PreTokenizerHelper::Tagged(pretok) => {
114                let mut values: serde_json::Map<String, serde_json::Value> =
115                    serde_json::from_value(pretok.rest).map_err(serde::de::Error::custom)?;
116                values.insert(
117                    "type".to_string(),
118                    serde_json::to_value(&pretok.variant).map_err(serde::de::Error::custom)?,
119                );
120                let values = serde_json::Value::Object(values);
121                match pretok.variant {
122                    EnumType::BertPreTokenizer => PreTokenizerWrapper::BertPreTokenizer(
123                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
124                    ),
125                    EnumType::ByteLevel => PreTokenizerWrapper::ByteLevel(
126                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
127                    ),
128                    EnumType::Delimiter => PreTokenizerWrapper::Delimiter(
129                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
130                    ),
131                    EnumType::Metaspace => PreTokenizerWrapper::Metaspace(
132                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
133                    ),
134                    EnumType::Whitespace => PreTokenizerWrapper::Whitespace(
135                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
136                    ),
137                    EnumType::Sequence => PreTokenizerWrapper::Sequence(
138                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
139                    ),
140                    EnumType::Split => PreTokenizerWrapper::Split(
141                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
142                    ),
143                    EnumType::Punctuation => PreTokenizerWrapper::Punctuation(
144                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
145                    ),
146                    EnumType::WhitespaceSplit => PreTokenizerWrapper::WhitespaceSplit(
147                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
148                    ),
149                    EnumType::Digits => PreTokenizerWrapper::Digits(
150                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
151                    ),
152                    EnumType::UnicodeScripts => PreTokenizerWrapper::UnicodeScripts(
153                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
154                    ),
155                }
156            }
157
158            PreTokenizerHelper::Legacy(value) => {
159                let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
160                match untagged {
161                    PreTokenizerUntagged::BertPreTokenizer(bert) => {
162                        PreTokenizerWrapper::BertPreTokenizer(bert)
163                    }
164                    PreTokenizerUntagged::ByteLevel(byte_level) => {
165                        PreTokenizerWrapper::ByteLevel(byte_level)
166                    }
167                    PreTokenizerUntagged::Delimiter(delimiter) => {
168                        PreTokenizerWrapper::Delimiter(delimiter)
169                    }
170                    PreTokenizerUntagged::Metaspace(metaspace) => {
171                        PreTokenizerWrapper::Metaspace(metaspace)
172                    }
173                    PreTokenizerUntagged::Whitespace(whitespace) => {
174                        PreTokenizerWrapper::Whitespace(whitespace)
175                    }
176                    PreTokenizerUntagged::Sequence(sequence) => {
177                        PreTokenizerWrapper::Sequence(sequence)
178                    }
179                    PreTokenizerUntagged::Split(split) => PreTokenizerWrapper::Split(split),
180                    PreTokenizerUntagged::Punctuation(punctuation) => {
181                        PreTokenizerWrapper::Punctuation(punctuation)
182                    }
183                    PreTokenizerUntagged::WhitespaceSplit(whitespace_split) => {
184                        PreTokenizerWrapper::WhitespaceSplit(whitespace_split)
185                    }
186                    PreTokenizerUntagged::Digits(digits) => PreTokenizerWrapper::Digits(digits),
187                    PreTokenizerUntagged::UnicodeScripts(unicode_scripts) => {
188                        PreTokenizerWrapper::UnicodeScripts(unicode_scripts)
189                    }
190                }
191            }
192        })
193    }
194}
195
196impl_enum_from!(BertPreTokenizer, PreTokenizerWrapper, BertPreTokenizer);
197impl_enum_from!(ByteLevel, PreTokenizerWrapper, ByteLevel);
198impl_enum_from!(CharDelimiterSplit, PreTokenizerWrapper, Delimiter);
199impl_enum_from!(Whitespace, PreTokenizerWrapper, Whitespace);
200impl_enum_from!(Punctuation, PreTokenizerWrapper, Punctuation);
201impl_enum_from!(Sequence, PreTokenizerWrapper, Sequence);
202impl_enum_from!(Split, PreTokenizerWrapper, Split);
203impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace);
204impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit);
205impl_enum_from!(Digits, PreTokenizerWrapper, Digits);
206impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts);
207
208#[cfg(test)]
209mod tests {
210    use super::metaspace::PrependScheme;
211    use super::*;
212
213    #[test]
214    fn test_deserialize() {
215        let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","str_rep":"▁","add_prefix_space":true}]}"#).unwrap();
216
217        assert_eq!(
218            pre_tokenizer,
219            PreTokenizerWrapper::Sequence(Sequence::new(vec![
220                PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
221                PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
222            ]))
223        );
224
225        let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
226            r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true}"#,
227        )
228        .unwrap();
229
230        assert_eq!(
231            pre_tokenizer,
232            PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
233        );
234
235        let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#).unwrap();
236
237        assert_eq!(
238            pre_tokenizer,
239            PreTokenizerWrapper::Sequence(Sequence::new(vec![
240                PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
241                PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
242            ]))
243        );
244
245        let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
246            r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true, "prepend_scheme":"first"}"#,
247        )
248        .unwrap();
249
250        assert_eq!(
251            pre_tokenizer,
252            PreTokenizerWrapper::Metaspace(Metaspace::new(
253                '▁',
254                metaspace::PrependScheme::First,
255                true
256            ))
257        );
258
259        let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
260            r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true, "prepend_scheme":"always"}"#,
261        )
262        .unwrap();
263
264        assert_eq!(
265            pre_tokenizer,
266            PreTokenizerWrapper::Metaspace(Metaspace::new(
267                '▁',
268                metaspace::PrependScheme::Always,
269                true
270            ))
271        );
272    }
273
274    #[test]
275    fn test_deserialize_whitespace_split() {
276        let pre_tokenizer: PreTokenizerWrapper =
277            serde_json::from_str(r#"{"type":"WhitespaceSplit"}"#).unwrap();
278        assert_eq!(
279            pre_tokenizer,
280            PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {})
281        );
282    }
283
284    #[test]
285    fn pre_tokenizer_deserialization_no_type() {
286        let json = r#"{"replacement":"▁","add_prefix_space":true, "prepend_scheme":"always"}}"#;
287        let reconstructed = serde_json::from_str::<PreTokenizerWrapper>(json);
288        match reconstructed {
289            Err(err) => assert_eq!(
290                err.to_string(),
291                "data did not match any variant of untagged enum PreTokenizerUntagged"
292            ),
293            _ => panic!("Expected an error here"),
294        }
295
296        let json = r#"{"type":"Metaspace", "replacement":"▁" }"#;
297        let reconstructed = serde_json::from_str::<PreTokenizerWrapper>(json).unwrap();
298        assert_eq!(
299            reconstructed,
300            PreTokenizerWrapper::Metaspace(Metaspace::default())
301        );
302
303        let json = r#"{"type":"Metaspace", "add_prefix_space":true }"#;
304        let reconstructed = serde_json::from_str::<PreTokenizerWrapper>(json);
305        match reconstructed {
306            Err(err) => assert_eq!(err.to_string(), "missing field `replacement`"),
307            _ => panic!("Expected an error here"),
308        }
309        let json = r#"{"behavior":"default_split"}"#;
310        let reconstructed = serde_json::from_str::<PreTokenizerWrapper>(json);
311        match reconstructed {
312            Err(err) => assert_eq!(
313                err.to_string(),
314                "data did not match any variant of untagged enum PreTokenizerUntagged"
315            ),
316            _ => panic!("Expected an error here"),
317        }
318    }
319}