simd_csv/
searcher.rs

1use std::iter::FusedIterator;
2
3#[cfg(target_arch = "x86_64")]
4mod x86_64 {
5    use std::marker::PhantomData;
6
7    use crate::ext::Pointer;
8
9    #[inline(always)]
10    fn get_for_offset(mask: u32) -> u32 {
11        #[cfg(target_endian = "big")]
12        {
13            mask.swap_bytes()
14        }
15        #[cfg(target_endian = "little")]
16        {
17            mask
18        }
19    }
20
21    #[inline(always)]
22    fn first_offset(mask: u32) -> usize {
23        get_for_offset(mask).trailing_zeros() as usize
24    }
25
26    #[inline(always)]
27    fn clear_least_significant_bit(mask: u32) -> u32 {
28        mask & (mask - 1)
29    }
30
31    pub mod sse2 {
32        use super::*;
33
34        use core::arch::x86_64::{
35            __m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128,
36            _mm_set1_epi8,
37        };
38
39        #[derive(Debug)]
40        pub struct SSE2Searcher {
41            n1: u8,
42            n2: u8,
43            n3: u8,
44            v1: __m128i,
45            v2: __m128i,
46            v3: __m128i,
47        }
48
49        impl SSE2Searcher {
50            #[inline]
51            pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
52                Self {
53                    n1,
54                    n2,
55                    n3,
56                    v1: _mm_set1_epi8(n1 as i8),
57                    v2: _mm_set1_epi8(n2 as i8),
58                    v3: _mm_set1_epi8(n3 as i8),
59                }
60            }
61
62            #[inline(always)]
63            pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> SSE2Indices<'s, 'h> {
64                SSE2Indices::new(self, haystack)
65            }
66        }
67
68        #[derive(Debug)]
69        pub struct SSE2Indices<'s, 'h> {
70            searcher: &'s SSE2Searcher,
71            haystack: PhantomData<&'h [u8]>,
72            start: *const u8,
73            end: *const u8,
74            current: *const u8,
75            mask: u32,
76        }
77
78        impl<'s, 'h> SSE2Indices<'s, 'h> {
79            #[inline]
80            fn new(searcher: &'s SSE2Searcher, haystack: &'h [u8]) -> Self {
81                let ptr = haystack.as_ptr();
82
83                Self {
84                    searcher,
85                    haystack: PhantomData,
86                    start: ptr,
87                    end: ptr.wrapping_add(haystack.len()),
88                    current: ptr,
89                    mask: 0,
90                }
91            }
92        }
93
94        const SSE2_STEP: usize = 16;
95
96        impl<'s, 'h> SSE2Indices<'s, 'h> {
97            pub unsafe fn next(&mut self) -> Option<usize> {
98                if self.start >= self.end {
99                    return None;
100                }
101
102                let mut mask = self.mask;
103                let vectorized_end = self.end.sub(SSE2_STEP);
104                let mut current = self.current;
105                let start = self.start;
106                let v1 = self.searcher.v1;
107                let v2 = self.searcher.v2;
108                let v3 = self.searcher.v3;
109
110                'main: loop {
111                    // Processing current move mask
112                    if mask != 0 {
113                        let offset = current.sub(SSE2_STEP).add(first_offset(mask));
114                        self.mask = clear_least_significant_bit(mask);
115                        self.current = current;
116
117                        return Some(offset.distance(start));
118                    }
119
120                    // Main loop of unaligned loads
121                    while current <= vectorized_end {
122                        let chunk = _mm_loadu_si128(current as *const __m128i);
123                        let cmp1 = _mm_cmpeq_epi8(chunk, v1);
124                        let cmp2 = _mm_cmpeq_epi8(chunk, v2);
125                        let cmp3 = _mm_cmpeq_epi8(chunk, v3);
126                        let cmp = _mm_or_si128(cmp1, cmp2);
127                        let cmp = _mm_or_si128(cmp, cmp3);
128
129                        mask = _mm_movemask_epi8(cmp) as u32;
130
131                        current = current.add(SSE2_STEP);
132
133                        if mask != 0 {
134                            continue 'main;
135                        }
136                    }
137
138                    // Processing remaining bytes linearly
139                    while current < self.end {
140                        if *current == self.searcher.n1
141                            || *current == self.searcher.n2
142                            || *current == self.searcher.n3
143                        {
144                            let offset = current.distance(start);
145                            self.current = current.add(1);
146                            return Some(offset);
147                        }
148                        current = current.add(1);
149                    }
150
151                    return None;
152                }
153            }
154        }
155    }
156}
157
158#[cfg(target_arch = "aarch64")]
159mod aarch64 {
160    use core::arch::aarch64::{
161        uint8x16_t, vceqq_u8, vdupq_n_u8, vget_lane_u64, vld1q_u8, vorrq_u8, vreinterpret_u64_u8,
162        vreinterpretq_u16_u8, vshrn_n_u16,
163    };
164    use std::marker::PhantomData;
165
166    use crate::ext::Pointer;
167
168    #[inline(always)]
169    unsafe fn neon_movemask(v: uint8x16_t) -> u64 {
170        let asu16s = vreinterpretq_u16_u8(v);
171        let mask = vshrn_n_u16(asu16s, 4);
172        let asu64 = vreinterpret_u64_u8(mask);
173        let scalar64 = vget_lane_u64(asu64, 0);
174
175        scalar64 & 0x8888888888888888
176    }
177
178    #[inline(always)]
179    fn first_offset(mask: u64) -> usize {
180        (mask.trailing_zeros() >> 2) as usize
181    }
182
183    #[inline(always)]
184    fn clear_least_significant_bit(mask: u64) -> u64 {
185        mask & (mask - 1)
186    }
187
188    #[derive(Debug)]
189    pub struct NeonSearcher {
190        n1: u8,
191        n2: u8,
192        n3: u8,
193        v1: uint8x16_t,
194        v2: uint8x16_t,
195        v3: uint8x16_t,
196    }
197
198    impl NeonSearcher {
199        #[inline]
200        pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
201            Self {
202                n1,
203                n2,
204                n3,
205                v1: vdupq_n_u8(n1),
206                v2: vdupq_n_u8(n2),
207                v3: vdupq_n_u8(n3),
208            }
209        }
210
211        #[inline(always)]
212        pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> NeonIndices<'s, 'h> {
213            NeonIndices::new(self, haystack)
214        }
215    }
216
217    #[derive(Debug)]
218    pub struct NeonIndices<'s, 'h> {
219        searcher: &'s NeonSearcher,
220        haystack: PhantomData<&'h [u8]>,
221        start: *const u8,
222        end: *const u8,
223        current: *const u8,
224        mask: u64,
225    }
226
227    impl<'s, 'h> NeonIndices<'s, 'h> {
228        #[inline]
229        fn new(searcher: &'s NeonSearcher, haystack: &'h [u8]) -> Self {
230            let ptr = haystack.as_ptr();
231
232            Self {
233                searcher,
234                haystack: PhantomData,
235                start: ptr,
236                end: ptr.wrapping_add(haystack.len()),
237                current: ptr,
238                mask: 0,
239            }
240        }
241    }
242
243    const SSE2_STEP: usize = 16;
244
245    impl<'s, 'h> NeonIndices<'s, 'h> {
246        pub unsafe fn next(&mut self) -> Option<usize> {
247            if self.start >= self.end {
248                return None;
249            }
250
251            let mut mask = self.mask;
252            let vectorized_end = self.end.sub(SSE2_STEP);
253            let mut current = self.current;
254            let start = self.start;
255            let v1 = self.searcher.v1;
256            let v2 = self.searcher.v2;
257            let v3 = self.searcher.v3;
258
259            'main: loop {
260                // Processing current move mask
261                if mask != 0 {
262                    let offset = current.sub(SSE2_STEP).add(first_offset(mask));
263                    self.mask = clear_least_significant_bit(mask);
264                    self.current = current;
265
266                    return Some(offset.distance(start));
267                }
268
269                // Main loop of unaligned loads
270                while current <= vectorized_end {
271                    let chunk = vld1q_u8(current);
272                    let cmp1 = vceqq_u8(chunk, v1);
273                    let cmp2 = vceqq_u8(chunk, v2);
274                    let cmp3 = vceqq_u8(chunk, v3);
275                    let cmp = vorrq_u8(cmp1, cmp2);
276                    let cmp = vorrq_u8(cmp, cmp3);
277
278                    mask = neon_movemask(cmp);
279
280                    current = current.add(SSE2_STEP);
281
282                    if mask != 0 {
283                        continue 'main;
284                    }
285                }
286
287                // Processing remaining bytes linearly
288                while current < self.end {
289                    if *current == self.searcher.n1
290                        || *current == self.searcher.n2
291                        || *current == self.searcher.n3
292                    {
293                        let offset = current.distance(start);
294                        self.current = current.add(1);
295                        return Some(offset);
296                    }
297                    current = current.add(1);
298                }
299
300                return None;
301            }
302        }
303    }
304}
305
306#[derive(Debug)]
307pub struct Searcher {
308    #[cfg(target_arch = "x86_64")]
309    inner: x86_64::sse2::SSE2Searcher,
310
311    #[cfg(target_arch = "aarch64")]
312    inner: aarch64::NeonSearcher,
313
314    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
315    inner: memchr::arch::all::memchr::Three,
316}
317
318impl Searcher {
319    pub fn leveraged_simd_instructions() -> &'static str {
320        #[cfg(target_arch = "x86_64")]
321        {
322            "sse2"
323        }
324
325        #[cfg(target_arch = "aarch64")]
326        {
327            "neon"
328        }
329
330        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
331        {
332            "none"
333        }
334    }
335
336    #[inline(always)]
337    pub fn new(n1: u8, n2: u8, n3: u8) -> Self {
338        #[cfg(target_arch = "x86_64")]
339        {
340            unsafe {
341                Self {
342                    inner: x86_64::sse2::SSE2Searcher::new(n1, n2, n3),
343                }
344            }
345        }
346
347        #[cfg(target_arch = "aarch64")]
348        {
349            unsafe {
350                Self {
351                    inner: aarch64::NeonSearcher::new(n1, n2, n3),
352                }
353            }
354        }
355
356        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
357        {
358            Self {
359                inner: memchr::arch::all::memchr::Three::new(n1, n2, n3),
360            }
361        }
362    }
363
364    #[inline(always)]
365    pub fn search<'s, 'h>(&'s self, haystack: &'h [u8]) -> Indices<'s, 'h> {
366        #[cfg(target_arch = "x86_64")]
367        {
368            Indices {
369                inner: self.inner.iter(haystack),
370            }
371        }
372
373        #[cfg(target_arch = "aarch64")]
374        {
375            Indices {
376                inner: self.inner.iter(haystack),
377            }
378        }
379
380        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
381        {
382            Indices {
383                inner: self.inner.iter(haystack),
384            }
385        }
386    }
387}
388
389#[derive(Debug)]
390pub struct Indices<'s, 'h> {
391    #[cfg(target_arch = "x86_64")]
392    inner: x86_64::sse2::SSE2Indices<'s, 'h>,
393
394    #[cfg(target_arch = "aarch64")]
395    inner: aarch64::NeonIndices<'s, 'h>,
396
397    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
398    inner: memchr::arch::all::memchr::ThreeIter<'s, 'h>,
399}
400
401impl<'s, 'h> FusedIterator for Indices<'s, 'h> {}
402
403impl<'s, 'h> Iterator for Indices<'s, 'h> {
404    type Item = usize;
405
406    #[inline(always)]
407    fn next(&mut self) -> Option<Self::Item> {
408        #[cfg(target_arch = "x86_64")]
409        {
410            unsafe { self.inner.next() }
411        }
412
413        #[cfg(target_arch = "aarch64")]
414        {
415            unsafe { self.inner.next() }
416        }
417
418        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
419        {
420            self.inner.next()
421        }
422    }
423}
424
425#[cfg(test)]
426mod tests {
427    use super::*;
428
429    use memchr::arch::all::memchr::Three;
430
431    static TEST_STRING: &[u8] = b"name,\"surname\",age,color,oper\n,\n,\nation,punctuation\nname,surname,age,color,operation,punctuation";
432    static TEST_STRING_OFFSETS: &[usize; 18] = &[
433        4, 5, 13, 14, 18, 24, 29, 30, 31, 32, 33, 39, 51, 56, 64, 68, 74, 84,
434    ];
435
436    #[test]
437    fn test_scalar_searcher() {
438        fn split(haystack: &[u8]) -> Vec<usize> {
439            let searcher = Three::new(b',', b'"', b'\n');
440            searcher.iter(haystack).collect()
441        }
442
443        let offsets = split(TEST_STRING);
444        assert_eq!(offsets, TEST_STRING_OFFSETS);
445
446        // Not found at all
447        assert!(split("b".repeat(75).as_bytes()).is_empty());
448
449        // Regular
450        assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
451
452        // Exactly 64
453        assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
454
455        // Less than 32
456        assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
457
458        // Less than 16
459        assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
460    }
461
462    #[test]
463    fn test_searcher() {
464        fn split(haystack: &[u8]) -> Vec<usize> {
465            let searcher = Searcher::new(b',', b'"', b'\n');
466            searcher.search(haystack).collect()
467        }
468
469        let offsets = split(TEST_STRING);
470        assert_eq!(offsets, TEST_STRING_OFFSETS);
471
472        // Not found at all
473        assert!(split("b".repeat(75).as_bytes()).is_empty());
474
475        // Regular
476        assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
477
478        // Exactly 64
479        assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
480
481        // Less than 32
482        assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
483
484        // Less than 16
485        assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
486
487        // Complex input
488        let complex = b"name,surname,age\n\"john\",\"landy, the \"\"everlasting\"\" bastard\",45\nlucy,rose,\"67\"\njermaine,jackson,\"89\"\n\nkarine,loucan,\"52\"\nrose,\"glib\",12\n\"guillaume\",\"plique\",\"42\"\r\n";
489        let complex_indices = split(complex);
490
491        assert!(complex_indices
492            .iter()
493            .copied()
494            .all(|c| complex[c] == b',' || complex[c] == b'\n' || complex[c] == b'"'));
495
496        assert_eq!(
497            complex_indices,
498            Three::new(b',', b'\n', b'"')
499                .iter(complex)
500                .collect::<Vec<_>>()
501        );
502    }
503}