tract_pulse_opl/
mask.rs

1use tract_nnef::internal::*;
2use tract_nnef::ser::tdim;
3use tract_nnef::tract_core::trivial_op_state_freeeze;
4
5pub fn register(registry: &mut Registry) {
6    registry.register_primitive(
7        "tract_pulse_mask",
8        &[
9            TypeName::Scalar.tensor().named("input"),
10            TypeName::Integer.named("axis"),
11            TypeName::Integer.named("begin"),
12            TypeName::Integer.named("end"),
13            TypeName::Scalar.named("value"),
14        ],
15        &[("output", TypeName::Scalar.tensor())],
16        deser,
17    );
18    registry.register_dumper(ser)
19}
20
21fn ser(ast: &mut IntoAst, node: &TypedNode, op: &PulseMask) -> TractResult<Option<Arc<RValue>>> {
22    let wire = ast.mapping[&node.inputs[0]].clone();
23    let params = vec![
24        ("axis", numeric(op.axis)),
25        ("begin", numeric(op.begin)),
26        ("end", tdim(&op.end)),
27        ("value", numeric(op.value.cast_to_scalar::<f32>())),
28    ];
29    Ok(Some(invocation("tract_pulse_mask", &[wire], &params)))
30}
31
32fn deser(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
33    let wire = invocation.named_arg_as(builder, "input")?;
34    let axis = invocation.named_arg_as(builder, "axis")?;
35    let begin = invocation.named_arg_as(builder, "begin")?;
36    let value: Tensor = tensor0(invocation.named_arg_as::<f32>(builder, "value")?);
37    let end = builder.allowing_new_symbols(|builder| invocation.named_arg_as(builder, "end"))?;
38    let op = PulseMask { axis, begin, end, value };
39    builder.wire(op, &[wire])
40}
41
42#[derive(Debug, Clone, Default, Hash)]
43struct PulseMaskOpState {
44    current_pos: usize,
45}
46
47impl OpState for PulseMaskOpState {
48    fn eval(
49        &mut self,
50        session: &mut SessionState,
51        op: &dyn Op,
52        inputs: TVec<TValue>,
53    ) -> TractResult<TVec<TValue>> {
54        let input = args_1!(inputs).into_tensor();
55        let op = op.downcast_ref::<PulseMask>().ok_or_else(|| format_err!("Wrong Op type"))?;
56        let tensor = self.pad(session, op, input)?;
57        Ok(tvec!(tensor.into_tvalue()))
58    }
59}
60
61impl PulseMaskOpState {
62    fn pad(
63        &mut self,
64        session: &SessionState,
65        op: &PulseMask,
66        mut input: Tensor,
67    ) -> TractResult<Tensor> {
68        let pulse = input.shape()[op.axis];
69        let pulse_begin = self.current_pos;
70        let pulse_end = self.current_pos + pulse;
71        self.current_pos += pulse;
72        let end = op.end.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
73
74        // pulse is entirely in valid input, just forward
75        if pulse_begin >= op.begin && pulse_end <= end {
76            return Ok(input);
77        }
78
79        if pulse_begin < op.begin {
80            let fill_up_to = (op.begin - pulse_begin).min(pulse);
81            unsafe {
82                dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())(
83                    &mut input,
84                    &op.value,
85                    op.axis,
86                    0..fill_up_to
87                ))
88            };
89        }
90        if pulse_end > end {
91            let fill_from = pulse - (pulse_end - end).min(pulse);
92            unsafe {
93                dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())(
94                    &mut input,
95                    &op.value,
96                    op.axis,
97                    fill_from..pulse
98                ))
99            }
100        }
101
102        Ok(input)
103    }
104}
105
106#[derive(Debug, Clone, Default, Hash)]
107pub struct PulseMask {
108    pub axis: usize,
109    pub begin: usize,
110    pub end: TDim,
111    pub value: Tensor,
112}
113
114impl Op for PulseMask {
115    fn name(&self) -> Cow<str> {
116        "PulseMask".into()
117    }
118
119    fn info(&self) -> TractResult<Vec<String>> {
120        Ok(vec![format!("axis: {} begin: {} end: {}", self.axis, self.begin, self.end,)])
121    }
122
123    op_as_typed_op!();
124}
125
126impl EvalOp for PulseMask {
127    fn is_stateless(&self) -> bool {
128        false
129    }
130
131    fn state(
132        &self,
133        _session: &mut SessionState,
134        _node_id: usize,
135    ) -> TractResult<Option<Box<dyn OpState>>> {
136        Ok(Some(Box::<PulseMaskOpState>::default()))
137    }
138}
139
140impl TypedOp for PulseMask {
141    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
142        Ok(tvec!(inputs[0].clone()))
143    }
144
145    as_op!();
146}
147
148trivial_op_state_freeeze!(PulseMaskOpState);