tract_tensorflow/ops/math/
reduce.rs1use tract_hir::internal::*;
2use tract_hir::ops::nn;
3
4use crate::model::ParsingContext;
5use crate::tfpb::tensorflow::NodeDef;
6
7#[derive(Debug, Clone, new, Hash)]
8pub struct Reduce {
9 t: DatumType,
10 t_idx: DatumType,
11 keep_dims: bool,
12 reducer: nn::Reducer,
13}
14
15
16
17pub fn max(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
18 reduce(pb, nn::Reducer::Max)
19}
20
21pub fn mean(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
22 reduce(pb, nn::Reducer::Mean)
23}
24
25pub fn min(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
26 reduce(pb, nn::Reducer::Min)
27}
28
29pub fn prod(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
30 reduce(pb, nn::Reducer::Prod)
31}
32
33pub fn sum(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
34 reduce(pb, nn::Reducer::Sum)
35}
36
37pub fn reduce(pb: &NodeDef, op: nn::Reducer) -> TractResult<Box<dyn InferenceOp>> {
38 let t = pb.get_attr_datum_type("T")?;
39 let t_idx = pb.get_attr_datum_type("Tidx")?;
40 let keep_dims = pb.get_attr_bool("keep_dims")?;
41 Ok(Box::new(Reduce::new(t, t_idx, keep_dims, op)))
42}
43
44impl Op for Reduce {
45 fn name(&self) -> StaticName {
46 format!("{:?}", self.reducer).into()
47 }
48
49 not_a_typed_op!();
50}
51
52impl EvalOp for Reduce {
53 fn is_stateless(&self) -> bool {
54 true
55 }
56
57 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
58 let (input, axes) = args_2!(inputs);
59 let axes: Vec<i64> = axes.cast_to::<i64>()?.as_slice::<i64>()?.to_vec();
60 let op = nn::Reduce::new(Some(axes), self.keep_dims, self.reducer);
61 expand(op).eval(tvec!(input))
62 }
63}
64
65impl InferenceRulesOp for Reduce {
66 fn rules<'r, 'p: 'r, 's: 'r>(
67 &'s self,
68 s: &mut Solver<'r>,
69 inputs: &'p [TensorProxy],
70 outputs: &'p [TensorProxy],
71 ) -> InferenceResult {
72 check_input_arity(inputs, 2)?;
73 check_output_arity(outputs, 1)?;
74 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
75 if self.keep_dims {
76 s.equals(&inputs[0].rank, &outputs[0].rank)?;
77 } else {
78 s.given(&inputs[1].rank, move |s, rank| {
79 if rank == 1 {
80 s.equals(
81 inputs[0].rank.bex().to_dim(),
82 inputs[1].shape[0].bex() + outputs[0].rank.bex().to_dim(),
83 )
84 } else {
85 s.equals(
86 inputs[0].rank.bex().to_dim(),
87 outputs[0].rank.bex().to_dim() + 1.to_dim(),
88 )
89 }
90 })?;
91 }
92 s.given_3(
93 &inputs[0].rank,
94 &outputs[0].rank,
95 &inputs[1].value,
96 move |s, irank, orank, axes| {
97 let axes: TVec<usize> = axes
98 .cast_to::<i64>()?
99 .as_slice::<i64>()?
100 .iter()
101 .map(|&ax| if ax > 0 { ax } else { ax + irank } as usize)
102 .collect();
103 let mut od = 0;
104 for id in 0..(irank as usize) {
105 if axes.contains(&id) {
106 if self.keep_dims {
107 s.equals(&outputs[0].shape[od], 1.to_dim())?;
108 od += 1;
109 }
110 } else if od < orank as usize {
111 s.equals(&outputs[0].shape[od], &inputs[0].shape[id])?;
112 od += 1;
113 }
114 }
115 Ok(())
116 },
117 )?;
118 Ok(())
119 }
120
121 fn to_typed(
122 &self,
123 _source: &InferenceModel,
124 node: &InferenceNode,
125 target: &mut TypedModel,
126 mapping: &HashMap<OutletId, OutletId>,
127 ) -> TractResult<TVec<OutletId>> {
128 if let Some(ref axes) = target.outlet_fact(mapping[&node.inputs[1]])?.konst {
129 let axes: Vec<i64> = axes.cast_to::<i64>()?.as_slice::<i64>()?.to_vec();
130 let op = nn::Reduce::new(Some(axes), self.keep_dims, self.reducer);
131 op.wire(&node.name, target, &[mapping[&node.inputs[0]]])
132 } else {
133 bail!("Nees axes to be const")
134 }
135 }
136
137 as_op!();
138}