tract_pulse/ops/
delay.rs

1use crate::internal::*;
2use tract_pulse_opl::ops::Delay;
3
4pub fn register(registry: &mut Registry) {
5    registry.register_dumper(ser_delay)
6}
7
8fn ser_delay(ast: &mut IntoAst, node: &TypedNode, op: &Delay) -> TractResult<Option<Arc<RValue>>> {
9    let wire = ast.mapping[&node.inputs[0]].clone();
10    Ok(Some(invocation(
11        "tract_pulse_delay",
12        &[wire],
13        &[
14            ("axis", numeric(op.axis)),
15            ("delay", numeric(op.delay)),
16            ("overlap", numeric(op.overlap)),
17        ],
18    )))
19}
20
21impl PulsedOp for Delay {
22    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
23        ensure!(inputs.len() == 1);
24        let mut fact = inputs[0].clone();
25        let stream = fact.stream.as_mut().unwrap();
26        fact.shape.set(self.axis, fact.shape[self.axis].clone() + self.overlap);
27        stream.delay += self.delay + self.overlap;
28        Ok(tvec!(fact))
29    }
30
31    as_op!();
32    pulsed_op_to_typed_op!();
33}
34
35#[cfg(test)]
36mod test {
37    use crate::fact::StreamInfo;
38
39    use super::*;
40
41    fn test_pulse_delay_over(pulse: usize, delay: usize, overlap: usize) {
42        let mut model = PulsedModel::default();
43        let stream_dim = model.symbols.sym("S").to_dim();
44        let fact1 = PulsedFact {
45            datum_type: u8::datum_type(),
46            shape: (&[pulse]).into(),
47            stream: Some(StreamInfo { axis: 0, dim: stream_dim, delay: 0 }),
48        };
49        let source = model.add_source("source", fact1.clone()).unwrap();
50        model
51            .wire_node(
52                "delay",
53                Delay::new_typed(&(&fact1).into(), fact1.stream.unwrap().axis, delay, overlap),
54                &[source],
55            )
56            .unwrap();
57        model.auto_outputs().unwrap();
58
59        let plan = SimplePlan::new(model).unwrap();
60        let mut state = tract_core::plan::SimpleState::new(plan).unwrap();
61
62        for i in 0..5 {
63            let input: Vec<u8> = (pulse * i..(pulse * (i + 1))).map(|a| a as u8).collect();
64            let expect: Vec<u8> = (pulse * i..(pulse * (i + 1) + overlap))
65                .map(|i| i.saturating_sub(delay + overlap) as u8)
66                .collect();
67            let output = state.run(tvec!(tensor1(&input).into())).unwrap();
68            let skip = (delay + overlap).saturating_sub(i * pulse).min(pulse + overlap);
69            assert_eq!(&output[0].as_slice::<u8>().unwrap()[skip..], &expect[skip..]);
70        }
71    }
72
73    #[test]
74    fn sub_pulse() {
75        test_pulse_delay_over(4, 1, 0);
76    }
77
78    #[test]
79    fn supra_pulse() {
80        test_pulse_delay_over(4, 5, 0);
81    }
82
83    #[test]
84    fn sub_pulse_context() {
85        test_pulse_delay_over(4, 0, 2);
86    }
87
88    #[test]
89    fn supra_pulse_context() {
90        test_pulse_delay_over(4, 0, 6);
91    }
92
93    #[test]
94    fn test_two_delays() {
95        let pulse = 4usize;
96        let mut model = PulsedModel::default();
97        let stream_dim = model.symbols.sym("S").to_dim();
98        let fact_0 = PulsedFact {
99            datum_type: u8::datum_type(),
100            shape: (&[pulse]).into(),
101            stream: Some(StreamInfo { axis: 0, dim: stream_dim, delay: 0 }),
102        };
103        let stream = fact_0.stream.as_ref().unwrap();
104        let source = model.add_source("source", fact_0.clone()).unwrap();
105        let delay_1 = model
106            .wire_node("delay-1", Delay::new_typed(&(&fact_0).into(), stream.axis, 2, 0), &[source])
107            .unwrap()[0];
108        let fact_1 = model.outlet_fact(delay_1).unwrap().clone();
109        let delay_2 = model
110            .wire_node(
111                "delay-1",
112                Delay::new_typed(&(&fact_1).into(), stream.axis, 2, 0),
113                &[delay_1],
114            )
115            .unwrap();
116        model.set_output_outlets(&delay_2).unwrap();
117
118        let plan = SimplePlan::new(model).unwrap();
119        let mut state = tract_core::plan::SimpleState::new(plan).unwrap();
120
121        for i in 0..5 {
122            let input: Vec<u8> = (pulse * i..(pulse * (i + 1))).map(|a| a as u8).collect();
123            let expect: Vec<u8> =
124                (pulse * i..(pulse * (i + 1))).map(|i| i.saturating_sub(4) as u8).collect();
125            let skip = 4usize.saturating_sub(i * pulse).min(pulse);
126            let output = state.run(tvec!(tensor1(&input).into())).unwrap();
127            assert_eq!(&output[0].as_slice::<u8>().unwrap()[skip..], &expect[skip..]);
128        }
129    }
130}