tract_pulse/ops/
downsample.rs

1use crate::internal::*;
2use tract_core::ops::Downsample;
3use tract_pulse_opl::ops::PulsedAxisSlice;
4use tract_pulse_opl::tract_nnef::tract_num_traits::Zero;
5
6register_all!(Downsample: pulsify);
7
8fn pulsify(
9    op: &Downsample,
10    _source: &TypedModel,
11    node: &TypedNode,
12    target: &mut PulsedModel,
13    mapping: &HashMap<OutletId, OutletId>,
14    _symbol: &Symbol,
15    _pulse: &TDim,
16) -> TractResult<Option<TVec<OutletId>>> {
17    let input = mapping[&node.inputs[0]];
18    let fact = target.outlet_fact(input)?.clone();
19    if let Some(stream) = fact.stream.as_ref() {
20        if stream.axis != op.axis {
21            return Ok(None);
22        }
23        let stride = if op.stride > 0 {
24            op.stride as usize
25        } else {
26            bail!("Negative strides are not causal, can not pulsify.")
27        };
28        let pulse = fact.pulse().unwrap();
29        if !(pulse.clone() % stride).is_zero() {
30            bail!("Pulsification requires pulse ({}) to be a stride ({}) multiple", pulse, stride)
31        }
32        let mut wire = tvec!(input);
33        let first_offset = stream.delay + op.modulo;
34        let new_op = Downsample { modulo: first_offset % stride, axis: op.axis, stride: op.stride };
35        wire = target.wire_node(format!("{}.downsample", node.name), new_op, &wire)?;
36        wire = target.wire_node(
37            &node.name,
38            PulsedAxisSlice {
39                axis: stream.axis,
40                skip: first_offset / stride,
41                take: (stream.dim.to_owned() - op.modulo).divceil(stride),
42            },
43            &wire,
44        )?;
45        target.rename_node(wire[0].node, &node.name)?;
46        Ok(Some(wire))
47    } else {
48        Ok(None)
49    }
50}
51
52impl PulsedOp for Downsample {
53    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
54        let mut fact = inputs[0].clone();
55        let stream = fact.stream.as_mut().unwrap();
56        fact.shape.set(self.axis, fact.shape[self.axis].clone() / self.stride as usize);
57        stream.dim = (stream.dim.clone() + stream.delay).divceil(self.stride as _);
58        stream.delay = 0;
59        Ok(tvec!(fact))
60    }
61
62    as_op!();
63    pulsed_op_to_typed_op!();
64}