tract_pulse/ops/array/
pad.rs

1use crate::internal::*;
2use tract_core::ops::array::{Pad, PadMode};
3use tract_pulse_opl::ops::{Delay, PulsePad};
4
5register_all!(Pad: pulsify);
6
7fn pulsify(
8    op: &Pad,
9    _source: &TypedModel,
10    node: &TypedNode,
11    target: &mut PulsedModel,
12    mapping: &HashMap<OutletId, OutletId>,
13    _symbol: &Symbol,
14    _pulse: &TDim,
15) -> TractResult<Option<TVec<OutletId>>> {
16    let mut input = mapping[&node.inputs[0]];
17    let fact = target.outlet_fact(input)?.clone();
18    let stream = fact.stream.as_ref().unwrap();
19    if !op.pads.iter().enumerate().all(|(ax, &(a, b))| ax == stream.axis || (a == 0 && b == 0)) {
20        return Ok(None);
21    }
22    let (before, after) = op.pads[stream.axis];
23    let pulse = fact.pulse().unwrap();
24    let mut extra_delay = before.saturating_sub(stream.delay);
25    match op.mode {
26        PadMode::Constant(_) => (),
27        PadMode::Edge => {
28            let pulse = if let Ok(pulse) = pulse.to_usize() {
29                pulse
30            } else {
31                bail!("Edge padding can only by pulsified with concrete integer values")
32            };
33            if before < pulse {
34                let start_offset = (stream.delay + extra_delay) % pulse;
35                if before > start_offset {
36                    extra_delay += before - start_offset;
37                }
38            } else {
39                bail!(
40                    "Edge padding mode needs pulse strictly bigger than left padding (pulse={} padding={})",
41                    pulse,
42                    before
43                    )
44            }
45        }
46        PadMode::Reflect => bail!("Reflect padding mode pulsing is not supported"),
47    };
48    if extra_delay > 0 {
49        input = target.wire_node(
50            format!("{}.Delay", node.name),
51            Delay::new_typed(&(&fact).into(), stream.axis, extra_delay, 0),
52            &[input],
53        )?[0];
54    }
55    let op = PulsePad {
56        axis: stream.axis,
57        before,
58        after: after.into(),
59        begin_input: stream.delay + extra_delay,
60        end_input: stream.delay.to_dim() + extra_delay + &stream.dim,
61        mode: op.mode.clone(),
62        overlap: 0,
63    };
64    Ok(Some(target.wire_node(&*node.name, op, &[input])?))
65}
66
67impl PulsedOp for PulsePad {
68    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
69        let mut fact = inputs[0].clone();
70        let stream = fact.stream.as_mut().unwrap();
71        stream.dim += self.before.to_dim() + &self.after;
72        stream.delay -= self.before;
73        Ok(tvec!(fact))
74    }
75
76    as_op!();
77    pulsed_op_to_typed_op!();
78}