tract_pulse/ops/cnn/
deconv.rs

1use crate::internal::*;
2use tract_core::num_traits::Zero;
3use tract_core::ops::cnn::Deconv;
4use tract_core::ops::cnn::PaddingSpec;
5use tract_pulse_opl::ops::DeconvDelay;
6use tract_pulse_opl::ops::PulseMask;
7
8register_all!(Deconv: pulsify);
9
10fn pulsify(
11    op: &Deconv,
12    source: &TypedModel,
13    node: &TypedNode,
14    target: &mut PulsedModel,
15    mapping: &HashMap<OutletId, OutletId>,
16    _symbol: &Symbol,
17    _pulse: &TDim,
18) -> TractResult<Option<TVec<OutletId>>> {
19    let fact = target.outlet_fact(mapping[&node.inputs[0]])?.clone();
20    let pulse = fact.pulse().unwrap();
21    let stream = fact.stream.as_ref().unwrap();
22    let c_axis = op.pool_spec.data_format.shape(&fact.shape)?.c_axis();
23    if c_axis == stream.axis {
24        bail!("Pulsification on C axis is not supported");
25    }
26    if op
27        .axes_mapping(&source.node_input_facts(node.id)?, &source.node_output_facts(node.id)?)?
28        .axis((InOut::In(0), stream.axis))?
29        .outputs[0]
30        .len()
31        == 1
32    {
33        // general case for invariants will manage
34        return Ok(None);
35    }
36    let geo_axis = stream.axis - op.pool_spec.data_format.h_axis();
37    let stride = op.pool_spec.stride(geo_axis);
38    let mut pulse_op = op.clone();
39    pulse_op.adjustments[geo_axis] = stride - 1;
40    pulse_op.pool_spec.padding = PaddingSpec::Valid;
41    let mut wire = tvec![mapping[&node.inputs[0]]];
42    let mask = PulseMask {
43        axis: stream.axis,
44        begin: stream.delay,
45        end: stream.dim.clone() + stream.delay,
46        value: Tensor::zero_scalar_dt(fact.datum_type)?,
47    };
48    wire = target.wire_node(format!("{}.mask", node.name), mask, &wire)?;
49    wire.push(mapping[&node.inputs[1]]);
50    wire.push(mapping[&node.inputs[2]]);
51    wire = target.wire_node(format!("{}.deconv", node.name), pulse_op, &wire)?;
52    let overlap = overlap(stream.axis, op);
53    let deconv_input_dim = (stream.dim.clone() - 1) * stride + 1;
54    let output_shape = tract_core::ops::cnn::deconv::output_shape(
55        &op.pool_spec,
56        &fact.streaming_shape(),
57        &op.adjustments,
58    )?;
59    let kernel_spatial_shape = &op.pool_spec.kernel_shape;
60    let shape = op.pool_spec.data_format.shape(fact.streaming_shape())?;
61    let paddings = op.pool_spec.padding.compute_for_deconv(
62        shape.hw_dims(),
63        kernel_spatial_shape,
64        &op.pool_spec.dilations(),
65        &op.pool_spec.strides(),
66        &op.adjustments,
67    )?;
68    wire = target.wire_node(
69        &node.name,
70        DeconvDelay {
71            axis: stream.axis,
72            overlap,
73            delay: paddings[geo_axis].pad_before.to_usize()? + stream.delay,
74            deconv_input_dim,
75            stride,
76            pulse: pulse.to_owned(),
77            deconv_output_dim: output_shape[stream.axis].clone(),
78        },
79        &wire,
80    )?;
81
82    for (geo_axis, padding) in paddings.iter().enumerate() {
83        if !padding.pad_before.is_zero() || !padding.pad_after.is_zero() {
84            let axis = geo_axis + shape.h_axis();
85            if axis == stream.axis {
86                continue;
87            };
88            let op = crate::model::PulseWrappingOp(Box::new(tract_core::ops::array::Slice::new(
89                axis,
90                padding.pad_before.clone(),
91                padding.deconvoluted.clone() + &padding.pad_before,
92            )));
93            wire = target.wire_node(format!("{}.padding.{}", node.name, geo_axis), op, &wire)?;
94        }
95    }
96
97    Ok(Some(wire))
98}
99
100fn overlap(pulse_axis: usize, op: &Deconv) -> usize {
101    let geo_axis = pulse_axis - op.pool_spec.data_format.h_axis();
102    (op.pool_spec.kernel_shape[geo_axis] - 1) * op.pool_spec.dilation(geo_axis)
103}
104
105impl PulsedOp for Deconv {
106    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
107        let mut fact = inputs[0].clone();
108        let stream = fact.stream.as_mut().unwrap();
109        let overlap = overlap(stream.axis, self);
110        let geo_axis = stream.axis - self.pool_spec.data_format.h_axis();
111        let stride = self.pool_spec.stride(geo_axis);
112        let mut output_shape = tract_core::ops::cnn::deconv::output_shape(
113            &self.pool_spec,
114            &inputs[0].streaming_shape(),
115            &self.adjustments,
116        )?;
117        stream.dim = output_shape[stream.axis].clone();
118        let pulse_len = fact.shape[stream.axis].clone() * stride;
119        output_shape[stream.axis] = pulse_len + overlap;
120        let c_axis = self.pool_spec.data_format.shape(&output_shape)?.c_axis();
121        output_shape[c_axis] = self.pool_spec.output_channels.into();
122        fact.shape = output_shape.into();
123        Ok(tvec!(fact))
124    }
125
126    as_op!();
127    pulsed_op_to_typed_op!();
128}
129
130impl PulsedOp for DeconvDelay {
131    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
132        let mut fact = inputs[0].clone();
133        let stream = fact.stream.as_mut().unwrap();
134        stream.dim = self.deconv_output_dim.clone();
135        let pulse_len = fact.shape[stream.axis].clone();
136        fact.shape.set(stream.axis, pulse_len - self.overlap);
137        stream.delay = self.delay;
138        Ok(tvec!(fact))
139    }
140
141    as_op!();
142    pulsed_op_to_typed_op!();
143}