1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
use crate::internal::*; use tract_pulse_opl::ops::Delay; pub fn register(registry: &mut Registry) { registry.register_dumper(TypeId::of::<Delay>(), ser_delay) } fn ser_delay(ast: &mut IntoAst, node: &TypedNode) -> TractResult<Option<Arc<RValue>>> { let op = node.op().downcast_ref::<Delay>().unwrap(); let wire = ast.mapping[&node.inputs[0]].clone(); Ok(Some(invocation( "tract_pulse_delay", &[wire], &[ ("axis", numeric(op.axis)), ("delay", numeric(op.delay)), ("overlap", numeric(op.overlap)), ], ))) } impl PulsedOp for Delay { fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> { let mut fact = inputs[0].clone(); fact.shape[self.axis] += self.overlap; fact.delay += self.delay + self.overlap; Ok(tvec!(fact)) } as_op!(); pulsed_op_to_typed_op!(); } #[cfg(test)] mod test { use super::*; fn test_pulse_delay_over(pulse: usize, delay: usize, overlap: usize) { let mut model = PulsedModel::default(); let fact1 = PulsedFact { datum_type: u8::datum_type(), shape: tvec![pulse.to_dim()], axis: 0, dim: stream_dim(), delay: 0, }; let source = model.add_source("source", fact1.clone()).unwrap(); model .wire_node("delay", Delay::new(fact1.axis, &(&fact1).into(), delay, overlap), &[source]) .unwrap(); model.auto_outputs().unwrap(); let plan = SimplePlan::new(model).unwrap(); let mut state = tract_core::plan::SimpleState::new(plan).unwrap(); for i in 0..5 { let input: Vec<u8> = (pulse * i..(pulse * (i + 1))).map(|a| a as u8).collect(); let expect: Vec<u8> = (pulse * i..(pulse * (i + 1) + overlap)) .map(|i| i.saturating_sub(delay + overlap) as u8) .collect(); let output = state.run(tvec!(tensor1(&input))).unwrap(); let skip = (delay + overlap).saturating_sub(i * pulse).min(pulse + overlap); assert_eq!(&output[0].as_slice::<u8>().unwrap()[skip..], &expect[skip..]); } } #[test] fn sub_pulse() { test_pulse_delay_over(4, 1, 0); } #[test] fn supra_pulse() { test_pulse_delay_over(4, 5, 0); } #[test] fn sub_pulse_context() { test_pulse_delay_over(4, 0, 2); } #[test] fn supra_pulse_context() { test_pulse_delay_over(4, 0, 6); } #[test] fn test_two_delays() { let pulse = 4usize; let mut model = PulsedModel::default(); let fact_0 = PulsedFact { datum_type: u8::datum_type(), shape: tvec![pulse.to_dim()], axis: 0, dim: stream_dim(), delay: 0, }; let source = model.add_source("source", fact_0.clone()).unwrap(); let delay_1 = model .wire_node("delay-1", Delay::new(fact_0.axis, &(&fact_0).into(), 2, 0), &[source]) .unwrap()[0]; let fact_1 = model.outlet_fact(delay_1).unwrap().clone(); let delay_2 = model .wire_node("delay-1", Delay::new(fact_1.axis, &(&fact_1).into(), 2, 0), &[delay_1]) .unwrap(); model.set_output_outlets(&delay_2).unwrap(); let plan = SimplePlan::new(model).unwrap(); let mut state = tract_core::plan::SimpleState::new(plan).unwrap(); for i in 0..5 { let input: Vec<u8> = (pulse * i..(pulse * (i + 1))).map(|a| a as u8).collect(); let expect: Vec<u8> = (pulse * i..(pulse * (i + 1))).map(|i| i.saturating_sub(4) as u8).collect(); let skip = 4usize.saturating_sub(i * pulse).min(pulse); let output = state.run(tvec!(tensor1(&input))).unwrap(); assert_eq!(&output[0].as_slice::<u8>().unwrap()[skip..], &expect[skip..]); } } }