simd_itertools/
contains.rs1use crate::LANE_COUNT;
2use std::slice;
3
4pub trait ContainsSimd<'a, T>
5where
6 T: std::cmp::PartialEq,
7{
8 fn contains_simd(&self, elem: &T) -> bool;
9}
10impl<'a, T> ContainsSimd<'a, T> for slice::Iter<'a, T>
11where
12 T: std::cmp::PartialEq,
13{
14 fn contains_simd(&self, elem: &T) -> bool
15 where
16 T: PartialEq,
17 {
18 let mut chunks = self.as_slice().chunks_exact(LANE_COUNT);
19 for chunk in chunks.by_ref() {
20 if chunk.iter().fold(false, |acc, x| acc | (x == elem)) {
21 return true;
22 }
23 }
24 chunks.remainder().contains(elem)
25 }
26}
27
28#[cfg(test)]
29mod tests {
30 use super::*;
31 use itertools::Itertools;
32 use rand::distributions::Standard;
33 use rand::prelude::Distribution;
34 use rand::seq::SliceRandom;
35 use rand::Rng;
36
37 fn test_simd_for_type<T>()
38 where
39 T: rand::distributions::uniform::SampleUniform
40 + PartialEq
41 + Copy
42 + Default
43 + std::cmp::PartialEq,
44 Standard: Distribution<T>,
45 {
46 for len in 0..500 {
47 for _ in 0..5 {
48 let mut v: Vec<T> = vec![T::default(); len];
49 let mut rng = rand::thread_rng();
50 for x in v.iter_mut() {
51 *x = rng.gen()
52 }
53 let needle = match rng.gen_bool(0.5) {
54 true => v.choose(&mut rng).cloned().unwrap_or(T::default()),
55 false => loop {
56 let n = rng.gen();
57 if !v.contains(&n) {
58 break n;
59 }
60 },
61 };
62 let ans = v.iter().contains_simd(&needle);
63 let correct = v.iter().contains(&needle);
64 assert_eq!(
65 ans,
66 correct,
67 "Failed for length {} and type {:?}",
68 len,
69 std::any::type_name::<T>()
70 );
71 }
72 }
73 }
74
75 #[test]
76 fn test_simd_contains() {
77 test_simd_for_type::<i8>();
78 test_simd_for_type::<i16>();
79 test_simd_for_type::<i32>();
80 test_simd_for_type::<i64>();
81 test_simd_for_type::<u8>();
82 test_simd_for_type::<u16>();
83 test_simd_for_type::<u32>();
84 test_simd_for_type::<u64>();
85 test_simd_for_type::<usize>();
86 test_simd_for_type::<isize>();
87 test_simd_for_type::<f32>();
88 test_simd_for_type::<f64>();
89 }
90}