1use tract_nnef::internal::*;
2use tract_nnef::tract_core::ops::OpStateFreeze;
3
4pub fn register(registry: &mut Registry) {
5 registry.register_primitive(
6 "tract_pulse_delay",
7 &[
8 TypeName::Scalar.tensor().named("input"),
9 TypeName::Integer.named("axis"),
10 TypeName::Integer.named("delay"),
11 TypeName::Integer.named("overlap"),
12 ],
13 &[("output", TypeName::Scalar.tensor())],
14 de_delay,
15 );
16}
17
18fn de_delay(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
19 let wire = invocation.named_arg_as(builder, "input")?;
20 let axis = invocation.named_arg_as::<i64>(builder, "axis")? as usize;
21 let delay = invocation.named_arg_as::<i64>(builder, "delay")? as usize;
22 let overlap = invocation.named_arg_as::<i64>(builder, "overlap")? as usize;
23 let input_fact = builder.model.outlet_fact(wire)?;
24 let op = Delay::new_typed(input_fact, axis, delay, overlap);
25 builder.wire(op, &[wire])
26}
27
28#[derive(Debug, Clone)]
29pub struct DelayState {
30 pub buffer: Option<Tensor>,
31}
32
33impl DelayState {
34 pub unsafe fn apply_delay_unchecked(
42 &mut self,
43 op: &Delay,
44 input: &Tensor,
45 output: &mut Tensor,
46 ) {
47 let buffered = op.delay + op.overlap;
48 let input_pulse = input.shape()[op.axis];
49 let output_pulse = input_pulse + op.overlap;
50 let buffer = self.buffer.as_mut().unwrap();
51 if op.delay < input_pulse {
52 let from_input = input_pulse - op.delay;
53 let from_buffer = output_pulse - from_input;
54 output.assign_slice_unchecked(..from_buffer, buffer, ..from_buffer, op.axis);
55 output.assign_slice_unchecked(from_buffer.., input, ..from_input, op.axis);
56 } else {
57 output.assign_slice_unchecked(.., buffer, ..output_pulse, op.axis);
58 };
59 if buffered < input_pulse {
61 buffer.assign_slice_unchecked(.., input, (input_pulse - buffered).., op.axis);
62 } else {
63 let stride = buffer.shape().iter().skip(op.axis + 1).product::<usize>()
64 * input.datum_type().size_of()
65 * input_pulse;
66 std::slice::from_raw_parts_mut(
67 buffer.as_ptr_mut_unchecked::<u8>(),
68 buffer.len() * input.datum_type().size_of(),
69 )
70 .rotate_left(stride);
71 buffer.assign_slice_unchecked((buffered - input_pulse).., input, .., op.axis);
72 }
73 }
74}
75
76impl OpState for DelayState {
77 fn eval(
78 &mut self,
79 _state: &mut SessionState,
80 op: &dyn Op,
81 inputs: TVec<TValue>,
82 ) -> TractResult<TVec<TValue>> {
83 let input = args_1!(inputs);
84 let op = op.downcast_ref::<Delay>().ok_or_else(|| format_err!("Wrong Op type"))?;
85 let buffered = op.delay + op.overlap;
86 let input_pulse = input.shape()[op.axis];
87 let output_pulse = input_pulse + op.overlap;
88 let mut output_shape: TVec<usize> = input.shape().into();
89 output_shape[op.axis] = output_pulse;
90 unsafe {
92 if self.buffer.is_none() {
93 let mut shape = input.shape().to_owned();
94 shape[op.axis] = buffered;
95 self.buffer = Some(Tensor::uninitialized_dt(input.datum_type(), &shape)?);
96 };
97 let mut output = Tensor::uninitialized_dt(input.datum_type(), &output_shape)?;
98 self.apply_delay_unchecked(op, &input, &mut output);
99 Ok(tvec!(output.into()))
100 }
101 }
102}
103
104#[derive(Clone, Debug, PartialEq, Eq, Hash)]
105pub struct Delay {
106 pub buffer_shape: TVec<TDim>,
107 pub axis: usize,
108 pub delay: usize,
109 pub overlap: usize,
110}
111
112impl Delay {
113 pub fn new_typed(input_fact: &TypedFact, axis: usize, delay: usize, overlap: usize) -> Delay {
114 let mut buffer_shape: TVec<TDim> = input_fact.shape.to_tvec();
115 buffer_shape[axis] = (delay + overlap).to_dim();
116 Delay { buffer_shape, axis, delay, overlap }
117 }
118}
119
120impl Op for Delay {
121 fn name(&self) -> Cow<str> {
122 "Delay".into()
123 }
124
125 fn info(&self) -> TractResult<Vec<String>> {
126 Ok(vec![
127 format!("axis: {} delay: {} overlap: {}", self.axis, self.delay, self.overlap),
128 format!("buffer: {:?}", self.buffer_shape),
129 ])
130 }
131
132 impl_op_same_as!();
133 op_as_typed_op!();
134}
135
136impl EvalOp for Delay {
137 fn is_stateless(&self) -> bool {
138 false
139 }
140
141 fn state(
142 &self,
143 _session: &mut SessionState,
144 _node_id: usize,
145 ) -> TractResult<Option<Box<dyn OpState>>> {
146 Ok(Some(Box::new(DelayState { buffer: None })))
147 }
148}
149
150impl TypedOp for Delay {
151 as_op!();
152
153 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
154 let mut fact = inputs[0].clone();
155 fact.shape.set(self.axis, fact.shape[self.axis].clone() + self.overlap.to_dim());
156 Ok(tvec!(fact))
157 }
158
159 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
160 Ok(tvec!((Cost::Buffer(inputs[0].datum_type), self.buffer_shape.iter().product())))
161 }
162
163 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
164 if self.axis != 0 {
165 Ok(tvec!((InOut::In(0), AxisOp::Move(self.axis, 0))))
166 } else {
167 Ok(tvec!())
168 }
169 }
170
171 fn change_axes(
172 &self,
173 model: &TypedModel,
174 node: &TypedNode,
175 _io: InOut,
176 change: &AxisOp,
177 ) -> TractResult<Option<AxisChangeConsequence>> {
178 if let Some(axis) = change.transform_axis(self.axis) {
179 if axis != self.axis {
180 Ok(Some(AxisChangeConsequence::new(
181 model,
182 node,
183 Some(Box::new(Self { axis, ..self.clone() }) as _),
184 change,
185 )))
186 } else {
187 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
188 }
189 } else {
190 Ok(None)
191 }
192 }
193}
194
195#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
196struct FrozenDelayState {
197 buffer: Option<Arc<Tensor>>,
198}
199
200impl OpStateFreeze for DelayState {
201 fn freeze(&self) -> Box<dyn FrozenOpState> {
202 Box::new(FrozenDelayState { buffer: self.buffer.as_ref().map(|t| t.clone().into_arc_tensor()) })
203 }
204}
205
206impl FrozenOpState for FrozenDelayState {
207 fn unfreeze(&self) -> Box<dyn OpState> {
208 Box::new(DelayState { buffer: self.buffer.as_ref().map(|t| t.clone().into_tensor()) })
209 }
210}