tract_linalg/generic/
reduce.rs1pub mod max {
3 pub use tract_data::internal::f16;
4
5 reduce_impl_wrap!(
6 f32,
7 SMax4,
8 4,
9 4,
10 (),
11 f32::MIN,
12 fn run(x: &[f32], _: ()) -> f32 {
13 debug_assert!(x.len() % Self::nr() == 0);
14 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
15 *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap()
16 },
17 fn reduce_two(a: f32, b: f32) -> f32 {
18 a.max(b)
19 }
20 );
21
22 reduce_impl_wrap!(
23 f16,
24 HMax8,
25 8,
26 8,
27 (),
28 f16::MIN,
29 fn run(x: &[f16], _: ()) -> f16 {
30 debug_assert!(x.len() % Self::nr() == 0);
31 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
32 *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap()
33 },
34 fn reduce_two(a: f16, b: f16) -> f16 {
35 a.max(b)
36 }
37 );
38
39 #[cfg(test)]
40 #[macro_use]
41 pub mod s {
42 crate::max_frame_tests!(true, f32, crate::generic::reduce::max::SMax4);
43 }
44
45 #[cfg(test)]
46 #[macro_use]
47 pub mod h {
48 use super::*;
49 crate::max_frame_tests!(true, f16, crate::generic::reduce::max::HMax8);
50 }
51}
52
53pub mod sum {
55 use crate::num_traits::Zero;
56 pub use tract_data::internal::f16;
57
58 reduce_impl_wrap!(
59 f32,
60 SSum4,
61 4,
62 4,
63 (),
64 0.0,
65 fn run(x: &[f32], _: ()) -> f32 {
66 debug_assert!(x.len() % Self::nr() == 0);
67 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
68 x.iter().sum::<f32>()
69 },
70 fn reduce_two(a: f32, b: f32) -> f32 {
71 a + b
72 }
73 );
74
75 reduce_impl_wrap!(
76 f16,
77 HSum8,
78 8,
79 8,
80 (),
81 f16::zero(),
82 fn run(x: &[f16], _: ()) -> f16 {
83 debug_assert!(x.len() % Self::nr() == 0);
84 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
85 x.iter().sum::<f16>()
86 },
87 fn reduce_two(a: f16, b: f16) -> f16 {
88 a + b
89 }
90 );
91
92 #[cfg(test)]
93 #[macro_use]
94 pub mod s {
95 crate::sum_frame_tests!(true, f32, crate::generic::reduce::sum::SSum4);
96 }
97
98 #[cfg(test)]
99 #[macro_use]
100 pub mod h {
101 use super::*;
102 crate::sum_frame_tests!(true, f16, crate::generic::reduce::sum::HSum8);
103 }
104}
105
106pub mod softmax_l2 {
108 use crate::num_traits::Zero;
109 use tract_data::internal::f16;
110
111 map_reduce_impl_wrap!(
112 f32,
113 SSoftMaxL2,
114 4,
115 4,
116 f32,
117 f32::MIN,
118 0.0,
119 fn run(x: &mut [f32], max: f32) -> f32 {
120 debug_assert!(x.len() % Self::nr() == 0);
121 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
122 let mut sum = 0.;
123 for v in x.iter_mut() {
124 let y = *v - max;
125 let y = fast_compact_exp_f32(y);
126 *v = y;
127 sum += y;
128 }
129 sum
130 },
131 fn reduce_two(a: f32, b: f32) -> f32 {
132 a + b
133 }
134 );
135
136 map_reduce_impl_wrap!(
137 f16,
138 HSoftMaxL2,
139 8,
140 8,
141 f16,
142 f16::MIN,
143 f16::zero(),
144 fn run(x: &mut [f16], max: f16) -> f16 {
145 debug_assert!(x.len() % Self::nr() == 0);
146 debug_assert!(x.as_ptr() as usize % Self::alignment_bytes() == 0);
147 let mut sum = f16::zero();
148 for v in x.iter_mut() {
149 let y = *v - max;
150 let y = f16::from_f32(fast_compact_exp_f32(y.to_f32()));
151 *v = y;
152 sum += y;
153 }
154 sum
155 },
156 fn reduce_two(a: f16, b: f16) -> f16 {
157 a + b
158 }
159 );
160
161 pub fn fast_compact_exp_f32(v: f32) -> f32 {
166 const MLN2: f32 = 0.6931471805f32;
167 const A: f32 = 8388608.0f32;
168 const B: f32 = 1065353216.0f32;
169 const C: f32 = 60801.0f32;
170 const SLOPE: f32 = A / MLN2;
171 const OFFSET: f32 = B - C;
172 f32::from_bits(((SLOPE * v) + OFFSET) as u32)
173 }
174
175 #[cfg(test)]
176 #[macro_use]
177 pub mod s {
178 crate::softmax_l2_frame_tests!(true, f32, super::SSoftMaxL2);
179 }
180
181 #[cfg(test)]
182 #[macro_use]
183 pub mod h {
184 use super::*;
185 crate::softmax_l2_frame_tests!(true, f16, HSoftMaxL2);
186 }
187}