simd_csv/
searcher.rs

1use std::iter::FusedIterator;
2
3#[cfg(target_arch = "x86_64")]
4mod x86_64 {
5    use std::marker::PhantomData;
6
7    use crate::ext::Pointer;
8
9    #[inline(always)]
10    fn get_for_offset(mask: u32) -> u32 {
11        #[cfg(target_endian = "big")]
12        {
13            mask.swap_bytes()
14        }
15        #[cfg(target_endian = "little")]
16        {
17            mask
18        }
19    }
20
21    #[inline(always)]
22    fn first_offset(mask: u32) -> usize {
23        get_for_offset(mask).trailing_zeros() as usize
24    }
25
26    #[inline(always)]
27    fn clear_least_significant_bit(mask: u32) -> u32 {
28        mask & (mask - 1)
29    }
30
31    pub mod sse2 {
32        use super::*;
33
34        use core::arch::x86_64::{
35            __m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128,
36            _mm_set1_epi8,
37        };
38
39        #[derive(Debug)]
40        pub struct SSE2Searcher {
41            n1: u8,
42            n2: u8,
43            n3: u8,
44            v1: __m128i,
45            v2: __m128i,
46            v3: __m128i,
47        }
48
49        impl SSE2Searcher {
50            #[inline]
51            pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
52                Self {
53                    n1,
54                    n2,
55                    n3,
56                    v1: _mm_set1_epi8(n1 as i8),
57                    v2: _mm_set1_epi8(n2 as i8),
58                    v3: _mm_set1_epi8(n3 as i8),
59                }
60            }
61
62            #[inline(always)]
63            pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> SSE2Indices<'s, 'h> {
64                SSE2Indices::new(self, haystack)
65            }
66        }
67
68        #[derive(Debug)]
69        pub struct SSE2Indices<'s, 'h> {
70            searcher: &'s SSE2Searcher,
71            haystack: PhantomData<&'h [u8]>,
72            start: *const u8,
73            end: *const u8,
74            current: *const u8,
75            mask: u32,
76        }
77
78        impl<'s, 'h> SSE2Indices<'s, 'h> {
79            #[inline]
80            fn new(searcher: &'s SSE2Searcher, haystack: &'h [u8]) -> Self {
81                let ptr = haystack.as_ptr();
82
83                Self {
84                    searcher,
85                    haystack: PhantomData,
86                    start: ptr,
87                    end: ptr.wrapping_add(haystack.len()),
88                    current: ptr,
89                    mask: 0,
90                }
91            }
92        }
93
94        const SSE2_STEP: usize = 16;
95
96        impl SSE2Indices<'_, '_> {
97            pub unsafe fn next(&mut self) -> Option<usize> {
98                if self.start >= self.end {
99                    return None;
100                }
101
102                let mut mask = self.mask;
103
104                let mut current = self.current;
105                let start = self.start;
106                let len = self.end.distance(start);
107                let v1 = self.searcher.v1;
108                let v2 = self.searcher.v2;
109                let v3 = self.searcher.v3;
110
111                'main: loop {
112                    // Processing current move mask
113                    if mask != 0 {
114                        let offset = current.sub(SSE2_STEP).add(first_offset(mask));
115                        self.mask = clear_least_significant_bit(mask);
116                        self.current = current;
117
118                        return Some(offset.distance(start));
119                    }
120
121                    // Main loop of unaligned loads
122                    if len >= SSE2_STEP {
123                        let vectorized_end = self.end.sub(SSE2_STEP);
124
125                        while current <= vectorized_end {
126                            let chunk = _mm_loadu_si128(current as *const __m128i);
127                            let cmp1 = _mm_cmpeq_epi8(chunk, v1);
128                            let cmp2 = _mm_cmpeq_epi8(chunk, v2);
129                            let cmp3 = _mm_cmpeq_epi8(chunk, v3);
130                            let cmp = _mm_or_si128(cmp1, cmp2);
131                            let cmp = _mm_or_si128(cmp, cmp3);
132
133                            mask = _mm_movemask_epi8(cmp) as u32;
134
135                            current = current.add(SSE2_STEP);
136
137                            if mask != 0 {
138                                continue 'main;
139                            }
140                        }
141                    }
142
143                    // Processing remaining bytes linearly
144                    while current < self.end {
145                        if *current == self.searcher.n1
146                            || *current == self.searcher.n2
147                            || *current == self.searcher.n3
148                        {
149                            let offset = current.distance(start);
150                            self.current = current.add(1);
151                            return Some(offset);
152                        }
153                        current = current.add(1);
154                    }
155
156                    return None;
157                }
158            }
159        }
160    }
161}
162
163#[cfg(target_arch = "aarch64")]
164mod aarch64 {
165    use core::arch::aarch64::{
166        uint8x16_t, vceqq_u8, vdupq_n_u8, vget_lane_u64, vld1q_u8, vorrq_u8, vreinterpret_u64_u8,
167        vreinterpretq_u16_u8, vshrn_n_u16,
168    };
169    use std::marker::PhantomData;
170
171    use crate::ext::Pointer;
172
173    #[inline(always)]
174    unsafe fn neon_movemask(v: uint8x16_t) -> u64 {
175        let asu16s = vreinterpretq_u16_u8(v);
176        let mask = vshrn_n_u16(asu16s, 4);
177        let asu64 = vreinterpret_u64_u8(mask);
178        let scalar64 = vget_lane_u64(asu64, 0);
179
180        scalar64 & 0x8888888888888888
181    }
182
183    #[inline(always)]
184    fn first_offset(mask: u64) -> usize {
185        (mask.trailing_zeros() >> 2) as usize
186    }
187
188    #[inline(always)]
189    fn clear_least_significant_bit(mask: u64) -> u64 {
190        mask & (mask - 1)
191    }
192
193    #[derive(Debug)]
194    pub struct NeonSearcher {
195        n1: u8,
196        n2: u8,
197        n3: u8,
198        v1: uint8x16_t,
199        v2: uint8x16_t,
200        v3: uint8x16_t,
201    }
202
203    impl NeonSearcher {
204        #[inline]
205        pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
206            Self {
207                n1,
208                n2,
209                n3,
210                v1: vdupq_n_u8(n1),
211                v2: vdupq_n_u8(n2),
212                v3: vdupq_n_u8(n3),
213            }
214        }
215
216        #[inline(always)]
217        pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> NeonIndices<'s, 'h> {
218            NeonIndices::new(self, haystack)
219        }
220    }
221
222    #[derive(Debug)]
223    pub struct NeonIndices<'s, 'h> {
224        searcher: &'s NeonSearcher,
225        haystack: PhantomData<&'h [u8]>,
226        start: *const u8,
227        end: *const u8,
228        current: *const u8,
229        mask: u64,
230    }
231
232    impl<'s, 'h> NeonIndices<'s, 'h> {
233        #[inline]
234        fn new(searcher: &'s NeonSearcher, haystack: &'h [u8]) -> Self {
235            let ptr = haystack.as_ptr();
236
237            Self {
238                searcher,
239                haystack: PhantomData,
240                start: ptr,
241                end: ptr.wrapping_add(haystack.len()),
242                current: ptr,
243                mask: 0,
244            }
245        }
246    }
247
248    const NEON_STEP: usize = 16;
249
250    impl NeonIndices<'_, '_> {
251        pub unsafe fn next(&mut self) -> Option<usize> {
252            if self.start >= self.end {
253                return None;
254            }
255
256            let mut mask = self.mask;
257            let mut current = self.current;
258            let start = self.start;
259            let len = self.end.distance(start);
260            let v1 = self.searcher.v1;
261            let v2 = self.searcher.v2;
262            let v3 = self.searcher.v3;
263
264            'main: loop {
265                // Processing current move mask
266                if mask != 0 {
267                    let offset = current.sub(NEON_STEP).add(first_offset(mask));
268                    self.mask = clear_least_significant_bit(mask);
269                    self.current = current;
270
271                    return Some(offset.distance(start));
272                }
273
274                // Main loop of unaligned loads
275                if len >= NEON_STEP {
276                    let vectorized_end = self.end.sub(NEON_STEP);
277
278                    while current <= vectorized_end {
279                        let chunk = vld1q_u8(current);
280                        let cmp1 = vceqq_u8(chunk, v1);
281                        let cmp2 = vceqq_u8(chunk, v2);
282                        let cmp3 = vceqq_u8(chunk, v3);
283                        let cmp = vorrq_u8(cmp1, cmp2);
284                        let cmp = vorrq_u8(cmp, cmp3);
285
286                        mask = neon_movemask(cmp);
287
288                        current = current.add(NEON_STEP);
289
290                        if mask != 0 {
291                            continue 'main;
292                        }
293                    }
294                }
295
296                // Processing remaining bytes linearly
297                while current < self.end {
298                    if *current == self.searcher.n1
299                        || *current == self.searcher.n2
300                        || *current == self.searcher.n3
301                    {
302                        let offset = current.distance(start);
303                        self.current = current.add(1);
304                        return Some(offset);
305                    }
306                    current = current.add(1);
307                }
308
309                return None;
310            }
311        }
312    }
313}
314
315/// Returns the SIMD instructions used by this crate's amortized `memchr`-like
316/// searcher.
317///
318/// Note that `memchr` routines, also used by this crate might use different
319/// instruction sets.
320pub fn searcher_simd_instructions() -> &'static str {
321    #[cfg(target_arch = "x86_64")]
322    {
323        "sse2"
324    }
325
326    #[cfg(target_arch = "aarch64")]
327    {
328        "neon"
329    }
330
331    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
332    {
333        "none"
334    }
335}
336
337#[derive(Debug)]
338pub struct Searcher {
339    #[cfg(target_arch = "x86_64")]
340    inner: x86_64::sse2::SSE2Searcher,
341
342    #[cfg(target_arch = "aarch64")]
343    inner: aarch64::NeonSearcher,
344
345    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
346    inner: memchr::arch::all::memchr::Three,
347}
348
349impl Searcher {
350    #[inline(always)]
351    pub fn new(n1: u8, n2: u8, n3: u8) -> Self {
352        #[cfg(target_arch = "x86_64")]
353        {
354            unsafe {
355                Self {
356                    inner: x86_64::sse2::SSE2Searcher::new(n1, n2, n3),
357                }
358            }
359        }
360
361        #[cfg(target_arch = "aarch64")]
362        {
363            unsafe {
364                Self {
365                    inner: aarch64::NeonSearcher::new(n1, n2, n3),
366                }
367            }
368        }
369
370        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
371        {
372            Self {
373                inner: memchr::arch::all::memchr::Three::new(n1, n2, n3),
374            }
375        }
376    }
377
378    #[inline(always)]
379    pub fn search<'s, 'h>(&'s self, haystack: &'h [u8]) -> Indices<'s, 'h> {
380        #[cfg(target_arch = "x86_64")]
381        {
382            Indices {
383                inner: self.inner.iter(haystack),
384            }
385        }
386
387        #[cfg(target_arch = "aarch64")]
388        {
389            Indices {
390                inner: self.inner.iter(haystack),
391            }
392        }
393
394        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
395        {
396            Indices {
397                inner: self.inner.iter(haystack),
398            }
399        }
400    }
401}
402
403#[derive(Debug)]
404pub struct Indices<'s, 'h> {
405    #[cfg(target_arch = "x86_64")]
406    inner: x86_64::sse2::SSE2Indices<'s, 'h>,
407
408    #[cfg(target_arch = "aarch64")]
409    inner: aarch64::NeonIndices<'s, 'h>,
410
411    #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
412    inner: memchr::arch::all::memchr::ThreeIter<'s, 'h>,
413}
414
415impl FusedIterator for Indices<'_, '_> {}
416
417impl Iterator for Indices<'_, '_> {
418    type Item = usize;
419
420    #[inline(always)]
421    fn next(&mut self) -> Option<Self::Item> {
422        #[cfg(target_arch = "x86_64")]
423        {
424            unsafe { self.inner.next() }
425        }
426
427        #[cfg(target_arch = "aarch64")]
428        {
429            unsafe { self.inner.next() }
430        }
431
432        #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
433        {
434            self.inner.next()
435        }
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    use memchr::arch::all::memchr::Three;
444
445    static TEST_STRING: &[u8] = b"name,\"surname\",age,color,oper\n,\n,\nation,punctuation\nname,surname,age,color,operation,punctuation";
446    static TEST_STRING_OFFSETS: &[usize; 18] = &[
447        4, 5, 13, 14, 18, 24, 29, 30, 31, 32, 33, 39, 51, 56, 64, 68, 74, 84,
448    ];
449
450    #[test]
451    fn test_scalar_searcher() {
452        fn split(haystack: &[u8]) -> Vec<usize> {
453            let searcher = Three::new(b',', b'"', b'\n');
454            searcher.iter(haystack).collect()
455        }
456
457        let offsets = split(TEST_STRING);
458        assert_eq!(offsets, TEST_STRING_OFFSETS);
459
460        // Not found at all
461        assert!(split("b".repeat(75).as_bytes()).is_empty());
462
463        // Regular
464        assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
465
466        // Exactly 64
467        assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
468
469        // Less than 32
470        assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
471
472        // Less than 16
473        assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
474    }
475
476    #[test]
477    fn test_searcher() {
478        fn split(haystack: &[u8]) -> Vec<usize> {
479            let searcher = Searcher::new(b',', b'"', b'\n');
480            searcher.search(haystack).collect()
481        }
482
483        let offsets = split(TEST_STRING);
484        assert_eq!(offsets, TEST_STRING_OFFSETS);
485
486        // Not found at all
487        assert!(split("b".repeat(75).as_bytes()).is_empty());
488
489        // Regular
490        assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
491
492        // Exactly 64
493        assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
494
495        // Less than 32
496        assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
497
498        // Less than 16
499        assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
500
501        // Complex input
502        let complex = b"name,surname,age\n\"john\",\"landy, the \"\"everlasting\"\" bastard\",45\nlucy,rose,\"67\"\njermaine,jackson,\"89\"\n\nkarine,loucan,\"52\"\nrose,\"glib\",12\n\"guillaume\",\"plique\",\"42\"\r\n";
503        let complex_indices = split(complex);
504
505        assert!(complex_indices
506            .iter()
507            .copied()
508            .all(|c| complex[c] == b',' || complex[c] == b'\n' || complex[c] == b'"'));
509
510        assert_eq!(
511            complex_indices,
512            Three::new(b',', b'\n', b'"')
513                .iter(complex)
514                .collect::<Vec<_>>()
515        );
516    }
517}