1use crate::PositionSimd;
2use std::slice;
3
4pub trait ArgmaxSimd<'a, T>
5where
6 T: std::cmp::PartialEq,
7{
8 fn argmax_simd(&self) -> Option<usize>;
9 fn argmax_simd_fast(&self) -> Option<usize>;
10}
11
12impl<'a, T> ArgmaxSimd<'a, T> for slice::Iter<'a, T>
13where
14 T: std::cmp::PartialEq + std::cmp::PartialOrd + Copy + std::cmp::Ord,
15{
16 fn argmax_simd(&self) -> Option<usize> {
17 match self.as_slice().iter().copied().max() {
18 Some(max) => self.position_simd(|x| *x == max),
19 None => None,
20 }
21 }
22 fn argmax_simd_fast(&self) -> Option<usize> {
23 match self
24 .as_slice()
25 .iter()
26 .reduce(|a, b| if a > b { a } else { b })
27 {
28 Some(max) => self.position_simd(|x| *x == *max),
29 None => None,
30 }
31 }
32}
33
34#[cfg(test)]
35mod tests {
36 use super::*;
37 use rand::distributions::Standard;
38 use rand::prelude::Distribution;
39 use rand::Rng;
40 use std::fmt::Debug;
41
42 fn test_simd_for_type<T>()
43 where
44 T: rand::distributions::uniform::SampleUniform
45 + PartialEq
46 + Debug
47 + Copy
48 + Default
49 + std::cmp::PartialEq
50 + Ord,
51 Standard: Distribution<T>,
52 {
53 for len in 0..1000 {
54 for _ in 0..5 {
55 let mut v: Vec<T> = vec![T::default(); len];
56 let mut rng = rand::thread_rng();
57 for x in v.iter_mut() {
58 *x = rng.gen()
59 }
60 let ans = v.iter().argmax_simd();
62 let correct = v
63 .iter()
64 .position(|x| *x == v.iter().cloned().max().unwrap());
65 assert_eq!(
66 ans,
67 correct,
68 "Failed for length {} and type {:?} {:?}",
69 len,
70 std::any::type_name::<T>(),
71 v
72 );
73 let ans = v.iter().argmax_simd_fast();
75 let correct = v.iter().position(|x| {
76 *x == v
77 .iter()
78 .copied()
79 .reduce(|a, b| if a > b { a } else { b })
80 .unwrap()
81 });
82 assert_eq!(
83 ans,
84 correct,
85 "Failed for length {} and type {:?} {:?}",
86 len,
87 std::any::type_name::<T>(),
88 v
89 );
90 }
91 }
92 }
93
94 #[test]
95 fn test_simd_max() {
96 test_simd_for_type::<i8>();
97 test_simd_for_type::<i16>();
98 test_simd_for_type::<i32>();
99 test_simd_for_type::<i64>();
100 test_simd_for_type::<u8>();
101 test_simd_for_type::<u16>();
102 test_simd_for_type::<u32>();
103 test_simd_for_type::<u64>();
104 test_simd_for_type::<usize>();
105 test_simd_for_type::<isize>();
106 }
107}