Skip to main content

tract_core/ops/scan/
decluttered.rs

1use std::collections::HashSet;
2
3use crate::ops::einsum::EinSum;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6
7use super::optimized::{OptScan, ScanOpParams};
8use tract_data::internal::*;
9use tract_data::itertools::izip;
10
11use super::*;
12
13#[derive(Debug, Clone, Default)]
14pub struct Scan {
15    pub skip: usize,
16    pub reset_every_turn: bool,
17    /// True iff the caller manages State inputs externally — they supply a
18    /// fresh value every run (typically reading the Scan's last_value_slot
19    /// output and feeding it back into the next call's State input). This
20    /// is set explicitly at construction time, e.g. by the ONNX LSTM/GRU/
21    /// RNN importer when both `initial_h` and `Y_h` are exposed (parakeet
22    /// decoder). When true, declutter_single_loop can safely inline a
23    /// stateful single-iter Scan because the caller's per-call value
24    /// reaches the body directly. When false (default), tract carries
25    /// State across calls internally and inlining would break recurrence
26    /// (issue #2157).
27    pub external_state: bool,
28    pub body: TypedModel,
29    pub decluttered: bool,
30    pub input_mapping: Vec<InputMapping>,
31    pub output_mapping: Vec<OutputMapping<TDim>>,
32}
33
34impl PartialEq for Scan {
35    fn eq(&self, _other: &Self) -> bool {
36        false
37    }
38}
39impl Eq for Scan {}
40
41impl Scan {
42    pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<OptScan> {
43        let mut model = self.body.clone();
44        if optimize_inner {
45            model = model.into_optimized()?;
46        }
47        let plan = SimplePlan::new(model)?;
48
49        Ok(OptScan::new(Arc::new(ScanOpParams::new(
50            self.skip,
51            self.reset_every_turn,
52            plan,
53            self.input_mapping.clone(),
54            self.output_mapping.clone(),
55        ))))
56    }
57
58    pub fn new(
59        body: TypedModel,
60        input_mapping: Vec<InputMapping>,
61        output_mapping: Vec<OutputMapping<TDim>>,
62        skip: usize,
63    ) -> TractResult<Scan> {
64        body.check_consistency()?;
65        ensure!(input_mapping.len() == body.input_outlets()?.len());
66        ensure!(output_mapping.len() == body.output_outlets()?.len());
67        Ok(Scan {
68            skip,
69            reset_every_turn: false,
70            external_state: false,
71            body,
72            decluttered: false,
73            input_mapping,
74            output_mapping,
75        })
76    }
77
78    pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
79        self.to_codegen_op(false).unwrap().iteration_count(inputs)
80    }
81
82    fn declutter_body(
83        &self,
84        session: &mut OptimizerSession,
85        model: &TypedModel,
86        node: &TypedNode,
87    ) -> TractResult<Option<TypedModelPatch>> {
88        rule_if!(!self.decluttered);
89        let mut new = self.clone();
90        let mut body = self.body.clone();
91        session.optimize(&mut body)?;
92        new.body = body;
93        new.decluttered = true;
94        Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
95    }
96
97    fn declutter_single_loop(
98        &self,
99        _session: &mut OptimizerSession,
100        model: &TypedModel,
101        node: &TypedNode,
102    ) -> TractResult<Option<TypedModelPatch>> {
103        let inputs = model.node_input_facts(node.id)?;
104        let iters =
105            super::iteration_count(&self.input_mapping, &inputs).context("No scan input")?;
106        rule_if!(iters.is_one());
107        // Inlining wires the body's State input directly from the outer
108        // initial-state input on every call. At runtime (optimized.rs
109        // OpState::eval), the State input is only seeded from inputs[slot]
110        // on the first call or when reset_every_turn is set; otherwise the
111        // body input is fed from the carried hidden_state.
112        //
113        // Inlining is safe iff the caller does not rely on tract's across-call
114        // carry. We accept it when:
115        //   - reset_every_turn (carry is cleared every turn anyway), or
116        //   - external_state (explicitly asserted, e.g. force_scan_external_state), or
117        //   - there is no State at all, or
118        //   - the caller manages the state: every recurrent state has a
119        //     last-value output that reaches a model output, so the caller can
120        //     read the updated state and feed it back across calls (e.g. DTLN /
121        //     parakeet decoders). A pulse model with tract-managed state (e.g.
122        //     DFN3 GRU) does NOT export its state, so it is not inlined.
123        // See issue #2157.
124        let has_state = self.input_mapping.iter().any(InputMapping::is_state);
125        let state_outputs: Vec<_> = self.output_mapping.iter().filter(|m| m.state).collect();
126        let state_exported = has_state
127            && !state_outputs.is_empty()
128            && state_outputs.iter().all(|m| {
129                m.last_value_slot.is_some_and(|slot| {
130                    Self::outlet_reaches_model_output(model, OutletId::new(node.id, slot))
131                })
132            });
133        rule_if!(self.reset_every_turn || self.external_state || !has_state || state_exported);
134        let mut patch = TypedModelPatch::new("Inline single loop scan");
135        patch.model = self.body.clone();
136        for (outer_wire, inner_wire) in izip!(&node.inputs, &self.body.inputs) {
137            patch.taps.insert(*inner_wire, *outer_wire);
138        }
139        for (inner_wire, mapping) in izip!(&self.body.outputs, &self.output_mapping) {
140            if let Some((slot, _)) = mapping.scan {
141                patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
142            }
143            if let Some(slot) = mapping.last_value_slot {
144                patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
145            }
146        }
147        Ok(Some(patch))
148    }
149
150    /// True if `start` (an outlet of this Scan, e.g. a state's last-value
151    /// output) is, or transitively feeds, a model output — i.e. the caller can
152    /// observe the updated state and thread it back across calls. Used to tell a
153    /// caller-managed-state model (safe to inline a single-iteration Scan) from
154    /// a pulse model whose state tract carries internally (must not inline).
155    fn outlet_reaches_model_output(model: &TypedModel, start: OutletId) -> bool {
156        let outputs: std::collections::HashSet<OutletId> = model.outputs.iter().copied().collect();
157        let mut seen: std::collections::HashSet<OutletId> = Default::default();
158        let mut stack = vec![start];
159        while let Some(o) = stack.pop() {
160            if outputs.contains(&o) {
161                return true;
162            }
163            if !seen.insert(o) {
164                continue;
165            }
166            for succ in &model.node(o.node).outputs[o.slot].successors {
167                for slot in 0..model.node(succ.node).outputs.len() {
168                    stack.push(OutletId::new(succ.node, slot));
169                }
170            }
171        }
172        false
173    }
174
175    fn declutter_body_axes(
176        &self,
177        _session: &mut OptimizerSession,
178        model: &TypedModel,
179        node: &TypedNode,
180    ) -> TractResult<Option<TypedModelPatch>> {
181        let mut suggestions = vec![];
182        for n in self.body.eval_order()? {
183            let node = self.body.node(n);
184            for suggestion in node.op.suggested_axis_changes()? {
185                let outlet = suggestion.0.as_outlet(node);
186                suggestions.push(AxisChange { outlet, op: suggestion.1 })
187            }
188            for (slot, fact) in node.outputs.iter().enumerate() {
189                for (ix, dim) in fact.fact.shape.iter().enumerate() {
190                    if dim.is_one() {
191                        suggestions.push(AxisChange {
192                            outlet: OutletId::new(n, slot),
193                            op: AxisOp::Rm(ix),
194                        });
195                    }
196                }
197            }
198        }
199        let node_input_facts = model.node_input_facts(node.id)?;
200        for suggestion in suggestions.into_iter() {
201            if let Some(conseq) = self.try_body_axes_change(suggestion, true, &node_input_facts)? {
202                let mut patch = TypedModelPatch::default();
203                let mut inputs = tvec!();
204                for outlet in &node.inputs {
205                    inputs.push(patch.tap_model(model, *outlet)?);
206                }
207                for change in conseq.wire_changes {
208                    if let InOut::In(i) = change.0 {
209                        let mut value = patch
210                            .outlet_fact(inputs[i])?
211                            .konst
212                            .clone()
213                            .context("Will only reshape constants")?
214                            .into_tensor();
215                        change.1.change_tensor(&mut value, false)?;
216                        let konst_name = patch.node(inputs[i].node).name.clone();
217                        inputs[i] = patch.add_const(konst_name, value)?;
218                    }
219                }
220                let wires = patch.wire_node(
221                    &node.name,
222                    conseq.substitute_op.unwrap_or_else(|| Box::new(self.clone())),
223                    &inputs,
224                )?;
225                for (ix, new) in wires.into_iter().enumerate() {
226                    patch.shunt_outside(model, OutletId::new(node.id, ix), new)?;
227                }
228                return Ok(Some(patch));
229            }
230        }
231        Ok(None)
232    }
233
234    fn remove_outer_output_from_mappings(
235        mappings: &[OutputMapping<TDim>],
236        discarded: usize,
237    ) -> Vec<OutputMapping<TDim>> {
238        mappings
239            .iter()
240            .map(|m| OutputMapping {
241                scan: m.scan.map(|(slot, info)| (slot - (slot > discarded) as usize, info)),
242                last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
243                full_dim_hint: m.full_dim_hint.clone(),
244                state: m.state,
245            })
246            .collect()
247    }
248
249    fn declutter_const_input(
250        &self,
251        _session: &mut OptimizerSession,
252        model: &TypedModel,
253        node: &TypedNode,
254    ) -> TractResult<Option<TypedModelPatch>> {
255        let inputs = model.node_input_facts(node.id)?;
256        for (slot, mapping) in self.input_mapping.iter().enumerate() {
257            if let InputMapping::Full = mapping
258                && let Some(konst) = inputs[slot].konst.as_ref()
259            {
260                let mut op = self.clone();
261                let src = op.body.inputs[slot];
262                op.body.inputs.remove(slot);
263                op.body.nodes[src.node].inputs.clear();
264                op.body.nodes[src.node].op = Box::new(Const::new(konst.clone())?);
265                op.input_mapping.remove(slot);
266                let mut inputs = node.inputs.clone();
267                inputs.remove(slot);
268                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
269            }
270        }
271        Ok(None)
272    }
273
274    fn declutter_discard_unused_input(
275        &self,
276        _session: &mut OptimizerSession,
277        model: &TypedModel,
278        node: &TypedNode,
279    ) -> TractResult<Option<TypedModelPatch>> {
280        for (slot, input) in self.body.input_outlets()?.iter().enumerate() {
281            let source_node = self.body.node(input.node);
282            if source_node.outputs[0].successors.len() == 0
283                && !self.body.output_outlets()?.contains(input)
284            {
285                let mut new_inputs = node.inputs.clone();
286                new_inputs.remove(slot);
287                let mut new_mappings: Vec<_> = self.input_mapping.clone();
288                new_mappings.remove(slot);
289                let mut model_inputs = self.body.input_outlets()?.to_vec();
290                model_inputs.remove(slot);
291                let mut body = self.body.clone();
292                let mut patch = TypedModelPatch::default();
293                patch.obliterate(source_node.id)?;
294                patch.apply(&mut body)?;
295                body.set_input_outlets(&model_inputs)?;
296                body.declutter()?;
297                let op =
298                    Self { body, input_mapping: new_mappings, decluttered: true, ..self.clone() };
299                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
300            }
301        }
302        Ok(None)
303    }
304
305    fn declutter_discard_useless_outer_output(
306        &self,
307        _session: &mut OptimizerSession,
308        model: &TypedModel,
309        node: &TypedNode,
310    ) -> TractResult<Option<TypedModelPatch>> {
311        for (ix, o) in node.outputs.iter().enumerate() {
312            if o.successors.len() == 0
313                && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
314            {
315                let mappings = self
316                    .output_mapping
317                    .iter()
318                    .map(|m| OutputMapping {
319                        scan: m.scan.filter(|(slot, _info)| *slot != ix),
320                        last_value_slot: m.last_value_slot.filter(|s| *s != ix),
321                        full_dim_hint: m.full_dim_hint.clone(),
322                        state: m.state,
323                    })
324                    .collect::<Vec<_>>();
325                let mut op = self.clone();
326                op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
327                let mut patch = TypedModelPatch::default();
328                let inputs = node
329                    .inputs
330                    .iter()
331                    .map(|&i| patch.tap_model(model, i))
332                    .collect::<TractResult<Vec<_>>>()?;
333                let wires = patch.wire_node(&*node.name, op, &inputs)?;
334                for oix in 0..node.outputs.len() {
335                    if oix != ix {
336                        patch.shunt_outside(
337                            model,
338                            OutletId::new(node.id, oix),
339                            wires[oix - (oix > ix) as usize],
340                        )?;
341                    }
342                }
343                return Ok(Some(patch));
344            }
345        }
346        Ok(None)
347    }
348
349    fn declutter_discard_empty_output_mapping_with_body_output(
350        &self,
351        _session: &mut OptimizerSession,
352        model: &TypedModel,
353        node: &TypedNode,
354    ) -> TractResult<Option<TypedModelPatch>> {
355        for (ix, om) in self.output_mapping.iter().enumerate() {
356            if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
357                let mut new_op = self.clone();
358                new_op.output_mapping.remove(ix);
359                new_op.body.outputs.remove(ix);
360                new_op.decluttered = false;
361                return Ok(Some(TypedModelPatch::replace_single_op(
362                    model,
363                    node,
364                    &node.inputs,
365                    new_op,
366                )?));
367            }
368        }
369        Ok(None)
370    }
371
372    fn declutter_pull_batcheable_input(
373        &self,
374        _session: &mut OptimizerSession,
375        model: &TypedModel,
376        node: &TypedNode,
377    ) -> TractResult<Option<TypedModelPatch>> {
378        'candidate: for (slot, input) in self.input_mapping.iter().enumerate() {
379            if let Some(scan_info) = input.as_scan() {
380                let scan_source = self.body.input_outlets()?[slot];
381                let scan_source_node = self.body.node(scan_source.node);
382                for mut succ in &scan_source_node.outputs[0].successors {
383                    for &succ_input in &self.body.node(succ.node).inputs {
384                        if succ_input != scan_source
385                            && self.body.outlet_fact(succ_input)?.konst.is_none()
386                        {
387                            continue 'candidate;
388                        }
389                    }
390                    if self.body.node(succ.node).outputs.len() != 1 {
391                        continue;
392                    }
393                    let mut new_body = self.body.clone();
394                    // insert propagate axis on einsum
395                    if let Some(einsum) = new_body.node(succ.node).op_as::<EinSum>()
396                        && let Some(patch) = einsum
397                            .propagate_axis(
398                                &new_body,
399                                new_body.node(succ.node),
400                                InOut::In(succ.slot),
401                                scan_info.axis,
402                            )
403                            .context("building axis propagating patch")?
404                    {
405                        patch.apply(&mut new_body)?;
406                        new_body.compute_const_facts()?;
407                        // propagate axis injects new nodes at the end. last successor of input
408                        // in new net will be the new succ
409                        let new_body_scan_input = new_body.input_outlets()?[slot];
410                        succ = new_body.node(new_body_scan_input.node).outputs[0]
411                            .successors
412                            .last()
413                            .unwrap();
414                    }
415
416                    let axes_mapping = {
417                        let (input_facts, output_facts) =
418                            new_body.node_facts(new_body.node(succ.node).id)?;
419                        new_body.node(succ.node).op.axes_mapping(&input_facts, &output_facts)?
420                    };
421                    let axis_info = axes_mapping.axis((InOut::In(succ.slot), scan_info.axis))?;
422                    if let &[axis_after] = &*axis_info.outputs[0] {
423                        let mut outside_patch = TypedModelPatch::new(format!(
424                            "Outer patch for input extraction of {}",
425                            new_body.node(succ.node)
426                        ));
427                        let mut patch_inputs = node
428                            .inputs
429                            .iter()
430                            .map(|&i| outside_patch.tap_model(model, i))
431                            .collect::<TractResult<TVec<_>>>()?;
432                        let mut extracted_op_inputs = tvec!();
433                        for (ix, outlet) in new_body.node(succ.node).inputs.iter().enumerate() {
434                            let wire = if ix == succ.slot {
435                                patch_inputs[slot]
436                            } else if let Some(konst) =
437                                new_body.outlet_fact(*outlet)?.konst.as_ref()
438                            {
439                                outside_patch.add_const(
440                                    format!(
441                                        "{}.extracted.{}",
442                                        node.name,
443                                        new_body.node(outlet.node).name
444                                    ),
445                                    konst.clone(),
446                                )?
447                            } else {
448                                unreachable!();
449                            };
450                            extracted_op_inputs.push(wire);
451                        }
452                        let new_input_wire = outside_patch.wire_node(
453                            format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
454                            new_body.node(succ.node).op.clone(),
455                            &extracted_op_inputs,
456                        )?[0];
457                        patch_inputs.push(new_input_wire);
458                        let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
459                        let mut new_input_inner_fact = new_input_outer_fact.clone();
460                        new_input_inner_fact.shape.set(axis_after, scan_info.chunk.abs().to_dim());
461
462                        let mut new_body = new_body.clone();
463                        let new_source_wire = new_body.add_source(
464                            format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
465                            new_input_inner_fact,
466                        )?;
467                        let mut inner_patch = TypedModelPatch::new(format!(
468                            "Inner body patch for extraction of {}",
469                            new_body.node(succ.node)
470                        ));
471                        let new_source_wire_in_patch =
472                            inner_patch.tap_model(&new_body, new_source_wire)?;
473                        inner_patch
474                            .shunt_outside(
475                                &new_body,
476                                OutletId::new(succ.node, 0),
477                                new_source_wire_in_patch,
478                            )
479                            .with_context(|| "patching inner model")?;
480                        inner_patch.apply(&mut new_body)?;
481
482                        let mut input_mapping = self.input_mapping.clone();
483                        input_mapping.push(InputMapping::Scan(ScanInfo {
484                            axis: axis_after,
485                            chunk: scan_info.chunk,
486                        }));
487
488                        let new_op = Self {
489                            input_mapping,
490                            decluttered: false,
491                            body: new_body,
492                            ..self.clone()
493                        };
494                        let output_wires =
495                            outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
496                        for w in output_wires {
497                            outside_patch
498                                .shunt_outside(model, OutletId::new(node.id, w.slot), w)
499                                .with_context(|| "patching outer model")?;
500                        }
501                        return Ok(Some(outside_patch));
502                    }
503                }
504            }
505        }
506        Ok(None)
507    }
508
509    fn declutter_pull_constant_outputs(
510        &self,
511        _session: &mut OptimizerSession,
512        model: &TypedModel,
513        node: &TypedNode,
514    ) -> TractResult<Option<TypedModelPatch>> {
515        for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
516            if let Some(slot) = mapping.last_value_slot
517                && let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone()
518            {
519                let inner_node = self.body.output_outlets()?[model_output_ix].node;
520                let inner_node = self.body.node(inner_node);
521                let mut patch = TypedModelPatch::new(format!("Extract const node {inner_node}"));
522                let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
523                patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
524                return Ok(Some(patch));
525            }
526        }
527        Ok(None)
528    }
529
530    fn declutter_pull_batcheable_output(
531        &self,
532        _session: &mut OptimizerSession,
533        model: &TypedModel,
534        node: &TypedNode,
535    ) -> TractResult<Option<TypedModelPatch>> {
536        for (mapping_ix, mapping) in self.output_mapping.iter().enumerate() {
537            if let Some((_, scan_info)) = mapping.scan {
538                let emitter_outlet = self.body.output_outlets()?[mapping_ix];
539                if self.body.node(emitter_outlet.node).outputs[emitter_outlet.slot].successors.len()
540                    > 0
541                    || self.body.inputs.contains(&emitter_outlet)
542                    || mapping.state
543                    || mapping.scan.map(|(_slot, i)| i.chunk > 1).unwrap_or(true)
544                {
545                    // continue if both last_value and full values are exported
546                    continue;
547                }
548                let mut new_body = self.body.clone();
549                if let Some(einsum) = new_body.node(emitter_outlet.node).op_as::<EinSum>()
550                    && let Some(patch) = einsum
551                        .propagate_axis(
552                            &new_body,
553                            new_body.node(emitter_outlet.node),
554                            InOut::Out(0),
555                            scan_info.axis,
556                        )
557                        .context("building axis propagating patch")?
558                {
559                    patch.apply(&mut new_body)?;
560                    new_body.prop_consts()?;
561                }
562                let emitter_outlet = new_body.output_outlets()?[mapping_ix];
563                let invariants = {
564                    let (input_facts, output_facts) = new_body.node_facts(emitter_outlet.node)?;
565                    new_body
566                        .node(emitter_outlet.node)
567                        .op
568                        .axes_mapping(&input_facts, &output_facts)?
569                };
570                let axis_tracking =
571                    invariants.axis((InOut::Out(emitter_outlet.slot), scan_info.axis))?;
572                rule_if!(axis_tracking.outputs.iter().all(|o| o.len() == 1));
573                let mut new_output_mapping = self.output_mapping.clone();
574                let mut new_scan_outputs = node.outputs.len();
575                let mut outer_slots = vec![];
576
577                // rewire input of the extracted node through the scan outlet boundary
578                for (input_slot, input) in
579                    new_body.node(emitter_outlet.node).inputs.clone().iter().enumerate()
580                {
581                    if new_body.outputs.iter().all(|o| o != input) {
582                        new_output_mapping.push(OutputMapping::default());
583                        new_body.outputs.push(*input);
584                    }
585                    let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
586                    let mapping = &mut new_output_mapping[body_output_id];
587                    let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
588                        if mapping.last_value_slot.is_none() {
589                            mapping.last_value_slot = Some(new_scan_outputs);
590                            new_scan_outputs += 1;
591                        }
592                        mapping.last_value_slot.unwrap()
593                    } else if let &[axis] = &*axis_tracking.inputs[input_slot] {
594                        if mapping.scan.is_none() {
595                            mapping.scan =
596                                Some((new_scan_outputs, ScanInfo { axis, chunk: scan_info.chunk }));
597                            new_scan_outputs += 1;
598                        }
599                        mapping.scan.unwrap().0
600                    } else {
601                        return Ok(None);
602                    };
603                    outer_slots.push(outer_slot);
604                }
605                let mut outside_patch = TypedModelPatch::new(format!(
606                    "Outside patch for output extraction of {}",
607                    new_body.node(emitter_outlet.node)
608                ));
609                let inputs = node
610                    .inputs
611                    .iter()
612                    .map(|&i| outside_patch.tap_model(model, i))
613                    .collect::<TractResult<TVec<_>>>()?;
614                let new_op = Self {
615                    output_mapping: new_output_mapping,
616                    decluttered: false,
617                    body: new_body.clone(), // FIXME maybe remove clone
618                    ..self.clone()
619                };
620                let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
621                let output = mapping.scan.unwrap();
622                let inputs =
623                    outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
624                let wire = outside_patch.wire_node(
625                    &new_body.node(emitter_outlet.node).name,
626                    new_body.node(emitter_outlet.node).op.clone(),
627                    &inputs,
628                )?[0];
629                outside_patch.shunt_outside(model, OutletId::new(node.id, output.0), wire)?;
630                for output_slot in 0..node.outputs.len() {
631                    if output_slot != output.0 {
632                        outside_patch.shunt_outside(
633                            model,
634                            OutletId::new(node.id, output_slot),
635                            OutletId::new(scan_outputs[0].node, output_slot),
636                        )?;
637                    }
638                }
639                return Ok(Some(outside_patch));
640            }
641        }
642        Ok(None)
643    }
644
645    fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
646        let input_state_outlets = self
647            .input_mapping
648            .iter()
649            .zip(self.body.input_outlets()?.iter())
650            .filter(|(m, _)| m.is_state())
651            .map(|(_, o)| o);
652        let output_state_outlets = self
653            .output_mapping
654            .iter()
655            .zip(self.body.output_outlets()?.iter())
656            .filter(|(m, _)| m.state)
657            .map(|(_, o)| o);
658        Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
659    }
660
661    fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult<TVec<OutletId>> {
662        let input_outlets =
663            self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| {
664                if node_input_facts[slot].konst.is_none() { Some(o) } else { None }
665            });
666        let output_outlets = self
667            .output_mapping
668            .iter()
669            .zip(self.body.output_outlets()?.iter())
670            .filter(|(m, _)| !m.invisible())
671            .map(|(_, o)| o);
672        Ok(input_outlets.chain(output_outlets).cloned().collect())
673    }
674
675    fn try_body_axes_change(
676        &self,
677        change: AxisChange,
678        locked_interface: bool,
679        node_input_facts: &[&TypedFact],
680    ) -> TractResult<Option<AxisChangeConsequence>> {
681        self.body.check_consistency()?;
682        let locked_outlets = self.body_locked_outlets(node_input_facts)?;
683        let mut explored: HashSet<AxisChange> = Default::default();
684        rule_if_some!(
685            (body_patch, body_changed_wires) = crate::optim::change_axes::change_axes(
686                &self.body,
687                &change,
688                if locked_interface { &locked_outlets } else { &[] },
689                &self.body_bounds()?,
690                &mut explored,
691            )?
692        );
693        let mut body = self.body.clone();
694        body_patch.apply(&mut body)?;
695        body.compact()?;
696        let mut wire_changes = tvec!();
697        let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
698        for (slot, m) in input_mapping.iter_mut().enumerate() {
699            if let Some(change) = body_changed_wires
700                .iter()
701                .find(|(iface, _change)| iface == &InOut::In(slot))
702                .map(|pair| pair.1.clone())
703            {
704                wire_changes.push((InOut::In(slot), change.clone()));
705                if let InputMapping::Scan(info) = m {
706                    rule_if_some!(axis = change.transform_axis(info.axis));
707                    info.axis = axis;
708                };
709            }
710        }
711        let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
712        for (ix, m) in output_mapping.iter_mut().enumerate() {
713            if let Some(change) = body_changed_wires
714                .iter()
715                .find(|(iface, _change)| iface == &InOut::Out(ix))
716                .map(|pair| pair.1.clone())
717            {
718                if let Some((slot, info)) = m.scan.as_mut() {
719                    rule_if_some!(new_axis = change.transform_axis(info.axis));
720                    info.axis = new_axis;
721                    wire_changes.push((InOut::Out(*slot), change.clone()));
722                }
723                if let Some(slot) = m.last_value_slot {
724                    wire_changes.push((InOut::Out(slot), change.clone()));
725                }
726            };
727        }
728        body.check_consistency()?;
729        let op = Some(Box::new(Scan {
730            body,
731            input_mapping,
732            output_mapping,
733            decluttered: false,
734            ..self.clone()
735        }) as _);
736        Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
737    }
738}
739
740impl Op for Scan {
741    fn name(&self) -> StaticName {
742        "Scan".into()
743    }
744
745    fn info(&self) -> TractResult<Vec<String>> {
746        let mut lines = vec![];
747        for (ix, im) in self.input_mapping.iter().enumerate() {
748            lines.push(format!("Model input  #{ix}: {im:?}"));
749        }
750        for (ix, om) in self.output_mapping.iter().enumerate() {
751            lines.push(format!("Model output #{ix}: {om:?}"));
752        }
753        lines.push(format!("skip:{} reset_every_turn:{:?}", self.skip, self.reset_every_turn));
754        Ok(lines)
755    }
756
757    fn validation(&self) -> Validation {
758        Validation::Rounding
759    }
760
761    op_as_typed_op!();
762}
763
764impl EvalOp for Scan {
765    fn is_stateless(&self) -> bool {
766        false
767    }
768    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
769        self.to_codegen_op(false)?.state(session, node_id)
770    }
771}
772
773impl TypedOp for Scan {
774    as_op!();
775
776    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
777        anyhow::ensure!(inputs.len() == self.body.inputs.len());
778        anyhow::ensure!(self.input_mapping.len() == self.body.inputs.len());
779        anyhow::ensure!(
780            self.input_mapping.iter().filter(|m| m.is_state()).count()
781                == self.output_mapping.iter().filter(|m| m.state).count()
782        );
783        for (i, o) in
784            self.input_mapping.iter().enumerate().filter(|(_, m)| m.is_state()).map(|(i, _)| i).zip(
785                self.output_mapping.iter().enumerate().filter(|(_, m)| m.state).map(|(o, _)| o),
786            )
787        {
788            let ifact = self.body.outlet_fact(self.body.inputs[i])?;
789            let ofact = self.body.outlet_fact(self.body.outputs[o])?;
790            anyhow::ensure!(
791                ifact == ofact,
792                "inconsistent fact: body input {i} is {ifact:?} and body output {o} is {ofact:?}\n{}",
793                self.body
794            )
795        }
796        let mut outputs = tvec!();
797        let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?;
798        for (ix, output) in self.output_mapping.iter().enumerate() {
799            let fact = self.body.output_fact(ix)?;
800            if let Some((slot, info)) = output.scan {
801                let mut shape = fact.shape.clone();
802                let scanning_dim =
803                    output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
804                shape.set(info.axis, scanning_dim);
805                outputs.push((slot, fact.datum_type.fact(shape)));
806            }
807            if let Some(slot) = output.last_value_slot {
808                outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
809            }
810        }
811        outputs.sort_by_key(|a| a.0);
812        anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
813        let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
814        Ok(outputs)
815    }
816
817    fn axes_mapping(
818        &self,
819        inputs: &[&TypedFact],
820        outputs: &[&TypedFact],
821    ) -> TractResult<AxesMapping> {
822        let mut mappings = vec![];
823        let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?;
824        for body_axis in body_invs.iter_all_axes() {
825            let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len());
826            info.inputs.clone_from(&body_axis.inputs);
827            for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
828                let mut slots = vec![];
829                if let Some((slot, _scan)) = output_mapping.scan {
830                    slots.push(slot);
831                }
832                if let Some(slot) = output_mapping.last_value_slot {
833                    slots.push(slot);
834                }
835                for slot in slots {
836                    info.outputs[slot].clone_from(&body_axis.outputs[ix]);
837                }
838            }
839            if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) {
840                mappings.push(info);
841            }
842        }
843        AxesMapping::new(inputs.len(), outputs.len(), mappings)
844    }
845
846    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
847        let mut suggestions = tvec!();
848        for (slot, input) in self.input_mapping.iter().enumerate() {
849            if let InputMapping::Scan(info) = input
850                && info.axis != 0
851            {
852                suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0)))
853            }
854        }
855        for output in &self.output_mapping {
856            if let Some((slot, scan)) = output.scan
857                && scan.axis != 0
858            {
859                suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0)))
860            }
861        }
862        Ok(suggestions)
863    }
864
865    fn change_axes(
866        &self,
867        model: &TypedModel,
868        node: &TypedNode,
869        io: InOut,
870        change: &AxisOp,
871    ) -> TractResult<Option<AxisChangeConsequence>> {
872        trace!("Propagating through {node}: {io:?} {change:?}");
873        let body_leading_outlet = match io {
874            InOut::In(ix) => self.body.input_outlets()?[ix],
875            InOut::Out(slot) => {
876                let output = self
877                    .output_mapping
878                    .iter()
879                    .position(|im| {
880                        im.scan.map(|(slot, _i)| slot) == Some(slot)
881                            || im.last_value_slot == Some(slot)
882                    })
883                    .unwrap();
884                self.body.output_outlets()?[output]
885            }
886        };
887        let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() };
888        let node_input_facts = model.node_input_facts(node.id)?;
889        let result = self
890            .try_body_axes_change(axis_change, false, &node_input_facts)
891            .with_context(|| "Attemping to run change through scan body".to_string())?;
892        if result.is_some() {
893            trace!("{node} accepted axis change");
894        } else {
895            trace!("{node} rejected axis change");
896        }
897        Ok(result)
898    }
899
900    fn declutter_with_session(
901        &self,
902        session: &mut OptimizerSession,
903        model: &TypedModel,
904        node: &TypedNode,
905    ) -> TractResult<Option<TypedModelPatch>> {
906        macro_rules! pass {
907            ($func:ident) => {
908                if let Some(mut r) = self
909                    .$func(session, model, node)
910                    .with_context(|| format!("{}", stringify!($func)))?
911                {
912                    trace!(stringify!($func));
913                    r.push_context(stringify!($func));
914                    return Ok(Some(r));
915                }
916            };
917        }
918        pass!(declutter_single_loop);
919        pass!(declutter_const_input);
920        pass!(declutter_discard_unused_input);
921        pass!(declutter_discard_useless_outer_output);
922        pass!(declutter_discard_empty_output_mapping_with_body_output);
923        pass!(declutter_body);
924        pass!(declutter_body_axes);
925        pass!(declutter_pull_constant_outputs);
926        pass!(declutter_pull_batcheable_input);
927        pass!(declutter_pull_batcheable_output);
928        Ok(None)
929    }
930
931    fn set_symbols(
932        &self,
933        _source: &TypedModel,
934        node: &TypedNode,
935        target: &mut TypedModel,
936        mapping: &HashMap<OutletId, OutletId>,
937        subs: &HashMap<Symbol, TDim>,
938    ) -> TractResult<TVec<OutletId>> {
939        let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
940        let op = Self {
941            output_mapping: self
942                .output_mapping
943                .iter()
944                .map(|om| om.set_symbols(subs))
945                .collect::<TractResult<Vec<_>>>()?,
946            body: self.body.set_symbols(subs)?,
947            ..self.clone()
948        };
949        target.wire_node(&node.name, op, &inputs)
950    }
951
952    fn codegen(
953        &self,
954        model: &TypedModel,
955        node: &TypedNode,
956    ) -> TractResult<Option<TypedModelPatch>> {
957        Ok(Some(TypedModelPatch::replace_single_op(
958            model,
959            node,
960            &node.inputs,
961            self.to_codegen_op(true)?,
962        )?))
963    }
964}