Skip to main content

tract_tensorflow/ops/math/
reduce.rs

1use tract_hir::internal::*;
2use tract_hir::ops::nn;
3
4use crate::model::ParsingContext;
5use crate::tfpb::tensorflow::NodeDef;
6
7#[derive(Debug, Clone, new, Hash)]
8pub struct Reduce {
9    t: DatumType,
10    t_idx: DatumType,
11    keep_dims: bool,
12    reducer: nn::Reducer,
13}
14
15
16
17pub fn max(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
18    reduce(pb, nn::Reducer::Max)
19}
20
21pub fn mean(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
22    reduce(pb, nn::Reducer::Mean)
23}
24
25pub fn min(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
26    reduce(pb, nn::Reducer::Min)
27}
28
29pub fn prod(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
30    reduce(pb, nn::Reducer::Prod)
31}
32
33pub fn sum(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
34    reduce(pb, nn::Reducer::Sum)
35}
36
37pub fn reduce(pb: &NodeDef, op: nn::Reducer) -> TractResult<Box<dyn InferenceOp>> {
38    let t = pb.get_attr_datum_type("T")?;
39    let t_idx = pb.get_attr_datum_type("Tidx")?;
40    let keep_dims = pb.get_attr_bool("keep_dims")?;
41    Ok(Box::new(Reduce::new(t, t_idx, keep_dims, op)))
42}
43
44impl Op for Reduce {
45    fn name(&self) -> StaticName {
46        format!("{:?}", self.reducer).into()
47    }
48
49    not_a_typed_op!();
50}
51
52impl EvalOp for Reduce {
53    fn is_stateless(&self) -> bool {
54        true
55    }
56
57    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
58        let (input, axes) = args_2!(inputs);
59        let axes: Vec<i64> = axes.cast_to::<i64>()?.as_slice::<i64>()?.to_vec();
60        let op = nn::Reduce::new(Some(axes), self.keep_dims, self.reducer);
61        expand(op).eval(tvec!(input))
62    }
63}
64
65impl InferenceRulesOp for Reduce {
66    fn rules<'r, 'p: 'r, 's: 'r>(
67        &'s self,
68        s: &mut Solver<'r>,
69        inputs: &'p [TensorProxy],
70        outputs: &'p [TensorProxy],
71    ) -> InferenceResult {
72        check_input_arity(inputs, 2)?;
73        check_output_arity(outputs, 1)?;
74        s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
75        if self.keep_dims {
76            s.equals(&inputs[0].rank, &outputs[0].rank)?;
77        } else {
78            s.given(&inputs[1].rank, move |s, rank| {
79                if rank == 1 {
80                    s.equals(
81                        inputs[0].rank.bex().to_dim(),
82                        inputs[1].shape[0].bex() + outputs[0].rank.bex().to_dim(),
83                    )
84                } else {
85                    s.equals(
86                        inputs[0].rank.bex().to_dim(),
87                        outputs[0].rank.bex().to_dim() + 1.to_dim(),
88                    )
89                }
90            })?;
91        }
92        s.given_3(
93            &inputs[0].rank,
94            &outputs[0].rank,
95            &inputs[1].value,
96            move |s, irank, orank, axes| {
97                let axes: TVec<usize> = axes
98                    .cast_to::<i64>()?
99                    .as_slice::<i64>()?
100                    .iter()
101                    .map(|&ax| if ax > 0 { ax } else { ax + irank } as usize)
102                    .collect();
103                let mut od = 0;
104                for id in 0..(irank as usize) {
105                    if axes.contains(&id) {
106                        if self.keep_dims {
107                            s.equals(&outputs[0].shape[od], 1.to_dim())?;
108                            od += 1;
109                        }
110                    } else if od < orank as usize {
111                        s.equals(&outputs[0].shape[od], &inputs[0].shape[id])?;
112                        od += 1;
113                    }
114                }
115                Ok(())
116            },
117        )?;
118        Ok(())
119    }
120
121    fn to_typed(
122        &self,
123        _source: &InferenceModel,
124        node: &InferenceNode,
125        target: &mut TypedModel,
126        mapping: &HashMap<OutletId, OutletId>,
127    ) -> TractResult<TVec<OutletId>> {
128        if let Some(ref axes) = target.outlet_fact(mapping[&node.inputs[1]])?.konst {
129            let axes: Vec<i64> = axes.cast_to::<i64>()?.as_slice::<i64>()?.to_vec();
130            let op = nn::Reduce::new(Some(axes), self.keep_dims, self.reducer);
131            op.wire(&node.name, target, &[mapping[&node.inputs[0]]])
132        } else {
133            bail!("Nees axes to be const")
134        }
135    }
136
137    as_op!();
138}