1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
use crate::internal::*;
use tract_pulse_opl::ops::Delay;

pub fn register(registry: &mut Registry) {
    registry.register_dumper(TypeId::of::<Delay>(), ser_delay)
}

fn ser_delay(ast: &mut IntoAst, node: &TypedNode) -> TractResult<Option<Arc<RValue>>> {
    let op = node.op().downcast_ref::<Delay>().unwrap();
    let wire = ast.mapping[&node.inputs[0]].clone();
    Ok(Some(invocation(
        "tract_pulse_delay",
        &[wire],
        &[
            ("axis", numeric(op.axis)),
            ("delay", numeric(op.delay)),
            ("overlap", numeric(op.overlap)),
        ],
    )))
}

impl PulsedOp for Delay {
    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
        ensure!(inputs.len() == 1);
        let mut fact = inputs[0].clone();
        let mut stream = fact.stream.as_mut().unwrap();
        fact.shape.set(self.axis, fact.shape[self.axis].clone() + self.overlap);
        stream.delay += self.delay + self.overlap;
        Ok(tvec!(fact))
    }

    as_op!();
    pulsed_op_to_typed_op!();
}

#[cfg(test)]
mod test {
    use crate::fact::StreamInfo;

    use super::*;

    fn test_pulse_delay_over(pulse: usize, delay: usize, overlap: usize) {
        let mut model = PulsedModel::default();
        let stream_dim = model.symbol_table.sym("S").to_dim();
        let fact1 = PulsedFact {
            datum_type: u8::datum_type(),
            shape: (&[pulse]).into(),
            stream: Some(StreamInfo { axis: 0, dim: stream_dim, delay: 0 }),
        };
        let source = model.add_source("source", fact1.clone()).unwrap();
        model
            .wire_node(
                "delay",
                Delay::new_typed(&(&fact1).into(), fact1.stream.unwrap().axis, delay, overlap),
                &[source],
            )
            .unwrap();
        model.auto_outputs().unwrap();

        let plan = SimplePlan::new(model).unwrap();
        let mut state = tract_core::plan::SimpleState::new(plan).unwrap();

        for i in 0..5 {
            let input: Vec<u8> = (pulse * i..(pulse * (i + 1))).map(|a| a as u8).collect();
            let expect: Vec<u8> = (pulse * i..(pulse * (i + 1) + overlap))
                .map(|i| i.saturating_sub(delay + overlap) as u8)
                .collect();
            let output = state.run(tvec!(tensor1(&input).into())).unwrap();
            let skip = (delay + overlap).saturating_sub(i * pulse).min(pulse + overlap);
            assert_eq!(&output[0].as_slice::<u8>().unwrap()[skip..], &expect[skip..]);
        }
    }

    #[test]
    fn sub_pulse() {
        test_pulse_delay_over(4, 1, 0);
    }

    #[test]
    fn supra_pulse() {
        test_pulse_delay_over(4, 5, 0);
    }

    #[test]
    fn sub_pulse_context() {
        test_pulse_delay_over(4, 0, 2);
    }

    #[test]
    fn supra_pulse_context() {
        test_pulse_delay_over(4, 0, 6);
    }

    #[test]
    fn test_two_delays() {
        let pulse = 4usize;
        let mut model = PulsedModel::default();
        let stream_dim = model.symbol_table.sym("S").to_dim();
        let fact_0 = PulsedFact {
            datum_type: u8::datum_type(),
            shape: (&[pulse]).into(),
            stream: Some(StreamInfo { axis: 0, dim: stream_dim, delay: 0 }),
        };
        let stream = fact_0.stream.as_ref().unwrap();
        let source = model.add_source("source", fact_0.clone()).unwrap();
        let delay_1 = model
            .wire_node("delay-1", Delay::new_typed(&(&fact_0).into(), stream.axis, 2, 0), &[source])
            .unwrap()[0];
        let fact_1 = model.outlet_fact(delay_1).unwrap().clone();
        let delay_2 = model
            .wire_node(
                "delay-1",
                Delay::new_typed(&(&fact_1).into(), stream.axis, 2, 0),
                &[delay_1],
            )
            .unwrap();
        model.set_output_outlets(&delay_2).unwrap();

        let plan = SimplePlan::new(model).unwrap();
        let mut state = tract_core::plan::SimpleState::new(plan).unwrap();

        for i in 0..5 {
            let input: Vec<u8> = (pulse * i..(pulse * (i + 1))).map(|a| a as u8).collect();
            let expect: Vec<u8> =
                (pulse * i..(pulse * (i + 1))).map(|i| i.saturating_sub(4) as u8).collect();
            let skip = 4usize.saturating_sub(i * pulse).min(pulse);
            let output = state.run(tvec!(tensor1(&input).into())).unwrap();
            assert_eq!(&output[0].as_slice::<u8>().unwrap()[skip..], &expect[skip..]);
        }
    }
}