1use std::iter::FusedIterator;
2
3#[cfg(target_arch = "x86_64")]
4mod x86_64 {
5 use std::marker::PhantomData;
6
7 use crate::ext::Pointer;
8
9 #[inline(always)]
10 fn get_for_offset(mask: u32) -> u32 {
11 #[cfg(target_endian = "big")]
12 {
13 mask.swap_bytes()
14 }
15 #[cfg(target_endian = "little")]
16 {
17 mask
18 }
19 }
20
21 #[inline(always)]
22 fn first_offset(mask: u32) -> usize {
23 get_for_offset(mask).trailing_zeros() as usize
24 }
25
26 #[inline(always)]
27 fn clear_least_significant_bit(mask: u32) -> u32 {
28 mask & (mask - 1)
29 }
30
31 pub mod sse2 {
32 use super::*;
33
34 use core::arch::x86_64::{
35 __m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128,
36 _mm_set1_epi8,
37 };
38
39 #[derive(Debug)]
40 pub struct SSE2Searcher {
41 n1: u8,
42 n2: u8,
43 n3: u8,
44 v1: __m128i,
45 v2: __m128i,
46 v3: __m128i,
47 }
48
49 impl SSE2Searcher {
50 #[inline]
51 pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
52 Self {
53 n1,
54 n2,
55 n3,
56 v1: _mm_set1_epi8(n1 as i8),
57 v2: _mm_set1_epi8(n2 as i8),
58 v3: _mm_set1_epi8(n3 as i8),
59 }
60 }
61
62 #[inline(always)]
63 pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> SSE2Indices<'s, 'h> {
64 SSE2Indices::new(self, haystack)
65 }
66 }
67
68 #[derive(Debug)]
69 pub struct SSE2Indices<'s, 'h> {
70 searcher: &'s SSE2Searcher,
71 haystack: PhantomData<&'h [u8]>,
72 start: *const u8,
73 end: *const u8,
74 current: *const u8,
75 mask: u32,
76 }
77
78 impl<'s, 'h> SSE2Indices<'s, 'h> {
79 #[inline]
80 fn new(searcher: &'s SSE2Searcher, haystack: &'h [u8]) -> Self {
81 let ptr = haystack.as_ptr();
82
83 Self {
84 searcher,
85 haystack: PhantomData,
86 start: ptr,
87 end: ptr.wrapping_add(haystack.len()),
88 current: ptr,
89 mask: 0,
90 }
91 }
92 }
93
94 const SSE2_STEP: usize = 16;
95
96 impl<'s, 'h> SSE2Indices<'s, 'h> {
97 pub unsafe fn next(&mut self) -> Option<usize> {
98 if self.start >= self.end {
99 return None;
100 }
101
102 let mut mask = self.mask;
103 let vectorized_end = self.end.sub(SSE2_STEP);
104 let mut current = self.current;
105 let start = self.start;
106 let v1 = self.searcher.v1;
107 let v2 = self.searcher.v2;
108 let v3 = self.searcher.v3;
109
110 'main: loop {
111 if mask != 0 {
113 let offset = current.sub(SSE2_STEP).add(first_offset(mask));
114 self.mask = clear_least_significant_bit(mask);
115 self.current = current;
116
117 return Some(offset.distance(start));
118 }
119
120 while current <= vectorized_end {
122 let chunk = _mm_loadu_si128(current as *const __m128i);
123 let cmp1 = _mm_cmpeq_epi8(chunk, v1);
124 let cmp2 = _mm_cmpeq_epi8(chunk, v2);
125 let cmp3 = _mm_cmpeq_epi8(chunk, v3);
126 let cmp = _mm_or_si128(cmp1, cmp2);
127 let cmp = _mm_or_si128(cmp, cmp3);
128
129 mask = _mm_movemask_epi8(cmp) as u32;
130
131 current = current.add(SSE2_STEP);
132
133 if mask != 0 {
134 continue 'main;
135 }
136 }
137
138 while current < self.end {
140 if *current == self.searcher.n1
141 || *current == self.searcher.n2
142 || *current == self.searcher.n3
143 {
144 let offset = current.distance(start);
145 self.current = current.add(1);
146 return Some(offset);
147 }
148 current = current.add(1);
149 }
150
151 return None;
152 }
153 }
154 }
155 }
156}
157
158#[cfg(target_arch = "aarch64")]
159mod aarch64 {
160 use core::arch::aarch64::{
161 uint8x16_t, vceqq_u8, vdupq_n_u8, vget_lane_u64, vld1q_u8, vorrq_u8, vreinterpret_u64_u8,
162 vreinterpretq_u16_u8, vshrn_n_u16,
163 };
164 use std::marker::PhantomData;
165
166 use crate::ext::Pointer;
167
168 #[inline(always)]
169 unsafe fn neon_movemask(v: uint8x16_t) -> u64 {
170 let asu16s = vreinterpretq_u16_u8(v);
171 let mask = vshrn_n_u16(asu16s, 4);
172 let asu64 = vreinterpret_u64_u8(mask);
173 let scalar64 = vget_lane_u64(asu64, 0);
174
175 scalar64 & 0x8888888888888888
176 }
177
178 #[inline(always)]
179 fn first_offset(mask: u64) -> usize {
180 (mask.trailing_zeros() >> 2) as usize
181 }
182
183 #[inline(always)]
184 fn clear_least_significant_bit(mask: u64) -> u64 {
185 mask & (mask - 1)
186 }
187
188 #[derive(Debug)]
189 pub struct NeonSearcher {
190 n1: u8,
191 n2: u8,
192 n3: u8,
193 v1: uint8x16_t,
194 v2: uint8x16_t,
195 v3: uint8x16_t,
196 }
197
198 impl NeonSearcher {
199 #[inline]
200 pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
201 Self {
202 n1,
203 n2,
204 n3,
205 v1: vdupq_n_u8(n1),
206 v2: vdupq_n_u8(n2),
207 v3: vdupq_n_u8(n3),
208 }
209 }
210
211 #[inline(always)]
212 pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> NeonIndices<'s, 'h> {
213 NeonIndices::new(self, haystack)
214 }
215 }
216
217 #[derive(Debug)]
218 pub struct NeonIndices<'s, 'h> {
219 searcher: &'s NeonSearcher,
220 haystack: PhantomData<&'h [u8]>,
221 start: *const u8,
222 end: *const u8,
223 current: *const u8,
224 mask: u64,
225 }
226
227 impl<'s, 'h> NeonIndices<'s, 'h> {
228 #[inline]
229 fn new(searcher: &'s NeonSearcher, haystack: &'h [u8]) -> Self {
230 let ptr = haystack.as_ptr();
231
232 Self {
233 searcher,
234 haystack: PhantomData,
235 start: ptr,
236 end: ptr.wrapping_add(haystack.len()),
237 current: ptr,
238 mask: 0,
239 }
240 }
241 }
242
243 const SSE2_STEP: usize = 16;
244
245 impl<'s, 'h> NeonIndices<'s, 'h> {
246 pub unsafe fn next(&mut self) -> Option<usize> {
247 if self.start >= self.end {
248 return None;
249 }
250
251 let mut mask = self.mask;
252 let vectorized_end = self.end.sub(SSE2_STEP);
253 let mut current = self.current;
254 let start = self.start;
255 let v1 = self.searcher.v1;
256 let v2 = self.searcher.v2;
257 let v3 = self.searcher.v3;
258
259 'main: loop {
260 if mask != 0 {
262 let offset = current.sub(SSE2_STEP).add(first_offset(mask));
263 self.mask = clear_least_significant_bit(mask);
264 self.current = current;
265
266 return Some(offset.distance(start));
267 }
268
269 while current <= vectorized_end {
271 let chunk = vld1q_u8(current);
272 let cmp1 = vceqq_u8(chunk, v1);
273 let cmp2 = vceqq_u8(chunk, v2);
274 let cmp3 = vceqq_u8(chunk, v3);
275 let cmp = vorrq_u8(cmp1, cmp2);
276 let cmp = vorrq_u8(cmp, cmp3);
277
278 mask = neon_movemask(cmp);
279
280 current = current.add(SSE2_STEP);
281
282 if mask != 0 {
283 continue 'main;
284 }
285 }
286
287 while current < self.end {
289 if *current == self.searcher.n1
290 || *current == self.searcher.n2
291 || *current == self.searcher.n3
292 {
293 let offset = current.distance(start);
294 self.current = current.add(1);
295 return Some(offset);
296 }
297 current = current.add(1);
298 }
299
300 return None;
301 }
302 }
303 }
304}
305
306#[derive(Debug)]
307pub struct Searcher {
308 #[cfg(target_arch = "x86_64")]
309 inner: x86_64::sse2::SSE2Searcher,
310
311 #[cfg(target_arch = "aarch64")]
312 inner: aarch64::NeonSearcher,
313
314 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
315 inner: memchr::arch::all::memchr::Three,
316}
317
318impl Searcher {
319 pub fn leveraged_simd_instructions() -> &'static str {
320 #[cfg(target_arch = "x86_64")]
321 {
322 "sse2"
323 }
324
325 #[cfg(target_arch = "aarch64")]
326 {
327 "neon"
328 }
329
330 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
331 {
332 "none"
333 }
334 }
335
336 #[inline(always)]
337 pub fn new(n1: u8, n2: u8, n3: u8) -> Self {
338 #[cfg(target_arch = "x86_64")]
339 {
340 unsafe {
341 Self {
342 inner: x86_64::sse2::SSE2Searcher::new(n1, n2, n3),
343 }
344 }
345 }
346
347 #[cfg(target_arch = "aarch64")]
348 {
349 unsafe {
350 Self {
351 inner: aarch64::NeonSearcher::new(n1, n2, n3),
352 }
353 }
354 }
355
356 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
357 {
358 Self {
359 inner: memchr::arch::all::memchr::Three::new(n1, n2, n3),
360 }
361 }
362 }
363
364 #[inline(always)]
365 pub fn search<'s, 'h>(&'s self, haystack: &'h [u8]) -> Indices<'s, 'h> {
366 #[cfg(target_arch = "x86_64")]
367 {
368 Indices {
369 inner: self.inner.iter(haystack),
370 }
371 }
372
373 #[cfg(target_arch = "aarch64")]
374 {
375 Indices {
376 inner: self.inner.iter(haystack),
377 }
378 }
379
380 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
381 {
382 Indices {
383 inner: self.inner.iter(haystack),
384 }
385 }
386 }
387}
388
389#[derive(Debug)]
390pub struct Indices<'s, 'h> {
391 #[cfg(target_arch = "x86_64")]
392 inner: x86_64::sse2::SSE2Indices<'s, 'h>,
393
394 #[cfg(target_arch = "aarch64")]
395 inner: aarch64::NeonIndices<'s, 'h>,
396
397 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
398 inner: memchr::arch::all::memchr::ThreeIter<'s, 'h>,
399}
400
401impl<'s, 'h> FusedIterator for Indices<'s, 'h> {}
402
403impl<'s, 'h> Iterator for Indices<'s, 'h> {
404 type Item = usize;
405
406 #[inline(always)]
407 fn next(&mut self) -> Option<Self::Item> {
408 #[cfg(target_arch = "x86_64")]
409 {
410 unsafe { self.inner.next() }
411 }
412
413 #[cfg(target_arch = "aarch64")]
414 {
415 unsafe { self.inner.next() }
416 }
417
418 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
419 {
420 self.inner.next()
421 }
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 use memchr::arch::all::memchr::Three;
430
431 static TEST_STRING: &[u8] = b"name,\"surname\",age,color,oper\n,\n,\nation,punctuation\nname,surname,age,color,operation,punctuation";
432 static TEST_STRING_OFFSETS: &[usize; 18] = &[
433 4, 5, 13, 14, 18, 24, 29, 30, 31, 32, 33, 39, 51, 56, 64, 68, 74, 84,
434 ];
435
436 #[test]
437 fn test_scalar_searcher() {
438 fn split(haystack: &[u8]) -> Vec<usize> {
439 let searcher = Three::new(b',', b'"', b'\n');
440 searcher.iter(haystack).collect()
441 }
442
443 let offsets = split(TEST_STRING);
444 assert_eq!(offsets, TEST_STRING_OFFSETS);
445
446 assert!(split("b".repeat(75).as_bytes()).is_empty());
448
449 assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
451
452 assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
454
455 assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
457
458 assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
460 }
461
462 #[test]
463 fn test_searcher() {
464 fn split(haystack: &[u8]) -> Vec<usize> {
465 let searcher = Searcher::new(b',', b'"', b'\n');
466 searcher.search(haystack).collect()
467 }
468
469 let offsets = split(TEST_STRING);
470 assert_eq!(offsets, TEST_STRING_OFFSETS);
471
472 assert!(split("b".repeat(75).as_bytes()).is_empty());
474
475 assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
477
478 assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
480
481 assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
483
484 assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
486
487 let complex = b"name,surname,age\n\"john\",\"landy, the \"\"everlasting\"\" bastard\",45\nlucy,rose,\"67\"\njermaine,jackson,\"89\"\n\nkarine,loucan,\"52\"\nrose,\"glib\",12\n\"guillaume\",\"plique\",\"42\"\r\n";
489 let complex_indices = split(complex);
490
491 assert!(complex_indices
492 .iter()
493 .copied()
494 .all(|c| complex[c] == b',' || complex[c] == b'\n' || complex[c] == b'"'));
495
496 assert_eq!(
497 complex_indices,
498 Three::new(b',', b'\n', b'"')
499 .iter(complex)
500 .collect::<Vec<_>>()
501 );
502 }
503}