tokenizers/normalizers/
mod.rs

1pub mod bert;
2pub mod byte_level;
3pub mod precompiled;
4pub mod prepend;
5pub mod replace;
6pub mod strip;
7pub mod unicode;
8pub mod utils;
9pub use crate::normalizers::bert::BertNormalizer;
10pub use crate::normalizers::byte_level::ByteLevel;
11pub use crate::normalizers::precompiled::Precompiled;
12pub use crate::normalizers::prepend::Prepend;
13pub use crate::normalizers::replace::Replace;
14pub use crate::normalizers::strip::{Strip, StripAccents};
15pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD};
16pub use crate::normalizers::utils::{Lowercase, Sequence};
17use serde::{Deserialize, Deserializer, Serialize};
18
19use crate::{NormalizedString, Normalizer};
20
21/// Wrapper for known Normalizers.
22#[derive(Clone, Debug, Serialize)]
23#[serde(untagged)]
24pub enum NormalizerWrapper {
25    BertNormalizer(BertNormalizer),
26    StripNormalizer(Strip),
27    StripAccents(StripAccents),
28    NFC(NFC),
29    NFD(NFD),
30    NFKC(NFKC),
31    NFKD(NFKD),
32    Sequence(Sequence),
33    Lowercase(Lowercase),
34    Nmt(Nmt),
35    Precompiled(Precompiled),
36    Replace(Replace),
37    Prepend(Prepend),
38    ByteLevel(ByteLevel),
39}
40
41impl<'de> Deserialize<'de> for NormalizerWrapper {
42    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
43    where
44        D: Deserializer<'de>,
45    {
46        #[derive(Debug, Deserialize)]
47        pub struct Tagged {
48            #[serde(rename = "type")]
49            variant: EnumType,
50            #[serde(flatten)]
51            rest: serde_json::Value,
52        }
53        #[derive(Debug, Serialize, Deserialize)]
54        pub enum EnumType {
55            Bert,
56            Strip,
57            StripAccents,
58            NFC,
59            NFD,
60            NFKC,
61            NFKD,
62            Sequence,
63            Lowercase,
64            Nmt,
65            Precompiled,
66            Replace,
67            Prepend,
68            ByteLevel,
69        }
70
71        #[derive(Deserialize)]
72        #[serde(untagged)]
73        pub enum NormalizerHelper {
74            Tagged(Tagged),
75            Legacy(serde_json::Value),
76        }
77
78        #[derive(Deserialize)]
79        #[serde(untagged)]
80        pub enum NormalizerUntagged {
81            BertNormalizer(BertNormalizer),
82            StripNormalizer(Strip),
83            StripAccents(StripAccents),
84            NFC(NFC),
85            NFD(NFD),
86            NFKC(NFKC),
87            NFKD(NFKD),
88            Sequence(Sequence),
89            Lowercase(Lowercase),
90            Nmt(Nmt),
91            Precompiled(Precompiled),
92            Replace(Replace),
93            Prepend(Prepend),
94            ByteLevel(ByteLevel),
95        }
96
97        let helper = NormalizerHelper::deserialize(deserializer)?;
98        Ok(match helper {
99            NormalizerHelper::Tagged(model) => {
100                let mut values: serde_json::Map<String, serde_json::Value> =
101                    serde_json::from_value(model.rest).expect("Parsed values");
102                values.insert(
103                    "type".to_string(),
104                    serde_json::to_value(&model.variant).expect("Reinsert"),
105                );
106                let values = serde_json::Value::Object(values);
107                match model.variant {
108                    EnumType::Bert => NormalizerWrapper::BertNormalizer(
109                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
110                    ),
111                    EnumType::Strip => NormalizerWrapper::StripNormalizer(
112                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
113                    ),
114                    EnumType::StripAccents => NormalizerWrapper::StripAccents(
115                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
116                    ),
117                    EnumType::NFC => NormalizerWrapper::NFC(
118                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
119                    ),
120                    EnumType::NFD => NormalizerWrapper::NFD(
121                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
122                    ),
123                    EnumType::NFKC => NormalizerWrapper::NFKC(
124                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
125                    ),
126                    EnumType::NFKD => NormalizerWrapper::NFKD(
127                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
128                    ),
129                    EnumType::Sequence => NormalizerWrapper::Sequence(
130                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
131                    ),
132                    EnumType::Lowercase => NormalizerWrapper::Lowercase(
133                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
134                    ),
135                    EnumType::Nmt => NormalizerWrapper::Nmt(
136                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
137                    ),
138                    EnumType::Precompiled => NormalizerWrapper::Precompiled(
139                        serde_json::from_str(
140                            &serde_json::to_string(&values).expect("Can reserialize precompiled"),
141                        )
142                        // .map_err(serde::de::Error::custom)
143                        .expect("Precompiled"),
144                    ),
145                    EnumType::Replace => NormalizerWrapper::Replace(
146                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
147                    ),
148                    EnumType::Prepend => NormalizerWrapper::Prepend(
149                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
150                    ),
151                    EnumType::ByteLevel => NormalizerWrapper::ByteLevel(
152                        serde_json::from_value(values).map_err(serde::de::Error::custom)?,
153                    ),
154                }
155            }
156
157            NormalizerHelper::Legacy(value) => {
158                let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
159                match untagged {
160                    NormalizerUntagged::BertNormalizer(bpe) => {
161                        NormalizerWrapper::BertNormalizer(bpe)
162                    }
163                    NormalizerUntagged::StripNormalizer(bpe) => {
164                        NormalizerWrapper::StripNormalizer(bpe)
165                    }
166                    NormalizerUntagged::StripAccents(bpe) => NormalizerWrapper::StripAccents(bpe),
167                    NormalizerUntagged::NFC(bpe) => NormalizerWrapper::NFC(bpe),
168                    NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe),
169                    NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe),
170                    NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe),
171                    NormalizerUntagged::Sequence(seq) => NormalizerWrapper::Sequence(seq),
172                    NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe),
173                    NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe),
174                    NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe),
175                    NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe),
176                    NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe),
177                    NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe),
178                }
179            }
180        })
181    }
182}
183
184impl Normalizer for NormalizerWrapper {
185    fn normalize(&self, normalized: &mut NormalizedString) -> crate::Result<()> {
186        match self {
187            Self::BertNormalizer(bn) => bn.normalize(normalized),
188            Self::StripNormalizer(sn) => sn.normalize(normalized),
189            Self::StripAccents(sn) => sn.normalize(normalized),
190            Self::NFC(nfc) => nfc.normalize(normalized),
191            Self::NFD(nfd) => nfd.normalize(normalized),
192            Self::NFKC(nfkc) => nfkc.normalize(normalized),
193            Self::NFKD(nfkd) => nfkd.normalize(normalized),
194            Self::Sequence(sequence) => sequence.normalize(normalized),
195            Self::Lowercase(lc) => lc.normalize(normalized),
196            Self::Nmt(lc) => lc.normalize(normalized),
197            Self::Precompiled(lc) => lc.normalize(normalized),
198            Self::Replace(lc) => lc.normalize(normalized),
199            Self::Prepend(lc) => lc.normalize(normalized),
200            Self::ByteLevel(lc) => lc.normalize(normalized),
201        }
202    }
203}
204
205impl_enum_from!(BertNormalizer, NormalizerWrapper, BertNormalizer);
206impl_enum_from!(NFKD, NormalizerWrapper, NFKD);
207impl_enum_from!(NFKC, NormalizerWrapper, NFKC);
208impl_enum_from!(NFC, NormalizerWrapper, NFC);
209impl_enum_from!(NFD, NormalizerWrapper, NFD);
210impl_enum_from!(Strip, NormalizerWrapper, StripNormalizer);
211impl_enum_from!(StripAccents, NormalizerWrapper, StripAccents);
212impl_enum_from!(Sequence, NormalizerWrapper, Sequence);
213impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase);
214impl_enum_from!(Nmt, NormalizerWrapper, Nmt);
215impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled);
216impl_enum_from!(Replace, NormalizerWrapper, Replace);
217impl_enum_from!(Prepend, NormalizerWrapper, Prepend);
218impl_enum_from!(ByteLevel, NormalizerWrapper, ByteLevel);
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    #[test]
224    fn post_processor_deserialization_no_type() {
225        let json = r#"{"strip_left":false, "strip_right":true}"#;
226        let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
227        assert!(matches!(
228            reconstructed.unwrap(),
229            NormalizerWrapper::StripNormalizer(_)
230        ));
231
232        let json = r#"{"trim_offsets":true, "add_prefix_space":true}"#;
233        let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
234        match reconstructed {
235            Err(err) => assert_eq!(
236                err.to_string(),
237                "data did not match any variant of untagged enum NormalizerUntagged"
238            ),
239            _ => panic!("Expected an error here"),
240        }
241
242        let json = r#"{"prepend":"a"}"#;
243        let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
244        assert!(matches!(
245            reconstructed.unwrap(),
246            NormalizerWrapper::Prepend(_)
247        ));
248    }
249
250    #[test]
251    fn normalizer_serialization() {
252        let json = r#"{"type":"Sequence","normalizers":[]}"#;
253        assert!(serde_json::from_str::<NormalizerWrapper>(json).is_ok());
254        let json = r#"{"type":"Sequence","normalizers":[{}]}"#;
255        let parse = serde_json::from_str::<NormalizerWrapper>(json);
256        match parse {
257            Err(err) => assert_eq!(
258                format!("{err}"),
259                "data did not match any variant of untagged enum NormalizerUntagged"
260            ),
261            _ => panic!("Expected error"),
262        }
263
264        let json = r#"{"replacement":"▁","prepend_scheme":"always"}"#;
265        let parse = serde_json::from_str::<NormalizerWrapper>(json);
266        match parse {
267            Err(err) => assert_eq!(
268                format!("{err}"),
269                "data did not match any variant of untagged enum NormalizerUntagged"
270            ),
271            _ => panic!("Expected error"),
272        }
273
274        let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
275        let parse = serde_json::from_str::<NormalizerWrapper>(json);
276        match parse {
277            Err(err) => assert_eq!(format!("{err}"), "missing field `normalizers`"),
278            _ => panic!("Expected error"),
279        }
280    }
281}