tract_onnx/ops/
cumsum.rs

1use tract_hir::internal::*;
2use tract_hir::tract_core::ops::scan::ScanInfo;
3
4use crate::model::{OnnxOpRegister, ParsingContext};
5use crate::pb::*;
6
7pub fn register_all_ops(reg: &mut OnnxOpRegister) {
8    reg.insert("CumSum", cumsum);
9}
10
11fn cumsum(
12    _ctx: &ParsingContext,
13    node: &NodeProto,
14) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
15    let reverse = node.get_attr_opt::<i64>("reverse")? == Some(1);
16    let exclusive = node.get_attr_opt::<i64>("exclusive")? == Some(1);
17    Ok((expand(CumSum { reverse, exclusive }), vec![]))
18}
19
20#[derive(Debug, Clone, Hash)]
21pub struct CumSum {
22    pub reverse: bool,
23    pub exclusive: bool,
24}
25
26impl Expansion for CumSum {
27    fn name(&self) -> StaticName {
28        "CumSum".into()
29    }
30
31    fn wire(
32        &self,
33        prefix: &str,
34        model: &mut TypedModel,
35        inputs: &[OutletId],
36    ) -> TractResult<TVec<OutletId>> {
37        use tract_core::ops::scan;
38        let axis =
39            model.outlet_fact(inputs[1])?.konst.as_ref().context("Axis expected to be a const")?;
40        let axis = axis.cast_to_scalar::<i64>()?;
41        let data = model.outlet_fact(inputs[0])?.clone();
42        let mut var_shape = data.shape.clone();
43        let axis = if axis < 0 { (axis + data.rank() as i64) as usize } else { axis as usize };
44        let zero = model.add_const(
45            format!("{prefix}.zero"),
46            Tensor::zero_dt(data.datum_type, &[])?.into_arc_tensor(),
47        )?;
48        var_shape.set(axis, 1.to_dim());
49        let init = model.wire_node(
50            format!("{prefix}.init"),
51            tract_core::ops::array::MultiBroadcastTo::new(var_shape.clone()),
52            &[zero],
53        )?[0];
54        let chunk = if self.reverse { -1 } else { 1 };
55        let input_mapping =
56            vec![scan::InputMapping::Scan(ScanInfo { axis, chunk }), scan::InputMapping::State];
57        // outputs will be
58        // acc + x (!exclusive)
59        // acc input (exclusive)
60        let output_mapping = vec![
61            scan::OutputMapping {
62                scan: Some((0, ScanInfo { axis, chunk })),
63                full_dim_hint: None,
64                last_value_slot: None,
65                state: true,
66            },
67            scan::OutputMapping {
68                scan: Some((1, ScanInfo { axis, chunk })),
69                full_dim_hint: None,
70                last_value_slot: None,
71                state: false,
72            },
73        ];
74        let mut body = TypedModel::default();
75        let var_fact = data.datum_type.fact(var_shape);
76        let x = body.add_source("scan_input", var_fact.clone())?;
77        let acc = body.add_source("acc_input", var_fact)?;
78        let sum = body.wire_node("add", tract_core::ops::math::add(), &[x, acc])?[0];
79        body.set_output_outlets(&[sum, acc])?;
80        let scan = scan::Scan::new(body, input_mapping, output_mapping, 0)?;
81        let wires = model.wire_node(prefix, scan, &[inputs[0], init])?;
82        let output = wires[self.exclusive as usize];
83        Ok(tvec![output])
84    }
85
86    fn rules<'r, 'p: 'r, 's: 'r>(
87        &'s self,
88        s: &mut Solver<'r>,
89        inputs: &'p [TensorProxy],
90        outputs: &'p [TensorProxy],
91    ) -> InferenceResult {
92        check_input_arity(inputs, 2)?;
93        check_output_arity(outputs, 1)?;
94        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
95        s.equals(&inputs[0].shape, &outputs[0].shape)?;
96        s.equals(&inputs[1].rank, 0)?;
97        Ok(())
98    }
99}