1use crate::processors::byte_level::bytes_char;
2use crate::tokenizer::{NormalizedString, Normalizer, Result};
3use crate::utils::macro_rules_attribute;
4use std::collections::{HashMap, HashSet};
5
6#[derive(Clone, Debug)]
7#[macro_rules_attribute(impl_serde_type!)]
8pub struct ByteLevel;
9
10lazy_static! {
11 static ref BYTES_CHAR: HashMap<u8, char> = bytes_char();
12 static ref CHAR_BYTES: HashMap<char, u8> =
13 bytes_char().into_iter().map(|(c, b)| (b, c)).collect();
14}
15
16impl Default for ByteLevel {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl ByteLevel {
23 pub fn new() -> Self {
24 Self {}
25 }
26
27 pub fn alphabet() -> HashSet<char> {
28 BYTES_CHAR.values().copied().collect()
29 }
30}
31
32impl Normalizer for ByteLevel {
33 fn normalize(&self, normalized: &mut NormalizedString) -> Result<()> {
35 if !normalized.is_empty() {
36 let s = normalized.get();
37 let mut transformations: Vec<(char, isize)> = Vec::with_capacity(s.len());
38 let mut i = 0;
39 for cur_char in s.chars() {
40 let size = cur_char.len_utf8();
41 let bytes = s[i..i + size].as_bytes();
42 i += size;
43 transformations.extend(
44 bytes
45 .iter()
46 .enumerate()
47 .map(|(i, b)| (BYTES_CHAR[b], isize::from(i > 0))),
48 );
49 }
50 normalized.transform(transformations, 0);
51 }
52 Ok(())
53 }
54}
55
56#[cfg(test)]
57mod tests {
58
59 use super::*;
60
61 #[test]
62 fn test_byte_level_normalize() {
63 let original = "Hello 我今天能为你做什么";
64 let normalized = "HelloĠæĪijä»Ĭ天èĥ½ä¸ºä½łåģļä»Ģä¹Ī";
65 assert_ne!(original, normalized);
66 let mut n = NormalizedString::from(original);
67 let byte_level = ByteLevel::new();
68 byte_level.normalize(&mut n).unwrap();
69 assert_eq!(&n.get(), &normalized);
70 assert_eq!(
71 n,
72 NormalizedString::new(
73 original.to_string(),
74 normalized.to_string(),
75 vec![
76 (0, 1),
77 (1, 2),
78 (2, 3),
79 (3, 4),
80 (4, 5),
81 (5, 6),
82 (5, 6),
83 (6, 9),
84 (6, 9),
85 (6, 9),
86 (6, 9),
87 (6, 9),
88 (6, 9),
89 (9, 12),
90 (9, 12),
91 (9, 12),
92 (9, 12),
93 (9, 12),
94 (9, 12),
95 (12, 15),
96 (12, 15),
97 (12, 15),
98 (12, 15),
99 (12, 15),
100 (12, 15),
101 (15, 18),
102 (15, 18),
103 (15, 18),
104 (15, 18),
105 (15, 18),
106 (15, 18),
107 (18, 21),
108 (18, 21),
109 (18, 21),
110 (18, 21),
111 (18, 21),
112 (18, 21),
113 (21, 24),
114 (21, 24),
115 (21, 24),
116 (21, 24),
117 (21, 24),
118 (21, 24),
119 (24, 27),
120 (24, 27),
121 (24, 27),
122 (24, 27),
123 (24, 27),
124 (24, 27),
125 (27, 30),
126 (27, 30),
127 (27, 30),
128 (27, 30),
129 (27, 30),
130 (27, 30),
131 (30, 33),
132 (30, 33),
133 (30, 33),
134 (30, 33),
135 (30, 33),
136 (30, 33)
137 ],
138 0
139 )
140 );
141 assert_eq!(
142 n.alignments_original(),
143 vec![
144 (0, 1),
145 (1, 2),
146 (2, 3),
147 (3, 4),
148 (4, 5),
149 (5, 7),
150 (7, 13),
151 (7, 13),
152 (7, 13),
153 (13, 19),
154 (13, 19),
155 (13, 19),
156 (19, 25),
157 (19, 25),
158 (19, 25),
159 (25, 31),
160 (25, 31),
161 (25, 31),
162 (31, 37),
163 (31, 37),
164 (31, 37),
165 (37, 43),
166 (37, 43),
167 (37, 43),
168 (43, 49),
169 (43, 49),
170 (43, 49),
171 (49, 55),
172 (49, 55),
173 (49, 55),
174 (55, 61),
175 (55, 61),
176 (55, 61)
177 ]
178 );
179 }
180}