1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
use crate::internal::*;
use tract_core::ops::Downsample;
use tract_pulse_opl::ops::PulsedAxisSlice;
use tract_pulse_opl::tract_nnef::tract_num_traits::Zero;

register_all!(Downsample: pulsify);

fn pulsify(
    op: &Downsample,
    _source: &TypedModel,
    node: &TypedNode,
    target: &mut PulsedModel,
    mapping: &HashMap<OutletId, OutletId>,
    _symbol: &Symbol,
    _pulse: &TDim,
) -> TractResult<Option<TVec<OutletId>>> {
    let input = mapping[&node.inputs[0]];
    let fact = target.outlet_fact(input)?.clone();
    if let Some(stream) = fact.stream.as_ref() {
        if stream.axis != op.axis {
            return Ok(None);
        }
        let stride = if op.stride > 0 {
            op.stride as usize
        } else {
            bail!("Negative strides are not causal, can not pulsify.")
        };
        let pulse = fact.pulse().unwrap();
        if !(pulse.clone() % stride).is_zero() {
            bail!("Pulsification requires pulse ({}) to be a stride ({}) multiple", pulse, stride)
        }
        let mut wire = tvec!(input);
        let first_offset = stream.delay + op.modulo;
        let new_op = Downsample { modulo: first_offset % stride, axis: op.axis, stride: op.stride };
        wire = target.wire_node(format!("{}.downsample", node.name), new_op, &wire)?;
        wire = target.wire_node(
            &node.name,
            PulsedAxisSlice {
                axis: stream.axis,
                skip: first_offset / stride,
                take: (stream.dim.to_owned() - op.modulo).divceil(stride),
            },
            &wire,
        )?;
        target.rename_node(wire[0].node, &node.name)?;
        Ok(Some(wire))
    } else {
        Ok(None)
    }
}

impl PulsedOp for Downsample {
    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
        let mut fact = inputs[0].clone();
        let mut stream = fact.stream.as_mut().unwrap();
        fact.shape.set(self.axis, fact.shape[self.axis].clone() / self.stride as usize);
        stream.dim = (stream.dim.clone() + stream.delay).divceil(self.stride as _);
        stream.delay = 0;
        Ok(tvec!(fact))
    }

    as_op!();
    pulsed_op_to_typed_op!();
}