1use tract_core::ndarray::*;
2use tract_core::ops::array::PadMode;
3use tract_nnef::internal::*;
4use tract_nnef::ser::tdim;
5use tract_nnef::tract_core::ops::OpStateFreeze;
6
7pub fn register(registry: &mut Registry) {
8 registry.register_primitive(
9 "tract_pulse_pulse_pad",
10 &[
11 TypeName::Scalar.tensor().named("input"),
12 TypeName::Integer.named("axis"),
13 TypeName::Integer.named("before"),
14 TypeName::Integer.named("after"),
15 TypeName::Integer.named("begin_input"),
16 TypeName::Integer.named("end_input"),
17 TypeName::String.named("border"),
18 TypeName::Scalar.named("value"),
19 TypeName::Integer.named("overlap"),
20 ],
21 &[("output", TypeName::Scalar.tensor())],
22 deser,
23 );
24 registry.register_dumper(ser)
25}
26
27fn ser(ast: &mut IntoAst, node: &TypedNode, op: &PulsePad) -> TractResult<Option<Arc<RValue>>> {
28 let wire = ast.mapping[&node.inputs[0]].clone();
29 let dt = ast.model.outlet_fact(node.inputs[0])?.datum_type;
30 let (border, value) = tract_nnef::ops::nnef::ser::pad_mode(&op.mode, dt)?;
31 let mut params = vec![
32 ("axis", numeric(op.axis)),
33 ("before", numeric(op.before)),
34 ("begin_input", numeric(op.begin_input)),
35 ("overlap", numeric(op.overlap)),
36 ("after", tdim(&op.after)),
37 ("end_input", tdim(&op.end_input)),
38 ];
39 params.push(("border", string(border)));
40 if let Some(value) = value {
41 params.push(("value", value));
42 }
43 Ok(Some(invocation("tract_pulse_pulse_pad", &[wire], ¶ms)))
44}
45
46fn deser(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
47 let wire = invocation.named_arg_as(builder, "input")?;
48 let axis = invocation.named_arg_as(builder, "axis")?;
49 let before = invocation.named_arg_as(builder, "before")?;
50 let begin_input = invocation.named_arg_as(builder, "begin_input")?;
51 let overlap = invocation.named_arg_as(builder, "overlap")?;
52 let border = invocation.named_arg_as::<String>(builder, "border")?;
53 let value: Tensor = tensor0(invocation.named_arg_as::<f32>(builder, "value")?);
54 let (after, end_input) = builder.allowing_new_symbols(|builder| {
55 TractResult::Ok((
56 invocation.named_arg_as(builder, "after")?,
57 invocation.named_arg_as(builder, "end_input")?,
58 ))
59 })?;
60
61 let mode = tract_nnef::ops::nnef::deser::pad_mode(&border, value)?;
62 let op = PulsePad { axis, before, after, begin_input, end_input, mode, overlap };
63 builder.wire(op, &[wire])
64}
65
66pub(crate) unsafe fn fill_slice_constant<T: Datum + Copy>(
67 data: &mut Tensor,
68 constant: &Tensor,
69 axis: usize,
70 range: std::ops::Range<usize>,
71) {
72 let c = constant.to_scalar_unchecked::<T>();
73 data.to_array_view_mut_unchecked::<T>().slice_axis_mut(Axis(axis), range.into()).fill(*c);
74}
75
76unsafe fn fill_slice_with_frame<T: Datum + Copy>(
77 data: &mut Tensor,
78 axis: usize,
79 valid: &Tensor,
80 range: std::ops::Range<usize>,
81) {
82 let mut data = data.to_array_view_mut_unchecked::<T>();
83 let valid = valid.to_array_view_unchecked::<T>();
84 for i in range {
85 data.slice_axis_mut(Axis(axis), (i..i + 1).into()).assign(&valid);
86 }
87}
88
89#[derive(Debug, Clone, Default, Hash)]
90struct PulsePadOpState {
91 current_pos: usize,
92 last_valid_frame: Option<Tensor>,
93}
94
95impl OpState for PulsePadOpState {
96 fn eval(
97 &mut self,
98 session: &mut SessionState,
99 op: &dyn Op,
100 inputs: TVec<TValue>,
101 ) -> TractResult<TVec<TValue>> {
102 let input = args_1!(inputs).into_tensor();
103 let op = op.downcast_ref::<PulsePad>().ok_or_else(|| format_err!("Wrong Op type"))?;
104 let tensor = self.pad(session, op, input)?;
105 Ok(tvec!(tensor.into_tvalue()))
106 }
107}
108
109impl PulsePadOpState {
110 unsafe fn save_frame<T: Datum + Copy>(&mut self, op: &PulsePad, input: &Tensor, frame: usize) {
111 let data = input.to_array_view_unchecked::<T>();
112 self.last_valid_frame =
113 Some(data.index_axis(Axis(op.axis), frame).to_owned().into_tensor());
114 }
115
116 fn pad(
117 &mut self,
118 session: &SessionState,
119 op: &PulsePad,
120 mut input: Tensor,
121 ) -> TractResult<Tensor> {
122 let pulse = input.shape()[op.axis];
123 let pulse_begin = self.current_pos;
124 let pulse_end = self.current_pos + pulse;
125 self.current_pos += pulse - op.overlap;
126 let end_input =
127 op.end_input.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
128 let after = op.after.eval(&session.resolved_symbols).to_usize().unwrap_or(usize::MAX);
129
130 if let PadMode::Edge = op.mode {
131 if after != 0 && pulse_begin < end_input {
132 let latest_valid_frame = (end_input - pulse_begin).min(pulse) - 1;
133 unsafe {
134 dispatch_copy_by_size!(Self::save_frame(input.datum_type())(
135 self,
136 op,
137 &input,
138 latest_valid_frame
139 ))
140 }
141 }
142 }
143
144 if pulse_begin >= op.begin_input && pulse_end <= end_input {
146 return Ok(input);
147 }
148 if pulse_end <= op.begin_input - op.before || pulse_begin >= end_input.saturating_add(after)
150 {
151 return Ok(input);
152 }
153
154 if pulse_begin < op.begin_input {
155 let fill_up_to = (op.begin_input - pulse_begin).min(pulse);
156 match &op.mode {
157 PadMode::Constant(c) => unsafe {
158 dispatch_copy_by_size!(fill_slice_constant(input.datum_type())(
159 &mut input,
160 c,
161 op.axis,
162 0..fill_up_to
163 ))
164 },
165 PadMode::Edge => {
166 let frame = input.slice(op.axis, fill_up_to, fill_up_to + 1)?;
167 unsafe {
168 dispatch_copy_by_size!(fill_slice_with_frame(input.datum_type())(
169 &mut input,
170 op.axis,
171 &frame,
172 0..fill_up_to
173 ))
174 }
175 }
176 _ => unimplemented!(),
177 }
178 }
179 if pulse_end > end_input && after > 0 {
180 let fill_from = pulse - (pulse_end - end_input).min(pulse);
181 match &op.mode {
182 PadMode::Constant(c) => unsafe {
183 dispatch_copy_by_size!(fill_slice_constant(input.datum_type())(
184 &mut input,
185 c,
186 op.axis,
187 fill_from..pulse
188 ))
189 },
190 PadMode::Edge => {
191 let last_frame = self.last_valid_frame.as_ref().unwrap();
192 unsafe {
193 dispatch_copy_by_size!(fill_slice_with_frame(input.datum_type())(
194 &mut input,
195 op.axis,
196 last_frame,
197 fill_from..pulse
198 ))
199 }
200 }
201 _ => unimplemented!(),
202 }
203 }
204
205 Ok(input)
206 }
207}
208
209#[derive(Debug, Clone, Default, Hash)]
210pub struct PulsePad {
211 pub axis: usize,
212 pub before: usize,
213 pub after: TDim,
214 pub begin_input: usize,
215 pub end_input: TDim,
216 pub mode: PadMode,
217 pub overlap: usize,
218}
219
220impl Op for PulsePad {
221 fn name(&self) -> Cow<str> {
222 "PulsePad".into()
223 }
224
225 fn info(&self) -> TractResult<Vec<String>> {
226 Ok(vec![format!(
227 "Mode: {:?}, axis: {} before: {} after: {}",
228 self.mode, self.axis, self.before, self.after,
229 )])
230 }
231
232 op_as_typed_op!();
233}
234
235impl EvalOp for PulsePad {
236 fn is_stateless(&self) -> bool {
237 false
238 }
239
240 fn state(
241 &self,
242 _session: &mut SessionState,
243 _node_id: usize,
244 ) -> TractResult<Option<Box<dyn OpState>>> {
245 Ok(Some(Box::<PulsePadOpState>::default()))
246 }
247}
248
249impl TypedOp for PulsePad {
250 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
251 Ok(tvec!(inputs[0].clone()))
252 }
253
254 as_op!();
255}
256
257#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
258struct FrozenPulsePadOpState {
259 current_pos: usize,
260 last_valid_frame: Option<Arc<Tensor>>,
261}
262
263impl OpStateFreeze for PulsePadOpState {
264 fn freeze(&self) -> Box<dyn FrozenOpState> {
265 Box::new(FrozenPulsePadOpState {
266 current_pos: self.current_pos,
267 last_valid_frame: self.last_valid_frame.as_ref().map(|t| t.clone().into_arc_tensor()),
268 })
269 }
270}
271
272impl FrozenOpState for FrozenPulsePadOpState {
273 fn unfreeze(&self) -> Box<dyn OpState> {
274 Box::new(PulsePadOpState {
275 current_pos: self.current_pos,
276 last_valid_frame: self.last_valid_frame.as_ref().map(|t| t.clone().into_tensor()),
277 })
278 }
279}