1use crate::internal::Axis;
2use crate::internal::*;
3use crate::ops::binary::TypedBinOp;
4use crate::ops::cast::cast;
5use crate::ops::change_axes::wire_with_rank_broadcast;
6use crate::ops::element_wise::ElementWiseOp;
7use crate::ops::math::{div, square, Mul, Square};
8use std::convert::TryFrom;
9use std::iter::Sum;
10use std::mem::transmute;
11use tract_data::internal::ClampCast;
12use tract_data::itertools::Itertools;
13use tract_ndarray::prelude::*;
14use tract_num_traits::{AsPrimitive, Bounded};
15
16macro_rules! r {
17 ($($path:ident)::* ($dt:expr) ($($args:expr),*)) => {
18 match $dt {
19 DatumType::U8 => $($path)::*::<u8,_,_,_>($($args),*),
20 DatumType::I8 => $($path)::*::<i8,_,_,_>($($args),*),
21 DatumType::U16 => $($path)::*::<u16,_,_,_>($($args),*),
22 DatumType::I16 => $($path)::*::<i16,_,_,_>($($args),*),
23 DatumType::I32 => $($path)::*::<i32,_,_,_>($($args),*),
24 DatumType::I64 => $($path)::*::<i64,_,_,_>($($args),*),
25 DatumType::F16 => $($path)::*::<f16,_,_,_>($($args),*),
26 DatumType::F32 => $($path)::*::<f32,_,_,_>($($args),*),
27 DatumType::F64 => $($path)::*::<f64,_,_,_>($($args),*),
28 DatumType::QI8(_) => $($path)::*::<i8,_,_,_>($($args),*),
29 DatumType::QU8(_) => $($path)::*::<u8,_,_,_>($($args),*),
30 _ => bail!("{:?} is not a number", $dt)
31 }
32 };
33 ($($path:ident)::* ($dt:expr) ($($args:expr),*); $($q_path:ident)::* ($($q_args:expr),*)) => {
34 match $dt {
35 DatumType::U8 => $($path)::*::<u8,_,_,_>($($args),*),
36 DatumType::I8 => $($path)::*::<i8,_,_,_>($($args),*),
37 DatumType::U16 => $($path)::*::<u16,_,_,_>($($args),*),
38 DatumType::I16 => $($path)::*::<i16,_,_,_>($($args),*),
39 DatumType::I32 => $($path)::*::<i32,_,_,_>($($args),*),
40 DatumType::I64 => $($path)::*::<i64,_,_,_>($($args),*),
41 DatumType::F16 => $($path)::*::<f16,_,_,_>($($args),*),
42 DatumType::F32 => $($path)::*::<f32,_,_,_>($($args),*),
43 DatumType::F64 => $($path)::*::<f64,_,_,_>($($args),*),
44 DatumType::QI8(_) => $($q_path)::*::<i8,_,_,_>($($q_args),*),
45 DatumType::QU8(_) => $($q_path)::*::<u8,_,_,_>($($q_args),*),
46 _ => bail!("{:?} is not a number", $dt)
47 }
48 }
49}
50
51#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
52pub enum Reducer {
53 ArgMax(bool), ArgMin(bool),
55 Max,
56 Min,
57 Prod,
58 Sum,
59 MeanOfSquares,
60}
61
62impl Reducer {
63 pub fn reduce(&self, axes: &[usize], input: &Tensor) -> TractResult<Tensor> {
64 use Reducer::*;
65 let dt = input.datum_type();
66 let output_shape: Vec<usize> = input
67 .shape()
68 .iter()
69 .enumerate()
70 .map(|(ax, &d)| if axes.contains(&ax) { 1 } else { d })
71 .collect();
72 let (zp, scale) = input.datum_type().zp_scale();
73 unsafe {
74 let mut t = match self {
75 ArgMax(last) => {
76 r!(Self::reduce_t(dt)(self, axes, &output_shape, input, argmax_t, *last))
77 }
78 ArgMin(last) => {
79 r!(Self::reduce_t(dt)(self, axes, &output_shape, input, argmin_t, *last))
80 }
81 Min => r!(Self::reduce_t(dt)(self, axes, &output_shape, input, min_t, ())),
82 Max => r!(Self::reduce_t(dt)(self, axes, &output_shape, input, max_t, ())),
83 Prod => {
84 r!(Self::reduce_t(dt)(self, axes, &output_shape, input, prod_t, ()); Self::reduce_t(self, axes, &output_shape, input, q_prod_t, (zp, scale)))
85 }
86 Sum => {
87 if dt.is_float() {
88 dispatch_floatlike!(Self::sum(dt)(self, axes, input))
89 } else {
90 r!(Self::reduce_t(dt)(
91 self,
92 axes,
93 &output_shape,
94 input,
95 q_sum_t,
96 (zp, scale)
97 ))
98 }
99 }
100 MeanOfSquares => self.mean_of_squares(axes, input)?,
101 };
102 if input.datum_type().is_quantized()
103 && input.datum_type().unquantized() == t.datum_type().unquantized()
104 {
105 t.set_datum_type(input.datum_type());
106 }
107 Ok(t)
108 }
109 }
110
111 unsafe fn reduce_t<T, TO, F, A>(
112 &self,
113 axes: &[usize],
114 output_shape: &[usize],
115 input_tensor: &Tensor,
116 f: F,
117 args: A,
118 ) -> Tensor
119 where
120 F: for<'a> Fn(ArrayViewD<'a, T>, A) -> TO,
121 T: Copy + Datum,
122 TO: Copy + Datum,
123 A: Copy,
124 {
125 use ndarray::*;
126 let input = unsafe { input_tensor.to_array_view_unchecked::<T>() };
127 let result = Array::from_shape_fn(output_shape, |coords| {
128 let slice_spec: Vec<SliceInfoElem> = coords
129 .slice()
130 .iter()
131 .enumerate()
132 .map(|(ax, &d)| if axes.contains(&ax) { (..).into() } else { d.into() })
133 .collect();
134 let slice_info = SliceInfo::<_, IxDyn, IxDyn>::try_from(slice_spec).unwrap();
135 let slice = input.slice(&slice_info);
136 f(slice, args)
137 });
138 result.into_tensor()
139 }
140
141 unsafe fn sum<T>(&self, axes: &[usize], input: &Tensor) -> Tensor
146 where
147 T: Copy + Datum + num_traits::Zero + Sum,
148 f16: AsPrimitive<T>,
149 f32: AsPrimitive<T>,
150 {
151 if axes.len() == 0 {
152 return input.to_owned();
153 }
154
155 if axes.len() > 1 || axes[0] != input.rank() - 1 {
157 let mut operative_axes = vec![];
158 let mut operative_shape: Vec<usize> = vec![];
159 for (ix, dim) in input.shape().iter().enumerate() {
160 if ix > 0 && axes.contains(&ix) && axes.contains(&(ix - 1)) {
162 *operative_shape.last_mut().unwrap() *= *dim;
163 } else if axes.contains(&ix) {
164 operative_axes.push(operative_shape.len());
165 operative_shape.push(*dim);
166 } else {
167 operative_shape.push(*dim);
168 }
169 }
170 let mut output = unsafe {
171 input
172 .to_array_view_unchecked::<T>()
173 .into_shape_with_order(operative_shape)
174 .unwrap()
175 .sum_axis(Axis(*operative_axes.iter().max().unwrap()))
176 };
177
178 for axis in operative_axes.iter().rev().skip(1) {
179 output = output.sum_axis(Axis(*axis));
180 }
181
182 let mut output = output.into_tensor();
183
184 for &axis in axes {
185 output.insert_axis(axis).unwrap();
186 }
187
188 output
189 } else {
190 let mut output: Option<ArrayD<T>> = None;
191 for axis in axes.iter().copied() {
192 let input_view = output
193 .as_ref()
194 .map(|o| o.view())
195 .unwrap_or_else(|| unsafe { input.to_array_view_unchecked::<T>() });
196
197 let reduced_dim = input_view.shape()[axis];
199 let input_stride = input_view.strides()[axis] as usize;
200 let output_shape = input_view
201 .shape()
202 .iter()
203 .enumerate()
204 .map(|(idx, dim)| if idx != axis { *dim } else { 1 })
205 .collect_vec();
206
207 output = Some(ArrayD::from_shape_fn(output_shape.clone(), |coords| {
208 let mut view = input_view.view();
209 for ix in 0..output_shape.len() {
210 if ix != axis {
211 view.collapse_axis(Axis(ix), coords[ix]);
212 }
213 }
214
215 if let Some(slice) = view.as_slice() {
216 if T::datum_type() == f16::datum_type() {
217 let slice: &[f16] = unsafe { std::mem::transmute(slice) };
218 (tract_linalg::ops().sum_f16)()
219 .run_with_params(slice, ())
220 .unwrap()
221 .as_()
222 } else if T::datum_type() == f32::datum_type() {
223 let slice: &[f32] = unsafe { std::mem::transmute(slice) };
224 (tract_linalg::ops().sum_f32)()
225 .run_with_params(slice, ())
226 .unwrap()
227 .as_()
228 } else {
229 slice.iter().cloned().sum::<T>()
230 }
231 } else {
232 let first: *const T = &input_view[coords];
233 let mut sum = T::zero();
234 for i in 0..reduced_dim {
235 sum = sum + unsafe { *(first.add(i * input_stride)) };
236 }
237 sum
238 }
239 }));
240 }
241 output.unwrap().into_tensor()
242 }
243 }
244
245 fn mean_of_squares(&self, axis: &[usize], input: &Tensor) -> TractResult<Tensor> {
246 let dt = input.datum_type();
247 let mut input = input.cast_to::<f32>()?.into_owned();
248 input.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x = *x * *x);
249 let mut output = unsafe { self.sum::<f32>(axis, &input) };
250 let norm = output.len() as f32 / input.len() as f32;
251 output.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x *= norm);
252 Ok(output.cast_to_dt(dt)?.into_owned())
253 }
254}
255
256fn argmax_t<T>(v: ArrayViewD<T>, last: bool) -> i64
257where
258 T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
259{
260 v.iter()
261 .copied()
262 .enumerate()
263 .fold(
264 (0usize, T::min_value()),
265 |acc, v| {
266 if v.1 > acc.1 || (last && acc.1 == v.1) {
267 v
268 } else {
269 acc
270 }
271 },
272 )
273 .0 as i64
274}
275
276fn argmin_t<T>(v: ArrayViewD<T>, last: bool) -> i64
277where
278 T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
279{
280 v.iter()
281 .copied()
282 .enumerate()
283 .fold(
284 (0usize, T::max_value()),
285 |acc, v| {
286 if v.1 < acc.1 || (last && acc.1 == v.1) {
287 v
288 } else {
289 acc
290 }
291 },
292 )
293 .0 as i64
294}
295
296fn max_t<T>(v: ArrayViewD<T>, _: ()) -> T
297where
298 T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
299{
300 if T::datum_type() == f32::datum_type() {
301 if let Some(slice) = v.as_slice() {
302 let slice = unsafe { transmute::<&[T], &[f32]>(slice) };
303 (tract_linalg::ops().max_f32)().run(slice).unwrap();
304 }
305 }
306 v.fold(T::min_value(), |acc, &v| if acc > v { acc } else { v })
307}
308
309fn min_t<T>(v: ArrayViewD<T>, _: ()) -> T
310where
311 T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
312{
313 v.fold(T::max_value(), |acc, &v| if acc < v { acc } else { v })
314}
315
316fn prod_t<T>(v: ArrayViewD<T>, _: ()) -> T
317where
318 T: Copy + Datum + num_traits::One,
319{
320 v.fold(T::one(), |acc, &v| acc * v)
321}
322
323fn q_prod_t<T>(v: ArrayViewD<T>, zp_scale: (i32, f32)) -> T
324where
325 T: Copy + num_traits::AsPrimitive<f32> + Bounded + Datum,
326 f32: num_traits::AsPrimitive<T>,
327{
328 let (zp, scale) = zp_scale;
329 (v.fold(1f32, |acc, &v| acc * (v.as_() - zp as f32)) * scale.powi(v.len() as i32 - 1)
330 + zp as f32)
331 .clamp_cast()
332}
333
334fn q_sum_t<T>(v: ArrayViewD<T>, zp_scale: (i32, f32)) -> T
335where
336 T: Copy + Bounded + num_traits::AsPrimitive<i32> + Datum,
337 i32: num_traits::AsPrimitive<T>,
338{
339 let (zp, _) = zp_scale;
340 (v.fold(0i32, |acc, &v| acc + v.as_()) - zp * (v.len() as i32 - 1)).clamp_cast()
341}
342
343#[derive(Clone, Debug, new, Hash)]
344pub struct Reduce {
345 pub axes: TVec<usize>,
346 pub reducer: Reducer,
347}
348
349impl Op for Reduce {
350 fn name(&self) -> StaticName {
351 format!("Reduce<{:?}>", self.reducer).into()
352 }
353 fn info(&self) -> TractResult<Vec<String>> {
354 Ok(vec![format!("axes: {:?}", self.axes)])
355 }
356 op_as_typed_op!();
357}
358
359impl EvalOp for Reduce {
360 fn is_stateless(&self) -> bool {
361 true
362 }
363
364 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
365 Ok(tvec!(self.reducer.reduce(&self.axes, &inputs[0])?.into()))
366 }
367}
368
369impl TypedOp for Reduce {
370 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
371 ensure!(self.axes.iter().tuple_windows().all(|(a, b)| a < b));
372 if inputs[0].datum_type == TDim::datum_type() {
373 bail!("Reduce input must be cast from TDim to i64 beforehand")
374 }
375 let mut shape: TVec<_> = inputs[0].shape.to_tvec();
376 for &ax in &self.axes {
377 shape[ax] = 1.to_dim();
378 }
379 let dt = if let Reducer::ArgMax(_) | Reducer::ArgMin(_) = self.reducer {
380 DatumType::I64
381 } else {
382 inputs[0].datum_type
383 };
384 Ok(tvec!(dt.fact(shape)))
385 }
386
387 fn declutter(
388 &self,
389 model: &TypedModel,
390 node: &TypedNode,
391 ) -> TractResult<Option<TypedModelPatch>> {
392 if let Some(patch) = self.declutter_mean_of_square(model, node)? {
393 return Ok(Some(patch));
394 }
395 if let Some(patch) = self.declutter_scalar_mul_then_sum(model, node)? {
396 return Ok(Some(patch));
397 }
398 if let Some(patch) = self.declutter_reduce_reduce(model, node)? {
399 return Ok(Some(patch));
400 }
401 Ok(None)
402 }
403
404 fn axes_mapping(
405 &self,
406 inputs: &[&TypedFact],
407 outputs: &[&TypedFact],
408 ) -> TractResult<AxesMapping> {
409 let mut letters = 'a'..;
410 let axes = (0..inputs[0].rank())
411 .flat_map(|ix| {
412 if self.axes.contains(&ix) {
413 tvec!(
414 Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
415 .input(0, ix),
416 Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
417 .output(0, ix),
418 )
419 } else {
420 tvec!(Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
421 .input(0, ix)
422 .output(0, ix))
423 }
424 .into_iter()
425 })
426 .collect_vec();
427 AxesMapping::new(1, 1, axes)
428 }
429
430 fn change_axes(
431 &self,
432 model: &TypedModel,
433 node: &TypedNode,
434 _io: InOut,
435 change: &AxisOp,
436 ) -> TractResult<Option<AxisChangeConsequence>> {
437 let mut axes = tvec!();
438 for reduced in &self.axes {
439 if let Some(axis) = change.transform_axis(*reduced) {
440 axes.push(axis);
441 } else {
442 return Ok(None);
443 }
444 }
445 axes.sort();
446 let op = Some(Box::new(Self { axes, ..self.clone() }) as _);
447 Ok(Some(AxisChangeConsequence::new(model, node, op, change)))
448 }
449
450 fn slice(
451 &self,
452 patch: &mut TypedModelPatch,
453 _model: &TypedModel,
454 node: &TypedNode,
455 _prefix: &str,
456 inputs: &[OutletId],
457 output_axis: usize,
458 _start: &TDim,
459 _end: &TDim,
460 ) -> TractResult<Option<TVec<OutletId>>> {
461 if self.axes.contains(&output_axis) {
462 return Ok(None);
463 }
464 patch.wire_node(&node.name, &node.op, inputs).map(Some)
465 }
466
467 as_op!();
468}
469
470impl Reduce {
471 fn declutter_reduce_reduce(
472 &self,
473 model: &TypedModel,
474 node: &TypedNode,
475 ) -> TractResult<Option<TypedModelPatch>> {
476 let Some(prec) = model.linear_prec(node.id)? else {
477 return Ok(None);
478 };
479 let Some(prec_reduce) = prec.op_as::<Self>() else {
480 return Ok(None);
481 };
482 use Reducer::*;
483 if prec_reduce.reducer != self.reducer || ![Sum, Prod, Min, Max].contains(&self.reducer) {
484 return Ok(None);
485 }
486 let mut patch = TypedModelPatch::default();
487 let wire = patch.tap_model(model, prec.inputs[0])?;
488 let wire = patch.wire_node(
489 &node.name,
490 Self {
491 reducer: self.reducer,
492 axes: prec_reduce
493 .axes
494 .iter()
495 .chain(self.axes.iter())
496 .copied()
497 .sorted()
498 .dedup()
499 .collect(),
500 },
501 &[wire],
502 )?;
503 patch.shunt_outside(model, node.id.into(), wire[0])?;
504 Ok(Some(patch))
505 }
506
507 fn declutter_scalar_mul_then_sum(
508 &self,
509 model: &TypedModel,
510 node: &TypedNode,
511 ) -> TractResult<Option<TypedModelPatch>> {
512 if self.reducer == Reducer::Sum {
513 let Some(prec) = model.linear_prec(node.id)? else {
514 return Ok(None);
515 };
516 let Some(prec_bin) = prec.op_as::<TypedBinOp>() else {
517 return Ok(None);
518 };
519 if !prec_bin.0.is::<Mul>() {
520 return Ok(None);
521 }
522 let mul_input_fact = model.node_input_facts(prec.id)?;
523 let Some(scalar_slot) = mul_input_fact
524 .iter()
525 .position(|f| f.konst.as_ref().is_some_and(|k| k.volume() == 1))
526 else {
527 return Ok(None);
528 };
529 let mut patch = TypedModelPatch::default();
530 let scalar = patch.tap_model(model, prec.inputs[scalar_slot])?;
531 let wire = patch.tap_model(model, prec.inputs[1 - scalar_slot])?;
532 let wire = patch.wire_node(&node.name, self.clone(), &[wire])?[0];
533 let wire = patch.wire_node(&prec.name, prec_bin.clone(), &[wire, scalar])?[0];
534 patch.shunt_outside(model, node.id.into(), wire)?;
535 return Ok(Some(patch));
536 }
537 Ok(None)
538 }
539
540 fn declutter_mean_of_square(
541 &self,
542 model: &TypedModel,
543 node: &TypedNode,
544 ) -> TractResult<Option<TypedModelPatch>> {
545 if self.reducer == Reducer::Sum {
546 let Some(prec) = model.linear_prec(node.id)? else {
547 return Ok(None);
548 };
549 let Some(prec_ew) = prec.op_as::<ElementWiseOp>() else {
550 return Ok(None);
551 };
552 if !prec_ew.0.is::<Square>() {
553 return Ok(None);
554 }
555 if node.outputs.len() != 1 || node.outputs[0].successors.len() != 1 {
556 return Ok(None);
557 }
558 let our_inlet = node.outputs[0].successors[0];
559 let succ = model.node(our_inlet.node);
560 let Some(succ_bin) = succ.op_as::<TypedBinOp>() else {
561 return Ok(None);
562 };
563 if !succ_bin.0.is::<Mul>() {
564 return Ok(None);
565 }
566 let other = succ.inputs[1 - our_inlet.slot];
567 let Some(other_konst) = model.outlet_fact(other)?.uniform.as_ref() else {
568 return Ok(None);
569 };
570 let norm: TDim = self.axes.iter().map(|&ax| &prec.outputs[0].fact.shape[ax]).product();
571 let Some(norm) = norm.as_i64() else {
572 return Ok(None);
573 };
574 if norm == 0 {
575 return Ok(None);
576 }
577 let norm = tensor0((norm as f32).recip());
578 if other_konst.close_enough(&norm, Approximation::Close).is_ok() {
579 let mut patch = TypedModelPatch::default();
580 let wire = patch.tap_model(model, prec.inputs[0])?;
581 let wire = patch.wire_node(
582 &node.name,
583 Reduce::new(self.axes.clone(), Reducer::MeanOfSquares),
584 &[wire],
585 )?[0];
586 patch.shunt_outside(model, succ.id.into(), wire)?;
587 return Ok(Some(patch));
588 }
589 }
590 Ok(None)
591 }
592}
593
594pub fn expand_mean_of_squares(
595 _ctx: &(),
596 model: &TypedModel,
597 node: &TypedNode,
598 name: &str,
599 op: &Reduce,
600) -> TractResult<Option<TypedModelPatch>> {
601 if op.reducer == Reducer::MeanOfSquares {
602 let mut patch = TypedModelPatch::default();
603 let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?);
604 let input_fact = model.outlet_fact(node.inputs[0])?;
605 let dt = input_fact.datum_type;
606 if dt != f32::datum_type() {
607 wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?;
608 }
609 wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?;
610 wire = patch.wire_node(
611 format!("{name}.sum"),
612 Reduce::new(op.axes.clone(), Reducer::Sum),
613 &wire,
614 )?;
615 let card = input_fact
616 .shape
617 .iter()
618 .enumerate()
619 .filter(|(ix, _dim)| op.axes.contains(ix))
620 .map(|(_ix, dim)| dim)
621 .product::<TDim>();
622 let card = patch.add_const(format!("{name}.card"), tensor0(card))?;
623 let card =
624 patch.wire_node(format!("{name}.card_to_f32"), cast(f32::datum_type()), &[card])?;
625
626 wire = wire_with_rank_broadcast(
627 format!("{name}.norm"),
628 &mut patch,
629 div(),
630 &[wire[0], card[0]],
631 )?;
632 if dt != f32::datum_type() {
633 wire = patch.wire_node(format!("{name}.from_f32"), cast(dt), &wire)?;
634 }
635 patch.shunt_outside(model, node.id.into(), wire[0])?;
636 Ok(Some(patch))
637 } else {
638 Ok(None)
639 }
640}