tokenizers/normalizers/
byte_level.rs

1use crate::processors::byte_level::bytes_char;
2use crate::tokenizer::{NormalizedString, Normalizer, Result};
3use crate::utils::macro_rules_attribute;
4use std::collections::{HashMap, HashSet};
5
6#[derive(Clone, Debug)]
7#[macro_rules_attribute(impl_serde_type!)]
8pub struct ByteLevel;
9
10lazy_static! {
11    static ref BYTES_CHAR: HashMap<u8, char> = bytes_char();
12    static ref CHAR_BYTES: HashMap<char, u8> =
13        bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
14}
15
16impl Default for ByteLevel {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl ByteLevel {
23    pub fn new() -> Self {
24        Self {}
25    }
26
27    pub fn alphabet() -> HashSet<char> {
28        BYTES_CHAR.values().copied().collect()
29    }
30}
31
32impl Normalizer for ByteLevel {
33    /// Strip the normalized string inplace
34    fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
35        if !normalized.is_empty() {
36            let s = normalized.get();
37            let mut transformations: Vec<(char, isize)> = Vec::with_capacity(s.len());
38            let mut i = 0;
39            for cur_char in s.chars() {
40                let size = cur_char.len_utf8();
41                let bytes = s[i..i + size].as_bytes();
42                i += size;
43                transformations.extend(
44                    bytes
45                        .iter()
46                        .enumerate()
47                        .map(|(i, b)| (BYTES_CHAR[b], isize::from(i > 0))),
48                );
49            }
50            normalized.transform(transformations, 0);
51        }
52        Ok(())
53    }
54}
55
56#[cfg(test)]
57mod tests {
58
59    use super::*;
60
61    #[test]
62    fn test_byte_level_normalize() {
63        let original = "Hello 我今天能为你做什么";
64        let normalized = "HelloĠæĪijä»Ĭ天èĥ½ä¸ºä½łåģļä»Ģä¹Ī";
65        assert_ne!(original, normalized);
66        let mut n = NormalizedString::from(original);
67        let byte_level = ByteLevel::new();
68        byte_level.normalize(&mut n).unwrap();
69        assert_eq!(&n.get(), &normalized);
70        assert_eq!(
71            n,
72            NormalizedString::new(
73                original.to_string(),
74                normalized.to_string(),
75                vec![
76                    (0, 1),
77                    (1, 2),
78                    (2, 3),
79                    (3, 4),
80                    (4, 5),
81                    (5, 6),
82                    (5, 6),
83                    (6, 9),
84                    (6, 9),
85                    (6, 9),
86                    (6, 9),
87                    (6, 9),
88                    (6, 9),
89                    (9, 12),
90                    (9, 12),
91                    (9, 12),
92                    (9, 12),
93                    (9, 12),
94                    (9, 12),
95                    (12, 15),
96                    (12, 15),
97                    (12, 15),
98                    (12, 15),
99                    (12, 15),
100                    (12, 15),
101                    (15, 18),
102                    (15, 18),
103                    (15, 18),
104                    (15, 18),
105                    (15, 18),
106                    (15, 18),
107                    (18, 21),
108                    (18, 21),
109                    (18, 21),
110                    (18, 21),
111                    (18, 21),
112                    (18, 21),
113                    (21, 24),
114                    (21, 24),
115                    (21, 24),
116                    (21, 24),
117                    (21, 24),
118                    (21, 24),
119                    (24, 27),
120                    (24, 27),
121                    (24, 27),
122                    (24, 27),
123                    (24, 27),
124                    (24, 27),
125                    (27, 30),
126                    (27, 30),
127                    (27, 30),
128                    (27, 30),
129                    (27, 30),
130                    (27, 30),
131                    (30, 33),
132                    (30, 33),
133                    (30, 33),
134                    (30, 33),
135                    (30, 33),
136                    (30, 33)
137                ],
138                0
139            )
140        );
141        assert_eq!(
142            n.alignments_original(),
143            vec![
144                (0, 1),
145                (1, 2),
146                (2, 3),
147                (3, 4),
148                (4, 5),
149                (5, 7),
150                (7, 13),
151                (7, 13),
152                (7, 13),
153                (13, 19),
154                (13, 19),
155                (13, 19),
156                (19, 25),
157                (19, 25),
158                (19, 25),
159                (25, 31),
160                (25, 31),
161                (25, 31),
162                (31, 37),
163                (31, 37),
164                (31, 37),
165                (37, 43),
166                (37, 43),
167                (37, 43),
168                (43, 49),
169                (43, 49),
170                (43, 49),
171                (49, 55),
172                (49, 55),
173                (49, 55),
174                (55, 61),
175                (55, 61),
176                (55, 61)
177            ]
178        );
179    }
180}