tract_pulse/ops/array/
pad.rs1use crate::internal::*;
2use tract_core::ops::array::{Pad, PadMode};
3use tract_pulse_opl::ops::{Delay, PulsePad};
4
5register_all!(Pad: pulsify);
6
7fn pulsify(
8 op: &Pad,
9 _source: &TypedModel,
10 node: &TypedNode,
11 target: &mut PulsedModel,
12 mapping: &HashMap<OutletId, OutletId>,
13 _symbol: &Symbol,
14 _pulse: &TDim,
15) -> TractResult<Option<TVec<OutletId>>> {
16 let mut input = mapping[&node.inputs[0]];
17 let fact = target.outlet_fact(input)?.clone();
18 let stream = fact.stream.as_ref().unwrap();
19 if !op.pads.iter().enumerate().all(|(ax, &(a, b))| ax == stream.axis || (a == 0 && b == 0)) {
20 return Ok(None);
21 }
22 let (before, after) = op.pads[stream.axis];
23 let pulse = fact.pulse().unwrap();
24 let mut extra_delay = before.saturating_sub(stream.delay);
25 match op.mode {
26 PadMode::Constant(_) => (),
27 PadMode::Edge => {
28 let pulse = if let Ok(pulse) = pulse.to_usize() {
29 pulse
30 } else {
31 bail!("Edge padding can only by pulsified with concrete integer values")
32 };
33 if before < pulse {
34 let start_offset = (stream.delay + extra_delay) % pulse;
35 if before > start_offset {
36 extra_delay += before - start_offset;
37 }
38 } else {
39 bail!(
40 "Edge padding mode needs pulse strictly bigger than left padding (pulse={} padding={})",
41 pulse,
42 before
43 )
44 }
45 }
46 PadMode::Reflect => bail!("Reflect padding mode pulsing is not supported"),
47 };
48 if extra_delay > 0 {
49 input = target.wire_node(
50 format!("{}.Delay", node.name),
51 Delay::new_typed(&(&fact).into(), stream.axis, extra_delay, 0),
52 &[input],
53 )?[0];
54 }
55 let op = PulsePad {
56 axis: stream.axis,
57 before,
58 after: after.into(),
59 begin_input: stream.delay + extra_delay,
60 end_input: stream.delay.to_dim() + extra_delay + &stream.dim,
61 mode: op.mode.clone(),
62 overlap: 0,
63 };
64 Ok(Some(target.wire_node(&*node.name, op, &[input])?))
65}
66
67impl PulsedOp for PulsePad {
68 fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
69 let mut fact = inputs[0].clone();
70 let stream = fact.stream.as_mut().unwrap();
71 stream.dim += self.before.to_dim() + &self.after;
72 stream.delay -= self.before;
73 Ok(tvec!(fact))
74 }
75
76 as_op!();
77 pulsed_op_to_typed_op!();
78}