tract_pulse/ops/
scan.rs

1use crate::fact::StreamInfo;
2use crate::internal::*;
3use tract_core::ops::scan::{InputMapping, Scan};
4
5register_all!(Scan: pulsify);
6
7fn pulsify(
8    op: &Scan,
9    source: &TypedModel,
10    node: &TypedNode,
11    target: &mut PulsedModel,
12    mapping: &HashMap<OutletId, OutletId>,
13    symbol: &Symbol,
14    pulse: &TDim,
15) -> TractResult<Option<TVec<OutletId>>> {
16
17
18/*
19
20    dbg!(source.node_axes_mapping(node.id)?.to_string());
21    for input_id in &node.inputs {
22        dbg!(target.outlet_fact(mapping[input_id]))?;
23    }
24    for input_id in 0..node.inputs.len() {
25        let input = mapping[&node.inputs[input_id]];
26        let input_fact = target.outlet_fact(input)?;
27        if let Some(info) = op.input_mapping[input_id].as_scan() {
28            if info.chunk < 0 {
29                bail!("Can not pulsify a backward scan.")
30            }
31            if input_fact.stream.as_ref().context("scan on non-streamed input")?.axis != info.axis {
32                bail!("Scan pulsification limited to scanning axis");
33            }
34        }
35    }
36*/
37
38    let pulse_inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
39
40    let axes_mapping = source.node_axes_mapping(node.id)?;
41    let first_scan_slot = op.input_mapping.iter().position(InputMapping::is_scan).unwrap();
42    let first_scan_axis = target.outlet_fact(pulse_inputs[first_scan_slot])?.stream.as_ref().unwrap().axis;
43    let scan_axis = axes_mapping.axis((InOut::In(first_scan_slot), first_scan_axis))?;
44    if first_scan_axis == op.input_mapping[first_scan_slot].as_scan().unwrap().axis {
45        let mut op = op.clone();
46        op.skip = target.outlet_fact(pulse_inputs[first_scan_slot])?.stream.as_ref().unwrap().delay;
47        for om in op.output_mapping.iter_mut() {
48            if om.scan.is_some() {
49                om.full_dim_hint = None;
50            }
51        }
52        Ok(Some(target.wire_node(&*node.name, op, &pulse_inputs)?))
53    } else if scan_axis.outputs.iter().all(|x| x.len() == 1) {
54        let body = PulsedModel::new(&op.body, symbol.clone(), pulse)?.into_typed()?;
55        let mut new_op = Scan::new(body, op.input_mapping.clone(), op.output_mapping.clone(), 0)?;
56        new_op.reset_every_turn = true;
57        target.wire_node(&node.name, new_op, &pulse_inputs).map(Some)
58    } else {
59        todo!("Unsupported pulsification")
60    }
61}
62
63impl PulsedOp for Scan {
64    fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
65        let outer_output_count = self
66            .output_mapping
67            .iter()
68            .map(|om| om.scan.map(|s| s.0).unwrap_or(0).max(om.last_value_slot.unwrap_or(0)))
69            .max()
70            .context("no output?")?
71            + 1;
72
73        let first_scan_slot = self.input_mapping.iter().position(InputMapping::is_scan).unwrap();
74        let first_pulse_axis = inputs[first_scan_slot].stream.as_ref().unwrap().axis;
75        let first_scan_axis = self.input_mapping[first_scan_slot].as_scan().as_ref().unwrap().axis;
76        let tracking = self.body.axes_mapping()?;
77        let pulse_axis = tracking.axis((InOut::In(first_scan_slot), first_pulse_axis))?;
78        let mut facts = tvec!();
79        for output_slot in 0..outer_output_count {
80            let (output_body_ix, output_mapping) = self
81                .output_mapping
82                .iter()
83                .enumerate()
84                .find(|(_ix, om)| om.scan.map(|s| s.0) == Some(output_slot))
85                .context("Scan pulse only supports full outputs")?;
86            let output_body_fact = self.body.output_fact(output_body_ix)?;
87            let fact = if first_scan_axis == first_pulse_axis {
88                let shape: ShapeFact = output_body_fact
89                    .shape
90                    .iter()
91                    .enumerate()
92                    .map(|(axis, d)| {
93                        if axis == output_mapping.scan.unwrap().1.axis {
94                            inputs[first_scan_slot].pulse().unwrap().to_dim()
95                        } else {
96                            d.clone()
97                        }
98                    })
99                    .collect();
100                PulsedFact {
101                    datum_type: output_body_fact.datum_type,
102                    shape,
103                    stream: Some(StreamInfo {
104                        axis: output_mapping.scan.unwrap().1.axis,
105                        dim: inputs[first_scan_slot].stream.as_ref().unwrap().dim.clone(),
106                        delay: inputs[first_scan_slot].stream.as_ref().unwrap().delay,
107                    }),
108                }
109            } else {
110                let pulse_axis = pulse_axis.outputs[output_body_ix][0];
111                let mut shape = output_body_fact.shape.clone();
112                if let Some(info) = output_mapping.scan {
113                    shape.set(info.0, inputs[first_scan_slot].shape[first_scan_axis].clone());
114                }
115                PulsedFact {
116                    datum_type: output_body_fact.datum_type,
117                    shape,
118                    stream: Some(StreamInfo {
119                        axis: pulse_axis,
120                        dim: inputs[first_scan_slot].stream.as_ref().unwrap().dim.clone(),
121                        delay: inputs[first_scan_slot].stream.as_ref().unwrap().delay,
122                    }),
123                }
124            };
125            facts.push(fact);
126        }
127        Ok(facts)
128    }
129
130    as_op!();
131    pulsed_op_to_typed_op!();
132}