tokenizers/normalizers/
strip.rs

1use crate::tokenizer::{NormalizedString, Normalizer, Result};
2use crate::utils::macro_rules_attribute;
3use serde::{Deserialize, Serialize};
4use unicode_normalization_alignments::char::is_combining_mark;
5
6#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
7#[serde(tag = "type")]
8#[non_exhaustive]
9pub struct Strip {
10    pub strip_left: bool,
11    pub strip_right: bool,
12}
13
14impl Strip {
15    pub fn new(strip_left: bool, strip_right: bool) -> Self {
16        Self {
17            strip_left,
18            strip_right,
19        }
20    }
21}
22
23impl Normalizer for Strip {
24    /// Strip the normalized string inplace
25    fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
26        if self.strip_left && self.strip_right {
27            // Fast path
28            normalized.strip();
29        } else {
30            if self.strip_left {
31                normalized.lstrip();
32            }
33
34            if self.strip_right {
35                normalized.rstrip();
36            }
37        }
38
39        Ok(())
40    }
41}
42
43// This normalizer removes combining marks from a normalized string
44// It's different from unidecode as it does not attempt to modify
45// non ascii languages.
46#[derive(Copy, Clone, Debug)]
47#[macro_rules_attribute(impl_serde_type!)]
48pub struct StripAccents;
49
50impl Normalizer for StripAccents {
51    /// Strip the normalized string inplace
52    fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
53        normalized.filter(|c| !is_combining_mark(c));
54        Ok(())
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61    use crate::normalizer::NormalizedString;
62    use crate::normalizers::Lowercase;
63    use crate::normalizers::NFKD;
64    use unicode_normalization_alignments::UnicodeNormalization;
65
66    #[test]
67    fn test_strip_accents() {
68        // Unicode combining char
69        let original: String = "Me llamó".nfkd().map(|(c, _)| c).collect();
70        let normalized = "Me llamo";
71        assert_ne!(original, normalized);
72        let mut n = NormalizedString::from(original);
73        StripAccents.normalize(&mut n).unwrap();
74        assert_eq!(&n.get(), &normalized);
75
76        // Ignores regular ascii
77        let original = "Me llamo";
78        let normalized = "Me llamo";
79        assert_eq!(original, normalized);
80        let mut n = NormalizedString::from(original);
81        StripAccents.normalize(&mut n).unwrap();
82        assert_eq!(&n.get(), &normalized);
83
84        // Does not change chinese
85        let original: String = "这很简单".nfkd().map(|(c, _)| c).collect();
86        let normalized = "这很简单";
87        assert_eq!(original, normalized);
88        let mut n = NormalizedString::from(original);
89        StripAccents.normalize(&mut n).unwrap();
90        assert_eq!(&n.get(), &normalized);
91    }
92
93    #[test]
94    fn test_vietnamese_bug() {
95        let original: String = "ậ…".to_string();
96        let normalized = "a...".to_string();
97        assert_ne!(original, normalized);
98        let mut n = NormalizedString::from(original);
99        NFKD.normalize(&mut n).unwrap();
100        StripAccents.normalize(&mut n).unwrap();
101        assert_eq!(&n.get(), &normalized);
102        Lowercase.normalize(&mut n).unwrap();
103        assert_eq!(&n.get(), &normalized);
104
105        let original: String = "Cụ thể, bạn sẽ tham gia một nhóm các giám đốc điều hành tổ chức, các nhà lãnh đạo doanh nghiệp, các học giả, chuyên gia phát triển và tình nguyện viên riêng biệt trong lĩnh vực phi lợi nhuận…".to_string();
106        let normalized = "cu the, ban se tham gia mot nhom cac giam đoc đieu hanh to chuc, cac nha lanh đao doanh nghiep, cac hoc gia, chuyen gia phat trien va tinh nguyen vien rieng biet trong linh vuc phi loi nhuan...".to_string();
107        let mut n = NormalizedString::from(original);
108        NFKD.normalize(&mut n).unwrap();
109        StripAccents.normalize(&mut n).unwrap();
110        Lowercase.normalize(&mut n).unwrap();
111        assert_eq!(&n.get(), &normalized);
112    }
113
114    #[test]
115    fn test_thai_bug() {
116        let original = "ำน\u{e49}ำ3ลำ".to_string();
117        let normalized = "านา3ลา".to_string();
118        assert_ne!(original, normalized);
119        let mut n = NormalizedString::from(original);
120        NFKD.normalize(&mut n).unwrap();
121        StripAccents.normalize(&mut n).unwrap();
122        Lowercase.normalize(&mut n).unwrap();
123        assert_eq!(&n.get(), &normalized);
124    }
125
126    #[test]
127    fn test_strip_accents_multiple() {
128        let original = "e\u{304}\u{304}\u{304}o";
129        let normalized = "eo";
130        assert_ne!(original, normalized);
131        let mut n = NormalizedString::from(original);
132        StripAccents.normalize(&mut n).unwrap();
133        assert_eq!(&n.get(), &normalized);
134        assert_eq!(
135            n,
136            NormalizedString::new(
137                original.to_string(),
138                normalized.to_string(),
139                vec![(0, 1), (7, 8)],
140                0
141            )
142        );
143        assert_eq!(
144            n.alignments_original(),
145            vec![
146                (0, 1),
147                (1, 1),
148                (1, 1),
149                (1, 1),
150                (1, 1),
151                (1, 1),
152                (1, 1),
153                (1, 2)
154            ]
155        );
156    }
157}