tract_linalg/generic/
reduce.rs

1// Reduce<max> generic implementation
2pub mod max {
3    pub use tract_data::internal::f16;
4
5    reduce_impl_wrap!(
6        f32,
7        SMax4,
8        4,
9        4,
10        (),
11        f32::MIN,
12        fn run(x: &[f32], _: ()) -> f32 {
13            debug_assert!(x.len() % Self::nr() == 0);
14            debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
15            *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap()
16        },
17        fn reduce_two(a: f32, b: f32) -> f32 {
18            a.max(b)
19        }
20    );
21
22    reduce_impl_wrap!(
23        f16,
24        HMax8,
25        8,
26        8,
27        (),
28        f16::MIN,
29        fn run(x: &[f16], _: ()) -> f16 {
30            debug_assert!(x.len() % Self::nr() == 0);
31            debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
32            *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap()
33        },
34        fn reduce_two(a: f16, b: f16) -> f16 {
35            a.max(b)
36        }
37    );
38
39    #[cfg(test)]
40    #[macro_use]
41    pub mod s {
42        crate::max_frame_tests!(true, f32, crate::generic::reduce::max::SMax4);
43    }
44
45    #[cfg(test)]
46    #[macro_use]
47    pub mod h {
48        use super::*;
49        crate::max_frame_tests!(true, f16, crate::generic::reduce::max::HMax8);
50    }
51}
52
53// Reduce<sum> generic implementation
54pub mod sum {
55    use crate::num_traits::Zero;
56    pub use tract_data::internal::f16;
57
58    reduce_impl_wrap!(
59        f32,
60        SSum4,
61        4,
62        4,
63        (),
64        0.0,
65        fn run(x: &[f32], _: ()) -> f32 {
66            debug_assert!(x.len() % Self::nr() == 0);
67            debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
68            x.iter().sum::<f32>()
69        },
70        fn reduce_two(a: f32, b: f32) -> f32 {
71            a + b
72        }
73    );
74
75    reduce_impl_wrap!(
76        f16,
77        HSum8,
78        8,
79        8,
80        (),
81        f16::zero(),
82        fn run(x: &[f16], _: ()) -> f16 {
83            debug_assert!(x.len() % Self::nr() == 0);
84            debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
85            x.iter().sum::<f16>()
86        },
87        fn reduce_two(a: f16, b: f16) -> f16 {
88            a + b
89        }
90    );
91
92    #[cfg(test)]
93    #[macro_use]
94    pub mod s {
95        crate::sum_frame_tests!(true, f32, crate::generic::reduce::sum::SSum4);
96    }
97
98    #[cfg(test)]
99    #[macro_use]
100    pub mod h {
101        use super::*;
102        crate::sum_frame_tests!(true, f16, crate::generic::reduce::sum::HSum8);
103    }
104}
105
106// Softmax generic implementation
107pub mod softmax_l2 {
108    use crate::num_traits::Zero;
109    use tract_data::internal::f16;
110
111    map_reduce_impl_wrap!(
112        f32,
113        SSoftMaxL2,
114        4,
115        4,
116        f32,
117        f32::MIN,
118        0.0,
119        fn run(x: &mut [f32], max: f32) -> f32 {
120            debug_assert!(x.len() % Self::nr() == 0);
121            debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
122            let mut sum = 0.;
123            for v in x.iter_mut() {
124                let y = *v - max;
125                let y = fast_compact_exp_f32(y);
126                *v = y;
127                sum += y;
128            }
129            sum
130        },
131        fn reduce_two(a: f32, b: f32) -> f32 {
132            a + b
133        }
134    );
135
136    map_reduce_impl_wrap!(
137        f16,
138        HSoftMaxL2,
139        8,
140        8,
141        f16,
142        f16::MIN,
143        f16::zero(),
144        fn run(x: &mut [f16], max: f16) -> f16 {
145            debug_assert!(x.len() % Self::nr() == 0);
146            debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
147            let mut sum = f16::zero();
148            for v in x.iter_mut() {
149                let y = *v - max;
150                let y = f16::from_f32(fast_compact_exp_f32(y.to_f32()));
151                *v = y;
152                sum += y;
153            }
154            sum
155        },
156        fn reduce_two(a: f16, b: f16) -> f16 {
157            a + b
158        }
159    );
160
161    // ported from https://github.com/gnuradio/volk/blob/master/kernels/volk/volk_32f_expfast_32f.h
162    // probably inspired from https://nic.schraudolph.org/pubs/Schraudolph99.pdf
163    // not that the cast to u32 deals with negative right, while implem in volk code are wrong in some
164    // corner cases (need a max(0,x) before the u32 conversion)
165    pub fn fast_compact_exp_f32(v: f32) -> f32 {
166        const MLN2: f32 = 0.6931471805f32;
167        const A: f32 = 8388608.0f32;
168        const B: f32 = 1065353216.0f32;
169        const C: f32 = 60801.0f32;
170        const SLOPE: f32 = A / MLN2;
171        const OFFSET: f32 = B - C;
172        f32::from_bits(((SLOPE * v) + OFFSET) as u32)
173    }
174
175    #[cfg(test)]
176    #[macro_use]
177    pub mod s {
178        crate::softmax_l2_frame_tests!(true, f32, super::SSoftMaxL2);
179    }
180
181    #[cfg(test)]
182    #[macro_use]
183    pub mod h {
184        use super::*;
185        crate::softmax_l2_frame_tests!(true, f16, HSoftMaxL2);
186    }
187}