1use std::collections::HashSet;
2
3use crate::ops::einsum::EinSum;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6
7use super::optimized::{OptScan, ScanOpParams};
8use tract_data::internal::*;
9use tract_data::itertools::izip;
10
11use super::*;
12
13#[derive(Debug, Clone, Default)]
14pub struct Scan {
15 pub skip: usize,
16 pub reset_every_turn: bool,
17 pub body: TypedModel,
18 pub decluttered: bool,
19 pub input_mapping: Vec<InputMapping>,
20 pub output_mapping: Vec<OutputMapping<TDim>>,
21}
22
23impl Scan {
24 pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<OptScan> {
25 let mut model = self.body.clone();
26 if optimize_inner {
27 model = model.into_optimized()?;
28 }
29 let plan = SimplePlan::new(model)?;
30
31 Ok(OptScan::new(Arc::new(ScanOpParams::new(
32 self.skip,
33 self.reset_every_turn,
34 plan,
35 self.input_mapping.clone(),
36 self.output_mapping.clone(),
37 ))))
38 }
39
40 pub fn new(
41 body: TypedModel,
42 input_mapping: Vec<InputMapping>,
43 output_mapping: Vec<OutputMapping<TDim>>,
44 skip: usize,
45 ) -> TractResult<Scan> {
46 body.check_consistency()?;
47 ensure!(input_mapping.len() == body.input_outlets()?.len());
48 ensure!(output_mapping.len() == body.output_outlets()?.len());
49 Ok(Scan {
50 skip,
51 reset_every_turn: false,
52 body,
53 decluttered: false,
54 input_mapping,
55 output_mapping,
56 })
57 }
58
59 pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
60 self.to_codegen_op(false).unwrap().iteration_count(inputs)
61 }
62
63 fn declutter_body(
64 &self,
65 session: &mut OptimizerSession,
66 model: &TypedModel,
67 node: &TypedNode,
68 ) -> TractResult<Option<TypedModelPatch>> {
69 rule_if!(!self.decluttered);
70 let mut new = self.clone();
71 let mut body = self.body.clone();
72 session.optimize(&mut body)?;
73 new.body = body;
74 new.decluttered = true;
75 Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
76 }
77
78 fn declutter_single_loop(
79 &self,
80 _session: &mut OptimizerSession,
81 model: &TypedModel,
82 node: &TypedNode,
83 ) -> TractResult<Option<TypedModelPatch>> {
84 let inputs = model.node_input_facts(node.id)?;
85 let iters =
86 super::iteration_count(&self.input_mapping, &inputs).context("No scan input")?;
87 if !iters.is_one() {
88 return Ok(None);
89 }
90 let mut patch = TypedModelPatch::new("Inline single loop scan");
91 patch.model = self.body.clone();
92 for (outer_wire, inner_wire) in izip!(&node.inputs, &self.body.inputs) {
93 patch.taps.insert(*inner_wire, *outer_wire);
94 }
95 for (inner_wire, mapping) in izip!(&self.body.outputs, &self.output_mapping) {
96 if let Some((slot, _)) = mapping.scan {
97 patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
98 }
99 if let Some(slot) = mapping.last_value_slot {
100 patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
101 }
102 }
103 Ok(Some(patch))
104 }
105
106 fn declutter_body_axes(
107 &self,
108 _session: &mut OptimizerSession,
109 model: &TypedModel,
110 node: &TypedNode,
111 ) -> TractResult<Option<TypedModelPatch>> {
112 let mut suggestions = vec![];
113 for n in self.body.eval_order()? {
114 let node = self.body.node(n);
115 for suggestion in node.op.suggested_axis_changes()? {
116 let outlet = suggestion.0.as_outlet(node);
117 suggestions.push(AxisChange { outlet, op: suggestion.1 })
118 }
119 for (slot, fact) in node.outputs.iter().enumerate() {
120 for (ix, dim) in fact.fact.shape.iter().enumerate() {
121 if dim.is_one() {
122 suggestions.push(AxisChange {
123 outlet: OutletId::new(n, slot),
124 op: AxisOp::Rm(ix),
125 });
126 }
127 }
128 }
129 }
130 let node_input_facts = model.node_input_facts(node.id)?;
131 for suggestion in suggestions.into_iter() {
132 if let Some(conseq) = self.try_body_axes_change(suggestion, true, &node_input_facts)? {
133 let mut patch = TypedModelPatch::default();
134 let mut inputs = tvec!();
135 for outlet in &node.inputs {
136 inputs.push(patch.tap_model(model, *outlet)?);
137 }
138 for change in conseq.wire_changes {
139 if let InOut::In(i) = change.0 {
140 let mut value = patch
141 .outlet_fact(inputs[i])?
142 .konst
143 .clone()
144 .context("Will only reshape constants")?
145 .into_tensor();
146 change.1.change_tensor(&mut value, false)?;
147 let konst_name = patch.node(inputs[i].node).name.clone();
148 inputs[i] = patch.add_const(konst_name, value)?;
149 }
150 }
151 let wires = patch.wire_node(
152 &node.name,
153 conseq.substitute_op.unwrap_or_else(|| Box::new(self.clone())),
154 &inputs,
155 )?;
156 for (ix, new) in wires.into_iter().enumerate() {
157 patch.shunt_outside(model, OutletId::new(node.id, ix), new)?;
158 }
159 return Ok(Some(patch));
160 }
161 }
162 Ok(None)
163 }
164
165 fn remove_outer_output_from_mappings(
166 mappings: &[OutputMapping<TDim>],
167 discarded: usize,
168 ) -> Vec<OutputMapping<TDim>> {
169 mappings
170 .iter()
171 .map(|m| OutputMapping {
172 scan: m.scan.map(|(slot, info)| (slot - (slot > discarded) as usize, info)),
173 last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
174 full_dim_hint: m.full_dim_hint.clone(),
175 state: m.state,
176 })
177 .collect()
178 }
179
180 fn declutter_const_input(
181 &self,
182 _session: &mut OptimizerSession,
183 model: &TypedModel,
184 node: &TypedNode,
185 ) -> TractResult<Option<TypedModelPatch>> {
186 let inputs = model.node_input_facts(node.id)?;
187 for (slot, mapping) in self.input_mapping.iter().enumerate() {
188 if let InputMapping::Full = mapping {
189 if let Some(konst) = inputs[slot].konst.as_ref() {
190 let mut op = self.clone();
191 let src = op.body.inputs[slot];
192 op.body.inputs.remove(slot);
193 op.body.nodes[src.node].inputs.clear();
194 op.body.nodes[src.node].op = Box::new(Const::new(konst.clone())?);
195 op.input_mapping.remove(slot);
196 let mut inputs = node.inputs.clone();
197 inputs.remove(slot);
198 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
199 }
200 }
201 }
202 Ok(None)
203 }
204
205 fn declutter_discard_unused_input(
206 &self,
207 _session: &mut OptimizerSession,
208 model: &TypedModel,
209 node: &TypedNode,
210 ) -> TractResult<Option<TypedModelPatch>> {
211 for (slot, input) in self.body.input_outlets()?.iter().enumerate() {
212 let source_node = self.body.node(input.node);
213 if source_node.outputs[0].successors.len() == 0
214 && !self.body.output_outlets()?.contains(input)
215 {
216 let mut new_inputs = node.inputs.clone();
217 new_inputs.remove(slot);
218 let mut new_mappings: Vec<_> = self.input_mapping.clone();
219 new_mappings.remove(slot);
220 let mut model_inputs = self.body.input_outlets()?.to_vec();
221 model_inputs.remove(slot);
222 let mut body = self.body.clone();
223 let mut patch = TypedModelPatch::default();
224 patch.obliterate(source_node.id)?;
225 patch.apply(&mut body)?;
226 body.set_input_outlets(&model_inputs)?;
227 body.declutter()?;
228 let op =
229 Self { body, input_mapping: new_mappings, decluttered: true, ..self.clone() };
230 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
231 }
232 }
233 Ok(None)
234 }
235
236 fn declutter_discard_useless_outer_output(
237 &self,
238 _session: &mut OptimizerSession,
239 model: &TypedModel,
240 node: &TypedNode,
241 ) -> TractResult<Option<TypedModelPatch>> {
242 for (ix, o) in node.outputs.iter().enumerate() {
243 if o.successors.len() == 0
244 && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
245 {
246 let mappings = self
247 .output_mapping
248 .iter()
249 .map(|m| OutputMapping {
250 scan: m.scan.filter(|(slot, _info)| *slot != ix),
251 last_value_slot: m.last_value_slot.filter(|s| *s != ix),
252 full_dim_hint: m.full_dim_hint.clone(),
253 state: m.state,
254 })
255 .collect::<Vec<_>>();
256 let mut op = self.clone();
257 op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
258 let mut patch = TypedModelPatch::default();
259 let inputs = node
260 .inputs
261 .iter()
262 .map(|&i| patch.tap_model(model, i))
263 .collect::<TractResult<Vec<_>>>()?;
264 let wires = patch.wire_node(&*node.name, op, &inputs)?;
265 for oix in 0..node.outputs.len() {
266 if oix != ix {
267 patch.shunt_outside(
268 model,
269 OutletId::new(node.id, oix),
270 wires[oix - (oix > ix) as usize],
271 )?;
272 }
273 }
274 return Ok(Some(patch));
275 }
276 }
277 Ok(None)
278 }
279
280 fn declutter_discard_empty_output_mapping_with_body_output(
281 &self,
282 _session: &mut OptimizerSession,
283 model: &TypedModel,
284 node: &TypedNode,
285 ) -> TractResult<Option<TypedModelPatch>> {
286 for (ix, om) in self.output_mapping.iter().enumerate() {
287 if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
288 let mut new_op = self.clone();
289 new_op.output_mapping.remove(ix);
290 new_op.body.outputs.remove(ix);
291 new_op.decluttered = false;
292 return Ok(Some(TypedModelPatch::replace_single_op(
293 model,
294 node,
295 &node.inputs,
296 new_op,
297 )?));
298 }
299 }
300 Ok(None)
301 }
302
303 fn declutter_pull_batcheable_input(
304 &self,
305 _session: &mut OptimizerSession,
306 model: &TypedModel,
307 node: &TypedNode,
308 ) -> TractResult<Option<TypedModelPatch>> {
309 'candidate: for (slot, input) in self.input_mapping.iter().enumerate() {
310 if let Some(scan_info) = input.as_scan() {
311 let scan_source = self.body.input_outlets()?[slot];
312 let scan_source_node = self.body.node(scan_source.node);
313 for mut succ in &scan_source_node.outputs[0].successors {
314 for &succ_input in &self.body.node(succ.node).inputs {
315 if succ_input != scan_source
316 && self.body.outlet_fact(succ_input)?.konst.is_none()
317 {
318 continue 'candidate;
319 }
320 }
321 if self.body.node(succ.node).outputs.len() != 1 {
322 continue;
323 }
324 let mut new_body = self.body.clone();
325 if let Some(einsum) = new_body.node(succ.node).op_as::<EinSum>() {
327 if let Some(patch) = einsum
328 .propagate_axis(
329 &new_body,
330 new_body.node(succ.node),
331 InOut::In(succ.slot),
332 scan_info.axis,
333 )
334 .context("building axis propagating patch")?
335 {
336 patch.apply(&mut new_body)?;
337 new_body.compute_const_facts()?;
338 let new_body_scan_input = new_body.input_outlets()?[slot];
341 succ = new_body.node(new_body_scan_input.node).outputs[0]
342 .successors
343 .last()
344 .unwrap();
345 }
346 }
347
348 let axes_mapping = {
349 let (input_facts, output_facts) =
350 new_body.node_facts(new_body.node(succ.node).id)?;
351 new_body.node(succ.node).op.axes_mapping(&input_facts, &output_facts)?
352 };
353 let axis_info = axes_mapping.axis((InOut::In(succ.slot), scan_info.axis))?;
354 if let &[axis_after] = &*axis_info.outputs[0] {
355 let mut outside_patch = TypedModelPatch::new(format!(
356 "Outer patch for input extraction of {}",
357 new_body.node(succ.node)
358 ));
359 let mut patch_inputs = node
360 .inputs
361 .iter()
362 .map(|&i| outside_patch.tap_model(model, i))
363 .collect::<TractResult<TVec<_>>>()?;
364 let mut extracted_op_inputs = tvec!();
365 for (ix, outlet) in new_body.node(succ.node).inputs.iter().enumerate() {
366 let wire = if ix == succ.slot {
367 patch_inputs[slot]
368 } else if let Some(konst) =
369 new_body.outlet_fact(*outlet)?.konst.as_ref()
370 {
371 outside_patch.add_const(
372 format!(
373 "{}.extracted.{}",
374 node.name,
375 new_body.node(outlet.node).name
376 ),
377 konst.clone(),
378 )?
379 } else {
380 unreachable!();
381 };
382 extracted_op_inputs.push(wire);
383 }
384 let new_input_wire = outside_patch.wire_node(
385 format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
386 new_body.node(succ.node).op.clone(),
387 &extracted_op_inputs,
388 )?[0];
389 patch_inputs.push(new_input_wire);
390 let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
391 let mut new_input_inner_fact = new_input_outer_fact.clone();
392 new_input_inner_fact.shape.set(axis_after, scan_info.chunk.abs().to_dim());
393
394 let mut new_body = new_body.clone();
395 let new_source_wire = new_body.add_source(
396 format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
397 new_input_inner_fact,
398 )?;
399 let mut inner_patch = TypedModelPatch::new(format!(
400 "Inner body patch for extraction of {}",
401 new_body.node(succ.node)
402 ));
403 let new_source_wire_in_patch =
404 inner_patch.tap_model(&new_body, new_source_wire)?;
405 inner_patch
406 .shunt_outside(
407 &new_body,
408 OutletId::new(succ.node, 0),
409 new_source_wire_in_patch,
410 )
411 .with_context(|| "patching inner model")?;
412 inner_patch.apply(&mut new_body)?;
413
414 let mut input_mapping = self.input_mapping.clone();
415 input_mapping.push(InputMapping::Scan(ScanInfo {
416 axis: axis_after,
417 chunk: scan_info.chunk,
418 }));
419
420 let new_op = Self {
421 input_mapping,
422 decluttered: false,
423 body: new_body,
424 ..self.clone()
425 };
426 let output_wires =
427 outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
428 for w in output_wires {
429 outside_patch
430 .shunt_outside(model, OutletId::new(node.id, w.slot), w)
431 .with_context(|| "patching outer model")?;
432 }
433 return Ok(Some(outside_patch));
434 }
435 }
436 }
437 }
438 Ok(None)
439 }
440
441 fn declutter_pull_constant_outputs(
442 &self,
443 _session: &mut OptimizerSession,
444 model: &TypedModel,
445 node: &TypedNode,
446 ) -> TractResult<Option<TypedModelPatch>> {
447 for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
448 if let Some(slot) = mapping.last_value_slot {
449 if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
450 let inner_node = self.body.output_outlets()?[model_output_ix].node;
451 let inner_node = self.body.node(inner_node);
452 let mut patch =
453 TypedModelPatch::new(format!("Extract const node {inner_node}"));
454 let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
455 patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
456 return Ok(Some(patch));
457 }
458 }
459 }
460 Ok(None)
461 }
462
463 fn declutter_pull_batcheable_output(
464 &self,
465 _session: &mut OptimizerSession,
466 model: &TypedModel,
467 node: &TypedNode,
468 ) -> TractResult<Option<TypedModelPatch>> {
469 for (mapping_ix, mapping) in self.output_mapping.iter().enumerate() {
470 if let Some((_, scan_info)) = mapping.scan {
471 let emitter_outlet = self.body.output_outlets()?[mapping_ix];
472 if self.body.node(emitter_outlet.node).outputs[emitter_outlet.slot].successors.len()
473 > 0
474 || self.body.inputs.contains(&emitter_outlet)
475 || mapping.state
476 || mapping.scan.map(|(_slot, i)| i.chunk > 1).unwrap_or(true)
477 {
478 continue;
480 }
481 let mut new_body = self.body.clone();
482 if let Some(einsum) = new_body.node(emitter_outlet.node).op_as::<EinSum>() {
483 if let Some(patch) = einsum
484 .propagate_axis(
485 &new_body,
486 new_body.node(emitter_outlet.node),
487 InOut::Out(0),
488 scan_info.axis,
489 )
490 .context("building axis propagating patch")?
491 {
492 patch.apply(&mut new_body)?;
493 new_body.prop_consts()?;
494 }
495 }
496 let emitter_outlet = new_body.output_outlets()?[mapping_ix];
497 let invariants = {
498 let (input_facts, output_facts) = new_body.node_facts(emitter_outlet.node)?;
499 new_body
500 .node(emitter_outlet.node)
501 .op
502 .axes_mapping(&input_facts, &output_facts)?
503 };
504 let axis_tracking =
505 invariants.axis((InOut::Out(emitter_outlet.slot), scan_info.axis))?;
506 rule_if!(axis_tracking.outputs.iter().all(|o| o.len() == 1));
507 let mut new_output_mapping = self.output_mapping.clone();
508 let mut new_scan_outputs = node.outputs.len();
509 let mut outer_slots = vec![];
510
511 for (input_slot, input) in
513 new_body.node(emitter_outlet.node).inputs.clone().iter().enumerate()
514 {
515 if new_body.outputs.iter().all(|o| o != input) {
516 new_output_mapping.push(OutputMapping::default());
517 new_body.outputs.push(*input);
518 }
519 let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
520 let mapping = &mut new_output_mapping[body_output_id];
521 let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
522 if mapping.last_value_slot.is_none() {
523 mapping.last_value_slot = Some(new_scan_outputs);
524 new_scan_outputs += 1;
525 }
526 mapping.last_value_slot.unwrap()
527 } else if let &[axis] = &*axis_tracking.inputs[input_slot] {
528 if mapping.scan.is_none() {
529 mapping.scan =
530 Some((new_scan_outputs, ScanInfo { axis, chunk: scan_info.chunk }));
531 new_scan_outputs += 1;
532 }
533 mapping.scan.unwrap().0
534 } else {
535 return Ok(None);
536 };
537 outer_slots.push(outer_slot);
538 }
539 let mut outside_patch = TypedModelPatch::new(format!(
540 "Outside patch for output extraction of {}",
541 new_body.node(emitter_outlet.node)
542 ));
543 let inputs = node
544 .inputs
545 .iter()
546 .map(|&i| outside_patch.tap_model(model, i))
547 .collect::<TractResult<TVec<_>>>()?;
548 let new_op = Self {
549 output_mapping: new_output_mapping,
550 decluttered: false,
551 body: new_body.clone(), ..self.clone()
553 };
554 let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
555 let output = mapping.scan.unwrap();
556 let inputs =
557 outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
558 let wire = outside_patch.wire_node(
559 &new_body.node(emitter_outlet.node).name,
560 new_body.node(emitter_outlet.node).op.clone(),
561 &inputs,
562 )?[0];
563 outside_patch.shunt_outside(model, OutletId::new(node.id, output.0), wire)?;
564 for output_slot in 0..node.outputs.len() {
565 if output_slot != output.0 {
566 outside_patch.shunt_outside(
567 model,
568 OutletId::new(node.id, output_slot),
569 OutletId::new(scan_outputs[0].node, output_slot),
570 )?;
571 }
572 }
573 return Ok(Some(outside_patch));
574 }
575 }
576 Ok(None)
577 }
578
579 fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
580 let input_state_outlets = self
581 .input_mapping
582 .iter()
583 .zip(self.body.input_outlets()?.iter())
584 .filter(|(m, _)| m.is_state())
585 .map(|(_, o)| o);
586 let output_state_outlets = self
587 .output_mapping
588 .iter()
589 .zip(self.body.output_outlets()?.iter())
590 .filter(|(m, _)| m.state)
591 .map(|(_, o)| o);
592 Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
593 }
594
595 fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult<TVec<OutletId>> {
596 let input_outlets =
597 self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| {
598 if node_input_facts[slot].konst.is_none() { Some(o) } else { None }
599 });
600 let output_outlets = self
601 .output_mapping
602 .iter()
603 .zip(self.body.output_outlets()?.iter())
604 .filter(|(m, _)| !m.invisible())
605 .map(|(_, o)| o);
606 Ok(input_outlets.chain(output_outlets).cloned().collect())
607 }
608
609 fn try_body_axes_change(
610 &self,
611 change: AxisChange,
612 locked_interface: bool,
613 node_input_facts: &[&TypedFact],
614 ) -> TractResult<Option<AxisChangeConsequence>> {
615 self.body.check_consistency()?;
616 let locked_outlets = self.body_locked_outlets(node_input_facts)?;
617 let mut explored: HashSet<AxisChange> = Default::default();
618 rule_if_some!(
619 (body_patch, body_changed_wires) = crate::optim::change_axes::change_axes(
620 &self.body,
621 &change,
622 if locked_interface { &locked_outlets } else { &[] },
623 &self.body_bounds()?,
624 &mut explored,
625 )?
626 );
627 let mut body = self.body.clone();
628 body_patch.apply(&mut body)?;
629 body.compact()?;
630 let mut wire_changes = tvec!();
631 let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
632 for (slot, m) in input_mapping.iter_mut().enumerate() {
633 if let Some(change) = body_changed_wires
634 .iter()
635 .find(|(iface, _change)| iface == &InOut::In(slot))
636 .map(|pair| pair.1.clone())
637 {
638 wire_changes.push((InOut::In(slot), change.clone()));
639 if let InputMapping::Scan(info) = m {
640 rule_if_some!(axis = change.transform_axis(info.axis));
641 info.axis = axis;
642 };
643 }
644 }
645 let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
646 for (ix, m) in output_mapping.iter_mut().enumerate() {
647 if let Some(change) = body_changed_wires
648 .iter()
649 .find(|(iface, _change)| iface == &InOut::Out(ix))
650 .map(|pair| pair.1.clone())
651 {
652 if let Some((slot, info)) = m.scan.as_mut() {
653 rule_if_some!(new_axis = change.transform_axis(info.axis));
654 info.axis = new_axis;
655 wire_changes.push((InOut::Out(*slot), change.clone()));
656 }
657 if let Some(slot) = m.last_value_slot {
658 wire_changes.push((InOut::Out(slot), change.clone()));
659 }
660 };
661 }
662 body.check_consistency()?;
663 let op = Some(Box::new(Scan {
664 body,
665 input_mapping,
666 output_mapping,
667 decluttered: false,
668 ..self.clone()
669 }) as _);
670 Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
671 }
672}
673
674impl Op for Scan {
675 fn name(&self) -> StaticName {
676 "Scan".into()
677 }
678
679 fn info(&self) -> TractResult<Vec<String>> {
680 let mut lines = vec![];
681 for (ix, im) in self.input_mapping.iter().enumerate() {
682 lines.push(format!("Model input #{ix}: {im:?}"));
683 }
684 for (ix, om) in self.output_mapping.iter().enumerate() {
685 lines.push(format!("Model output #{ix}: {om:?}"));
686 }
687 lines.push(format!("skip:{} reset_every_turn:{:?}", self.skip, self.reset_every_turn));
688 Ok(lines)
689 }
690
691 fn validation(&self) -> Validation {
692 Validation::Rounding
693 }
694
695 op_as_typed_op!();
696}
697
698impl EvalOp for Scan {
699 fn is_stateless(&self) -> bool {
700 false
701 }
702 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
703 self.to_codegen_op(false)?.state(session, node_id)
704 }
705}
706
707impl TypedOp for Scan {
708 as_op!();
709
710 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
711 anyhow::ensure!(inputs.len() == self.body.inputs.len());
712 anyhow::ensure!(self.input_mapping.len() == self.body.inputs.len());
713 anyhow::ensure!(
714 self.input_mapping.iter().filter(|m| m.is_state()).count()
715 == self.output_mapping.iter().filter(|m| m.state).count()
716 );
717 for (i, o) in
718 self.input_mapping.iter().enumerate().filter(|(_, m)| m.is_state()).map(|(i, _)| i).zip(
719 self.output_mapping.iter().enumerate().filter(|(_, m)| m.state).map(|(o, _)| o),
720 )
721 {
722 let ifact = self.body.outlet_fact(self.body.inputs[i])?;
723 let ofact = self.body.outlet_fact(self.body.outputs[o])?;
724 anyhow::ensure!(
725 ifact == ofact,
726 "inconsistent fact: body input {i} is {ifact:?} and body output {o} is {ofact:?}\n{}",
727 self.body
728 )
729 }
730 let mut outputs = tvec!();
731 let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?;
732 for (ix, output) in self.output_mapping.iter().enumerate() {
733 let fact = self.body.output_fact(ix)?;
734 if let Some((slot, info)) = output.scan {
735 let mut shape = fact.shape.clone();
736 let scanning_dim =
737 output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
738 shape.set(info.axis, scanning_dim);
739 outputs.push((slot, fact.datum_type.fact(shape)));
740 }
741 if let Some(slot) = output.last_value_slot {
742 outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
743 }
744 }
745 outputs.sort_by_key(|a| a.0);
746 anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
747 let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
748 Ok(outputs)
749 }
750
751 fn axes_mapping(
752 &self,
753 inputs: &[&TypedFact],
754 outputs: &[&TypedFact],
755 ) -> TractResult<AxesMapping> {
756 let mut mappings = vec![];
757 let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?;
758 for body_axis in body_invs.iter_all_axes() {
759 let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len());
760 info.inputs.clone_from(&body_axis.inputs);
761 for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
762 let mut slots = vec![];
763 if let Some((slot, _scan)) = output_mapping.scan {
764 slots.push(slot);
765 }
766 if let Some(slot) = output_mapping.last_value_slot {
767 slots.push(slot);
768 }
769 for slot in slots {
770 info.outputs[slot].clone_from(&body_axis.outputs[ix]);
771 }
772 }
773 if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) {
774 mappings.push(info);
775 }
776 }
777 AxesMapping::new(inputs.len(), outputs.len(), mappings)
778 }
779
780 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
781 let mut suggestions = tvec!();
782 for (slot, input) in self.input_mapping.iter().enumerate() {
783 if let InputMapping::Scan(info) = input {
784 if info.axis != 0 {
785 suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0)))
786 }
787 }
788 }
789 for output in &self.output_mapping {
790 if let Some((slot, scan)) = output.scan {
791 if scan.axis != 0 {
792 suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0)))
793 }
794 }
795 }
796 Ok(suggestions)
797 }
798
799 fn change_axes(
800 &self,
801 model: &TypedModel,
802 node: &TypedNode,
803 io: InOut,
804 change: &AxisOp,
805 ) -> TractResult<Option<AxisChangeConsequence>> {
806 trace!("Propagating through {node}: {io:?} {change:?}");
807 let body_leading_outlet = match io {
808 InOut::In(ix) => self.body.input_outlets()?[ix],
809 InOut::Out(slot) => {
810 let output = self
811 .output_mapping
812 .iter()
813 .position(|im| {
814 im.scan.map(|(slot, _i)| slot) == Some(slot)
815 || im.last_value_slot == Some(slot)
816 })
817 .unwrap();
818 self.body.output_outlets()?[output]
819 }
820 };
821 let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() };
822 let node_input_facts = model.node_input_facts(node.id)?;
823 let result = self
824 .try_body_axes_change(axis_change, false, &node_input_facts)
825 .with_context(|| "Attemping to run change through scan body".to_string())?;
826 if result.is_some() {
827 trace!("{node} accepted axis change");
828 } else {
829 trace!("{node} rejected axis change");
830 }
831 Ok(result)
832 }
833
834 fn declutter_with_session(
835 &self,
836 session: &mut OptimizerSession,
837 model: &TypedModel,
838 node: &TypedNode,
839 ) -> TractResult<Option<TypedModelPatch>> {
840 macro_rules! pass {
841 ($func:ident) => {
842 if let Some(mut r) = self
843 .$func(session, model, node)
844 .with_context(|| format!("{}", stringify!($func)))?
845 {
846 trace!(stringify!($func));
847 r.push_context(stringify!($func));
848 return Ok(Some(r));
849 }
850 };
851 }
852 pass!(declutter_single_loop);
853 pass!(declutter_const_input);
854 pass!(declutter_discard_unused_input);
855 pass!(declutter_discard_useless_outer_output);
856 pass!(declutter_discard_empty_output_mapping_with_body_output);
857 pass!(declutter_body);
858 pass!(declutter_body_axes);
859 pass!(declutter_pull_constant_outputs);
860 pass!(declutter_pull_batcheable_input);
861 pass!(declutter_pull_batcheable_output);
862 Ok(None)
863 }
864
865 fn concretize_dims(
866 &self,
867 _source: &TypedModel,
868 node: &TypedNode,
869 target: &mut TypedModel,
870 mapping: &HashMap<OutletId, OutletId>,
871 values: &SymbolValues,
872 ) -> TractResult<TVec<OutletId>> {
873 let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
874 let op = Self {
875 output_mapping: self
876 .output_mapping
877 .iter()
878 .map(|om| om.concretize_dims(values))
879 .collect::<TractResult<Vec<_>>>()?,
880 body: self.body.concretize_dims(values)?,
881 ..self.clone()
882 };
883 target.wire_node(&node.name, op, &inputs)
884 }
885
886 fn codegen(
887 &self,
888 model: &TypedModel,
889 node: &TypedNode,
890 ) -> TractResult<Option<TypedModelPatch>> {
891 Ok(Some(TypedModelPatch::replace_single_op(
892 model,
893 node,
894 &node.inputs,
895 self.to_codegen_op(true)?,
896 )?))
897 }
898}