simd_itertools/
contains.rs

1use crate::LANE_COUNT;
2use std::slice;
3
4pub trait ContainsSimd<'a, T>
5where
6    T: std::cmp::PartialEq,
7{
8    fn contains_simd(&self, elem: &T) -> bool;
9}
10impl<'a, T> ContainsSimd<'a, T> for slice::Iter<'a, T>
11where
12    T: std::cmp::PartialEq,
13{
14    fn contains_simd(&self, elem: &T) -> bool
15    where
16        T: PartialEq,
17    {
18        let mut chunks = self.as_slice().chunks_exact(LANE_COUNT);
19        for chunk in chunks.by_ref() {
20            if chunk.iter().fold(false, |acc, x| acc | (x == elem)) {
21                return true;
22            }
23        }
24        chunks.remainder().contains(elem)
25    }
26}
27
28#[cfg(test)]
29mod tests {
30    use super::*;
31    use itertools::Itertools;
32    use rand::distributions::Standard;
33    use rand::prelude::Distribution;
34    use rand::seq::SliceRandom;
35    use rand::Rng;
36
37    fn test_simd_for_type<T>()
38    where
39        T: rand::distributions::uniform::SampleUniform
40            + PartialEq
41            + Copy
42            + Default
43            + std::cmp::PartialEq,
44        Standard: Distribution<T>,
45    {
46        for len in 0..500 {
47            for _ in 0..5 {
48                let mut v: Vec<T> = vec![T::default(); len];
49                let mut rng = rand::thread_rng();
50                for x in v.iter_mut() {
51                    *x = rng.gen()
52                }
53                let needle = match rng.gen_bool(0.5) {
54                    true => v.choose(&mut rng).cloned().unwrap_or(T::default()),
55                    false => loop {
56                        let n = rng.gen();
57                        if !v.contains(&n) {
58                            break n;
59                        }
60                    },
61                };
62                let ans = v.iter().contains_simd(&needle);
63                let correct = v.iter().contains(&needle);
64                assert_eq!(
65                    ans,
66                    correct,
67                    "Failed for length {} and type {:?}",
68                    len,
69                    std::any::type_name::<T>()
70                );
71            }
72        }
73    }
74
75    #[test]
76    fn test_simd_contains() {
77        test_simd_for_type::<i8>();
78        test_simd_for_type::<i16>();
79        test_simd_for_type::<i32>();
80        test_simd_for_type::<i64>();
81        test_simd_for_type::<u8>();
82        test_simd_for_type::<u16>();
83        test_simd_for_type::<u32>();
84        test_simd_for_type::<u64>();
85        test_simd_for_type::<usize>();
86        test_simd_for_type::<isize>();
87        test_simd_for_type::<f32>();
88        test_simd_for_type::<f64>();
89    }
90}