1use std::collections::HashSet;
2
3use crate::ops::einsum::EinSum;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6
7use super::optimized::{OptScan, ScanOpParams};
8use tract_data::internal::*;
9use tract_data::itertools::izip;
10
11use super::*;
12
13#[derive(Debug, Clone, Default)]
14pub struct Scan {
15 pub skip: usize,
16 pub reset_every_turn: bool,
17 pub external_state: bool,
28 pub body: TypedModel,
29 pub decluttered: bool,
30 pub input_mapping: Vec<InputMapping>,
31 pub output_mapping: Vec<OutputMapping<TDim>>,
32}
33
34impl PartialEq for Scan {
35 fn eq(&self, _other: &Self) -> bool {
36 false
37 }
38}
39impl Eq for Scan {}
40
41impl Scan {
42 pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<OptScan> {
43 let mut model = self.body.clone();
44 if optimize_inner {
45 model = model.into_optimized()?;
46 }
47 let plan = SimplePlan::new(model)?;
48
49 Ok(OptScan::new(Arc::new(ScanOpParams::new(
50 self.skip,
51 self.reset_every_turn,
52 plan,
53 self.input_mapping.clone(),
54 self.output_mapping.clone(),
55 ))))
56 }
57
58 pub fn new(
59 body: TypedModel,
60 input_mapping: Vec<InputMapping>,
61 output_mapping: Vec<OutputMapping<TDim>>,
62 skip: usize,
63 ) -> TractResult<Scan> {
64 body.check_consistency()?;
65 ensure!(input_mapping.len() == body.input_outlets()?.len());
66 ensure!(output_mapping.len() == body.output_outlets()?.len());
67 Ok(Scan {
68 skip,
69 reset_every_turn: false,
70 external_state: false,
71 body,
72 decluttered: false,
73 input_mapping,
74 output_mapping,
75 })
76 }
77
78 pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
79 self.to_codegen_op(false).unwrap().iteration_count(inputs)
80 }
81
82 fn declutter_body(
83 &self,
84 session: &mut OptimizerSession,
85 model: &TypedModel,
86 node: &TypedNode,
87 ) -> TractResult<Option<TypedModelPatch>> {
88 rule_if!(!self.decluttered);
89 let mut new = self.clone();
90 let mut body = self.body.clone();
91 session.optimize(&mut body)?;
92 new.body = body;
93 new.decluttered = true;
94 Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
95 }
96
97 fn declutter_single_loop(
98 &self,
99 _session: &mut OptimizerSession,
100 model: &TypedModel,
101 node: &TypedNode,
102 ) -> TractResult<Option<TypedModelPatch>> {
103 let inputs = model.node_input_facts(node.id)?;
104 let iters =
105 super::iteration_count(&self.input_mapping, &inputs).context("No scan input")?;
106 rule_if!(iters.is_one());
107 rule_if!(
121 self.reset_every_turn
122 || self.external_state
123 || !self.input_mapping.iter().any(InputMapping::is_state)
124 );
125 let mut patch = TypedModelPatch::new("Inline single loop scan");
126 patch.model = self.body.clone();
127 for (outer_wire, inner_wire) in izip!(&node.inputs, &self.body.inputs) {
128 patch.taps.insert(*inner_wire, *outer_wire);
129 }
130 for (inner_wire, mapping) in izip!(&self.body.outputs, &self.output_mapping) {
131 if let Some((slot, _)) = mapping.scan {
132 patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
133 }
134 if let Some(slot) = mapping.last_value_slot {
135 patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
136 }
137 }
138 Ok(Some(patch))
139 }
140
141 fn declutter_body_axes(
142 &self,
143 _session: &mut OptimizerSession,
144 model: &TypedModel,
145 node: &TypedNode,
146 ) -> TractResult<Option<TypedModelPatch>> {
147 let mut suggestions = vec![];
148 for n in self.body.eval_order()? {
149 let node = self.body.node(n);
150 for suggestion in node.op.suggested_axis_changes()? {
151 let outlet = suggestion.0.as_outlet(node);
152 suggestions.push(AxisChange { outlet, op: suggestion.1 })
153 }
154 for (slot, fact) in node.outputs.iter().enumerate() {
155 for (ix, dim) in fact.fact.shape.iter().enumerate() {
156 if dim.is_one() {
157 suggestions.push(AxisChange {
158 outlet: OutletId::new(n, slot),
159 op: AxisOp::Rm(ix),
160 });
161 }
162 }
163 }
164 }
165 let node_input_facts = model.node_input_facts(node.id)?;
166 for suggestion in suggestions.into_iter() {
167 if let Some(conseq) = self.try_body_axes_change(suggestion, true, &node_input_facts)? {
168 let mut patch = TypedModelPatch::default();
169 let mut inputs = tvec!();
170 for outlet in &node.inputs {
171 inputs.push(patch.tap_model(model, *outlet)?);
172 }
173 for change in conseq.wire_changes {
174 if let InOut::In(i) = change.0 {
175 let mut value = patch
176 .outlet_fact(inputs[i])?
177 .konst
178 .clone()
179 .context("Will only reshape constants")?
180 .into_tensor();
181 change.1.change_tensor(&mut value, false)?;
182 let konst_name = patch.node(inputs[i].node).name.clone();
183 inputs[i] = patch.add_const(konst_name, value)?;
184 }
185 }
186 let wires = patch.wire_node(
187 &node.name,
188 conseq.substitute_op.unwrap_or_else(|| Box::new(self.clone())),
189 &inputs,
190 )?;
191 for (ix, new) in wires.into_iter().enumerate() {
192 patch.shunt_outside(model, OutletId::new(node.id, ix), new)?;
193 }
194 return Ok(Some(patch));
195 }
196 }
197 Ok(None)
198 }
199
200 fn remove_outer_output_from_mappings(
201 mappings: &[OutputMapping<TDim>],
202 discarded: usize,
203 ) -> Vec<OutputMapping<TDim>> {
204 mappings
205 .iter()
206 .map(|m| OutputMapping {
207 scan: m.scan.map(|(slot, info)| (slot - (slot > discarded) as usize, info)),
208 last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
209 full_dim_hint: m.full_dim_hint.clone(),
210 state: m.state,
211 })
212 .collect()
213 }
214
215 fn declutter_const_input(
216 &self,
217 _session: &mut OptimizerSession,
218 model: &TypedModel,
219 node: &TypedNode,
220 ) -> TractResult<Option<TypedModelPatch>> {
221 let inputs = model.node_input_facts(node.id)?;
222 for (slot, mapping) in self.input_mapping.iter().enumerate() {
223 if let InputMapping::Full = mapping
224 && let Some(konst) = inputs[slot].konst.as_ref()
225 {
226 let mut op = self.clone();
227 let src = op.body.inputs[slot];
228 op.body.inputs.remove(slot);
229 op.body.nodes[src.node].inputs.clear();
230 op.body.nodes[src.node].op = Box::new(Const::new(konst.clone())?);
231 op.input_mapping.remove(slot);
232 let mut inputs = node.inputs.clone();
233 inputs.remove(slot);
234 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
235 }
236 }
237 Ok(None)
238 }
239
240 fn declutter_discard_unused_input(
241 &self,
242 _session: &mut OptimizerSession,
243 model: &TypedModel,
244 node: &TypedNode,
245 ) -> TractResult<Option<TypedModelPatch>> {
246 for (slot, input) in self.body.input_outlets()?.iter().enumerate() {
247 let source_node = self.body.node(input.node);
248 if source_node.outputs[0].successors.len() == 0
249 && !self.body.output_outlets()?.contains(input)
250 {
251 let mut new_inputs = node.inputs.clone();
252 new_inputs.remove(slot);
253 let mut new_mappings: Vec<_> = self.input_mapping.clone();
254 new_mappings.remove(slot);
255 let mut model_inputs = self.body.input_outlets()?.to_vec();
256 model_inputs.remove(slot);
257 let mut body = self.body.clone();
258 let mut patch = TypedModelPatch::default();
259 patch.obliterate(source_node.id)?;
260 patch.apply(&mut body)?;
261 body.set_input_outlets(&model_inputs)?;
262 body.declutter()?;
263 let op =
264 Self { body, input_mapping: new_mappings, decluttered: true, ..self.clone() };
265 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
266 }
267 }
268 Ok(None)
269 }
270
271 fn declutter_discard_useless_outer_output(
272 &self,
273 _session: &mut OptimizerSession,
274 model: &TypedModel,
275 node: &TypedNode,
276 ) -> TractResult<Option<TypedModelPatch>> {
277 for (ix, o) in node.outputs.iter().enumerate() {
278 if o.successors.len() == 0
279 && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
280 {
281 let mappings = self
282 .output_mapping
283 .iter()
284 .map(|m| OutputMapping {
285 scan: m.scan.filter(|(slot, _info)| *slot != ix),
286 last_value_slot: m.last_value_slot.filter(|s| *s != ix),
287 full_dim_hint: m.full_dim_hint.clone(),
288 state: m.state,
289 })
290 .collect::<Vec<_>>();
291 let mut op = self.clone();
292 op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
293 let mut patch = TypedModelPatch::default();
294 let inputs = node
295 .inputs
296 .iter()
297 .map(|&i| patch.tap_model(model, i))
298 .collect::<TractResult<Vec<_>>>()?;
299 let wires = patch.wire_node(&*node.name, op, &inputs)?;
300 for oix in 0..node.outputs.len() {
301 if oix != ix {
302 patch.shunt_outside(
303 model,
304 OutletId::new(node.id, oix),
305 wires[oix - (oix > ix) as usize],
306 )?;
307 }
308 }
309 return Ok(Some(patch));
310 }
311 }
312 Ok(None)
313 }
314
315 fn declutter_discard_empty_output_mapping_with_body_output(
316 &self,
317 _session: &mut OptimizerSession,
318 model: &TypedModel,
319 node: &TypedNode,
320 ) -> TractResult<Option<TypedModelPatch>> {
321 for (ix, om) in self.output_mapping.iter().enumerate() {
322 if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
323 let mut new_op = self.clone();
324 new_op.output_mapping.remove(ix);
325 new_op.body.outputs.remove(ix);
326 new_op.decluttered = false;
327 return Ok(Some(TypedModelPatch::replace_single_op(
328 model,
329 node,
330 &node.inputs,
331 new_op,
332 )?));
333 }
334 }
335 Ok(None)
336 }
337
338 fn declutter_pull_batcheable_input(
339 &self,
340 _session: &mut OptimizerSession,
341 model: &TypedModel,
342 node: &TypedNode,
343 ) -> TractResult<Option<TypedModelPatch>> {
344 'candidate: for (slot, input) in self.input_mapping.iter().enumerate() {
345 if let Some(scan_info) = input.as_scan() {
346 let scan_source = self.body.input_outlets()?[slot];
347 let scan_source_node = self.body.node(scan_source.node);
348 for mut succ in &scan_source_node.outputs[0].successors {
349 for &succ_input in &self.body.node(succ.node).inputs {
350 if succ_input != scan_source
351 && self.body.outlet_fact(succ_input)?.konst.is_none()
352 {
353 continue 'candidate;
354 }
355 }
356 if self.body.node(succ.node).outputs.len() != 1 {
357 continue;
358 }
359 let mut new_body = self.body.clone();
360 if let Some(einsum) = new_body.node(succ.node).op_as::<EinSum>()
362 && let Some(patch) = einsum
363 .propagate_axis(
364 &new_body,
365 new_body.node(succ.node),
366 InOut::In(succ.slot),
367 scan_info.axis,
368 )
369 .context("building axis propagating patch")?
370 {
371 patch.apply(&mut new_body)?;
372 new_body.compute_const_facts()?;
373 let new_body_scan_input = new_body.input_outlets()?[slot];
376 succ = new_body.node(new_body_scan_input.node).outputs[0]
377 .successors
378 .last()
379 .unwrap();
380 }
381
382 let axes_mapping = {
383 let (input_facts, output_facts) =
384 new_body.node_facts(new_body.node(succ.node).id)?;
385 new_body.node(succ.node).op.axes_mapping(&input_facts, &output_facts)?
386 };
387 let axis_info = axes_mapping.axis((InOut::In(succ.slot), scan_info.axis))?;
388 if let &[axis_after] = &*axis_info.outputs[0] {
389 let mut outside_patch = TypedModelPatch::new(format!(
390 "Outer patch for input extraction of {}",
391 new_body.node(succ.node)
392 ));
393 let mut patch_inputs = node
394 .inputs
395 .iter()
396 .map(|&i| outside_patch.tap_model(model, i))
397 .collect::<TractResult<TVec<_>>>()?;
398 let mut extracted_op_inputs = tvec!();
399 for (ix, outlet) in new_body.node(succ.node).inputs.iter().enumerate() {
400 let wire = if ix == succ.slot {
401 patch_inputs[slot]
402 } else if let Some(konst) =
403 new_body.outlet_fact(*outlet)?.konst.as_ref()
404 {
405 outside_patch.add_const(
406 format!(
407 "{}.extracted.{}",
408 node.name,
409 new_body.node(outlet.node).name
410 ),
411 konst.clone(),
412 )?
413 } else {
414 unreachable!();
415 };
416 extracted_op_inputs.push(wire);
417 }
418 let new_input_wire = outside_patch.wire_node(
419 format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
420 new_body.node(succ.node).op.clone(),
421 &extracted_op_inputs,
422 )?[0];
423 patch_inputs.push(new_input_wire);
424 let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
425 let mut new_input_inner_fact = new_input_outer_fact.clone();
426 new_input_inner_fact.shape.set(axis_after, scan_info.chunk.abs().to_dim());
427
428 let mut new_body = new_body.clone();
429 let new_source_wire = new_body.add_source(
430 format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
431 new_input_inner_fact,
432 )?;
433 let mut inner_patch = TypedModelPatch::new(format!(
434 "Inner body patch for extraction of {}",
435 new_body.node(succ.node)
436 ));
437 let new_source_wire_in_patch =
438 inner_patch.tap_model(&new_body, new_source_wire)?;
439 inner_patch
440 .shunt_outside(
441 &new_body,
442 OutletId::new(succ.node, 0),
443 new_source_wire_in_patch,
444 )
445 .with_context(|| "patching inner model")?;
446 inner_patch.apply(&mut new_body)?;
447
448 let mut input_mapping = self.input_mapping.clone();
449 input_mapping.push(InputMapping::Scan(ScanInfo {
450 axis: axis_after,
451 chunk: scan_info.chunk,
452 }));
453
454 let new_op = Self {
455 input_mapping,
456 decluttered: false,
457 body: new_body,
458 ..self.clone()
459 };
460 let output_wires =
461 outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
462 for w in output_wires {
463 outside_patch
464 .shunt_outside(model, OutletId::new(node.id, w.slot), w)
465 .with_context(|| "patching outer model")?;
466 }
467 return Ok(Some(outside_patch));
468 }
469 }
470 }
471 }
472 Ok(None)
473 }
474
475 fn declutter_pull_constant_outputs(
476 &self,
477 _session: &mut OptimizerSession,
478 model: &TypedModel,
479 node: &TypedNode,
480 ) -> TractResult<Option<TypedModelPatch>> {
481 for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
482 if let Some(slot) = mapping.last_value_slot
483 && let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone()
484 {
485 let inner_node = self.body.output_outlets()?[model_output_ix].node;
486 let inner_node = self.body.node(inner_node);
487 let mut patch = TypedModelPatch::new(format!("Extract const node {inner_node}"));
488 let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
489 patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
490 return Ok(Some(patch));
491 }
492 }
493 Ok(None)
494 }
495
496 fn declutter_pull_batcheable_output(
497 &self,
498 _session: &mut OptimizerSession,
499 model: &TypedModel,
500 node: &TypedNode,
501 ) -> TractResult<Option<TypedModelPatch>> {
502 for (mapping_ix, mapping) in self.output_mapping.iter().enumerate() {
503 if let Some((_, scan_info)) = mapping.scan {
504 let emitter_outlet = self.body.output_outlets()?[mapping_ix];
505 if self.body.node(emitter_outlet.node).outputs[emitter_outlet.slot].successors.len()
506 > 0
507 || self.body.inputs.contains(&emitter_outlet)
508 || mapping.state
509 || mapping.scan.map(|(_slot, i)| i.chunk > 1).unwrap_or(true)
510 {
511 continue;
513 }
514 let mut new_body = self.body.clone();
515 if let Some(einsum) = new_body.node(emitter_outlet.node).op_as::<EinSum>()
516 && let Some(patch) = einsum
517 .propagate_axis(
518 &new_body,
519 new_body.node(emitter_outlet.node),
520 InOut::Out(0),
521 scan_info.axis,
522 )
523 .context("building axis propagating patch")?
524 {
525 patch.apply(&mut new_body)?;
526 new_body.prop_consts()?;
527 }
528 let emitter_outlet = new_body.output_outlets()?[mapping_ix];
529 let invariants = {
530 let (input_facts, output_facts) = new_body.node_facts(emitter_outlet.node)?;
531 new_body
532 .node(emitter_outlet.node)
533 .op
534 .axes_mapping(&input_facts, &output_facts)?
535 };
536 let axis_tracking =
537 invariants.axis((InOut::Out(emitter_outlet.slot), scan_info.axis))?;
538 rule_if!(axis_tracking.outputs.iter().all(|o| o.len() == 1));
539 let mut new_output_mapping = self.output_mapping.clone();
540 let mut new_scan_outputs = node.outputs.len();
541 let mut outer_slots = vec![];
542
543 for (input_slot, input) in
545 new_body.node(emitter_outlet.node).inputs.clone().iter().enumerate()
546 {
547 if new_body.outputs.iter().all(|o| o != input) {
548 new_output_mapping.push(OutputMapping::default());
549 new_body.outputs.push(*input);
550 }
551 let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
552 let mapping = &mut new_output_mapping[body_output_id];
553 let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
554 if mapping.last_value_slot.is_none() {
555 mapping.last_value_slot = Some(new_scan_outputs);
556 new_scan_outputs += 1;
557 }
558 mapping.last_value_slot.unwrap()
559 } else if let &[axis] = &*axis_tracking.inputs[input_slot] {
560 if mapping.scan.is_none() {
561 mapping.scan =
562 Some((new_scan_outputs, ScanInfo { axis, chunk: scan_info.chunk }));
563 new_scan_outputs += 1;
564 }
565 mapping.scan.unwrap().0
566 } else {
567 return Ok(None);
568 };
569 outer_slots.push(outer_slot);
570 }
571 let mut outside_patch = TypedModelPatch::new(format!(
572 "Outside patch for output extraction of {}",
573 new_body.node(emitter_outlet.node)
574 ));
575 let inputs = node
576 .inputs
577 .iter()
578 .map(|&i| outside_patch.tap_model(model, i))
579 .collect::<TractResult<TVec<_>>>()?;
580 let new_op = Self {
581 output_mapping: new_output_mapping,
582 decluttered: false,
583 body: new_body.clone(), ..self.clone()
585 };
586 let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
587 let output = mapping.scan.unwrap();
588 let inputs =
589 outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
590 let wire = outside_patch.wire_node(
591 &new_body.node(emitter_outlet.node).name,
592 new_body.node(emitter_outlet.node).op.clone(),
593 &inputs,
594 )?[0];
595 outside_patch.shunt_outside(model, OutletId::new(node.id, output.0), wire)?;
596 for output_slot in 0..node.outputs.len() {
597 if output_slot != output.0 {
598 outside_patch.shunt_outside(
599 model,
600 OutletId::new(node.id, output_slot),
601 OutletId::new(scan_outputs[0].node, output_slot),
602 )?;
603 }
604 }
605 return Ok(Some(outside_patch));
606 }
607 }
608 Ok(None)
609 }
610
611 fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
612 let input_state_outlets = self
613 .input_mapping
614 .iter()
615 .zip(self.body.input_outlets()?.iter())
616 .filter(|(m, _)| m.is_state())
617 .map(|(_, o)| o);
618 let output_state_outlets = self
619 .output_mapping
620 .iter()
621 .zip(self.body.output_outlets()?.iter())
622 .filter(|(m, _)| m.state)
623 .map(|(_, o)| o);
624 Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
625 }
626
627 fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult<TVec<OutletId>> {
628 let input_outlets =
629 self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| {
630 if node_input_facts[slot].konst.is_none() { Some(o) } else { None }
631 });
632 let output_outlets = self
633 .output_mapping
634 .iter()
635 .zip(self.body.output_outlets()?.iter())
636 .filter(|(m, _)| !m.invisible())
637 .map(|(_, o)| o);
638 Ok(input_outlets.chain(output_outlets).cloned().collect())
639 }
640
641 fn try_body_axes_change(
642 &self,
643 change: AxisChange,
644 locked_interface: bool,
645 node_input_facts: &[&TypedFact],
646 ) -> TractResult<Option<AxisChangeConsequence>> {
647 self.body.check_consistency()?;
648 let locked_outlets = self.body_locked_outlets(node_input_facts)?;
649 let mut explored: HashSet<AxisChange> = Default::default();
650 rule_if_some!(
651 (body_patch, body_changed_wires) = crate::optim::change_axes::change_axes(
652 &self.body,
653 &change,
654 if locked_interface { &locked_outlets } else { &[] },
655 &self.body_bounds()?,
656 &mut explored,
657 )?
658 );
659 let mut body = self.body.clone();
660 body_patch.apply(&mut body)?;
661 body.compact()?;
662 let mut wire_changes = tvec!();
663 let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
664 for (slot, m) in input_mapping.iter_mut().enumerate() {
665 if let Some(change) = body_changed_wires
666 .iter()
667 .find(|(iface, _change)| iface == &InOut::In(slot))
668 .map(|pair| pair.1.clone())
669 {
670 wire_changes.push((InOut::In(slot), change.clone()));
671 if let InputMapping::Scan(info) = m {
672 rule_if_some!(axis = change.transform_axis(info.axis));
673 info.axis = axis;
674 };
675 }
676 }
677 let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
678 for (ix, m) in output_mapping.iter_mut().enumerate() {
679 if let Some(change) = body_changed_wires
680 .iter()
681 .find(|(iface, _change)| iface == &InOut::Out(ix))
682 .map(|pair| pair.1.clone())
683 {
684 if let Some((slot, info)) = m.scan.as_mut() {
685 rule_if_some!(new_axis = change.transform_axis(info.axis));
686 info.axis = new_axis;
687 wire_changes.push((InOut::Out(*slot), change.clone()));
688 }
689 if let Some(slot) = m.last_value_slot {
690 wire_changes.push((InOut::Out(slot), change.clone()));
691 }
692 };
693 }
694 body.check_consistency()?;
695 let op = Some(Box::new(Scan {
696 body,
697 input_mapping,
698 output_mapping,
699 decluttered: false,
700 ..self.clone()
701 }) as _);
702 Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
703 }
704}
705
706impl Op for Scan {
707 fn name(&self) -> StaticName {
708 "Scan".into()
709 }
710
711 fn info(&self) -> TractResult<Vec<String>> {
712 let mut lines = vec![];
713 for (ix, im) in self.input_mapping.iter().enumerate() {
714 lines.push(format!("Model input #{ix}: {im:?}"));
715 }
716 for (ix, om) in self.output_mapping.iter().enumerate() {
717 lines.push(format!("Model output #{ix}: {om:?}"));
718 }
719 lines.push(format!("skip:{} reset_every_turn:{:?}", self.skip, self.reset_every_turn));
720 Ok(lines)
721 }
722
723 fn validation(&self) -> Validation {
724 Validation::Rounding
725 }
726
727 op_as_typed_op!();
728}
729
730impl EvalOp for Scan {
731 fn is_stateless(&self) -> bool {
732 false
733 }
734 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
735 self.to_codegen_op(false)?.state(session, node_id)
736 }
737}
738
739impl TypedOp for Scan {
740 as_op!();
741
742 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
743 anyhow::ensure!(inputs.len() == self.body.inputs.len());
744 anyhow::ensure!(self.input_mapping.len() == self.body.inputs.len());
745 anyhow::ensure!(
746 self.input_mapping.iter().filter(|m| m.is_state()).count()
747 == self.output_mapping.iter().filter(|m| m.state).count()
748 );
749 for (i, o) in
750 self.input_mapping.iter().enumerate().filter(|(_, m)| m.is_state()).map(|(i, _)| i).zip(
751 self.output_mapping.iter().enumerate().filter(|(_, m)| m.state).map(|(o, _)| o),
752 )
753 {
754 let ifact = self.body.outlet_fact(self.body.inputs[i])?;
755 let ofact = self.body.outlet_fact(self.body.outputs[o])?;
756 anyhow::ensure!(
757 ifact == ofact,
758 "inconsistent fact: body input {i} is {ifact:?} and body output {o} is {ofact:?}\n{}",
759 self.body
760 )
761 }
762 let mut outputs = tvec!();
763 let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?;
764 for (ix, output) in self.output_mapping.iter().enumerate() {
765 let fact = self.body.output_fact(ix)?;
766 if let Some((slot, info)) = output.scan {
767 let mut shape = fact.shape.clone();
768 let scanning_dim =
769 output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
770 shape.set(info.axis, scanning_dim);
771 outputs.push((slot, fact.datum_type.fact(shape)));
772 }
773 if let Some(slot) = output.last_value_slot {
774 outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
775 }
776 }
777 outputs.sort_by_key(|a| a.0);
778 anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
779 let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
780 Ok(outputs)
781 }
782
783 fn axes_mapping(
784 &self,
785 inputs: &[&TypedFact],
786 outputs: &[&TypedFact],
787 ) -> TractResult<AxesMapping> {
788 let mut mappings = vec![];
789 let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?;
790 for body_axis in body_invs.iter_all_axes() {
791 let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len());
792 info.inputs.clone_from(&body_axis.inputs);
793 for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
794 let mut slots = vec![];
795 if let Some((slot, _scan)) = output_mapping.scan {
796 slots.push(slot);
797 }
798 if let Some(slot) = output_mapping.last_value_slot {
799 slots.push(slot);
800 }
801 for slot in slots {
802 info.outputs[slot].clone_from(&body_axis.outputs[ix]);
803 }
804 }
805 if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) {
806 mappings.push(info);
807 }
808 }
809 AxesMapping::new(inputs.len(), outputs.len(), mappings)
810 }
811
812 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
813 let mut suggestions = tvec!();
814 for (slot, input) in self.input_mapping.iter().enumerate() {
815 if let InputMapping::Scan(info) = input
816 && info.axis != 0
817 {
818 suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0)))
819 }
820 }
821 for output in &self.output_mapping {
822 if let Some((slot, scan)) = output.scan
823 && scan.axis != 0
824 {
825 suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0)))
826 }
827 }
828 Ok(suggestions)
829 }
830
831 fn change_axes(
832 &self,
833 model: &TypedModel,
834 node: &TypedNode,
835 io: InOut,
836 change: &AxisOp,
837 ) -> TractResult<Option<AxisChangeConsequence>> {
838 trace!("Propagating through {node}: {io:?} {change:?}");
839 let body_leading_outlet = match io {
840 InOut::In(ix) => self.body.input_outlets()?[ix],
841 InOut::Out(slot) => {
842 let output = self
843 .output_mapping
844 .iter()
845 .position(|im| {
846 im.scan.map(|(slot, _i)| slot) == Some(slot)
847 || im.last_value_slot == Some(slot)
848 })
849 .unwrap();
850 self.body.output_outlets()?[output]
851 }
852 };
853 let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() };
854 let node_input_facts = model.node_input_facts(node.id)?;
855 let result = self
856 .try_body_axes_change(axis_change, false, &node_input_facts)
857 .with_context(|| "Attemping to run change through scan body".to_string())?;
858 if result.is_some() {
859 trace!("{node} accepted axis change");
860 } else {
861 trace!("{node} rejected axis change");
862 }
863 Ok(result)
864 }
865
866 fn declutter_with_session(
867 &self,
868 session: &mut OptimizerSession,
869 model: &TypedModel,
870 node: &TypedNode,
871 ) -> TractResult<Option<TypedModelPatch>> {
872 macro_rules! pass {
873 ($func:ident) => {
874 if let Some(mut r) = self
875 .$func(session, model, node)
876 .with_context(|| format!("{}", stringify!($func)))?
877 {
878 trace!(stringify!($func));
879 r.push_context(stringify!($func));
880 return Ok(Some(r));
881 }
882 };
883 }
884 pass!(declutter_single_loop);
885 pass!(declutter_const_input);
886 pass!(declutter_discard_unused_input);
887 pass!(declutter_discard_useless_outer_output);
888 pass!(declutter_discard_empty_output_mapping_with_body_output);
889 pass!(declutter_body);
890 pass!(declutter_body_axes);
891 pass!(declutter_pull_constant_outputs);
892 pass!(declutter_pull_batcheable_input);
893 pass!(declutter_pull_batcheable_output);
894 Ok(None)
895 }
896
897 fn set_symbols(
898 &self,
899 _source: &TypedModel,
900 node: &TypedNode,
901 target: &mut TypedModel,
902 mapping: &HashMap<OutletId, OutletId>,
903 subs: &HashMap<Symbol, TDim>,
904 ) -> TractResult<TVec<OutletId>> {
905 let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
906 let op = Self {
907 output_mapping: self
908 .output_mapping
909 .iter()
910 .map(|om| om.set_symbols(subs))
911 .collect::<TractResult<Vec<_>>>()?,
912 body: self.body.set_symbols(subs)?,
913 ..self.clone()
914 };
915 target.wire_node(&node.name, op, &inputs)
916 }
917
918 fn codegen(
919 &self,
920 model: &TypedModel,
921 node: &TypedNode,
922 ) -> TractResult<Option<TypedModelPatch>> {
923 Ok(Some(TypedModelPatch::replace_single_op(
924 model,
925 node,
926 &node.inputs,
927 self.to_codegen_op(true)?,
928 )?))
929 }
930}