1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
use crate::internal::*;
use tract_core::ops::array::{Pad, PadMode};
use tract_pulse_opl::ops::{Delay, PulsePad};

register_all!(Pad: pulsify);

fn pulsify(
    op: &Pad,
    _source: &TypedModel,
    node: &TypedNode,
    target: &mut PulsedModel,
    mapping: &HashMap<OutletId, OutletId>,
    _pulse: usize,
) -> TractResult<Option<TVec<OutletId>>> {
    let mut input = mapping[&node.inputs[0]];
    let fact = target.outlet_fact(input)?.clone();
    if !op.pads.iter().enumerate().all(|(ax, &(a, b))| ax == fact.axis || (a == 0 && b == 0)) {
        return Ok(None);
    }
    let (before, after) = op.pads[fact.axis];
    let pulse = fact.pulse();
    let mut extra_delay = before.saturating_sub(fact.delay);
    match op.mode {
        PadMode::Constant(_) => (),
        PadMode::Edge if before < pulse => {
            let start_offset = (fact.delay + extra_delay) % pulse;
            if before > start_offset {
                extra_delay += before - start_offset;
            }
        }
        PadMode::Edge => bail!(
            "Edge padding mode needs pulse strictly bigger than left padding (pulse={} padding={})",
            pulse,
            before
        ),
        PadMode::Reflect => bail!("Reflect padding mode pulsing is not supported"),
    };
    if extra_delay > 0 {
        input = target.wire_node(
            format!("{}.Delay", node.name),
            Delay::new_typed(&(&fact).into(), fact.axis, extra_delay, 0),
            &[input],
        )?[0];
    }
    let op = PulsePad {
        axis: fact.axis,
        before,
        after: after.into(),
        begin_input: fact.delay + extra_delay,
        end_input: fact.delay.to_dim() + extra_delay + fact.dim,
        mode: op.mode.clone(),
        overlap: 0,
    };
    Ok(Some(target.wire_node(&*node.name, op, &[input])?))
}

impl PulsedOp for PulsePad {
    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
        let mut fact = inputs[0].clone();
        fact.dim += self.before.to_dim() + &self.after;
        fact.delay -= self.before;
        Ok(tvec!(fact))
    }

    as_op!();
    pulsed_op_to_typed_op!();
}