simd_csv/
searcher.rs

1/// A trait for adding some helper routines to pointers.
2trait Pointer {
3    /// Returns the distance, in units of `T`, between `self` and `origin`.
4    ///
5    /// # Safety
6    ///
7    /// Same as `ptr::offset_from` in addition to `self >= origin`.
8    unsafe fn distance(self, origin: Self) -> usize;
9}
10
11impl<T> Pointer for *const T {
12    #[inline(always)]
13    unsafe fn distance(self, origin: *const T) -> usize {
14        // TODO: Replace with `ptr::sub_ptr` once stabilized.
15        usize::try_from(self.offset_from(origin)).unwrap_unchecked()
16    }
17}
18
19#[cfg(target_arch = "x86_64")]
20mod x86_64 {
21    use std::marker::PhantomData;
22
23    use super::Pointer;
24
25    #[inline(always)]
26    fn get_for_offset(mask: u32) -> u32 {
27        #[cfg(target_endian = "big")]
28        {
29            mask.swap_bytes()
30        }
31        #[cfg(target_endian = "little")]
32        {
33            mask
34        }
35    }
36
37    #[inline(always)]
38    fn first_offset(mask: u32) -> usize {
39        get_for_offset(mask).trailing_zeros() as usize
40    }
41
42    #[inline(always)]
43    fn clear_least_significant_bit(mask: u32) -> u32 {
44        mask & (mask - 1)
45    }
46
47    pub mod sse2 {
48        use super::*;
49
50        use core::arch::x86_64::{
51            __m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128,
52            _mm_set1_epi8,
53        };
54
55        #[derive(Debug)]
56        pub struct SSE2Searcher {
57            n1: u8,
58            n2: u8,
59            n3: u8,
60            v1: __m128i,
61            v2: __m128i,
62            v3: __m128i,
63        }
64
65        impl SSE2Searcher {
66            #[inline]
67            pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
68                Self {
69                    n1,
70                    n2,
71                    n3,
72                    v1: _mm_set1_epi8(n1 as i8),
73                    v2: _mm_set1_epi8(n2 as i8),
74                    v3: _mm_set1_epi8(n3 as i8),
75                }
76            }
77
78            #[inline(always)]
79            pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> SSE2Indices<'s, 'h> {
80                SSE2Indices::new(self, haystack)
81            }
82        }
83
84        #[derive(Debug)]
85        pub struct SSE2Indices<'s, 'h> {
86            searcher: &'s SSE2Searcher,
87            haystack: PhantomData<&'h [u8]>,
88            start: *const u8,
89            end: *const u8,
90            current: *const u8,
91            mask: u32,
92        }
93
94        impl<'s, 'h> SSE2Indices<'s, 'h> {
95            #[inline]
96            fn new(searcher: &'s SSE2Searcher, haystack: &'h [u8]) -> Self {
97                let ptr = haystack.as_ptr();
98
99                Self {
100                    searcher,
101                    haystack: PhantomData,
102                    start: ptr,
103                    end: ptr.wrapping_add(haystack.len()),
104                    current: ptr,
105                    mask: 0,
106                }
107            }
108        }
109
110        const SSE2_STEP: usize = 16;
111
112        impl<'s, 'h> SSE2Indices<'s, 'h> {
113            pub unsafe fn next(&mut self) -> Option<usize> {
114                if self.start >= self.end {
115                    return None;
116                }
117
118                let mut mask = self.mask;
119                let vectorized_end = self.end.sub(SSE2_STEP);
120                let mut current = self.current;
121                let start = self.start;
122                let v1 = self.searcher.v1;
123                let v2 = self.searcher.v2;
124                let v3 = self.searcher.v3;
125
126                'main: loop {
127                    // Processing current move mask
128                    if mask != 0 {
129                        let offset = current.sub(SSE2_STEP).add(first_offset(mask));
130                        self.mask = clear_least_significant_bit(mask);
131                        self.current = current;
132
133                        return Some(offset.distance(start));
134                    }
135
136                    // Main loop of unaligned loads
137                    while current <= vectorized_end {
138                        let chunk = _mm_loadu_si128(current as *const __m128i);
139                        let cmp1 = _mm_cmpeq_epi8(chunk, v1);
140                        let cmp2 = _mm_cmpeq_epi8(chunk, v2);
141                        let cmp3 = _mm_cmpeq_epi8(chunk, v3);
142                        let cmp = _mm_or_si128(cmp1, cmp2);
143                        let cmp = _mm_or_si128(cmp, cmp3);
144
145                        mask = _mm_movemask_epi8(cmp) as u32;
146
147                        current = current.add(SSE2_STEP);
148
149                        if mask != 0 {
150                            continue 'main;
151                        }
152                    }
153
154                    // Processing remaining bytes linearly
155                    while current < self.end {
156                        if *current == self.searcher.n1
157                            || *current == self.searcher.n2
158                            || *current == self.searcher.n3
159                        {
160                            let offset = current.distance(start);
161                            self.current = current.add(1);
162                            return Some(offset);
163                        }
164                        current = current.add(1);
165                    }
166
167                    return None;
168                }
169            }
170        }
171    }
172}
173
174#[cfg(target_arch = "aarch64")]
175mod aarch64 {
176    use core::arch::aarch64::{
177        uint8x16_t, vceqq_u8, vdupq_n_u8, vget_lane_u64, vld1q_u8, vorrq_u8, vreinterpret_u64_u8,
178        vreinterpretq_u16_u8, vshrn_n_u16,
179    };
180    use std::marker::PhantomData;
181
182    use super::Pointer;
183
184    #[inline(always)]
185    unsafe fn neon_movemask(v: uint8x16_t) -> u64 {
186        let asu16s = vreinterpretq_u16_u8(v);
187        let mask = vshrn_n_u16(asu16s, 4);
188        let asu64 = vreinterpret_u64_u8(mask);
189        let scalar64 = vget_lane_u64(asu64, 0);
190
191        scalar64 & 0x8888888888888888
192    }
193
194    #[inline(always)]
195    fn first_offset(mask: u64) -> usize {
196        (mask.trailing_zeros() >> 2) as usize
197    }
198
199    #[inline(always)]
200    fn clear_least_significant_bit(mask: u64) -> u64 {
201        mask & (mask - 1)
202    }
203
204    #[derive(Debug)]
205    pub struct NeonSearcher {
206        n1: u8,
207        n2: u8,
208        n3: u8,
209        v1: uint8x16_t,
210        v2: uint8x16_t,
211        v3: uint8x16_t,
212    }
213
214    impl NeonSearcher {
215        #[inline]
216        pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
217            Self {
218                n1,
219                n2,
220                n3,
221                v1: vdupq_n_u8(n1),
222                v2: vdupq_n_u8(n2),
223                v3: vdupq_n_u8(n3),
224            }
225        }
226
227        #[inline(always)]
228        pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> NeonIndices<'s, 'h> {
229            NeonIndices::new(self, haystack)
230        }
231    }
232
233    #[derive(Debug)]
234    pub struct NeonIndices<'s, 'h> {
235        searcher: &'s NeonSearcher,
236        haystack: PhantomData<&'h [u8]>,
237        start: *const u8,
238        end: *const u8,
239        current: *const u8,
240        mask: u64,
241    }
242
243    impl<'s, 'h> NeonIndices<'s, 'h> {
244        #[inline]
245        fn new(searcher: &'s NeonSearcher, haystack: &'h [u8]) -> Self {
246            let ptr = haystack.as_ptr();
247
248            Self {
249                searcher,
250                haystack: PhantomData,
251                start: ptr,
252                end: ptr.wrapping_add(haystack.len()),
253                current: ptr,
254                mask: 0,
255            }
256        }
257    }
258
259    const SSE2_STEP: usize = 16;
260
261    impl<'s, 'h> NeonIndices<'s, 'h> {
262        pub unsafe fn next(&mut self) -> Option<usize> {
263            if self.start >= self.end {
264                return None;
265            }
266
267            let mut mask = self.mask;
268            let vectorized_end = self.end.sub(SSE2_STEP);
269            let mut current = self.current;
270            let start = self.start;
271            let v1 = self.searcher.v1;
272            let v2 = self.searcher.v2;
273            let v3 = self.searcher.v3;
274
275            'main: loop {
276                // Processing current move mask
277                if mask != 0 {
278                    let offset = current.sub(SSE2_STEP).add(first_offset(mask));
279                    self.mask = clear_least_significant_bit(mask);
280                    self.current = current;
281
282                    return Some(offset.distance(start));
283                }
284
285                // Main loop of unaligned loads
286                while current <= vectorized_end {
287                    let chunk = vld1q_u8(current);
288                    let cmp1 = vceqq_u8(chunk, v1);
289                    let cmp2 = vceqq_u8(chunk, v2);
290                    let cmp3 = vceqq_u8(chunk, v3);
291                    let cmp = vorrq_u8(cmp1, cmp2);
292                    let cmp = vorrq_u8(cmp, cmp3);
293
294                    mask = neon_movemask(cmp);
295
296                    current = current.add(SSE2_STEP);
297
298                    if mask != 0 {
299                        continue 'main;
300                    }
301                }
302
303                // Processing remaining bytes linearly
304                while current < self.end {
305                    if *current == self.searcher.n1
306                        || *current == self.searcher.n2
307                        || *current == self.searcher.n3
308                    {
309                        let offset = current.distance(start);
310                        self.current = current.add(1);
311                        return Some(offset);
312                    }
313                    current = current.add(1);
314                }
315
316                return None;
317            }
318        }
319    }
320}
321
322#[derive(Debug)]
323pub struct Searcher {
324    #[cfg(target_arch = "x86_64")]
325    inner: x86_64::sse2::SSE2Searcher,
326
327    #[cfg(target_arch = "aarch64")]
328    inner: aarch64::NeonSearcher,
329
330    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
331    inner: memchr::arch::all::memchr::Three,
332}
333
334impl Searcher {
335    pub fn leveraged_simd_instructions() -> &'static str {
336        #[cfg(target_arch = "x86_64")]
337        {
338            "sse2"
339        }
340
341        #[cfg(target_arch = "aarch64")]
342        {
343            "neon"
344        }
345
346        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
347        {
348            "none"
349        }
350    }
351
352    #[inline(always)]
353    pub fn new(n1: u8, n2: u8, n3: u8) -> Self {
354        #[cfg(target_arch = "x86_64")]
355        {
356            unsafe {
357                Self {
358                    inner: x86_64::sse2::SSE2Searcher::new(n1, n2, n3),
359                }
360            }
361        }
362
363        #[cfg(target_arch = "aarch64")]
364        {
365            unsafe {
366                Self {
367                    inner: aarch64::NeonSearcher::new(n1, n2, n3),
368                }
369            }
370        }
371
372        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
373        {
374            Self {
375                inner: memchr::arch::all::memchr::Three::new(n1, n2, n3),
376            }
377        }
378    }
379
380    #[inline(always)]
381    pub fn search<'s, 'h>(&'s self, haystack: &'h [u8]) -> Indices<'s, 'h> {
382        #[cfg(target_arch = "x86_64")]
383        {
384            Indices {
385                inner: self.inner.iter(haystack),
386            }
387        }
388
389        #[cfg(target_arch = "aarch64")]
390        {
391            Indices {
392                inner: self.inner.iter(haystack),
393            }
394        }
395
396        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
397        {
398            Indices {
399                inner: self.inner.iter(haystack),
400            }
401        }
402    }
403}
404
405#[derive(Debug)]
406pub struct Indices<'s, 'h> {
407    #[cfg(target_arch = "x86_64")]
408    inner: x86_64::sse2::SSE2Indices<'s, 'h>,
409
410    #[cfg(target_arch = "aarch64")]
411    inner: aarch64::NeonIndices<'s, 'h>,
412
413    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
414    inner: memchr::arch::all::memchr::ThreeIter<'s, 'h>,
415}
416
417impl<'s, 'h> Iterator for Indices<'s, 'h> {
418    type Item = usize;
419
420    #[inline(always)]
421    fn next(&mut self) -> Option<Self::Item> {
422        #[cfg(target_arch = "x86_64")]
423        {
424            unsafe { self.inner.next() }
425        }
426
427        #[cfg(target_arch = "aarch64")]
428        {
429            unsafe { self.inner.next() }
430        }
431
432        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
433        {
434            self.inner.next()
435        }
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    use memchr::arch::all::memchr::Three;
444
445    static TEST_STRING: &[u8] = b"name,\"surname\",age,color,oper\n,\n,\nation,punctuation\nname,surname,age,color,operation,punctuation";
446    static TEST_STRING_OFFSETS: &[usize; 18] = &[
447        4, 5, 13, 14, 18, 24, 29, 30, 31, 32, 33, 39, 51, 56, 64, 68, 74, 84,
448    ];
449
450    #[test]
451    fn test_scalar_searcher() {
452        fn split(haystack: &[u8]) -> Vec<usize> {
453            let searcher = Three::new(b',', b'"', b'\n');
454            searcher.iter(haystack).collect()
455        }
456
457        let offsets = split(TEST_STRING);
458        assert_eq!(offsets, TEST_STRING_OFFSETS);
459
460        // Not found at all
461        assert!(split("b".repeat(75).as_bytes()).is_empty());
462
463        // Regular
464        assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
465
466        // Exactly 64
467        assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
468
469        // Less than 32
470        assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
471
472        // Less than 16
473        assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
474    }
475
476    #[test]
477    fn test_searcher() {
478        fn split(haystack: &[u8]) -> Vec<usize> {
479            let searcher = Searcher::new(b',', b'"', b'\n');
480            searcher.search(haystack).collect()
481        }
482
483        let offsets = split(TEST_STRING);
484        assert_eq!(offsets, TEST_STRING_OFFSETS);
485
486        // Not found at all
487        assert!(split("b".repeat(75).as_bytes()).is_empty());
488
489        // Regular
490        assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
491
492        // Exactly 64
493        assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
494
495        // Less than 32
496        assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
497
498        // Less than 16
499        assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
500
501        // Complex input
502        let complex = b"name,surname,age\n\"john\",\"landy, the \"\"everlasting\"\" bastard\",45\nlucy,rose,\"67\"\njermaine,jackson,\"89\"\n\nkarine,loucan,\"52\"\nrose,\"glib\",12\n\"guillaume\",\"plique\",\"42\"\r\n";
503        let complex_indices = split(complex);
504
505        assert!(complex_indices
506            .iter()
507            .copied()
508            .all(|c| complex[c] == b',' || complex[c] == b'\n' || complex[c] == b'"'));
509
510        assert_eq!(
511            complex_indices,
512            Three::new(b',', b'\n', b'"')
513                .iter(complex)
514                .collect::<Vec<_>>()
515        );
516    }
517}