simd_itertools/
filter.rs

1use crate::PositionSimd;
2use std::slice;
3
4pub struct SimdFilter<'a, T, F>
5where
6    T: std::cmp::PartialEq + Copy,
7    F: Fn(&T) -> bool,
8{
9    pub position: usize,
10    pub f: F,
11    pub arr: &'a [T],
12}
13
14impl<'a, T, F> Iterator for SimdFilter<'a, T, F>
15where
16    T: std::cmp::PartialEq + Copy,
17    F: Fn(&T) -> bool,
18{
19    type Item = T;
20
21    fn next(&mut self) -> Option<Self::Item> {
22        match self.arr[self.position..].iter().position_simd(&self.f) {
23            Some(pos) => {
24                self.position += pos + 1;
25                Some(self.arr[self.position - 1])
26            }
27            None => None,
28        }
29    }
30}
31
32pub trait FilterSimd<'a, T>
33where
34    T: std::cmp::PartialEq + Copy,
35{
36    fn filter_simd<F>(&self, f: F) -> SimdFilter<'a, T, F>
37    where
38        F: Fn(&T) -> bool + 'a;
39}
40
41impl<'a, T> FilterSimd<'a, T> for slice::Iter<'a, T>
42where
43    T: std::cmp::PartialEq + Copy,
44{
45    /// This is the least optimal of all functions.
46    /// current implementation relies on sparsity of elems.
47    ///
48    ///
49    /// This kind of pattern is fast:
50    /// ```[0,0,0,0,0,0,0,0,0,0,1,1,0,1,1,0,0,0,0,0,0]```
51    ///
52    /// This kind of pattern is slow (similar to scalar speed):
53    /// ```[1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1]```
54    ///
55    /// The speed comes from checking if a chunks contains any wanted element.
56    ///
57    ///
58    ///  ```(0..10000).collect_vec().iter().filter_simd(|x| *x % 100 == 0).collect::<Vec<i32>>()```
59    /// is ~4x faster on x86 with avx2
60    ///
61    ///  ```(0..10000).collect_vec().iter().filter_simd(|x| *x % 10 == 0).collect::<Vec<i32>>()```
62    /// is ~2x faster on x86 with avx2
63    ///
64    ///```(0..10000).collect_vec().iter().filter_simd(|x| *x % 1 == 0).collect::<Vec<i32>>()```
65    /// is 30% slower than scalar on x86 with avx2
66    ///
67    /// Something like this works well on all patterns on x86:
68    fn filter_simd<F>(&self, f: F) -> SimdFilter<'a, T, F>
69    where
70        F: Fn(&T) -> bool + 'a,
71    {
72        SimdFilter {
73            position: 0,
74            f,
75            arr: self.as_slice(),
76        }
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use itertools::Itertools;
83    use rand::distributions::Standard;
84    use rand::prelude::Distribution;
85    use rand::Rng;
86    use std::fmt::Debug;
87
88    use crate::FilterSimd;
89
90    fn test_simd_for_type<T>()
91    where
92        T: rand::distributions::uniform::SampleUniform
93            + PartialEq
94            + Copy
95            + Default
96            + Debug
97            + std::cmp::PartialEq
98            + std::cmp::PartialOrd,
99        Standard: Distribution<T>,
100    {
101        for len in 0..5000 {
102            let ops = [
103                |x: &T| *x == T::default(),
104                |x: &T| *x != T::default(),
105                |x: &T| *x < T::default(),
106                |x: &T| *x > T::default(),
107                |x: &T| [T::default()].contains(x),
108            ];
109            let ops2 = [
110                |x: &&T| **x == T::default(),
111                |x: &&T| **x != T::default(),
112                |x: &&T| **x < T::default(),
113                |x: &&T| **x > T::default(),
114                |x: &&T| [T::default()].contains(x),
115            ];
116
117            for (op_simd, op_scalar) in ops.iter().zip(ops2) {
118                let mut v: Vec<T> = vec![T::default(); len];
119                let mut rng = rand::thread_rng();
120                for x in v.iter_mut() {
121                    *x = rng.gen()
122                }
123
124                let ans = v.iter().filter_simd(op_simd).collect_vec();
125                let correct = v.iter().filter(op_scalar).cloned().collect_vec();
126                assert_eq!(
127                    ans,
128                    correct,
129                    "Failed for length {} and type {:?}",
130                    len,
131                    std::any::type_name::<T>()
132                );
133            }
134        }
135    }
136
137    #[test]
138    fn test_simd() {
139        test_simd_for_type::<i8>();
140        test_simd_for_type::<i16>();
141        test_simd_for_type::<i32>();
142        test_simd_for_type::<i64>();
143        test_simd_for_type::<u8>();
144        test_simd_for_type::<u16>();
145        test_simd_for_type::<u32>();
146        test_simd_for_type::<u64>();
147        test_simd_for_type::<usize>();
148        test_simd_for_type::<isize>();
149        test_simd_for_type::<f32>();
150        test_simd_for_type::<f64>();
151    }
152}