tract_pulse/ops/
downsample.rs1use crate::internal::*;
2use tract_core::ops::Downsample;
3use tract_pulse_opl::ops::PulsedAxisSlice;
4use tract_pulse_opl::tract_nnef::tract_num_traits::Zero;
5
6register_all!(Downsample: pulsify);
7
8fn pulsify(
9 op: &Downsample,
10 _source: &TypedModel,
11 node: &TypedNode,
12 target: &mut PulsedModel,
13 mapping: &HashMap<OutletId, OutletId>,
14 _symbol: &Symbol,
15 _pulse: &TDim,
16) -> TractResult<Option<TVec<OutletId>>> {
17 let input = mapping[&node.inputs[0]];
18 let fact = target.outlet_fact(input)?.clone();
19 if let Some(stream) = fact.stream.as_ref() {
20 if stream.axis != op.axis {
21 return Ok(None);
22 }
23 let stride = if op.stride > 0 {
24 op.stride as usize
25 } else {
26 bail!("Negative strides are not causal, can not pulsify.")
27 };
28 let pulse = fact.pulse().unwrap();
29 if !(pulse.clone() % stride).is_zero() {
30 bail!("Pulsification requires pulse ({}) to be a stride ({}) multiple", pulse, stride)
31 }
32 let mut wire = tvec!(input);
33 let first_offset = stream.delay + op.modulo;
34 let new_op = Downsample { modulo: first_offset % stride, axis: op.axis, stride: op.stride };
35 wire = target.wire_node(format!("{}.downsample", node.name), new_op, &wire)?;
36 wire = target.wire_node(
37 &node.name,
38 PulsedAxisSlice {
39 axis: stream.axis,
40 skip: first_offset / stride,
41 take: (stream.dim.to_owned() - op.modulo).divceil(stride),
42 },
43 &wire,
44 )?;
45 target.rename_node(wire[0].node, &node.name)?;
46 Ok(Some(wire))
47 } else {
48 Ok(None)
49 }
50}
51
52impl PulsedOp for Downsample {
53 fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
54 let mut fact = inputs[0].clone();
55 let stream = fact.stream.as_mut().unwrap();
56 fact.shape.set(self.axis, fact.shape[self.axis].clone() / self.stride as usize);
57 stream.dim = (stream.dim.clone() + stream.delay).divceil(self.stride as _);
58 stream.delay = 0;
59 Ok(tvec!(fact))
60 }
61
62 as_op!();
63 pulsed_op_to_typed_op!();
64}