tract_pulse_opl/
delay.rs

1use tract_nnef::internal::*;
2use tract_nnef::tract_core::ops::OpStateFreeze;
3
4pub fn register(registry: &mut Registry) {
5    registry.register_primitive(
6        "tract_pulse_delay",
7        &[
8            TypeName::Scalar.tensor().named("input"),
9            TypeName::Integer.named("axis"),
10            TypeName::Integer.named("delay"),
11            TypeName::Integer.named("overlap"),
12        ],
13        &[("output", TypeName::Scalar.tensor())],
14        de_delay,
15    );
16}
17
18fn de_delay(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
19    let wire = invocation.named_arg_as(builder, "input")?;
20    let axis = invocation.named_arg_as::<i64>(builder, "axis")? as usize;
21    let delay = invocation.named_arg_as::<i64>(builder, "delay")? as usize;
22    let overlap = invocation.named_arg_as::<i64>(builder, "overlap")? as usize;
23    let input_fact = builder.model.outlet_fact(wire)?;
24    let op = Delay::new_typed(input_fact, axis, delay, overlap);
25    builder.wire(op, &[wire])
26}
27
28#[derive(Debug, Clone)]
29pub struct DelayState {
30    pub buffer: Option<Tensor>,
31}
32
33impl DelayState {
34    /// Apply delay op on input and store the result in the output tensor
35    /// This method doesn't use allocation.
36    ///
37    /// # Safety
38    ///
39    /// Input and Ouput tensors shape must be compatible with this operator, otherwise it could lead
40    /// to an undefined behaviour.
41    pub unsafe fn apply_delay_unchecked(
42        &mut self,
43        op: &Delay,
44        input: &Tensor,
45        output: &mut Tensor,
46    ) {
47        let buffered = op.delay + op.overlap;
48        let input_pulse = input.shape()[op.axis];
49        let output_pulse = input_pulse + op.overlap;
50        let buffer = self.buffer.as_mut().unwrap();
51        if op.delay < input_pulse {
52            let from_input = input_pulse - op.delay;
53            let from_buffer = output_pulse - from_input;
54            output.assign_slice_unchecked(..from_buffer, buffer, ..from_buffer, op.axis);
55            output.assign_slice_unchecked(from_buffer.., input, ..from_input, op.axis);
56        } else {
57            output.assign_slice_unchecked(.., buffer, ..output_pulse, op.axis);
58        };
59        // maintain buffer
60        if buffered < input_pulse {
61            buffer.assign_slice_unchecked(.., input, (input_pulse - buffered).., op.axis);
62        } else {
63            let stride = buffer.shape().iter().skip(op.axis + 1).product::<usize>()
64                * input.datum_type().size_of()
65                * input_pulse;
66            std::slice::from_raw_parts_mut(
67                buffer.as_ptr_mut_unchecked::<u8>(),
68                buffer.len() * input.datum_type().size_of(),
69            )
70            .rotate_left(stride);
71            buffer.assign_slice_unchecked((buffered - input_pulse).., input, .., op.axis);
72        }
73    }
74}
75
76impl OpState for DelayState {
77    fn eval(
78        &mut self,
79        _state: &mut SessionState,
80        op: &dyn Op,
81        inputs: TVec<TValue>,
82    ) -> TractResult<TVec<TValue>> {
83        let input = args_1!(inputs);
84        let op = op.downcast_ref::<Delay>().ok_or_else(|| format_err!("Wrong Op type"))?;
85        let buffered = op.delay + op.overlap;
86        let input_pulse = input.shape()[op.axis];
87        let output_pulse = input_pulse + op.overlap;
88        let mut output_shape: TVec<usize> = input.shape().into();
89        output_shape[op.axis] = output_pulse;
90        // build output
91        unsafe {
92            if self.buffer.is_none() {
93                let mut shape = input.shape().to_owned();
94                shape[op.axis] = buffered;
95                self.buffer = Some(Tensor::uninitialized_dt(input.datum_type(), &shape)?);
96            };
97            let mut output = Tensor::uninitialized_dt(input.datum_type(), &output_shape)?;
98            self.apply_delay_unchecked(op, &input, &mut output);
99            Ok(tvec!(output.into()))
100        }
101    }
102}
103
104#[derive(Clone, Debug, PartialEq, Eq, Hash)]
105pub struct Delay {
106    pub buffer_shape: TVec<TDim>,
107    pub axis: usize,
108    pub delay: usize,
109    pub overlap: usize,
110}
111
112impl Delay {
113    pub fn new_typed(input_fact: &TypedFact, axis: usize, delay: usize, overlap: usize) -> Delay {
114        let mut buffer_shape: TVec<TDim> = input_fact.shape.to_tvec();
115        buffer_shape[axis] = (delay + overlap).to_dim();
116        Delay { buffer_shape, axis, delay, overlap }
117    }
118}
119
120impl Op for Delay {
121    fn name(&self) -> Cow<str> {
122        "Delay".into()
123    }
124
125    fn info(&self) -> TractResult<Vec<String>> {
126        Ok(vec![
127            format!("axis: {} delay: {} overlap: {}", self.axis, self.delay, self.overlap),
128            format!("buffer: {:?}", self.buffer_shape),
129        ])
130    }
131
132    impl_op_same_as!();
133    op_as_typed_op!();
134}
135
136impl EvalOp for Delay {
137    fn is_stateless(&self) -> bool {
138        false
139    }
140
141    fn state(
142        &self,
143        _session: &mut SessionState,
144        _node_id: usize,
145    ) -> TractResult<Option<Box<dyn OpState>>> {
146        Ok(Some(Box::new(DelayState { buffer: None })))
147    }
148}
149
150impl TypedOp for Delay {
151    as_op!();
152
153    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
154        let mut fact = inputs[0].clone();
155        fact.shape.set(self.axis, fact.shape[self.axis].clone() + self.overlap.to_dim());
156        Ok(tvec!(fact))
157    }
158
159    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
160        Ok(tvec!((Cost::Buffer(inputs[0].datum_type), self.buffer_shape.iter().product())))
161    }
162
163    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
164        if self.axis != 0 {
165            Ok(tvec!((InOut::In(0), AxisOp::Move(self.axis, 0))))
166        } else {
167            Ok(tvec!())
168        }
169    }
170
171    fn change_axes(
172        &self,
173        model: &TypedModel,
174        node: &TypedNode,
175        _io: InOut,
176        change: &AxisOp,
177    ) -> TractResult<Option<AxisChangeConsequence>> {
178        if let Some(axis) = change.transform_axis(self.axis) {
179            if axis != self.axis {
180                Ok(Some(AxisChangeConsequence::new(
181                    model,
182                    node,
183                    Some(Box::new(Self { axis, ..self.clone() }) as _),
184                    change,
185                )))
186            } else {
187                Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
188            }
189        } else {
190            Ok(None)
191        }
192    }
193}
194
195#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
196struct FrozenDelayState {
197    buffer: Option<Arc<Tensor>>,
198}
199
200impl OpStateFreeze for DelayState {
201    fn freeze(&self) -> Box<dyn FrozenOpState> {
202        Box::new(FrozenDelayState { buffer: self.buffer.as_ref().map(|t| t.clone().into_arc_tensor()) })
203    }
204}
205
206impl FrozenOpState for FrozenDelayState {
207    fn unfreeze(&self) -> Box<dyn OpState> {
208        Box::new(DelayState { buffer: self.buffer.as_ref().map(|t| t.clone().into_tensor()) })
209    }
210}