1use crate::internal::*;
2
3use tract_core::ops::nn::Reduce as TReduce;
4use tract_core::ops::nn::Reducer as TReducer;
5
6#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
7pub enum Reducer {
8 ArgMax(bool), ArgMin(bool),
10 L1,
11 L2,
12 LogSum,
13 LogSumExp,
14 Max,
15 Mean,
16 Min,
17 Prod,
18 Sum,
19 SumSquare,
20}
21
22impl Reducer {
23 pub fn wire(
24 &self,
25 axes: TVec<usize>,
26 name: &str,
27 target: &mut TypedModel,
28 mut wire: OutletId,
29 ) -> TractResult<OutletId> {
30 use tract_core::ops::math;
31 use Reducer::*;
32 match self {
33 ArgMax(last) => {
34 wire =
35 target.wire_node(name, TReduce::new(axes, TReducer::ArgMax(*last)), &[wire])?[0]
36 }
37 ArgMin(last) => {
38 wire =
39 target.wire_node(name, TReduce::new(axes, TReducer::ArgMin(*last)), &[wire])?[0]
40 }
41 Max => wire = target.wire_node(name, TReduce::new(axes, TReducer::Max), &[wire])?[0],
42 Min => wire = target.wire_node(name, TReduce::new(axes, TReducer::Min), &[wire])?[0],
43 Sum => wire = target.wire_node(name, TReduce::new(axes, TReducer::Sum), &[wire])?[0],
44 Prod => wire = target.wire_node(name, TReduce::new(axes, TReducer::Prod), &[wire])?[0],
45
46 L1 => {
47 wire = target.wire_node(format!("{name}.abs"), math::abs(), &[wire])?[0];
48 wire = target.wire_node(
49 format!("{name}.sum"),
50 TReduce::new(axes, TReducer::Sum),
51 &[wire],
52 )?[0];
53 }
54 L2 => {
55 wire = target.wire_node(format!("{name}.sq"), math::square(), &[wire])?[0];
56 wire = target.wire_node(
57 format!("{name}.sum"),
58 TReduce::new(axes, TReducer::Sum),
59 &[wire],
60 )?[0];
61 wire = target.wire_node(format!("{name}.sqrt"), math::sqrt(), &[wire])?[0];
62 }
63 LogSum => {
64 wire = target.wire_node(
65 format!("{name}.sum"),
66 TReduce::new(axes, TReducer::Sum),
67 &[wire],
68 )?[0];
69 wire = target.wire_node(format!("{name}.ln"), math::ln(), &[wire])?[0];
70 }
71 LogSumExp => {
72 wire = target.wire_node(format!("{name}.exp"), math::exp(), &[wire])?[0];
73 wire = target.wire_node(
74 format!("{name}.sum"),
75 TReduce::new(axes, TReducer::Sum),
76 &[wire],
77 )?[0];
78 wire = target.wire_node(format!("{name}.ln"), math::ln(), &[wire])?[0];
79 }
80 SumSquare => {
81 wire = target.wire_node(format!("{name}.sq"), math::square(), &[wire])?[0];
82 wire = target.wire_node(
83 name.to_string() + ".sum",
84 TReduce::new(axes, TReducer::Sum),
85 &[wire],
86 )?[0]
87 }
88 Mean => {
89 let fact = target.outlet_fact(wire)?.clone();
90 wire = target.wire_node(
91 name.to_string() + ".sum",
92 TReduce::new(axes.clone(), TReducer::Sum),
93 &[wire],
94 )?[0];
95 let size: TDim = axes.iter().map(|ax| &fact.shape[*ax]).product();
96 let size = tensor0(size).broadcast_into_rank(fact.rank())?;
97 let size = target.add_const(name.to_string() + ".size", size)?;
98 let size = target.wire_node(
99 name.to_string() + ".cast",
100 tract_core::ops::cast::cast(fact.datum_type),
101 &[size],
102 )?[0];
103 wire = target.wire_node(name.to_string() + ".norm", math::div(), &[wire, size])?[0];
104 }
105 };
106 Ok(wire)
107 }
108}
109
110#[derive(Clone, Debug, new, Hash)]
111pub struct Reduce {
112 pub axes: Option<Vec<i64>>,
113 pub keep_dims: bool,
114 pub reducer: Reducer,
115}
116
117
118
119impl Reduce {
120 pub fn must_reduce(&self, ax: usize, rank: usize) -> bool {
121 let resolved_axes: Option<Vec<usize>> = match &self.axes {
122 None => None,
123 Some(original_axes) => {
124 let mut ans: Vec<usize> = vec![];
125 for or_ax in original_axes.iter() {
126 ans.push(Self::resolve_axis(*or_ax, rank).unwrap());
127 }
128 Some(ans)
129 }
130 };
131
132 resolved_axes.as_ref().map(|axes| axes.contains(&ax)).unwrap_or(true)
133 }
134
135 pub fn output_shape(&self, shape: &[TDim]) -> TVec<TDim> {
136 shape
137 .iter()
138 .enumerate()
139 .filter_map(|(ix, d)| {
140 if self.must_reduce(ix, shape.len()) {
141 if self.keep_dims {
142 Some(1.to_dim())
143 } else {
144 None
145 }
146 } else {
147 Some(d.clone())
148 }
149 })
150 .collect()
151 }
152
153 fn resolve_axis(axis: i64, rank: usize) -> TractResult<usize> {
154 if 0 <= axis && axis < rank as i64 {
155 Ok(axis as usize)
156 } else if -(rank as i64) <= axis && axis < 0 {
157 Ok((axis + rank as i64) as usize)
158 } else {
159 bail!("Illegal combination of values for rank and axis: {} and {}", rank, axis)
160 }
161 }
162
163 fn resolve_axes(&self, input_rank: usize) -> TractResult<TVec<usize>> {
164 let mut axes: TVec<usize> = match self.axes.as_ref() {
165 None => Ok((0..input_rank).collect()),
166 Some(axis) => axis.iter().map(|&a| Self::resolve_axis(a, input_rank)).collect(),
167 }?;
168 axes.sort();
169 Ok(axes)
170 }
171}
172
173impl Expansion for Reduce {
174 fn name(&self) -> StaticName {
175 format!("Reduce<{:?}>", self.reducer).into()
176 }
177 fn info(&self) -> TractResult<Vec<String>> {
178 Ok(vec![format!("axes: {:?} keep_dims: {}", self.axes, self.keep_dims)])
179 }
180
181 fn rules<'r, 'p: 'r, 's: 'r>(
182 &'s self,
183 s: &mut Solver<'r>,
184 inputs: &'p [TensorProxy],
185 outputs: &'p [TensorProxy],
186 ) -> InferenceResult {
187 check_input_arity(inputs, 1)?;
188 check_output_arity(outputs, 1)?;
189 if let Reducer::ArgMax(_) | Reducer::ArgMin(_) = self.reducer {
190 s.equals(&outputs[0].datum_type, DatumType::I64)?;
191 } else {
192 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
193 }
194 if self.keep_dims {
195 s.equals(&inputs[0].rank, &outputs[0].rank)?;
196 } else if let Some(axes) = self.axes.as_ref() {
197 s.equals(inputs[0].rank.bex() - axes.len() as i64, &outputs[0].rank)?;
198 } else {
199 s.equals(&outputs[0].rank, 0)?;
200 }
201 s.given(&inputs[0].shape, move |s, shape| {
202 let out_shape = self.output_shape(&shape);
203 s.equals(&outputs[0].shape, out_shape)
204 })
205 }
206
207 fn wire(
208 &self,
209 name: &str,
210 target: &mut TypedModel,
211 inputs: &[OutletId],
212 ) -> TractResult<TVec<OutletId>> {
213 let mut wire = inputs[0];
214 let fact = target.outlet_fact(wire)?.clone();
215 let mut axes = self.resolve_axes(fact.rank())?;
216 axes.sort();
217 if fact.datum_type == TDim::datum_type() {
218 wire = target.wire_node(
219 format!("{name}.cast_from_tdim"),
220 tract_core::ops::cast::cast(i64::datum_type()),
221 &[wire],
222 )?[0];
223 }
224 wire = self.reducer.wire(axes.clone(), name, target, wire).context("wiring reducer")?;
225 if fact.datum_type == TDim::datum_type() {
226 wire = target.wire_node(
227 format!("{name}.cast_to_tdim"),
228 tract_core::ops::cast::cast(TDim::datum_type()),
229 &[wire],
230 )?[0];
231 }
232 if !self.keep_dims {
233 for axis in axes.into_iter().rev() {
234 wire = target.wire_node(
235 format!("{name}-dispose-dims-{axis}"),
236 AxisOp::Rm(axis),
237 &[wire],
238 )?[0];
239 }
240 }
241 Ok(tvec!(wire))
242 }
243}