tract_core/ops/nn/
reduce.rs

1use crate::internal::Axis;
2use crate::internal::*;
3use crate::ops::binary::TypedBinOp;
4use crate::ops::cast::cast;
5use crate::ops::change_axes::wire_with_rank_broadcast;
6use crate::ops::element_wise::ElementWiseOp;
7use crate::ops::math::{div, square, Mul, Square};
8use std::convert::TryFrom;
9use std::iter::Sum;
10use std::mem::transmute;
11use tract_data::internal::ClampCast;
12use tract_data::itertools::Itertools;
13use tract_ndarray::prelude::*;
14use tract_num_traits::{AsPrimitive, Bounded};
15
16macro_rules! r {
17    ($($path:ident)::* ($dt:expr) ($($args:expr),*)) => {
18        match $dt {
19            DatumType::U8   => $($path)::*::<u8,_,_,_>($($args),*),
20            DatumType::I8   => $($path)::*::<i8,_,_,_>($($args),*),
21            DatumType::U16  => $($path)::*::<u16,_,_,_>($($args),*),
22            DatumType::I16  => $($path)::*::<i16,_,_,_>($($args),*),
23            DatumType::I32  => $($path)::*::<i32,_,_,_>($($args),*),
24            DatumType::I64  => $($path)::*::<i64,_,_,_>($($args),*),
25            DatumType::F16  => $($path)::*::<f16,_,_,_>($($args),*),
26            DatumType::F32  => $($path)::*::<f32,_,_,_>($($args),*),
27            DatumType::F64  => $($path)::*::<f64,_,_,_>($($args),*),
28            DatumType::QI8(_)  => $($path)::*::<i8,_,_,_>($($args),*),
29            DatumType::QU8(_)  => $($path)::*::<u8,_,_,_>($($args),*),
30            _ => bail!("{:?} is not a number", $dt)
31        }
32    };
33    ($($path:ident)::* ($dt:expr) ($($args:expr),*); $($q_path:ident)::* ($($q_args:expr),*)) => {
34        match $dt {
35            DatumType::U8   => $($path)::*::<u8,_,_,_>($($args),*),
36            DatumType::I8   => $($path)::*::<i8,_,_,_>($($args),*),
37            DatumType::U16  => $($path)::*::<u16,_,_,_>($($args),*),
38            DatumType::I16  => $($path)::*::<i16,_,_,_>($($args),*),
39            DatumType::I32  => $($path)::*::<i32,_,_,_>($($args),*),
40            DatumType::I64  => $($path)::*::<i64,_,_,_>($($args),*),
41            DatumType::F16  => $($path)::*::<f16,_,_,_>($($args),*),
42            DatumType::F32  => $($path)::*::<f32,_,_,_>($($args),*),
43            DatumType::F64  => $($path)::*::<f64,_,_,_>($($args),*),
44            DatumType::QI8(_)  => $($q_path)::*::<i8,_,_,_>($($q_args),*),
45            DatumType::QU8(_)  => $($q_path)::*::<u8,_,_,_>($($q_args),*),
46            _ => bail!("{:?} is not a number", $dt)
47        }
48    }
49}
50
51#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
52pub enum Reducer {
53    ArgMax(bool), // take last
54    ArgMin(bool),
55    Max,
56    Min,
57    Prod,
58    Sum,
59    MeanOfSquares,
60}
61
62impl Reducer {
63    pub fn reduce(&self, axes: &[usize], input: &Tensor) -> TractResult<Tensor> {
64        use Reducer::*;
65        let dt = input.datum_type();
66        let output_shape: Vec<usize> = input
67            .shape()
68            .iter()
69            .enumerate()
70            .map(|(ax, &d)| if axes.contains(&ax) { 1 } else { d })
71            .collect();
72        let (zp, scale) = input.datum_type().zp_scale();
73        unsafe {
74            let mut t = match self {
75                ArgMax(last) => {
76                    r!(Self::reduce_t(dt)(self, axes, &output_shape, input, argmax_t, *last))
77                }
78                ArgMin(last) => {
79                    r!(Self::reduce_t(dt)(self, axes, &output_shape, input, argmin_t, *last))
80                }
81                Min => r!(Self::reduce_t(dt)(self, axes, &output_shape, input, min_t, ())),
82                Max => r!(Self::reduce_t(dt)(self, axes, &output_shape, input, max_t, ())),
83                Prod => {
84                    r!(Self::reduce_t(dt)(self, axes, &output_shape, input, prod_t, ()); Self::reduce_t(self, axes, &output_shape, input, q_prod_t, (zp, scale)))
85                }
86                Sum => {
87                    if dt.is_float() {
88                        dispatch_floatlike!(Self::sum(dt)(self, axes, input))
89                    } else {
90                        r!(Self::reduce_t(dt)(
91                            self,
92                            axes,
93                            &output_shape,
94                            input,
95                            q_sum_t,
96                            (zp, scale)
97                        ))
98                    }
99                }
100                MeanOfSquares => self.mean_of_squares(axes, input)?,
101            };
102            if input.datum_type().is_quantized()
103                && input.datum_type().unquantized() == t.datum_type().unquantized()
104            {
105                t.set_datum_type(input.datum_type());
106            }
107            Ok(t)
108        }
109    }
110
111    unsafe fn reduce_t<T, TO, F, A>(
112        &self,
113        axes: &[usize],
114        output_shape: &[usize],
115        input_tensor: &Tensor,
116        f: F,
117        args: A,
118    ) -> Tensor
119    where
120        F: for<'a> Fn(ArrayViewD<'a, T>, A) -> TO,
121        T: Copy + Datum,
122        TO: Copy + Datum,
123        A: Copy,
124    {
125        use ndarray::*;
126        let input = input_tensor.to_array_view_unchecked::<T>();
127        let result = Array::from_shape_fn(output_shape, |coords| {
128            let slice_spec: Vec<SliceInfoElem> = coords
129                .slice()
130                .iter()
131                .enumerate()
132                .map(|(ax, &d)| if axes.contains(&ax) { (..).into() } else { d.into() })
133                .collect();
134            let slice_info = SliceInfo::<_, IxDyn, IxDyn>::try_from(slice_spec).unwrap();
135            let slice = input.slice(&slice_info);
136            f(slice, args)
137        });
138        result.into_tensor()
139    }
140
141    // sum is a special citizen: enough activity that it gets "special"
142    // treatment. we could use the same "algo" for min, max and prod, to the
143    // price of more code in the library. argmax and argmin are more
144    // tricky (not associative)
145    unsafe fn sum<T>(&self, axes: &[usize], input: &Tensor) -> Tensor
146    where
147        T: Copy + Datum + num_traits::Zero + Sum,
148        f16: AsPrimitive<T>,
149        f32: AsPrimitive<T>,
150    {
151        if axes.len() == 0 {
152            return input.to_owned();
153        }
154
155        // use tract-optimized path only when single reuction axis and is at end
156        if axes.len() > 1 || axes[0] != input.rank() - 1 {
157            let mut operative_axes = vec![];
158            let mut operative_shape: Vec<usize> = vec![];
159            for (ix, dim) in input.shape().iter().enumerate() {
160                // axis is reduced, but is not the first of a series of reduced axes
161                if ix > 0 && axes.contains(&ix) && axes.contains(&(ix - 1)) {
162                    *operative_shape.last_mut().unwrap() *= *dim;
163                } else if axes.contains(&ix) {
164                    operative_axes.push(operative_shape.len());
165                    operative_shape.push(*dim);
166                } else {
167                    operative_shape.push(*dim);
168                }
169            }
170            let mut output = input
171                .to_array_view_unchecked::<T>()
172                .into_shape_with_order(operative_shape)
173                .unwrap()
174                .sum_axis(Axis(*operative_axes.iter().max().unwrap()));
175
176            for axis in operative_axes.iter().rev().skip(1) {
177                output = output.sum_axis(Axis(*axis));
178            }
179
180            let mut output = output.into_tensor();
181
182            for &axis in axes {
183                output.insert_axis(axis).unwrap();
184            }
185
186            output
187        } else {
188            let mut output: Option<ArrayD<T>> = None;
189            for axis in axes.iter().copied() {
190                let input_view = output
191                    .as_ref()
192                    .map(|o| o.view())
193                    .unwrap_or_else(|| input.to_array_view_unchecked::<T>());
194
195                // Create array that will contain intermidiate result
196                let reduced_dim = input_view.shape()[axis];
197                let input_stride = input_view.strides()[axis] as usize;
198                let output_shape = input_view
199                    .shape()
200                    .iter()
201                    .enumerate()
202                    .map(|(idx, dim)| if idx != axis { *dim } else { 1 })
203                    .collect_vec();
204
205                output = Some(ArrayD::from_shape_fn(output_shape.clone(), |coords| {
206                    let mut view = input_view.view();
207                    for ix in 0..output_shape.len() {
208                        if ix != axis {
209                            view.collapse_axis(Axis(ix), coords[ix]);
210                        }
211                    }
212
213                    if let Some(slice) = view.as_slice() {
214                        if T::datum_type() == f16::datum_type() {
215                            let slice: &[f16] = unsafe { std::mem::transmute(slice) };
216                            (tract_linalg::ops().sum_f16)()
217                                .run_with_params(slice, ())
218                                .unwrap()
219                                .as_()
220                        } else if T::datum_type() == f32::datum_type() {
221                            let slice: &[f32] = unsafe { std::mem::transmute(slice) };
222                            (tract_linalg::ops().sum_f32)()
223                                .run_with_params(slice, ())
224                                .unwrap()
225                                .as_()
226                        } else {
227                            slice.iter().cloned().sum::<T>()
228                        }
229                    } else {
230                        dbg!("ndarary code");
231                        let first: *const T = &input_view[coords];
232                        let mut sum = T::zero();
233                        for i in 0..reduced_dim {
234                            sum = sum + *(first.add(i * input_stride));
235                        }
236                        sum
237                    }
238                }));
239            }
240            output.unwrap().into_tensor()
241        }
242    }
243
244    fn mean_of_squares(&self, axis: &[usize], input: &Tensor) -> TractResult<Tensor> {
245        let dt = input.datum_type();
246        let mut input = input.cast_to::<f32>()?.into_owned();
247        input.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x = *x * *x);
248        let mut output = unsafe { self.sum::<f32>(axis, &input) };
249        let norm = output.len() as f32 / input.len() as f32;
250        output.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x *= norm);
251        Ok(output.cast_to_dt(dt)?.into_owned())
252    }
253}
254
255fn argmax_t<T>(v: ArrayViewD<T>, last: bool) -> i64
256where
257    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
258{
259    v.iter()
260        .copied()
261        .enumerate()
262        .fold(
263            (0usize, T::min_value()),
264            |acc, v| {
265                if v.1 > acc.1 || (last && acc.1 == v.1) {
266                    v
267                } else {
268                    acc
269                }
270            },
271        )
272        .0 as i64
273}
274
275fn argmin_t<T>(v: ArrayViewD<T>, last: bool) -> i64
276where
277    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
278{
279    v.iter()
280        .copied()
281        .enumerate()
282        .fold(
283            (0usize, T::max_value()),
284            |acc, v| {
285                if v.1 < acc.1 || (last && acc.1 == v.1) {
286                    v
287                } else {
288                    acc
289                }
290            },
291        )
292        .0 as i64
293}
294
295fn max_t<T>(v: ArrayViewD<T>, _: ()) -> T
296where
297    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
298{
299    if T::datum_type() == f32::datum_type() {
300        if let Some(slice) = v.as_slice() {
301            let slice = unsafe { transmute::<&[T], &[f32]>(slice) };
302            (tract_linalg::ops().max_f32)().run(slice).unwrap();
303        }
304    }
305    v.fold(T::min_value(), |acc, &v| if acc > v { acc } else { v })
306}
307
308fn min_t<T>(v: ArrayViewD<T>, _: ()) -> T
309where
310    T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
311{
312    v.fold(T::max_value(), |acc, &v| if acc < v { acc } else { v })
313}
314
315fn prod_t<T>(v: ArrayViewD<T>, _: ()) -> T
316where
317    T: Copy + Datum + num_traits::One,
318{
319    v.fold(T::one(), |acc, &v| acc * v)
320}
321
322fn q_prod_t<T>(v: ArrayViewD<T>, zp_scale: (i32, f32)) -> T
323where
324    T: Copy + num_traits::AsPrimitive<f32> + Bounded + Datum,
325    f32: num_traits::AsPrimitive<T>,
326{
327    let (zp, scale) = zp_scale;
328    (v.fold(1f32, |acc, &v| acc * (v.as_() - zp as f32)) * scale.powi(v.len() as i32 - 1)
329        + zp as f32)
330        .clamp_cast()
331}
332
333fn q_sum_t<T>(v: ArrayViewD<T>, zp_scale: (i32, f32)) -> T
334where
335    T: Copy + Bounded + num_traits::AsPrimitive<i32> + Datum,
336    i32: num_traits::AsPrimitive<T>,
337{
338    let (zp, _) = zp_scale;
339    (v.fold(0i32, |acc, &v| acc + v.as_()) - zp * (v.len() as i32 - 1)).clamp_cast()
340}
341
342#[derive(Clone, Debug, new, Hash)]
343pub struct Reduce {
344    pub axes: TVec<usize>,
345    pub reducer: Reducer,
346}
347
348impl Op for Reduce {
349    fn name(&self) -> Cow<str> {
350        format!("Reduce<{:?}>", self.reducer).into()
351    }
352    fn info(&self) -> TractResult<Vec<String>> {
353        Ok(vec![format!("axes: {:?}", self.axes)])
354    }
355    op_as_typed_op!();
356}
357
358impl EvalOp for Reduce {
359    fn is_stateless(&self) -> bool {
360        true
361    }
362
363    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
364        Ok(tvec!(self.reducer.reduce(&self.axes, &inputs[0])?.into()))
365    }
366}
367
368impl TypedOp for Reduce {
369    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
370        ensure!(self.axes.iter().tuple_windows().all(|(a, b)| a < b));
371        if inputs[0].datum_type == TDim::datum_type() {
372            bail!("Reduce input must be cast from TDim to i64 beforehand")
373        }
374        let mut shape: TVec<_> = inputs[0].shape.to_tvec();
375        for &ax in &self.axes {
376            shape[ax] = 1.to_dim();
377        }
378        let dt = if let Reducer::ArgMax(_) | Reducer::ArgMin(_) = self.reducer {
379            DatumType::I64
380        } else {
381            inputs[0].datum_type
382        };
383        Ok(tvec!(dt.fact(shape)))
384    }
385
386    fn declutter(
387        &self,
388        model: &TypedModel,
389        node: &TypedNode,
390    ) -> TractResult<Option<TypedModelPatch>> {
391        if let Some(patch) = self.declutter_mean_of_square(model, node)? {
392            return Ok(Some(patch));
393        }
394        if let Some(patch) = self.declutter_scalar_mul_then_sum(model, node)? {
395            return Ok(Some(patch));
396        }
397        if let Some(patch) = self.declutter_reduce_reduce(model, node)? {
398            return Ok(Some(patch));
399        }
400        Ok(None)
401    }
402
403    fn axes_mapping(
404        &self,
405        inputs: &[&TypedFact],
406        outputs: &[&TypedFact],
407    ) -> TractResult<AxesMapping> {
408        let mut letters = 'a'..;
409        let axes = (0..inputs[0].rank())
410            .flat_map(|ix| {
411                if self.axes.contains(&ix) {
412                    tvec!(
413                        Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
414                            .input(0, ix),
415                        Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
416                            .output(0, ix),
417                    )
418                } else {
419                    tvec!(Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
420                        .input(0, ix)
421                        .output(0, ix))
422                }
423                .into_iter()
424            })
425            .collect_vec();
426        AxesMapping::new(1, 1, axes)
427    }
428
429    fn change_axes(
430        &self,
431        model: &TypedModel,
432        node: &TypedNode,
433        _io: InOut,
434        change: &AxisOp,
435    ) -> TractResult<Option<AxisChangeConsequence>> {
436        let mut axes = tvec!();
437        for reduced in &self.axes {
438            if let Some(axis) = change.transform_axis(*reduced) {
439                axes.push(axis);
440            } else {
441                return Ok(None);
442            }
443        }
444        axes.sort();
445        let op = Some(Box::new(Self { axes, ..self.clone() }) as _);
446        Ok(Some(AxisChangeConsequence::new(model, node, op, change)))
447    }
448
449    fn slice(
450        &self,
451        patch: &mut TypedModelPatch,
452        _model: &TypedModel,
453        node: &TypedNode,
454        _prefix: &str,
455        inputs: &[OutletId],
456        output_axis: usize,
457        _start: &TDim,
458        _end: &TDim,
459    ) -> TractResult<Option<TVec<OutletId>>> {
460        if self.axes.contains(&output_axis) {
461            return Ok(None);
462        }
463        patch.wire_node(&node.name, &node.op, inputs).map(Some)
464    }
465
466    as_op!();
467}
468
469impl Reduce {
470    fn declutter_reduce_reduce(
471        &self,
472        model: &TypedModel,
473        node: &TypedNode,
474    ) -> TractResult<Option<TypedModelPatch>> {
475        let Some(prec) = model.single_prec(node.id)? else {
476            return Ok(None);
477        };
478        let Some(prec_reduce) = prec.op_as::<Self>() else {
479            return Ok(None);
480        };
481        use Reducer::*;
482        if prec_reduce.reducer != self.reducer || ![Sum, Prod, Min, Max].contains(&self.reducer) {
483            return Ok(None);
484        }
485        let mut patch = TypedModelPatch::default();
486        let wire = patch.tap_model(model, prec.inputs[0])?;
487        let wire = patch.wire_node(
488            &node.name,
489            Self {
490                reducer: self.reducer,
491                axes: prec_reduce
492                    .axes
493                    .iter()
494                    .chain(self.axes.iter())
495                    .copied()
496                    .sorted()
497                    .dedup()
498                    .collect(),
499            },
500            &[wire],
501        )?;
502        patch.shunt_outside(model, node.id.into(), wire[0])?;
503        Ok(Some(patch))
504    }
505
506    fn declutter_scalar_mul_then_sum(
507        &self,
508        model: &TypedModel,
509        node: &TypedNode,
510    ) -> TractResult<Option<TypedModelPatch>> {
511        if self.reducer == Reducer::Sum {
512            let Some(prec) = model.single_prec(node.id)? else {
513                return Ok(None);
514            };
515            let Some(prec_bin) = prec.op_as::<TypedBinOp>() else {
516                return Ok(None);
517            };
518            if !prec_bin.0.is::<Mul>() {
519                return Ok(None);
520            }
521            let mul_input_fact = model.node_input_facts(prec.id)?;
522            let Some(scalar_slot) = mul_input_fact
523                .iter()
524                .position(|f| f.konst.as_ref().is_some_and(|k| k.volume() == 1))
525            else {
526                return Ok(None);
527            };
528            let mut patch = TypedModelPatch::default();
529            let scalar = patch.tap_model(model, prec.inputs[scalar_slot])?;
530            let wire = patch.tap_model(model, prec.inputs[1 - scalar_slot])?;
531            let wire = patch.wire_node(&node.name, self.clone(), &[wire])?[0];
532            let wire = patch.wire_node(&prec.name, prec_bin.clone(), &[wire, scalar])?[0];
533            patch.shunt_outside(model, node.id.into(), wire)?;
534            return Ok(Some(patch));
535        }
536        Ok(None)
537    }
538
539    fn declutter_mean_of_square(
540        &self,
541        model: &TypedModel,
542        node: &TypedNode,
543    ) -> TractResult<Option<TypedModelPatch>> {
544        if self.reducer == Reducer::Sum {
545            let Some(prec) = model.single_prec(node.id)? else {
546                return Ok(None);
547            };
548            let Some(prec_ew) = prec.op_as::<ElementWiseOp>() else {
549                return Ok(None);
550            };
551            if !prec_ew.0.is::<Square>() {
552                return Ok(None);
553            }
554            if node.outputs.len() != 1 || node.outputs[0].successors.len() != 1 {
555                return Ok(None);
556            }
557            let our_inlet = node.outputs[0].successors[0];
558            let succ = model.node(our_inlet.node);
559            let Some(succ_bin) = succ.op_as::<TypedBinOp>() else {
560                return Ok(None);
561            };
562            if !succ_bin.0.is::<Mul>() {
563                return Ok(None);
564            }
565            let other = succ.inputs[1 - our_inlet.slot];
566            let Some(other_konst) = model.outlet_fact(other)?.uniform.as_ref() else {
567                return Ok(None);
568            };
569            let norm: TDim = self.axes.iter().map(|&ax| &prec.outputs[0].fact.shape[ax]).product();
570            let Some(norm) = norm.as_i64() else {
571                return Ok(None);
572            };
573            if norm == 0 {
574                return Ok(None);
575            }
576            let norm = tensor0((norm as f32).recip());
577            if other_konst.close_enough(&norm, Approximation::Close).is_ok() {
578                let mut patch = TypedModelPatch::default();
579                let wire = patch.tap_model(model, prec.inputs[0])?;
580                let wire = patch.wire_node(
581                    &node.name,
582                    Reduce::new(self.axes.clone(), Reducer::MeanOfSquares),
583                    &[wire],
584                )?[0];
585                patch.shunt_outside(model, succ.id.into(), wire)?;
586                return Ok(Some(patch));
587            }
588        }
589        Ok(None)
590    }
591}
592
593pub fn expand_mean_of_squares(
594    _ctx: &(),
595    model: &TypedModel,
596    node: &TypedNode,
597    name: &str,
598    op: &Reduce,
599) -> TractResult<Option<TypedModelPatch>> {
600    if op.reducer == Reducer::MeanOfSquares {
601        let mut patch = TypedModelPatch::default();
602        let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?);
603        let input_fact = model.outlet_fact(node.inputs[0])?;
604        let dt = input_fact.datum_type;
605        if dt != f32::datum_type() {
606            wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?;
607        }
608        wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?;
609        wire = patch.wire_node(
610            format!("{name}.sum"),
611            Reduce::new(op.axes.clone(), Reducer::Sum),
612            &wire,
613        )?;
614        let card = input_fact
615            .shape
616            .iter()
617            .enumerate()
618            .filter(|(ix, _dim)| op.axes.contains(ix))
619            .map(|(_ix, dim)| dim)
620            .product::<TDim>();
621        let card = patch.add_const(format!("{name}.card"), tensor0(card))?;
622        let card =
623            patch.wire_node(format!("{name}.card_to_f32"), cast(f32::datum_type()), &[card])?;
624
625        wire = wire_with_rank_broadcast(
626            format!("{name}.norm"),
627            &mut patch,
628            div(),
629            &[wire[0], card[0]],
630        )?;
631        if dt != f32::datum_type() {
632            wire = patch.wire_node(format!("{name}.from_f32"), cast(dt), &wire)?;
633        }
634        patch.shunt_outside(model, node.id.into(), wire[0])?;
635        Ok(Some(patch))
636    } else {
637        Ok(None)
638    }
639}