1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
use std::ops::AddAssign;

use tract_ndarray::Axis;
use tract_nnef::internal::*;
use tract_nnef::tract_core::ops::OpStateFreeze;
use tract_num_traits::Zero;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DeconvDelay {
    pub axis: usize,
    pub overlap: usize,
    pub delay: usize,
    pub stride: usize,
    pub pulse: TDim,
    pub deconv_input_dim: TDim,
    pub deconv_output_dim: TDim,
}

impl_dyn_hash!(DeconvDelay);

impl Op for DeconvDelay {
    fn name(&self) -> Cow<str> {
        "DeconvDelay".into()
    }

    op_as_typed_op!();
}

impl EvalOp for DeconvDelay {
    fn is_stateless(&self) -> bool {
        false
    }

    fn eval(&self, _inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        unreachable!()
    }

    fn state(
        &self,
        _session: &mut SessionState,
        _node_id: usize,
    ) -> TractResult<Option<Box<dyn OpState>>> {
        Ok(Some(Box::new(DeconvDelayState { valid_inputed: -(self.delay as isize), buffer: None })))
    }
}

impl TypedOp for DeconvDelay {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let mut fact = inputs[0].clone();
        let len = fact.shape[self.axis].clone();
        fact.shape.set(self.axis, len - self.overlap);
        Ok(tvec!(fact))
    }

    as_op!();
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
pub struct DeconvDelayState {
    valid_inputed: isize,
    buffer: Option<Tensor>,
}

impl OpState for DeconvDelayState {
    fn eval(
        &mut self,
        session: &mut SessionState,
        op: &dyn Op,
        inputs: TVec<TValue>,
    ) -> TractResult<TVec<TValue>> {
        let op = op.downcast_ref::<DeconvDelay>().context("Wrong op")?;
        if self.buffer.is_none() {
            let mut buffer_size: TVec<usize> = inputs[0].shape().into();
            buffer_size[op.axis] = op.overlap; //+ (op.stride - 1) * (op.pulse - 1);
            self.buffer = Some(Tensor::zero_dt(inputs[0].datum_type(), &buffer_size)?);
        }
        let mut input = inputs[0].clone().into_tensor();
        dispatch_numbers!(Self::eval_t(input.datum_type())(self, session, op, &mut input))?;
        let output = input.slice(op.axis, 0, input.shape()[op.axis] - op.overlap)?;
        Ok(tvec!(output.into_tvalue()))
    }
}

impl DeconvDelayState {
    fn eval_t<T: Datum + AddAssign + Zero>(
        &mut self,
        session: &SessionState,
        op: &DeconvDelay,
        input: &mut Tensor,
    ) -> TractResult<()> {
        let buffer = self.buffer.as_mut().unwrap();
        let mut buffer = buffer.to_array_view_mut::<T>()?;
        let mut input = input.to_array_view_mut::<T>()?;
        let input_pulse = input.shape()[op.axis];
        let output_pulse = input_pulse - op.overlap;
        self.valid_inputed += output_pulse as isize;
        if let Ok(input_dim) = op.deconv_input_dim.eval(&session.resolved_symbols).to_isize() {
            if self.valid_inputed > input_dim {
                let to_be_zeroed = ((self.valid_inputed - input_dim) as usize).min(input_pulse);
                let mut zeroed =
                    input.slice_axis_mut(Axis(op.axis), (input_pulse - to_be_zeroed..).into());
                zeroed.fill(T::zero());
            }
        }
        {
            let mut input_view = input.slice_axis_mut(Axis(op.axis), (0..op.overlap).into());
            input_view += &buffer;
        }
        buffer.assign(&input.slice_axis(Axis(op.axis), (output_pulse..).into()));

        Ok(())
    }
}

#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
struct FrozenDeconvDelayState {
    valid_inputed: isize,
    buffer: Option<Arc<Tensor>>,
}

impl OpStateFreeze for DeconvDelayState {
    fn freeze(&self) -> Box<dyn FrozenOpState> {
        Box::new(FrozenDeconvDelayState {
            valid_inputed: self.valid_inputed,
            buffer: self.buffer.as_ref().map(|t| t.clone().into_arc_tensor()),
        })
    }
}

impl FrozenOpState for FrozenDeconvDelayState {
    fn unfreeze(&self) -> Box<dyn OpState> {
        Box::new(DeconvDelayState {
            valid_inputed: self.valid_inputed,
            buffer: self.buffer.as_ref().map(|t| t.clone().into_tensor()),
        })
    }
}