1use std::collections::HashSet;
2
3use crate::ops::einsum::EinSum;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6
7use super::optimized::{OptScan, ScanOpParams};
8use tract_data::internal::*;
9
10use super::*;
11
12#[derive(Debug, Clone, Default)]
13pub struct Scan {
14 pub skip: usize,
15 pub reset_every_turn: bool,
16 pub body: TypedModel,
17 pub decluttered: bool,
18 pub input_mapping: Vec<InputMapping>,
19 pub output_mapping: Vec<OutputMapping<TDim>>,
20}
21
22impl Scan {
23 pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<OptScan> {
24 let mut model = self.body.clone();
25 if optimize_inner {
26 model = model.into_optimized()?;
27 }
28 let plan = SimplePlan::new(model)?;
29
30 Ok(OptScan::new(Arc::new(ScanOpParams::new(
31 self.skip,
32 self.reset_every_turn,
33 Arc::new(plan),
34 self.input_mapping.clone(),
35 self.output_mapping.clone(),
36 ))))
37 }
38
39 pub fn new(
40 body: TypedModel,
41 input_mapping: Vec<InputMapping>,
42 output_mapping: Vec<OutputMapping<TDim>>,
43 skip: usize,
44 ) -> TractResult<Scan> {
45 body.check_consistency()?;
46 ensure!(input_mapping.len() == body.input_outlets()?.len());
47 ensure!(output_mapping.len() == body.output_outlets()?.len());
48 Ok(Scan {
49 skip,
50 reset_every_turn: false,
51 body,
52 decluttered: false,
53 input_mapping,
54 output_mapping,
55 })
56 }
57
58 pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
59 self.to_codegen_op(false).unwrap().iteration_count(inputs)
60 }
61
62 fn declutter_body(
63 &self,
64 session: &mut OptimizerSession,
65 model: &TypedModel,
66 node: &TypedNode,
67 ) -> TractResult<Option<TypedModelPatch>> {
68 if !self.decluttered {
69 let mut new = self.clone();
70 let mut body = self.body.clone();
71 session.optimize(&mut body)?;
72 new.body = body;
73 new.decluttered = true;
74 Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
75 } else {
76 Ok(None)
77 }
78 }
79
80 fn declutter_body_axes(
81 &self,
82 _session: &mut OptimizerSession,
83 model: &TypedModel,
84 node: &TypedNode,
85 ) -> TractResult<Option<TypedModelPatch>> {
86 let mut suggestions = vec![];
87 for n in self.body.eval_order()? {
88 let node = self.body.node(n);
89 for suggestion in node.op.suggested_axis_changes()? {
90 let outlet = suggestion.0.as_outlet(node);
91 suggestions.push(AxisChange { outlet, op: suggestion.1 })
92 }
93 for (slot, fact) in node.outputs.iter().enumerate() {
94 for (ix, dim) in fact.fact.shape.iter().enumerate() {
95 if dim.is_one() {
96 suggestions.push(AxisChange {
97 outlet: OutletId::new(n, slot),
98 op: AxisOp::Rm(ix),
99 });
100 }
101 }
102 }
103 }
104 let node_input_facts = model.node_input_facts(node.id)?;
105 for suggestion in suggestions.into_iter() {
106 if let Some(conseq) = self.try_body_axes_change(suggestion, true, &node_input_facts)? {
107 let mut patch = TypedModelPatch::default();
108 let mut inputs = tvec!();
109 for outlet in &node.inputs {
110 inputs.push(patch.tap_model(model, *outlet)?);
111 }
112 for change in conseq.wire_changes {
113 if let InOut::In(i) = change.0 {
114 let mut value = patch
115 .outlet_fact(inputs[i])?
116 .konst
117 .clone()
118 .context("Will only reshape constants")?
119 .into_tensor();
120 change.1.change_tensor(&mut value, false)?;
121 let konst_name = patch.node(inputs[i].node).name.clone();
122 inputs[i] = patch.add_const(konst_name, value)?;
123 }
124 }
125 let wires = patch.wire_node(
126 &node.name,
127 conseq.substitute_op.unwrap_or_else(|| Box::new(self.clone())),
128 &inputs,
129 )?;
130 for (ix, new) in wires.into_iter().enumerate() {
131 patch.shunt_outside(model, OutletId::new(node.id, ix), new)?;
132 }
133 return Ok(Some(patch));
134 }
135 }
136 Ok(None)
137 }
138
139 fn remove_outer_output_from_mappings(
140 mappings: &[OutputMapping<TDim>],
141 discarded: usize,
142 ) -> Vec<OutputMapping<TDim>> {
143 mappings
144 .iter()
145 .map(|m| OutputMapping {
146 scan: m.scan.map(|(slot, info)| (slot - (slot > discarded) as usize, info)),
147 last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
148 full_dim_hint: m.full_dim_hint.clone(),
149 state: m.state,
150 })
151 .collect()
152 }
153
154 fn declutter_const_input(
155 &self,
156 _session: &mut OptimizerSession,
157 model: &TypedModel,
158 node: &TypedNode,
159 ) -> TractResult<Option<TypedModelPatch>> {
160 let inputs = model.node_input_facts(node.id)?;
161 for (slot, mapping) in self.input_mapping.iter().enumerate() {
162 if let InputMapping::Full = mapping {
163 if let Some(konst) = inputs[slot].konst.as_ref() {
164 let mut op = self.clone();
165 let src = op.body.inputs[slot];
166 op.body.inputs.remove(slot);
167 op.body.nodes[src.node].inputs.clear();
168 op.body.nodes[src.node].op = Box::new(Const::new(konst.clone()));
169 op.input_mapping.remove(slot);
170 let mut inputs = node.inputs.clone();
171 inputs.remove(slot);
172 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
173 }
174 }
175 }
176 Ok(None)
177 }
178
179 fn declutter_discard_unused_input(
180 &self,
181 _session: &mut OptimizerSession,
182 model: &TypedModel,
183 node: &TypedNode,
184 ) -> TractResult<Option<TypedModelPatch>> {
185 for (slot, input) in self.body.input_outlets()?.iter().enumerate() {
186 let source_node = self.body.node(input.node);
187 if source_node.outputs[0].successors.len() == 0
188 && !self.body.output_outlets()?.contains(input)
189 {
190 let mut new_inputs = node.inputs.clone();
191 new_inputs.remove(slot);
192 let mut new_mappings: Vec<_> = self.input_mapping.clone();
193 new_mappings.remove(slot);
194 let mut model_inputs = self.body.input_outlets()?.to_vec();
195 model_inputs.remove(slot);
196 let mut body = self.body.clone();
197 let mut patch = TypedModelPatch::default();
198 patch.obliterate(source_node.id)?;
199 patch.apply(&mut body)?;
200 body.set_input_outlets(&model_inputs)?;
201 body.declutter()?;
202 let op =
203 Self { body, input_mapping: new_mappings, decluttered: true, ..self.clone() };
204 return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
205 }
206 }
207 Ok(None)
208 }
209
210 fn declutter_discard_useless_outer_output(
211 &self,
212 _session: &mut OptimizerSession,
213 model: &TypedModel,
214 node: &TypedNode,
215 ) -> TractResult<Option<TypedModelPatch>> {
216 for (ix, o) in node.outputs.iter().enumerate() {
217 if o.successors.len() == 0
218 && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
219 {
220 let mappings = self
221 .output_mapping
222 .iter()
223 .map(|m| OutputMapping {
224 scan: m.scan.filter(|(slot, _info)| *slot != ix),
225 last_value_slot: m.last_value_slot.filter(|s| *s != ix),
226 full_dim_hint: m.full_dim_hint.clone(),
227 state: m.state,
228 })
229 .collect::<Vec<_>>();
230 let mut op = self.clone();
231 op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
232 let mut patch = TypedModelPatch::default();
233 let inputs = node
234 .inputs
235 .iter()
236 .map(|&i| patch.tap_model(model, i))
237 .collect::<TractResult<Vec<_>>>()?;
238 let wires = patch.wire_node(&*node.name, op, &inputs)?;
239 for oix in 0..node.outputs.len() {
240 if oix != ix {
241 patch.shunt_outside(
242 model,
243 OutletId::new(node.id, oix),
244 wires[oix - (oix > ix) as usize],
245 )?;
246 }
247 }
248 return Ok(Some(patch));
249 }
250 }
251 Ok(None)
252 }
253
254 fn declutter_discard_empty_output_mapping_with_body_output(
255 &self,
256 _session: &mut OptimizerSession,
257 model: &TypedModel,
258 node: &TypedNode,
259 ) -> TractResult<Option<TypedModelPatch>> {
260 for (ix, om) in self.output_mapping.iter().enumerate() {
261 if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
262 let mut new_op = self.clone();
263 new_op.output_mapping.remove(ix);
264 new_op.body.outputs.remove(ix);
265 new_op.decluttered = false;
266 return Ok(Some(TypedModelPatch::replace_single_op(
267 model,
268 node,
269 &node.inputs,
270 new_op,
271 )?));
272 }
273 }
274 Ok(None)
275 }
276
277 fn declutter_pull_batcheable_input(
278 &self,
279 _session: &mut OptimizerSession,
280 model: &TypedModel,
281 node: &TypedNode,
282 ) -> TractResult<Option<TypedModelPatch>> {
283 'candidate: for (slot, input) in self.input_mapping.iter().enumerate() {
284 if let Some(scan_info) = input.as_scan() {
285 let scan_source = self.body.input_outlets()?[slot];
286 let scan_source_node = self.body.node(scan_source.node);
287 for mut succ in &scan_source_node.outputs[0].successors {
288 for &succ_input in &self.body.node(succ.node).inputs {
289 if succ_input != scan_source
290 && self.body.outlet_fact(succ_input)?.konst.is_none()
291 {
292 continue 'candidate;
293 }
294 }
295 if self.body.node(succ.node).outputs.len() != 1 {
296 continue;
297 }
298 let mut new_body = self.body.clone();
299 if let Some(einsum) = new_body.node(succ.node).op_as::<EinSum>() {
301 if let Some(patch) = einsum
302 .propagate_axis(
303 &new_body,
304 new_body.node(succ.node),
305 InOut::In(succ.slot),
306 scan_info.axis,
307 )
308 .context("building axis propagating patch")?
309 {
310 patch.apply(&mut new_body)?;
311 new_body.compute_const_facts()?;
312 let new_body_scan_input = new_body.input_outlets()?[slot];
315 succ = new_body.node(new_body_scan_input.node).outputs[0]
316 .successors
317 .last()
318 .unwrap();
319 }
320 }
321
322 let axes_mapping = {
323 let (input_facts, output_facts) =
324 new_body.node_facts(new_body.node(succ.node).id)?;
325 new_body.node(succ.node).op.axes_mapping(&input_facts, &output_facts)?
326 };
327 let axis_info = axes_mapping.axis((InOut::In(succ.slot), scan_info.axis))?;
328 if let &[axis_after] = &*axis_info.outputs[0] {
329 let mut outside_patch = TypedModelPatch::new(format!(
330 "Outer patch for input extraction of {}",
331 new_body.node(succ.node)
332 ));
333 let mut patch_inputs = node
334 .inputs
335 .iter()
336 .map(|&i| outside_patch.tap_model(model, i))
337 .collect::<TractResult<TVec<_>>>()?;
338 let mut extracted_op_inputs = tvec!();
339 for (ix, outlet) in new_body.node(succ.node).inputs.iter().enumerate() {
340 let wire = if ix == succ.slot {
341 patch_inputs[slot]
342 } else if let Some(konst) =
343 new_body.outlet_fact(*outlet)?.konst.as_ref()
344 {
345 outside_patch.add_const(
346 format!(
347 "{}.extracted.{}",
348 node.name,
349 new_body.node(outlet.node).name
350 ),
351 konst.clone(),
352 )?
353 } else {
354 unreachable!();
355 };
356 extracted_op_inputs.push(wire);
357 }
358 let new_input_wire = outside_patch.wire_node(
359 format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
360 new_body.node(succ.node).op.clone(),
361 &extracted_op_inputs,
362 )?[0];
363 patch_inputs.push(new_input_wire);
364 let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
365 let mut new_input_inner_fact = new_input_outer_fact.clone();
366 new_input_inner_fact.shape.set(axis_after, scan_info.chunk.abs().to_dim());
367
368 let mut new_body = new_body.clone();
369 let new_source_wire = new_body.add_source(
370 format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
371 new_input_inner_fact,
372 )?;
373 let mut inner_patch = TypedModelPatch::new(format!(
374 "Inner body patch for extraction of {}",
375 new_body.node(succ.node)
376 ));
377 let new_source_wire_in_patch =
378 inner_patch.tap_model(&new_body, new_source_wire)?;
379 inner_patch
380 .shunt_outside(
381 &new_body,
382 OutletId::new(succ.node, 0),
383 new_source_wire_in_patch,
384 )
385 .with_context(|| "patching inner model")?;
386 inner_patch.apply(&mut new_body)?;
387
388 let mut input_mapping = self.input_mapping.clone();
389 input_mapping.push(InputMapping::Scan(ScanInfo {
390 axis: axis_after,
391 chunk: scan_info.chunk,
392 }));
393
394 let new_op = Self {
395 input_mapping,
396 decluttered: false,
397 body: new_body,
398 ..self.clone()
399 };
400 let output_wires =
401 outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
402 for w in output_wires {
403 outside_patch
404 .shunt_outside(model, OutletId::new(node.id, w.slot), w)
405 .with_context(|| "patching outer model")?;
406 }
407 return Ok(Some(outside_patch));
408 }
409 }
410 }
411 }
412 Ok(None)
413 }
414
415 fn declutter_pull_constant_outputs(
416 &self,
417 _session: &mut OptimizerSession,
418 model: &TypedModel,
419 node: &TypedNode,
420 ) -> TractResult<Option<TypedModelPatch>> {
421 for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
422 if let Some(slot) = mapping.last_value_slot {
423 if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
424 let inner_node = self.body.output_outlets()?[model_output_ix].node;
425 let inner_node = self.body.node(inner_node);
426 let mut patch =
427 TypedModelPatch::new(format!("Extract const node {inner_node}"));
428 let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
429 patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
430 return Ok(Some(patch));
431 }
432 }
433 }
434 Ok(None)
435 }
436
437 fn declutter_pull_batcheable_output(
438 &self,
439 _session: &mut OptimizerSession,
440 model: &TypedModel,
441 node: &TypedNode,
442 ) -> TractResult<Option<TypedModelPatch>> {
443 for (mapping_ix, mapping) in self.output_mapping.iter().enumerate() {
444 if let Some((_, scan_info)) = mapping.scan {
445 let emitter_outlet = self.body.output_outlets()?[mapping_ix];
446 if self.body.node(emitter_outlet.node).outputs[emitter_outlet.slot].successors.len()
447 > 0
448 || self.body.inputs.contains(&emitter_outlet)
449 || mapping.state
450 || mapping.scan.map(|(_slot, i)| i.chunk > 1).unwrap_or(true)
451 {
452 continue;
454 }
455 let mut new_body = self.body.clone();
456 if let Some(einsum) = new_body.node(emitter_outlet.node).op_as::<EinSum>() {
457 if let Some(patch) = einsum
458 .propagate_axis(
459 &new_body,
460 new_body.node(emitter_outlet.node),
461 InOut::Out(0),
462 scan_info.axis,
463 )
464 .context("building axis propagating patch")?
465 {
466 patch.apply(&mut new_body)?;
467 new_body.prop_consts()?;
468 }
469 }
470 let emitter_outlet = new_body.output_outlets()?[mapping_ix];
471 let invariants = {
472 let (input_facts, output_facts) = new_body.node_facts(emitter_outlet.node)?;
473 new_body
474 .node(emitter_outlet.node)
475 .op
476 .axes_mapping(&input_facts, &output_facts)?
477 };
478 let axis_tracking =
479 invariants.axis((InOut::Out(emitter_outlet.slot), scan_info.axis))?;
480 if axis_tracking.outputs.iter().any(|o| o.len() > 1) {
481 return Ok(None);
482 }
483 let mut new_output_mapping = self.output_mapping.clone();
484 let mut new_scan_outputs = node.outputs.len();
485 let mut outer_slots = vec![];
486
487 for (input_slot, input) in
489 new_body.node(emitter_outlet.node).inputs.clone().iter().enumerate()
490 {
491 if new_body.outputs.iter().all(|o| o != input) {
492 new_output_mapping.push(OutputMapping::default());
493 new_body.outputs.push(*input);
494 }
495 let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
496 let mapping = &mut new_output_mapping[body_output_id];
497 let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
498 if mapping.last_value_slot.is_none() {
499 mapping.last_value_slot = Some(new_scan_outputs);
500 new_scan_outputs += 1;
501 }
502 mapping.last_value_slot.unwrap()
503 } else if let &[axis] = &*axis_tracking.inputs[input_slot] {
504 if mapping.scan.is_none() {
505 mapping.scan =
506 Some((new_scan_outputs, ScanInfo { axis, chunk: scan_info.chunk }));
507 new_scan_outputs += 1;
508 }
509 mapping.scan.unwrap().0
510 } else {
511 return Ok(None);
512 };
513 outer_slots.push(outer_slot);
514 }
515 let mut outside_patch = TypedModelPatch::new(format!(
516 "Outside patch for output extraction of {}",
517 new_body.node(emitter_outlet.node)
518 ));
519 let inputs = node
520 .inputs
521 .iter()
522 .map(|&i| outside_patch.tap_model(model, i))
523 .collect::<TractResult<TVec<_>>>()?;
524 let new_op = Self {
525 output_mapping: new_output_mapping,
526 decluttered: false,
527 body: new_body.clone(), ..self.clone()
529 };
530 let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
531 let output = mapping.scan.unwrap();
532 let inputs =
533 outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
534 let wire = outside_patch.wire_node(
535 &new_body.node(emitter_outlet.node).name,
536 new_body.node(emitter_outlet.node).op.clone(),
537 &inputs,
538 )?[0];
539 outside_patch.shunt_outside(model, OutletId::new(node.id, output.0), wire)?;
540 for output_slot in 0..node.outputs.len() {
541 if output_slot != output.0 {
542 outside_patch.shunt_outside(
543 model,
544 OutletId::new(node.id, output_slot),
545 OutletId::new(scan_outputs[0].node, output_slot),
546 )?;
547 }
548 }
549 return Ok(Some(outside_patch));
550 }
551 }
552 Ok(None)
553 }
554
555 fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
556 let input_state_outlets = self
557 .input_mapping
558 .iter()
559 .zip(self.body.input_outlets()?.iter())
560 .filter(|(m, _)| m.is_state())
561 .map(|(_, o)| o);
562 let output_state_outlets = self
563 .output_mapping
564 .iter()
565 .zip(self.body.output_outlets()?.iter())
566 .filter(|(m, _)| m.state)
567 .map(|(_, o)| o);
568 Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
569 }
570
571 fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult<TVec<OutletId>> {
572 let input_outlets =
573 self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| {
574 if node_input_facts[slot].konst.is_none() {
575 Some(o)
576 } else {
577 None
578 }
579 });
580 let output_outlets = self
581 .output_mapping
582 .iter()
583 .zip(self.body.output_outlets()?.iter())
584 .filter(|(m, _)| !m.invisible())
585 .map(|(_, o)| o);
586 Ok(input_outlets.chain(output_outlets).cloned().collect())
587 }
588
589 fn try_body_axes_change(
590 &self,
591 change: AxisChange,
592 locked_interface: bool,
593 node_input_facts: &[&TypedFact],
594 ) -> TractResult<Option<AxisChangeConsequence>> {
595 self.body.check_consistency()?;
596 let locked_outlets = self.body_locked_outlets(node_input_facts)?;
597 let mut explored: HashSet<AxisChange> = Default::default();
598 let (body_patch, body_changed_wires) = if let Some(changes) =
599 crate::optim::change_axes::change_axes(
600 &self.body,
601 &change,
602 if locked_interface { &locked_outlets } else { &[] },
603 &self.body_bounds()?,
604 &mut explored
605 )? {
606 changes
607 } else {
608 return Ok(None);
609 };
610 let mut body = self.body.clone();
611 body_patch.apply(&mut body)?;
612 body.compact()?;
613 let mut wire_changes = tvec!();
614 let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
615 for (slot, m) in input_mapping.iter_mut().enumerate() {
616 if let Some(change) = body_changed_wires
617 .iter()
618 .find(|(iface, _change)| iface == &InOut::In(slot))
619 .map(|pair| pair.1.clone())
620 {
621 wire_changes.push((InOut::In(slot), change.clone()));
622 if let InputMapping::Scan(info) = m {
623 if let Some(axis) = change.transform_axis(info.axis) {
624 info.axis = axis;
625 } else {
626 return Ok(None);
627 };
628 };
629 }
630 }
631 let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
632 for (ix, m) in output_mapping.iter_mut().enumerate() {
633 if let Some(change) = body_changed_wires
634 .iter()
635 .find(|(iface, _change)| iface == &InOut::Out(ix))
636 .map(|pair| pair.1.clone())
637 {
638 if let Some((slot, info)) = m.scan.as_mut() {
639 if let Some(new_axis) = change.transform_axis(info.axis) {
640 info.axis = new_axis;
641 } else {
642 return Ok(None);
643 }
644 wire_changes.push((InOut::Out(*slot), change.clone()));
645 }
646 if let Some(slot) = m.last_value_slot {
647 wire_changes.push((InOut::Out(slot), change.clone()));
648 }
649 };
650 }
651 body.check_consistency()?;
652 let op = Some(Box::new(Scan {
653 body,
654 input_mapping,
655 output_mapping,
656 decluttered: false,
657 ..self.clone()
658 }) as _);
659 Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
660 }
661}
662
663impl Op for Scan {
664 fn name(&self) -> Cow<str> {
665 "Scan".into()
666 }
667
668 fn info(&self) -> TractResult<Vec<String>> {
669 let mut lines = vec![];
670 for (ix, im) in self.input_mapping.iter().enumerate() {
671 lines.push(format!("Model input #{ix}: {im:?}"));
672 }
673 for (ix, om) in self.output_mapping.iter().enumerate() {
674 lines.push(format!("Model output #{ix}: {om:?}"));
675 }
676 lines.push(format!("skip:{} reset_every_turn:{:?}", self.skip, self.reset_every_turn));
677 Ok(lines)
678 }
679
680 fn validation(&self) -> Validation {
681 Validation::Rounding
682 }
683
684 op_as_typed_op!();
685}
686
687impl EvalOp for Scan {
688 fn is_stateless(&self) -> bool {
689 false
690 }
691 fn state(
692 &self,
693 session: &mut SessionState,
694 node_id: usize,
695 ) -> TractResult<Option<Box<dyn OpState>>> {
696 self.to_codegen_op(false)?.state(session, node_id)
697 }
698}
699
700impl TypedOp for Scan {
701 as_op!();
702
703 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
704 anyhow::ensure!(inputs.len() == self.body.inputs.len());
705 anyhow::ensure!(self.input_mapping.len() == self.body.inputs.len());
706 anyhow::ensure!(
707 self.input_mapping.iter().filter(|m| m.is_state()).count()
708 == self.output_mapping.iter().filter(|m| m.state).count()
709 );
710 for (i, o) in
711 self.input_mapping.iter().enumerate().filter(|(_, m)| m.is_state()).map(|(i, _)| i).zip(
712 self.output_mapping.iter().enumerate().filter(|(_, m)| m.state).map(|(o, _)| o),
713 )
714 {
715 let ifact = self.body.outlet_fact(self.body.inputs[i])?;
716 let ofact = self.body.outlet_fact(self.body.outputs[o])?;
717 anyhow::ensure!(ifact == ofact,
718 "inconsistent state shape: body input {i} is {ifact:?} and body output {o} is {ofact:?}",
719 )
720 }
721 let mut outputs = tvec!();
722 let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?;
723 for (ix, output) in self.output_mapping.iter().enumerate() {
724 let fact = self.body.output_fact(ix)?;
725 if let Some((slot, info)) = output.scan {
726 let mut shape = fact.shape.clone();
727 let scanning_dim =
728 output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
729 shape.set(info.axis, scanning_dim);
730 outputs.push((slot, fact.datum_type.fact(shape)));
731 }
732 if let Some(slot) = output.last_value_slot {
733 outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
734 }
735 }
736 outputs.sort_by_key(|a| a.0);
737 anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
738 let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
739 Ok(outputs)
740 }
741
742 fn axes_mapping(
743 &self,
744 inputs: &[&TypedFact],
745 outputs: &[&TypedFact],
746 ) -> TractResult<AxesMapping> {
747 let mut mappings = vec![];
748 let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?;
749 for body_axis in body_invs.iter_all_axes() {
750 let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len());
751 info.inputs.clone_from(&body_axis.inputs);
752 for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
753 let mut slots = vec![];
754 if let Some((slot, _scan)) = output_mapping.scan {
755 slots.push(slot);
756 }
757 if let Some(slot) = output_mapping.last_value_slot {
758 slots.push(slot);
759 }
760 for slot in slots {
761 info.outputs[slot].clone_from(&body_axis.outputs[ix]);
762 }
763 }
764 if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) {
765 mappings.push(info);
766 }
767 }
768 AxesMapping::new(inputs.len(), outputs.len(), mappings)
769 }
770
771 fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
772 let mut suggestions = tvec!();
773 for (slot, input) in self.input_mapping.iter().enumerate() {
774 if let InputMapping::Scan(info) = input {
775 if info.axis != 0 {
776 suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0)))
777 }
778 }
779 }
780 for output in &self.output_mapping {
781 if let Some((slot, scan)) = output.scan {
782 if scan.axis != 0 {
783 suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0)))
784 }
785 }
786 }
787 Ok(suggestions)
788 }
789
790 fn change_axes(
791 &self,
792 model: &TypedModel,
793 node: &TypedNode,
794 io: InOut,
795 change: &AxisOp,
796 ) -> TractResult<Option<AxisChangeConsequence>> {
797 trace!("Propagating through {}: {:?} {:?}", node, io, change);
798 let body_leading_outlet = match io {
799 InOut::In(ix) => self.body.input_outlets()?[ix],
800 InOut::Out(slot) => {
801 let output = self
802 .output_mapping
803 .iter()
804 .position(|im| {
805 im.scan.map(|(slot, _i)| slot) == Some(slot)
806 || im.last_value_slot == Some(slot)
807 })
808 .unwrap();
809 self.body.output_outlets()?[output]
810 }
811 };
812 let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() };
813 let node_input_facts = model.node_input_facts(node.id)?;
814 let result = self
815 .try_body_axes_change(axis_change, false, &node_input_facts)
816 .with_context(|| "Attemping to run change through scan body".to_string())?;
817 if result.is_some() {
818 trace!("{} accepted axis change", node);
819 } else {
820 trace!("{} rejected axis change", node);
821 }
822 Ok(result)
823 }
824
825 fn declutter_with_session(
826 &self,
827 session: &mut OptimizerSession,
828 model: &TypedModel,
829 node: &TypedNode,
830 ) -> TractResult<Option<TypedModelPatch>> {
831 macro_rules! pass {
832 ($func:ident) => {
833 if let Some(mut r) = self
834 .$func(session, model, node)
835 .with_context(|| format!("{}", stringify!($func)))?
836 {
837 trace!(stringify!($func));
838 r.push_context(stringify!($func));
839 return Ok(Some(r));
840 }
841 };
842 }
843 pass!(declutter_const_input);
844 pass!(declutter_discard_unused_input);
845 pass!(declutter_discard_useless_outer_output);
846 pass!(declutter_discard_empty_output_mapping_with_body_output);
847 pass!(declutter_body);
848 pass!(declutter_body_axes);
849 pass!(declutter_pull_constant_outputs);
850 pass!(declutter_pull_batcheable_input);
851 pass!(declutter_pull_batcheable_output);
852 Ok(None)
853 }
854
855 fn concretize_dims(
856 &self,
857 _source: &TypedModel,
858 node: &TypedNode,
859 target: &mut TypedModel,
860 mapping: &HashMap<OutletId, OutletId>,
861 values: &SymbolValues,
862 ) -> TractResult<TVec<OutletId>> {
863 let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
864 let op = Self {
865 output_mapping: self
866 .output_mapping
867 .iter()
868 .map(|om| om.concretize_dims(values))
869 .collect::<TractResult<Vec<_>>>()?,
870 body: self.body.concretize_dims(values)?,
871 ..self.clone()
872 };
873 target.wire_node(&node.name, op, &inputs)
874 }
875
876 fn codegen(
877 &self,
878 model: &TypedModel,
879 node: &TypedNode,
880 ) -> TractResult<Option<TypedModelPatch>> {
881 Ok(Some(TypedModelPatch::replace_single_op(
882 model,
883 node,
884 &node.inputs,
885 self.to_codegen_op(true)?,
886 )?))
887 }
888}