Skip to main content

tract_core/ops/scan/
decluttered.rs

1use std::collections::HashSet;
2
3use crate::ops::einsum::EinSum;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6
7use super::optimized::{OptScan, ScanOpParams};
8use tract_data::internal::*;
9use tract_data::itertools::izip;
10
11use super::*;
12
13#[derive(Debug, Clone, Default)]
14pub struct Scan {
15    pub skip: usize,
16    pub reset_every_turn: bool,
17    pub body: TypedModel,
18    pub decluttered: bool,
19    pub input_mapping: Vec<InputMapping>,
20    pub output_mapping: Vec<OutputMapping<TDim>>,
21}
22
23impl Scan {
24    pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<OptScan> {
25        let mut model = self.body.clone();
26        if optimize_inner {
27            model = model.into_optimized()?;
28        }
29        let plan = SimplePlan::new(model)?;
30
31        Ok(OptScan::new(Arc::new(ScanOpParams::new(
32            self.skip,
33            self.reset_every_turn,
34            plan,
35            self.input_mapping.clone(),
36            self.output_mapping.clone(),
37        ))))
38    }
39
40    pub fn new(
41        body: TypedModel,
42        input_mapping: Vec<InputMapping>,
43        output_mapping: Vec<OutputMapping<TDim>>,
44        skip: usize,
45    ) -> TractResult<Scan> {
46        body.check_consistency()?;
47        ensure!(input_mapping.len() == body.input_outlets()?.len());
48        ensure!(output_mapping.len() == body.output_outlets()?.len());
49        Ok(Scan {
50            skip,
51            reset_every_turn: false,
52            body,
53            decluttered: false,
54            input_mapping,
55            output_mapping,
56        })
57    }
58
59    pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
60        self.to_codegen_op(false).unwrap().iteration_count(inputs)
61    }
62
63    fn declutter_body(
64        &self,
65        session: &mut OptimizerSession,
66        model: &TypedModel,
67        node: &TypedNode,
68    ) -> TractResult<Option<TypedModelPatch>> {
69        rule_if!(!self.decluttered);
70        let mut new = self.clone();
71        let mut body = self.body.clone();
72        session.optimize(&mut body)?;
73        new.body = body;
74        new.decluttered = true;
75        Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
76    }
77
78    fn declutter_single_loop(
79        &self,
80        _session: &mut OptimizerSession,
81        model: &TypedModel,
82        node: &TypedNode,
83    ) -> TractResult<Option<TypedModelPatch>> {
84        let inputs = model.node_input_facts(node.id)?;
85        let iters =
86            super::iteration_count(&self.input_mapping, &inputs).context("No scan input")?;
87        if !iters.is_one() {
88            return Ok(None);
89        }
90        let mut patch = TypedModelPatch::new("Inline single loop scan");
91        patch.model = self.body.clone();
92        for (outer_wire, inner_wire) in izip!(&node.inputs, &self.body.inputs) {
93            patch.taps.insert(*inner_wire, *outer_wire);
94        }
95        for (inner_wire, mapping) in izip!(&self.body.outputs, &self.output_mapping) {
96            if let Some((slot, _)) = mapping.scan {
97                patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
98            }
99            if let Some(slot) = mapping.last_value_slot {
100                patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
101            }
102        }
103        Ok(Some(patch))
104    }
105
106    fn declutter_body_axes(
107        &self,
108        _session: &mut OptimizerSession,
109        model: &TypedModel,
110        node: &TypedNode,
111    ) -> TractResult<Option<TypedModelPatch>> {
112        let mut suggestions = vec![];
113        for n in self.body.eval_order()? {
114            let node = self.body.node(n);
115            for suggestion in node.op.suggested_axis_changes()? {
116                let outlet = suggestion.0.as_outlet(node);
117                suggestions.push(AxisChange { outlet, op: suggestion.1 })
118            }
119            for (slot, fact) in node.outputs.iter().enumerate() {
120                for (ix, dim) in fact.fact.shape.iter().enumerate() {
121                    if dim.is_one() {
122                        suggestions.push(AxisChange {
123                            outlet: OutletId::new(n, slot),
124                            op: AxisOp::Rm(ix),
125                        });
126                    }
127                }
128            }
129        }
130        let node_input_facts = model.node_input_facts(node.id)?;
131        for suggestion in suggestions.into_iter() {
132            if let Some(conseq) = self.try_body_axes_change(suggestion, true, &node_input_facts)? {
133                let mut patch = TypedModelPatch::default();
134                let mut inputs = tvec!();
135                for outlet in &node.inputs {
136                    inputs.push(patch.tap_model(model, *outlet)?);
137                }
138                for change in conseq.wire_changes {
139                    if let InOut::In(i) = change.0 {
140                        let mut value = patch
141                            .outlet_fact(inputs[i])?
142                            .konst
143                            .clone()
144                            .context("Will only reshape constants")?
145                            .into_tensor();
146                        change.1.change_tensor(&mut value, false)?;
147                        let konst_name = patch.node(inputs[i].node).name.clone();
148                        inputs[i] = patch.add_const(konst_name, value)?;
149                    }
150                }
151                let wires = patch.wire_node(
152                    &node.name,
153                    conseq.substitute_op.unwrap_or_else(|| Box::new(self.clone())),
154                    &inputs,
155                )?;
156                for (ix, new) in wires.into_iter().enumerate() {
157                    patch.shunt_outside(model, OutletId::new(node.id, ix), new)?;
158                }
159                return Ok(Some(patch));
160            }
161        }
162        Ok(None)
163    }
164
165    fn remove_outer_output_from_mappings(
166        mappings: &[OutputMapping<TDim>],
167        discarded: usize,
168    ) -> Vec<OutputMapping<TDim>> {
169        mappings
170            .iter()
171            .map(|m| OutputMapping {
172                scan: m.scan.map(|(slot, info)| (slot - (slot > discarded) as usize, info)),
173                last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
174                full_dim_hint: m.full_dim_hint.clone(),
175                state: m.state,
176            })
177            .collect()
178    }
179
180    fn declutter_const_input(
181        &self,
182        _session: &mut OptimizerSession,
183        model: &TypedModel,
184        node: &TypedNode,
185    ) -> TractResult<Option<TypedModelPatch>> {
186        let inputs = model.node_input_facts(node.id)?;
187        for (slot, mapping) in self.input_mapping.iter().enumerate() {
188            if let InputMapping::Full = mapping {
189                if let Some(konst) = inputs[slot].konst.as_ref() {
190                    let mut op = self.clone();
191                    let src = op.body.inputs[slot];
192                    op.body.inputs.remove(slot);
193                    op.body.nodes[src.node].inputs.clear();
194                    op.body.nodes[src.node].op = Box::new(Const::new(konst.clone())?);
195                    op.input_mapping.remove(slot);
196                    let mut inputs = node.inputs.clone();
197                    inputs.remove(slot);
198                    return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
199                }
200            }
201        }
202        Ok(None)
203    }
204
205    fn declutter_discard_unused_input(
206        &self,
207        _session: &mut OptimizerSession,
208        model: &TypedModel,
209        node: &TypedNode,
210    ) -> TractResult<Option<TypedModelPatch>> {
211        for (slot, input) in self.body.input_outlets()?.iter().enumerate() {
212            let source_node = self.body.node(input.node);
213            if source_node.outputs[0].successors.len() == 0
214                && !self.body.output_outlets()?.contains(input)
215            {
216                let mut new_inputs = node.inputs.clone();
217                new_inputs.remove(slot);
218                let mut new_mappings: Vec<_> = self.input_mapping.clone();
219                new_mappings.remove(slot);
220                let mut model_inputs = self.body.input_outlets()?.to_vec();
221                model_inputs.remove(slot);
222                let mut body = self.body.clone();
223                let mut patch = TypedModelPatch::default();
224                patch.obliterate(source_node.id)?;
225                patch.apply(&mut body)?;
226                body.set_input_outlets(&model_inputs)?;
227                body.declutter()?;
228                let op =
229                    Self { body, input_mapping: new_mappings, decluttered: true, ..self.clone() };
230                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
231            }
232        }
233        Ok(None)
234    }
235
236    fn declutter_discard_useless_outer_output(
237        &self,
238        _session: &mut OptimizerSession,
239        model: &TypedModel,
240        node: &TypedNode,
241    ) -> TractResult<Option<TypedModelPatch>> {
242        for (ix, o) in node.outputs.iter().enumerate() {
243            if o.successors.len() == 0
244                && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
245            {
246                let mappings = self
247                    .output_mapping
248                    .iter()
249                    .map(|m| OutputMapping {
250                        scan: m.scan.filter(|(slot, _info)| *slot != ix),
251                        last_value_slot: m.last_value_slot.filter(|s| *s != ix),
252                        full_dim_hint: m.full_dim_hint.clone(),
253                        state: m.state,
254                    })
255                    .collect::<Vec<_>>();
256                let mut op = self.clone();
257                op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
258                let mut patch = TypedModelPatch::default();
259                let inputs = node
260                    .inputs
261                    .iter()
262                    .map(|&i| patch.tap_model(model, i))
263                    .collect::<TractResult<Vec<_>>>()?;
264                let wires = patch.wire_node(&*node.name, op, &inputs)?;
265                for oix in 0..node.outputs.len() {
266                    if oix != ix {
267                        patch.shunt_outside(
268                            model,
269                            OutletId::new(node.id, oix),
270                            wires[oix - (oix > ix) as usize],
271                        )?;
272                    }
273                }
274                return Ok(Some(patch));
275            }
276        }
277        Ok(None)
278    }
279
280    fn declutter_discard_empty_output_mapping_with_body_output(
281        &self,
282        _session: &mut OptimizerSession,
283        model: &TypedModel,
284        node: &TypedNode,
285    ) -> TractResult<Option<TypedModelPatch>> {
286        for (ix, om) in self.output_mapping.iter().enumerate() {
287            if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
288                let mut new_op = self.clone();
289                new_op.output_mapping.remove(ix);
290                new_op.body.outputs.remove(ix);
291                new_op.decluttered = false;
292                return Ok(Some(TypedModelPatch::replace_single_op(
293                    model,
294                    node,
295                    &node.inputs,
296                    new_op,
297                )?));
298            }
299        }
300        Ok(None)
301    }
302
303    fn declutter_pull_batcheable_input(
304        &self,
305        _session: &mut OptimizerSession,
306        model: &TypedModel,
307        node: &TypedNode,
308    ) -> TractResult<Option<TypedModelPatch>> {
309        'candidate: for (slot, input) in self.input_mapping.iter().enumerate() {
310            if let Some(scan_info) = input.as_scan() {
311                let scan_source = self.body.input_outlets()?[slot];
312                let scan_source_node = self.body.node(scan_source.node);
313                for mut succ in &scan_source_node.outputs[0].successors {
314                    for &succ_input in &self.body.node(succ.node).inputs {
315                        if succ_input != scan_source
316                            && self.body.outlet_fact(succ_input)?.konst.is_none()
317                        {
318                            continue 'candidate;
319                        }
320                    }
321                    if self.body.node(succ.node).outputs.len() != 1 {
322                        continue;
323                    }
324                    let mut new_body = self.body.clone();
325                    // insert propagate axis on einsum
326                    if let Some(einsum) = new_body.node(succ.node).op_as::<EinSum>() {
327                        if let Some(patch) = einsum
328                            .propagate_axis(
329                                &new_body,
330                                new_body.node(succ.node),
331                                InOut::In(succ.slot),
332                                scan_info.axis,
333                            )
334                            .context("building axis propagating patch")?
335                        {
336                            patch.apply(&mut new_body)?;
337                            new_body.compute_const_facts()?;
338                            // propagate axis injects new nodes at the end. last successor of input
339                            // in new net will be the new succ
340                            let new_body_scan_input = new_body.input_outlets()?[slot];
341                            succ = new_body.node(new_body_scan_input.node).outputs[0]
342                                .successors
343                                .last()
344                                .unwrap();
345                        }
346                    }
347
348                    let axes_mapping = {
349                        let (input_facts, output_facts) =
350                            new_body.node_facts(new_body.node(succ.node).id)?;
351                        new_body.node(succ.node).op.axes_mapping(&input_facts, &output_facts)?
352                    };
353                    let axis_info = axes_mapping.axis((InOut::In(succ.slot), scan_info.axis))?;
354                    if let &[axis_after] = &*axis_info.outputs[0] {
355                        let mut outside_patch = TypedModelPatch::new(format!(
356                            "Outer patch for input extraction of {}",
357                            new_body.node(succ.node)
358                        ));
359                        let mut patch_inputs = node
360                            .inputs
361                            .iter()
362                            .map(|&i| outside_patch.tap_model(model, i))
363                            .collect::<TractResult<TVec<_>>>()?;
364                        let mut extracted_op_inputs = tvec!();
365                        for (ix, outlet) in new_body.node(succ.node).inputs.iter().enumerate() {
366                            let wire = if ix == succ.slot {
367                                patch_inputs[slot]
368                            } else if let Some(konst) =
369                                new_body.outlet_fact(*outlet)?.konst.as_ref()
370                            {
371                                outside_patch.add_const(
372                                    format!(
373                                        "{}.extracted.{}",
374                                        node.name,
375                                        new_body.node(outlet.node).name
376                                    ),
377                                    konst.clone(),
378                                )?
379                            } else {
380                                unreachable!();
381                            };
382                            extracted_op_inputs.push(wire);
383                        }
384                        let new_input_wire = outside_patch.wire_node(
385                            format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
386                            new_body.node(succ.node).op.clone(),
387                            &extracted_op_inputs,
388                        )?[0];
389                        patch_inputs.push(new_input_wire);
390                        let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
391                        let mut new_input_inner_fact = new_input_outer_fact.clone();
392                        new_input_inner_fact.shape.set(axis_after, scan_info.chunk.abs().to_dim());
393
394                        let mut new_body = new_body.clone();
395                        let new_source_wire = new_body.add_source(
396                            format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
397                            new_input_inner_fact,
398                        )?;
399                        let mut inner_patch = TypedModelPatch::new(format!(
400                            "Inner body patch for extraction of {}",
401                            new_body.node(succ.node)
402                        ));
403                        let new_source_wire_in_patch =
404                            inner_patch.tap_model(&new_body, new_source_wire)?;
405                        inner_patch
406                            .shunt_outside(
407                                &new_body,
408                                OutletId::new(succ.node, 0),
409                                new_source_wire_in_patch,
410                            )
411                            .with_context(|| "patching inner model")?;
412                        inner_patch.apply(&mut new_body)?;
413
414                        let mut input_mapping = self.input_mapping.clone();
415                        input_mapping.push(InputMapping::Scan(ScanInfo {
416                            axis: axis_after,
417                            chunk: scan_info.chunk,
418                        }));
419
420                        let new_op = Self {
421                            input_mapping,
422                            decluttered: false,
423                            body: new_body,
424                            ..self.clone()
425                        };
426                        let output_wires =
427                            outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
428                        for w in output_wires {
429                            outside_patch
430                                .shunt_outside(model, OutletId::new(node.id, w.slot), w)
431                                .with_context(|| "patching outer model")?;
432                        }
433                        return Ok(Some(outside_patch));
434                    }
435                }
436            }
437        }
438        Ok(None)
439    }
440
441    fn declutter_pull_constant_outputs(
442        &self,
443        _session: &mut OptimizerSession,
444        model: &TypedModel,
445        node: &TypedNode,
446    ) -> TractResult<Option<TypedModelPatch>> {
447        for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
448            if let Some(slot) = mapping.last_value_slot {
449                if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
450                    let inner_node = self.body.output_outlets()?[model_output_ix].node;
451                    let inner_node = self.body.node(inner_node);
452                    let mut patch =
453                        TypedModelPatch::new(format!("Extract const node {inner_node}"));
454                    let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
455                    patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
456                    return Ok(Some(patch));
457                }
458            }
459        }
460        Ok(None)
461    }
462
463    fn declutter_pull_batcheable_output(
464        &self,
465        _session: &mut OptimizerSession,
466        model: &TypedModel,
467        node: &TypedNode,
468    ) -> TractResult<Option<TypedModelPatch>> {
469        for (mapping_ix, mapping) in self.output_mapping.iter().enumerate() {
470            if let Some((_, scan_info)) = mapping.scan {
471                let emitter_outlet = self.body.output_outlets()?[mapping_ix];
472                if self.body.node(emitter_outlet.node).outputs[emitter_outlet.slot].successors.len()
473                    > 0
474                    || self.body.inputs.contains(&emitter_outlet)
475                    || mapping.state
476                    || mapping.scan.map(|(_slot, i)| i.chunk > 1).unwrap_or(true)
477                {
478                    // continue if both last_value and full values are exported
479                    continue;
480                }
481                let mut new_body = self.body.clone();
482                if let Some(einsum) = new_body.node(emitter_outlet.node).op_as::<EinSum>() {
483                    if let Some(patch) = einsum
484                        .propagate_axis(
485                            &new_body,
486                            new_body.node(emitter_outlet.node),
487                            InOut::Out(0),
488                            scan_info.axis,
489                        )
490                        .context("building axis propagating patch")?
491                    {
492                        patch.apply(&mut new_body)?;
493                        new_body.prop_consts()?;
494                    }
495                }
496                let emitter_outlet = new_body.output_outlets()?[mapping_ix];
497                let invariants = {
498                    let (input_facts, output_facts) = new_body.node_facts(emitter_outlet.node)?;
499                    new_body
500                        .node(emitter_outlet.node)
501                        .op
502                        .axes_mapping(&input_facts, &output_facts)?
503                };
504                let axis_tracking =
505                    invariants.axis((InOut::Out(emitter_outlet.slot), scan_info.axis))?;
506                rule_if!(axis_tracking.outputs.iter().all(|o| o.len() == 1));
507                let mut new_output_mapping = self.output_mapping.clone();
508                let mut new_scan_outputs = node.outputs.len();
509                let mut outer_slots = vec![];
510
511                // rewire input of the extracted node through the scan outlet boundary
512                for (input_slot, input) in
513                    new_body.node(emitter_outlet.node).inputs.clone().iter().enumerate()
514                {
515                    if new_body.outputs.iter().all(|o| o != input) {
516                        new_output_mapping.push(OutputMapping::default());
517                        new_body.outputs.push(*input);
518                    }
519                    let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
520                    let mapping = &mut new_output_mapping[body_output_id];
521                    let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
522                        if mapping.last_value_slot.is_none() {
523                            mapping.last_value_slot = Some(new_scan_outputs);
524                            new_scan_outputs += 1;
525                        }
526                        mapping.last_value_slot.unwrap()
527                    } else if let &[axis] = &*axis_tracking.inputs[input_slot] {
528                        if mapping.scan.is_none() {
529                            mapping.scan =
530                                Some((new_scan_outputs, ScanInfo { axis, chunk: scan_info.chunk }));
531                            new_scan_outputs += 1;
532                        }
533                        mapping.scan.unwrap().0
534                    } else {
535                        return Ok(None);
536                    };
537                    outer_slots.push(outer_slot);
538                }
539                let mut outside_patch = TypedModelPatch::new(format!(
540                    "Outside patch for output extraction of {}",
541                    new_body.node(emitter_outlet.node)
542                ));
543                let inputs = node
544                    .inputs
545                    .iter()
546                    .map(|&i| outside_patch.tap_model(model, i))
547                    .collect::<TractResult<TVec<_>>>()?;
548                let new_op = Self {
549                    output_mapping: new_output_mapping,
550                    decluttered: false,
551                    body: new_body.clone(), // FIXME maybe remove clone
552                    ..self.clone()
553                };
554                let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
555                let output = mapping.scan.unwrap();
556                let inputs =
557                    outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
558                let wire = outside_patch.wire_node(
559                    &new_body.node(emitter_outlet.node).name,
560                    new_body.node(emitter_outlet.node).op.clone(),
561                    &inputs,
562                )?[0];
563                outside_patch.shunt_outside(model, OutletId::new(node.id, output.0), wire)?;
564                for output_slot in 0..node.outputs.len() {
565                    if output_slot != output.0 {
566                        outside_patch.shunt_outside(
567                            model,
568                            OutletId::new(node.id, output_slot),
569                            OutletId::new(scan_outputs[0].node, output_slot),
570                        )?;
571                    }
572                }
573                return Ok(Some(outside_patch));
574            }
575        }
576        Ok(None)
577    }
578
579    fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
580        let input_state_outlets = self
581            .input_mapping
582            .iter()
583            .zip(self.body.input_outlets()?.iter())
584            .filter(|(m, _)| m.is_state())
585            .map(|(_, o)| o);
586        let output_state_outlets = self
587            .output_mapping
588            .iter()
589            .zip(self.body.output_outlets()?.iter())
590            .filter(|(m, _)| m.state)
591            .map(|(_, o)| o);
592        Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
593    }
594
595    fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult<TVec<OutletId>> {
596        let input_outlets =
597            self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| {
598                if node_input_facts[slot].konst.is_none() { Some(o) } else { None }
599            });
600        let output_outlets = self
601            .output_mapping
602            .iter()
603            .zip(self.body.output_outlets()?.iter())
604            .filter(|(m, _)| !m.invisible())
605            .map(|(_, o)| o);
606        Ok(input_outlets.chain(output_outlets).cloned().collect())
607    }
608
609    fn try_body_axes_change(
610        &self,
611        change: AxisChange,
612        locked_interface: bool,
613        node_input_facts: &[&TypedFact],
614    ) -> TractResult<Option<AxisChangeConsequence>> {
615        self.body.check_consistency()?;
616        let locked_outlets = self.body_locked_outlets(node_input_facts)?;
617        let mut explored: HashSet<AxisChange> = Default::default();
618        rule_if_some!(
619            (body_patch, body_changed_wires) = crate::optim::change_axes::change_axes(
620                &self.body,
621                &change,
622                if locked_interface { &locked_outlets } else { &[] },
623                &self.body_bounds()?,
624                &mut explored,
625            )?
626        );
627        let mut body = self.body.clone();
628        body_patch.apply(&mut body)?;
629        body.compact()?;
630        let mut wire_changes = tvec!();
631        let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
632        for (slot, m) in input_mapping.iter_mut().enumerate() {
633            if let Some(change) = body_changed_wires
634                .iter()
635                .find(|(iface, _change)| iface == &InOut::In(slot))
636                .map(|pair| pair.1.clone())
637            {
638                wire_changes.push((InOut::In(slot), change.clone()));
639                if let InputMapping::Scan(info) = m {
640                    rule_if_some!(axis = change.transform_axis(info.axis));
641                    info.axis = axis;
642                };
643            }
644        }
645        let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
646        for (ix, m) in output_mapping.iter_mut().enumerate() {
647            if let Some(change) = body_changed_wires
648                .iter()
649                .find(|(iface, _change)| iface == &InOut::Out(ix))
650                .map(|pair| pair.1.clone())
651            {
652                if let Some((slot, info)) = m.scan.as_mut() {
653                    rule_if_some!(new_axis = change.transform_axis(info.axis));
654                    info.axis = new_axis;
655                    wire_changes.push((InOut::Out(*slot), change.clone()));
656                }
657                if let Some(slot) = m.last_value_slot {
658                    wire_changes.push((InOut::Out(slot), change.clone()));
659                }
660            };
661        }
662        body.check_consistency()?;
663        let op = Some(Box::new(Scan {
664            body,
665            input_mapping,
666            output_mapping,
667            decluttered: false,
668            ..self.clone()
669        }) as _);
670        Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
671    }
672}
673
674impl Op for Scan {
675    fn name(&self) -> StaticName {
676        "Scan".into()
677    }
678
679    fn info(&self) -> TractResult<Vec<String>> {
680        let mut lines = vec![];
681        for (ix, im) in self.input_mapping.iter().enumerate() {
682            lines.push(format!("Model input  #{ix}: {im:?}"));
683        }
684        for (ix, om) in self.output_mapping.iter().enumerate() {
685            lines.push(format!("Model output #{ix}: {om:?}"));
686        }
687        lines.push(format!("skip:{} reset_every_turn:{:?}", self.skip, self.reset_every_turn));
688        Ok(lines)
689    }
690
691    fn validation(&self) -> Validation {
692        Validation::Rounding
693    }
694
695    op_as_typed_op!();
696}
697
698impl EvalOp for Scan {
699    fn is_stateless(&self) -> bool {
700        false
701    }
702    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
703        self.to_codegen_op(false)?.state(session, node_id)
704    }
705}
706
707impl TypedOp for Scan {
708    as_op!();
709
710    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
711        anyhow::ensure!(inputs.len() == self.body.inputs.len());
712        anyhow::ensure!(self.input_mapping.len() == self.body.inputs.len());
713        anyhow::ensure!(
714            self.input_mapping.iter().filter(|m| m.is_state()).count()
715                == self.output_mapping.iter().filter(|m| m.state).count()
716        );
717        for (i, o) in
718            self.input_mapping.iter().enumerate().filter(|(_, m)| m.is_state()).map(|(i, _)| i).zip(
719                self.output_mapping.iter().enumerate().filter(|(_, m)| m.state).map(|(o, _)| o),
720            )
721        {
722            let ifact = self.body.outlet_fact(self.body.inputs[i])?;
723            let ofact = self.body.outlet_fact(self.body.outputs[o])?;
724            anyhow::ensure!(
725                ifact == ofact,
726                "inconsistent fact: body input {i} is {ifact:?} and body output {o} is {ofact:?}\n{}",
727                self.body
728            )
729        }
730        let mut outputs = tvec!();
731        let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?;
732        for (ix, output) in self.output_mapping.iter().enumerate() {
733            let fact = self.body.output_fact(ix)?;
734            if let Some((slot, info)) = output.scan {
735                let mut shape = fact.shape.clone();
736                let scanning_dim =
737                    output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
738                shape.set(info.axis, scanning_dim);
739                outputs.push((slot, fact.datum_type.fact(shape)));
740            }
741            if let Some(slot) = output.last_value_slot {
742                outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
743            }
744        }
745        outputs.sort_by_key(|a| a.0);
746        anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
747        let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
748        Ok(outputs)
749    }
750
751    fn axes_mapping(
752        &self,
753        inputs: &[&TypedFact],
754        outputs: &[&TypedFact],
755    ) -> TractResult<AxesMapping> {
756        let mut mappings = vec![];
757        let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?;
758        for body_axis in body_invs.iter_all_axes() {
759            let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len());
760            info.inputs.clone_from(&body_axis.inputs);
761            for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
762                let mut slots = vec![];
763                if let Some((slot, _scan)) = output_mapping.scan {
764                    slots.push(slot);
765                }
766                if let Some(slot) = output_mapping.last_value_slot {
767                    slots.push(slot);
768                }
769                for slot in slots {
770                    info.outputs[slot].clone_from(&body_axis.outputs[ix]);
771                }
772            }
773            if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) {
774                mappings.push(info);
775            }
776        }
777        AxesMapping::new(inputs.len(), outputs.len(), mappings)
778    }
779
780    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
781        let mut suggestions = tvec!();
782        for (slot, input) in self.input_mapping.iter().enumerate() {
783            if let InputMapping::Scan(info) = input {
784                if info.axis != 0 {
785                    suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0)))
786                }
787            }
788        }
789        for output in &self.output_mapping {
790            if let Some((slot, scan)) = output.scan {
791                if scan.axis != 0 {
792                    suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0)))
793                }
794            }
795        }
796        Ok(suggestions)
797    }
798
799    fn change_axes(
800        &self,
801        model: &TypedModel,
802        node: &TypedNode,
803        io: InOut,
804        change: &AxisOp,
805    ) -> TractResult<Option<AxisChangeConsequence>> {
806        trace!("Propagating through {node}: {io:?} {change:?}");
807        let body_leading_outlet = match io {
808            InOut::In(ix) => self.body.input_outlets()?[ix],
809            InOut::Out(slot) => {
810                let output = self
811                    .output_mapping
812                    .iter()
813                    .position(|im| {
814                        im.scan.map(|(slot, _i)| slot) == Some(slot)
815                            || im.last_value_slot == Some(slot)
816                    })
817                    .unwrap();
818                self.body.output_outlets()?[output]
819            }
820        };
821        let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() };
822        let node_input_facts = model.node_input_facts(node.id)?;
823        let result = self
824            .try_body_axes_change(axis_change, false, &node_input_facts)
825            .with_context(|| "Attemping to run change through scan body".to_string())?;
826        if result.is_some() {
827            trace!("{node} accepted axis change");
828        } else {
829            trace!("{node} rejected axis change");
830        }
831        Ok(result)
832    }
833
834    fn declutter_with_session(
835        &self,
836        session: &mut OptimizerSession,
837        model: &TypedModel,
838        node: &TypedNode,
839    ) -> TractResult<Option<TypedModelPatch>> {
840        macro_rules! pass {
841            ($func:ident) => {
842                if let Some(mut r) = self
843                    .$func(session, model, node)
844                    .with_context(|| format!("{}", stringify!($func)))?
845                {
846                    trace!(stringify!($func));
847                    r.push_context(stringify!($func));
848                    return Ok(Some(r));
849                }
850            };
851        }
852        pass!(declutter_single_loop);
853        pass!(declutter_const_input);
854        pass!(declutter_discard_unused_input);
855        pass!(declutter_discard_useless_outer_output);
856        pass!(declutter_discard_empty_output_mapping_with_body_output);
857        pass!(declutter_body);
858        pass!(declutter_body_axes);
859        pass!(declutter_pull_constant_outputs);
860        pass!(declutter_pull_batcheable_input);
861        pass!(declutter_pull_batcheable_output);
862        Ok(None)
863    }
864
865    fn concretize_dims(
866        &self,
867        _source: &TypedModel,
868        node: &TypedNode,
869        target: &mut TypedModel,
870        mapping: &HashMap<OutletId, OutletId>,
871        values: &SymbolValues,
872    ) -> TractResult<TVec<OutletId>> {
873        let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
874        let op = Self {
875            output_mapping: self
876                .output_mapping
877                .iter()
878                .map(|om| om.concretize_dims(values))
879                .collect::<TractResult<Vec<_>>>()?,
880            body: self.body.concretize_dims(values)?,
881            ..self.clone()
882        };
883        target.wire_node(&node.name, op, &inputs)
884    }
885
886    fn codegen(
887        &self,
888        model: &TypedModel,
889        node: &TypedNode,
890    ) -> TractResult<Option<TypedModelPatch>> {
891        Ok(Some(TypedModelPatch::replace_single_op(
892            model,
893            node,
894            &node.inputs,
895            self.to_codegen_op(true)?,
896        )?))
897    }
898}