Skip to main content

thread_utilities/
simd.rs

1// SPDX-FileCopyrightText: 2025 Knitli Inc. <knitli@knit.li>
2// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
3// SPDX-License-Identifier: AGPL-3.0-or-later
4//! SIMD optimized utilities for string processing.
5//!
6//! This module provides a series of SIMD optimized functions for string processing.
7//! Its operations use the `simdeez` crate, along with `memchr` for strong SIMD support.
8//! Both libraries provide SIMD support for wasm32, `x86_64/x86`, and aarch64 and can find
9//! optimal instruction sets at runtime.
10//! If no SIMD support is available, they will fall back to scalar operations.
11
12use memchr::memmem::FinderRev;
13use simdeez::{prelude::*, simd_runtime_generate};
14use std::sync::OnceLock;
15
16/// UTF-8 continuation bytes have pattern 10xxxxxx (0x80-0xBF)
17const UTF_8_CONTINUATION_PATTERN: i8 = 0b1000_0000_u8 as i8;
18/// We want to count bytes that are NOT continuation bytes
19/// So we mask out the continuation pattern
20/// Non-continuation bytes have pattern 11xxxxxx (0xC0-0xFF)
21const NON_UTF_8_CONTINUATION_PATTERN: i8 = 0b1100_0000_u8 as i8;
22static REV_LINE_FINDER: OnceLock<FinderRev> = OnceLock::new();
23
24// Checks if a string is all ascii.
25simd_runtime_generate!(
26    pub fn is_ascii_simd(text: &str) -> bool {
27        let bytes = text.as_bytes();
28        let len = bytes.len();
29
30        // reinterpret u8 as i8 slice (safe because underlying bits match)
31        let bytes_i8 = unsafe { std::slice::from_raw_parts(bytes.as_ptr().cast::<i8>(), len) };
32
33        let mut remainder = bytes_i8;
34
35        // Process in vector-width chunks
36        while remainder.len() >= S::Vi8::WIDTH {
37            let chunk = &remainder[..S::Vi8::WIDTH];
38            let v = S::Vi8::load_from_slice(chunk);
39
40            // For ASCII, all values must be >= 0 (since ASCII is 0..127)
41            let mask = v.cmp_lt(S::Vi8::set1(0));
42            // Check if any lane is negative (non-ASCII)
43            // get_mask() returns a bitmask, if any bit is set, it means non-ASCII was found
44            if mask.get_mask() != 0 {
45                return false;
46            }
47
48            remainder = &remainder[S::Vi8::WIDTH..];
49        }
50
51        // Handle remaining bytes
52        remainder.iter().all(|&b| b >= 0)
53    }
54);
55
56// Find the last occurrence of a byte value in a slice, searching backwards
57// Returns the index of the last occurrence, or None if not found
58simd_runtime_generate!(
59    fn find_last_byte_simd(haystack: &[u8], needle: u8, is_eol: bool) -> Option<usize> {
60        if haystack.is_empty() {
61            return None;
62        }
63        if is_eol {
64            // Special case for newline, use cached finder
65            // Use into_owned() to ensure the FinderRev outlives the reference to its needle (it doesn't need after it's constructed)
66            let line_finder =
67                REV_LINE_FINDER.get_or_init(|| FinderRev::new(&[needle]).into_owned());
68            return line_finder.rfind(haystack);
69        }
70        let bound_needle = &[needle];
71        let finder = FinderRev::new(bound_needle);
72
73        finder.rfind(haystack)
74    }
75);
76
77// Count UTF-8 characters in a byte slice using SIMD
78// This counts by identifying non-continuation bytes
79simd_runtime_generate!(
80    fn count_utf8_chars_simd(bytes: &[u8]) -> usize {
81        let len = bytes.len();
82        if len == 0 {
83            return 0;
84        }
85
86        // Convert to i8 for SIMD operations
87        let bytes_i8 = unsafe { std::slice::from_raw_parts(bytes.as_ptr().cast::<i8>(), len) };
88
89        let mut remainder = bytes_i8;
90        let mut char_count = 0;
91
92        let continuation_pattern = S::Vi8::set1(UTF_8_CONTINUATION_PATTERN);
93        let mask_pattern = S::Vi8::set1(NON_UTF_8_CONTINUATION_PATTERN);
94
95        // Process in SIMD chunks
96        while remainder.len() >= S::Vi8::WIDTH {
97            let chunk = &remainder[..S::Vi8::WIDTH];
98            let v = S::Vi8::load_from_slice(chunk);
99
100            // Check which bytes are NOT continuation bytes
101            // Continuation bytes: (byte & 0b11000000) == 0b10000000
102            let masked = v & mask_pattern;
103            let is_continuation = masked.cmp_eq(continuation_pattern);
104
105            // Count non-continuation bytes
106            let mask = is_continuation.get_mask();
107            // Count zeros in the mask (non-continuation bytes)
108            char_count += S::Vi8::WIDTH - mask.count_ones() as usize;
109
110            remainder = &remainder[S::Vi8::WIDTH..];
111        }
112
113        // Handle remaining bytes
114        for &byte in remainder {
115            if (byte as u8) & NON_UTF_8_CONTINUATION_PATTERN as u8
116                != UTF_8_CONTINUATION_PATTERN as u8
117            {
118                char_count += 1;
119            }
120        }
121
122        char_count
123    }
124);
125
126/// Optimized character column calculation with SIMD, finding the last newline character's index
127///
128/// It first checks if the line is entirely `ascii` with [`is_ascii_simd`],
129/// and if so, uses a faster search strategy with [`find_last_byte_simd`].
130/// If there are utf-8 characters present, it still uses the same approach but then
131/// must use [`count_utf8_chars_simd`] to count non-continuation bytes.
132/// All operations are highly optimized with full SIMD support.
133#[inline]
134#[must_use]
135pub fn get_char_column_simd(text: &str, offset: usize) -> usize {
136    if offset == 0 {
137        return 0;
138    }
139
140    let bytes = text.as_bytes();
141    if offset > bytes.len() {
142        return 0;
143    }
144
145    let search_slice = &bytes[..offset];
146
147    // Check if the text is ASCII for fast path
148    if is_ascii_simd(text) {
149        // ASCII fast path: find last newline and count bytes
150        match find_last_byte_simd(search_slice, b'\n', true) {
151            Some(newline_pos) => offset - newline_pos - 1,
152            None => offset, // No newline found, entire offset is the column
153        }
154    } else {
155        // UTF-8 path: find last newline then count UTF-8 characters
156        match find_last_byte_simd(search_slice, b'\n', true) {
157            Some(newline_pos) => {
158                let line_start = newline_pos + 1;
159                let line_bytes = &search_slice[line_start..];
160                count_utf8_chars_simd(line_bytes)
161            }
162            None => {
163                // No newline found, count characters from start
164                count_utf8_chars_simd(search_slice)
165            }
166        }
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_empty_string() {
176        assert!(is_ascii_simd(""));
177    }
178
179    #[test]
180    fn test_pure_ascii() {
181        assert!(is_ascii_simd("Hello, World!"));
182        assert!(is_ascii_simd("123456789"));
183        assert!(is_ascii_simd("ABCDEFGHIJKLMNOPQRSTUVWXYZ"));
184        assert!(is_ascii_simd("abcdefghijklmnopqrstuvwxyz"));
185        assert!(is_ascii_simd("!@#$%^&*()_+-=[]{}|;':\",./<>?"));
186    }
187
188    #[test]
189    fn test_ascii_with_newlines_and_tabs() {
190        assert!(is_ascii_simd("Hello\nWorld\t!"));
191        assert!(is_ascii_simd("\t\n\r"));
192    }
193
194    #[test]
195    fn test_ascii_control_characters() {
196        // Test ASCII control characters (0-31, 127)
197        assert!(is_ascii_simd("\x00\x01\x02\x03\x04\x05\x06\x07"));
198        assert!(is_ascii_simd("\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F"));
199        assert!(is_ascii_simd("\x10\x11\x12\x13\x14\x15\x16\x17"));
200        assert!(is_ascii_simd("\x18\x19\x1A\x1B\x1C\x1D\x1E\x1F"));
201        assert!(is_ascii_simd("\x7F")); // DEL character
202    }
203
204    #[test]
205    fn test_non_ascii_characters() {
206        // UTF-8 encoded non-ASCII characters
207        assert!(!is_ascii_simd("café")); // contains é
208        assert!(!is_ascii_simd("naïve")); // contains ï
209        assert!(!is_ascii_simd("résumé")); // contains é
210        assert!(!is_ascii_simd("🚀")); // emoji
211        assert!(!is_ascii_simd("こんにちは")); // Japanese
212        assert!(!is_ascii_simd("Привет")); // Russian
213        assert!(!is_ascii_simd("مرحبا")); // Arabic
214        // all together for fun
215        assert!(!is_ascii_simd(
216            "café مرحبا こんにちは 🚀 Привет résumé naïve"
217        ));
218    }
219
220    #[test]
221    fn test_mixed_ascii_non_ascii() {
222        assert!(!is_ascii_simd("Hello café"));
223        assert!(!is_ascii_simd("ASCII and 🚀"));
224        assert!(!is_ascii_simd("test\u{200B}")); // zero-width space
225    }
226
227    #[test]
228    fn test_long_ascii_strings() {
229        // Test strings longer than typical SIMD vector width
230        let long_ascii = "a".repeat(1000);
231        assert!(is_ascii_simd(&long_ascii));
232
233        let long_ascii_mixed = "ABC123!@#".repeat(100);
234        assert!(is_ascii_simd(&long_ascii_mixed));
235    }
236
237    #[test]
238    fn test_long_non_ascii_strings() {
239        let long_non_ascii = "café".repeat(100);
240        assert!(!is_ascii_simd(&long_non_ascii));
241    }
242
243    #[test]
244    fn test_ascii_boundary_values() {
245        // Test characters at ASCII boundaries
246        assert!(is_ascii_simd("\x00")); // NULL (0)
247        assert!(is_ascii_simd("\x7F")); // DEL (127)
248
249        // Test non-ASCII characters (properly encoded UTF-8)
250        assert!(!is_ascii_simd("ü")); // UTF-8 encoded ü (first byte is 0xC3)
251        assert!(!is_ascii_simd("€")); // UTF-8 encoded € (first byte is 0xE2)
252    }
253
254    #[test]
255    fn test_various_lengths() {
256        // Test strings of various lengths to exercise both SIMD and scalar paths
257        for i in 1..=100 {
258            let ascii_string = "a".repeat(i);
259            assert!(is_ascii_simd(&ascii_string), "Failed for length {i}");
260        }
261    }
262
263    #[test]
264    fn test_non_ascii_at_different_positions() {
265        // Non-ASCII at the beginning
266        assert!(!is_ascii_simd("éabc"));
267
268        // Non-ASCII in the middle
269        assert!(!is_ascii_simd("abéc"));
270
271        // Non-ASCII at the end
272        assert!(!is_ascii_simd("abcé"));
273
274        // Multiple non-ASCII characters
275        assert!(!is_ascii_simd("éabcé"));
276    }
277
278    #[test]
279    fn test_consistency_with_str_is_ascii() {
280        let test_strings = vec![
281            "",
282            "Hello",
283            "café",
284            "🚀",
285            "ASCII123!@#",
286            "test\u{200B}",
287            "\x00\x7F",
288        ];
289
290        // Test regular strings
291        for test_str in &test_strings {
292            assert_eq!(
293                is_ascii_simd(test_str),
294                test_str.is_ascii(),
295                "Mismatch for string: {test_str:?}"
296            );
297        }
298
299        // Test long string separately
300        let long_string = "a".repeat(1000);
301        assert_eq!(
302            is_ascii_simd(&long_string),
303            long_string.is_ascii(),
304            "Mismatch for long string"
305        );
306
307        // Test additional non-ASCII characters
308        let non_ascii_chars = ["ü", "€", "漢", "🎉"];
309        for ch in &non_ascii_chars {
310            assert_eq!(
311                is_ascii_simd(ch),
312                ch.is_ascii(),
313                "Mismatch for non-ASCII character: {ch:?}"
314            );
315        }
316    }
317
318    #[test]
319    fn test_simd_vector_width_boundaries() {
320        // Test strings that are exactly SIMD vector width and around those boundaries
321        // Common SIMD widths are 16, 32, 64 bytes
322        for width in [16, 32, 64] {
323            // Exactly vector width
324            let exact = "a".repeat(width);
325            assert!(is_ascii_simd(&exact));
326
327            // One less than vector width
328            let one_less = "a".repeat(width - 1);
329            assert!(is_ascii_simd(&one_less));
330
331            // One more than vector width
332            let one_more = "a".repeat(width + 1);
333            assert!(is_ascii_simd(&one_more));
334
335            // Non-ASCII at exact boundary
336            let mut boundary_test = "a".repeat(width - 1);
337            boundary_test.push('é');
338            assert!(!is_ascii_simd(&boundary_test));
339        }
340    }
341
342    #[test]
343    fn test_all_ascii_characters() {
344        // Test all valid ASCII characters (0-127)
345        let mut all_ascii = String::new();
346        for i in 0u8..=127 {
347            all_ascii.push(i as char);
348        }
349        assert!(is_ascii_simd(&all_ascii));
350    }
351
352    #[test]
353    fn debug_simple_case() {
354        // Test with simple ASCII first
355        assert!(is_ascii_simd("a"));
356        assert!(is_ascii_simd("aa"));
357        assert!(is_ascii_simd("aaa"));
358
359        // Test with simple non-ASCII
360        assert!(!is_ascii_simd("é"));
361
362        println!("Simple cases work");
363    }
364
365    // Tests for find_last_byte_simd
366    #[test]
367    fn test_find_last_byte_empty() {
368        assert_eq!(find_last_byte_simd(&[], b'a', false), None);
369    }
370
371    #[test]
372    fn test_find_last_byte_single() {
373        assert_eq!(find_last_byte_simd(b"a", b'a', false), Some(0));
374        assert_eq!(find_last_byte_simd(b"a", b'b', false), None);
375    }
376
377    #[test]
378    fn test_find_last_byte_multiple() {
379        let haystack = b"hello world hello";
380        assert_eq!(find_last_byte_simd(haystack, b'l', false), Some(15)); // Last 'l'
381        assert_eq!(find_last_byte_simd(haystack, b'h', false), Some(12)); // Last 'h'
382        assert_eq!(find_last_byte_simd(haystack, b'o', false), Some(16)); // Last 'o'
383        assert_eq!(find_last_byte_simd(haystack, b'x', false), None); // Not found
384    }
385
386    #[test]
387    fn test_find_last_byte_newlines() {
388        let text = b"line1\nline2\nline3";
389        assert_eq!(find_last_byte_simd(text, b'\n', true), Some(11)); // Last newline
390
391        let single_line = b"no newlines here";
392        assert_eq!(find_last_byte_simd(single_line, b'\n', true), None);
393    }
394
395    #[test]
396    fn test_find_last_byte_long() {
397        // Test with strings longer than SIMD width
398        let long_text = "a".repeat(100) + "b" + &"a".repeat(100);
399        let bytes = long_text.as_bytes();
400        assert_eq!(find_last_byte_simd(bytes, b'b', false), Some(100));
401    }
402
403    // Tests for count_utf8_chars_simd
404    #[test]
405    fn test_count_utf8_chars_empty() {
406        assert_eq!(count_utf8_chars_simd(&[]), 0);
407    }
408
409    #[test]
410    fn test_count_utf8_chars_ascii() {
411        assert_eq!(count_utf8_chars_simd(b"hello"), 5);
412        assert_eq!(count_utf8_chars_simd(b"Hello, World!"), 13);
413        assert_eq!(count_utf8_chars_simd(b"123"), 3);
414    }
415
416    #[test]
417    fn test_count_utf8_chars_utf8() {
418        // "café" in UTF-8: c(1) a(1) f(1) é(2 bytes: 0xC3 0xA9)
419        assert_eq!(count_utf8_chars_simd("café".as_bytes()), 4);
420
421        // "🚀" in UTF-8: 4 bytes (0xF0 0x9F 0x9A 0x80)
422        assert_eq!(count_utf8_chars_simd("🚀".as_bytes()), 1);
423
424        // Mixed: "Hello🚀" = 5 ASCII + 1 emoji = 6 chars
425        assert_eq!(count_utf8_chars_simd("Hello🚀".as_bytes()), 6);
426    }
427
428    #[test]
429    fn test_count_utf8_chars_consistency() {
430        let test_strings = vec!["Hello", "café", "🚀", "Hello, 世界!", "résumé", "测试", ""];
431
432        for test_str in test_strings {
433            let simd_count = count_utf8_chars_simd(test_str.as_bytes());
434            let std_count = test_str.chars().count();
435            assert_eq!(simd_count, std_count, "Mismatch for string: {test_str:?}");
436        }
437    }
438
439    // Tests for get_char_column_simd
440    #[test]
441    fn test_get_char_column_simple() {
442        // Simple case: no newlines
443        assert_eq!(get_char_column_simd("hello", 5), 5);
444        assert_eq!(get_char_column_simd("hello", 3), 3);
445        assert_eq!(get_char_column_simd("hello", 0), 0);
446    }
447
448    #[test]
449    fn test_get_char_column_with_newlines() {
450        let text = "line1\nline2\nline3";
451
452        // Position at start of each line
453        assert_eq!(get_char_column_simd(text, 0), 0); // Start of "line1"
454        assert_eq!(get_char_column_simd(text, 6), 0); // Start of "line2"
455        assert_eq!(get_char_column_simd(text, 12), 0); // Start of "line3"
456
457        // Positions within lines
458        assert_eq!(get_char_column_simd(text, 3), 3); // "lin|e1"
459        assert_eq!(get_char_column_simd(text, 9), 3); // "lin|e2"
460        assert_eq!(get_char_column_simd(text, 15), 3); // "lin|e3"
461    }
462
463    #[test]
464    fn test_get_char_column_utf8() {
465        // Test with UTF-8 characters
466        let text = "café\nnaïve";
467
468        // Position within first line: "ca|fé" = position 2
469        assert_eq!(get_char_column_simd(text, 2), 2);
470
471        // Position at start of second line after newline
472        assert_eq!(get_char_column_simd(text, 6), 0); // Start of "naïve"
473
474        // Position within second line: "na|ïve" = position 2 (after 'n', 'a')
475        assert_eq!(get_char_column_simd(text, 8), 2);
476    }
477
478    #[test]
479    fn test_get_char_column_consistency_with_original() {
480        fn original_get_char_column(text: &str, offset: usize) -> usize {
481            let src = text.as_bytes();
482            let mut col = 0;
483            for &b in src[..offset].iter().rev() {
484                if b == b'\n' {
485                    break;
486                }
487                if b & 0b1100_0000 != 0b1000_0000 {
488                    col += 1;
489                }
490            }
491            col
492        }
493
494        let test_cases = vec![
495            ("hello", vec![0, 1, 3, 5]),
496            ("line1\nline2", vec![0, 3, 5, 6, 9]),
497            ("café\nworld", vec![0, 2, 5, 6, 8]),
498            ("🚀test\nnew", vec![0, 1, 3, 6, 7]),
499            ("", vec![0]),
500            ("a", vec![0, 1]),
501        ];
502
503        for (text, offsets) in test_cases {
504            for offset in offsets {
505                if offset <= text.len() {
506                    let original = original_get_char_column(text, offset);
507                    let simd = get_char_column_simd(text, offset);
508                    assert_eq!(
509                        original, simd,
510                        "Mismatch for text: {text:?}, offset: {offset}"
511                    );
512                }
513            }
514        }
515    }
516
517    #[test]
518    fn test_get_char_column_edge_cases() {
519        // Test edge cases
520        assert_eq!(get_char_column_simd("", 0), 0);
521        assert_eq!(get_char_column_simd("test", 0), 0);
522        assert_eq!(get_char_column_simd("test", 100), 0); // Offset beyond length
523
524        // Test with only newlines
525        assert_eq!(get_char_column_simd("\n\n\n", 1), 0);
526        assert_eq!(get_char_column_simd("\n\n\n", 2), 0);
527
528        // Test long lines
529        let long_line = "a".repeat(1000);
530        assert_eq!(get_char_column_simd(&long_line, 500), 500);
531
532        let long_with_newline = "a".repeat(500) + "\n" + &"b".repeat(300);
533        assert_eq!(get_char_column_simd(&long_with_newline, 800), 299);
534    }
535}