Skip to main content

trustformers_tokenizers/
simd.rs

1#[cfg(target_arch = "x86_64")]
2use std::arch::x86_64::*;
3use trustformers_core::errors::{Result, TrustformersError};
4
5/// SIMD-optimized tokenization utilities for improved performance
6pub struct SimdTokenizer {
7    /// Lookup table for ASCII character classification
8    ascii_lookup: [u8; 256],
9}
10
11impl SimdTokenizer {
12    /// Create a new SIMD tokenizer
13    pub fn new() -> Self {
14        let mut ascii_lookup = [0u8; 256];
15
16        // Set up character classification flags
17        // Bit 0: alphabetic, Bit 1: numeric, Bit 2: whitespace, Bit 3: punctuation
18        for (i, flags_ref) in ascii_lookup.iter_mut().enumerate() {
19            let ch = i as u8 as char;
20            let mut flags = 0u8;
21
22            if ch.is_alphabetic() {
23                flags |= 1;
24            }
25            if ch.is_numeric() {
26                flags |= 2;
27            }
28            if ch.is_whitespace() {
29                flags |= 4;
30            }
31            if ch.is_ascii_punctuation() {
32                flags |= 8;
33            }
34
35            *flags_ref = flags;
36        }
37
38        Self { ascii_lookup }
39    }
40
41    /// Fast ASCII character classification using SIMD
42    #[cfg(target_arch = "x86_64")]
43    pub fn classify_ascii_chars(&self, text: &[u8]) -> Vec<u8> {
44        if !is_x86_feature_detected!("avx2") {
45            return self.classify_ascii_chars_scalar(text);
46        }
47
48        unsafe { self.classify_ascii_chars_avx2(text) }
49    }
50
51    /// Fast ASCII character classification - fallback for non-x86_64
52    #[cfg(not(target_arch = "x86_64"))]
53    pub fn classify_ascii_chars(&self, text: &[u8]) -> Vec<u8> {
54        self.classify_ascii_chars_scalar(text)
55    }
56
57    /// Fallback scalar implementation for character classification
58    fn classify_ascii_chars_scalar(&self, text: &[u8]) -> Vec<u8> {
59        text.iter().map(|&byte| self.ascii_lookup[byte as usize]).collect()
60    }
61
62    /// AVX2-optimized character classification
63    #[cfg(target_arch = "x86_64")]
64    #[target_feature(enable = "avx2")]
65    unsafe fn classify_ascii_chars_avx2(&self, text: &[u8]) -> Vec<u8> {
66        let mut result = Vec::with_capacity(text.len());
67        let chunks = text.chunks_exact(32);
68        let remainder = chunks.remainder();
69
70        // Process 32 bytes at a time using AVX2
71        for chunk in chunks {
72            let _input = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
73
74            // Character classification using parallel lookups
75            let mut output = [0u8; 32];
76            for i in 0..32 {
77                output[i] = self.ascii_lookup[chunk[i] as usize];
78            }
79
80            result.extend_from_slice(&output);
81        }
82
83        // Process remaining bytes
84        for &byte in remainder {
85            result.push(self.ascii_lookup[byte as usize]);
86        }
87
88        result
89    }
90
91    /// Fast whitespace detection using SIMD
92    #[cfg(target_arch = "x86_64")]
93    pub fn find_whitespace_boundaries(&self, text: &[u8]) -> Vec<usize> {
94        if !is_x86_feature_detected!("avx2") {
95            return self.find_whitespace_boundaries_scalar(text);
96        }
97
98        unsafe { self.find_whitespace_boundaries_avx2(text) }
99    }
100
101    /// Fast whitespace detection - fallback for non-x86_64
102    #[cfg(not(target_arch = "x86_64"))]
103    pub fn find_whitespace_boundaries(&self, text: &[u8]) -> Vec<usize> {
104        self.find_whitespace_boundaries_scalar(text)
105    }
106
107    /// Scalar implementation for whitespace boundary detection
108    fn find_whitespace_boundaries_scalar(&self, text: &[u8]) -> Vec<usize> {
109        let mut boundaries = Vec::new();
110        let mut in_whitespace = false;
111
112        for (i, &byte) in text.iter().enumerate() {
113            let is_whitespace = (self.ascii_lookup[byte as usize] & 4) != 0;
114
115            if is_whitespace != in_whitespace {
116                boundaries.push(i);
117                in_whitespace = is_whitespace;
118            }
119        }
120
121        boundaries
122    }
123
124    /// AVX2-optimized whitespace boundary detection
125    #[cfg(target_arch = "x86_64")]
126    #[target_feature(enable = "avx2")]
127    unsafe fn find_whitespace_boundaries_avx2(&self, text: &[u8]) -> Vec<usize> {
128        let mut boundaries = Vec::new();
129        let mut prev_whitespace_mask = 0u32;
130
131        // Define whitespace characters
132        let space = _mm256_set1_epi8(b' ' as i8);
133        let tab = _mm256_set1_epi8(b'\t' as i8);
134        let newline = _mm256_set1_epi8(b'\n' as i8);
135        let carriage_return = _mm256_set1_epi8(b'\r' as i8);
136
137        let chunks = text.chunks_exact(32);
138        let mut offset = 0;
139
140        for chunk in chunks {
141            let input = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
142
143            // Compare with whitespace characters
144            let space_mask = _mm256_cmpeq_epi8(input, space);
145            let tab_mask = _mm256_cmpeq_epi8(input, tab);
146            let newline_mask = _mm256_cmpeq_epi8(input, newline);
147            let cr_mask = _mm256_cmpeq_epi8(input, carriage_return);
148
149            // Combine all whitespace masks
150            let whitespace_mask = _mm256_or_si256(
151                _mm256_or_si256(space_mask, tab_mask),
152                _mm256_or_si256(newline_mask, cr_mask),
153            );
154
155            let mask_bits = _mm256_movemask_epi8(whitespace_mask) as u32;
156
157            // Find transitions between whitespace and non-whitespace
158            let transitions = mask_bits ^ (mask_bits << 1) ^ prev_whitespace_mask;
159
160            // Extract boundary positions
161            for i in 0..32 {
162                if (transitions & (1 << i)) != 0 {
163                    boundaries.push(offset + i);
164                }
165            }
166
167            prev_whitespace_mask = if (mask_bits & (1 << 31)) != 0 { 1 } else { 0 };
168            offset += 32;
169        }
170
171        // Process remainder with scalar code
172        let remainder = &text[offset..];
173        for (i, &byte) in remainder.iter().enumerate() {
174            let is_whitespace = matches!(byte, b' ' | b'\t' | b'\n' | b'\r');
175            let current_mask = if is_whitespace { 1 } else { 0 };
176
177            if current_mask != prev_whitespace_mask {
178                boundaries.push(offset + i);
179                prev_whitespace_mask = current_mask;
180            }
181        }
182
183        boundaries
184    }
185
186    /// Fast byte-to-UTF8 validation using SIMD
187    #[cfg(target_arch = "x86_64")]
188    pub fn validate_utf8_fast(&self, bytes: &[u8]) -> Result<()> {
189        if !is_x86_feature_detected!("avx2") {
190            return self.validate_utf8_scalar(bytes);
191        }
192
193        unsafe { self.validate_utf8_avx2(bytes) }
194    }
195
196    /// Fast byte-to-UTF8 validation - fallback for non-x86_64
197    #[cfg(not(target_arch = "x86_64"))]
198    pub fn validate_utf8_fast(&self, bytes: &[u8]) -> Result<()> {
199        self.validate_utf8_scalar(bytes)
200    }
201
202    /// Scalar UTF-8 validation
203    fn validate_utf8_scalar(&self, bytes: &[u8]) -> Result<()> {
204        std::str::from_utf8(bytes)
205            .map_err(|e| TrustformersError::invalid_input(format!("Invalid UTF-8: {}", e)))?;
206        Ok(())
207    }
208
209    /// AVX2-optimized UTF-8 validation
210    #[cfg(target_arch = "x86_64")]
211    #[target_feature(enable = "avx2")]
212    unsafe fn validate_utf8_avx2(&self, bytes: &[u8]) -> Result<()> {
213        // Simplified fast path for ASCII-only text
214        let chunks = bytes.chunks_exact(32);
215        let remainder = chunks.remainder();
216
217        for chunk in chunks {
218            let input = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
219            let ascii_mask = _mm256_cmpgt_epi8(_mm256_setzero_si256(), input);
220
221            if _mm256_movemask_epi8(ascii_mask) != 0 {
222                // Contains non-ASCII, fall back to scalar validation
223                return self.validate_utf8_scalar(bytes);
224            }
225        }
226
227        // Check remainder
228        for &byte in remainder {
229            if byte >= 128 {
230                return self.validate_utf8_scalar(bytes);
231            }
232        }
233
234        Ok(())
235    }
236
237    /// Fast case conversion using SIMD
238    #[cfg(target_arch = "x86_64")]
239    pub fn to_lowercase_ascii(&self, text: &[u8]) -> Vec<u8> {
240        if !is_x86_feature_detected!("avx2") {
241            return self.to_lowercase_ascii_scalar(text);
242        }
243
244        unsafe { self.to_lowercase_ascii_avx2(text) }
245    }
246
247    /// Fast case conversion - fallback for non-x86_64
248    #[cfg(not(target_arch = "x86_64"))]
249    pub fn to_lowercase_ascii(&self, text: &[u8]) -> Vec<u8> {
250        self.to_lowercase_ascii_scalar(text)
251    }
252
253    /// Scalar lowercase conversion
254    fn to_lowercase_ascii_scalar(&self, text: &[u8]) -> Vec<u8> {
255        text.iter()
256            .map(|&byte| {
257                if byte.is_ascii_uppercase() {
258                    byte + 32 // Convert to lowercase
259                } else {
260                    byte
261                }
262            })
263            .collect()
264    }
265
266    /// AVX2-optimized lowercase conversion
267    #[cfg(target_arch = "x86_64")]
268    #[target_feature(enable = "avx2")]
269    unsafe fn to_lowercase_ascii_avx2(&self, text: &[u8]) -> Vec<u8> {
270        let mut result = Vec::with_capacity(text.len());
271        let chunks = text.chunks_exact(32);
272        let remainder = chunks.remainder();
273
274        let a_upper = _mm256_set1_epi8(b'A' as i8);
275        let z_upper = _mm256_set1_epi8(b'Z' as i8);
276        let to_lower_offset = _mm256_set1_epi8(32);
277
278        for chunk in chunks {
279            let input = _mm256_loadu_si256(chunk.as_ptr() as *const __m256i);
280
281            // Create mask for uppercase letters
282            let ge_a = _mm256_cmpgt_epi8(input, _mm256_sub_epi8(a_upper, _mm256_set1_epi8(1)));
283            let le_z = _mm256_cmpgt_epi8(_mm256_add_epi8(z_upper, _mm256_set1_epi8(1)), input);
284            let is_upper = _mm256_and_si256(ge_a, le_z);
285
286            // Apply lowercase conversion
287            let lowercase_offset = _mm256_and_si256(is_upper, to_lower_offset);
288            let output = _mm256_add_epi8(input, lowercase_offset);
289
290            let mut temp = [0u8; 32];
291            _mm256_storeu_si256(temp.as_mut_ptr() as *mut __m256i, output);
292            result.extend_from_slice(&temp);
293        }
294
295        // Process remainder
296        for &byte in remainder {
297            let converted = if byte >= b'A' && byte <= b'Z' { byte + 32 } else { byte };
298            result.push(converted);
299        }
300
301        result
302    }
303
304    /// High-performance text preprocessing pipeline
305    pub fn preprocess_text(&self, text: &str) -> Result<Vec<String>> {
306        let bytes = text.as_bytes();
307
308        // Step 1: Validate UTF-8
309        self.validate_utf8_fast(bytes)?;
310
311        // Step 2: Find word boundaries using whitespace detection
312        let boundaries = self.find_whitespace_boundaries(bytes);
313
314        // Step 3: Extract tokens
315        let mut tokens = Vec::new();
316        let mut start = 0;
317
318        for &boundary in &boundaries {
319            if start < boundary {
320                let token_bytes = &bytes[start..boundary];
321                let token = String::from_utf8_lossy(token_bytes).into_owned();
322                if !token.trim().is_empty() {
323                    tokens.push(token);
324                }
325            }
326            start = boundary;
327        }
328
329        // Add final token if any
330        if start < bytes.len() {
331            let token_bytes = &bytes[start..];
332            let token = String::from_utf8_lossy(token_bytes).into_owned();
333            if !token.trim().is_empty() {
334                tokens.push(token);
335            }
336        }
337
338        Ok(tokens)
339    }
340}
341
342impl Default for SimdTokenizer {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_simd_character_classification() {
354        let tokenizer = SimdTokenizer::new();
355        let text = b"Hello, World! 123";
356
357        let classifications = tokenizer.classify_ascii_chars(text);
358
359        // 'H' should be alphabetic (bit 0 set)
360        assert_eq!(classifications[0] & 1, 1);
361
362        // ',' should be punctuation (bit 3 set)
363        assert_eq!(classifications[5] & 8, 8);
364
365        // ' ' should be whitespace (bit 2 set)
366        assert_eq!(classifications[6] & 4, 4);
367
368        // '1' should be numeric (bit 1 set)
369        assert_eq!(classifications[14] & 2, 2);
370    }
371
372    #[test]
373    fn test_simd_whitespace_boundaries() {
374        let tokenizer = SimdTokenizer::new();
375        let text = b"Hello World Test";
376
377        let boundaries = tokenizer.find_whitespace_boundaries(text);
378
379        // Should find boundaries at positions 5 and 11 (before/after spaces)
380        assert!(boundaries.contains(&5));
381        assert!(boundaries.contains(&6));
382        assert!(boundaries.contains(&11));
383        assert!(boundaries.contains(&12));
384    }
385
386    #[test]
387    fn test_simd_utf8_validation() {
388        let tokenizer = SimdTokenizer::new();
389
390        // Valid ASCII
391        assert!(tokenizer.validate_utf8_fast(b"Hello World").is_ok());
392
393        // Valid UTF-8
394        assert!(tokenizer.validate_utf8_fast("Hello 世界".as_bytes()).is_ok());
395
396        // Invalid UTF-8
397        assert!(tokenizer.validate_utf8_fast(&[0xFF, 0xFE]).is_err());
398    }
399
400    #[test]
401    fn test_simd_lowercase() {
402        let tokenizer = SimdTokenizer::new();
403        let text = b"Hello WORLD Test";
404
405        let lowercase = tokenizer.to_lowercase_ascii(text);
406        let expected = b"hello world test";
407
408        assert_eq!(lowercase, expected);
409    }
410
411    #[test]
412    fn test_simd_preprocess_pipeline() {
413        let tokenizer = SimdTokenizer::new();
414        let text = "Hello, World! How are you?";
415
416        let tokens = tokenizer.preprocess_text(text).expect("Operation failed in test");
417
418        assert!(tokens.len() > 0);
419        assert!(tokens.contains(&"Hello,".to_string()));
420        assert!(tokens.contains(&"World!".to_string()));
421        assert!(tokens.contains(&"How".to_string()));
422    }
423
424    #[test]
425    fn test_simd_empty_input() {
426        let tokenizer = SimdTokenizer::new();
427
428        assert_eq!(tokenizer.classify_ascii_chars(b"").len(), 0);
429        assert_eq!(tokenizer.find_whitespace_boundaries(b"").len(), 0);
430        assert!(tokenizer.validate_utf8_fast(b"").is_ok());
431        assert_eq!(tokenizer.to_lowercase_ascii(b"").len(), 0);
432    }
433
434    #[test]
435    fn test_simd_long_input() {
436        let tokenizer = SimdTokenizer::new();
437        let text = "A".repeat(1000);
438
439        let lowercase = tokenizer.to_lowercase_ascii(text.as_bytes());
440        let expected = "a".repeat(1000);
441
442        assert_eq!(lowercase, expected.as_bytes());
443    }
444}