1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
use crate::fact::StreamInfo;
use crate::internal::*;
use tract_core::ops::scan::{InputMapping, Scan};

register_all!(Scan: pulsify);

fn pulsify(
    op: &Scan,
    source: &TypedModel,
    node: &TypedNode,
    target: &mut PulsedModel,
    mapping: &HashMap<OutletId, OutletId>,
    symbol: &Symbol,
    pulse: &TDim,
) -> TractResult<Option<TVec<OutletId>>> {


/*

    dbg!(source.node_axes_mapping(node.id)?.to_string());
    for input_id in &node.inputs {
        dbg!(target.outlet_fact(mapping[input_id]))?;
    }
    for input_id in 0..node.inputs.len() {
        let input = mapping[&node.inputs[input_id]];
        let input_fact = target.outlet_fact(input)?;
        if let Some(info) = op.input_mapping[input_id].as_scan() {
            if info.chunk < 0 {
                bail!("Can not pulsify a backward scan.")
            }
            if input_fact.stream.as_ref().context("scan on non-streamed input")?.axis != info.axis {
                bail!("Scan pulsification limited to scanning axis");
            }
        }
    }
*/

    let pulse_inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();

    let axes_mapping = source.node_axes_mapping(node.id)?;
    let first_scan_slot = op.input_mapping.iter().position(InputMapping::is_scan).unwrap();
    let first_scan_axis = target.outlet_fact(pulse_inputs[first_scan_slot])?.stream.as_ref().unwrap().axis;
    let scan_axis = axes_mapping.axis((InOut::In(first_scan_slot), first_scan_axis))?;
    if first_scan_axis == op.input_mapping[first_scan_slot].as_scan().unwrap().axis {
        let mut op = op.clone();
        op.skip = target.outlet_fact(pulse_inputs[first_scan_slot])?.stream.as_ref().unwrap().delay;
        for om in op.output_mapping.iter_mut() {
            if om.scan.is_some() {
                om.full_dim_hint = None;
            }
        }
        Ok(Some(target.wire_node(&*node.name, op, &pulse_inputs)?))
    } else if scan_axis.outputs.iter().all(|x| x.len() == 1) {
        let body = PulsedModel::new(&op.body, symbol.clone(), pulse)?.into_typed()?;
        let mut new_op = Scan::new(body, op.input_mapping.clone(), op.output_mapping.clone(), 0)?;
        new_op.reset_every_turn = true;
        target.wire_node(&node.name, new_op, &pulse_inputs).map(Some)
    } else {
        todo!("Unsupported pulsification")
    }
}

impl PulsedOp for Scan {
    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
        let outer_output_count = self
            .output_mapping
            .iter()
            .map(|om| om.scan.map(|s| s.0).unwrap_or(0).max(om.last_value_slot.unwrap_or(0)))
            .max()
            .context("no output?")?
            + 1;

        let first_scan_slot = self.input_mapping.iter().position(InputMapping::is_scan).unwrap();
        let first_pulse_axis = inputs[first_scan_slot].stream.as_ref().unwrap().axis;
        let first_scan_axis = self.input_mapping[first_scan_slot].as_scan().as_ref().unwrap().axis;
        let tracking = self.body.axes_mapping()?;
        let pulse_axis = tracking.axis((InOut::In(first_scan_slot), first_pulse_axis))?;
        let mut facts = tvec!();
        for output_slot in 0..outer_output_count {
            let (output_body_ix, output_mapping) = self
                .output_mapping
                .iter()
                .enumerate()
                .find(|(_ix, om)| om.scan.map(|s| s.0) == Some(output_slot))
                .context("Scan pulse only supports full outputs")?;
            let output_body_fact = self.body.output_fact(output_body_ix)?;
            let fact = if first_scan_axis == first_pulse_axis {
                let shape: ShapeFact = output_body_fact
                    .shape
                    .iter()
                    .enumerate()
                    .map(|(axis, d)| {
                        if axis == output_mapping.scan.unwrap().1.axis {
                            inputs[first_scan_slot].pulse().unwrap().to_dim()
                        } else {
                            d
                        }
                    })
                    .collect();
                PulsedFact {
                    datum_type: output_body_fact.datum_type,
                    shape,
                    stream: Some(StreamInfo {
                        axis: output_mapping.scan.unwrap().1.axis,
                        dim: inputs[first_scan_slot].stream.as_ref().unwrap().dim.clone(),
                        delay: inputs[first_scan_slot].stream.as_ref().unwrap().delay,
                    }),
                }
            } else {
                let pulse_axis = pulse_axis.outputs[output_body_ix][0];
                let mut shape = output_body_fact.shape.clone();
                if let Some(info) = output_mapping.scan {
                    shape.set(info.0, inputs[first_scan_slot].shape[first_scan_axis].clone());
                }
                PulsedFact {
                    datum_type: output_body_fact.datum_type,
                    shape,
                    stream: Some(StreamInfo {
                        axis: pulse_axis,
                        dim: inputs[first_scan_slot].stream.as_ref().unwrap().dim.clone(),
                        delay: inputs[first_scan_slot].stream.as_ref().unwrap().delay,
                    }),
                }
            };
            facts.push(fact);
        }
        Ok(facts)
    }

    as_op!();
    pulsed_op_to_typed_op!();
}