tract_pulse_opl/
deconv_delay.rs

1use std::ops::AddAssign;
2
3use tract_ndarray::Axis;
4use tract_nnef::internal::*;
5use tract_nnef::tract_core::ops::OpStateFreeze;
6use tract_num_traits::Zero;
7
8#[derive(Debug, Clone, PartialEq, Eq, Hash)]
9pub struct DeconvDelay {
10    pub axis: usize,
11    pub overlap: usize,
12    pub delay: usize,
13    pub stride: usize,
14    pub pulse: TDim,
15    pub deconv_input_dim: TDim,
16    pub deconv_output_dim: TDim,
17}
18
19
20
21impl Op for DeconvDelay {
22    fn name(&self) -> Cow<str> {
23        "DeconvDelay".into()
24    }
25
26    op_as_typed_op!();
27}
28
29impl EvalOp for DeconvDelay {
30    fn is_stateless(&self) -> bool {
31        false
32    }
33
34    fn eval(&self, _inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
35        unreachable!()
36    }
37
38    fn state(
39        &self,
40        _session: &mut SessionState,
41        _node_id: usize,
42    ) -> TractResult<Option<Box<dyn OpState>>> {
43        Ok(Some(Box::new(DeconvDelayState { valid_inputed: -(self.delay as isize), buffer: None })))
44    }
45}
46
47impl TypedOp for DeconvDelay {
48    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
49        let mut fact = inputs[0].clone();
50        let len = fact.shape[self.axis].clone();
51        fact.shape.set(self.axis, len - self.overlap);
52        Ok(tvec!(fact))
53    }
54
55    as_op!();
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
59pub struct DeconvDelayState {
60    valid_inputed: isize,
61    buffer: Option<Tensor>,
62}
63
64impl OpState for DeconvDelayState {
65    fn eval(
66        &mut self,
67        session: &mut SessionState,
68        op: &dyn Op,
69        inputs: TVec<TValue>,
70    ) -> TractResult<TVec<TValue>> {
71        let op = op.downcast_ref::<DeconvDelay>().context("Wrong op")?;
72        if self.buffer.is_none() {
73            let mut buffer_size: TVec<usize> = inputs[0].shape().into();
74            buffer_size[op.axis] = op.overlap; //+ (op.stride - 1) * (op.pulse - 1);
75            self.buffer = Some(Tensor::zero_dt(inputs[0].datum_type(), &buffer_size)?);
76        }
77        let mut input = inputs[0].clone().into_tensor();
78        dispatch_numbers!(Self::eval_t(input.datum_type())(self, session, op, &mut input))?;
79        let output = input.slice(op.axis, 0, input.shape()[op.axis] - op.overlap)?;
80        Ok(tvec!(output.into_tvalue()))
81    }
82}
83
84impl DeconvDelayState {
85    fn eval_t<T: Datum + AddAssign + Zero>(
86        &mut self,
87        session: &SessionState,
88        op: &DeconvDelay,
89        input: &mut Tensor,
90    ) -> TractResult<()> {
91        let buffer = self.buffer.as_mut().unwrap();
92        let mut buffer = buffer.to_array_view_mut::<T>()?;
93        let mut input = input.to_array_view_mut::<T>()?;
94        let input_pulse = input.shape()[op.axis];
95        let output_pulse = input_pulse - op.overlap;
96        self.valid_inputed += output_pulse as isize;
97        if let Ok(input_dim) = op.deconv_input_dim.eval(&session.resolved_symbols).to_isize() {
98            if self.valid_inputed > input_dim {
99                let to_be_zeroed = ((self.valid_inputed - input_dim) as usize).min(input_pulse);
100                let mut zeroed =
101                    input.slice_axis_mut(Axis(op.axis), (input_pulse - to_be_zeroed..).into());
102                zeroed.fill(T::zero());
103            }
104        }
105        {
106            let mut input_view = input.slice_axis_mut(Axis(op.axis), (0..op.overlap).into());
107            input_view += &buffer;
108        }
109        buffer.assign(&input.slice_axis(Axis(op.axis), (output_pulse..).into()));
110
111        Ok(())
112    }
113}
114
115#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
116struct FrozenDeconvDelayState {
117    valid_inputed: isize,
118    buffer: Option<Arc<Tensor>>,
119}
120
121impl OpStateFreeze for DeconvDelayState {
122    fn freeze(&self) -> Box<dyn FrozenOpState> {
123        Box::new(FrozenDeconvDelayState {
124            valid_inputed: self.valid_inputed,
125            buffer: self.buffer.as_ref().map(|t| t.clone().into_arc_tensor()),
126        })
127    }
128}
129
130impl FrozenOpState for FrozenDeconvDelayState {
131    fn unfreeze(&self) -> Box<dyn OpState> {
132        Box::new(DeconvDelayState {
133            valid_inputed: self.valid_inputed,
134            buffer: self.buffer.as_ref().map(|t| t.clone().into_tensor()),
135        })
136    }
137}