1use crate::fact::StreamInfo;
2use crate::internal::*;
3use tract_core::ops::scan::{InputMapping, Scan};
4
5register_all!(Scan: pulsify);
6
7fn pulsify(
8 op: &Scan,
9 source: &TypedModel,
10 node: &TypedNode,
11 target: &mut PulsedModel,
12 mapping: &HashMap<OutletId, OutletId>,
13 symbol: &Symbol,
14 pulse: &TDim,
15) -> TractResult<Option<TVec<OutletId>>> {
16
17
18let pulse_inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
39
40 let axes_mapping = source.node_axes_mapping(node.id)?;
41 let first_scan_slot = op.input_mapping.iter().position(InputMapping::is_scan).unwrap();
42 let first_scan_axis = target.outlet_fact(pulse_inputs[first_scan_slot])?.stream.as_ref().unwrap().axis;
43 let scan_axis = axes_mapping.axis((InOut::In(first_scan_slot), first_scan_axis))?;
44 if first_scan_axis == op.input_mapping[first_scan_slot].as_scan().unwrap().axis {
45 let mut op = op.clone();
46 op.skip = target.outlet_fact(pulse_inputs[first_scan_slot])?.stream.as_ref().unwrap().delay;
47 for om in op.output_mapping.iter_mut() {
48 if om.scan.is_some() {
49 om.full_dim_hint = None;
50 }
51 }
52 Ok(Some(target.wire_node(&*node.name, op, &pulse_inputs)?))
53 } else if scan_axis.outputs.iter().all(|x| x.len() == 1) {
54 let body = PulsedModel::new(&op.body, symbol.clone(), pulse)?.into_typed()?;
55 let mut new_op = Scan::new(body, op.input_mapping.clone(), op.output_mapping.clone(), 0)?;
56 new_op.reset_every_turn = true;
57 target.wire_node(&node.name, new_op, &pulse_inputs).map(Some)
58 } else {
59 todo!("Unsupported pulsification")
60 }
61}
62
63impl PulsedOp for Scan {
64 fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>> {
65 let outer_output_count = self
66 .output_mapping
67 .iter()
68 .map(|om| om.scan.map(|s| s.0).unwrap_or(0).max(om.last_value_slot.unwrap_or(0)))
69 .max()
70 .context("no output?")?
71 + 1;
72
73 let first_scan_slot = self.input_mapping.iter().position(InputMapping::is_scan).unwrap();
74 let first_pulse_axis = inputs[first_scan_slot].stream.as_ref().unwrap().axis;
75 let first_scan_axis = self.input_mapping[first_scan_slot].as_scan().as_ref().unwrap().axis;
76 let tracking = self.body.axes_mapping()?;
77 let pulse_axis = tracking.axis((InOut::In(first_scan_slot), first_pulse_axis))?;
78 let mut facts = tvec!();
79 for output_slot in 0..outer_output_count {
80 let (output_body_ix, output_mapping) = self
81 .output_mapping
82 .iter()
83 .enumerate()
84 .find(|(_ix, om)| om.scan.map(|s| s.0) == Some(output_slot))
85 .context("Scan pulse only supports full outputs")?;
86 let output_body_fact = self.body.output_fact(output_body_ix)?;
87 let fact = if first_scan_axis == first_pulse_axis {
88 let shape: ShapeFact = output_body_fact
89 .shape
90 .iter()
91 .enumerate()
92 .map(|(axis, d)| {
93 if axis == output_mapping.scan.unwrap().1.axis {
94 inputs[first_scan_slot].pulse().unwrap().to_dim()
95 } else {
96 d.clone()
97 }
98 })
99 .collect();
100 PulsedFact {
101 datum_type: output_body_fact.datum_type,
102 shape,
103 stream: Some(StreamInfo {
104 axis: output_mapping.scan.unwrap().1.axis,
105 dim: inputs[first_scan_slot].stream.as_ref().unwrap().dim.clone(),
106 delay: inputs[first_scan_slot].stream.as_ref().unwrap().delay,
107 }),
108 }
109 } else {
110 let pulse_axis = pulse_axis.outputs[output_body_ix][0];
111 let mut shape = output_body_fact.shape.clone();
112 if let Some(info) = output_mapping.scan {
113 shape.set(info.0, inputs[first_scan_slot].shape[first_scan_axis].clone());
114 }
115 PulsedFact {
116 datum_type: output_body_fact.datum_type,
117 shape,
118 stream: Some(StreamInfo {
119 axis: pulse_axis,
120 dim: inputs[first_scan_slot].stream.as_ref().unwrap().dim.clone(),
121 delay: inputs[first_scan_slot].stream.as_ref().unwrap().delay,
122 }),
123 }
124 };
125 facts.push(fact);
126 }
127 Ok(facts)
128 }
129
130 as_op!();
131 pulsed_op_to_typed_op!();
132}