tract_pulse/ops/cnn/
deconv.rs1use crate::internal::*;
2use tract_core::num_traits::Zero;
3use tract_core::ops::cnn::Deconv;
4use tract_core::ops::cnn::PaddingSpec;
5use tract_pulse_opl::ops::DeconvDelay;
6use tract_pulse_opl::ops::PulseMask;
7
8register_all!(Deconv: pulsify);
9
10fn pulsify(
11 op: &Deconv,
12 source: &TypedModel,
13 node: &TypedNode,
14 target: &mut PulsedModel,
15 mapping: &HashMap<OutletId, OutletId>,
16 _symbol: &Symbol,
17 _pulse: &TDim,
18) -> TractResult<Option<TVec<OutletId>>> {
19 let fact = target.outlet_fact(mapping[&node.inputs[0]])?.clone();
20 let pulse = fact.pulse().unwrap();
21 let stream = fact.stream.as_ref().unwrap();
22 let c_axis = op.pool_spec.data_format.shape(&fact.shape)?.c_axis();
23 if c_axis == stream.axis {
24 bail!("Pulsification on C axis is not supported");
25 }
26 if op
27 .axes_mapping(&source.node_input_facts(node.id)?, &source.node_output_facts(node.id)?)?
28 .axis((InOut::In(0), stream.axis))?
29 .outputs[0]
30 .len()
31 == 1
32 {
33 return Ok(None);
35 }
36 let geo_axis = stream.axis - op.pool_spec.data_format.h_axis();
37 let stride = op.pool_spec.stride(geo_axis);
38 let mut pulse_op = op.clone();
39 pulse_op.adjustments[geo_axis] = stride - 1;
40 pulse_op.pool_spec.padding = PaddingSpec::Valid;
41 let mut wire = tvec![mapping[&node.inputs[0]]];
42 let mask = PulseMask {
43 axis: stream.axis,
44 begin: stream.delay,
45 end: stream.dim.clone() + stream.delay,
46 value: Tensor::zero_scalar_dt(fact.datum_type)?,
47 };
48 wire = target.wire_node(format!("{}.mask", node.name), mask, &wire)?;
49 wire.push(mapping[&node.inputs[1]]);
50 wire.push(mapping[&node.inputs[2]]);
51 wire = target.wire_node(format!("{}.deconv", node.name), pulse_op, &wire)?;
52 let overlap = overlap(stream.axis, op);
53 let deconv_input_dim = (stream.dim.clone() - 1) * stride + 1;
54 let output_shape = tract_core::ops::cnn::deconv::output_shape(
55 &op.pool_spec,
56 &fact.streaming_shape(),
57 &op.adjustments,
58 )?;
59 let kernel_spatial_shape = &op.pool_spec.kernel_shape;
60 let shape = op.pool_spec.data_format.shape(fact.streaming_shape())?;
61 let paddings = op.pool_spec.padding.compute_for_deconv(
62 shape.hw_dims(),
63 kernel_spatial_shape,
64 &op.pool_spec.dilations(),
65 &op.pool_spec.strides(),
66 &op.adjustments,
67 )?;
68 wire = target.wire_node(
69 &node.name,
70 DeconvDelay {
71 axis: stream.axis,
72 overlap,
73 delay: paddings[geo_axis].pad_before.to_usize()? + stream.delay,
74 deconv_input_dim,
75 stride,
76 pulse: pulse.to_owned(),
77 deconv_output_dim: output_shape[stream.axis].clone(),
78 },
79 &wire,
80 )?;
81
82 for (geo_axis, padding) in paddings.iter().enumerate() {
83 if !padding.pad_before.is_zero() || !padding.pad_after.is_zero() {
84 let axis = geo_axis + shape.h_axis();
85 if axis == stream.axis {
86 continue;
87 };
88 let op = crate::model::PulseWrappingOp(Box::new(tract_core::ops::array::Slice::new(
89 axis,
90 padding.pad_before.clone(),
91 padding.deconvoluted.clone() + &padding.pad_before,
92 )));
93 wire = target.wire_node(format!("{}.padding.{}", node.name, geo_axis), op, &wire)?;
94 }
95 }
96
97 Ok(Some(wire))
98}
99
100fn overlap(pulse_axis: usize, op: &Deconv) -> usize {
101 let geo_axis = pulse_axis - op.pool_spec.data_format.h_axis();
102 (op.pool_spec.kernel_shape[geo_axis] - 1) * op.pool_spec.dilation(geo_axis)
103}
104
105impl PulsedOp for Deconv {
106 fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
107 let mut fact = inputs[0].clone();
108 let stream = fact.stream.as_mut().unwrap();
109 let overlap = overlap(stream.axis, self);
110 let geo_axis = stream.axis - self.pool_spec.data_format.h_axis();
111 let stride = self.pool_spec.stride(geo_axis);
112 let mut output_shape = tract_core::ops::cnn::deconv::output_shape(
113 &self.pool_spec,
114 &inputs[0].streaming_shape(),
115 &self.adjustments,
116 )?;
117 stream.dim = output_shape[stream.axis].clone();
118 let pulse_len = fact.shape[stream.axis].clone() * stride;
119 output_shape[stream.axis] = pulse_len + overlap;
120 let c_axis = self.pool_spec.data_format.shape(&output_shape)?.c_axis();
121 output_shape[c_axis] = self.pool_spec.output_channels.into();
122 fact.shape = output_shape.into();
123 Ok(tvec!(fact))
124 }
125
126 as_op!();
127 pulsed_op_to_typed_op!();
128}
129
130impl PulsedOp for DeconvDelay {
131 fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
132 let mut fact = inputs[0].clone();
133 let stream = fact.stream.as_mut().unwrap();
134 stream.dim = self.deconv_output_dim.clone();
135 let pulse_len = fact.shape[stream.axis].clone();
136 fact.shape.set(stream.axis, pulse_len - self.overlap);
137 stream.delay = self.delay;
138 Ok(tvec!(fact))
139 }
140
141 as_op!();
142 pulsed_op_to_typed_op!();
143}