1use crate::internal::Axis;
2use crate::internal::*;
3use crate::ops::binary::TypedBinOp;
4use crate::ops::cast::cast;
5use crate::ops::change_axes::wire_with_rank_broadcast;
6use crate::ops::element_wise::ElementWiseOp;
7use crate::ops::math::{Mul, Square, div, square};
8use std::convert::TryFrom;
9use std::iter::Sum;
10use std::mem::transmute;
11use tract_data::internal::ClampCast;
12use tract_data::itertools::Itertools;
13use tract_ndarray::prelude::*;
14use tract_num_traits::{AsPrimitive, Bounded};
15
16macro_rules! r {
17 ($($path:ident)::* ($dt:expr) ($($args:expr),*)) => {
18 match $dt {
19 DatumType::U8 => $($path)::*::<u8,_,_,_>($($args),*),
20 DatumType::I8 => $($path)::*::<i8,_,_,_>($($args),*),
21 DatumType::U16 => $($path)::*::<u16,_,_,_>($($args),*),
22 DatumType::I16 => $($path)::*::<i16,_,_,_>($($args),*),
23 DatumType::I32 => $($path)::*::<i32,_,_,_>($($args),*),
24 DatumType::I64 => $($path)::*::<i64,_,_,_>($($args),*),
25 DatumType::F16 => $($path)::*::<f16,_,_,_>($($args),*),
26 DatumType::F32 => $($path)::*::<f32,_,_,_>($($args),*),
27 DatumType::F64 => $($path)::*::<f64,_,_,_>($($args),*),
28 DatumType::QI8(_) => $($path)::*::<i8,_,_,_>($($args),*),
29 DatumType::QU8(_) => $($path)::*::<u8,_,_,_>($($args),*),
30 _ => bail!("{:?} is not a number", $dt)
31 }
32 };
33 ($($path:ident)::* ($dt:expr) ($($args:expr),*); $($q_path:ident)::* ($($q_args:expr),*)) => {
34 match $dt {
35 DatumType::U8 => $($path)::*::<u8,_,_,_>($($args),*),
36 DatumType::I8 => $($path)::*::<i8,_,_,_>($($args),*),
37 DatumType::U16 => $($path)::*::<u16,_,_,_>($($args),*),
38 DatumType::I16 => $($path)::*::<i16,_,_,_>($($args),*),
39 DatumType::I32 => $($path)::*::<i32,_,_,_>($($args),*),
40 DatumType::I64 => $($path)::*::<i64,_,_,_>($($args),*),
41 DatumType::F16 => $($path)::*::<f16,_,_,_>($($args),*),
42 DatumType::F32 => $($path)::*::<f32,_,_,_>($($args),*),
43 DatumType::F64 => $($path)::*::<f64,_,_,_>($($args),*),
44 DatumType::QI8(_) => $($q_path)::*::<i8,_,_,_>($($q_args),*),
45 DatumType::QU8(_) => $($q_path)::*::<u8,_,_,_>($($q_args),*),
46 _ => bail!("{:?} is not a number", $dt)
47 }
48 }
49}
50
51#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
52pub enum Reducer {
53 ArgMax(bool), ArgMin(bool),
55 Max,
56 Min,
57 Prod,
58 Sum,
59 MeanOfSquares,
60 All,
61 Any,
62}
63
64impl Reducer {
65 pub fn reduce(&self, axes: &[usize], input: &Tensor) -> TractResult<Tensor> {
66 use Reducer::*;
67 let dt = input.datum_type();
68 let output_shape: Vec<usize> = input
69 .shape()
70 .iter()
71 .enumerate()
72 .map(|(ax, &d)| if axes.contains(&ax) { 1 } else { d })
73 .collect();
74 let (zp, scale) = input.datum_type().zp_scale();
75 unsafe {
76 let mut t = match self {
77 ArgMax(last) => {
78 r!(Self::reduce_t(dt)(self, axes, &output_shape, input, argmax_t, *last))
79 }
80 ArgMin(last) => {
81 r!(Self::reduce_t(dt)(self, axes, &output_shape, input, argmin_t, *last))
82 }
83 Min => r!(Self::reduce_t(dt)(self, axes, &output_shape, input, min_t, ())),
84 Max => r!(Self::reduce_t(dt)(self, axes, &output_shape, input, max_t, ())),
85 Prod => {
86 r!(Self::reduce_t(dt)(self, axes, &output_shape, input, prod_t, ()); Self::reduce_t(self, axes, &output_shape, input, q_prod_t, (zp, scale)))
87 }
88 Sum => {
89 if dt.is_float() {
90 dispatch_floatlike!(Self::sum(dt)(self, axes, input))
91 } else {
92 r!(Self::reduce_t(dt)(
93 self,
94 axes,
95 &output_shape,
96 input,
97 q_sum_t,
98 (zp, scale)
99 ))
100 }
101 }
102 MeanOfSquares => self.mean_of_squares(axes, input)?,
103 All => Self::reduce_t(self, axes, &output_shape, input, all_bool, ()),
104 Any => Self::reduce_t(self, axes, &output_shape, input, any_bool, ()),
105 };
106 if input.datum_type().is_quantized()
107 && input.datum_type().unquantized() == t.datum_type().unquantized()
108 {
109 t.set_datum_type(input.datum_type());
110 }
111 Ok(t)
112 }
113 }
114
115 unsafe fn reduce_t<T, TO, F, A>(
116 &self,
117 axes: &[usize],
118 output_shape: &[usize],
119 input_tensor: &Tensor,
120 f: F,
121 args: A,
122 ) -> Tensor
123 where
124 F: for<'a> Fn(ArrayViewD<'a, T>, A) -> TO,
125 T: Copy + Datum,
126 TO: Copy + Datum,
127 A: Copy,
128 {
129 use ndarray::*;
130 let input = unsafe { input_tensor.to_array_view_unchecked::<T>() };
131 let result = Array::from_shape_fn(output_shape, |coords| {
132 let slice_spec: Vec<SliceInfoElem> = coords
133 .slice()
134 .iter()
135 .enumerate()
136 .map(|(ax, &d)| if axes.contains(&ax) { (..).into() } else { d.into() })
137 .collect();
138 let slice_info = SliceInfo::<_, IxDyn, IxDyn>::try_from(slice_spec).unwrap();
139 let slice = input.slice(&slice_info);
140 f(slice, args)
141 });
142 result.into_tensor()
143 }
144
145 unsafe fn sum<T>(&self, axes: &[usize], input: &Tensor) -> Tensor
150 where
151 T: Copy + Datum + num_traits::Zero + Sum,
152 f16: AsPrimitive<T>,
153 f32: AsPrimitive<T>,
154 {
155 if axes.len() == 0 {
156 return input.to_owned();
157 }
158
159 if axes.len() > 1 || axes[0] != input.rank() - 1 {
161 let mut operative_axes = vec![];
162 let mut operative_shape: Vec<usize> = vec![];
163 for (ix, dim) in input.shape().iter().enumerate() {
164 if ix > 0 && axes.contains(&ix) && axes.contains(&(ix - 1)) {
166 *operative_shape.last_mut().unwrap() *= *dim;
167 } else if axes.contains(&ix) {
168 operative_axes.push(operative_shape.len());
169 operative_shape.push(*dim);
170 } else {
171 operative_shape.push(*dim);
172 }
173 }
174 let mut output = unsafe {
175 input
176 .to_array_view_unchecked::<T>()
177 .into_shape_with_order(operative_shape)
178 .unwrap()
179 .sum_axis(Axis(*operative_axes.iter().max().unwrap()))
180 };
181
182 for axis in operative_axes.iter().rev().skip(1) {
183 output = output.sum_axis(Axis(*axis));
184 }
185
186 let mut output = output.into_tensor();
187
188 for &axis in axes {
189 output.insert_axis(axis).unwrap();
190 }
191
192 output
193 } else {
194 let mut output: Option<ArrayD<T>> = None;
195 for axis in axes.iter().copied() {
196 let input_view = output
197 .as_ref()
198 .map(|o| o.view())
199 .unwrap_or_else(|| unsafe { input.to_array_view_unchecked::<T>() });
200
201 let reduced_dim = input_view.shape()[axis];
203 let input_stride = input_view.strides()[axis] as usize;
204 let output_shape = input_view
205 .shape()
206 .iter()
207 .enumerate()
208 .map(|(idx, dim)| if idx != axis { *dim } else { 1 })
209 .collect_vec();
210
211 output = Some(ArrayD::from_shape_fn(output_shape.clone(), |coords| {
212 let mut view = input_view.view();
213 for ix in 0..output_shape.len() {
214 if ix != axis {
215 view.collapse_axis(Axis(ix), coords[ix]);
216 }
217 }
218
219 if let Some(slice) = view.as_slice() {
220 if T::datum_type() == f16::datum_type() {
221 let slice: &[f16] = unsafe { std::mem::transmute(slice) };
222 (tract_linalg::ops().sum_f16)()
223 .run_with_params(slice, ())
224 .unwrap()
225 .as_()
226 } else if T::datum_type() == f32::datum_type() {
227 let slice: &[f32] = unsafe { std::mem::transmute(slice) };
228 (tract_linalg::ops().sum_f32)()
229 .run_with_params(slice, ())
230 .unwrap()
231 .as_()
232 } else {
233 slice.iter().cloned().sum::<T>()
234 }
235 } else {
236 let first: *const T = &input_view[coords];
237 let mut sum = T::zero();
238 for i in 0..reduced_dim {
239 sum = sum + unsafe { *(first.add(i * input_stride)) };
240 }
241 sum
242 }
243 }));
244 }
245 output.unwrap().into_tensor()
246 }
247 }
248
249 fn mean_of_squares(&self, axis: &[usize], input: &Tensor) -> TractResult<Tensor> {
250 let dt = input.datum_type();
251 let mut input = input.cast_to::<f32>()?.into_owned();
252 input.try_as_plain_mut()?.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x = *x * *x);
253 let mut output = unsafe { self.sum::<f32>(axis, &input) };
254 let norm = output.len() as f32 / input.len() as f32;
255 output.try_as_plain_mut()?.as_slice_mut::<f32>()?.iter_mut().for_each(|x| *x *= norm);
256 Ok(output.cast_to_dt(dt)?.into_owned())
257 }
258}
259
260fn argmax_t<T>(v: ArrayViewD<T>, last: bool) -> i64
261where
262 T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
263{
264 v.iter()
265 .copied()
266 .enumerate()
267 .fold(
268 (0usize, T::min_value()),
269 |acc, v| {
270 if v.1 > acc.1 || (last && acc.1 == v.1) { v } else { acc }
271 },
272 )
273 .0 as i64
274}
275
276fn argmin_t<T>(v: ArrayViewD<T>, last: bool) -> i64
277where
278 T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
279{
280 v.iter()
281 .copied()
282 .enumerate()
283 .fold(
284 (0usize, T::max_value()),
285 |acc, v| {
286 if v.1 < acc.1 || (last && acc.1 == v.1) { v } else { acc }
287 },
288 )
289 .0 as i64
290}
291
292fn max_t<T>(v: ArrayViewD<T>, _: ()) -> T
293where
294 T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
295{
296 if T::datum_type() == f32::datum_type()
297 && let Some(slice) = v.as_slice()
298 {
299 let slice = unsafe { transmute::<&[T], &[f32]>(slice) };
300 (tract_linalg::ops().max_f32)().run(slice).unwrap();
301 }
302 v.fold(T::min_value(), |acc, &v| if acc > v { acc } else { v })
303}
304
305fn min_t<T>(v: ArrayViewD<T>, _: ()) -> T
306where
307 T: Copy + Datum + num_traits::Bounded + ::std::cmp::PartialOrd,
308{
309 v.fold(T::max_value(), |acc, &v| if acc < v { acc } else { v })
310}
311
312fn prod_t<T>(v: ArrayViewD<T>, _: ()) -> T
313where
314 T: Copy + Datum + num_traits::One,
315{
316 v.fold(T::one(), |acc, &v| acc * v)
317}
318
319fn q_prod_t<T>(v: ArrayViewD<T>, zp_scale: (i32, f32)) -> T
320where
321 T: Copy + num_traits::AsPrimitive<f32> + Bounded + Datum,
322 f32: num_traits::AsPrimitive<T>,
323{
324 let (zp, scale) = zp_scale;
325 (v.fold(1f32, |acc, &v| acc * (v.as_() - zp as f32)) * scale.powi(v.len() as i32 - 1)
326 + zp as f32)
327 .clamp_cast()
328}
329
330fn q_sum_t<T>(v: ArrayViewD<T>, zp_scale: (i32, f32)) -> T
331where
332 T: Copy + Bounded + num_traits::AsPrimitive<i32> + Datum,
333 i32: num_traits::AsPrimitive<T>,
334{
335 let (zp, _) = zp_scale;
336 (v.fold(0i32, |acc, &v| acc + v.as_()) - zp * (v.len() as i32 - 1)).clamp_cast()
337}
338
339fn all_bool(v: ArrayViewD<bool>, _: ()) -> bool {
340 v.iter().all(|v| *v)
341}
342
343fn any_bool(v: ArrayViewD<bool>, _: ()) -> bool {
344 v.iter().any(|v| *v)
345}
346
347#[derive(Clone, Debug, new, Hash, PartialEq, Eq)]
348pub struct Reduce {
349 pub axes: TVec<usize>,
350 pub reducer: Reducer,
351}
352
353impl Op for Reduce {
354 fn name(&self) -> StaticName {
355 format!("Reduce<{:?}>", self.reducer).into()
356 }
357 fn info(&self) -> TractResult<Vec<String>> {
358 Ok(vec![format!("axes: {:?}", self.axes)])
359 }
360 op_as_typed_op!();
361}
362
363impl EvalOp for Reduce {
364 fn is_stateless(&self) -> bool {
365 true
366 }
367
368 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
369 Ok(tvec!(self.reducer.reduce(&self.axes, &inputs[0])?.into()))
370 }
371}
372
373impl TypedOp for Reduce {
374 fn input_roi(
375 &self,
376 model: &TypedModel,
377 node: &TypedNode,
378 ) -> TractResult<Option<TVec<Option<TDim>>>> {
379 crate::optim::propagate_roi::bubble_roi(model, node)
380 }
381
382 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
383 ensure!(self.axes.iter().tuple_windows().all(|(a, b)| a < b));
384 if inputs[0].datum_type == TDim::datum_type() {
385 bail!("Reduce input must be cast from TDim to i64 beforehand")
386 }
387 let mut shape: TVec<_> = inputs[0].shape.to_tvec();
388 for &ax in &self.axes {
389 shape[ax] = 1.to_dim();
390 }
391 let dt = if let Reducer::ArgMax(_) | Reducer::ArgMin(_) = self.reducer {
392 DatumType::I64
393 } else {
394 inputs[0].datum_type
395 };
396 Ok(tvec!(dt.fact(shape)))
397 }
398
399 fn declutter(
400 &self,
401 model: &TypedModel,
402 node: &TypedNode,
403 ) -> TractResult<Option<TypedModelPatch>> {
404 if let Some(patch) = self.declutter_mean_of_square(model, node)? {
405 return Ok(Some(patch));
406 }
407 if let Some(patch) = self.declutter_scalar_mul_then_sum(model, node)? {
408 return Ok(Some(patch));
409 }
410 if let Some(patch) = self.declutter_reduce_reduce(model, node)? {
411 return Ok(Some(patch));
412 }
413 if let Some(patch) = super::rms_norm::detect_rms_norm(self, model, node)? {
414 return Ok(Some(patch));
415 }
416 Ok(None)
417 }
418
419 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
420 let dt = inputs[0].datum_type;
421 let count: TDim = inputs[0].shape.iter().product();
422 match self.reducer {
423 Reducer::Sum
424 | Reducer::Prod
425 | Reducer::Min
426 | Reducer::Max
427 | Reducer::All
428 | Reducer::Any => Ok(tvec!((Cost::FMA(dt), count))),
429 Reducer::MeanOfSquares => Ok(tvec!((Cost::FMA(dt), count * 2))),
430 Reducer::ArgMax(_) | Reducer::ArgMin(_) => Ok(tvec!((Cost::FMA(dt), count))),
431 }
432 }
433
434 fn axes_mapping(
435 &self,
436 inputs: &[&TypedFact],
437 outputs: &[&TypedFact],
438 ) -> TractResult<AxesMapping> {
439 let mut letters = 'a'..;
440 let axes = (0..inputs[0].rank())
441 .flat_map(|ix| {
442 if self.axes.contains(&ix) {
443 tvec!(
444 Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
445 .input(0, ix),
446 Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
447 .output(0, ix),
448 )
449 } else {
450 tvec!(
451 Axis::new(letters.next().unwrap(), inputs.len(), outputs.len())
452 .input(0, ix)
453 .output(0, ix)
454 )
455 }
456 .into_iter()
457 })
458 .collect_vec();
459 AxesMapping::new(1, 1, axes)
460 }
461
462 fn change_axes(
463 &self,
464 model: &TypedModel,
465 node: &TypedNode,
466 _io: InOut,
467 change: &AxisOp,
468 ) -> TractResult<Option<AxisChangeConsequence>> {
469 let mut axes = tvec!();
470 for reduced in &self.axes {
471 rule_if_some!(axis = change.transform_axis(*reduced));
472 axes.push(axis);
473 }
474 axes.sort();
475 let op = Some(Box::new(Self { axes, ..self.clone() }) as _);
476 Ok(Some(AxisChangeConsequence::new(model, node, op, change)))
477 }
478
479 fn slice(
480 &self,
481 patch: &mut TypedModelPatch,
482 _model: &TypedModel,
483 node: &TypedNode,
484 _prefix: &str,
485 inputs: &[OutletId],
486 output_axis: usize,
487 _start: &TDim,
488 _end: &TDim,
489 ) -> TractResult<Option<TVec<OutletId>>> {
490 rule_if!(!self.axes.contains(&output_axis));
491 patch.wire_node(&node.name, &node.op, inputs).map(Some)
492 }
493
494 as_op!();
495}
496
497impl Reduce {
498 fn declutter_reduce_reduce(
499 &self,
500 model: &TypedModel,
501 node: &TypedNode,
502 ) -> TractResult<Option<TypedModelPatch>> {
503 use Reducer::*;
504 rule_if_some!(prec = model.linear_prec(node.id)?);
505 rule_if_some!(prec_reduce = prec.op_as::<Self>());
506 rule_if!(prec_reduce.reducer == self.reducer);
507 rule_if!([Sum, Prod, Min, Max].contains(&self.reducer));
508 let mut patch = TypedModelPatch::default();
509 let wire = patch.tap_model(model, prec.inputs[0])?;
510 let wire = patch.wire_node(
511 &node.name,
512 Self {
513 reducer: self.reducer,
514 axes: prec_reduce
515 .axes
516 .iter()
517 .chain(self.axes.iter())
518 .copied()
519 .sorted()
520 .dedup()
521 .collect(),
522 },
523 &[wire],
524 )?;
525 patch.shunt_outside(model, node.id.into(), wire[0])?;
526 Ok(Some(patch))
527 }
528
529 fn declutter_scalar_mul_then_sum(
530 &self,
531 model: &TypedModel,
532 node: &TypedNode,
533 ) -> TractResult<Option<TypedModelPatch>> {
534 if self.reducer == Reducer::Sum {
535 rule_if_some!(prec = model.linear_prec(node.id)?);
536 rule_if_some!(prec_bin = prec.op_as::<TypedBinOp>());
537 rule_if!(prec_bin.0.is::<Mul>());
538 let mul_input_fact = model.node_input_facts(prec.id)?;
539 rule_if_some!(
540 scalar_slot = mul_input_fact
541 .iter()
542 .position(|f| f.konst.as_ref().is_some_and(|k| k.volume() == 1))
543 );
544 let mut patch = TypedModelPatch::default();
545 let scalar = patch.tap_model(model, prec.inputs[scalar_slot])?;
546 let wire = patch.tap_model(model, prec.inputs[1 - scalar_slot])?;
547 let wire = patch.wire_node(&node.name, self.clone(), &[wire])?[0];
548 let wire = patch.wire_node(&prec.name, prec_bin.clone(), &[wire, scalar])?[0];
549 patch.shunt_outside(model, node.id.into(), wire)?;
550 return Ok(Some(patch));
551 }
552 Ok(None)
553 }
554
555 fn declutter_mean_of_square(
556 &self,
557 model: &TypedModel,
558 node: &TypedNode,
559 ) -> TractResult<Option<TypedModelPatch>> {
560 if self.reducer == Reducer::Sum {
561 rule_if_some!(prec = model.linear_prec(node.id)?);
562 rule_if_some!(prec_ew = prec.op_as::<ElementWiseOp>());
563 rule_if!(prec_ew.0.is::<Square>());
564 rule_if!(node.outputs.len() == 1);
565 rule_if!(node.outputs[0].successors.len() == 1);
566 let our_inlet = node.outputs[0].successors[0];
567 let succ = model.node(our_inlet.node);
568 rule_if_some!(succ_bin = succ.op_as::<TypedBinOp>());
569 rule_if!(succ_bin.0.is::<Mul>());
570 let other = succ.inputs[1 - our_inlet.slot];
571 rule_if_some!(other_konst = model.outlet_fact(other)?.uniform.as_ref());
572 let norm: TDim = self.axes.iter().map(|&ax| &prec.outputs[0].fact.shape[ax]).product();
573 rule_if_some!(norm = norm.as_i64());
574 rule_if!(norm > 0);
575 let norm = tensor0((norm as f32).recip());
576 if other_konst.close_enough(&norm, Approximation::Close).is_ok() {
577 let mut patch = TypedModelPatch::default();
578 let wire = patch.tap_model(model, prec.inputs[0])?;
579 let wire = patch.wire_node(
580 &node.name,
581 Reduce::new(self.axes.clone(), Reducer::MeanOfSquares),
582 &[wire],
583 )?[0];
584 patch.shunt_outside(model, succ.id.into(), wire)?;
585 return Ok(Some(patch));
586 }
587 }
588 Ok(None)
589 }
590}
591
592pub fn expand_mean_of_squares(
593 _ctx: &(),
594 model: &TypedModel,
595 node: &TypedNode,
596 name: &str,
597 op: &Reduce,
598) -> TractResult<Option<TypedModelPatch>> {
599 rule_if!(op.reducer == Reducer::MeanOfSquares);
600 let mut patch = TypedModelPatch::default();
601 let mut wire = tvec!(patch.tap_model(model, node.inputs[0])?);
602 let input_fact = model.outlet_fact(node.inputs[0])?;
603 let dt = input_fact.datum_type;
604 if dt != f32::datum_type() {
605 wire = patch.wire_node(format!("{name}.to_f32"), cast(f32::datum_type()), &wire)?;
606 }
607 wire = patch.wire_node(format!("{name}.sqr"), square(), &wire)?;
608 wire = patch.wire_node(
609 format!("{name}.sum"),
610 Reduce::new(op.axes.clone(), Reducer::Sum),
611 &wire,
612 )?;
613 let card = input_fact
614 .shape
615 .iter()
616 .enumerate()
617 .filter(|(ix, _dim)| op.axes.contains(ix))
618 .map(|(_ix, dim)| dim)
619 .product::<TDim>();
620 let card = patch.add_const(format!("{name}.card"), tensor0(card))?;
621 let card = patch.wire_node(format!("{name}.card_to_f32"), cast(f32::datum_type()), &[card])?;
622
623 wire =
624 wire_with_rank_broadcast(format!("{name}.norm"), &mut patch, div(), &[wire[0], card[0]])?;
625 if dt != f32::datum_type() {
626 wire = patch.wire_node(format!("{name}.from_f32"), cast(dt), &wire)?;
627 }
628 patch.shunt_outside(model, node.id.into(), wire[0])?;
629 Ok(Some(patch))
630}