1pub mod bert;
2pub mod byte_level;
3pub mod precompiled;
4pub mod prepend;
5pub mod replace;
6pub mod strip;
7pub mod unicode;
8pub mod utils;
9pub use crate::normalizers::bert::BertNormalizer;
10pub use crate::normalizers::byte_level::ByteLevel;
11pub use crate::normalizers::precompiled::Precompiled;
12pub use crate::normalizers::prepend::Prepend;
13pub use crate::normalizers::replace::Replace;
14pub use crate::normalizers::strip::{Strip, StripAccents};
15pub use crate::normalizers::unicode::{Nmt, NFC, NFD, NFKC, NFKD};
16pub use crate::normalizers::utils::{Lowercase, Sequence};
17use serde::{Deserialize, Deserializer, Serialize};
18
19use crate::{NormalizedString, Normalizer};
20
21#[derive(Clone, Debug, Serialize)]
23#[serde(untagged)]
24pub enum NormalizerWrapper {
25 BertNormalizer(BertNormalizer),
26 StripNormalizer(Strip),
27 StripAccents(StripAccents),
28 NFC(NFC),
29 NFD(NFD),
30 NFKC(NFKC),
31 NFKD(NFKD),
32 Sequence(Sequence),
33 Lowercase(Lowercase),
34 Nmt(Nmt),
35 Precompiled(Precompiled),
36 Replace(Replace),
37 Prepend(Prepend),
38 ByteLevel(ByteLevel),
39}
40
41impl<'de> Deserialize<'de> for NormalizerWrapper {
42 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
43 where
44 D: Deserializer<'de>,
45 {
46 #[derive(Debug, Deserialize)]
47 pub struct Tagged {
48 #[serde(rename = "type")]
49 variant: EnumType,
50 #[serde(flatten)]
51 rest: serde_json::Value,
52 }
53 #[derive(Debug, Serialize, Deserialize)]
54 pub enum EnumType {
55 Bert,
56 Strip,
57 StripAccents,
58 NFC,
59 NFD,
60 NFKC,
61 NFKD,
62 Sequence,
63 Lowercase,
64 Nmt,
65 Precompiled,
66 Replace,
67 Prepend,
68 ByteLevel,
69 }
70
71 #[derive(Deserialize)]
72 #[serde(untagged)]
73 pub enum NormalizerHelper {
74 Tagged(Tagged),
75 Legacy(serde_json::Value),
76 }
77
78 #[derive(Deserialize)]
79 #[serde(untagged)]
80 pub enum NormalizerUntagged {
81 BertNormalizer(BertNormalizer),
82 StripNormalizer(Strip),
83 StripAccents(StripAccents),
84 NFC(NFC),
85 NFD(NFD),
86 NFKC(NFKC),
87 NFKD(NFKD),
88 Sequence(Sequence),
89 Lowercase(Lowercase),
90 Nmt(Nmt),
91 Precompiled(Precompiled),
92 Replace(Replace),
93 Prepend(Prepend),
94 ByteLevel(ByteLevel),
95 }
96
97 let helper = NormalizerHelper::deserialize(deserializer)?;
98 Ok(match helper {
99 NormalizerHelper::Tagged(model) => {
100 let mut values: serde_json::Map<String, serde_json::Value> =
101 serde_json::from_value(model.rest).expect("Parsed values");
102 values.insert(
103 "type".to_string(),
104 serde_json::to_value(&model.variant).expect("Reinsert"),
105 );
106 let values = serde_json::Value::Object(values);
107 match model.variant {
108 EnumType::Bert => NormalizerWrapper::BertNormalizer(
109 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
110 ),
111 EnumType::Strip => NormalizerWrapper::StripNormalizer(
112 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
113 ),
114 EnumType::StripAccents => NormalizerWrapper::StripAccents(
115 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
116 ),
117 EnumType::NFC => NormalizerWrapper::NFC(
118 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
119 ),
120 EnumType::NFD => NormalizerWrapper::NFD(
121 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
122 ),
123 EnumType::NFKC => NormalizerWrapper::NFKC(
124 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
125 ),
126 EnumType::NFKD => NormalizerWrapper::NFKD(
127 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
128 ),
129 EnumType::Sequence => NormalizerWrapper::Sequence(
130 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
131 ),
132 EnumType::Lowercase => NormalizerWrapper::Lowercase(
133 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
134 ),
135 EnumType::Nmt => NormalizerWrapper::Nmt(
136 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
137 ),
138 EnumType::Precompiled => NormalizerWrapper::Precompiled(
139 serde_json::from_str(
140 &serde_json::to_string(&values).expect("Can reserialize precompiled"),
141 )
142 .expect("Precompiled"),
144 ),
145 EnumType::Replace => NormalizerWrapper::Replace(
146 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
147 ),
148 EnumType::Prepend => NormalizerWrapper::Prepend(
149 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
150 ),
151 EnumType::ByteLevel => NormalizerWrapper::ByteLevel(
152 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
153 ),
154 }
155 }
156
157 NormalizerHelper::Legacy(value) => {
158 let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
159 match untagged {
160 NormalizerUntagged::BertNormalizer(bpe) => {
161 NormalizerWrapper::BertNormalizer(bpe)
162 }
163 NormalizerUntagged::StripNormalizer(bpe) => {
164 NormalizerWrapper::StripNormalizer(bpe)
165 }
166 NormalizerUntagged::StripAccents(bpe) => NormalizerWrapper::StripAccents(bpe),
167 NormalizerUntagged::NFC(bpe) => NormalizerWrapper::NFC(bpe),
168 NormalizerUntagged::NFD(bpe) => NormalizerWrapper::NFD(bpe),
169 NormalizerUntagged::NFKC(bpe) => NormalizerWrapper::NFKC(bpe),
170 NormalizerUntagged::NFKD(bpe) => NormalizerWrapper::NFKD(bpe),
171 NormalizerUntagged::Sequence(seq) => NormalizerWrapper::Sequence(seq),
172 NormalizerUntagged::Lowercase(bpe) => NormalizerWrapper::Lowercase(bpe),
173 NormalizerUntagged::Nmt(bpe) => NormalizerWrapper::Nmt(bpe),
174 NormalizerUntagged::Precompiled(bpe) => NormalizerWrapper::Precompiled(bpe),
175 NormalizerUntagged::Replace(bpe) => NormalizerWrapper::Replace(bpe),
176 NormalizerUntagged::Prepend(bpe) => NormalizerWrapper::Prepend(bpe),
177 NormalizerUntagged::ByteLevel(bpe) => NormalizerWrapper::ByteLevel(bpe),
178 }
179 }
180 })
181 }
182}
183
184impl Normalizer for NormalizerWrapper {
185 fn normalize(&self, normalized: &mut NormalizedString) -> crate::Result<()> {
186 match self {
187 Self::BertNormalizer(bn) => bn.normalize(normalized),
188 Self::StripNormalizer(sn) => sn.normalize(normalized),
189 Self::StripAccents(sn) => sn.normalize(normalized),
190 Self::NFC(nfc) => nfc.normalize(normalized),
191 Self::NFD(nfd) => nfd.normalize(normalized),
192 Self::NFKC(nfkc) => nfkc.normalize(normalized),
193 Self::NFKD(nfkd) => nfkd.normalize(normalized),
194 Self::Sequence(sequence) => sequence.normalize(normalized),
195 Self::Lowercase(lc) => lc.normalize(normalized),
196 Self::Nmt(lc) => lc.normalize(normalized),
197 Self::Precompiled(lc) => lc.normalize(normalized),
198 Self::Replace(lc) => lc.normalize(normalized),
199 Self::Prepend(lc) => lc.normalize(normalized),
200 Self::ByteLevel(lc) => lc.normalize(normalized),
201 }
202 }
203}
204
205impl_enum_from!(BertNormalizer, NormalizerWrapper, BertNormalizer);
206impl_enum_from!(NFKD, NormalizerWrapper, NFKD);
207impl_enum_from!(NFKC, NormalizerWrapper, NFKC);
208impl_enum_from!(NFC, NormalizerWrapper, NFC);
209impl_enum_from!(NFD, NormalizerWrapper, NFD);
210impl_enum_from!(Strip, NormalizerWrapper, StripNormalizer);
211impl_enum_from!(StripAccents, NormalizerWrapper, StripAccents);
212impl_enum_from!(Sequence, NormalizerWrapper, Sequence);
213impl_enum_from!(Lowercase, NormalizerWrapper, Lowercase);
214impl_enum_from!(Nmt, NormalizerWrapper, Nmt);
215impl_enum_from!(Precompiled, NormalizerWrapper, Precompiled);
216impl_enum_from!(Replace, NormalizerWrapper, Replace);
217impl_enum_from!(Prepend, NormalizerWrapper, Prepend);
218impl_enum_from!(ByteLevel, NormalizerWrapper, ByteLevel);
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 #[test]
224 fn post_processor_deserialization_no_type() {
225 let json = r#"{"strip_left":false, "strip_right":true}"#;
226 let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
227 assert!(matches!(
228 reconstructed.unwrap(),
229 NormalizerWrapper::StripNormalizer(_)
230 ));
231
232 let json = r#"{"trim_offsets":true, "add_prefix_space":true}"#;
233 let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
234 match reconstructed {
235 Err(err) => assert_eq!(
236 err.to_string(),
237 "data did not match any variant of untagged enum NormalizerUntagged"
238 ),
239 _ => panic!("Expected an error here"),
240 }
241
242 let json = r#"{"prepend":"a"}"#;
243 let reconstructed = serde_json::from_str::<NormalizerWrapper>(json);
244 assert!(matches!(
245 reconstructed.unwrap(),
246 NormalizerWrapper::Prepend(_)
247 ));
248 }
249
250 #[test]
251 fn normalizer_serialization() {
252 let json = r#"{"type":"Sequence","normalizers":[]}"#;
253 assert!(serde_json::from_str::<NormalizerWrapper>(json).is_ok());
254 let json = r#"{"type":"Sequence","normalizers":[{}]}"#;
255 let parse = serde_json::from_str::<NormalizerWrapper>(json);
256 match parse {
257 Err(err) => assert_eq!(
258 format!("{err}"),
259 "data did not match any variant of untagged enum NormalizerUntagged"
260 ),
261 _ => panic!("Expected error"),
262 }
263
264 let json = r#"{"replacement":"▁","prepend_scheme":"always"}"#;
265 let parse = serde_json::from_str::<NormalizerWrapper>(json);
266 match parse {
267 Err(err) => assert_eq!(
268 format!("{err}"),
269 "data did not match any variant of untagged enum NormalizerUntagged"
270 ),
271 _ => panic!("Expected error"),
272 }
273
274 let json = r#"{"type":"Sequence","prepend_scheme":"always"}"#;
275 let parse = serde_json::from_str::<NormalizerWrapper>(json);
276 match parse {
277 Err(err) => assert_eq!(format!("{err}"), "missing field `normalizers`"),
278 _ => panic!("Expected error"),
279 }
280 }
281}