1use tract_nnef::internal::*;
2use tract_nnef::ser::tdim;
3use tract_nnef::tract_core::trivial_op_state_freeeze;
4
5pub fn register(registry: &mut Registry) {
6 registry.register_primitive(
7 "tract_pulse_mask",
8 &[
9 TypeName::Scalar.tensor().named("input"),
10 TypeName::Integer.named("axis"),
11 TypeName::Integer.named("begin"),
12 TypeName::Integer.named("end"),
13 TypeName::Scalar.named("value"),
14 ],
15 &[("output", TypeName::Scalar.tensor())],
16 deser,
17 );
18 registry.register_dumper(ser)
19}
20
21fn ser(ast: &mut IntoAst, node: &TypedNode, op: &PulseMask) -> TractResult<Option<Arc<RValue>>> {
22 let wire = ast.mapping[&node.inputs[0]].clone();
23 let params = vec![
24 ("axis", numeric(op.axis)),
25 ("begin", numeric(op.begin)),
26 ("end", tdim(&op.end)),
27 ("value", numeric(op.value.cast_to_scalar::<f32>())),
28 ];
29 Ok(Some(invocation("tract_pulse_mask", &[wire], ¶ms)))
30}
31
32fn deser(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
33 let wire = invocation.named_arg_as(builder, "input")?;
34 let axis = invocation.named_arg_as(builder, "axis")?;
35 let begin = invocation.named_arg_as(builder, "begin")?;
36 let value: Tensor = tensor0(invocation.named_arg_as::<f32>(builder, "value")?);
37 let end = builder.allowing_new_symbols(|builder| invocation.named_arg_as(builder, "end"))?;
38 let op = PulseMask { axis, begin, end, value };
39 builder.wire(op, &[wire])
40}
41
42#[derive(Debug, Clone, Default, Hash)]
43struct PulseMaskOpState {
44 current_pos: usize,
45}
46
47impl OpState for PulseMaskOpState {
48 fn eval(
49 &mut self,
50 session: &mut SessionState,
51 op: &dyn Op,
52 inputs: TVec<TValue>,
53 ) -> TractResult<TVec<TValue>> {
54 let input = args_1!(inputs).into_tensor();
55 let op = op.downcast_ref::<PulseMask>().ok_or_else(|| format_err!("Wrong Op type"))?;
56 let tensor = self.pad(session, op, input)?;
57 Ok(tvec!(tensor.into_tvalue()))
58 }
59}
60
61impl PulseMaskOpState {
62 fn pad(
63 &mut self,
64 session: &SessionState,
65 op: &PulseMask,
66 mut input: Tensor,
67 ) -> TractResult<Tensor> {
68 let pulse = input.shape()[op.axis];
69 let pulse_begin = self.current_pos;
70 let pulse_end = self.current_pos + pulse;
71 self.current_pos += pulse;
72 let end = op.end.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
73
74 if pulse_begin >= op.begin && pulse_end <= end {
76 return Ok(input);
77 }
78
79 if pulse_begin < op.begin {
80 let fill_up_to = (op.begin - pulse_begin).min(pulse);
81 unsafe {
82 dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())(
83 &mut input,
84 &op.value,
85 op.axis,
86 0..fill_up_to
87 ))
88 };
89 }
90 if pulse_end > end {
91 let fill_from = pulse - (pulse_end - end).min(pulse);
92 unsafe {
93 dispatch_copy_by_size!(crate::pad::fill_slice_constant(input.datum_type())(
94 &mut input,
95 &op.value,
96 op.axis,
97 fill_from..pulse
98 ))
99 }
100 }
101
102 Ok(input)
103 }
104}
105
106#[derive(Debug, Clone, Default, Hash)]
107pub struct PulseMask {
108 pub axis: usize,
109 pub begin: usize,
110 pub end: TDim,
111 pub value: Tensor,
112}
113
114impl Op for PulseMask {
115 fn name(&self) -> Cow<str> {
116 "PulseMask".into()
117 }
118
119 fn info(&self) -> TractResult<Vec<String>> {
120 Ok(vec![format!("axis: {} begin: {} end: {}", self.axis, self.begin, self.end,)])
121 }
122
123 op_as_typed_op!();
124}
125
126impl EvalOp for PulseMask {
127 fn is_stateless(&self) -> bool {
128 false
129 }
130
131 fn state(
132 &self,
133 _session: &mut SessionState,
134 _node_id: usize,
135 ) -> TractResult<Option<Box<dyn OpState>>> {
136 Ok(Some(Box::<PulseMaskOpState>::default()))
137 }
138}
139
140impl TypedOp for PulseMask {
141 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
142 Ok(tvec!(inputs[0].clone()))
143 }
144
145 as_op!();
146}
147
148trivial_op_state_freeeze!(PulseMaskOpState);