1use std::collections::HashSet;
2
3use crate::ops::einsum::EinSum;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6
7use super::optimized::{OptScan, ScanOpParams};
8use tract_data::internal::*;
9use tract_data::itertools::izip;
10
11use super::*;
12
13#[derive(Debug, Clone, Default)]
14pub struct Scan {
15 pub skip: usize,
16 pub reset_every_turn: bool,
17 pub body: TypedModel,
18 pub decluttered: bool,
19 pub input_mapping: Vec<InputMapping>,
20 pub output_mapping: Vec<OutputMapping<TDim>>,
21}
22
23impl PartialEq for Scan {
24 fn eq(&self, _other: &Self) -> bool {
25 false
26 }
27}
28impl Eq for Scan {}
29
30impl Scan {
31 pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<OptScan> {
32 let mut model = self.body.clone();
33 if optimize_inner {
34 model = model.into_optimized()?;
35 }
36 let plan = SimplePlan::new(model)?;
37
38 Ok(OptScan::new(Arc::new(ScanOpParams::new(
39 self.skip,
40 self.reset_every_turn,
41 plan,
42 self.input_mapping.clone(),
43 self.output_mapping.clone(),
44 ))))
45 }
46
47 pub fn new(
48 body: TypedModel,
49 input_mapping: Vec<InputMapping>,
50 output_mapping: Vec<OutputMapping<TDim>>,
51 skip: usize,
52 ) -> TractResult<Scan> {
53 body.check_consistency()?;
54 ensure!(input_mapping.len() == body.input_outlets()?.len());
55 ensure!(output_mapping.len() == body.output_outlets()?.len());
56 Ok(Scan {
57 skip,
58 reset_every_turn: false,
59 body,
60 decluttered: false,
61 input_mapping,
62 output_mapping,
63 })
64 }
65
66 pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
67 self.to_codegen_op(false).unwrap().iteration_count(inputs)
68 }
69
70 fn declutter_body(
71 &self,
72 session: &mut OptimizerSession,
73 model: &TypedModel,
74 node: &TypedNode,
75 ) -> TractResult<Option<TypedModelPatch>> {
76 rule_if!(!self.decluttered);
77 let mut new = self.clone();
78 let mut body = self.body.clone();
79 session.optimize(&mut body)?;
80 new.body = body;
81 new.decluttered = true;
82 Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
83 }
84
85 fn declutter_single_loop(
86 &self,
87 _session: &mut OptimizerSession,
88 model: &TypedModel,
89 node: &TypedNode,
90 ) -> TractResult<Option<TypedModelPatch>> {
91 let inputs = model.node_input_facts(node.id)?;
92 let iters =
93 super::iteration_count(&self.input_mapping, &inputs).context("No scan input")?;
94 if !iters.is_one() {
95 return Ok(None);
96 }
97 let mut patch = TypedModelPatch::new("Inline single loop scan");
98 patch.model = self.body.clone();
99 for (outer_wire, inner_wire) in izip!(&node.inputs, &self.body.inputs) {
100 patch.taps.insert(*inner_wire, *outer_wire);
101 }
102 for (inner_wire, mapping) in izip!(&self.body.outputs, &self.output_mapping) {
103 if let Some((slot, _)) = mapping.scan {
104 patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
105 }
106 if let Some(slot) = mapping.last_value_slot {
107 patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
108 }
109 }
110 Ok(Some(patch))
111 }
112
113 fn declutter_body_axes(
114 &self,
115 _session: &mut OptimizerSession,
116 model: &TypedModel,
117 node: &TypedNode,
118 ) -> TractResult<Option<TypedModelPatch>> {
119 let mut suggestions = vec![];
120 for n in self.body.eval_order()? {
121 let node = self.body.node(n);
122 for suggestion in node.op.suggested_axis_changes()? {
123 let outlet = suggestion.0.as_outlet(node);
124 suggestions.push(AxisChange { outlet, op: suggestion.1 })
125 }
126 for (slot, fact) in node.outputs.iter().enumerate() {
127 for (ix, dim) in fact.fact.shape.iter().enumerate() {
128 if dim.is_one() {
129 suggestions.push(AxisChange {
130 outlet: OutletId::new(n, slot),
131 op: AxisOp::Rm(ix),
132 });
133 }
134 }
135 }
136 }
137 let node_input_facts = model.node_input_facts(node.id)?;
138 for suggestion in suggestions.into_iter() {
139 if let Some(conseq) = self.try_body_axes_change(suggestion, true, &node_input_facts)? {
140 let mut patch = TypedModelPatch::default();
141 let mut inputs = tvec!();
142 for outlet in &node.inputs {
143 inputs.push(patch.tap_model(model, *outlet)?);
144 }
145 for change in conseq.wire_changes {
146 if let InOut::In(i) = change.0 {
147 let mut value = patch
148 .outlet_fact(inputs[i])?
149 .konst
150 .clone()
151 .context("Will only reshape constants")?
152 .into_tensor();
153 change.1.change_tensor(&mut value, false)?;
154 let konst_name = patch.node(inputs[i].node).name.clone();
155 inputs[i] = patch.add_const(konst_name, value)?;
156 }
157 }
158 let wires = patch.wire_node(
159 &node.name,
160 conseq.substitute_op.unwrap_or_else(|| Box::new(self.clone())),
161 &inputs,
162 )?;
163 for (ix, new) in wires.into_iter().enumerate() {
164 patch.shunt_outside(model, OutletId::new(node.id, ix), new)?;
165 }
166 return Ok(Some(patch));
167 }
168 }
169 Ok(None)
170 }
171
172 fn remove_outer_output_from_mappings(
173 mappings: &[OutputMapping<TDim>],
174 discarded: usize,
175 ) -> Vec<OutputMapping<TDim>> {
176 mappings
177 .iter()
178 .map(|m| OutputMapping {
179 scan: m.scan.map(|(slot, info)| (slot - (slot > discarded) as usize, info)),
180 last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
181 full_dim_hint: m.full_dim_hint.clone(),
182 state: m.state,
183 })
184 .collect()
185 }
186
187 fn declutter_const_input(
188 &self,
189 _session: &mut OptimizerSession,
190 model: &TypedModel,
191 node: &TypedNode,
192 ) -> TractResult<Option<TypedModelPatch>> {
193 let inputs = model.node_input_facts(node.id)?;
194 for (slot, mapping) in self.input_mapping.iter().enumerate() {
195 if let InputMapping::Full = mapping
196 && let Some(konst) = inputs[slot].konst.as_ref()
197 {
198 let mut op = self.clone();
199 let src = op.body.inputs[slot];
200 op.body.inputs.remove(slot);
201 op.body.nodes[src.node].inputs.clear();
202 op.body.nodes[src.node].op = Box::new(Const::new(konst.clone())?);
203 op.input_mapping.remove(slot);
204 let mut inputs = node.inputs.clone();
205 inputs.remove(slot);
206 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
207 }
208 }
209 Ok(None)
210 }
211
212 fn declutter_discard_unused_input(
213 &self,
214 _session: &mut OptimizerSession,
215 model: &TypedModel,
216 node: &TypedNode,
217 ) -> TractResult<Option<TypedModelPatch>> {
218 for (slot, input) in self.body.input_outlets()?.iter().enumerate() {
219 let source_node = self.body.node(input.node);
220 if source_node.outputs[0].successors.len() == 0
221 && !self.body.output_outlets()?.contains(input)
222 {
223 let mut new_inputs = node.inputs.clone();
224 new_inputs.remove(slot);
225 let mut new_mappings: Vec<_> = self.input_mapping.clone();
226 new_mappings.remove(slot);
227 let mut model_inputs = self.body.input_outlets()?.to_vec();
228 model_inputs.remove(slot);
229 let mut body = self.body.clone();
230 let mut patch = TypedModelPatch::default();
231 patch.obliterate(source_node.id)?;
232 patch.apply(&mut body)?;
233 body.set_input_outlets(&model_inputs)?;
234 body.declutter()?;
235 let op =
236 Self { body, input_mapping: new_mappings, decluttered: true, ..self.clone() };
237 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
238 }
239 }
240 Ok(None)
241 }
242
243 fn declutter_discard_useless_outer_output(
244 &self,
245 _session: &mut OptimizerSession,
246 model: &TypedModel,
247 node: &TypedNode,
248 ) -> TractResult<Option<TypedModelPatch>> {
249 for (ix, o) in node.outputs.iter().enumerate() {
250 if o.successors.len() == 0
251 && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
252 {
253 let mappings = self
254 .output_mapping
255 .iter()
256 .map(|m| OutputMapping {
257 scan: m.scan.filter(|(slot, _info)| *slot != ix),
258 last_value_slot: m.last_value_slot.filter(|s| *s != ix),
259 full_dim_hint: m.full_dim_hint.clone(),
260 state: m.state,
261 })
262 .collect::<Vec<_>>();
263 let mut op = self.clone();
264 op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
265 let mut patch = TypedModelPatch::default();
266 let inputs = node
267 .inputs
268 .iter()
269 .map(|&i| patch.tap_model(model, i))
270 .collect::<TractResult<Vec<_>>>()?;
271 let wires = patch.wire_node(&*node.name, op, &inputs)?;
272 for oix in 0..node.outputs.len() {
273 if oix != ix {
274 patch.shunt_outside(
275 model,
276 OutletId::new(node.id, oix),
277 wires[oix - (oix > ix) as usize],
278 )?;
279 }
280 }
281 return Ok(Some(patch));
282 }
283 }
284 Ok(None)
285 }
286
287 fn declutter_discard_empty_output_mapping_with_body_output(
288 &self,
289 _session: &mut OptimizerSession,
290 model: &TypedModel,
291 node: &TypedNode,
292 ) -> TractResult<Option<TypedModelPatch>> {
293 for (ix, om) in self.output_mapping.iter().enumerate() {
294 if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
295 let mut new_op = self.clone();
296 new_op.output_mapping.remove(ix);
297 new_op.body.outputs.remove(ix);
298 new_op.decluttered = false;
299 return Ok(Some(TypedModelPatch::replace_single_op(
300 model,
301 node,
302 &node.inputs,
303 new_op,
304 )?));
305 }
306 }
307 Ok(None)
308 }
309
310 fn declutter_pull_batcheable_input(
311 &self,
312 _session: &mut OptimizerSession,
313 model: &TypedModel,
314 node: &TypedNode,
315 ) -> TractResult<Option<TypedModelPatch>> {
316 'candidate: for (slot, input) in self.input_mapping.iter().enumerate() {
317 if let Some(scan_info) = input.as_scan() {
318 let scan_source = self.body.input_outlets()?[slot];
319 let scan_source_node = self.body.node(scan_source.node);
320 for mut succ in &scan_source_node.outputs[0].successors {
321 for &succ_input in &self.body.node(succ.node).inputs {
322 if succ_input != scan_source
323 && self.body.outlet_fact(succ_input)?.konst.is_none()
324 {
325 continue 'candidate;
326 }
327 }
328 if self.body.node(succ.node).outputs.len() != 1 {
329 continue;
330 }
331 let mut new_body = self.body.clone();
332 if let Some(einsum) = new_body.node(succ.node).op_as::<EinSum>()
334 && let Some(patch) = einsum
335 .propagate_axis(
336 &new_body,
337 new_body.node(succ.node),
338 InOut::In(succ.slot),
339 scan_info.axis,
340 )
341 .context("building axis propagating patch")?
342 {
343 patch.apply(&mut new_body)?;
344 new_body.compute_const_facts()?;
345 let new_body_scan_input = new_body.input_outlets()?[slot];
348 succ = new_body.node(new_body_scan_input.node).outputs[0]
349 .successors
350 .last()
351 .unwrap();
352 }
353
354 let axes_mapping = {
355 let (input_facts, output_facts) =
356 new_body.node_facts(new_body.node(succ.node).id)?;
357 new_body.node(succ.node).op.axes_mapping(&input_facts, &output_facts)?
358 };
359 let axis_info = axes_mapping.axis((InOut::In(succ.slot), scan_info.axis))?;
360 if let &[axis_after] = &*axis_info.outputs[0] {
361 let mut outside_patch = TypedModelPatch::new(format!(
362 "Outer patch for input extraction of {}",
363 new_body.node(succ.node)
364 ));
365 let mut patch_inputs = node
366 .inputs
367 .iter()
368 .map(|&i| outside_patch.tap_model(model, i))
369 .collect::<TractResult<TVec<_>>>()?;
370 let mut extracted_op_inputs = tvec!();
371 for (ix, outlet) in new_body.node(succ.node).inputs.iter().enumerate() {
372 let wire = if ix == succ.slot {
373 patch_inputs[slot]
374 } else if let Some(konst) =
375 new_body.outlet_fact(*outlet)?.konst.as_ref()
376 {
377 outside_patch.add_const(
378 format!(
379 "{}.extracted.{}",
380 node.name,
381 new_body.node(outlet.node).name
382 ),
383 konst.clone(),
384 )?
385 } else {
386 unreachable!();
387 };
388 extracted_op_inputs.push(wire);
389 }
390 let new_input_wire = outside_patch.wire_node(
391 format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
392 new_body.node(succ.node).op.clone(),
393 &extracted_op_inputs,
394 )?[0];
395 patch_inputs.push(new_input_wire);
396 let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
397 let mut new_input_inner_fact = new_input_outer_fact.clone();
398 new_input_inner_fact.shape.set(axis_after, scan_info.chunk.abs().to_dim());
399
400 let mut new_body = new_body.clone();
401 let new_source_wire = new_body.add_source(
402 format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
403 new_input_inner_fact,
404 )?;
405 let mut inner_patch = TypedModelPatch::new(format!(
406 "Inner body patch for extraction of {}",
407 new_body.node(succ.node)
408 ));
409 let new_source_wire_in_patch =
410 inner_patch.tap_model(&new_body, new_source_wire)?;
411 inner_patch
412 .shunt_outside(
413 &new_body,
414 OutletId::new(succ.node, 0),
415 new_source_wire_in_patch,
416 )
417 .with_context(|| "patching inner model")?;
418 inner_patch.apply(&mut new_body)?;
419
420 let mut input_mapping = self.input_mapping.clone();
421 input_mapping.push(InputMapping::Scan(ScanInfo {
422 axis: axis_after,
423 chunk: scan_info.chunk,
424 }));
425
426 let new_op = Self {
427 input_mapping,
428 decluttered: false,
429 body: new_body,
430 ..self.clone()
431 };
432 let output_wires =
433 outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
434 for w in output_wires {
435 outside_patch
436 .shunt_outside(model, OutletId::new(node.id, w.slot), w)
437 .with_context(|| "patching outer model")?;
438 }
439 return Ok(Some(outside_patch));
440 }
441 }
442 }
443 }
444 Ok(None)
445 }
446
447 fn declutter_pull_constant_outputs(
448 &self,
449 _session: &mut OptimizerSession,
450 model: &TypedModel,
451 node: &TypedNode,
452 ) -> TractResult<Option<TypedModelPatch>> {
453 for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
454 if let Some(slot) = mapping.last_value_slot
455 && let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone()
456 {
457 let inner_node = self.body.output_outlets()?[model_output_ix].node;
458 let inner_node = self.body.node(inner_node);
459 let mut patch = TypedModelPatch::new(format!("Extract const node {inner_node}"));
460 let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
461 patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
462 return Ok(Some(patch));
463 }
464 }
465 Ok(None)
466 }
467
468 fn declutter_pull_batcheable_output(
469 &self,
470 _session: &mut OptimizerSession,
471 model: &TypedModel,
472 node: &TypedNode,
473 ) -> TractResult<Option<TypedModelPatch>> {
474 for (mapping_ix, mapping) in self.output_mapping.iter().enumerate() {
475 if let Some((_, scan_info)) = mapping.scan {
476 let emitter_outlet = self.body.output_outlets()?[mapping_ix];
477 if self.body.node(emitter_outlet.node).outputs[emitter_outlet.slot].successors.len()
478 > 0
479 || self.body.inputs.contains(&emitter_outlet)
480 || mapping.state
481 || mapping.scan.map(|(_slot, i)| i.chunk > 1).unwrap_or(true)
482 {
483 continue;
485 }
486 let mut new_body = self.body.clone();
487 if let Some(einsum) = new_body.node(emitter_outlet.node).op_as::<EinSum>()
488 && let Some(patch) = einsum
489 .propagate_axis(
490 &new_body,
491 new_body.node(emitter_outlet.node),
492 InOut::Out(0),
493 scan_info.axis,
494 )
495 .context("building axis propagating patch")?
496 {
497 patch.apply(&mut new_body)?;
498 new_body.prop_consts()?;
499 }
500 let emitter_outlet = new_body.output_outlets()?[mapping_ix];
501 let invariants = {
502 let (input_facts, output_facts) = new_body.node_facts(emitter_outlet.node)?;
503 new_body
504 .node(emitter_outlet.node)
505 .op
506 .axes_mapping(&input_facts, &output_facts)?
507 };
508 let axis_tracking =
509 invariants.axis((InOut::Out(emitter_outlet.slot), scan_info.axis))?;
510 rule_if!(axis_tracking.outputs.iter().all(|o| o.len() == 1));
511 let mut new_output_mapping = self.output_mapping.clone();
512 let mut new_scan_outputs = node.outputs.len();
513 let mut outer_slots = vec![];
514
515 for (input_slot, input) in
517 new_body.node(emitter_outlet.node).inputs.clone().iter().enumerate()
518 {
519 if new_body.outputs.iter().all(|o| o != input) {
520 new_output_mapping.push(OutputMapping::default());
521 new_body.outputs.push(*input);
522 }
523 let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
524 let mapping = &mut new_output_mapping[body_output_id];
525 let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
526 if mapping.last_value_slot.is_none() {
527 mapping.last_value_slot = Some(new_scan_outputs);
528 new_scan_outputs += 1;
529 }
530 mapping.last_value_slot.unwrap()
531 } else if let &[axis] = &*axis_tracking.inputs[input_slot] {
532 if mapping.scan.is_none() {
533 mapping.scan =
534 Some((new_scan_outputs, ScanInfo { axis, chunk: scan_info.chunk }));
535 new_scan_outputs += 1;
536 }
537 mapping.scan.unwrap().0
538 } else {
539 return Ok(None);
540 };
541 outer_slots.push(outer_slot);
542 }
543 let mut outside_patch = TypedModelPatch::new(format!(
544 "Outside patch for output extraction of {}",
545 new_body.node(emitter_outlet.node)
546 ));
547 let inputs = node
548 .inputs
549 .iter()
550 .map(|&i| outside_patch.tap_model(model, i))
551 .collect::<TractResult<TVec<_>>>()?;
552 let new_op = Self {
553 output_mapping: new_output_mapping,
554 decluttered: false,
555 body: new_body.clone(), ..self.clone()
557 };
558 let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
559 let output = mapping.scan.unwrap();
560 let inputs =
561 outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
562 let wire = outside_patch.wire_node(
563 &new_body.node(emitter_outlet.node).name,
564 new_body.node(emitter_outlet.node).op.clone(),
565 &inputs,
566 )?[0];
567 outside_patch.shunt_outside(model, OutletId::new(node.id, output.0), wire)?;
568 for output_slot in 0..node.outputs.len() {
569 if output_slot != output.0 {
570 outside_patch.shunt_outside(
571 model,
572 OutletId::new(node.id, output_slot),
573 OutletId::new(scan_outputs[0].node, output_slot),
574 )?;
575 }
576 }
577 return Ok(Some(outside_patch));
578 }
579 }
580 Ok(None)
581 }
582
583 fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
584 let input_state_outlets = self
585 .input_mapping
586 .iter()
587 .zip(self.body.input_outlets()?.iter())
588 .filter(|(m, _)| m.is_state())
589 .map(|(_, o)| o);
590 let output_state_outlets = self
591 .output_mapping
592 .iter()
593 .zip(self.body.output_outlets()?.iter())
594 .filter(|(m, _)| m.state)
595 .map(|(_, o)| o);
596 Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
597 }
598
599 fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult<TVec<OutletId>> {
600 let input_outlets =
601 self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| {
602 if node_input_facts[slot].konst.is_none() { Some(o) } else { None }
603 });
604 let output_outlets = self
605 .output_mapping
606 .iter()
607 .zip(self.body.output_outlets()?.iter())
608 .filter(|(m, _)| !m.invisible())
609 .map(|(_, o)| o);
610 Ok(input_outlets.chain(output_outlets).cloned().collect())
611 }
612
613 fn try_body_axes_change(
614 &self,
615 change: AxisChange,
616 locked_interface: bool,
617 node_input_facts: &[&TypedFact],
618 ) -> TractResult<Option<AxisChangeConsequence>> {
619 self.body.check_consistency()?;
620 let locked_outlets = self.body_locked_outlets(node_input_facts)?;
621 let mut explored: HashSet<AxisChange> = Default::default();
622 rule_if_some!(
623 (body_patch, body_changed_wires) = crate::optim::change_axes::change_axes(
624 &self.body,
625 &change,
626 if locked_interface { &locked_outlets } else { &[] },
627 &self.body_bounds()?,
628 &mut explored,
629 )?
630 );
631 let mut body = self.body.clone();
632 body_patch.apply(&mut body)?;
633 body.compact()?;
634 let mut wire_changes = tvec!();
635 let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
636 for (slot, m) in input_mapping.iter_mut().enumerate() {
637 if let Some(change) = body_changed_wires
638 .iter()
639 .find(|(iface, _change)| iface == &InOut::In(slot))
640 .map(|pair| pair.1.clone())
641 {
642 wire_changes.push((InOut::In(slot), change.clone()));
643 if let InputMapping::Scan(info) = m {
644 rule_if_some!(axis = change.transform_axis(info.axis));
645 info.axis = axis;
646 };
647 }
648 }
649 let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
650 for (ix, m) in output_mapping.iter_mut().enumerate() {
651 if let Some(change) = body_changed_wires
652 .iter()
653 .find(|(iface, _change)| iface == &InOut::Out(ix))
654 .map(|pair| pair.1.clone())
655 {
656 if let Some((slot, info)) = m.scan.as_mut() {
657 rule_if_some!(new_axis = change.transform_axis(info.axis));
658 info.axis = new_axis;
659 wire_changes.push((InOut::Out(*slot), change.clone()));
660 }
661 if let Some(slot) = m.last_value_slot {
662 wire_changes.push((InOut::Out(slot), change.clone()));
663 }
664 };
665 }
666 body.check_consistency()?;
667 let op = Some(Box::new(Scan {
668 body,
669 input_mapping,
670 output_mapping,
671 decluttered: false,
672 ..self.clone()
673 }) as _);
674 Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
675 }
676}
677
678impl Op for Scan {
679 fn name(&self) -> StaticName {
680 "Scan".into()
681 }
682
683 fn info(&self) -> TractResult<Vec<String>> {
684 let mut lines = vec![];
685 for (ix, im) in self.input_mapping.iter().enumerate() {
686 lines.push(format!("Model input #{ix}: {im:?}"));
687 }
688 for (ix, om) in self.output_mapping.iter().enumerate() {
689 lines.push(format!("Model output #{ix}: {om:?}"));
690 }
691 lines.push(format!("skip:{} reset_every_turn:{:?}", self.skip, self.reset_every_turn));
692 Ok(lines)
693 }
694
695 fn validation(&self) -> Validation {
696 Validation::Rounding
697 }
698
699 op_as_typed_op!();
700}
701
702impl EvalOp for Scan {
703 fn is_stateless(&self) -> bool {
704 false
705 }
706 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
707 self.to_codegen_op(false)?.state(session, node_id)
708 }
709}
710
711impl TypedOp for Scan {
712 as_op!();
713
714 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
715 anyhow::ensure!(inputs.len() == self.body.inputs.len());
716 anyhow::ensure!(self.input_mapping.len() == self.body.inputs.len());
717 anyhow::ensure!(
718 self.input_mapping.iter().filter(|m| m.is_state()).count()
719 == self.output_mapping.iter().filter(|m| m.state).count()
720 );
721 for (i, o) in
722 self.input_mapping.iter().enumerate().filter(|(_, m)| m.is_state()).map(|(i, _)| i).zip(
723 self.output_mapping.iter().enumerate().filter(|(_, m)| m.state).map(|(o, _)| o),
724 )
725 {
726 let ifact = self.body.outlet_fact(self.body.inputs[i])?;
727 let ofact = self.body.outlet_fact(self.body.outputs[o])?;
728 anyhow::ensure!(
729 ifact == ofact,
730 "inconsistent fact: body input {i} is {ifact:?} and body output {o} is {ofact:?}\n{}",
731 self.body
732 )
733 }
734 let mut outputs = tvec!();
735 let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?;
736 for (ix, output) in self.output_mapping.iter().enumerate() {
737 let fact = self.body.output_fact(ix)?;
738 if let Some((slot, info)) = output.scan {
739 let mut shape = fact.shape.clone();
740 let scanning_dim =
741 output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
742 shape.set(info.axis, scanning_dim);
743 outputs.push((slot, fact.datum_type.fact(shape)));
744 }
745 if let Some(slot) = output.last_value_slot {
746 outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
747 }
748 }
749 outputs.sort_by_key(|a| a.0);
750 anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
751 let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
752 Ok(outputs)
753 }
754
755 fn axes_mapping(
756 &self,
757 inputs: &[&TypedFact],
758 outputs: &[&TypedFact],
759 ) -> TractResult<AxesMapping> {
760 let mut mappings = vec![];
761 let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?;
762 for body_axis in body_invs.iter_all_axes() {
763 let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len());
764 info.inputs.clone_from(&body_axis.inputs);
765 for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
766 let mut slots = vec![];
767 if let Some((slot, _scan)) = output_mapping.scan {
768 slots.push(slot);
769 }
770 if let Some(slot) = output_mapping.last_value_slot {
771 slots.push(slot);
772 }
773 for slot in slots {
774 info.outputs[slot].clone_from(&body_axis.outputs[ix]);
775 }
776 }
777 if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) {
778 mappings.push(info);
779 }
780 }
781 AxesMapping::new(inputs.len(), outputs.len(), mappings)
782 }
783
784 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
785 let mut suggestions = tvec!();
786 for (slot, input) in self.input_mapping.iter().enumerate() {
787 if let InputMapping::Scan(info) = input
788 && info.axis != 0
789 {
790 suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0)))
791 }
792 }
793 for output in &self.output_mapping {
794 if let Some((slot, scan)) = output.scan
795 && scan.axis != 0
796 {
797 suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0)))
798 }
799 }
800 Ok(suggestions)
801 }
802
803 fn change_axes(
804 &self,
805 model: &TypedModel,
806 node: &TypedNode,
807 io: InOut,
808 change: &AxisOp,
809 ) -> TractResult<Option<AxisChangeConsequence>> {
810 trace!("Propagating through {node}: {io:?} {change:?}");
811 let body_leading_outlet = match io {
812 InOut::In(ix) => self.body.input_outlets()?[ix],
813 InOut::Out(slot) => {
814 let output = self
815 .output_mapping
816 .iter()
817 .position(|im| {
818 im.scan.map(|(slot, _i)| slot) == Some(slot)
819 || im.last_value_slot == Some(slot)
820 })
821 .unwrap();
822 self.body.output_outlets()?[output]
823 }
824 };
825 let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() };
826 let node_input_facts = model.node_input_facts(node.id)?;
827 let result = self
828 .try_body_axes_change(axis_change, false, &node_input_facts)
829 .with_context(|| "Attemping to run change through scan body".to_string())?;
830 if result.is_some() {
831 trace!("{node} accepted axis change");
832 } else {
833 trace!("{node} rejected axis change");
834 }
835 Ok(result)
836 }
837
838 fn declutter_with_session(
839 &self,
840 session: &mut OptimizerSession,
841 model: &TypedModel,
842 node: &TypedNode,
843 ) -> TractResult<Option<TypedModelPatch>> {
844 macro_rules! pass {
845 ($func:ident) => {
846 if let Some(mut r) = self
847 .$func(session, model, node)
848 .with_context(|| format!("{}", stringify!($func)))?
849 {
850 trace!(stringify!($func));
851 r.push_context(stringify!($func));
852 return Ok(Some(r));
853 }
854 };
855 }
856 pass!(declutter_single_loop);
857 pass!(declutter_const_input);
858 pass!(declutter_discard_unused_input);
859 pass!(declutter_discard_useless_outer_output);
860 pass!(declutter_discard_empty_output_mapping_with_body_output);
861 pass!(declutter_body);
862 pass!(declutter_body_axes);
863 pass!(declutter_pull_constant_outputs);
864 pass!(declutter_pull_batcheable_input);
865 pass!(declutter_pull_batcheable_output);
866 Ok(None)
867 }
868
869 fn concretize_dims(
870 &self,
871 _source: &TypedModel,
872 node: &TypedNode,
873 target: &mut TypedModel,
874 mapping: &HashMap<OutletId, OutletId>,
875 values: &SymbolValues,
876 ) -> TractResult<TVec<OutletId>> {
877 let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
878 let op = Self {
879 output_mapping: self
880 .output_mapping
881 .iter()
882 .map(|om| om.concretize_dims(values))
883 .collect::<TractResult<Vec<_>>>()?,
884 body: self.body.concretize_dims(values)?,
885 ..self.clone()
886 };
887 target.wire_node(&node.name, op, &inputs)
888 }
889
890 fn codegen(
891 &self,
892 model: &TypedModel,
893 node: &TypedNode,
894 ) -> TractResult<Option<TypedModelPatch>> {
895 Ok(Some(TypedModelPatch::replace_single_op(
896 model,
897 node,
898 &node.inputs,
899 self.to_codegen_op(true)?,
900 )?))
901 }
902}