1use std::iter::FusedIterator;
2
3#[cfg(target_arch = "x86_64")]
4mod x86_64 {
5 use std::marker::PhantomData;
6
7 use crate::ext::Pointer;
8
9 #[inline(always)]
10 fn get_for_offset(mask: u32) -> u32 {
11 #[cfg(target_endian = "big")]
12 {
13 mask.swap_bytes()
14 }
15 #[cfg(target_endian = "little")]
16 {
17 mask
18 }
19 }
20
21 #[inline(always)]
22 fn first_offset(mask: u32) -> usize {
23 get_for_offset(mask).trailing_zeros() as usize
24 }
25
26 #[inline(always)]
27 fn clear_least_significant_bit(mask: u32) -> u32 {
28 mask & (mask - 1)
29 }
30
31 pub mod sse2 {
32 use super::*;
33
34 use core::arch::x86_64::{
35 __m128i, _mm_cmpeq_epi8, _mm_loadu_si128, _mm_movemask_epi8, _mm_or_si128,
36 _mm_set1_epi8,
37 };
38
39 #[derive(Debug)]
40 pub struct SSE2Searcher {
41 n1: u8,
42 n2: u8,
43 n3: u8,
44 v1: __m128i,
45 v2: __m128i,
46 v3: __m128i,
47 }
48
49 impl SSE2Searcher {
50 #[inline]
51 pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
52 Self {
53 n1,
54 n2,
55 n3,
56 v1: _mm_set1_epi8(n1 as i8),
57 v2: _mm_set1_epi8(n2 as i8),
58 v3: _mm_set1_epi8(n3 as i8),
59 }
60 }
61
62 #[inline(always)]
63 pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> SSE2Indices<'s, 'h> {
64 SSE2Indices::new(self, haystack)
65 }
66 }
67
68 #[derive(Debug)]
69 pub struct SSE2Indices<'s, 'h> {
70 searcher: &'s SSE2Searcher,
71 haystack: PhantomData<&'h [u8]>,
72 start: *const u8,
73 end: *const u8,
74 current: *const u8,
75 mask: u32,
76 }
77
78 impl<'s, 'h> SSE2Indices<'s, 'h> {
79 #[inline]
80 fn new(searcher: &'s SSE2Searcher, haystack: &'h [u8]) -> Self {
81 let ptr = haystack.as_ptr();
82
83 Self {
84 searcher,
85 haystack: PhantomData,
86 start: ptr,
87 end: ptr.wrapping_add(haystack.len()),
88 current: ptr,
89 mask: 0,
90 }
91 }
92 }
93
94 const SSE2_STEP: usize = 16;
95
96 impl SSE2Indices<'_, '_> {
97 pub unsafe fn next(&mut self) -> Option<usize> {
98 if self.start >= self.end {
99 return None;
100 }
101
102 let mut mask = self.mask;
103
104 let mut current = self.current;
105 let start = self.start;
106 let len = self.end.distance(start);
107 let v1 = self.searcher.v1;
108 let v2 = self.searcher.v2;
109 let v3 = self.searcher.v3;
110
111 'main: loop {
112 if mask != 0 {
114 let offset = current.sub(SSE2_STEP).add(first_offset(mask));
115 self.mask = clear_least_significant_bit(mask);
116 self.current = current;
117
118 return Some(offset.distance(start));
119 }
120
121 if len >= SSE2_STEP {
123 let vectorized_end = self.end.sub(SSE2_STEP);
124
125 while current <= vectorized_end {
126 let chunk = _mm_loadu_si128(current as *const __m128i);
127 let cmp1 = _mm_cmpeq_epi8(chunk, v1);
128 let cmp2 = _mm_cmpeq_epi8(chunk, v2);
129 let cmp3 = _mm_cmpeq_epi8(chunk, v3);
130 let cmp = _mm_or_si128(cmp1, cmp2);
131 let cmp = _mm_or_si128(cmp, cmp3);
132
133 mask = _mm_movemask_epi8(cmp) as u32;
134
135 current = current.add(SSE2_STEP);
136
137 if mask != 0 {
138 continue 'main;
139 }
140 }
141 }
142
143 while current < self.end {
145 if *current == self.searcher.n1
146 || *current == self.searcher.n2
147 || *current == self.searcher.n3
148 {
149 let offset = current.distance(start);
150 self.current = current.add(1);
151 return Some(offset);
152 }
153 current = current.add(1);
154 }
155
156 return None;
157 }
158 }
159 }
160 }
161}
162
163#[cfg(target_arch = "aarch64")]
164mod aarch64 {
165 use core::arch::aarch64::{
166 uint8x16_t, vceqq_u8, vdupq_n_u8, vget_lane_u64, vld1q_u8, vorrq_u8, vreinterpret_u64_u8,
167 vreinterpretq_u16_u8, vshrn_n_u16,
168 };
169 use std::marker::PhantomData;
170
171 use crate::ext::Pointer;
172
173 #[inline(always)]
174 unsafe fn neon_movemask(v: uint8x16_t) -> u64 {
175 let asu16s = vreinterpretq_u16_u8(v);
176 let mask = vshrn_n_u16(asu16s, 4);
177 let asu64 = vreinterpret_u64_u8(mask);
178 let scalar64 = vget_lane_u64(asu64, 0);
179
180 scalar64 & 0x8888888888888888
181 }
182
183 #[inline(always)]
184 fn first_offset(mask: u64) -> usize {
185 (mask.trailing_zeros() >> 2) as usize
186 }
187
188 #[inline(always)]
189 fn clear_least_significant_bit(mask: u64) -> u64 {
190 mask & (mask - 1)
191 }
192
193 #[derive(Debug)]
194 pub struct NeonSearcher {
195 n1: u8,
196 n2: u8,
197 n3: u8,
198 v1: uint8x16_t,
199 v2: uint8x16_t,
200 v3: uint8x16_t,
201 }
202
203 impl NeonSearcher {
204 #[inline]
205 pub unsafe fn new(n1: u8, n2: u8, n3: u8) -> Self {
206 Self {
207 n1,
208 n2,
209 n3,
210 v1: vdupq_n_u8(n1),
211 v2: vdupq_n_u8(n2),
212 v3: vdupq_n_u8(n3),
213 }
214 }
215
216 #[inline(always)]
217 pub fn iter<'s, 'h>(&'s self, haystack: &'h [u8]) -> NeonIndices<'s, 'h> {
218 NeonIndices::new(self, haystack)
219 }
220 }
221
222 #[derive(Debug)]
223 pub struct NeonIndices<'s, 'h> {
224 searcher: &'s NeonSearcher,
225 haystack: PhantomData<&'h [u8]>,
226 start: *const u8,
227 end: *const u8,
228 current: *const u8,
229 mask: u64,
230 }
231
232 impl<'s, 'h> NeonIndices<'s, 'h> {
233 #[inline]
234 fn new(searcher: &'s NeonSearcher, haystack: &'h [u8]) -> Self {
235 let ptr = haystack.as_ptr();
236
237 Self {
238 searcher,
239 haystack: PhantomData,
240 start: ptr,
241 end: ptr.wrapping_add(haystack.len()),
242 current: ptr,
243 mask: 0,
244 }
245 }
246 }
247
248 const NEON_STEP: usize = 16;
249
250 impl NeonIndices<'_, '_> {
251 pub unsafe fn next(&mut self) -> Option<usize> {
252 if self.start >= self.end {
253 return None;
254 }
255
256 let mut mask = self.mask;
257 let mut current = self.current;
258 let start = self.start;
259 let len = self.end.distance(start);
260 let v1 = self.searcher.v1;
261 let v2 = self.searcher.v2;
262 let v3 = self.searcher.v3;
263
264 'main: loop {
265 if mask != 0 {
267 let offset = current.sub(NEON_STEP).add(first_offset(mask));
268 self.mask = clear_least_significant_bit(mask);
269 self.current = current;
270
271 return Some(offset.distance(start));
272 }
273
274 if len >= NEON_STEP {
276 let vectorized_end = self.end.sub(NEON_STEP);
277
278 while current <= vectorized_end {
279 let chunk = vld1q_u8(current);
280 let cmp1 = vceqq_u8(chunk, v1);
281 let cmp2 = vceqq_u8(chunk, v2);
282 let cmp3 = vceqq_u8(chunk, v3);
283 let cmp = vorrq_u8(cmp1, cmp2);
284 let cmp = vorrq_u8(cmp, cmp3);
285
286 mask = neon_movemask(cmp);
287
288 current = current.add(NEON_STEP);
289
290 if mask != 0 {
291 continue 'main;
292 }
293 }
294 }
295
296 while current < self.end {
298 if *current == self.searcher.n1
299 || *current == self.searcher.n2
300 || *current == self.searcher.n3
301 {
302 let offset = current.distance(start);
303 self.current = current.add(1);
304 return Some(offset);
305 }
306 current = current.add(1);
307 }
308
309 return None;
310 }
311 }
312 }
313}
314
315pub fn searcher_simd_instructions() -> &'static str {
321 #[cfg(target_arch = "x86_64")]
322 {
323 "sse2"
324 }
325
326 #[cfg(target_arch = "aarch64")]
327 {
328 "neon"
329 }
330
331 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
332 {
333 "none"
334 }
335}
336
337#[derive(Debug)]
338pub struct Searcher {
339 #[cfg(target_arch = "x86_64")]
340 inner: x86_64::sse2::SSE2Searcher,
341
342 #[cfg(target_arch = "aarch64")]
343 inner: aarch64::NeonSearcher,
344
345 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
346 inner: memchr::arch::all::memchr::Three,
347}
348
349impl Searcher {
350 #[inline(always)]
351 pub fn new(n1: u8, n2: u8, n3: u8) -> Self {
352 #[cfg(target_arch = "x86_64")]
353 {
354 unsafe {
355 Self {
356 inner: x86_64::sse2::SSE2Searcher::new(n1, n2, n3),
357 }
358 }
359 }
360
361 #[cfg(target_arch = "aarch64")]
362 {
363 unsafe {
364 Self {
365 inner: aarch64::NeonSearcher::new(n1, n2, n3),
366 }
367 }
368 }
369
370 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
371 {
372 Self {
373 inner: memchr::arch::all::memchr::Three::new(n1, n2, n3),
374 }
375 }
376 }
377
378 #[inline(always)]
379 pub fn search<'s, 'h>(&'s self, haystack: &'h [u8]) -> Indices<'s, 'h> {
380 #[cfg(target_arch = "x86_64")]
381 {
382 Indices {
383 inner: self.inner.iter(haystack),
384 }
385 }
386
387 #[cfg(target_arch = "aarch64")]
388 {
389 Indices {
390 inner: self.inner.iter(haystack),
391 }
392 }
393
394 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
395 {
396 Indices {
397 inner: self.inner.iter(haystack),
398 }
399 }
400 }
401}
402
403#[derive(Debug)]
404pub struct Indices<'s, 'h> {
405 #[cfg(target_arch = "x86_64")]
406 inner: x86_64::sse2::SSE2Indices<'s, 'h>,
407
408 #[cfg(target_arch = "aarch64")]
409 inner: aarch64::NeonIndices<'s, 'h>,
410
411 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
412 inner: memchr::arch::all::memchr::ThreeIter<'s, 'h>,
413}
414
415impl FusedIterator for Indices<'_, '_> {}
416
417impl Iterator for Indices<'_, '_> {
418 type Item = usize;
419
420 #[inline(always)]
421 fn next(&mut self) -> Option<Self::Item> {
422 #[cfg(target_arch = "x86_64")]
423 {
424 unsafe { self.inner.next() }
425 }
426
427 #[cfg(target_arch = "aarch64")]
428 {
429 unsafe { self.inner.next() }
430 }
431
432 #[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
433 {
434 self.inner.next()
435 }
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 use memchr::arch::all::memchr::Three;
444
445 static TEST_STRING: &[u8] = b"name,\"surname\",age,color,oper\n,\n,\nation,punctuation\nname,surname,age,color,operation,punctuation";
446 static TEST_STRING_OFFSETS: &[usize; 18] = &[
447 4, 5, 13, 14, 18, 24, 29, 30, 31, 32, 33, 39, 51, 56, 64, 68, 74, 84,
448 ];
449
450 #[test]
451 fn test_scalar_searcher() {
452 fn split(haystack: &[u8]) -> Vec<usize> {
453 let searcher = Three::new(b',', b'"', b'\n');
454 searcher.iter(haystack).collect()
455 }
456
457 let offsets = split(TEST_STRING);
458 assert_eq!(offsets, TEST_STRING_OFFSETS);
459
460 assert!(split("b".repeat(75).as_bytes()).is_empty());
462
463 assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
465
466 assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
468
469 assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
471
472 assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
474 }
475
476 #[test]
477 fn test_searcher() {
478 fn split(haystack: &[u8]) -> Vec<usize> {
479 let searcher = Searcher::new(b',', b'"', b'\n');
480 searcher.search(haystack).collect()
481 }
482
483 let offsets = split(TEST_STRING);
484 assert_eq!(offsets, TEST_STRING_OFFSETS);
485
486 assert!(split("b".repeat(75).as_bytes()).is_empty());
488
489 assert_eq!(split("b,".repeat(75).as_bytes()).len(), 75);
491
492 assert_eq!(split("b,".repeat(64).as_bytes()).len(), 64);
494
495 assert_eq!(split("b,".repeat(25).as_bytes()).len(), 25);
497
498 assert_eq!(split("b,".repeat(13).as_bytes()).len(), 13);
500
501 let complex = b"name,surname,age\n\"john\",\"landy, the \"\"everlasting\"\" bastard\",45\nlucy,rose,\"67\"\njermaine,jackson,\"89\"\n\nkarine,loucan,\"52\"\nrose,\"glib\",12\n\"guillaume\",\"plique\",\"42\"\r\n";
503 let complex_indices = split(complex);
504
505 assert!(complex_indices
506 .iter()
507 .copied()
508 .all(|c| complex[c] == b',' || complex[c] == b'\n' || complex[c] == b'"'));
509
510 assert_eq!(
511 complex_indices,
512 Three::new(b',', b'\n', b'"')
513 .iter(complex)
514 .collect::<Vec<_>>()
515 );
516 }
517}