1pub mod max;
2pub mod softmax;
3pub mod sum;
4
5use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use tract_data::TractResult;
9
10use crate::LADatum;
11
12use super::element_wise_helper::{map_reduce_slice_with_alignment, reduce_slice_with_alignment};
13
14macro_rules! reduce_impl_wrap {
15 ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $neutral: expr, $run: item, $reduce_two: item) => {
16 paste! {
17 #[derive(Copy, Clone, Debug)]
18 #[allow(non_camel_case_types)]
19 pub struct $func;
20
21 impl crate::frame::reduce::ReduceKer<$ti, $params> for $func {
22 #[inline(always)]
23 fn name() -> &'static str {
24 stringify!($func)
25 }
26 #[inline(always)]
27 fn nr() -> usize {
28 $nr
29 }
30 #[inline(always)]
31 fn alignment_items() -> usize {
32 $alignment_items
33 }
34 #[inline(always)]
35 fn alignment_bytes() -> usize {
36 $alignment_items * std::mem::size_of::<$ti>()
37 }
38 #[inline(always)]
39 fn neutral() -> $ti {
40 $neutral
41 }
42 $run
43 $reduce_two
44 }
45 }
46 };
47}
48
49pub trait Reduce<T, Params = ()>: Send + Sync + Debug + dyn_clone::DynClone
50where
51 Params: Copy + Send + Sync + Debug + 'static + Default,
52 T: Copy + Debug + PartialEq + Send + Sync,
53{
54 fn name(&self) -> &'static str;
55 fn run(&self, vec: &[T]) -> TractResult<T> {
56 self.run_with_params(vec, Params::default())
57 }
58 fn run_with_params(&self, vec: &[T], params: Params) -> TractResult<T>;
59}
60
61dyn_clone::clone_trait_object!(<T, Params> Reduce<T, Params> where T: Copy, Params: Copy);
62
63#[derive(Debug, Clone, new)]
64pub struct ReduceImpl<K, T, Params = ()>
65where
66 T: LADatum,
67 Params: Copy + Send + Sync + Debug + 'static + Default,
68 K: ReduceKer<T, Params> + Clone,
69{
70 phantom: PhantomData<(K, T, Params)>,
71}
72
73impl<K, T, Params> Reduce<T, Params> for ReduceImpl<K, T, Params>
74where
75 T: LADatum,
76 Params: Copy + Send + Sync + Debug + 'static + Default,
77 K: ReduceKer<T, Params> + Clone,
78{
79 fn name(&self) -> &'static str {
80 K::name()
81 }
82
83 fn run_with_params(&self, vec: &[T], params: Params) -> TractResult<T> {
84 reduce_slice_with_alignment(
85 vec,
86 |data| K::run(data, params),
87 K::nr(),
88 K::alignment_bytes(),
89 K::neutral(),
90 K::reduce_two,
91 )
92 }
93}
94
95pub trait ReduceKer<T, Params = ()>:
96 Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static
97where
98 Params: Copy + Send + Sync + Debug + 'static + Default,
99 T: LADatum,
100{
101 fn name() -> &'static str;
102 fn alignment_bytes() -> usize {
103 Self::alignment_items() * T::datum_type().size_of()
104 }
105 fn alignment_items() -> usize;
106 fn nr() -> usize;
107 fn neutral() -> T;
108 fn reduce_two(a: T, b: T) -> T;
109 fn run(vec: &[T], params: Params) -> T;
110 fn red() -> Box<dyn Reduce<T, Params>> {
111 Box::new(ReduceImpl::<Self, T, Params>::new())
112 }
113}
114
115#[allow(unused_macros)]
116macro_rules! map_reduce_impl_wrap {
117 ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $map_neutral: expr, $reduce_neutral: expr, $run: item, $reduce_two: item) => {
118 paste! {
119 #[derive(Copy, Clone, Debug)]
120 #[allow(non_camel_case_types)]
121 pub struct $func;
122
123 impl crate::frame::reduce::MapReduceKer<$ti, $params> for $func {
124 #[inline(always)]
125 fn name() -> &'static str {
126 stringify!($func)
127 }
128 #[inline(always)]
129 fn nr() -> usize {
130 $nr
131 }
132 #[inline(always)]
133 fn alignment_items() -> usize {
134 $alignment_items
135 }
136 #[inline(always)]
137 fn alignment_bytes() -> usize {
138 $alignment_items * std::mem::size_of::<$ti>()
139 }
140 #[inline(always)]
141 fn map_neutral() -> $ti {
142 $map_neutral
143 }
144 #[inline(always)]
145 fn reduce_neutral() -> $ti {
146 $reduce_neutral
147 }
148 $run
149 $reduce_two
150 }
151 }
152 };
153}
154
155pub trait MapReduce<T, Params = ()>: Send + Sync + Debug + dyn_clone::DynClone
156where
157 Params: Copy + Send + Sync + Debug + 'static + Default,
158 T: Copy + Debug + PartialEq + Send + Sync,
159{
160 fn name(&self) -> &'static str;
161 fn run(&self, vec: &mut [T]) -> TractResult<T> {
162 self.run_with_params(vec, Params::default())
163 }
164 fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult<T>;
165}
166
167dyn_clone::clone_trait_object!(<T, Params> MapReduce<T, Params> where T: Copy, Params: Copy);
168
169#[derive(Debug, Clone, new)]
170pub struct MapReduceImpl<K, T, Params = ()>
171where
172 T: LADatum,
173 Params: Copy + Send + Sync + Debug + 'static + Default,
174 K: MapReduceKer<T, Params> + Clone,
175{
176 phantom: PhantomData<(K, T, Params)>,
177}
178
179impl<K, T, Params> MapReduce<T, Params> for MapReduceImpl<K, T, Params>
180where
181 T: LADatum,
182 Params: Copy + Send + Sync + Debug + 'static + Default,
183 K: MapReduceKer<T, Params> + Clone,
184{
185 fn name(&self) -> &'static str {
186 K::name()
187 }
188 fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult<T> {
189 map_reduce_slice_with_alignment(
190 vec,
191 |data| K::run(data, params),
192 K::nr(),
193 K::alignment_bytes(),
194 K::map_neutral(),
195 K::reduce_neutral(),
196 K::reduce_two,
197 )
198 }
199}
200
201pub trait MapReduceKer<T, Params = ()>:
202 Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static
203where
204 Params: Copy + Send + Sync + Debug + 'static + Default,
205 T: LADatum,
206{
207 fn name() -> &'static str;
208 fn alignment_bytes() -> usize {
209 Self::alignment_items() * T::datum_type().size_of()
210 }
211 fn alignment_items() -> usize;
212 fn nr() -> usize;
213 fn map_neutral() -> T;
214 fn reduce_neutral() -> T;
215 fn reduce_two(a: T, b: T) -> T;
216 fn run(vec: &mut [T], params: Params) -> T;
217 fn red() -> Box<dyn MapReduce<T, Params>> {
218 Box::new(MapReduceImpl::<Self, T, Params>::new())
219 }
220}
221
222#[cfg(test)]
223pub mod test {
224 use super::*;
225 use proptest::test_runner::{TestCaseError, TestCaseResult};
226 use tract_data::internal::*;
227 use tract_data::itertools::Itertools;
228
229 pub fn test_reduce<K: ReduceKer<T, ()>, T: LADatum>(
230 values: &[T],
231 neutral: T,
232 reference_reduce: impl Fn(T, T) -> T,
233 ) -> TestCaseResult {
234 test_reduce_params::<K, T, ()>(values, neutral, reference_reduce, ())
235 }
236
237 pub fn test_reduce_params<K: ReduceKer<T, Params>, T: LADatum, Params>(
238 values: &[T],
239 neutral: T,
240 reference_reducer: impl Fn(T, T) -> T,
241 params: Params,
242 ) -> TestCaseResult
243 where
244 Params: Copy + Send + Sync + Debug + 'static + Default,
245 {
246 crate::setup_test_logger();
247 let op = K::red();
248 let expected = values.iter().fold(neutral, |acc, i| reference_reducer(acc, *i));
249 let found = values;
250 let red = op.run_with_params(found, params).unwrap();
251 tensor0(red)
252 .close_enough(&tensor0(expected), true)
253 .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?;
254 Ok(())
255 }
256
257 pub fn test_map_reduce<K: MapReduceKer<T, ()>, T: LADatum>(
258 values: &[T],
259 map_neutral: T,
260 neutral: T,
261 reference_map: impl Fn(T) -> T,
262 reference_reduce: impl Fn(T, T) -> T,
263 ) -> TestCaseResult {
264 test_map_reduce_params::<K, T, ()>(
265 values,
266 map_neutral,
267 neutral,
268 reference_map,
269 reference_reduce,
270 (),
271 )
272 }
273
274 pub fn test_map_reduce_params<K: MapReduceKer<T, Params>, T: LADatum, Params>(
275 values: &[T],
276 _neutral: T,
277 map_neutral: T,
278 reference_map: impl Fn(T) -> T,
279 reference_reducer: impl Fn(T, T) -> T,
280 params: Params,
281 ) -> TestCaseResult
282 where
283 Params: Copy + Send + Sync + Debug + 'static + Default,
284 {
285 crate::setup_test_logger();
286 let op = K::red();
287 let mut found = values.to_vec();
288 let expected_values = values.iter().copied().map(reference_map).collect_vec();
289 let expected_reduced =
290 expected_values.iter().fold(map_neutral, |acc, i| reference_reducer(acc, *i));
291 let red = op.run_with_params(&mut found, params).unwrap();
292 tensor1(&found)
293 .close_enough(&tensor1(&expected_values), Approximation::SuperApproximate)
294 .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?;
295 tensor0(red)
296 .close_enough(&tensor0(expected_reduced), Approximation::SuperApproximate)
297 .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?;
298 Ok(())
299 }
300}