1use crate::internal::*;
2use tract_pulse_opl::ops::Delay;
3
4pub fn register(registry: &mut Registry) {
5 registry.register_dumper(ser_delay)
6}
7
8fn ser_delay(ast: &mut IntoAst, node: &TypedNode, op: &Delay) -> TractResult<Option<Arc<RValue>>> {
9 let wire = ast.mapping[&node.inputs[0]].clone();
10 Ok(Some(invocation(
11 "tract_pulse_delay",
12 &[wire],
13 &[
14 ("axis", numeric(op.axis)),
15 ("delay", numeric(op.delay)),
16 ("overlap", numeric(op.overlap)),
17 ],
18 )))
19}
20
21impl PulsedOp for Delay {
22 fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
23 ensure!(inputs.len() == 1);
24 let mut fact = inputs[0].clone();
25 let stream = fact.stream.as_mut().unwrap();
26 fact.shape.set(self.axis, fact.shape[self.axis].clone() + self.overlap);
27 stream.delay += self.delay + self.overlap;
28 Ok(tvec!(fact))
29 }
30
31 as_op!();
32 pulsed_op_to_typed_op!();
33}
34
35#[cfg(test)]
36mod test {
37 use crate::fact::StreamInfo;
38
39 use super::*;
40
41 fn test_pulse_delay_over(pulse: usize, delay: usize, overlap: usize) {
42 let mut model = PulsedModel::default();
43 let stream_dim = model.symbols.sym("S").to_dim();
44 let fact1 = PulsedFact {
45 datum_type: u8::datum_type(),
46 shape: (&[pulse]).into(),
47 stream: Some(StreamInfo { axis: 0, dim: stream_dim, delay: 0 }),
48 };
49 let source = model.add_source("source", fact1.clone()).unwrap();
50 model
51 .wire_node(
52 "delay",
53 Delay::new_typed(&(&fact1).into(), fact1.stream.unwrap().axis, delay, overlap),
54 &[source],
55 )
56 .unwrap();
57 model.auto_outputs().unwrap();
58
59 let plan = SimplePlan::new(model).unwrap();
60 let mut state = tract_core::plan::SimpleState::new(plan).unwrap();
61
62 for i in 0..5 {
63 let input: Vec<u8> = (pulse * i..(pulse * (i + 1))).map(|a| a as u8).collect();
64 let expect: Vec<u8> = (pulse * i..(pulse * (i + 1) + overlap))
65 .map(|i| i.saturating_sub(delay + overlap) as u8)
66 .collect();
67 let output = state.run(tvec!(tensor1(&input).into())).unwrap();
68 let skip = (delay + overlap).saturating_sub(i * pulse).min(pulse + overlap);
69 assert_eq!(&output[0].as_slice::<u8>().unwrap()[skip..], &expect[skip..]);
70 }
71 }
72
73 #[test]
74 fn sub_pulse() {
75 test_pulse_delay_over(4, 1, 0);
76 }
77
78 #[test]
79 fn supra_pulse() {
80 test_pulse_delay_over(4, 5, 0);
81 }
82
83 #[test]
84 fn sub_pulse_context() {
85 test_pulse_delay_over(4, 0, 2);
86 }
87
88 #[test]
89 fn supra_pulse_context() {
90 test_pulse_delay_over(4, 0, 6);
91 }
92
93 #[test]
94 fn test_two_delays() {
95 let pulse = 4usize;
96 let mut model = PulsedModel::default();
97 let stream_dim = model.symbols.sym("S").to_dim();
98 let fact_0 = PulsedFact {
99 datum_type: u8::datum_type(),
100 shape: (&[pulse]).into(),
101 stream: Some(StreamInfo { axis: 0, dim: stream_dim, delay: 0 }),
102 };
103 let stream = fact_0.stream.as_ref().unwrap();
104 let source = model.add_source("source", fact_0.clone()).unwrap();
105 let delay_1 = model
106 .wire_node("delay-1", Delay::new_typed(&(&fact_0).into(), stream.axis, 2, 0), &[source])
107 .unwrap()[0];
108 let fact_1 = model.outlet_fact(delay_1).unwrap().clone();
109 let delay_2 = model
110 .wire_node(
111 "delay-1",
112 Delay::new_typed(&(&fact_1).into(), stream.axis, 2, 0),
113 &[delay_1],
114 )
115 .unwrap();
116 model.set_output_outlets(&delay_2).unwrap();
117
118 let plan = SimplePlan::new(model).unwrap();
119 let mut state = tract_core::plan::SimpleState::new(plan).unwrap();
120
121 for i in 0..5 {
122 let input: Vec<u8> = (pulse * i..(pulse * (i + 1))).map(|a| a as u8).collect();
123 let expect: Vec<u8> =
124 (pulse * i..(pulse * (i + 1))).map(|i| i.saturating_sub(4) as u8).collect();
125 let skip = 4usize.saturating_sub(i * pulse).min(pulse);
126 let output = state.run(tvec!(tensor1(&input).into())).unwrap();
127 assert_eq!(&output[0].as_slice::<u8>().unwrap()[skip..], &expect[skip..]);
128 }
129 }
130}