simd_slice/
lib.rs

1#![feature(portable_simd)]
2
3use std::simd::{Simd, SimdElement};
4
5pub trait AsSimdSlice<T: SimdElement> {
6    fn as_simd_slice<'a>(&'a self) -> SimdSlice<'a, T>;
7}
8
9impl<T: SimdElement> AsSimdSlice<T> for [T] {
10    fn as_simd_slice<'a>(&'a self) -> SimdSlice<'a, T> {
11        SimdSlice(self)
12    }
13}
14
15pub struct SimdSlice<'a, T: SimdElement>(&'a [T]);
16
17macro_rules! impl_op {
18    ($t: ty) => {
19        impl<'a> SimdSlice<'a, $t> {
20            pub fn sum(&self) -> $t {
21                let n = self.0.len() & (!3); // round down to a multiple of 4
22                let slice = &self.0[..n];
23
24                let mut i = 0;
25                let mut agg = Simd::<$t, 4>::splat(num::zero());
26
27                // this is much slower:
28                // slice
29                //     .chunks(4)
30                //     .for_each(|s| agg += Simd::<$t, 4>::from_slice(s));
31
32                while i < n {
33                    agg += Simd::<$t, 4>::from_slice(unsafe { &slice.get_unchecked(i..i + 4) });
34                    i += 4;
35                }
36                let mut agg = agg.horizontal_sum();
37                self.0[n..].iter().for_each(|x| agg += x);
38                agg
39            }
40
41            pub fn min(&self) -> Option<$t> {
42                let n = self.0.len() & (!3); // round down to a multiple of 4
43
44                if n > 0 {
45                    let slice = &self.0[..n];
46                    let mut i = 4;
47                    let mut agg = Simd::<$t, 4>::from_slice(&slice[..4]);
48                    while i < n {
49                        agg += Simd::<$t, 4>::from_slice(unsafe { &slice.get_unchecked(i..i + 4) });
50                        i += 4;
51                    }
52                    let mut agg = agg.horizontal_min();
53                    self.0[n..].iter().for_each(|&x| {
54                        if x < agg {
55                            agg = x;
56                        }
57                    });
58                    Some(agg)
59                } else if self.0.len() > 0 {
60                    let mut agg = self.0[0];
61                    self.0[1..].iter().for_each(|&x| {
62                        if x < agg {
63                            agg = x;
64                        }
65                    });
66                    Some(agg)
67                } else {
68                    None
69                }
70            }
71
72            pub fn max(&self) -> Option<$t> {
73                let n = self.0.len() & (!3); // round down to a multiple of 4
74
75                if n > 0 {
76                    let slice = &self.0[..n];
77                    let mut i = 4;
78                    let mut agg = Simd::<$t, 4>::from_slice(&slice[..4]);
79                    while i < n {
80                        agg += Simd::<$t, 4>::from_slice(unsafe { &slice.get_unchecked(i..i + 4) });
81                        i += 4;
82                    }
83                    let mut agg = agg.horizontal_max();
84                    self.0[n..].iter().for_each(|&x| {
85                        if x < agg {
86                            agg = x;
87                        }
88                    });
89                    Some(agg)
90                } else if self.0.len() > 0 {
91                    let mut agg = self.0[0];
92                    self.0[1..].iter().for_each(|&x| {
93                        if x < agg {
94                            agg = x;
95                        }
96                    });
97                    Some(agg)
98                } else {
99                    None
100                }
101            }
102        }
103    };
104}
105impl_op!(u8);
106impl_op!(u16);
107impl_op!(u32);
108impl_op!(u64);
109impl_op!(usize);
110impl_op!(i8);
111impl_op!(i16);
112impl_op!(i32);
113impl_op!(i64);
114impl_op!(isize);
115impl_op!(f32);
116impl_op!(f64);
117
118#[cfg(test)]
119mod tests {
120
121    #[test]
122    fn it_works() {
123        use super::AsSimdSlice;
124
125        let a = &[10, 20, 3, 4, 5, 6, 7_i32];
126        assert_eq!(a[1..5].iter().sum::<i32>(), a[1..5].as_simd_slice().sum());
127        assert_eq!(a[1..5].iter().max().cloned(), a[1..5].as_simd_slice().max());
128        assert_eq!(a[1..5].iter().min().cloned(), a[1..5].as_simd_slice().min());
129    }
130}