1use std::collections::HashSet;
2
3use crate::ops::einsum::EinSum;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6
7use super::optimized::{OptScan, ScanOpParams};
8use tract_data::internal::*;
9use tract_data::itertools::izip;
10
11use super::*;
12
13#[derive(Debug, Clone, Default)]
14pub struct Scan {
15 pub skip: usize,
16 pub reset_every_turn: bool,
17 pub external_state: bool,
28 pub body: TypedModel,
29 pub decluttered: bool,
30 pub input_mapping: Vec<InputMapping>,
31 pub output_mapping: Vec<OutputMapping<TDim>>,
32}
33
34impl PartialEq for Scan {
35 fn eq(&self, _other: &Self) -> bool {
36 false
37 }
38}
39impl Eq for Scan {}
40
41impl Scan {
42 pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<OptScan> {
43 let mut model = self.body.clone();
44 if optimize_inner {
45 model = model.into_optimized()?;
46 }
47 let plan = SimplePlan::new(model)?;
48
49 Ok(OptScan::new(Arc::new(ScanOpParams::new(
50 self.skip,
51 self.reset_every_turn,
52 plan,
53 self.input_mapping.clone(),
54 self.output_mapping.clone(),
55 ))))
56 }
57
58 pub fn new(
59 body: TypedModel,
60 input_mapping: Vec<InputMapping>,
61 output_mapping: Vec<OutputMapping<TDim>>,
62 skip: usize,
63 ) -> TractResult<Scan> {
64 body.check_consistency()?;
65 ensure!(input_mapping.len() == body.input_outlets()?.len());
66 ensure!(output_mapping.len() == body.output_outlets()?.len());
67 Ok(Scan {
68 skip,
69 reset_every_turn: false,
70 external_state: false,
71 body,
72 decluttered: false,
73 input_mapping,
74 output_mapping,
75 })
76 }
77
78 pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
79 self.to_codegen_op(false).unwrap().iteration_count(inputs)
80 }
81
82 fn declutter_body(
83 &self,
84 session: &mut OptimizerSession,
85 model: &TypedModel,
86 node: &TypedNode,
87 ) -> TractResult<Option<TypedModelPatch>> {
88 rule_if!(!self.decluttered);
89 let mut new = self.clone();
90 let mut body = self.body.clone();
91 session.optimize(&mut body)?;
92 new.body = body;
93 new.decluttered = true;
94 Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
95 }
96
97 fn declutter_single_loop(
98 &self,
99 _session: &mut OptimizerSession,
100 model: &TypedModel,
101 node: &TypedNode,
102 ) -> TractResult<Option<TypedModelPatch>> {
103 let inputs = model.node_input_facts(node.id)?;
104 let iters =
105 super::iteration_count(&self.input_mapping, &inputs).context("No scan input")?;
106 rule_if!(iters.is_one());
107 let has_state = self.input_mapping.iter().any(InputMapping::is_state);
125 let state_outputs: Vec<_> = self.output_mapping.iter().filter(|m| m.state).collect();
126 let state_exported = has_state
127 && !state_outputs.is_empty()
128 && state_outputs.iter().all(|m| {
129 m.last_value_slot.is_some_and(|slot| {
130 Self::outlet_reaches_model_output(model, OutletId::new(node.id, slot))
131 })
132 });
133 rule_if!(self.reset_every_turn || self.external_state || !has_state || state_exported);
134 let mut patch = TypedModelPatch::new("Inline single loop scan");
135 patch.model = self.body.clone();
136 for (outer_wire, inner_wire) in izip!(&node.inputs, &self.body.inputs) {
137 patch.taps.insert(*inner_wire, *outer_wire);
138 }
139 for (inner_wire, mapping) in izip!(&self.body.outputs, &self.output_mapping) {
140 if let Some((slot, _)) = mapping.scan {
141 patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
142 }
143 if let Some(slot) = mapping.last_value_slot {
144 patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
145 }
146 }
147 Ok(Some(patch))
148 }
149
150 fn outlet_reaches_model_output(model: &TypedModel, start: OutletId) -> bool {
156 let outputs: std::collections::HashSet<OutletId> = model.outputs.iter().copied().collect();
157 let mut seen: std::collections::HashSet<OutletId> = Default::default();
158 let mut stack = vec![start];
159 while let Some(o) = stack.pop() {
160 if outputs.contains(&o) {
161 return true;
162 }
163 if !seen.insert(o) {
164 continue;
165 }
166 for succ in &model.node(o.node).outputs[o.slot].successors {
167 for slot in 0..model.node(succ.node).outputs.len() {
168 stack.push(OutletId::new(succ.node, slot));
169 }
170 }
171 }
172 false
173 }
174
175 fn declutter_body_axes(
176 &self,
177 _session: &mut OptimizerSession,
178 model: &TypedModel,
179 node: &TypedNode,
180 ) -> TractResult<Option<TypedModelPatch>> {
181 let mut suggestions = vec![];
182 for n in self.body.eval_order()? {
183 let node = self.body.node(n);
184 for suggestion in node.op.suggested_axis_changes()? {
185 let outlet = suggestion.0.as_outlet(node);
186 suggestions.push(AxisChange { outlet, op: suggestion.1 })
187 }
188 for (slot, fact) in node.outputs.iter().enumerate() {
189 for (ix, dim) in fact.fact.shape.iter().enumerate() {
190 if dim.is_one() {
191 suggestions.push(AxisChange {
192 outlet: OutletId::new(n, slot),
193 op: AxisOp::Rm(ix),
194 });
195 }
196 }
197 }
198 }
199 let node_input_facts = model.node_input_facts(node.id)?;
200 for suggestion in suggestions.into_iter() {
201 if let Some(conseq) = self.try_body_axes_change(suggestion, true, &node_input_facts)? {
202 let mut patch = TypedModelPatch::default();
203 let mut inputs = tvec!();
204 for outlet in &node.inputs {
205 inputs.push(patch.tap_model(model, *outlet)?);
206 }
207 for change in conseq.wire_changes {
208 if let InOut::In(i) = change.0 {
209 let mut value = patch
210 .outlet_fact(inputs[i])?
211 .konst
212 .clone()
213 .context("Will only reshape constants")?
214 .into_tensor();
215 change.1.change_tensor(&mut value, false)?;
216 let konst_name = patch.node(inputs[i].node).name.clone();
217 inputs[i] = patch.add_const(konst_name, value)?;
218 }
219 }
220 let wires = patch.wire_node(
221 &node.name,
222 conseq.substitute_op.unwrap_or_else(|| Box::new(self.clone())),
223 &inputs,
224 )?;
225 for (ix, new) in wires.into_iter().enumerate() {
226 patch.shunt_outside(model, OutletId::new(node.id, ix), new)?;
227 }
228 return Ok(Some(patch));
229 }
230 }
231 Ok(None)
232 }
233
234 fn remove_outer_output_from_mappings(
235 mappings: &[OutputMapping<TDim>],
236 discarded: usize,
237 ) -> Vec<OutputMapping<TDim>> {
238 mappings
239 .iter()
240 .map(|m| OutputMapping {
241 scan: m.scan.map(|(slot, info)| (slot - (slot > discarded) as usize, info)),
242 last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
243 full_dim_hint: m.full_dim_hint.clone(),
244 state: m.state,
245 })
246 .collect()
247 }
248
249 fn declutter_const_input(
250 &self,
251 _session: &mut OptimizerSession,
252 model: &TypedModel,
253 node: &TypedNode,
254 ) -> TractResult<Option<TypedModelPatch>> {
255 let inputs = model.node_input_facts(node.id)?;
256 for (slot, mapping) in self.input_mapping.iter().enumerate() {
257 if let InputMapping::Full = mapping
258 && let Some(konst) = inputs[slot].konst.as_ref()
259 {
260 let mut op = self.clone();
261 let src = op.body.inputs[slot];
262 op.body.inputs.remove(slot);
263 op.body.nodes[src.node].inputs.clear();
264 op.body.nodes[src.node].op = Box::new(Const::new(konst.clone())?);
265 op.input_mapping.remove(slot);
266 let mut inputs = node.inputs.clone();
267 inputs.remove(slot);
268 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
269 }
270 }
271 Ok(None)
272 }
273
274 fn declutter_discard_unused_input(
275 &self,
276 _session: &mut OptimizerSession,
277 model: &TypedModel,
278 node: &TypedNode,
279 ) -> TractResult<Option<TypedModelPatch>> {
280 for (slot, input) in self.body.input_outlets()?.iter().enumerate() {
281 let source_node = self.body.node(input.node);
282 if source_node.outputs[0].successors.len() == 0
283 && !self.body.output_outlets()?.contains(input)
284 {
285 let mut new_inputs = node.inputs.clone();
286 new_inputs.remove(slot);
287 let mut new_mappings: Vec<_> = self.input_mapping.clone();
288 new_mappings.remove(slot);
289 let mut model_inputs = self.body.input_outlets()?.to_vec();
290 model_inputs.remove(slot);
291 let mut body = self.body.clone();
292 let mut patch = TypedModelPatch::default();
293 patch.obliterate(source_node.id)?;
294 patch.apply(&mut body)?;
295 body.set_input_outlets(&model_inputs)?;
296 body.declutter()?;
297 let op =
298 Self { body, input_mapping: new_mappings, decluttered: true, ..self.clone() };
299 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
300 }
301 }
302 Ok(None)
303 }
304
305 fn declutter_discard_useless_outer_output(
306 &self,
307 _session: &mut OptimizerSession,
308 model: &TypedModel,
309 node: &TypedNode,
310 ) -> TractResult<Option<TypedModelPatch>> {
311 for (ix, o) in node.outputs.iter().enumerate() {
312 if o.successors.len() == 0
313 && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
314 {
315 let mappings = self
316 .output_mapping
317 .iter()
318 .map(|m| OutputMapping {
319 scan: m.scan.filter(|(slot, _info)| *slot != ix),
320 last_value_slot: m.last_value_slot.filter(|s| *s != ix),
321 full_dim_hint: m.full_dim_hint.clone(),
322 state: m.state,
323 })
324 .collect::<Vec<_>>();
325 let mut op = self.clone();
326 op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
327 let mut patch = TypedModelPatch::default();
328 let inputs = node
329 .inputs
330 .iter()
331 .map(|&i| patch.tap_model(model, i))
332 .collect::<TractResult<Vec<_>>>()?;
333 let wires = patch.wire_node(&*node.name, op, &inputs)?;
334 for oix in 0..node.outputs.len() {
335 if oix != ix {
336 patch.shunt_outside(
337 model,
338 OutletId::new(node.id, oix),
339 wires[oix - (oix > ix) as usize],
340 )?;
341 }
342 }
343 return Ok(Some(patch));
344 }
345 }
346 Ok(None)
347 }
348
349 fn declutter_discard_empty_output_mapping_with_body_output(
350 &self,
351 _session: &mut OptimizerSession,
352 model: &TypedModel,
353 node: &TypedNode,
354 ) -> TractResult<Option<TypedModelPatch>> {
355 for (ix, om) in self.output_mapping.iter().enumerate() {
356 if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
357 let mut new_op = self.clone();
358 new_op.output_mapping.remove(ix);
359 new_op.body.outputs.remove(ix);
360 new_op.decluttered = false;
361 return Ok(Some(TypedModelPatch::replace_single_op(
362 model,
363 node,
364 &node.inputs,
365 new_op,
366 )?));
367 }
368 }
369 Ok(None)
370 }
371
372 fn declutter_pull_batcheable_input(
373 &self,
374 _session: &mut OptimizerSession,
375 model: &TypedModel,
376 node: &TypedNode,
377 ) -> TractResult<Option<TypedModelPatch>> {
378 'candidate: for (slot, input) in self.input_mapping.iter().enumerate() {
379 if let Some(scan_info) = input.as_scan() {
380 let scan_source = self.body.input_outlets()?[slot];
381 let scan_source_node = self.body.node(scan_source.node);
382 for mut succ in &scan_source_node.outputs[0].successors {
383 for &succ_input in &self.body.node(succ.node).inputs {
384 if succ_input != scan_source
385 && self.body.outlet_fact(succ_input)?.konst.is_none()
386 {
387 continue 'candidate;
388 }
389 }
390 if self.body.node(succ.node).outputs.len() != 1 {
391 continue;
392 }
393 let mut new_body = self.body.clone();
394 if let Some(einsum) = new_body.node(succ.node).op_as::<EinSum>()
396 && let Some(patch) = einsum
397 .propagate_axis(
398 &new_body,
399 new_body.node(succ.node),
400 InOut::In(succ.slot),
401 scan_info.axis,
402 )
403 .context("building axis propagating patch")?
404 {
405 patch.apply(&mut new_body)?;
406 new_body.compute_const_facts()?;
407 let new_body_scan_input = new_body.input_outlets()?[slot];
410 succ = new_body.node(new_body_scan_input.node).outputs[0]
411 .successors
412 .last()
413 .unwrap();
414 }
415
416 let axes_mapping = {
417 let (input_facts, output_facts) =
418 new_body.node_facts(new_body.node(succ.node).id)?;
419 new_body.node(succ.node).op.axes_mapping(&input_facts, &output_facts)?
420 };
421 let axis_info = axes_mapping.axis((InOut::In(succ.slot), scan_info.axis))?;
422 if let &[axis_after] = &*axis_info.outputs[0] {
423 let mut outside_patch = TypedModelPatch::new(format!(
424 "Outer patch for input extraction of {}",
425 new_body.node(succ.node)
426 ));
427 let mut patch_inputs = node
428 .inputs
429 .iter()
430 .map(|&i| outside_patch.tap_model(model, i))
431 .collect::<TractResult<TVec<_>>>()?;
432 let mut extracted_op_inputs = tvec!();
433 for (ix, outlet) in new_body.node(succ.node).inputs.iter().enumerate() {
434 let wire = if ix == succ.slot {
435 patch_inputs[slot]
436 } else if let Some(konst) =
437 new_body.outlet_fact(*outlet)?.konst.as_ref()
438 {
439 outside_patch.add_const(
440 format!(
441 "{}.extracted.{}",
442 node.name,
443 new_body.node(outlet.node).name
444 ),
445 konst.clone(),
446 )?
447 } else {
448 unreachable!();
449 };
450 extracted_op_inputs.push(wire);
451 }
452 let new_input_wire = outside_patch.wire_node(
453 format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
454 new_body.node(succ.node).op.clone(),
455 &extracted_op_inputs,
456 )?[0];
457 patch_inputs.push(new_input_wire);
458 let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
459 let mut new_input_inner_fact = new_input_outer_fact.clone();
460 new_input_inner_fact.shape.set(axis_after, scan_info.chunk.abs().to_dim());
461
462 let mut new_body = new_body.clone();
463 let new_source_wire = new_body.add_source(
464 format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
465 new_input_inner_fact,
466 )?;
467 let mut inner_patch = TypedModelPatch::new(format!(
468 "Inner body patch for extraction of {}",
469 new_body.node(succ.node)
470 ));
471 let new_source_wire_in_patch =
472 inner_patch.tap_model(&new_body, new_source_wire)?;
473 inner_patch
474 .shunt_outside(
475 &new_body,
476 OutletId::new(succ.node, 0),
477 new_source_wire_in_patch,
478 )
479 .with_context(|| "patching inner model")?;
480 inner_patch.apply(&mut new_body)?;
481
482 let mut input_mapping = self.input_mapping.clone();
483 input_mapping.push(InputMapping::Scan(ScanInfo {
484 axis: axis_after,
485 chunk: scan_info.chunk,
486 }));
487
488 let new_op = Self {
489 input_mapping,
490 decluttered: false,
491 body: new_body,
492 ..self.clone()
493 };
494 let output_wires =
495 outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
496 for w in output_wires {
497 outside_patch
498 .shunt_outside(model, OutletId::new(node.id, w.slot), w)
499 .with_context(|| "patching outer model")?;
500 }
501 return Ok(Some(outside_patch));
502 }
503 }
504 }
505 }
506 Ok(None)
507 }
508
509 fn declutter_pull_constant_outputs(
510 &self,
511 _session: &mut OptimizerSession,
512 model: &TypedModel,
513 node: &TypedNode,
514 ) -> TractResult<Option<TypedModelPatch>> {
515 for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
516 if let Some(slot) = mapping.last_value_slot
517 && let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone()
518 {
519 let inner_node = self.body.output_outlets()?[model_output_ix].node;
520 let inner_node = self.body.node(inner_node);
521 let mut patch = TypedModelPatch::new(format!("Extract const node {inner_node}"));
522 let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
523 patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
524 return Ok(Some(patch));
525 }
526 }
527 Ok(None)
528 }
529
530 fn declutter_pull_batcheable_output(
531 &self,
532 _session: &mut OptimizerSession,
533 model: &TypedModel,
534 node: &TypedNode,
535 ) -> TractResult<Option<TypedModelPatch>> {
536 for (mapping_ix, mapping) in self.output_mapping.iter().enumerate() {
537 if let Some((_, scan_info)) = mapping.scan {
538 let emitter_outlet = self.body.output_outlets()?[mapping_ix];
539 if self.body.node(emitter_outlet.node).outputs[emitter_outlet.slot].successors.len()
540 > 0
541 || self.body.inputs.contains(&emitter_outlet)
542 || mapping.state
543 || mapping.scan.map(|(_slot, i)| i.chunk > 1).unwrap_or(true)
544 {
545 continue;
547 }
548 let mut new_body = self.body.clone();
549 if let Some(einsum) = new_body.node(emitter_outlet.node).op_as::<EinSum>()
550 && let Some(patch) = einsum
551 .propagate_axis(
552 &new_body,
553 new_body.node(emitter_outlet.node),
554 InOut::Out(0),
555 scan_info.axis,
556 )
557 .context("building axis propagating patch")?
558 {
559 patch.apply(&mut new_body)?;
560 new_body.prop_consts()?;
561 }
562 let emitter_outlet = new_body.output_outlets()?[mapping_ix];
563 let invariants = {
564 let (input_facts, output_facts) = new_body.node_facts(emitter_outlet.node)?;
565 new_body
566 .node(emitter_outlet.node)
567 .op
568 .axes_mapping(&input_facts, &output_facts)?
569 };
570 let axis_tracking =
571 invariants.axis((InOut::Out(emitter_outlet.slot), scan_info.axis))?;
572 rule_if!(axis_tracking.outputs.iter().all(|o| o.len() == 1));
573 let mut new_output_mapping = self.output_mapping.clone();
574 let mut new_scan_outputs = node.outputs.len();
575 let mut outer_slots = vec![];
576
577 for (input_slot, input) in
579 new_body.node(emitter_outlet.node).inputs.clone().iter().enumerate()
580 {
581 if new_body.outputs.iter().all(|o| o != input) {
582 new_output_mapping.push(OutputMapping::default());
583 new_body.outputs.push(*input);
584 }
585 let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
586 let mapping = &mut new_output_mapping[body_output_id];
587 let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
588 if mapping.last_value_slot.is_none() {
589 mapping.last_value_slot = Some(new_scan_outputs);
590 new_scan_outputs += 1;
591 }
592 mapping.last_value_slot.unwrap()
593 } else if let &[axis] = &*axis_tracking.inputs[input_slot] {
594 if mapping.scan.is_none() {
595 mapping.scan =
596 Some((new_scan_outputs, ScanInfo { axis, chunk: scan_info.chunk }));
597 new_scan_outputs += 1;
598 }
599 mapping.scan.unwrap().0
600 } else {
601 return Ok(None);
602 };
603 outer_slots.push(outer_slot);
604 }
605 let mut outside_patch = TypedModelPatch::new(format!(
606 "Outside patch for output extraction of {}",
607 new_body.node(emitter_outlet.node)
608 ));
609 let inputs = node
610 .inputs
611 .iter()
612 .map(|&i| outside_patch.tap_model(model, i))
613 .collect::<TractResult<TVec<_>>>()?;
614 let new_op = Self {
615 output_mapping: new_output_mapping,
616 decluttered: false,
617 body: new_body.clone(), ..self.clone()
619 };
620 let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
621 let output = mapping.scan.unwrap();
622 let inputs =
623 outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
624 let wire = outside_patch.wire_node(
625 &new_body.node(emitter_outlet.node).name,
626 new_body.node(emitter_outlet.node).op.clone(),
627 &inputs,
628 )?[0];
629 outside_patch.shunt_outside(model, OutletId::new(node.id, output.0), wire)?;
630 for output_slot in 0..node.outputs.len() {
631 if output_slot != output.0 {
632 outside_patch.shunt_outside(
633 model,
634 OutletId::new(node.id, output_slot),
635 OutletId::new(scan_outputs[0].node, output_slot),
636 )?;
637 }
638 }
639 return Ok(Some(outside_patch));
640 }
641 }
642 Ok(None)
643 }
644
645 fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
646 let input_state_outlets = self
647 .input_mapping
648 .iter()
649 .zip(self.body.input_outlets()?.iter())
650 .filter(|(m, _)| m.is_state())
651 .map(|(_, o)| o);
652 let output_state_outlets = self
653 .output_mapping
654 .iter()
655 .zip(self.body.output_outlets()?.iter())
656 .filter(|(m, _)| m.state)
657 .map(|(_, o)| o);
658 Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
659 }
660
661 fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult<TVec<OutletId>> {
662 let input_outlets =
663 self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| {
664 if node_input_facts[slot].konst.is_none() { Some(o) } else { None }
665 });
666 let output_outlets = self
667 .output_mapping
668 .iter()
669 .zip(self.body.output_outlets()?.iter())
670 .filter(|(m, _)| !m.invisible())
671 .map(|(_, o)| o);
672 Ok(input_outlets.chain(output_outlets).cloned().collect())
673 }
674
675 fn try_body_axes_change(
676 &self,
677 change: AxisChange,
678 locked_interface: bool,
679 node_input_facts: &[&TypedFact],
680 ) -> TractResult<Option<AxisChangeConsequence>> {
681 self.body.check_consistency()?;
682 let locked_outlets = self.body_locked_outlets(node_input_facts)?;
683 let mut explored: HashSet<AxisChange> = Default::default();
684 rule_if_some!(
685 (body_patch, body_changed_wires) = crate::optim::change_axes::change_axes(
686 &self.body,
687 &change,
688 if locked_interface { &locked_outlets } else { &[] },
689 &self.body_bounds()?,
690 &mut explored,
691 )?
692 );
693 let mut body = self.body.clone();
694 body_patch.apply(&mut body)?;
695 body.compact()?;
696 let mut wire_changes = tvec!();
697 let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
698 for (slot, m) in input_mapping.iter_mut().enumerate() {
699 if let Some(change) = body_changed_wires
700 .iter()
701 .find(|(iface, _change)| iface == &InOut::In(slot))
702 .map(|pair| pair.1.clone())
703 {
704 wire_changes.push((InOut::In(slot), change.clone()));
705 if let InputMapping::Scan(info) = m {
706 rule_if_some!(axis = change.transform_axis(info.axis));
707 info.axis = axis;
708 };
709 }
710 }
711 let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
712 for (ix, m) in output_mapping.iter_mut().enumerate() {
713 if let Some(change) = body_changed_wires
714 .iter()
715 .find(|(iface, _change)| iface == &InOut::Out(ix))
716 .map(|pair| pair.1.clone())
717 {
718 if let Some((slot, info)) = m.scan.as_mut() {
719 rule_if_some!(new_axis = change.transform_axis(info.axis));
720 info.axis = new_axis;
721 wire_changes.push((InOut::Out(*slot), change.clone()));
722 }
723 if let Some(slot) = m.last_value_slot {
724 wire_changes.push((InOut::Out(slot), change.clone()));
725 }
726 };
727 }
728 body.check_consistency()?;
729 let op = Some(Box::new(Scan {
730 body,
731 input_mapping,
732 output_mapping,
733 decluttered: false,
734 ..self.clone()
735 }) as _);
736 Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
737 }
738}
739
740impl Op for Scan {
741 fn name(&self) -> StaticName {
742 "Scan".into()
743 }
744
745 fn info(&self) -> TractResult<Vec<String>> {
746 let mut lines = vec![];
747 for (ix, im) in self.input_mapping.iter().enumerate() {
748 lines.push(format!("Model input #{ix}: {im:?}"));
749 }
750 for (ix, om) in self.output_mapping.iter().enumerate() {
751 lines.push(format!("Model output #{ix}: {om:?}"));
752 }
753 lines.push(format!("skip:{} reset_every_turn:{:?}", self.skip, self.reset_every_turn));
754 Ok(lines)
755 }
756
757 fn validation(&self) -> Validation {
758 Validation::Rounding
759 }
760
761 op_as_typed_op!();
762}
763
764impl EvalOp for Scan {
765 fn is_stateless(&self) -> bool {
766 false
767 }
768 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
769 self.to_codegen_op(false)?.state(session, node_id)
770 }
771}
772
773impl TypedOp for Scan {
774 as_op!();
775
776 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
777 anyhow::ensure!(inputs.len() == self.body.inputs.len());
778 anyhow::ensure!(self.input_mapping.len() == self.body.inputs.len());
779 anyhow::ensure!(
780 self.input_mapping.iter().filter(|m| m.is_state()).count()
781 == self.output_mapping.iter().filter(|m| m.state).count()
782 );
783 for (i, o) in
784 self.input_mapping.iter().enumerate().filter(|(_, m)| m.is_state()).map(|(i, _)| i).zip(
785 self.output_mapping.iter().enumerate().filter(|(_, m)| m.state).map(|(o, _)| o),
786 )
787 {
788 let ifact = self.body.outlet_fact(self.body.inputs[i])?;
789 let ofact = self.body.outlet_fact(self.body.outputs[o])?;
790 anyhow::ensure!(
791 ifact == ofact,
792 "inconsistent fact: body input {i} is {ifact:?} and body output {o} is {ofact:?}\n{}",
793 self.body
794 )
795 }
796 let mut outputs = tvec!();
797 let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?;
798 for (ix, output) in self.output_mapping.iter().enumerate() {
799 let fact = self.body.output_fact(ix)?;
800 if let Some((slot, info)) = output.scan {
801 let mut shape = fact.shape.clone();
802 let scanning_dim =
803 output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
804 shape.set(info.axis, scanning_dim);
805 outputs.push((slot, fact.datum_type.fact(shape)));
806 }
807 if let Some(slot) = output.last_value_slot {
808 outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
809 }
810 }
811 outputs.sort_by_key(|a| a.0);
812 anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
813 let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
814 Ok(outputs)
815 }
816
817 fn axes_mapping(
818 &self,
819 inputs: &[&TypedFact],
820 outputs: &[&TypedFact],
821 ) -> TractResult<AxesMapping> {
822 let mut mappings = vec![];
823 let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?;
824 for body_axis in body_invs.iter_all_axes() {
825 let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len());
826 info.inputs.clone_from(&body_axis.inputs);
827 for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
828 let mut slots = vec![];
829 if let Some((slot, _scan)) = output_mapping.scan {
830 slots.push(slot);
831 }
832 if let Some(slot) = output_mapping.last_value_slot {
833 slots.push(slot);
834 }
835 for slot in slots {
836 info.outputs[slot].clone_from(&body_axis.outputs[ix]);
837 }
838 }
839 if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) {
840 mappings.push(info);
841 }
842 }
843 AxesMapping::new(inputs.len(), outputs.len(), mappings)
844 }
845
846 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
847 let mut suggestions = tvec!();
848 for (slot, input) in self.input_mapping.iter().enumerate() {
849 if let InputMapping::Scan(info) = input
850 && info.axis != 0
851 {
852 suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0)))
853 }
854 }
855 for output in &self.output_mapping {
856 if let Some((slot, scan)) = output.scan
857 && scan.axis != 0
858 {
859 suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0)))
860 }
861 }
862 Ok(suggestions)
863 }
864
865 fn change_axes(
866 &self,
867 model: &TypedModel,
868 node: &TypedNode,
869 io: InOut,
870 change: &AxisOp,
871 ) -> TractResult<Option<AxisChangeConsequence>> {
872 trace!("Propagating through {node}: {io:?} {change:?}");
873 let body_leading_outlet = match io {
874 InOut::In(ix) => self.body.input_outlets()?[ix],
875 InOut::Out(slot) => {
876 let output = self
877 .output_mapping
878 .iter()
879 .position(|im| {
880 im.scan.map(|(slot, _i)| slot) == Some(slot)
881 || im.last_value_slot == Some(slot)
882 })
883 .unwrap();
884 self.body.output_outlets()?[output]
885 }
886 };
887 let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() };
888 let node_input_facts = model.node_input_facts(node.id)?;
889 let result = self
890 .try_body_axes_change(axis_change, false, &node_input_facts)
891 .with_context(|| "Attemping to run change through scan body".to_string())?;
892 if result.is_some() {
893 trace!("{node} accepted axis change");
894 } else {
895 trace!("{node} rejected axis change");
896 }
897 Ok(result)
898 }
899
900 fn declutter_with_session(
901 &self,
902 session: &mut OptimizerSession,
903 model: &TypedModel,
904 node: &TypedNode,
905 ) -> TractResult<Option<TypedModelPatch>> {
906 macro_rules! pass {
907 ($func:ident) => {
908 if let Some(mut r) = self
909 .$func(session, model, node)
910 .with_context(|| format!("{}", stringify!($func)))?
911 {
912 trace!(stringify!($func));
913 r.push_context(stringify!($func));
914 return Ok(Some(r));
915 }
916 };
917 }
918 pass!(declutter_single_loop);
919 pass!(declutter_const_input);
920 pass!(declutter_discard_unused_input);
921 pass!(declutter_discard_useless_outer_output);
922 pass!(declutter_discard_empty_output_mapping_with_body_output);
923 pass!(declutter_body);
924 pass!(declutter_body_axes);
925 pass!(declutter_pull_constant_outputs);
926 pass!(declutter_pull_batcheable_input);
927 pass!(declutter_pull_batcheable_output);
928 Ok(None)
929 }
930
931 fn set_symbols(
932 &self,
933 _source: &TypedModel,
934 node: &TypedNode,
935 target: &mut TypedModel,
936 mapping: &HashMap<OutletId, OutletId>,
937 subs: &HashMap<Symbol, TDim>,
938 ) -> TractResult<TVec<OutletId>> {
939 let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
940 let op = Self {
941 output_mapping: self
942 .output_mapping
943 .iter()
944 .map(|om| om.set_symbols(subs))
945 .collect::<TractResult<Vec<_>>>()?,
946 body: self.body.set_symbols(subs)?,
947 ..self.clone()
948 };
949 target.wire_node(&node.name, op, &inputs)
950 }
951
952 fn codegen(
953 &self,
954 model: &TypedModel,
955 node: &TypedNode,
956 ) -> TractResult<Option<TypedModelPatch>> {
957 Ok(Some(TypedModelPatch::replace_single_op(
958 model,
959 node,
960 &node.inputs,
961 self.to_codegen_op(true)?,
962 )?))
963 }
964}