simd_itertools/
position.rs

1use crate::LANE_COUNT;
2use multiversion::multiversion;
3use std::slice;
4
5pub trait PositionSimd<'a, T>
6where
7    T: std::cmp::PartialEq,
8{
9    fn position_simd<F>(&self, f: F) -> Option<usize>
10    where
11        F: Fn(&T) -> bool;
12}
13impl<'a, T> PositionSimd<'a, T> for slice::Iter<'a, T>
14where
15    T: std::cmp::PartialEq,
16{
17    fn position_simd<F>(&self, f: F) -> Option<usize>
18    where
19        F: Fn(&T) -> bool,
20    {
21        position_autovec(self.as_slice(), f)
22    }
23}
24
25#[multiversion(targets = "simd")]
26pub fn position_autovec<F, T>(arr: &[T], f: F) -> Option<usize>
27where
28    F: Fn(&T) -> bool,
29{
30    let mut chunks = arr.chunks_exact(LANE_COUNT);
31    for (chunk_idx, chunk) in chunks.by_ref().enumerate() {
32        if chunk.iter().fold(false, |acc, x| acc | (f(x))) {
33            return Some(
34                chunk_idx * LANE_COUNT + unsafe { chunk.iter().position(f).unwrap_unchecked() },
35            );
36        }
37    }
38    chunks
39        .remainder()
40        .iter()
41        .position(f)
42        .map(|i| (arr.len() / LANE_COUNT) * LANE_COUNT + i)
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use rand::distributions::Standard;
49    use rand::prelude::Distribution;
50    use rand::Rng;
51
52    fn test_simd_for_type<T>()
53    where
54        T: rand::distributions::uniform::SampleUniform
55            + PartialEq
56            + Copy
57            + Default
58            + std::cmp::PartialEq
59            + std::cmp::PartialOrd,
60        Standard: Distribution<T>,
61    {
62        for len in 0..5000 {
63            let ops = [
64                |x: &T| *x == T::default(),
65                |x: &T| *x != T::default(),
66                |x: &T| *x < T::default(),
67                |x: &T| *x > T::default(),
68                |x: &T| [T::default()].contains(x),
69            ];
70
71            for op in ops {
72                let mut v: Vec<T> = vec![T::default(); len];
73                let mut rng = rand::thread_rng();
74                for x in v.iter_mut() {
75                    *x = rng.gen()
76                }
77
78                let ans = v.iter().position_simd(op);
79                let correct = v.iter().position(op);
80                assert_eq!(
81                    ans,
82                    correct,
83                    "Failed for length {} and type {:?}",
84                    len,
85                    std::any::type_name::<T>()
86                );
87            }
88        }
89    }
90
91    #[test]
92    fn test_simd() {
93        test_simd_for_type::<i8>();
94        test_simd_for_type::<i16>();
95        test_simd_for_type::<i32>();
96        test_simd_for_type::<i64>();
97        test_simd_for_type::<u8>();
98        test_simd_for_type::<u16>();
99        test_simd_for_type::<u32>();
100        test_simd_for_type::<u64>();
101        test_simd_for_type::<usize>();
102        test_simd_for_type::<isize>();
103        test_simd_for_type::<f32>();
104        test_simd_for_type::<f64>();
105    }
106}