1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
use crate::internal::*;
use tract_core::ops::array::{Pad, PadMode};
use tract_pulse_opl::ops::{Delay, PulsePad};

register_all!(Pad: pulsify);

fn pulsify(
    op: &Pad,
    _source: &TypedModel,
    node: &TypedNode,
    target: &mut PulsedModel,
    mapping: &HashMap<OutletId, OutletId>,
    _symbol: &Symbol,
    _pulse: &TDim,
) -> TractResult<Option<TVec<OutletId>>> {
    let mut input = mapping[&node.inputs[0]];
    let fact = target.outlet_fact(input)?.clone();
    let stream = fact.stream.as_ref().unwrap();
    if !op.pads.iter().enumerate().all(|(ax, &(a, b))| ax == stream.axis || (a == 0 && b == 0)) {
        return Ok(None);
    }
    let (before, after) = op.pads[stream.axis];
    let pulse = fact.pulse().unwrap();
    let mut extra_delay = before.saturating_sub(stream.delay);
    match op.mode {
        PadMode::Constant(_) => (),
        PadMode::Edge => {
            let pulse = if let Ok(pulse) = pulse.to_usize() {
                pulse
            } else {
                bail!("Edge padding can only by pulsified with concrete integer values")
            };
            if before < pulse {
                let start_offset = (stream.delay + extra_delay) % pulse;
                if before > start_offset {
                    extra_delay += before - start_offset;
                }
            } else {
                bail!(
                    "Edge padding mode needs pulse strictly bigger than left padding (pulse={} padding={})",
                    pulse,
                    before
                    )
            }
        }
        PadMode::Reflect => bail!("Reflect padding mode pulsing is not supported"),
    };
    if extra_delay > 0 {
        input = target.wire_node(
            format!("{}.Delay", node.name),
            Delay::new_typed(&(&fact).into(), stream.axis, extra_delay, 0),
            &[input],
        )?[0];
    }
    let op = PulsePad {
        axis: stream.axis,
        before,
        after: after.into(),
        begin_input: stream.delay + extra_delay,
        end_input: stream.delay.to_dim() + extra_delay + &stream.dim,
        mode: op.mode.clone(),
        overlap: 0,
    };
    Ok(Some(target.wire_node(&*node.name, op, &[input])?))
}

impl PulsedOp for PulsePad {
    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
        let mut fact = inputs[0].clone();
        let mut stream = fact.stream.as_mut().unwrap();
        stream.dim += self.before.to_dim() + &self.after;
        stream.delay -= self.before;
        Ok(tvec!(fact))
    }

    as_op!();
    pulsed_op_to_typed_op!();
}