simd_itertools/
position.rs1use crate::LANE_COUNT;
2use multiversion::multiversion;
3use std::slice;
4
5pub trait PositionSimd<'a, T>
6where
7 T: std::cmp::PartialEq,
8{
9 fn position_simd<F>(&self, f: F) -> Option<usize>
10 where
11 F: Fn(&T) -> bool;
12}
13impl<'a, T> PositionSimd<'a, T> for slice::Iter<'a, T>
14where
15 T: std::cmp::PartialEq,
16{
17 fn position_simd<F>(&self, f: F) -> Option<usize>
18 where
19 F: Fn(&T) -> bool,
20 {
21 position_autovec(self.as_slice(), f)
22 }
23}
24
25#[multiversion(targets = "simd")]
26pub fn position_autovec<F, T>(arr: &[T], f: F) -> Option<usize>
27where
28 F: Fn(&T) -> bool,
29{
30 let mut chunks = arr.chunks_exact(LANE_COUNT);
31 for (chunk_idx, chunk) in chunks.by_ref().enumerate() {
32 if chunk.iter().fold(false, |acc, x| acc | (f(x))) {
33 return Some(
34 chunk_idx * LANE_COUNT + unsafe { chunk.iter().position(f).unwrap_unchecked() },
35 );
36 }
37 }
38 chunks
39 .remainder()
40 .iter()
41 .position(f)
42 .map(|i| (arr.len() / LANE_COUNT) * LANE_COUNT + i)
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48 use rand::distributions::Standard;
49 use rand::prelude::Distribution;
50 use rand::Rng;
51
52 fn test_simd_for_type<T>()
53 where
54 T: rand::distributions::uniform::SampleUniform
55 + PartialEq
56 + Copy
57 + Default
58 + std::cmp::PartialEq
59 + std::cmp::PartialOrd,
60 Standard: Distribution<T>,
61 {
62 for len in 0..5000 {
63 let ops = [
64 |x: &T| *x == T::default(),
65 |x: &T| *x != T::default(),
66 |x: &T| *x < T::default(),
67 |x: &T| *x > T::default(),
68 |x: &T| [T::default()].contains(x),
69 ];
70
71 for op in ops {
72 let mut v: Vec<T> = vec![T::default(); len];
73 let mut rng = rand::thread_rng();
74 for x in v.iter_mut() {
75 *x = rng.gen()
76 }
77
78 let ans = v.iter().position_simd(op);
79 let correct = v.iter().position(op);
80 assert_eq!(
81 ans,
82 correct,
83 "Failed for length {} and type {:?}",
84 len,
85 std::any::type_name::<T>()
86 );
87 }
88 }
89 }
90
91 #[test]
92 fn test_simd() {
93 test_simd_for_type::<i8>();
94 test_simd_for_type::<i16>();
95 test_simd_for_type::<i32>();
96 test_simd_for_type::<i64>();
97 test_simd_for_type::<u8>();
98 test_simd_for_type::<u16>();
99 test_simd_for_type::<u32>();
100 test_simd_for_type::<u64>();
101 test_simd_for_type::<usize>();
102 test_simd_for_type::<isize>();
103 test_simd_for_type::<f32>();
104 test_simd_for_type::<f64>();
105 }
106}