trustformers_tokenizers/
simd.rs1#[cfg(target_arch = "x86_64")]
2use std::arch::x86_64::*;
3use trustformers_core::errors::{Result, TrustformersError};
4
5pub struct SimdTokenizer {
7 ascii_lookup: [u8; 256],
9}
10
11impl SimdTokenizer {
12 pub fn new() -> Self {
14 let mut ascii_lookup = [0u8; 256];
15
16 for (i, flags_ref) in ascii_lookup.iter_mut().enumerate() {
19 let ch = i as u8 as char;
20 let mut flags = 0u8;
21
22 if ch.is_alphabetic() {
23 flags |= 1;
24 }
25 if ch.is_numeric() {
26 flags |= 2;
27 }
28 if ch.is_whitespace() {
29 flags |= 4;
30 }
31 if ch.is_ascii_punctuation() {
32 flags |= 8;
33 }
34
35 *flags_ref = flags;
36 }
37
38 Self { ascii_lookup }
39 }
40
41 #[cfg(target_arch = "x86_64")]
43 pub fn classify_ascii_chars(&self, text: &[u8]) -> Vec<u8> {
44 if !is_x86_feature_detected!("avx2") {
45 return self.classify_ascii_chars_scalar(text);
46 }
47
48 unsafe { self.classify_ascii_chars_avx2(text) }
49 }
50
51 #[cfg(not(target_arch = "x86_64"))]
53 pub fn classify_ascii_chars(&self, text: &[u8]) -> Vec<u8> {
54 self.classify_ascii_chars_scalar(text)
55 }
56
57 fn classify_ascii_chars_scalar(&self, text: &[u8]) -> Vec<u8> {
59 text.iter().map(|&byte| self.ascii_lookup[byte as usize]).collect()
60 }
61
62 #[cfg(target_arch = "x86_64")]
64 #[target_feature(enable = "avx2")]
65 unsafe fn classify_ascii_chars_avx2(&self, text: &[u8]) -> Vec<u8> {
66 let mut result = Vec::with_capacity(text.len());
67 let chunks = text.chunks_exact(32);
68 let remainder = chunks.remainder();
69
70 for chunk in chunks {
72 let _input = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
73
74 let mut output = [0u8; 32];
76 for i in 0..32 {
77 output[i] = self.ascii_lookup[chunk[i] as usize];
78 }
79
80 result.extend_from_slice(&output);
81 }
82
83 for &byte in remainder {
85 result.push(self.ascii_lookup[byte as usize]);
86 }
87
88 result
89 }
90
91 #[cfg(target_arch = "x86_64")]
93 pub fn find_whitespace_boundaries(&self, text: &[u8]) -> Vec<usize> {
94 if !is_x86_feature_detected!("avx2") {
95 return self.find_whitespace_boundaries_scalar(text);
96 }
97
98 unsafe { self.find_whitespace_boundaries_avx2(text) }
99 }
100
101 #[cfg(not(target_arch = "x86_64"))]
103 pub fn find_whitespace_boundaries(&self, text: &[u8]) -> Vec<usize> {
104 self.find_whitespace_boundaries_scalar(text)
105 }
106
107 fn find_whitespace_boundaries_scalar(&self, text: &[u8]) -> Vec<usize> {
109 let mut boundaries = Vec::new();
110 let mut in_whitespace = false;
111
112 for (i, &byte) in text.iter().enumerate() {
113 let is_whitespace = (self.ascii_lookup[byte as usize] & 4) != 0;
114
115 if is_whitespace != in_whitespace {
116 boundaries.push(i);
117 in_whitespace = is_whitespace;
118 }
119 }
120
121 boundaries
122 }
123
124 #[cfg(target_arch = "x86_64")]
126 #[target_feature(enable = "avx2")]
127 unsafe fn find_whitespace_boundaries_avx2(&self, text: &[u8]) -> Vec<usize> {
128 let mut boundaries = Vec::new();
129 let mut prev_whitespace_mask = 0u32;
130
131 let space = _mm256_set1_epi8(b' ' as i8);
133 let tab = _mm256_set1_epi8(b'\t' as i8);
134 let newline = _mm256_set1_epi8(b'\n' as i8);
135 let carriage_return = _mm256_set1_epi8(b'\r' as i8);
136
137 let chunks = text.chunks_exact(32);
138 let mut offset = 0;
139
140 for chunk in chunks {
141 let input = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
142
143 let space_mask = _mm256_cmpeq_epi8(input, space);
145 let tab_mask = _mm256_cmpeq_epi8(input, tab);
146 let newline_mask = _mm256_cmpeq_epi8(input, newline);
147 let cr_mask = _mm256_cmpeq_epi8(input, carriage_return);
148
149 let whitespace_mask = _mm256_or_si256(
151 _mm256_or_si256(space_mask, tab_mask),
152 _mm256_or_si256(newline_mask, cr_mask),
153 );
154
155 let mask_bits = _mm256_movemask_epi8(whitespace_mask) as u32;
156
157 let transitions = mask_bits ^ (mask_bits << 1) ^ prev_whitespace_mask;
159
160 for i in 0..32 {
162 if (transitions & (1 << i)) != 0 {
163 boundaries.push(offset + i);
164 }
165 }
166
167 prev_whitespace_mask = if (mask_bits & (1 << 31)) != 0 { 1 } else { 0 };
168 offset += 32;
169 }
170
171 let remainder = &text[offset..];
173 for (i, &byte) in remainder.iter().enumerate() {
174 let is_whitespace = matches!(byte, b' ' | b'\t' | b'\n' | b'\r');
175 let current_mask = if is_whitespace { 1 } else { 0 };
176
177 if current_mask != prev_whitespace_mask {
178 boundaries.push(offset + i);
179 prev_whitespace_mask = current_mask;
180 }
181 }
182
183 boundaries
184 }
185
186 #[cfg(target_arch = "x86_64")]
188 pub fn validate_utf8_fast(&self, bytes: &[u8]) -> Result<()> {
189 if !is_x86_feature_detected!("avx2") {
190 return self.validate_utf8_scalar(bytes);
191 }
192
193 unsafe { self.validate_utf8_avx2(bytes) }
194 }
195
196 #[cfg(not(target_arch = "x86_64"))]
198 pub fn validate_utf8_fast(&self, bytes: &[u8]) -> Result<()> {
199 self.validate_utf8_scalar(bytes)
200 }
201
202 fn validate_utf8_scalar(&self, bytes: &[u8]) -> Result<()> {
204 std::str::from_utf8(bytes)
205 .map_err(|e| TrustformersError::invalid_input(format!("Invalid UTF-8: {}", e)))?;
206 Ok(())
207 }
208
209 #[cfg(target_arch = "x86_64")]
211 #[target_feature(enable = "avx2")]
212 unsafe fn validate_utf8_avx2(&self, bytes: &[u8]) -> Result<()> {
213 let chunks = bytes.chunks_exact(32);
215 let remainder = chunks.remainder();
216
217 for chunk in chunks {
218 let input = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
219 let ascii_mask = _mm256_cmpgt_epi8(_mm256_setzero_si256(), input);
220
221 if _mm256_movemask_epi8(ascii_mask) != 0 {
222 return self.validate_utf8_scalar(bytes);
224 }
225 }
226
227 for &byte in remainder {
229 if byte >= 128 {
230 return self.validate_utf8_scalar(bytes);
231 }
232 }
233
234 Ok(())
235 }
236
237 #[cfg(target_arch = "x86_64")]
239 pub fn to_lowercase_ascii(&self, text: &[u8]) -> Vec<u8> {
240 if !is_x86_feature_detected!("avx2") {
241 return self.to_lowercase_ascii_scalar(text);
242 }
243
244 unsafe { self.to_lowercase_ascii_avx2(text) }
245 }
246
247 #[cfg(not(target_arch = "x86_64"))]
249 pub fn to_lowercase_ascii(&self, text: &[u8]) -> Vec<u8> {
250 self.to_lowercase_ascii_scalar(text)
251 }
252
253 fn to_lowercase_ascii_scalar(&self, text: &[u8]) -> Vec<u8> {
255 text.iter()
256 .map(|&byte| {
257 if byte.is_ascii_uppercase() {
258 byte + 32 } else {
260 byte
261 }
262 })
263 .collect()
264 }
265
266 #[cfg(target_arch = "x86_64")]
268 #[target_feature(enable = "avx2")]
269 unsafe fn to_lowercase_ascii_avx2(&self, text: &[u8]) -> Vec<u8> {
270 let mut result = Vec::with_capacity(text.len());
271 let chunks = text.chunks_exact(32);
272 let remainder = chunks.remainder();
273
274 let a_upper = _mm256_set1_epi8(b'A' as i8);
275 let z_upper = _mm256_set1_epi8(b'Z' as i8);
276 let to_lower_offset = _mm256_set1_epi8(32);
277
278 for chunk in chunks {
279 let input = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
280
281 let ge_a = _mm256_cmpgt_epi8(input, _mm256_sub_epi8(a_upper, _mm256_set1_epi8(1)));
283 let le_z = _mm256_cmpgt_epi8(_mm256_add_epi8(z_upper, _mm256_set1_epi8(1)), input);
284 let is_upper = _mm256_and_si256(ge_a, le_z);
285
286 let lowercase_offset = _mm256_and_si256(is_upper, to_lower_offset);
288 let output = _mm256_add_epi8(input, lowercase_offset);
289
290 let mut temp = [0u8; 32];
291 _mm256_storeu_si256(temp.as_mut_ptr() as *mut __m256i, output);
292 result.extend_from_slice(&temp);
293 }
294
295 for &byte in remainder {
297 let converted = if byte >= b'A' && byte <= b'Z' { byte + 32 } else { byte };
298 result.push(converted);
299 }
300
301 result
302 }
303
304 pub fn preprocess_text(&self, text: &str) -> Result<Vec<String>> {
306 let bytes = text.as_bytes();
307
308 self.validate_utf8_fast(bytes)?;
310
311 let boundaries = self.find_whitespace_boundaries(bytes);
313
314 let mut tokens = Vec::new();
316 let mut start = 0;
317
318 for &boundary in &boundaries {
319 if start < boundary {
320 let token_bytes = &bytes[start..boundary];
321 let token = String::from_utf8_lossy(token_bytes).into_owned();
322 if !token.trim().is_empty() {
323 tokens.push(token);
324 }
325 }
326 start = boundary;
327 }
328
329 if start < bytes.len() {
331 let token_bytes = &bytes[start..];
332 let token = String::from_utf8_lossy(token_bytes).into_owned();
333 if !token.trim().is_empty() {
334 tokens.push(token);
335 }
336 }
337
338 Ok(tokens)
339 }
340}
341
342impl Default for SimdTokenizer {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_simd_character_classification() {
354 let tokenizer = SimdTokenizer::new();
355 let text = b"Hello, World! 123";
356
357 let classifications = tokenizer.classify_ascii_chars(text);
358
359 assert_eq!(classifications[0] & 1, 1);
361
362 assert_eq!(classifications[5] & 8, 8);
364
365 assert_eq!(classifications[6] & 4, 4);
367
368 assert_eq!(classifications[14] & 2, 2);
370 }
371
372 #[test]
373 fn test_simd_whitespace_boundaries() {
374 let tokenizer = SimdTokenizer::new();
375 let text = b"Hello World Test";
376
377 let boundaries = tokenizer.find_whitespace_boundaries(text);
378
379 assert!(boundaries.contains(&5));
381 assert!(boundaries.contains(&6));
382 assert!(boundaries.contains(&11));
383 assert!(boundaries.contains(&12));
384 }
385
386 #[test]
387 fn test_simd_utf8_validation() {
388 let tokenizer = SimdTokenizer::new();
389
390 assert!(tokenizer.validate_utf8_fast(b"Hello World").is_ok());
392
393 assert!(tokenizer.validate_utf8_fast("Hello 世界".as_bytes()).is_ok());
395
396 assert!(tokenizer.validate_utf8_fast(&[0xFF, 0xFE]).is_err());
398 }
399
400 #[test]
401 fn test_simd_lowercase() {
402 let tokenizer = SimdTokenizer::new();
403 let text = b"Hello WORLD Test";
404
405 let lowercase = tokenizer.to_lowercase_ascii(text);
406 let expected = b"hello world test";
407
408 assert_eq!(lowercase, expected);
409 }
410
411 #[test]
412 fn test_simd_preprocess_pipeline() {
413 let tokenizer = SimdTokenizer::new();
414 let text = "Hello, World! How are you?";
415
416 let tokens = tokenizer.preprocess_text(text).expect("Operation failed in test");
417
418 assert!(tokens.len() > 0);
419 assert!(tokens.contains(&"Hello,".to_string()));
420 assert!(tokens.contains(&"World!".to_string()));
421 assert!(tokens.contains(&"How".to_string()));
422 }
423
424 #[test]
425 fn test_simd_empty_input() {
426 let tokenizer = SimdTokenizer::new();
427
428 assert_eq!(tokenizer.classify_ascii_chars(b"").len(), 0);
429 assert_eq!(tokenizer.find_whitespace_boundaries(b"").len(), 0);
430 assert!(tokenizer.validate_utf8_fast(b"").is_ok());
431 assert_eq!(tokenizer.to_lowercase_ascii(b"").len(), 0);
432 }
433
434 #[test]
435 fn test_simd_long_input() {
436 let tokenizer = SimdTokenizer::new();
437 let text = "A".repeat(1000);
438
439 let lowercase = tokenizer.to_lowercase_ascii(text.as_bytes());
440 let expected = "a".repeat(1000);
441
442 assert_eq!(lowercase, expected.as_bytes());
443 }
444}