Skip to main content

tract_core/ops/scan/
decluttered.rs

1use std::collections::HashSet;
2
3use crate::ops::einsum::EinSum;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6
7use super::optimized::{OptScan, ScanOpParams};
8use tract_data::internal::*;
9use tract_data::itertools::izip;
10
11use super::*;
12
13#[derive(Debug, Clone, Default)]
14pub struct Scan {
15    pub skip: usize,
16    pub reset_every_turn: bool,
17    pub body: TypedModel,
18    pub decluttered: bool,
19    pub input_mapping: Vec<InputMapping>,
20    pub output_mapping: Vec<OutputMapping<TDim>>,
21}
22
23impl PartialEq for Scan {
24    fn eq(&self, _other: &Self) -> bool {
25        false
26    }
27}
28impl Eq for Scan {}
29
30impl Scan {
31    pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<OptScan> {
32        let mut model = self.body.clone();
33        if optimize_inner {
34            model = model.into_optimized()?;
35        }
36        let plan = SimplePlan::new(model)?;
37
38        Ok(OptScan::new(Arc::new(ScanOpParams::new(
39            self.skip,
40            self.reset_every_turn,
41            plan,
42            self.input_mapping.clone(),
43            self.output_mapping.clone(),
44        ))))
45    }
46
47    pub fn new(
48        body: TypedModel,
49        input_mapping: Vec<InputMapping>,
50        output_mapping: Vec<OutputMapping<TDim>>,
51        skip: usize,
52    ) -> TractResult<Scan> {
53        body.check_consistency()?;
54        ensure!(input_mapping.len() == body.input_outlets()?.len());
55        ensure!(output_mapping.len() == body.output_outlets()?.len());
56        Ok(Scan {
57            skip,
58            reset_every_turn: false,
59            body,
60            decluttered: false,
61            input_mapping,
62            output_mapping,
63        })
64    }
65
66    pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
67        self.to_codegen_op(false).unwrap().iteration_count(inputs)
68    }
69
70    fn declutter_body(
71        &self,
72        session: &mut OptimizerSession,
73        model: &TypedModel,
74        node: &TypedNode,
75    ) -> TractResult<Option<TypedModelPatch>> {
76        rule_if!(!self.decluttered);
77        let mut new = self.clone();
78        let mut body = self.body.clone();
79        session.optimize(&mut body)?;
80        new.body = body;
81        new.decluttered = true;
82        Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
83    }
84
85    fn declutter_single_loop(
86        &self,
87        _session: &mut OptimizerSession,
88        model: &TypedModel,
89        node: &TypedNode,
90    ) -> TractResult<Option<TypedModelPatch>> {
91        let inputs = model.node_input_facts(node.id)?;
92        let iters =
93            super::iteration_count(&self.input_mapping, &inputs).context("No scan input")?;
94        if !iters.is_one() {
95            return Ok(None);
96        }
97        let mut patch = TypedModelPatch::new("Inline single loop scan");
98        patch.model = self.body.clone();
99        for (outer_wire, inner_wire) in izip!(&node.inputs, &self.body.inputs) {
100            patch.taps.insert(*inner_wire, *outer_wire);
101        }
102        for (inner_wire, mapping) in izip!(&self.body.outputs, &self.output_mapping) {
103            if let Some((slot, _)) = mapping.scan {
104                patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
105            }
106            if let Some(slot) = mapping.last_value_slot {
107                patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
108            }
109        }
110        Ok(Some(patch))
111    }
112
113    fn declutter_body_axes(
114        &self,
115        _session: &mut OptimizerSession,
116        model: &TypedModel,
117        node: &TypedNode,
118    ) -> TractResult<Option<TypedModelPatch>> {
119        let mut suggestions = vec![];
120        for n in self.body.eval_order()? {
121            let node = self.body.node(n);
122            for suggestion in node.op.suggested_axis_changes()? {
123                let outlet = suggestion.0.as_outlet(node);
124                suggestions.push(AxisChange { outlet, op: suggestion.1 })
125            }
126            for (slot, fact) in node.outputs.iter().enumerate() {
127                for (ix, dim) in fact.fact.shape.iter().enumerate() {
128                    if dim.is_one() {
129                        suggestions.push(AxisChange {
130                            outlet: OutletId::new(n, slot),
131                            op: AxisOp::Rm(ix),
132                        });
133                    }
134                }
135            }
136        }
137        let node_input_facts = model.node_input_facts(node.id)?;
138        for suggestion in suggestions.into_iter() {
139            if let Some(conseq) = self.try_body_axes_change(suggestion, true, &node_input_facts)? {
140                let mut patch = TypedModelPatch::default();
141                let mut inputs = tvec!();
142                for outlet in &node.inputs {
143                    inputs.push(patch.tap_model(model, *outlet)?);
144                }
145                for change in conseq.wire_changes {
146                    if let InOut::In(i) = change.0 {
147                        let mut value = patch
148                            .outlet_fact(inputs[i])?
149                            .konst
150                            .clone()
151                            .context("Will only reshape constants")?
152                            .into_tensor();
153                        change.1.change_tensor(&mut value, false)?;
154                        let konst_name = patch.node(inputs[i].node).name.clone();
155                        inputs[i] = patch.add_const(konst_name, value)?;
156                    }
157                }
158                let wires = patch.wire_node(
159                    &node.name,
160                    conseq.substitute_op.unwrap_or_else(|| Box::new(self.clone())),
161                    &inputs,
162                )?;
163                for (ix, new) in wires.into_iter().enumerate() {
164                    patch.shunt_outside(model, OutletId::new(node.id, ix), new)?;
165                }
166                return Ok(Some(patch));
167            }
168        }
169        Ok(None)
170    }
171
172    fn remove_outer_output_from_mappings(
173        mappings: &[OutputMapping<TDim>],
174        discarded: usize,
175    ) -> Vec<OutputMapping<TDim>> {
176        mappings
177            .iter()
178            .map(|m| OutputMapping {
179                scan: m.scan.map(|(slot, info)| (slot - (slot > discarded) as usize, info)),
180                last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
181                full_dim_hint: m.full_dim_hint.clone(),
182                state: m.state,
183            })
184            .collect()
185    }
186
187    fn declutter_const_input(
188        &self,
189        _session: &mut OptimizerSession,
190        model: &TypedModel,
191        node: &TypedNode,
192    ) -> TractResult<Option<TypedModelPatch>> {
193        let inputs = model.node_input_facts(node.id)?;
194        for (slot, mapping) in self.input_mapping.iter().enumerate() {
195            if let InputMapping::Full = mapping
196                && let Some(konst) = inputs[slot].konst.as_ref()
197            {
198                let mut op = self.clone();
199                let src = op.body.inputs[slot];
200                op.body.inputs.remove(slot);
201                op.body.nodes[src.node].inputs.clear();
202                op.body.nodes[src.node].op = Box::new(Const::new(konst.clone())?);
203                op.input_mapping.remove(slot);
204                let mut inputs = node.inputs.clone();
205                inputs.remove(slot);
206                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
207            }
208        }
209        Ok(None)
210    }
211
212    fn declutter_discard_unused_input(
213        &self,
214        _session: &mut OptimizerSession,
215        model: &TypedModel,
216        node: &TypedNode,
217    ) -> TractResult<Option<TypedModelPatch>> {
218        for (slot, input) in self.body.input_outlets()?.iter().enumerate() {
219            let source_node = self.body.node(input.node);
220            if source_node.outputs[0].successors.len() == 0
221                && !self.body.output_outlets()?.contains(input)
222            {
223                let mut new_inputs = node.inputs.clone();
224                new_inputs.remove(slot);
225                let mut new_mappings: Vec<_> = self.input_mapping.clone();
226                new_mappings.remove(slot);
227                let mut model_inputs = self.body.input_outlets()?.to_vec();
228                model_inputs.remove(slot);
229                let mut body = self.body.clone();
230                let mut patch = TypedModelPatch::default();
231                patch.obliterate(source_node.id)?;
232                patch.apply(&mut body)?;
233                body.set_input_outlets(&model_inputs)?;
234                body.declutter()?;
235                let op =
236                    Self { body, input_mapping: new_mappings, decluttered: true, ..self.clone() };
237                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
238            }
239        }
240        Ok(None)
241    }
242
243    fn declutter_discard_useless_outer_output(
244        &self,
245        _session: &mut OptimizerSession,
246        model: &TypedModel,
247        node: &TypedNode,
248    ) -> TractResult<Option<TypedModelPatch>> {
249        for (ix, o) in node.outputs.iter().enumerate() {
250            if o.successors.len() == 0
251                && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
252            {
253                let mappings = self
254                    .output_mapping
255                    .iter()
256                    .map(|m| OutputMapping {
257                        scan: m.scan.filter(|(slot, _info)| *slot != ix),
258                        last_value_slot: m.last_value_slot.filter(|s| *s != ix),
259                        full_dim_hint: m.full_dim_hint.clone(),
260                        state: m.state,
261                    })
262                    .collect::<Vec<_>>();
263                let mut op = self.clone();
264                op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
265                let mut patch = TypedModelPatch::default();
266                let inputs = node
267                    .inputs
268                    .iter()
269                    .map(|&i| patch.tap_model(model, i))
270                    .collect::<TractResult<Vec<_>>>()?;
271                let wires = patch.wire_node(&*node.name, op, &inputs)?;
272                for oix in 0..node.outputs.len() {
273                    if oix != ix {
274                        patch.shunt_outside(
275                            model,
276                            OutletId::new(node.id, oix),
277                            wires[oix - (oix > ix) as usize],
278                        )?;
279                    }
280                }
281                return Ok(Some(patch));
282            }
283        }
284        Ok(None)
285    }
286
287    fn declutter_discard_empty_output_mapping_with_body_output(
288        &self,
289        _session: &mut OptimizerSession,
290        model: &TypedModel,
291        node: &TypedNode,
292    ) -> TractResult<Option<TypedModelPatch>> {
293        for (ix, om) in self.output_mapping.iter().enumerate() {
294            if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
295                let mut new_op = self.clone();
296                new_op.output_mapping.remove(ix);
297                new_op.body.outputs.remove(ix);
298                new_op.decluttered = false;
299                return Ok(Some(TypedModelPatch::replace_single_op(
300                    model,
301                    node,
302                    &node.inputs,
303                    new_op,
304                )?));
305            }
306        }
307        Ok(None)
308    }
309
310    fn declutter_pull_batcheable_input(
311        &self,
312        _session: &mut OptimizerSession,
313        model: &TypedModel,
314        node: &TypedNode,
315    ) -> TractResult<Option<TypedModelPatch>> {
316        'candidate: for (slot, input) in self.input_mapping.iter().enumerate() {
317            if let Some(scan_info) = input.as_scan() {
318                let scan_source = self.body.input_outlets()?[slot];
319                let scan_source_node = self.body.node(scan_source.node);
320                for mut succ in &scan_source_node.outputs[0].successors {
321                    for &succ_input in &self.body.node(succ.node).inputs {
322                        if succ_input != scan_source
323                            && self.body.outlet_fact(succ_input)?.konst.is_none()
324                        {
325                            continue 'candidate;
326                        }
327                    }
328                    if self.body.node(succ.node).outputs.len() != 1 {
329                        continue;
330                    }
331                    let mut new_body = self.body.clone();
332                    // insert propagate axis on einsum
333                    if let Some(einsum) = new_body.node(succ.node).op_as::<EinSum>()
334                        && let Some(patch) = einsum
335                            .propagate_axis(
336                                &new_body,
337                                new_body.node(succ.node),
338                                InOut::In(succ.slot),
339                                scan_info.axis,
340                            )
341                            .context("building axis propagating patch")?
342                    {
343                        patch.apply(&mut new_body)?;
344                        new_body.compute_const_facts()?;
345                        // propagate axis injects new nodes at the end. last successor of input
346                        // in new net will be the new succ
347                        let new_body_scan_input = new_body.input_outlets()?[slot];
348                        succ = new_body.node(new_body_scan_input.node).outputs[0]
349                            .successors
350                            .last()
351                            .unwrap();
352                    }
353
354                    let axes_mapping = {
355                        let (input_facts, output_facts) =
356                            new_body.node_facts(new_body.node(succ.node).id)?;
357                        new_body.node(succ.node).op.axes_mapping(&input_facts, &output_facts)?
358                    };
359                    let axis_info = axes_mapping.axis((InOut::In(succ.slot), scan_info.axis))?;
360                    if let &[axis_after] = &*axis_info.outputs[0] {
361                        let mut outside_patch = TypedModelPatch::new(format!(
362                            "Outer patch for input extraction of {}",
363                            new_body.node(succ.node)
364                        ));
365                        let mut patch_inputs = node
366                            .inputs
367                            .iter()
368                            .map(|&i| outside_patch.tap_model(model, i))
369                            .collect::<TractResult<TVec<_>>>()?;
370                        let mut extracted_op_inputs = tvec!();
371                        for (ix, outlet) in new_body.node(succ.node).inputs.iter().enumerate() {
372                            let wire = if ix == succ.slot {
373                                patch_inputs[slot]
374                            } else if let Some(konst) =
375                                new_body.outlet_fact(*outlet)?.konst.as_ref()
376                            {
377                                outside_patch.add_const(
378                                    format!(
379                                        "{}.extracted.{}",
380                                        node.name,
381                                        new_body.node(outlet.node).name
382                                    ),
383                                    konst.clone(),
384                                )?
385                            } else {
386                                unreachable!();
387                            };
388                            extracted_op_inputs.push(wire);
389                        }
390                        let new_input_wire = outside_patch.wire_node(
391                            format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
392                            new_body.node(succ.node).op.clone(),
393                            &extracted_op_inputs,
394                        )?[0];
395                        patch_inputs.push(new_input_wire);
396                        let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
397                        let mut new_input_inner_fact = new_input_outer_fact.clone();
398                        new_input_inner_fact.shape.set(axis_after, scan_info.chunk.abs().to_dim());
399
400                        let mut new_body = new_body.clone();
401                        let new_source_wire = new_body.add_source(
402                            format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
403                            new_input_inner_fact,
404                        )?;
405                        let mut inner_patch = TypedModelPatch::new(format!(
406                            "Inner body patch for extraction of {}",
407                            new_body.node(succ.node)
408                        ));
409                        let new_source_wire_in_patch =
410                            inner_patch.tap_model(&new_body, new_source_wire)?;
411                        inner_patch
412                            .shunt_outside(
413                                &new_body,
414                                OutletId::new(succ.node, 0),
415                                new_source_wire_in_patch,
416                            )
417                            .with_context(|| "patching inner model")?;
418                        inner_patch.apply(&mut new_body)?;
419
420                        let mut input_mapping = self.input_mapping.clone();
421                        input_mapping.push(InputMapping::Scan(ScanInfo {
422                            axis: axis_after,
423                            chunk: scan_info.chunk,
424                        }));
425
426                        let new_op = Self {
427                            input_mapping,
428                            decluttered: false,
429                            body: new_body,
430                            ..self.clone()
431                        };
432                        let output_wires =
433                            outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
434                        for w in output_wires {
435                            outside_patch
436                                .shunt_outside(model, OutletId::new(node.id, w.slot), w)
437                                .with_context(|| "patching outer model")?;
438                        }
439                        return Ok(Some(outside_patch));
440                    }
441                }
442            }
443        }
444        Ok(None)
445    }
446
447    fn declutter_pull_constant_outputs(
448        &self,
449        _session: &mut OptimizerSession,
450        model: &TypedModel,
451        node: &TypedNode,
452    ) -> TractResult<Option<TypedModelPatch>> {
453        for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
454            if let Some(slot) = mapping.last_value_slot
455                && let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone()
456            {
457                let inner_node = self.body.output_outlets()?[model_output_ix].node;
458                let inner_node = self.body.node(inner_node);
459                let mut patch = TypedModelPatch::new(format!("Extract const node {inner_node}"));
460                let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
461                patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
462                return Ok(Some(patch));
463            }
464        }
465        Ok(None)
466    }
467
468    fn declutter_pull_batcheable_output(
469        &self,
470        _session: &mut OptimizerSession,
471        model: &TypedModel,
472        node: &TypedNode,
473    ) -> TractResult<Option<TypedModelPatch>> {
474        for (mapping_ix, mapping) in self.output_mapping.iter().enumerate() {
475            if let Some((_, scan_info)) = mapping.scan {
476                let emitter_outlet = self.body.output_outlets()?[mapping_ix];
477                if self.body.node(emitter_outlet.node).outputs[emitter_outlet.slot].successors.len()
478                    > 0
479                    || self.body.inputs.contains(&emitter_outlet)
480                    || mapping.state
481                    || mapping.scan.map(|(_slot, i)| i.chunk > 1).unwrap_or(true)
482                {
483                    // continue if both last_value and full values are exported
484                    continue;
485                }
486                let mut new_body = self.body.clone();
487                if let Some(einsum) = new_body.node(emitter_outlet.node).op_as::<EinSum>()
488                    && let Some(patch) = einsum
489                        .propagate_axis(
490                            &new_body,
491                            new_body.node(emitter_outlet.node),
492                            InOut::Out(0),
493                            scan_info.axis,
494                        )
495                        .context("building axis propagating patch")?
496                {
497                    patch.apply(&mut new_body)?;
498                    new_body.prop_consts()?;
499                }
500                let emitter_outlet = new_body.output_outlets()?[mapping_ix];
501                let invariants = {
502                    let (input_facts, output_facts) = new_body.node_facts(emitter_outlet.node)?;
503                    new_body
504                        .node(emitter_outlet.node)
505                        .op
506                        .axes_mapping(&input_facts, &output_facts)?
507                };
508                let axis_tracking =
509                    invariants.axis((InOut::Out(emitter_outlet.slot), scan_info.axis))?;
510                rule_if!(axis_tracking.outputs.iter().all(|o| o.len() == 1));
511                let mut new_output_mapping = self.output_mapping.clone();
512                let mut new_scan_outputs = node.outputs.len();
513                let mut outer_slots = vec![];
514
515                // rewire input of the extracted node through the scan outlet boundary
516                for (input_slot, input) in
517                    new_body.node(emitter_outlet.node).inputs.clone().iter().enumerate()
518                {
519                    if new_body.outputs.iter().all(|o| o != input) {
520                        new_output_mapping.push(OutputMapping::default());
521                        new_body.outputs.push(*input);
522                    }
523                    let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
524                    let mapping = &mut new_output_mapping[body_output_id];
525                    let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
526                        if mapping.last_value_slot.is_none() {
527                            mapping.last_value_slot = Some(new_scan_outputs);
528                            new_scan_outputs += 1;
529                        }
530                        mapping.last_value_slot.unwrap()
531                    } else if let &[axis] = &*axis_tracking.inputs[input_slot] {
532                        if mapping.scan.is_none() {
533                            mapping.scan =
534                                Some((new_scan_outputs, ScanInfo { axis, chunk: scan_info.chunk }));
535                            new_scan_outputs += 1;
536                        }
537                        mapping.scan.unwrap().0
538                    } else {
539                        return Ok(None);
540                    };
541                    outer_slots.push(outer_slot);
542                }
543                let mut outside_patch = TypedModelPatch::new(format!(
544                    "Outside patch for output extraction of {}",
545                    new_body.node(emitter_outlet.node)
546                ));
547                let inputs = node
548                    .inputs
549                    .iter()
550                    .map(|&i| outside_patch.tap_model(model, i))
551                    .collect::<TractResult<TVec<_>>>()?;
552                let new_op = Self {
553                    output_mapping: new_output_mapping,
554                    decluttered: false,
555                    body: new_body.clone(), // FIXME maybe remove clone
556                    ..self.clone()
557                };
558                let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
559                let output = mapping.scan.unwrap();
560                let inputs =
561                    outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
562                let wire = outside_patch.wire_node(
563                    &new_body.node(emitter_outlet.node).name,
564                    new_body.node(emitter_outlet.node).op.clone(),
565                    &inputs,
566                )?[0];
567                outside_patch.shunt_outside(model, OutletId::new(node.id, output.0), wire)?;
568                for output_slot in 0..node.outputs.len() {
569                    if output_slot != output.0 {
570                        outside_patch.shunt_outside(
571                            model,
572                            OutletId::new(node.id, output_slot),
573                            OutletId::new(scan_outputs[0].node, output_slot),
574                        )?;
575                    }
576                }
577                return Ok(Some(outside_patch));
578            }
579        }
580        Ok(None)
581    }
582
583    fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
584        let input_state_outlets = self
585            .input_mapping
586            .iter()
587            .zip(self.body.input_outlets()?.iter())
588            .filter(|(m, _)| m.is_state())
589            .map(|(_, o)| o);
590        let output_state_outlets = self
591            .output_mapping
592            .iter()
593            .zip(self.body.output_outlets()?.iter())
594            .filter(|(m, _)| m.state)
595            .map(|(_, o)| o);
596        Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
597    }
598
599    fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult<TVec<OutletId>> {
600        let input_outlets =
601            self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| {
602                if node_input_facts[slot].konst.is_none() { Some(o) } else { None }
603            });
604        let output_outlets = self
605            .output_mapping
606            .iter()
607            .zip(self.body.output_outlets()?.iter())
608            .filter(|(m, _)| !m.invisible())
609            .map(|(_, o)| o);
610        Ok(input_outlets.chain(output_outlets).cloned().collect())
611    }
612
613    fn try_body_axes_change(
614        &self,
615        change: AxisChange,
616        locked_interface: bool,
617        node_input_facts: &[&TypedFact],
618    ) -> TractResult<Option<AxisChangeConsequence>> {
619        self.body.check_consistency()?;
620        let locked_outlets = self.body_locked_outlets(node_input_facts)?;
621        let mut explored: HashSet<AxisChange> = Default::default();
622        rule_if_some!(
623            (body_patch, body_changed_wires) = crate::optim::change_axes::change_axes(
624                &self.body,
625                &change,
626                if locked_interface { &locked_outlets } else { &[] },
627                &self.body_bounds()?,
628                &mut explored,
629            )?
630        );
631        let mut body = self.body.clone();
632        body_patch.apply(&mut body)?;
633        body.compact()?;
634        let mut wire_changes = tvec!();
635        let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
636        for (slot, m) in input_mapping.iter_mut().enumerate() {
637            if let Some(change) = body_changed_wires
638                .iter()
639                .find(|(iface, _change)| iface == &InOut::In(slot))
640                .map(|pair| pair.1.clone())
641            {
642                wire_changes.push((InOut::In(slot), change.clone()));
643                if let InputMapping::Scan(info) = m {
644                    rule_if_some!(axis = change.transform_axis(info.axis));
645                    info.axis = axis;
646                };
647            }
648        }
649        let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
650        for (ix, m) in output_mapping.iter_mut().enumerate() {
651            if let Some(change) = body_changed_wires
652                .iter()
653                .find(|(iface, _change)| iface == &InOut::Out(ix))
654                .map(|pair| pair.1.clone())
655            {
656                if let Some((slot, info)) = m.scan.as_mut() {
657                    rule_if_some!(new_axis = change.transform_axis(info.axis));
658                    info.axis = new_axis;
659                    wire_changes.push((InOut::Out(*slot), change.clone()));
660                }
661                if let Some(slot) = m.last_value_slot {
662                    wire_changes.push((InOut::Out(slot), change.clone()));
663                }
664            };
665        }
666        body.check_consistency()?;
667        let op = Some(Box::new(Scan {
668            body,
669            input_mapping,
670            output_mapping,
671            decluttered: false,
672            ..self.clone()
673        }) as _);
674        Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
675    }
676}
677
678impl Op for Scan {
679    fn name(&self) -> StaticName {
680        "Scan".into()
681    }
682
683    fn info(&self) -> TractResult<Vec<String>> {
684        let mut lines = vec![];
685        for (ix, im) in self.input_mapping.iter().enumerate() {
686            lines.push(format!("Model input  #{ix}: {im:?}"));
687        }
688        for (ix, om) in self.output_mapping.iter().enumerate() {
689            lines.push(format!("Model output #{ix}: {om:?}"));
690        }
691        lines.push(format!("skip:{} reset_every_turn:{:?}", self.skip, self.reset_every_turn));
692        Ok(lines)
693    }
694
695    fn validation(&self) -> Validation {
696        Validation::Rounding
697    }
698
699    op_as_typed_op!();
700}
701
702impl EvalOp for Scan {
703    fn is_stateless(&self) -> bool {
704        false
705    }
706    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
707        self.to_codegen_op(false)?.state(session, node_id)
708    }
709}
710
711impl TypedOp for Scan {
712    as_op!();
713
714    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
715        anyhow::ensure!(inputs.len() == self.body.inputs.len());
716        anyhow::ensure!(self.input_mapping.len() == self.body.inputs.len());
717        anyhow::ensure!(
718            self.input_mapping.iter().filter(|m| m.is_state()).count()
719                == self.output_mapping.iter().filter(|m| m.state).count()
720        );
721        for (i, o) in
722            self.input_mapping.iter().enumerate().filter(|(_, m)| m.is_state()).map(|(i, _)| i).zip(
723                self.output_mapping.iter().enumerate().filter(|(_, m)| m.state).map(|(o, _)| o),
724            )
725        {
726            let ifact = self.body.outlet_fact(self.body.inputs[i])?;
727            let ofact = self.body.outlet_fact(self.body.outputs[o])?;
728            anyhow::ensure!(
729                ifact == ofact,
730                "inconsistent fact: body input {i} is {ifact:?} and body output {o} is {ofact:?}\n{}",
731                self.body
732            )
733        }
734        let mut outputs = tvec!();
735        let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?;
736        for (ix, output) in self.output_mapping.iter().enumerate() {
737            let fact = self.body.output_fact(ix)?;
738            if let Some((slot, info)) = output.scan {
739                let mut shape = fact.shape.clone();
740                let scanning_dim =
741                    output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
742                shape.set(info.axis, scanning_dim);
743                outputs.push((slot, fact.datum_type.fact(shape)));
744            }
745            if let Some(slot) = output.last_value_slot {
746                outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
747            }
748        }
749        outputs.sort_by_key(|a| a.0);
750        anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
751        let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
752        Ok(outputs)
753    }
754
755    fn axes_mapping(
756        &self,
757        inputs: &[&TypedFact],
758        outputs: &[&TypedFact],
759    ) -> TractResult<AxesMapping> {
760        let mut mappings = vec![];
761        let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?;
762        for body_axis in body_invs.iter_all_axes() {
763            let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len());
764            info.inputs.clone_from(&body_axis.inputs);
765            for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
766                let mut slots = vec![];
767                if let Some((slot, _scan)) = output_mapping.scan {
768                    slots.push(slot);
769                }
770                if let Some(slot) = output_mapping.last_value_slot {
771                    slots.push(slot);
772                }
773                for slot in slots {
774                    info.outputs[slot].clone_from(&body_axis.outputs[ix]);
775                }
776            }
777            if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) {
778                mappings.push(info);
779            }
780        }
781        AxesMapping::new(inputs.len(), outputs.len(), mappings)
782    }
783
784    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
785        let mut suggestions = tvec!();
786        for (slot, input) in self.input_mapping.iter().enumerate() {
787            if let InputMapping::Scan(info) = input
788                && info.axis != 0
789            {
790                suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0)))
791            }
792        }
793        for output in &self.output_mapping {
794            if let Some((slot, scan)) = output.scan
795                && scan.axis != 0
796            {
797                suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0)))
798            }
799        }
800        Ok(suggestions)
801    }
802
803    fn change_axes(
804        &self,
805        model: &TypedModel,
806        node: &TypedNode,
807        io: InOut,
808        change: &AxisOp,
809    ) -> TractResult<Option<AxisChangeConsequence>> {
810        trace!("Propagating through {node}: {io:?} {change:?}");
811        let body_leading_outlet = match io {
812            InOut::In(ix) => self.body.input_outlets()?[ix],
813            InOut::Out(slot) => {
814                let output = self
815                    .output_mapping
816                    .iter()
817                    .position(|im| {
818                        im.scan.map(|(slot, _i)| slot) == Some(slot)
819                            || im.last_value_slot == Some(slot)
820                    })
821                    .unwrap();
822                self.body.output_outlets()?[output]
823            }
824        };
825        let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() };
826        let node_input_facts = model.node_input_facts(node.id)?;
827        let result = self
828            .try_body_axes_change(axis_change, false, &node_input_facts)
829            .with_context(|| "Attemping to run change through scan body".to_string())?;
830        if result.is_some() {
831            trace!("{node} accepted axis change");
832        } else {
833            trace!("{node} rejected axis change");
834        }
835        Ok(result)
836    }
837
838    fn declutter_with_session(
839        &self,
840        session: &mut OptimizerSession,
841        model: &TypedModel,
842        node: &TypedNode,
843    ) -> TractResult<Option<TypedModelPatch>> {
844        macro_rules! pass {
845            ($func:ident) => {
846                if let Some(mut r) = self
847                    .$func(session, model, node)
848                    .with_context(|| format!("{}", stringify!($func)))?
849                {
850                    trace!(stringify!($func));
851                    r.push_context(stringify!($func));
852                    return Ok(Some(r));
853                }
854            };
855        }
856        pass!(declutter_single_loop);
857        pass!(declutter_const_input);
858        pass!(declutter_discard_unused_input);
859        pass!(declutter_discard_useless_outer_output);
860        pass!(declutter_discard_empty_output_mapping_with_body_output);
861        pass!(declutter_body);
862        pass!(declutter_body_axes);
863        pass!(declutter_pull_constant_outputs);
864        pass!(declutter_pull_batcheable_input);
865        pass!(declutter_pull_batcheable_output);
866        Ok(None)
867    }
868
869    fn concretize_dims(
870        &self,
871        _source: &TypedModel,
872        node: &TypedNode,
873        target: &mut TypedModel,
874        mapping: &HashMap<OutletId, OutletId>,
875        values: &SymbolValues,
876    ) -> TractResult<TVec<OutletId>> {
877        let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
878        let op = Self {
879            output_mapping: self
880                .output_mapping
881                .iter()
882                .map(|om| om.concretize_dims(values))
883                .collect::<TractResult<Vec<_>>>()?,
884            body: self.body.concretize_dims(values)?,
885            ..self.clone()
886        };
887        target.wire_node(&node.name, op, &inputs)
888    }
889
890    fn codegen(
891        &self,
892        model: &TypedModel,
893        node: &TypedNode,
894    ) -> TractResult<Option<TypedModelPatch>> {
895        Ok(Some(TypedModelPatch::replace_single_op(
896            model,
897            node,
898            &node.inputs,
899            self.to_codegen_op(true)?,
900        )?))
901    }
902}