tract_linalg/frame/reduce/
mod.rs

1pub mod max;
2pub mod softmax;
3pub mod sum;
4
5use std::fmt::Debug;
6use std::marker::PhantomData;
7
8use tract_data::TractResult;
9
10use crate::LADatum;
11
12use super::element_wise_helper::{map_reduce_slice_with_alignment, reduce_slice_with_alignment};
13
14macro_rules! reduce_impl_wrap {
15    ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $neutral: expr, $run: item, $reduce_two: item) => {
16        paste! {
17            #[derive(Copy, Clone, Debug)]
18            #[allow(non_camel_case_types)]
19            pub struct $func;
20
21            impl crate::frame::reduce::ReduceKer<$ti, $params> for $func {
22                #[inline(always)]
23                fn name() -> &'static str {
24                    stringify!($func)
25                }
26                #[inline(always)]
27                fn nr() -> usize {
28                    $nr
29                }
30                #[inline(always)]
31                fn alignment_items() -> usize {
32                    $alignment_items
33                }
34                #[inline(always)]
35                fn alignment_bytes() -> usize {
36                    $alignment_items * std::mem::size_of::<$ti>()
37                }
38                #[inline(always)]
39                fn neutral() -> $ti {
40                    $neutral
41                }
42                $run
43                $reduce_two
44            }
45        }
46    };
47}
48
49pub trait Reduce<T, Params = ()>: Send + Sync + Debug + dyn_clone::DynClone
50where
51    Params: Copy + Send + Sync + Debug + 'static + Default,
52    T: Copy + Debug + PartialEq + Send + Sync,
53{
54    fn name(&self) -> &'static str;
55    fn run(&self, vec: &[T]) -> TractResult<T> {
56        self.run_with_params(vec, Params::default())
57    }
58    fn run_with_params(&self, vec: &[T], params: Params) -> TractResult<T>;
59}
60
61dyn_clone::clone_trait_object!(<T, Params> Reduce<T, Params> where T: Copy, Params: Copy);
62
63#[derive(Debug, Clone, new)]
64pub struct ReduceImpl<K, T, Params = ()>
65where
66    T: LADatum,
67    Params: Copy + Send + Sync + Debug + 'static + Default,
68    K: ReduceKer<T, Params> + Clone,
69{
70    phantom: PhantomData<(K, T, Params)>,
71}
72
73impl<K, T, Params> Reduce<T, Params> for ReduceImpl<K, T, Params>
74where
75    T: LADatum,
76    Params: Copy + Send + Sync + Debug + 'static + Default,
77    K: ReduceKer<T, Params> + Clone,
78{
79    fn name(&self) -> &'static str {
80        K::name()
81    }
82
83    fn run_with_params(&self, vec: &[T], params: Params) -> TractResult<T> {
84        reduce_slice_with_alignment(
85            vec,
86            |data| K::run(data, params),
87            K::nr(),
88            K::alignment_bytes(),
89            K::neutral(),
90            K::reduce_two,
91        )
92    }
93}
94
95pub trait ReduceKer<T, Params = ()>:
96    Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static
97where
98    Params: Copy + Send + Sync + Debug + 'static + Default,
99    T: LADatum,
100{
101    fn name() -> &'static str;
102    fn alignment_bytes() -> usize {
103        Self::alignment_items() * T::datum_type().size_of()
104    }
105    fn alignment_items() -> usize;
106    fn nr() -> usize;
107    fn neutral() -> T;
108    fn reduce_two(a: T, b: T) -> T;
109    fn run(vec: &[T], params: Params) -> T;
110    fn red() -> Box<dyn Reduce<T, Params>> {
111        Box::new(ReduceImpl::<Self, T, Params>::new())
112    }
113}
114
115#[allow(unused_macros)]
116macro_rules! map_reduce_impl_wrap {
117    ($ti: ident, $func: ident, $nr: expr, $alignment_items: expr, $params: ty, $map_neutral: expr, $reduce_neutral: expr, $run: item, $reduce_two: item) => {
118        paste! {
119            #[derive(Copy, Clone, Debug)]
120            #[allow(non_camel_case_types)]
121            pub struct $func;
122
123            impl crate::frame::reduce::MapReduceKer<$ti, $params> for $func {
124                #[inline(always)]
125                fn name() -> &'static str {
126                    stringify!($func)
127                }
128                #[inline(always)]
129                fn nr() -> usize {
130                    $nr
131                }
132                #[inline(always)]
133                fn alignment_items() -> usize {
134                    $alignment_items
135                }
136                #[inline(always)]
137                fn alignment_bytes() -> usize {
138                    $alignment_items * std::mem::size_of::<$ti>()
139                }
140                #[inline(always)]
141                fn map_neutral() -> $ti {
142                    $map_neutral
143                }
144                #[inline(always)]
145                fn reduce_neutral() -> $ti {
146                    $reduce_neutral
147                }
148                $run
149                $reduce_two
150            }
151        }
152    };
153}
154
155pub trait MapReduce<T, Params = ()>: Send + Sync + Debug + dyn_clone::DynClone
156where
157    Params: Copy + Send + Sync + Debug + 'static + Default,
158    T: Copy + Debug + PartialEq + Send + Sync,
159{
160    fn name(&self) -> &'static str;
161    fn run(&self, vec: &mut [T]) -> TractResult<T> {
162        self.run_with_params(vec, Params::default())
163    }
164    fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult<T>;
165}
166
167dyn_clone::clone_trait_object!(<T, Params> MapReduce<T, Params> where T: Copy, Params: Copy);
168
169#[derive(Debug, Clone, new)]
170pub struct MapReduceImpl<K, T, Params = ()>
171where
172    T: LADatum,
173    Params: Copy + Send + Sync + Debug + 'static + Default,
174    K: MapReduceKer<T, Params> + Clone,
175{
176    phantom: PhantomData<(K, T, Params)>,
177}
178
179impl<K, T, Params> MapReduce<T, Params> for MapReduceImpl<K, T, Params>
180where
181    T: LADatum,
182    Params: Copy + Send + Sync + Debug + 'static + Default,
183    K: MapReduceKer<T, Params> + Clone,
184{
185    fn name(&self) -> &'static str {
186        K::name()
187    }
188    fn run_with_params(&self, vec: &mut [T], params: Params) -> TractResult<T> {
189        map_reduce_slice_with_alignment(
190            vec,
191            |data| K::run(data, params),
192            K::nr(),
193            K::alignment_bytes(),
194            K::map_neutral(),
195            K::reduce_neutral(),
196            K::reduce_two,
197        )
198    }
199}
200
201pub trait MapReduceKer<T, Params = ()>:
202    Send + Sync + Debug + dyn_clone::DynClone + Clone + 'static
203where
204    Params: Copy + Send + Sync + Debug + 'static + Default,
205    T: LADatum,
206{
207    fn name() -> &'static str;
208    fn alignment_bytes() -> usize {
209        Self::alignment_items() * T::datum_type().size_of()
210    }
211    fn alignment_items() -> usize;
212    fn nr() -> usize;
213    fn map_neutral() -> T;
214    fn reduce_neutral() -> T;
215    fn reduce_two(a: T, b: T) -> T;
216    fn run(vec: &mut [T], params: Params) -> T;
217    fn red() -> Box<dyn MapReduce<T, Params>> {
218        Box::new(MapReduceImpl::<Self, T, Params>::new())
219    }
220}
221
222#[cfg(test)]
223pub mod test {
224    use super::*;
225    use proptest::test_runner::{TestCaseError, TestCaseResult};
226    use tract_data::internal::*;
227    use tract_data::itertools::Itertools;
228
229    pub fn test_reduce<K: ReduceKer<T, ()>, T: LADatum>(
230        values: &[T],
231        neutral: T,
232        reference_reduce: impl Fn(T, T) -> T,
233    ) -> TestCaseResult {
234        test_reduce_params::<K, T, ()>(values, neutral, reference_reduce, ())
235    }
236
237    pub fn test_reduce_params<K: ReduceKer<T, Params>, T: LADatum, Params>(
238        values: &[T],
239        neutral: T,
240        reference_reducer: impl Fn(T, T) -> T,
241        params: Params,
242    ) -> TestCaseResult
243    where
244        Params: Copy + Send + Sync + Debug + 'static + Default,
245    {
246        crate::setup_test_logger();
247        let op = K::red();
248        let expected = values.iter().fold(neutral, |acc, i| reference_reducer(acc, *i));
249        let found = values;
250        let red = op.run_with_params(found, params).unwrap();
251        tensor0(red)
252            .close_enough(&tensor0(expected), true)
253            .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?;
254        Ok(())
255    }
256
257    pub fn test_map_reduce<K: MapReduceKer<T, ()>, T: LADatum>(
258        values: &[T],
259        map_neutral: T,
260        neutral: T,
261        reference_map: impl Fn(T) -> T,
262        reference_reduce: impl Fn(T, T) -> T,
263    ) -> TestCaseResult {
264        test_map_reduce_params::<K, T, ()>(
265            values,
266            map_neutral,
267            neutral,
268            reference_map,
269            reference_reduce,
270            (),
271        )
272    }
273
274    pub fn test_map_reduce_params<K: MapReduceKer<T, Params>, T: LADatum, Params>(
275        values: &[T],
276        _neutral: T,
277        map_neutral: T,
278        reference_map: impl Fn(T) -> T,
279        reference_reducer: impl Fn(T, T) -> T,
280        params: Params,
281    ) -> TestCaseResult
282    where
283        Params: Copy + Send + Sync + Debug + 'static + Default,
284    {
285        crate::setup_test_logger();
286        let op = K::red();
287        let mut found = values.to_vec();
288        let expected_values = values.iter().copied().map(reference_map).collect_vec();
289        let expected_reduced =
290            expected_values.iter().fold(map_neutral, |acc, i| reference_reducer(acc, *i));
291        let red = op.run_with_params(&mut found, params).unwrap();
292        tensor1(&found)
293            .close_enough(&tensor1(&expected_values), Approximation::SuperApproximate)
294            .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?;
295        tensor0(red)
296            .close_enough(&tensor0(expected_reduced), Approximation::SuperApproximate)
297            .map_err(|e| TestCaseError::fail(e.root_cause().to_string()))?;
298        Ok(())
299    }
300}