1use crate::PositionSimd;
2use std::slice;
3
4pub struct SimdFilter<'a, T, F>
5where
6 T: std::cmp::PartialEq + Copy,
7 F: Fn(&T) -> bool,
8{
9 pub position: usize,
10 pub f: F,
11 pub arr: &'a [T],
12}
13
14impl<'a, T, F> Iterator for SimdFilter<'a, T, F>
15where
16 T: std::cmp::PartialEq + Copy,
17 F: Fn(&T) -> bool,
18{
19 type Item = T;
20
21 fn next(&mut self) -> Option<Self::Item> {
22 match self.arr[self.position..].iter().position_simd(&self.f) {
23 Some(pos) => {
24 self.position += pos + 1;
25 Some(self.arr[self.position - 1])
26 }
27 None => None,
28 }
29 }
30}
31
32pub trait FilterSimd<'a, T>
33where
34 T: std::cmp::PartialEq + Copy,
35{
36 fn filter_simd<F>(&self, f: F) -> SimdFilter<'a, T, F>
37 where
38 F: Fn(&T) -> bool + 'a;
39}
40
41impl<'a, T> FilterSimd<'a, T> for slice::Iter<'a, T>
42where
43 T: std::cmp::PartialEq + Copy,
44{
45 fn filter_simd<F>(&self, f: F) -> SimdFilter<'a, T, F>
69 where
70 F: Fn(&T) -> bool + 'a,
71 {
72 SimdFilter {
73 position: 0,
74 f,
75 arr: self.as_slice(),
76 }
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use itertools::Itertools;
83 use rand::distributions::Standard;
84 use rand::prelude::Distribution;
85 use rand::Rng;
86 use std::fmt::Debug;
87
88 use crate::FilterSimd;
89
90 fn test_simd_for_type<T>()
91 where
92 T: rand::distributions::uniform::SampleUniform
93 + PartialEq
94 + Copy
95 + Default
96 + Debug
97 + std::cmp::PartialEq
98 + std::cmp::PartialOrd,
99 Standard: Distribution<T>,
100 {
101 for len in 0..5000 {
102 let ops = [
103 |x: &T| *x == T::default(),
104 |x: &T| *x != T::default(),
105 |x: &T| *x < T::default(),
106 |x: &T| *x > T::default(),
107 |x: &T| [T::default()].contains(x),
108 ];
109 let ops2 = [
110 |x: &&T| **x == T::default(),
111 |x: &&T| **x != T::default(),
112 |x: &&T| **x < T::default(),
113 |x: &&T| **x > T::default(),
114 |x: &&T| [T::default()].contains(x),
115 ];
116
117 for (op_simd, op_scalar) in ops.iter().zip(ops2) {
118 let mut v: Vec<T> = vec![T::default(); len];
119 let mut rng = rand::thread_rng();
120 for x in v.iter_mut() {
121 *x = rng.gen()
122 }
123
124 let ans = v.iter().filter_simd(op_simd).collect_vec();
125 let correct = v.iter().filter(op_scalar).cloned().collect_vec();
126 assert_eq!(
127 ans,
128 correct,
129 "Failed for length {} and type {:?}",
130 len,
131 std::any::type_name::<T>()
132 );
133 }
134 }
135 }
136
137 #[test]
138 fn test_simd() {
139 test_simd_for_type::<i8>();
140 test_simd_for_type::<i16>();
141 test_simd_for_type::<i32>();
142 test_simd_for_type::<i64>();
143 test_simd_for_type::<u8>();
144 test_simd_for_type::<u16>();
145 test_simd_for_type::<u32>();
146 test_simd_for_type::<u64>();
147 test_simd_for_type::<usize>();
148 test_simd_for_type::<isize>();
149 test_simd_for_type::<f32>();
150 test_simd_for_type::<f64>();
151 }
152}