Skip to main content

simd_csv/
searcher.rs

1use std::iter::FusedIterator;
2
3#[cfg(target_arch = "x86_64")]
4mod x86_64 {
5    use std::marker::PhantomData;
6
7    use crate::ext::Pointer;
8
9    #[inline(always)]
10    fn get_for_offset(mask: u32) -> u32 {
11        #[cfg(target_endian = "big")]
12        {
13            mask.swap_bytes()
14        }
15        #[cfg(target_endian = "little")]
16        {
17            mask
18        }
19    }
20
21    #[inline(always)]
22    fn first_offset(mask: u32) -> usize {
23        get_for_offset(mask).trailing_zeros() as usize
24    }
25
26    #[inline(always)]
27    fn clear_least_significant_bit(mask: u32) -> u32 {
28        mask & (mask - 1)
29    }
30
31    pub mod sse2 {
32        use super::*;
33
34        use core::arch::x86_64::{
35            __m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128,
36            _mm_set1_epi8,
37        };
38
39        #[derive(Debug)]
40        pub struct SSE2Searcher {
41            n1: u8,
42            n2: u8,
43            n3: u8,
44            v1: __m128i,
45            v2: __m128i,
46            v3: __m128i,
47        }
48
49        impl SSE2Searcher {
50            #[inline]
51            pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
52                Self {
53                    n1,
54                    n2,
55                    n3,
56                    v1: _mm_set1_epi8(n1 as i8),
57                    v2: _mm_set1_epi8(n2 as i8),
58                    v3: _mm_set1_epi8(n3 as i8),
59                }
60            }
61
62            #[inline(always)]
63            pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> SSE2Indices<'s, 'h> {
64                SSE2Indices::new(self, haystack)
65            }
66        }
67
68        #[derive(Debug)]
69        pub struct SSE2Indices<'s, 'h> {
70            searcher: &'s SSE2Searcher,
71            haystack: PhantomData<&'h [u8]>,
72            start: *const u8,
73            end: *const u8,
74            current: *const u8,
75            mask: u32,
76        }
77
78        impl<'s, 'h> SSE2Indices<'s, 'h> {
79            #[inline]
80            fn new(searcher: &'s SSE2Searcher, haystack: &'h [u8]) -> Self {
81                let ptr = haystack.as_ptr();
82
83                Self {
84                    searcher,
85                    haystack: PhantomData,
86                    start: ptr,
87                    end: ptr.wrapping_add(haystack.len()),
88                    current: ptr,
89                    mask: 0,
90                }
91            }
92        }
93
94        const SSE2_STEP: usize = 16;
95
96        impl SSE2Indices<'_, '_> {
97            #[inline]
98            pub unsafe fn _next_in_current_mask(&mut self) -> Option<usize> {
99                let mask = self.mask;
100                let current = self.current;
101
102                if mask != 0 {
103                    let offset = current.sub(SSE2_STEP).add(first_offset(mask));
104                    self.mask = clear_least_significant_bit(mask);
105                    self.current = current;
106
107                    Some(offset.distance(self.start))
108                } else {
109                    None
110                }
111            }
112
113            pub unsafe fn next(&mut self) -> Option<usize> {
114                if self.start >= self.end {
115                    return None;
116                }
117
118                let mut mask = self.mask;
119
120                let mut current = self.current;
121                let start = self.start;
122                let len = self.end.distance(start);
123                let v1 = self.searcher.v1;
124                let v2 = self.searcher.v2;
125                let v3 = self.searcher.v3;
126
127                'main: loop {
128                    // Processing current move mask
129                    if mask != 0 {
130                        let offset = current.sub(SSE2_STEP).add(first_offset(mask));
131                        self.mask = clear_least_significant_bit(mask);
132                        self.current = current;
133
134                        return Some(offset.distance(start));
135                    }
136
137                    // Main loop of unaligned loads
138                    if len >= SSE2_STEP {
139                        let vectorized_end = self.end.sub(SSE2_STEP);
140
141                        while current <= vectorized_end {
142                            let chunk = _mm_loadu_si128(current as *const __m128i);
143                            let cmp1 = _mm_cmpeq_epi8(chunk, v1);
144                            let cmp2 = _mm_cmpeq_epi8(chunk, v2);
145                            let cmp3 = _mm_cmpeq_epi8(chunk, v3);
146                            let cmp = _mm_or_si128(cmp1, cmp2);
147                            let cmp = _mm_or_si128(cmp, cmp3);
148
149                            mask = _mm_movemask_epi8(cmp) as u32;
150
151                            current = current.add(SSE2_STEP);
152
153                            if mask != 0 {
154                                continue 'main;
155                            }
156                        }
157                    }
158
159                    // Processing remaining bytes linearly
160                    while current < self.end {
161                        if *current == self.searcher.n1
162                            || *current == self.searcher.n2
163                            || *current == self.searcher.n3
164                        {
165                            let offset = current.distance(start);
166                            self.current = current.add(1);
167                            return Some(offset);
168                        }
169                        current = current.add(1);
170                    }
171
172                    return None;
173                }
174            }
175        }
176    }
177}
178
179#[cfg(target_arch = "aarch64")]
180mod aarch64 {
181    use core::arch::aarch64::{
182        uint8x16_t, vceqq_u8, vdupq_n_u8, vget_lane_u64, vld1q_u8, vorrq_u8, vreinterpret_u64_u8,
183        vreinterpretq_u16_u8, vshrn_n_u16,
184    };
185    use std::marker::PhantomData;
186
187    use crate::ext::Pointer;
188
189    #[inline(always)]
190    unsafe fn neon_movemask(v: uint8x16_t) -> u64 {
191        let asu16s = vreinterpretq_u16_u8(v);
192        let mask = vshrn_n_u16(asu16s, 4);
193        let asu64 = vreinterpret_u64_u8(mask);
194        let scalar64 = vget_lane_u64(asu64, 0);
195
196        scalar64 & 0x8888888888888888
197    }
198
199    #[inline(always)]
200    fn first_offset(mask: u64) -> usize {
201        (mask.trailing_zeros() >> 2) as usize
202    }
203
204    #[inline(always)]
205    fn clear_least_significant_bit(mask: u64) -> u64 {
206        mask & (mask - 1)
207    }
208
209    #[derive(Debug)]
210    pub struct NeonSearcher {
211        n1: u8,
212        n2: u8,
213        n3: u8,
214        v1: uint8x16_t,
215        v2: uint8x16_t,
216        v3: uint8x16_t,
217    }
218
219    impl NeonSearcher {
220        #[inline]
221        pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
222            Self {
223                n1,
224                n2,
225                n3,
226                v1: vdupq_n_u8(n1),
227                v2: vdupq_n_u8(n2),
228                v3: vdupq_n_u8(n3),
229            }
230        }
231
232        #[inline(always)]
233        pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> NeonIndices<'s, 'h> {
234            NeonIndices::new(self, haystack)
235        }
236    }
237
238    #[derive(Debug)]
239    pub struct NeonIndices<'s, 'h> {
240        searcher: &'s NeonSearcher,
241        haystack: PhantomData<&'h [u8]>,
242        start: *const u8,
243        end: *const u8,
244        current: *const u8,
245        mask: u64,
246    }
247
248    impl<'s, 'h> NeonIndices<'s, 'h> {
249        #[inline]
250        fn new(searcher: &'s NeonSearcher, haystack: &'h [u8]) -> Self {
251            let ptr = haystack.as_ptr();
252
253            Self {
254                searcher,
255                haystack: PhantomData,
256                start: ptr,
257                end: ptr.wrapping_add(haystack.len()),
258                current: ptr,
259                mask: 0,
260            }
261        }
262    }
263
264    const NEON_STEP: usize = 16;
265
266    impl NeonIndices<'_, '_> {
267        #[inline]
268        pub unsafe fn _next_in_current_mask(&mut self) -> Option<usize> {
269            let mask = self.mask;
270            let current = self.current;
271
272            if mask != 0 {
273                let offset = current.sub(NEON_STEP).add(first_offset(mask));
274                self.mask = clear_least_significant_bit(mask);
275                self.current = current;
276
277                Some(offset.distance(self.start))
278            } else {
279                None
280            }
281        }
282
283        pub unsafe fn next(&mut self) -> Option<usize> {
284            if self.start >= self.end {
285                return None;
286            }
287
288            let mut mask = self.mask;
289            let mut current = self.current;
290            let start = self.start;
291            let len = self.end.distance(start);
292            let v1 = self.searcher.v1;
293            let v2 = self.searcher.v2;
294            let v3 = self.searcher.v3;
295
296            'main: loop {
297                // Processing current move mask
298                if mask != 0 {
299                    let offset = current.sub(NEON_STEP).add(first_offset(mask));
300                    self.mask = clear_least_significant_bit(mask);
301                    self.current = current;
302
303                    return Some(offset.distance(start));
304                }
305
306                // Main loop of unaligned loads
307                if len >= NEON_STEP {
308                    let vectorized_end = self.end.sub(NEON_STEP);
309
310                    while current <= vectorized_end {
311                        let chunk = vld1q_u8(current);
312                        let cmp1 = vceqq_u8(chunk, v1);
313                        let cmp2 = vceqq_u8(chunk, v2);
314                        let cmp3 = vceqq_u8(chunk, v3);
315                        let cmp = vorrq_u8(cmp1, cmp2);
316                        let cmp = vorrq_u8(cmp, cmp3);
317
318                        mask = neon_movemask(cmp);
319
320                        current = current.add(NEON_STEP);
321
322                        if mask != 0 {
323                            continue 'main;
324                        }
325                    }
326                }
327
328                // Processing remaining bytes linearly
329                while current < self.end {
330                    if *current == self.searcher.n1
331                        || *current == self.searcher.n2
332                        || *current == self.searcher.n3
333                    {
334                        let offset = current.distance(start);
335                        self.current = current.add(1);
336                        return Some(offset);
337                    }
338                    current = current.add(1);
339                }
340
341                return None;
342            }
343        }
344    }
345}
346
347/// Returns the SIMD instructions used by this crate's amortized `memchr`-like
348/// searcher.
349///
350/// Note that `memchr` routines, also used by this crate might use different
351/// instruction sets.
352pub fn searcher_simd_instructions() -> &'static str {
353    #[cfg(target_arch = "x86_64")]
354    {
355        "sse2"
356    }
357
358    #[cfg(target_arch = "aarch64")]
359    {
360        "neon"
361    }
362
363    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
364    {
365        "none"
366    }
367}
368
369#[derive(Debug)]
370pub struct Searcher {
371    #[cfg(target_arch = "x86_64")]
372    inner: x86_64::sse2::SSE2Searcher,
373
374    #[cfg(target_arch = "aarch64")]
375    inner: aarch64::NeonSearcher,
376
377    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
378    inner: memchr::arch::all::memchr::Three,
379}
380
381impl Searcher {
382    #[inline(always)]
383    pub fn new(n1: u8, n2: u8, n3: u8) -> Self {
384        #[cfg(target_arch = "x86_64")]
385        {
386            unsafe {
387                Self {
388                    inner: x86_64::sse2::SSE2Searcher::new(n1, n2, n3),
389                }
390            }
391        }
392
393        #[cfg(target_arch = "aarch64")]
394        {
395            unsafe {
396                Self {
397                    inner: aarch64::NeonSearcher::new(n1, n2, n3),
398                }
399            }
400        }
401
402        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
403        {
404            Self {
405                inner: memchr::arch::all::memchr::Three::new(n1, n2, n3),
406            }
407        }
408    }
409
410    #[inline(always)]
411    pub fn search<'s, 'h>(&'s self, haystack: &'h [u8]) -> Indices<'s, 'h> {
412        #[cfg(target_arch = "x86_64")]
413        {
414            Indices {
415                inner: self.inner.iter(haystack),
416            }
417        }
418
419        #[cfg(target_arch = "aarch64")]
420        {
421            Indices {
422                inner: self.inner.iter(haystack),
423            }
424        }
425
426        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
427        {
428            Indices {
429                inner: self.inner.iter(haystack),
430            }
431        }
432    }
433}
434
435#[derive(Debug)]
436pub struct Indices<'s, 'h> {
437    #[cfg(target_arch = "x86_64")]
438    inner: x86_64::sse2::SSE2Indices<'s, 'h>,
439
440    #[cfg(target_arch = "aarch64")]
441    inner: aarch64::NeonIndices<'s, 'h>,
442
443    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
444    inner: memchr::arch::all::memchr::ThreeIter<'s, 'h>,
445}
446
447impl FusedIterator for Indices<'_, '_> {}
448
449impl Iterator for Indices<'_, '_> {
450    type Item = usize;
451
452    #[inline(always)]
453    fn next(&mut self) -> Option<Self::Item> {
454        #[cfg(target_arch = "x86_64")]
455        {
456            unsafe { self.inner.next() }
457        }
458
459        #[cfg(target_arch = "aarch64")]
460        {
461            unsafe { self.inner.next() }
462        }
463
464        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
465        {
466            self.inner.next()
467        }
468    }
469}
470
471impl Indices<'_, '_> {
472    #[inline(always)]
473    pub fn _next_in_current_mask(&mut self) -> Option<usize> {
474        #[cfg(target_arch = "x86_64")]
475        {
476            unsafe { self.inner._next_in_current_mask() }
477        }
478
479        #[cfg(target_arch = "aarch64")]
480        {
481            unsafe { self.inner._next_in_current_mask() }
482        }
483
484        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
485        {
486            None
487        }
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494
495    use memchr::arch::all::memchr::Three;
496
497    static TEST_STRING: &[u8] = b"name,\"surname\",age,color,oper\n,\n,\nation,punctuation\nname,surname,age,color,operation,punctuation";
498    static TEST_STRING_OFFSETS: &[usize; 18] = &[
499        4, 5, 13, 14, 18, 24, 29, 30, 31, 32, 33, 39, 51, 56, 64, 68, 74, 84,
500    ];
501
502    #[test]
503    fn test_scalar_searcher() {
504        fn split(haystack: &[u8]) -> Vec<usize> {
505            let searcher = Three::new(b',', b'"', b'\n');
506            searcher.iter(haystack).collect()
507        }
508
509        let offsets = split(TEST_STRING);
510        assert_eq!(offsets, TEST_STRING_OFFSETS);
511
512        // Not found at all
513        assert!(split("b".repeat(75).as_bytes()).is_empty());
514
515        // Regular
516        assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
517
518        // Exactly 64
519        assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
520
521        // Less than 32
522        assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
523
524        // Less than 16
525        assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
526    }
527
528    #[test]
529    fn test_searcher() {
530        fn split(haystack: &[u8]) -> Vec<usize> {
531            let searcher = Searcher::new(b',', b'"', b'\n');
532            searcher.search(haystack).collect()
533        }
534
535        let offsets = split(TEST_STRING);
536        assert_eq!(offsets, TEST_STRING_OFFSETS);
537
538        // Not found at all
539        assert!(split("b".repeat(75).as_bytes()).is_empty());
540
541        // Regular
542        assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
543
544        // Exactly 64
545        assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
546
547        // Less than 32
548        assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
549
550        // Less than 16
551        assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
552
553        // Complex input
554        let complex = b"name,surname,age\n\"john\",\"landy, the \"\"everlasting\"\" bastard\",45\nlucy,rose,\"67\"\njermaine,jackson,\"89\"\n\nkarine,loucan,\"52\"\nrose,\"glib\",12\n\"guillaume\",\"plique\",\"42\"\r\n";
555        let complex_indices = split(complex);
556
557        assert!(complex_indices
558            .iter()
559            .copied()
560            .all(|c| complex[c] == b',' || complex[c] == b'\n' || complex[c] == b'"'));
561
562        assert_eq!(
563            complex_indices,
564            Three::new(b',', b'\n', b'"')
565                .iter(complex)
566                .collect::<Vec<_>>()
567        );
568    }
569}