simd_itertools/
argmin.rs

1use crate::position::PositionSimd;
2use std::slice;
3
4pub trait ArgminSimd<'a, T>
5where
6    T: std::cmp::PartialEq,
7{
8    fn argmin_simd(&self) -> Option<usize>;
9    fn argmin_simd_fast(&self) -> Option<usize>;
10}
11
12impl<'a, T> ArgminSimd<'a, T> for slice::Iter<'a, T>
13where
14    T: std::cmp::PartialEq + std::cmp::PartialOrd + Copy + std::cmp::Ord,
15{
16    fn argmin_simd(&self) -> Option<usize> {
17        match self.as_slice().iter().copied().min() {
18            Some(min) => self.position_simd(|x| *x == min),
19            None => None,
20        }
21    }
22    fn argmin_simd_fast(&self) -> Option<usize> {
23        match self
24            .as_slice()
25            .iter()
26            .reduce(|a, b| if a < b { a } else { b })
27        {
28            Some(min) => self.position_simd(|x| *x == *min),
29            None => None,
30        }
31    }
32}
33
34#[cfg(test)]
35mod tests {
36    use super::*;
37    use rand::distributions::Standard;
38    use rand::prelude::Distribution;
39    use rand::Rng;
40    use std::fmt::Debug;
41
42    fn test_simd_for_type<T>()
43    where
44        T: rand::distributions::uniform::SampleUniform
45            + PartialEq
46            + Debug
47            + Copy
48            + Default
49            + std::cmp::PartialEq
50            + Ord,
51        Standard: Distribution<T>,
52    {
53        for len in 0..1000 {
54            for _ in 0..5 {
55                let mut v: Vec<T> = vec![T::default(); len];
56                let mut rng = rand::thread_rng();
57                for x in v.iter_mut() {
58                    *x = rng.gen()
59                }
60                // normal
61                let ans = v.iter().argmin_simd();
62                let correct = v
63                    .iter()
64                    .position(|x| *x == v.iter().cloned().min().unwrap());
65                assert_eq!(
66                    ans,
67                    correct,
68                    "Failed for length {} and type {:?} {:?}",
69                    len,
70                    std::any::type_name::<T>(),
71                    v
72                );
73                // fast
74                let ans = v.iter().argmin_simd_fast();
75                let correct = v.iter().position(|x| {
76                    *x == v
77                        .iter()
78                        .copied()
79                        .reduce(|a, b| if a < b { a } else { b })
80                        .unwrap()
81                });
82                assert_eq!(
83                    ans,
84                    correct,
85                    "Failed for length {} and type {:?} {:?}",
86                    len,
87                    std::any::type_name::<T>(),
88                    v
89                );
90            }
91        }
92    }
93
94    #[test]
95    fn test_simd_min() {
96        test_simd_for_type::<i8>();
97        test_simd_for_type::<i16>();
98        test_simd_for_type::<i32>();
99        test_simd_for_type::<i64>();
100        test_simd_for_type::<u8>();
101        test_simd_for_type::<u16>();
102        test_simd_for_type::<u32>();
103        test_simd_for_type::<u64>();
104        test_simd_for_type::<usize>();
105        test_simd_for_type::<isize>();
106    }
107}