tokenizers/normalizers/
precompiled.rs

1use crate::tokenizer::{NormalizedString, Normalizer, Result};
2pub use spm_precompiled::Precompiled;
3use std::cmp::Ordering;
4use unicode_segmentation::UnicodeSegmentation;
5
6fn replace(transformations: &mut Vec<(char, isize)>, old_part: &str, new_part: &str) {
7    let old_count = old_part.chars().count() as isize;
8    let new_count = new_part.chars().count() as isize;
9    let diff = new_count - old_count;
10
11    // If we are just replacing characters, all changes should be == 0
12    transformations.extend(new_part.chars().map(|c| (c, 0)));
13
14    match diff.cmp(&0) {
15        // If we are adding some characters, the last DIFF characters shoud be == 1
16        Ordering::Greater => {
17            transformations
18                .iter_mut()
19                .rev()
20                .take(diff as usize)
21                .for_each(|(_, cs)| *cs = 1);
22        }
23        // If we are removing some characters, the last one should include the diff
24        Ordering::Less => {
25            if let Some((_, cs)) = transformations.last_mut() {
26                *cs += diff;
27            }
28        }
29        _ => {}
30    }
31}
32
33impl Normalizer for Precompiled {
34    fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
35        let mut transformations = Vec::with_capacity(normalized.get().len());
36        // Future reader. From @Narsil.
37        // Yes, this is weird,
38        // Yes, this seems broken
39        // No, I don't know why Google did this.
40        // If you question this code, check this normalizer against
41        // XNLI database (all languages) with Unigram model against
42        // Mbart, XLMRoberta *AND* Marian. If you don't get 100% or
43        // break a single test.
44        // You don't pass.
45        let mut modified = false;
46        normalized.get().graphemes(true).for_each(|grapheme| {
47            if grapheme.len() < 6 {
48                if let Some(norm) = self.transform(grapheme) {
49                    modified = true;
50                    replace(&mut transformations, grapheme, norm);
51                    return;
52                }
53            }
54            for (char_index, c) in grapheme.char_indices() {
55                let part = &grapheme[char_index..char_index + c.len_utf8()];
56                if let Some(norm) = self.transform(part) {
57                    modified = true;
58                    replace(&mut transformations, part, norm);
59                } else {
60                    transformations.push((c, 0));
61                }
62            }
63        });
64        if modified {
65            normalized.transform(transformations, 0);
66        }
67        Ok(())
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74
75    #[test]
76    fn expansion_followed_by_removal() {
77        // Simulate transformations from "™\x1eg" to "TMg"
78        let mut transformations = vec![];
79
80        let mut n = NormalizedString::from("™\x1eg");
81        replace(&mut transformations, "™", "TM");
82        replace(&mut transformations, "\x1e", "");
83        transformations.push(('g', 0));
84
85        n.transform(transformations, 0);
86
87        assert_eq!(n.get(), "TMg");
88    }
89}