Skip to main content

tract_core/ops/nn/
reduce.rs

1use crate::internal::Axis;
2use crate::internal::*;
3use crate::ops::binary::TypedBinOp;
4use crate::ops::cast::cast;
5use crate::ops::change_axes::wire_with_rank_broadcast;
6use crate::ops::element_wise::ElementWiseOp;
7use crate::ops::math::{div, square, Mul, Square};
8use std::convert::TryFrom;
9use std::iter::Sum;
10use std::mem::transmute;
11use tract_data::internal::ClampCast;
12use tract_data::itertools::Itertools;
13use tract_ndarray::prelude::*;
14use tract_num_traits::{AsPrimitive, Bounded};
15
16macro_rules! r {
17    ($($path:ident)::* ($dt:expr) ($($args:expr),*)) => {
18        match $dt {
19            DatumType::U8   => $($path)::*::<u8,_,_,_>($($args),*),
20            DatumType::I8   => $($path)::*::<i8,_,_,_>($($args),*),
21            DatumType::U16  => $($path)::*::<u16,_,_,_>($($args),*),
22            DatumType::I16  => $($path)::*::<i16,_,_,_>($($args),*),
23            DatumType::I32  => $($path)::*::<i32,_,_,_>($($args),*),
24            DatumType::I64  => $($path)::*::<i64,_,_,_>($($args),*),
25            DatumType::F16  => $($path)::*::<f16,_,_,_>($($args),*),
26            DatumType::F32  => $($path)::*::<f32,_,_,_>($($args),*),
27            DatumType::F64  => $($path)::*::<f64,_,_,_>($($args),*),
28            DatumType::QI8(_)  => $($path)::*::<i8,_,_,_>($($args),*),
29            DatumType::QU8(_)  => $($path)::*::<u8,_,_,_>($($args),*),
30            _ => bail!("{:?} is not a number", $dt)
31        }
32    };
33    ($($path:ident)::* ($dt:expr) ($($args:expr),*); $($q_path:ident)::* ($($q_args:expr),*)) => {
34        match $dt {
35            DatumType::U8   => $($path)::*::<u8,_,_,_>($($args),*),
36            DatumType::I8   => $($path)::*::<i8,_,_,_>($($args),*),
37            DatumType::U16  => $($path)::*::<u16,_,_,_>($($args),*),
38            DatumType::I16  => $($path)::*::<i16,_,_,_>($($args),*),
39            DatumType::I32  => $($path)::*::<i32,_,_,_>($($args),*),
40            DatumType::I64  => $($path)::*::<i64,_,_,_>($($args),*),
41            DatumType::F16  => $($path)::*::<f16,_,_,_>($($args),*),
42            DatumType::F32  => $($path)::*::<f32,_,_,_>($($args),*),
43            DatumType::F64  => $($path)::*::<f64,_,_,_>($($args),*),
44            DatumType::QI8(_)  => $($q_path)::*::<i8,_,_,_>($($q_args),*),
45            DatumType::QU8(_)  => $($q_path)::*::<u8,_,_,_>($($q_args),*),
46            _ => bail!("{:?} is not a number", $dt)
47        }
48    }
49}
50
51#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
52pub enum Reducer {
53    ArgMax(bool), // take last
54    ArgMin(bool),
55    Max,
56    Min,
57    Prod,
58    Sum,
59    MeanOfSquares,
60}
61
62impl Reducer {
63    pub fn reduce(&self, axes: &[usize], input: &Tensor) -> TractResult<Tensor> {
64        use Reducer::*;
65        let dt = input.datum_type();
66        let output_shape: Vec<usize> = input
67            .shape()
68            .iter()
69            .enumerate()
70            .map(|(ax, &d)| if axes.contains(&ax) { 1 } else { d })
71            .collect();
72        let (zp, scale) = input.datum_type().zp_scale();
73        unsafe {
74            let mut t = match self {
75                ArgMax(last) => {
76                    r!(Self::reduce_t(dt)(self, axes, &output_shape, input, argmax_t, *last))
77                }
78                ArgMin(last) => {
79                    r!(Self::reduce_t(dt)(self, axes, &output_shape, input, argmin_t, *last))
80                }
81                Min => r!(Self::reduce_t(dt)(self, axes, &output_shape, input, min_t, ())),
82                Max => r!(Self::reduce_t(dt)(self, axes, &output_shape, input, max_t, ())),
83                Prod => {
84                    r!(Self::reduce_t(dt)(self, axes, &output_shape, input, prod_t, ()); Self::reduce_t(self, axes, &output_shape, input, q_prod_t, (zp, scale)))
85                }
86                Sum => {
87                    if dt.is_float() {
88                        dispatch_floatlike!(Self::sum(dt)(self, axes, input))
89                    } else {
90                        r!(Self::reduce_t(dt)(
91                            self,
92                            axes,
93                            &output_shape,
94                            input,
95                            q_sum_t,
96                            (zp, scale)
97                        ))
98                    }
99                }
100                MeanOfSquares => self.mean_of_squares(axes, input)?,
101            };
102            if input.datum_type().is_quantized()
103                && input.datum_type().unquantized() == t.datum_type().unquantized()
104            {
105                t.set_datum_type(input.datum_type());
106            }
107            Ok(t)
108        }
109    }
110
111    unsafe fn reduce_t<T, TO, F, A>(
112        &self,
113        axes: &[usize],
114        output_shape: &[usize],
115        input_tensor: &Tensor,
116        f: F,
117        args: A,
118    ) -> Tensor
119    where
120        F: for<'a> Fn(ArrayViewD<'a, T>, A) -> TO,
121        T: Copy + Datum,
122        TO: Copy + Datum,
123        A: Copy,
124    {
125        use ndarray::*;
126        let input = unsafe { input_tensor.to_array_view_unchecked::<T>() };
127        let result = Array::from_shape_fn(output_shape, |coords| {
128            let slice_spec: Vec<SliceInfoElem> = coords
129                .slice()
130                .iter()
131                .enumerate()
132                .map(|(ax, &d)| if axes.contains(&ax) { (..).into() } else { d.into() })
133                .collect();
134            let slice_info = SliceInfo::<_, IxDyn, IxDyn>::try_from(slice_spec).unwrap();
135            let slice = input.slice(&slice_info);
136            f(slice, args)
137        });
138        result.into_tensor()
139    }
140
141    // sum is a special citizen: enough activity that it gets "special"
142    // treatment. we could use the same "algo" for min, max and prod, to the
143    // price of more code in the library. argmax and argmin are more
144    // tricky (not associative)
145    unsafe fn sum<T>(&self, axes: &[usize], input: &Tensor) -> Tensor
146    where
147        T: Copy + Datum + num_traits::Zero + Sum,
148        f16: AsPrimitive<T>,
149        f32: AsPrimitive<T>,
150    {
151        if axes.len() == 0 {
152            return input.to_owned();
153        }
154
155        // use tract-optimized path only when single reuction axis and is at end
156        if axes.len() > 1 || axes[0] != input.rank() - 1 {
157            let mut operative_axes = vec![];
158            let mut operative_shape: Vec<usize> = vec![];
159            for (ix, dim) in input.shape().iter().enumerate() {
160                // axis is reduced, but is not the first of a series of reduced axes
161                if ix > 0 && axes.contains(&ix) && axes.contains(&(ix - 1)) {
162                    *operative_shape.last_mut().unwrap() *= *dim;
163                } else if axes.contains(&ix) {
164                    operative_axes.push(operative_shape.len());
165                    operative_shape.push(*dim);
166                } else {
167                    operative_shape.push(*dim);
168                }
169            }
170            let mut output = unsafe {
171                input
172                    .to_array_view_unchecked::<T>()
173                    .into_shape_with_order(operative_shape)
174                    .unwrap()
175                    .sum_axis(Axis(*operative_axes.iter().max().unwrap()))
176            };
177
178            for axis in operative_axes.iter().rev().skip(1) {
179                output = output.sum_axis(Axis(*axis));
180            }
181
182            let mut output = output.into_tensor();
183
184            for &axis in axes {
185                output.insert_axis(axis).unwrap();
186            }
187
188            output
189        } else {
190            let mut output: Option<ArrayD<T>> = None;
191            for axis in axes.iter().copied() {
192                let input_view = output
193                    .as_ref()
194                    .map(|o| o.view())
195                    .unwrap_or_else(|| unsafe { input.to_array_view_unchecked::<T>() });
196
197                // Create array that will contain intermidiate result
198                let reduced_dim = input_view.shape()[axis];
199                let input_stride = input_view.strides()[axis] as usize;
200                let output_shape = input_view
201                    .shape()
202                    .iter()
203                    .enumerate()
204                    .map(|(idx, dim)| if idx != axis { *dim } else { 1 })
205                    .collect_vec();
206
207                output = Some(ArrayD::from_shape_fn(output_shape.clone(), |coords| {
208                    let mut view = input_view.view();
209                    for ix in 0..output_shape.len() {
210                        if ix != axis {
211                            view.collapse_axis(Axis(ix), coords[ix]);
212                        }
213                    }
214
215                    if let Some(slice) = view.as_slice() {
216                        if T::datum_type() == f16::datum_type() {
217                            let slice: &[f16] = unsafe { std::mem::transmute(slice) };
218                            (tract_linalg::ops().sum_f16)()
219                                .run_with_params(slice, ())
220                                .unwrap()
221                                .as_()
222                        } else if T::datum_type() == f32::datum_type() {
223                            let slice: &[f32] = unsafe { std::mem::transmute(slice) };
224                            (tract_linalg::ops().sum_f32)()
225                                .run_with_params(slice, ())
226                                .unwrap()
227                                .as_()
228                        } else {
229                            slice.iter().cloned().sum::<T>()
230                        }
231                    } else {
232                        let first: *const T = &input_view[coords];
233                        let mut sum = T::zero();
234                        for i in 0..reduced_dim {
235                            sum = sum + unsafe { *(first.add(i * input_stride)) };
236                        }
237                        sum
238                    }
239                }));
240            }
241            output.unwrap().into_tensor()
242        }
243    }
244
245    fn mean_of_squares(&self, axis: &[usize], input: &Tensor) -> TractResult<Tensor> {
246        let dt = input.datum_type();
247        let mut input = input.cast_to::<f32>()?.into_owned();
248        input.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x = *x * *x);
249        let mut output = unsafe { self.sum::<f32>(axis, &input) };
250        let norm = output.len() as f32 / input.len() as f32;
251        output.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x *= norm);
252        Ok(output.cast_to_dt(dt)?.into_owned())
253    }
254}
255
256fn argmax_t<T>(v: ArrayViewD<T>, last: bool) -> i64
257where
258    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
259{
260    v.iter()
261        .copied()
262        .enumerate()
263        .fold(
264            (0usize, T::min_value()),
265            |acc, v| {
266                if v.1 > acc.1 || (last && acc.1 == v.1) {
267                    v
268                } else {
269                    acc
270                }
271            },
272        )
273        .0 as i64
274}
275
276fn argmin_t<T>(v: ArrayViewD<T>, last: bool) -> i64
277where
278    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
279{
280    v.iter()
281        .copied()
282        .enumerate()
283        .fold(
284            (0usize, T::max_value()),
285            |acc, v| {
286                if v.1 < acc.1 || (last && acc.1 == v.1) {
287                    v
288                } else {
289                    acc
290                }
291            },
292        )
293        .0 as i64
294}
295
296fn max_t<T>(v: ArrayViewD<T>, _: ()) -> T
297where
298    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
299{
300    if T::datum_type() == f32::datum_type() {
301        if let Some(slice) = v.as_slice() {
302            let slice = unsafe { transmute::<&[T], &[f32]>(slice) };
303            (tract_linalg::ops().max_f32)().run(slice).unwrap();
304        }
305    }
306    v.fold(T::min_value(), |acc, &v| if acc > v { acc } else { v })
307}
308
309fn min_t<T>(v: ArrayViewD<T>, _: ()) -> T
310where
311    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
312{
313    v.fold(T::max_value(), |acc, &v| if acc < v { acc } else { v })
314}
315
316fn prod_t<T>(v: ArrayViewD<T>, _: ()) -> T
317where
318    T: Copy + Datum + num_traits::One,
319{
320    v.fold(T::one(), |acc, &v| acc * v)
321}
322
323fn q_prod_t<T>(v: ArrayViewD<T>, zp_scale: (i32, f32)) -> T
324where
325    T: Copy + num_traits::AsPrimitive<f32> + Bounded + Datum,
326    f32: num_traits::AsPrimitive<T>,
327{
328    let (zp, scale) = zp_scale;
329    (v.fold(1f32, |acc, &v| acc * (v.as_() - zp as f32)) * scale.powi(v.len() as i32 - 1)
330        + zp as f32)
331        .clamp_cast()
332}
333
334fn q_sum_t<T>(v: ArrayViewD<T>, zp_scale: (i32, f32)) -> T
335where
336    T: Copy + Bounded + num_traits::AsPrimitive<i32> + Datum,
337    i32: num_traits::AsPrimitive<T>,
338{
339    let (zp, _) = zp_scale;
340    (v.fold(0i32, |acc, &v| acc + v.as_()) - zp * (v.len() as i32 - 1)).clamp_cast()
341}
342
343#[derive(Clone, Debug, new, Hash)]
344pub struct Reduce {
345    pub axes: TVec<usize>,
346    pub reducer: Reducer,
347}
348
349impl Op for Reduce {
350    fn name(&self) -> StaticName {
351        format!("Reduce<{:?}>", self.reducer).into()
352    }
353    fn info(&self) -> TractResult<Vec<String>> {
354        Ok(vec![format!("axes: {:?}", self.axes)])
355    }
356    op_as_typed_op!();
357}
358
359impl EvalOp for Reduce {
360    fn is_stateless(&self) -> bool {
361        true
362    }
363
364    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
365        Ok(tvec!(self.reducer.reduce(&self.axes, &inputs[0])?.into()))
366    }
367}
368
369impl TypedOp for Reduce {
370    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
371        ensure!(self.axes.iter().tuple_windows().all(|(a, b)| a < b));
372        if inputs[0].datum_type == TDim::datum_type() {
373            bail!("Reduce input must be cast from TDim to i64 beforehand")
374        }
375        let mut shape: TVec<_> = inputs[0].shape.to_tvec();
376        for &ax in &self.axes {
377            shape[ax] = 1.to_dim();
378        }
379        let dt = if let Reducer::ArgMax(_) | Reducer::ArgMin(_) = self.reducer {
380            DatumType::I64
381        } else {
382            inputs[0].datum_type
383        };
384        Ok(tvec!(dt.fact(shape)))
385    }
386
387    fn declutter(
388        &self,
389        model: &TypedModel,
390        node: &TypedNode,
391    ) -> TractResult<Option<TypedModelPatch>> {
392        if let Some(patch) = self.declutter_mean_of_square(model, node)? {
393            return Ok(Some(patch));
394        }
395        if let Some(patch) = self.declutter_scalar_mul_then_sum(model, node)? {
396            return Ok(Some(patch));
397        }
398        if let Some(patch) = self.declutter_reduce_reduce(model, node)? {
399            return Ok(Some(patch));
400        }
401        Ok(None)
402    }
403
404    fn axes_mapping(
405        &self,
406        inputs: &[&TypedFact],
407        outputs: &[&TypedFact],
408    ) -> TractResult<AxesMapping> {
409        let mut letters = 'a'..;
410        let axes = (0..inputs[0].rank())
411            .flat_map(|ix| {
412                if self.axes.contains(&ix) {
413                    tvec!(
414                        Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
415                            .input(0, ix),
416                        Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
417                            .output(0, ix),
418                    )
419                } else {
420                    tvec!(Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
421                        .input(0, ix)
422                        .output(0, ix))
423                }
424                .into_iter()
425            })
426            .collect_vec();
427        AxesMapping::new(1, 1, axes)
428    }
429
430    fn change_axes(
431        &self,
432        model: &TypedModel,
433        node: &TypedNode,
434        _io: InOut,
435        change: &AxisOp,
436    ) -> TractResult<Option<AxisChangeConsequence>> {
437        let mut axes = tvec!();
438        for reduced in &self.axes {
439            if let Some(axis) = change.transform_axis(*reduced) {
440                axes.push(axis);
441            } else {
442                return Ok(None);
443            }
444        }
445        axes.sort();
446        let op = Some(Box::new(Self { axes, ..self.clone() }) as _);
447        Ok(Some(AxisChangeConsequence::new(model, node, op, change)))
448    }
449
450    fn slice(
451        &self,
452        patch: &mut TypedModelPatch,
453        _model: &TypedModel,
454        node: &TypedNode,
455        _prefix: &str,
456        inputs: &[OutletId],
457        output_axis: usize,
458        _start: &TDim,
459        _end: &TDim,
460    ) -> TractResult<Option<TVec<OutletId>>> {
461        if self.axes.contains(&output_axis) {
462            return Ok(None);
463        }
464        patch.wire_node(&node.name, &node.op, inputs).map(Some)
465    }
466
467    as_op!();
468}
469
470impl Reduce {
471    fn declutter_reduce_reduce(
472        &self,
473        model: &TypedModel,
474        node: &TypedNode,
475    ) -> TractResult<Option<TypedModelPatch>> {
476        let Some(prec) = model.linear_prec(node.id)? else {
477            return Ok(None);
478        };
479        let Some(prec_reduce) = prec.op_as::<Self>() else {
480            return Ok(None);
481        };
482        use Reducer::*;
483        if prec_reduce.reducer != self.reducer || ![Sum, Prod, Min, Max].contains(&self.reducer) {
484            return Ok(None);
485        }
486        let mut patch = TypedModelPatch::default();
487        let wire = patch.tap_model(model, prec.inputs[0])?;
488        let wire = patch.wire_node(
489            &node.name,
490            Self {
491                reducer: self.reducer,
492                axes: prec_reduce
493                    .axes
494                    .iter()
495                    .chain(self.axes.iter())
496                    .copied()
497                    .sorted()
498                    .dedup()
499                    .collect(),
500            },
501            &[wire],
502        )?;
503        patch.shunt_outside(model, node.id.into(), wire[0])?;
504        Ok(Some(patch))
505    }
506
507    fn declutter_scalar_mul_then_sum(
508        &self,
509        model: &TypedModel,
510        node: &TypedNode,
511    ) -> TractResult<Option<TypedModelPatch>> {
512        if self.reducer == Reducer::Sum {
513            let Some(prec) = model.linear_prec(node.id)? else {
514                return Ok(None);
515            };
516            let Some(prec_bin) = prec.op_as::<TypedBinOp>() else {
517                return Ok(None);
518            };
519            if !prec_bin.0.is::<Mul>() {
520                return Ok(None);
521            }
522            let mul_input_fact = model.node_input_facts(prec.id)?;
523            let Some(scalar_slot) = mul_input_fact
524                .iter()
525                .position(|f| f.konst.as_ref().is_some_and(|k| k.volume() == 1))
526            else {
527                return Ok(None);
528            };
529            let mut patch = TypedModelPatch::default();
530            let scalar = patch.tap_model(model, prec.inputs[scalar_slot])?;
531            let wire = patch.tap_model(model, prec.inputs[1 - scalar_slot])?;
532            let wire = patch.wire_node(&node.name, self.clone(), &[wire])?[0];
533            let wire = patch.wire_node(&prec.name, prec_bin.clone(), &[wire, scalar])?[0];
534            patch.shunt_outside(model, node.id.into(), wire)?;
535            return Ok(Some(patch));
536        }
537        Ok(None)
538    }
539
540    fn declutter_mean_of_square(
541        &self,
542        model: &TypedModel,
543        node: &TypedNode,
544    ) -> TractResult<Option<TypedModelPatch>> {
545        if self.reducer == Reducer::Sum {
546            let Some(prec) = model.linear_prec(node.id)? else {
547                return Ok(None);
548            };
549            let Some(prec_ew) = prec.op_as::<ElementWiseOp>() else {
550                return Ok(None);
551            };
552            if !prec_ew.0.is::<Square>() {
553                return Ok(None);
554            }
555            if node.outputs.len() != 1 || node.outputs[0].successors.len() != 1 {
556                return Ok(None);
557            }
558            let our_inlet = node.outputs[0].successors[0];
559            let succ = model.node(our_inlet.node);
560            let Some(succ_bin) = succ.op_as::<TypedBinOp>() else {
561                return Ok(None);
562            };
563            if !succ_bin.0.is::<Mul>() {
564                return Ok(None);
565            }
566            let other = succ.inputs[1 - our_inlet.slot];
567            let Some(other_konst) = model.outlet_fact(other)?.uniform.as_ref() else {
568                return Ok(None);
569            };
570            let norm: TDim = self.axes.iter().map(|&ax| &prec.outputs[0].fact.shape[ax]).product();
571            let Some(norm) = norm.as_i64() else {
572                return Ok(None);
573            };
574            if norm == 0 {
575                return Ok(None);
576            }
577            let norm = tensor0((norm as f32).recip());
578            if other_konst.close_enough(&norm, Approximation::Close).is_ok() {
579                let mut patch = TypedModelPatch::default();
580                let wire = patch.tap_model(model, prec.inputs[0])?;
581                let wire = patch.wire_node(
582                    &node.name,
583                    Reduce::new(self.axes.clone(), Reducer::MeanOfSquares),
584                    &[wire],
585                )?[0];
586                patch.shunt_outside(model, succ.id.into(), wire)?;
587                return Ok(Some(patch));
588            }
589        }
590        Ok(None)
591    }
592}
593
594pub fn expand_mean_of_squares(
595    _ctx: &(),
596    model: &TypedModel,
597    node: &TypedNode,
598    name: &str,
599    op: &Reduce,
600) -> TractResult<Option<TypedModelPatch>> {
601    if op.reducer == Reducer::MeanOfSquares {
602        let mut patch = TypedModelPatch::default();
603        let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?);
604        let input_fact = model.outlet_fact(node.inputs[0])?;
605        let dt = input_fact.datum_type;
606        if dt != f32::datum_type() {
607            wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?;
608        }
609        wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?;
610        wire = patch.wire_node(
611            format!("{name}.sum"),
612            Reduce::new(op.axes.clone(), Reducer::Sum),
613            &wire,
614        )?;
615        let card = input_fact
616            .shape
617            .iter()
618            .enumerate()
619            .filter(|(ix, _dim)| op.axes.contains(ix))
620            .map(|(_ix, dim)| dim)
621            .product::<TDim>();
622        let card = patch.add_const(format!("{name}.card"), tensor0(card))?;
623        let card =
624            patch.wire_node(format!("{name}.card_to_f32"), cast(f32::datum_type()), &[card])?;
625
626        wire = wire_with_rank_broadcast(
627            format!("{name}.norm"),
628            &mut patch,
629            div(),
630            &[wire[0], card[0]],
631        )?;
632        if dt != f32::datum_type() {
633            wire = patch.wire_node(format!("{name}.from_f32"), cast(dt), &wire)?;
634        }
635        patch.shunt_outside(model, node.id.into(), wire[0])?;
636        Ok(Some(patch))
637    } else {
638        Ok(None)
639    }
640}