1use tract_hir::internal::*;
2use tract_hir::tract_core::ops::scan::ScanInfo;
3
4use crate::model::{OnnxOpRegister, ParsingContext};
5use crate::pb::*;
6
7pub fn register_all_ops(reg: &mut OnnxOpRegister) {
8 reg.insert("CumSum", cumsum);
9}
10
11fn cumsum(
12 _ctx: &ParsingContext,
13 node: &NodeProto,
14) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
15 let reverse = node.get_attr_opt::<i64>("reverse")? == Some(1);
16 let exclusive = node.get_attr_opt::<i64>("exclusive")? == Some(1);
17 Ok((expand(CumSum { reverse, exclusive }), vec![]))
18}
19
20#[derive(Debug, Clone, Hash)]
21pub struct CumSum {
22 pub reverse: bool,
23 pub exclusive: bool,
24}
25
26impl Expansion for CumSum {
27 fn name(&self) -> StaticName {
28 "CumSum".into()
29 }
30
31 fn wire(
32 &self,
33 prefix: &str,
34 model: &mut TypedModel,
35 inputs: &[OutletId],
36 ) -> TractResult<TVec<OutletId>> {
37 use tract_core::ops::scan;
38 let axis =
39 model.outlet_fact(inputs[1])?.konst.as_ref().context("Axis expected to be a const")?;
40 let axis = axis.cast_to_scalar::<i64>()?;
41 let data = model.outlet_fact(inputs[0])?.clone();
42 let mut var_shape = data.shape.clone();
43 let axis = if axis < 0 { (axis + data.rank() as i64) as usize } else { axis as usize };
44 let zero = model.add_const(
45 format!("{prefix}.zero"),
46 Tensor::zero_dt(data.datum_type, &[])?.into_arc_tensor(),
47 )?;
48 var_shape.set(axis, 1.to_dim());
49 let init = model.wire_node(
50 format!("{prefix}.init"),
51 tract_core::ops::array::MultiBroadcastTo::new(var_shape.clone()),
52 &[zero],
53 )?[0];
54 let chunk = if self.reverse { -1 } else { 1 };
55 let input_mapping =
56 vec![scan::InputMapping::Scan(ScanInfo { axis, chunk }), scan::InputMapping::State];
57 let output_mapping = vec![
61 scan::OutputMapping {
62 scan: Some((0, ScanInfo { axis, chunk })),
63 full_dim_hint: None,
64 last_value_slot: None,
65 state: true,
66 },
67 scan::OutputMapping {
68 scan: Some((1, ScanInfo { axis, chunk })),
69 full_dim_hint: None,
70 last_value_slot: None,
71 state: false,
72 },
73 ];
74 let mut body = TypedModel::default();
75 let var_fact = data.datum_type.fact(var_shape);
76 let x = body.add_source("scan_input", var_fact.clone())?;
77 let acc = body.add_source("acc_input", var_fact)?;
78 let sum = body.wire_node("add", tract_core::ops::math::add(), &[x, acc])?[0];
79 body.set_output_outlets(&[sum, acc])?;
80 let scan = scan::Scan::new(body, input_mapping, output_mapping, 0)?;
81 let wires = model.wire_node(prefix, scan, &[inputs[0], init])?;
82 let output = wires[self.exclusive as usize];
83 Ok(tvec![output])
84 }
85
86 fn rules<'r, 'p: 'r, 's: 'r>(
87 &'s self,
88 s: &mut Solver<'r>,
89 inputs: &'p [TensorProxy],
90 outputs: &'p [TensorProxy],
91 ) -> InferenceResult {
92 check_input_arity(inputs, 2)?;
93 check_output_arity(outputs, 1)?;
94 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
95 s.equals(&inputs[0].shape, &outputs[0].shape)?;
96 s.equals(&inputs[1].rank, 0)?;
97 Ok(())
98 }
99}