tract_hir/ops/nn/
reduce.rs

1use crate::internal::*;
2
3use tract_core::ops::nn::Reduce as TReduce;
4use tract_core::ops::nn::Reducer as TReducer;
5
6#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
7pub enum Reducer {
8    ArgMax(bool), // take last
9    ArgMin(bool),
10    L1,
11    L2,
12    LogSum,
13    LogSumExp,
14    Max,
15    Mean,
16    Min,
17    Prod,
18    Sum,
19    SumSquare,
20}
21
22impl Reducer {
23    pub fn wire(
24        &self,
25        axes: TVec<usize>,
26        name: &str,
27        target: &mut TypedModel,
28        mut wire: OutletId,
29    ) -> TractResult<OutletId> {
30        use tract_core::ops::math;
31        use Reducer::*;
32        match self {
33            ArgMax(last) => {
34                wire =
35                    target.wire_node(name, TReduce::new(axes, TReducer::ArgMax(*last)), &[wire])?[0]
36            }
37            ArgMin(last) => {
38                wire =
39                    target.wire_node(name, TReduce::new(axes, TReducer::ArgMin(*last)), &[wire])?[0]
40            }
41            Max => wire = target.wire_node(name, TReduce::new(axes, TReducer::Max), &[wire])?[0],
42            Min => wire = target.wire_node(name, TReduce::new(axes, TReducer::Min), &[wire])?[0],
43            Sum => wire = target.wire_node(name, TReduce::new(axes, TReducer::Sum), &[wire])?[0],
44            Prod => wire = target.wire_node(name, TReduce::new(axes, TReducer::Prod), &[wire])?[0],
45
46            L1 => {
47                wire = target.wire_node(format!("{name}.abs"), math::abs(), &[wire])?[0];
48                wire = target.wire_node(
49                    format!("{name}.sum"),
50                    TReduce::new(axes, TReducer::Sum),
51                    &[wire],
52                )?[0];
53            }
54            L2 => {
55                wire = target.wire_node(format!("{name}.sq"), math::square(), &[wire])?[0];
56                wire = target.wire_node(
57                    format!("{name}.sum"),
58                    TReduce::new(axes, TReducer::Sum),
59                    &[wire],
60                )?[0];
61                wire = target.wire_node(format!("{name}.sqrt"), math::sqrt(), &[wire])?[0];
62            }
63            LogSum => {
64                wire = target.wire_node(
65                    format!("{name}.sum"),
66                    TReduce::new(axes, TReducer::Sum),
67                    &[wire],
68                )?[0];
69                wire = target.wire_node(format!("{name}.ln"), math::ln(), &[wire])?[0];
70            }
71            LogSumExp => {
72                wire = target.wire_node(format!("{name}.exp"), math::exp(), &[wire])?[0];
73                wire = target.wire_node(
74                    format!("{name}.sum"),
75                    TReduce::new(axes, TReducer::Sum),
76                    &[wire],
77                )?[0];
78                wire = target.wire_node(format!("{name}.ln"), math::ln(), &[wire])?[0];
79            }
80            SumSquare => {
81                wire = target.wire_node(format!("{name}.sq"), math::square(), &[wire])?[0];
82                wire = target.wire_node(
83                    name.to_string() + ".sum",
84                    TReduce::new(axes, TReducer::Sum),
85                    &[wire],
86                )?[0]
87            }
88            Mean => {
89                let fact = target.outlet_fact(wire)?.clone();
90                wire = target.wire_node(
91                    name.to_string() + ".sum",
92                    TReduce::new(axes.clone(), TReducer::Sum),
93                    &[wire],
94                )?[0];
95                let size: TDim = axes.iter().map(|ax| &fact.shape[*ax]).product();
96                let size = tensor0(size).broadcast_into_rank(fact.rank())?;
97                let size = target.add_const(name.to_string() + ".size", size)?;
98                let size = target.wire_node(
99                    name.to_string() + ".cast",
100                    tract_core::ops::cast::cast(fact.datum_type),
101                    &[size],
102                )?[0];
103                wire = target.wire_node(name.to_string() + ".norm", math::div(), &[wire, size])?[0];
104            }
105        };
106        Ok(wire)
107    }
108}
109
110#[derive(Clone, Debug, new, Hash)]
111pub struct Reduce {
112    pub axes: Option<Vec<i64>>,
113    pub keep_dims: bool,
114    pub reducer: Reducer,
115}
116
117
118
119impl Reduce {
120    pub fn must_reduce(&self, ax: usize, rank: usize) -> bool {
121        let resolved_axes: Option<Vec<usize>> = match &self.axes {
122            None => None,
123            Some(original_axes) => {
124                let mut ans: Vec<usize> = vec![];
125                for or_ax in original_axes.iter() {
126                    ans.push(Self::resolve_axis(*or_ax, rank).unwrap());
127                }
128                Some(ans)
129            }
130        };
131
132        resolved_axes.as_ref().map(|axes| axes.contains(&ax)).unwrap_or(true)
133    }
134
135    pub fn output_shape(&self, shape: &[TDim]) -> TVec<TDim> {
136        shape
137            .iter()
138            .enumerate()
139            .filter_map(|(ix, d)| {
140                if self.must_reduce(ix, shape.len()) {
141                    if self.keep_dims {
142                        Some(1.to_dim())
143                    } else {
144                        None
145                    }
146                } else {
147                    Some(d.clone())
148                }
149            })
150            .collect()
151    }
152
153    fn resolve_axis(axis: i64, rank: usize) -> TractResult<usize> {
154        if 0 <= axis && axis < rank as i64 {
155            Ok(axis as usize)
156        } else if -(rank as i64) <= axis && axis < 0 {
157            Ok((axis + rank as i64) as usize)
158        } else {
159            bail!("Illegal combination of values for rank and axis: {} and {}", rank, axis)
160        }
161    }
162
163    fn resolve_axes(&self, input_rank: usize) -> TractResult<TVec<usize>> {
164        let mut axes: TVec<usize> = match self.axes.as_ref() {
165            None => Ok((0..input_rank).collect()),
166            Some(axis) => axis.iter().map(|&a| Self::resolve_axis(a, input_rank)).collect(),
167        }?;
168        axes.sort();
169        Ok(axes)
170    }
171}
172
173impl Expansion for Reduce {
174    fn name(&self) -> StaticName {
175        format!("Reduce<{:?}>", self.reducer).into()
176    }
177    fn info(&self) -> TractResult<Vec<String>> {
178        Ok(vec![format!("axes: {:?} keep_dims: {}", self.axes, self.keep_dims)])
179    }
180
181    fn rules<'r, 'p: 'r, 's: 'r>(
182        &'s self,
183        s: &mut Solver<'r>,
184        inputs: &'p [TensorProxy],
185        outputs: &'p [TensorProxy],
186    ) -> InferenceResult {
187        check_input_arity(inputs, 1)?;
188        check_output_arity(outputs, 1)?;
189        if let Reducer::ArgMax(_) | Reducer::ArgMin(_) = self.reducer {
190            s.equals(&outputs[0].datum_type, DatumType::I64)?;
191        } else {
192            s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
193        }
194        if self.keep_dims {
195            s.equals(&inputs[0].rank, &outputs[0].rank)?;
196        } else if let Some(axes) = self.axes.as_ref() {
197            s.equals(inputs[0].rank.bex() - axes.len() as i64, &outputs[0].rank)?;
198        } else {
199            s.equals(&outputs[0].rank, 0)?;
200        }
201        s.given(&inputs[0].shape, move |s, shape| {
202            let out_shape = self.output_shape(&shape);
203            s.equals(&outputs[0].shape, out_shape)
204        })
205    }
206
207    fn wire(
208        &self,
209        name: &str,
210        target: &mut TypedModel,
211        inputs: &[OutletId],
212    ) -> TractResult<TVec<OutletId>> {
213        let mut wire = inputs[0];
214        let fact = target.outlet_fact(wire)?.clone();
215        let mut axes = self.resolve_axes(fact.rank())?;
216        axes.sort();
217        if fact.datum_type == TDim::datum_type() {
218            wire = target.wire_node(
219                format!("{name}.cast_from_tdim"),
220                tract_core::ops::cast::cast(i64::datum_type()),
221                &[wire],
222            )?[0];
223        }
224        wire = self.reducer.wire(axes.clone(), name, target, wire).context("wiring reducer")?;
225        if fact.datum_type == TDim::datum_type() {
226            wire = target.wire_node(
227                format!("{name}.cast_to_tdim"),
228                tract_core::ops::cast::cast(TDim::datum_type()),
229                &[wire],
230            )?[0];
231        }
232        if !self.keep_dims {
233            for axis in axes.into_iter().rev() {
234                wire = target.wire_node(
235                    format!("{name}-dispose-dims-{axis}"),
236                    AxisOp::Rm(axis),
237                    &[wire],
238                )?[0];
239            }
240        }
241        Ok(tvec!(wire))
242    }
243}