1use tract_nnef::internal::*;
2use tract_nnef::tract_core::ops::OpStateFreeze;
3
4pub fn register(registry: &mut Registry) {
5 registry.register_primitive(
6 "tract_pulse_delay",
7 &[
8 TypeName::Scalar.tensor().named("input"),
9 TypeName::Integer.named("axis"),
10 TypeName::Integer.named("delay"),
11 TypeName::Integer.named("overlap"),
12 ],
13 &[("output", TypeName::Scalar.tensor())],
14 de_delay,
15 );
16}
17
18fn de_delay(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
19 let wire = invocation.named_arg_as(builder, "input")?;
20 let axis = invocation.named_arg_as::<i64>(builder, "axis")? as usize;
21 let delay = invocation.named_arg_as::<i64>(builder, "delay")? as usize;
22 let overlap = invocation.named_arg_as::<i64>(builder, "overlap")? as usize;
23 let input_fact = builder.model.outlet_fact(wire)?;
24 let op = Delay::new_typed(input_fact, axis, delay, overlap);
25 builder.wire(op, &[wire])
26}
27
28#[derive(Debug, Clone)]
29pub struct DelayState {
30 pub buffer: Option<Tensor>,
31}
32
33impl DelayState {
34 pub unsafe fn apply_delay_unchecked(
42 &mut self,
43 op: &Delay,
44 input: &Tensor,
45 output: &mut Tensor,
46 ) {
47 unsafe {
48 let buffered = op.delay + op.overlap;
49 let input_pulse = input.shape()[op.axis];
50 let output_pulse = input_pulse + op.overlap;
51 let buffer = self.buffer.as_mut().unwrap();
52 if op.delay < input_pulse {
53 let from_input = input_pulse - op.delay;
54 let from_buffer = output_pulse - from_input;
55 output.assign_slice_unchecked(..from_buffer, buffer, ..from_buffer, op.axis);
56 output.assign_slice_unchecked(from_buffer.., input, ..from_input, op.axis);
57 } else {
58 output.assign_slice_unchecked(.., buffer, ..output_pulse, op.axis);
59 };
60 if buffered < input_pulse {
62 buffer.assign_slice_unchecked(.., input, (input_pulse - buffered).., op.axis);
63 } else {
64 let stride = buffer.shape().iter().skip(op.axis + 1).product::<usize>()
65 * input.datum_type().size_of()
66 * input_pulse;
67 std::slice::from_raw_parts_mut(
68 buffer.as_ptr_mut_unchecked::<u8>(),
69 buffer.len() * input.datum_type().size_of(),
70 )
71 .rotate_left(stride);
72 buffer.assign_slice_unchecked((buffered - input_pulse).., input, .., op.axis);
73 }
74 }
75 }
76}
77
78impl OpState for DelayState {
79 fn eval(
80 &mut self,
81 _state: &mut SessionState,
82 op: &dyn Op,
83 inputs: TVec<TValue>,
84 ) -> TractResult<TVec<TValue>> {
85 let input = args_1!(inputs);
86 let op = op.downcast_ref::<Delay>().ok_or_else(|| format_err!("Wrong Op type"))?;
87 let buffered = op.delay + op.overlap;
88 let input_pulse = input.shape()[op.axis];
89 let output_pulse = input_pulse + op.overlap;
90 let mut output_shape: TVec<usize> = input.shape().into();
91 output_shape[op.axis] = output_pulse;
92 unsafe {
94 if self.buffer.is_none() {
95 let mut shape = input.shape().to_owned();
96 shape[op.axis] = buffered;
97 self.buffer = Some(Tensor::uninitialized_dt(input.datum_type(), &shape)?);
98 };
99 let mut output = Tensor::uninitialized_dt(input.datum_type(), &output_shape)?;
100 self.apply_delay_unchecked(op, &input, &mut output);
101 Ok(tvec!(output.into()))
102 }
103 }
104}
105
106#[derive(Clone, Debug, PartialEq, Eq, Hash)]
107pub struct Delay {
108 pub buffer_shape: TVec<TDim>,
109 pub axis: usize,
110 pub delay: usize,
111 pub overlap: usize,
112}
113
114impl Delay {
115 pub fn new_typed(input_fact: &TypedFact, axis: usize, delay: usize, overlap: usize) -> Delay {
116 let mut buffer_shape: TVec<TDim> = input_fact.shape.to_tvec();
117 buffer_shape[axis] = (delay + overlap).to_dim();
118 Delay { buffer_shape, axis, delay, overlap }
119 }
120}
121
122impl Op for Delay {
123 fn name(&self) -> StaticName {
124 "Delay".into()
125 }
126
127 fn info(&self) -> TractResult<Vec<String>> {
128 Ok(vec![
129 format!("axis: {} delay: {} overlap: {}", self.axis, self.delay, self.overlap),
130 format!("buffer: {:?}", self.buffer_shape),
131 ])
132 }
133
134 impl_op_same_as!();
135 op_as_typed_op!();
136}
137
138impl EvalOp for Delay {
139 fn is_stateless(&self) -> bool {
140 false
141 }
142
143 fn state(
144 &self,
145 _session: &mut SessionState,
146 _node_id: usize,
147 ) -> TractResult<Option<Box<dyn OpState>>> {
148 Ok(Some(Box::new(DelayState { buffer: None })))
149 }
150}
151
152impl TypedOp for Delay {
153 as_op!();
154
155 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
156 let mut fact = inputs[0].clone();
157 fact.shape.set(self.axis, fact.shape[self.axis].clone() + self.overlap.to_dim());
158 Ok(tvec!(fact))
159 }
160
161 fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
162 Ok(tvec!((Cost::Buffer(inputs[0].datum_type), self.buffer_shape.iter().product())))
163 }
164
165 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
166 if self.axis != 0 {
167 Ok(tvec!((InOut::In(0), AxisOp::Move(self.axis, 0))))
168 } else {
169 Ok(tvec!())
170 }
171 }
172
173 fn change_axes(
174 &self,
175 model: &TypedModel,
176 node: &TypedNode,
177 _io: InOut,
178 change: &AxisOp,
179 ) -> TractResult<Option<AxisChangeConsequence>> {
180 if let Some(axis) = change.transform_axis(self.axis) {
181 if axis != self.axis {
182 Ok(Some(AxisChangeConsequence::new(
183 model,
184 node,
185 Some(Box::new(Self { axis, ..self.clone() }) as _),
186 change,
187 )))
188 } else {
189 Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
190 }
191 } else {
192 Ok(None)
193 }
194 }
195}
196
197#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
198struct FrozenDelayState {
199 buffer: Option<Arc<Tensor>>,
200}
201
202impl OpStateFreeze for DelayState {
203 fn freeze(&self) -> Box<dyn FrozenOpState> {
204 Box::new(FrozenDelayState {
205 buffer: self.buffer.as_ref().map(|t| t.clone().into_arc_tensor()),
206 })
207 }
208}
209
210impl FrozenOpState for FrozenDelayState {
211 fn unfreeze(&self) -> Box<dyn OpState> {
212 Box::new(DelayState { buffer: self.buffer.as_ref().map(|t| t.clone().into_tensor()) })
213 }
214}