1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
use crate::fact::StreamInfo;
use crate::internal::*;
use tract_core::ops::scan::{InputMapping, Scan};

register_all!(Scan: pulsify);

fn pulsify(
    op: &Scan,
    _source: &TypedModel,
    node: &TypedNode,
    target: &mut PulsedModel,
    mapping: &HashMap<OutletId, OutletId>,
    _symbol: &Symbol,
    _pulse: &TDim,
) -> TractResult<Option<TVec<OutletId>>> {
    for input_id in 0..node.inputs.len() {
        let input = mapping[&node.inputs[input_id]];
        let input_fact = target.outlet_fact(input)?;
        if let Some(info) = op.input_mapping[input_id].as_scan() {
            if info.chunk < 0 {
                bail!("Can not pulsify a backward scan.")
            }
            if input_fact.stream.as_ref().context("scan on non-streamed input")?.axis != info.axis {
                bail!("Scan pulsification limited to scanning axis");
            }
        }
    }

    let pulse_inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();

    let mut op = op.clone();
    let first_scan_slot = op.input_mapping.iter().position(InputMapping::is_scan).unwrap();
    op.skip = target.outlet_fact(pulse_inputs[first_scan_slot])?.stream.as_ref().unwrap().delay;
    for mut om in op.output_mapping.iter_mut() {
        if om.scan.is_some() {
            om.full_dim_hint = None;
        }
    }
    Ok(Some(target.wire_node(&*node.name, op, &pulse_inputs)?))
}

impl PulsedOp for Scan {
    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
        let output_count = self
            .output_mapping
            .iter()
            .map(|om| om.scan.map(|s| s.0).unwrap_or(0).max(om.last_value_slot.unwrap_or(0)))
            .max()
            .context("no output?")?
            + 1;

        let first_scan_slot = self.input_mapping.iter().position(InputMapping::is_scan).unwrap();
        let mut facts = tvec!();
        for output_slot in 0..output_count {
            let (output_body_ix, output_mapping) = self
                .output_mapping
                .iter()
                .enumerate()
                .find(|(_ix, om)| om.scan.map(|s| s.0) == Some(output_slot))
                .context("Scan pulse only supports full outputs")?;
            let output_body_fact = self.body.output_fact(output_body_ix)?;
            let shape: ShapeFact = output_body_fact
                .shape
                .iter()
                .enumerate()
                .map(|(axis, d)| {
                    if axis == output_mapping.scan.unwrap().1.axis {
                        inputs[first_scan_slot].pulse().unwrap().to_dim()
                    } else {
                        d
                    }
                })
                .collect();
            let fact = PulsedFact {
                datum_type: output_body_fact.datum_type,
                shape,
                stream: Some(StreamInfo {
                    axis: output_mapping.scan.unwrap().1.axis,
                    dim: inputs[first_scan_slot].stream.as_ref().unwrap().dim.clone(),
                    delay: inputs[first_scan_slot].stream.as_ref().unwrap().delay,
                }),
            };
            facts.push(fact);
        }
        Ok(facts)
    }

    as_op!();
    pulsed_op_to_typed_op!();
}