Skip to main content

tract_core/ops/scan/
decluttered.rs

1use std::collections::HashSet;
2
3use crate::ops::einsum::EinSum;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6
7use super::optimized::{OptScan, ScanOpParams};
8use tract_data::internal::*;
9use tract_data::itertools::izip;
10
11use super::*;
12
13#[derive(Debug, Clone, Default)]
14pub struct Scan {
15    pub skip: usize,
16    pub reset_every_turn: bool,
17    /// True iff the caller manages State inputs externally — they supply a
18    /// fresh value every run (typically reading the Scan's last_value_slot
19    /// output and feeding it back into the next call's State input). This
20    /// is set explicitly at construction time, e.g. by the ONNX LSTM/GRU/
21    /// RNN importer when both `initial_h` and `Y_h` are exposed (parakeet
22    /// decoder). When true, declutter_single_loop can safely inline a
23    /// stateful single-iter Scan because the caller's per-call value
24    /// reaches the body directly. When false (default), tract carries
25    /// State across calls internally and inlining would break recurrence
26    /// (issue #2157).
27    pub external_state: bool,
28    pub body: TypedModel,
29    pub decluttered: bool,
30    pub input_mapping: Vec<InputMapping>,
31    pub output_mapping: Vec<OutputMapping<TDim>>,
32}
33
34impl PartialEq for Scan {
35    fn eq(&self, _other: &Self) -> bool {
36        false
37    }
38}
39impl Eq for Scan {}
40
41impl Scan {
42    pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<OptScan> {
43        let mut model = self.body.clone();
44        if optimize_inner {
45            model = model.into_optimized()?;
46        }
47        let plan = SimplePlan::new(model)?;
48
49        Ok(OptScan::new(Arc::new(ScanOpParams::new(
50            self.skip,
51            self.reset_every_turn,
52            plan,
53            self.input_mapping.clone(),
54            self.output_mapping.clone(),
55        ))))
56    }
57
58    pub fn new(
59        body: TypedModel,
60        input_mapping: Vec<InputMapping>,
61        output_mapping: Vec<OutputMapping<TDim>>,
62        skip: usize,
63    ) -> TractResult<Scan> {
64        body.check_consistency()?;
65        ensure!(input_mapping.len() == body.input_outlets()?.len());
66        ensure!(output_mapping.len() == body.output_outlets()?.len());
67        Ok(Scan {
68            skip,
69            reset_every_turn: false,
70            external_state: false,
71            body,
72            decluttered: false,
73            input_mapping,
74            output_mapping,
75        })
76    }
77
78    pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
79        self.to_codegen_op(false).unwrap().iteration_count(inputs)
80    }
81
82    fn declutter_body(
83        &self,
84        session: &mut OptimizerSession,
85        model: &TypedModel,
86        node: &TypedNode,
87    ) -> TractResult<Option<TypedModelPatch>> {
88        rule_if!(!self.decluttered);
89        let mut new = self.clone();
90        let mut body = self.body.clone();
91        session.optimize(&mut body)?;
92        new.body = body;
93        new.decluttered = true;
94        Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
95    }
96
97    fn declutter_single_loop(
98        &self,
99        _session: &mut OptimizerSession,
100        model: &TypedModel,
101        node: &TypedNode,
102    ) -> TractResult<Option<TypedModelPatch>> {
103        let inputs = model.node_input_facts(node.id)?;
104        let iters =
105            super::iteration_count(&self.input_mapping, &inputs).context("No scan input")?;
106        rule_if!(iters.is_one());
107        // Inlining wires the body's State input directly from the outer
108        // initial-state input on every call. At runtime (optimized.rs
109        // OpState::eval), the State input is only seeded from inputs[slot]
110        // on the first call or when reset_every_turn is set; otherwise the
111        // body input is fed from the carried hidden_state.
112        //
113        // Inlining is therefore safe iff the caller does not rely on
114        // tract's across-call carry, which we encode as one of:
115        //   - reset_every_turn (carry is cleared every turn anyway), or
116        //   - external_state (caller plumbs the State input every call,
117        //     e.g. parakeet decoder).
118        // Otherwise (internal-managed state, e.g. DFN3 GRU under pulse),
119        // bail. See issue #2157.
120        rule_if!(
121            self.reset_every_turn
122                || self.external_state
123                || !self.input_mapping.iter().any(InputMapping::is_state)
124        );
125        let mut patch = TypedModelPatch::new("Inline single loop scan");
126        patch.model = self.body.clone();
127        for (outer_wire, inner_wire) in izip!(&node.inputs, &self.body.inputs) {
128            patch.taps.insert(*inner_wire, *outer_wire);
129        }
130        for (inner_wire, mapping) in izip!(&self.body.outputs, &self.output_mapping) {
131            if let Some((slot, _)) = mapping.scan {
132                patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
133            }
134            if let Some(slot) = mapping.last_value_slot {
135                patch.shunt_outside(model, (node.id, slot).into(), *inner_wire)?;
136            }
137        }
138        Ok(Some(patch))
139    }
140
141    fn declutter_body_axes(
142        &self,
143        _session: &mut OptimizerSession,
144        model: &TypedModel,
145        node: &TypedNode,
146    ) -> TractResult<Option<TypedModelPatch>> {
147        let mut suggestions = vec![];
148        for n in self.body.eval_order()? {
149            let node = self.body.node(n);
150            for suggestion in node.op.suggested_axis_changes()? {
151                let outlet = suggestion.0.as_outlet(node);
152                suggestions.push(AxisChange { outlet, op: suggestion.1 })
153            }
154            for (slot, fact) in node.outputs.iter().enumerate() {
155                for (ix, dim) in fact.fact.shape.iter().enumerate() {
156                    if dim.is_one() {
157                        suggestions.push(AxisChange {
158                            outlet: OutletId::new(n, slot),
159                            op: AxisOp::Rm(ix),
160                        });
161                    }
162                }
163            }
164        }
165        let node_input_facts = model.node_input_facts(node.id)?;
166        for suggestion in suggestions.into_iter() {
167            if let Some(conseq) = self.try_body_axes_change(suggestion, true, &node_input_facts)? {
168                let mut patch = TypedModelPatch::default();
169                let mut inputs = tvec!();
170                for outlet in &node.inputs {
171                    inputs.push(patch.tap_model(model, *outlet)?);
172                }
173                for change in conseq.wire_changes {
174                    if let InOut::In(i) = change.0 {
175                        let mut value = patch
176                            .outlet_fact(inputs[i])?
177                            .konst
178                            .clone()
179                            .context("Will only reshape constants")?
180                            .into_tensor();
181                        change.1.change_tensor(&mut value, false)?;
182                        let konst_name = patch.node(inputs[i].node).name.clone();
183                        inputs[i] = patch.add_const(konst_name, value)?;
184                    }
185                }
186                let wires = patch.wire_node(
187                    &node.name,
188                    conseq.substitute_op.unwrap_or_else(|| Box::new(self.clone())),
189                    &inputs,
190                )?;
191                for (ix, new) in wires.into_iter().enumerate() {
192                    patch.shunt_outside(model, OutletId::new(node.id, ix), new)?;
193                }
194                return Ok(Some(patch));
195            }
196        }
197        Ok(None)
198    }
199
200    fn remove_outer_output_from_mappings(
201        mappings: &[OutputMapping<TDim>],
202        discarded: usize,
203    ) -> Vec<OutputMapping<TDim>> {
204        mappings
205            .iter()
206            .map(|m| OutputMapping {
207                scan: m.scan.map(|(slot, info)| (slot - (slot > discarded) as usize, info)),
208                last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
209                full_dim_hint: m.full_dim_hint.clone(),
210                state: m.state,
211            })
212            .collect()
213    }
214
215    fn declutter_const_input(
216        &self,
217        _session: &mut OptimizerSession,
218        model: &TypedModel,
219        node: &TypedNode,
220    ) -> TractResult<Option<TypedModelPatch>> {
221        let inputs = model.node_input_facts(node.id)?;
222        for (slot, mapping) in self.input_mapping.iter().enumerate() {
223            if let InputMapping::Full = mapping
224                && let Some(konst) = inputs[slot].konst.as_ref()
225            {
226                let mut op = self.clone();
227                let src = op.body.inputs[slot];
228                op.body.inputs.remove(slot);
229                op.body.nodes[src.node].inputs.clear();
230                op.body.nodes[src.node].op = Box::new(Const::new(konst.clone())?);
231                op.input_mapping.remove(slot);
232                let mut inputs = node.inputs.clone();
233                inputs.remove(slot);
234                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
235            }
236        }
237        Ok(None)
238    }
239
240    fn declutter_discard_unused_input(
241        &self,
242        _session: &mut OptimizerSession,
243        model: &TypedModel,
244        node: &TypedNode,
245    ) -> TractResult<Option<TypedModelPatch>> {
246        for (slot, input) in self.body.input_outlets()?.iter().enumerate() {
247            let source_node = self.body.node(input.node);
248            if source_node.outputs[0].successors.len() == 0
249                && !self.body.output_outlets()?.contains(input)
250            {
251                let mut new_inputs = node.inputs.clone();
252                new_inputs.remove(slot);
253                let mut new_mappings: Vec<_> = self.input_mapping.clone();
254                new_mappings.remove(slot);
255                let mut model_inputs = self.body.input_outlets()?.to_vec();
256                model_inputs.remove(slot);
257                let mut body = self.body.clone();
258                let mut patch = TypedModelPatch::default();
259                patch.obliterate(source_node.id)?;
260                patch.apply(&mut body)?;
261                body.set_input_outlets(&model_inputs)?;
262                body.declutter()?;
263                let op =
264                    Self { body, input_mapping: new_mappings, decluttered: true, ..self.clone() };
265                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
266            }
267        }
268        Ok(None)
269    }
270
271    fn declutter_discard_useless_outer_output(
272        &self,
273        _session: &mut OptimizerSession,
274        model: &TypedModel,
275        node: &TypedNode,
276    ) -> TractResult<Option<TypedModelPatch>> {
277        for (ix, o) in node.outputs.iter().enumerate() {
278            if o.successors.len() == 0
279                && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
280            {
281                let mappings = self
282                    .output_mapping
283                    .iter()
284                    .map(|m| OutputMapping {
285                        scan: m.scan.filter(|(slot, _info)| *slot != ix),
286                        last_value_slot: m.last_value_slot.filter(|s| *s != ix),
287                        full_dim_hint: m.full_dim_hint.clone(),
288                        state: m.state,
289                    })
290                    .collect::<Vec<_>>();
291                let mut op = self.clone();
292                op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
293                let mut patch = TypedModelPatch::default();
294                let inputs = node
295                    .inputs
296                    .iter()
297                    .map(|&i| patch.tap_model(model, i))
298                    .collect::<TractResult<Vec<_>>>()?;
299                let wires = patch.wire_node(&*node.name, op, &inputs)?;
300                for oix in 0..node.outputs.len() {
301                    if oix != ix {
302                        patch.shunt_outside(
303                            model,
304                            OutletId::new(node.id, oix),
305                            wires[oix - (oix > ix) as usize],
306                        )?;
307                    }
308                }
309                return Ok(Some(patch));
310            }
311        }
312        Ok(None)
313    }
314
315    fn declutter_discard_empty_output_mapping_with_body_output(
316        &self,
317        _session: &mut OptimizerSession,
318        model: &TypedModel,
319        node: &TypedNode,
320    ) -> TractResult<Option<TypedModelPatch>> {
321        for (ix, om) in self.output_mapping.iter().enumerate() {
322            if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
323                let mut new_op = self.clone();
324                new_op.output_mapping.remove(ix);
325                new_op.body.outputs.remove(ix);
326                new_op.decluttered = false;
327                return Ok(Some(TypedModelPatch::replace_single_op(
328                    model,
329                    node,
330                    &node.inputs,
331                    new_op,
332                )?));
333            }
334        }
335        Ok(None)
336    }
337
338    fn declutter_pull_batcheable_input(
339        &self,
340        _session: &mut OptimizerSession,
341        model: &TypedModel,
342        node: &TypedNode,
343    ) -> TractResult<Option<TypedModelPatch>> {
344        'candidate: for (slot, input) in self.input_mapping.iter().enumerate() {
345            if let Some(scan_info) = input.as_scan() {
346                let scan_source = self.body.input_outlets()?[slot];
347                let scan_source_node = self.body.node(scan_source.node);
348                for mut succ in &scan_source_node.outputs[0].successors {
349                    for &succ_input in &self.body.node(succ.node).inputs {
350                        if succ_input != scan_source
351                            && self.body.outlet_fact(succ_input)?.konst.is_none()
352                        {
353                            continue 'candidate;
354                        }
355                    }
356                    if self.body.node(succ.node).outputs.len() != 1 {
357                        continue;
358                    }
359                    let mut new_body = self.body.clone();
360                    // insert propagate axis on einsum
361                    if let Some(einsum) = new_body.node(succ.node).op_as::<EinSum>()
362                        && let Some(patch) = einsum
363                            .propagate_axis(
364                                &new_body,
365                                new_body.node(succ.node),
366                                InOut::In(succ.slot),
367                                scan_info.axis,
368                            )
369                            .context("building axis propagating patch")?
370                    {
371                        patch.apply(&mut new_body)?;
372                        new_body.compute_const_facts()?;
373                        // propagate axis injects new nodes at the end. last successor of input
374                        // in new net will be the new succ
375                        let new_body_scan_input = new_body.input_outlets()?[slot];
376                        succ = new_body.node(new_body_scan_input.node).outputs[0]
377                            .successors
378                            .last()
379                            .unwrap();
380                    }
381
382                    let axes_mapping = {
383                        let (input_facts, output_facts) =
384                            new_body.node_facts(new_body.node(succ.node).id)?;
385                        new_body.node(succ.node).op.axes_mapping(&input_facts, &output_facts)?
386                    };
387                    let axis_info = axes_mapping.axis((InOut::In(succ.slot), scan_info.axis))?;
388                    if let &[axis_after] = &*axis_info.outputs[0] {
389                        let mut outside_patch = TypedModelPatch::new(format!(
390                            "Outer patch for input extraction of {}",
391                            new_body.node(succ.node)
392                        ));
393                        let mut patch_inputs = node
394                            .inputs
395                            .iter()
396                            .map(|&i| outside_patch.tap_model(model, i))
397                            .collect::<TractResult<TVec<_>>>()?;
398                        let mut extracted_op_inputs = tvec!();
399                        for (ix, outlet) in new_body.node(succ.node).inputs.iter().enumerate() {
400                            let wire = if ix == succ.slot {
401                                patch_inputs[slot]
402                            } else if let Some(konst) =
403                                new_body.outlet_fact(*outlet)?.konst.as_ref()
404                            {
405                                outside_patch.add_const(
406                                    format!(
407                                        "{}.extracted.{}",
408                                        node.name,
409                                        new_body.node(outlet.node).name
410                                    ),
411                                    konst.clone(),
412                                )?
413                            } else {
414                                unreachable!();
415                            };
416                            extracted_op_inputs.push(wire);
417                        }
418                        let new_input_wire = outside_patch.wire_node(
419                            format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
420                            new_body.node(succ.node).op.clone(),
421                            &extracted_op_inputs,
422                        )?[0];
423                        patch_inputs.push(new_input_wire);
424                        let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
425                        let mut new_input_inner_fact = new_input_outer_fact.clone();
426                        new_input_inner_fact.shape.set(axis_after, scan_info.chunk.abs().to_dim());
427
428                        let mut new_body = new_body.clone();
429                        let new_source_wire = new_body.add_source(
430                            format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
431                            new_input_inner_fact,
432                        )?;
433                        let mut inner_patch = TypedModelPatch::new(format!(
434                            "Inner body patch for extraction of {}",
435                            new_body.node(succ.node)
436                        ));
437                        let new_source_wire_in_patch =
438                            inner_patch.tap_model(&new_body, new_source_wire)?;
439                        inner_patch
440                            .shunt_outside(
441                                &new_body,
442                                OutletId::new(succ.node, 0),
443                                new_source_wire_in_patch,
444                            )
445                            .with_context(|| "patching inner model")?;
446                        inner_patch.apply(&mut new_body)?;
447
448                        let mut input_mapping = self.input_mapping.clone();
449                        input_mapping.push(InputMapping::Scan(ScanInfo {
450                            axis: axis_after,
451                            chunk: scan_info.chunk,
452                        }));
453
454                        let new_op = Self {
455                            input_mapping,
456                            decluttered: false,
457                            body: new_body,
458                            ..self.clone()
459                        };
460                        let output_wires =
461                            outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
462                        for w in output_wires {
463                            outside_patch
464                                .shunt_outside(model, OutletId::new(node.id, w.slot), w)
465                                .with_context(|| "patching outer model")?;
466                        }
467                        return Ok(Some(outside_patch));
468                    }
469                }
470            }
471        }
472        Ok(None)
473    }
474
475    fn declutter_pull_constant_outputs(
476        &self,
477        _session: &mut OptimizerSession,
478        model: &TypedModel,
479        node: &TypedNode,
480    ) -> TractResult<Option<TypedModelPatch>> {
481        for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
482            if let Some(slot) = mapping.last_value_slot
483                && let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone()
484            {
485                let inner_node = self.body.output_outlets()?[model_output_ix].node;
486                let inner_node = self.body.node(inner_node);
487                let mut patch = TypedModelPatch::new(format!("Extract const node {inner_node}"));
488                let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
489                patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
490                return Ok(Some(patch));
491            }
492        }
493        Ok(None)
494    }
495
496    fn declutter_pull_batcheable_output(
497        &self,
498        _session: &mut OptimizerSession,
499        model: &TypedModel,
500        node: &TypedNode,
501    ) -> TractResult<Option<TypedModelPatch>> {
502        for (mapping_ix, mapping) in self.output_mapping.iter().enumerate() {
503            if let Some((_, scan_info)) = mapping.scan {
504                let emitter_outlet = self.body.output_outlets()?[mapping_ix];
505                if self.body.node(emitter_outlet.node).outputs[emitter_outlet.slot].successors.len()
506                    > 0
507                    || self.body.inputs.contains(&emitter_outlet)
508                    || mapping.state
509                    || mapping.scan.map(|(_slot, i)| i.chunk > 1).unwrap_or(true)
510                {
511                    // continue if both last_value and full values are exported
512                    continue;
513                }
514                let mut new_body = self.body.clone();
515                if let Some(einsum) = new_body.node(emitter_outlet.node).op_as::<EinSum>()
516                    && let Some(patch) = einsum
517                        .propagate_axis(
518                            &new_body,
519                            new_body.node(emitter_outlet.node),
520                            InOut::Out(0),
521                            scan_info.axis,
522                        )
523                        .context("building axis propagating patch")?
524                {
525                    patch.apply(&mut new_body)?;
526                    new_body.prop_consts()?;
527                }
528                let emitter_outlet = new_body.output_outlets()?[mapping_ix];
529                let invariants = {
530                    let (input_facts, output_facts) = new_body.node_facts(emitter_outlet.node)?;
531                    new_body
532                        .node(emitter_outlet.node)
533                        .op
534                        .axes_mapping(&input_facts, &output_facts)?
535                };
536                let axis_tracking =
537                    invariants.axis((InOut::Out(emitter_outlet.slot), scan_info.axis))?;
538                rule_if!(axis_tracking.outputs.iter().all(|o| o.len() == 1));
539                let mut new_output_mapping = self.output_mapping.clone();
540                let mut new_scan_outputs = node.outputs.len();
541                let mut outer_slots = vec![];
542
543                // rewire input of the extracted node through the scan outlet boundary
544                for (input_slot, input) in
545                    new_body.node(emitter_outlet.node).inputs.clone().iter().enumerate()
546                {
547                    if new_body.outputs.iter().all(|o| o != input) {
548                        new_output_mapping.push(OutputMapping::default());
549                        new_body.outputs.push(*input);
550                    }
551                    let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
552                    let mapping = &mut new_output_mapping[body_output_id];
553                    let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
554                        if mapping.last_value_slot.is_none() {
555                            mapping.last_value_slot = Some(new_scan_outputs);
556                            new_scan_outputs += 1;
557                        }
558                        mapping.last_value_slot.unwrap()
559                    } else if let &[axis] = &*axis_tracking.inputs[input_slot] {
560                        if mapping.scan.is_none() {
561                            mapping.scan =
562                                Some((new_scan_outputs, ScanInfo { axis, chunk: scan_info.chunk }));
563                            new_scan_outputs += 1;
564                        }
565                        mapping.scan.unwrap().0
566                    } else {
567                        return Ok(None);
568                    };
569                    outer_slots.push(outer_slot);
570                }
571                let mut outside_patch = TypedModelPatch::new(format!(
572                    "Outside patch for output extraction of {}",
573                    new_body.node(emitter_outlet.node)
574                ));
575                let inputs = node
576                    .inputs
577                    .iter()
578                    .map(|&i| outside_patch.tap_model(model, i))
579                    .collect::<TractResult<TVec<_>>>()?;
580                let new_op = Self {
581                    output_mapping: new_output_mapping,
582                    decluttered: false,
583                    body: new_body.clone(), // FIXME maybe remove clone
584                    ..self.clone()
585                };
586                let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
587                let output = mapping.scan.unwrap();
588                let inputs =
589                    outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
590                let wire = outside_patch.wire_node(
591                    &new_body.node(emitter_outlet.node).name,
592                    new_body.node(emitter_outlet.node).op.clone(),
593                    &inputs,
594                )?[0];
595                outside_patch.shunt_outside(model, OutletId::new(node.id, output.0), wire)?;
596                for output_slot in 0..node.outputs.len() {
597                    if output_slot != output.0 {
598                        outside_patch.shunt_outside(
599                            model,
600                            OutletId::new(node.id, output_slot),
601                            OutletId::new(scan_outputs[0].node, output_slot),
602                        )?;
603                    }
604                }
605                return Ok(Some(outside_patch));
606            }
607        }
608        Ok(None)
609    }
610
611    fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
612        let input_state_outlets = self
613            .input_mapping
614            .iter()
615            .zip(self.body.input_outlets()?.iter())
616            .filter(|(m, _)| m.is_state())
617            .map(|(_, o)| o);
618        let output_state_outlets = self
619            .output_mapping
620            .iter()
621            .zip(self.body.output_outlets()?.iter())
622            .filter(|(m, _)| m.state)
623            .map(|(_, o)| o);
624        Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
625    }
626
627    fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult<TVec<OutletId>> {
628        let input_outlets =
629            self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| {
630                if node_input_facts[slot].konst.is_none() { Some(o) } else { None }
631            });
632        let output_outlets = self
633            .output_mapping
634            .iter()
635            .zip(self.body.output_outlets()?.iter())
636            .filter(|(m, _)| !m.invisible())
637            .map(|(_, o)| o);
638        Ok(input_outlets.chain(output_outlets).cloned().collect())
639    }
640
641    fn try_body_axes_change(
642        &self,
643        change: AxisChange,
644        locked_interface: bool,
645        node_input_facts: &[&TypedFact],
646    ) -> TractResult<Option<AxisChangeConsequence>> {
647        self.body.check_consistency()?;
648        let locked_outlets = self.body_locked_outlets(node_input_facts)?;
649        let mut explored: HashSet<AxisChange> = Default::default();
650        rule_if_some!(
651            (body_patch, body_changed_wires) = crate::optim::change_axes::change_axes(
652                &self.body,
653                &change,
654                if locked_interface { &locked_outlets } else { &[] },
655                &self.body_bounds()?,
656                &mut explored,
657            )?
658        );
659        let mut body = self.body.clone();
660        body_patch.apply(&mut body)?;
661        body.compact()?;
662        let mut wire_changes = tvec!();
663        let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
664        for (slot, m) in input_mapping.iter_mut().enumerate() {
665            if let Some(change) = body_changed_wires
666                .iter()
667                .find(|(iface, _change)| iface == &InOut::In(slot))
668                .map(|pair| pair.1.clone())
669            {
670                wire_changes.push((InOut::In(slot), change.clone()));
671                if let InputMapping::Scan(info) = m {
672                    rule_if_some!(axis = change.transform_axis(info.axis));
673                    info.axis = axis;
674                };
675            }
676        }
677        let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
678        for (ix, m) in output_mapping.iter_mut().enumerate() {
679            if let Some(change) = body_changed_wires
680                .iter()
681                .find(|(iface, _change)| iface == &InOut::Out(ix))
682                .map(|pair| pair.1.clone())
683            {
684                if let Some((slot, info)) = m.scan.as_mut() {
685                    rule_if_some!(new_axis = change.transform_axis(info.axis));
686                    info.axis = new_axis;
687                    wire_changes.push((InOut::Out(*slot), change.clone()));
688                }
689                if let Some(slot) = m.last_value_slot {
690                    wire_changes.push((InOut::Out(slot), change.clone()));
691                }
692            };
693        }
694        body.check_consistency()?;
695        let op = Some(Box::new(Scan {
696            body,
697            input_mapping,
698            output_mapping,
699            decluttered: false,
700            ..self.clone()
701        }) as _);
702        Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
703    }
704}
705
706impl Op for Scan {
707    fn name(&self) -> StaticName {
708        "Scan".into()
709    }
710
711    fn info(&self) -> TractResult<Vec<String>> {
712        let mut lines = vec![];
713        for (ix, im) in self.input_mapping.iter().enumerate() {
714            lines.push(format!("Model input  #{ix}: {im:?}"));
715        }
716        for (ix, om) in self.output_mapping.iter().enumerate() {
717            lines.push(format!("Model output #{ix}: {om:?}"));
718        }
719        lines.push(format!("skip:{} reset_every_turn:{:?}", self.skip, self.reset_every_turn));
720        Ok(lines)
721    }
722
723    fn validation(&self) -> Validation {
724        Validation::Rounding
725    }
726
727    op_as_typed_op!();
728}
729
730impl EvalOp for Scan {
731    fn is_stateless(&self) -> bool {
732        false
733    }
734    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
735        self.to_codegen_op(false)?.state(session, node_id)
736    }
737}
738
739impl TypedOp for Scan {
740    as_op!();
741
742    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
743        anyhow::ensure!(inputs.len() == self.body.inputs.len());
744        anyhow::ensure!(self.input_mapping.len() == self.body.inputs.len());
745        anyhow::ensure!(
746            self.input_mapping.iter().filter(|m| m.is_state()).count()
747                == self.output_mapping.iter().filter(|m| m.state).count()
748        );
749        for (i, o) in
750            self.input_mapping.iter().enumerate().filter(|(_, m)| m.is_state()).map(|(i, _)| i).zip(
751                self.output_mapping.iter().enumerate().filter(|(_, m)| m.state).map(|(o, _)| o),
752            )
753        {
754            let ifact = self.body.outlet_fact(self.body.inputs[i])?;
755            let ofact = self.body.outlet_fact(self.body.outputs[o])?;
756            anyhow::ensure!(
757                ifact == ofact,
758                "inconsistent fact: body input {i} is {ifact:?} and body output {o} is {ofact:?}\n{}",
759                self.body
760            )
761        }
762        let mut outputs = tvec!();
763        let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?;
764        for (ix, output) in self.output_mapping.iter().enumerate() {
765            let fact = self.body.output_fact(ix)?;
766            if let Some((slot, info)) = output.scan {
767                let mut shape = fact.shape.clone();
768                let scanning_dim =
769                    output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
770                shape.set(info.axis, scanning_dim);
771                outputs.push((slot, fact.datum_type.fact(shape)));
772            }
773            if let Some(slot) = output.last_value_slot {
774                outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
775            }
776        }
777        outputs.sort_by_key(|a| a.0);
778        anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
779        let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
780        Ok(outputs)
781    }
782
783    fn axes_mapping(
784        &self,
785        inputs: &[&TypedFact],
786        outputs: &[&TypedFact],
787    ) -> TractResult<AxesMapping> {
788        let mut mappings = vec![];
789        let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?;
790        for body_axis in body_invs.iter_all_axes() {
791            let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len());
792            info.inputs.clone_from(&body_axis.inputs);
793            for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
794                let mut slots = vec![];
795                if let Some((slot, _scan)) = output_mapping.scan {
796                    slots.push(slot);
797                }
798                if let Some(slot) = output_mapping.last_value_slot {
799                    slots.push(slot);
800                }
801                for slot in slots {
802                    info.outputs[slot].clone_from(&body_axis.outputs[ix]);
803                }
804            }
805            if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) {
806                mappings.push(info);
807            }
808        }
809        AxesMapping::new(inputs.len(), outputs.len(), mappings)
810    }
811
812    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
813        let mut suggestions = tvec!();
814        for (slot, input) in self.input_mapping.iter().enumerate() {
815            if let InputMapping::Scan(info) = input
816                && info.axis != 0
817            {
818                suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0)))
819            }
820        }
821        for output in &self.output_mapping {
822            if let Some((slot, scan)) = output.scan
823                && scan.axis != 0
824            {
825                suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0)))
826            }
827        }
828        Ok(suggestions)
829    }
830
831    fn change_axes(
832        &self,
833        model: &TypedModel,
834        node: &TypedNode,
835        io: InOut,
836        change: &AxisOp,
837    ) -> TractResult<Option<AxisChangeConsequence>> {
838        trace!("Propagating through {node}: {io:?} {change:?}");
839        let body_leading_outlet = match io {
840            InOut::In(ix) => self.body.input_outlets()?[ix],
841            InOut::Out(slot) => {
842                let output = self
843                    .output_mapping
844                    .iter()
845                    .position(|im| {
846                        im.scan.map(|(slot, _i)| slot) == Some(slot)
847                            || im.last_value_slot == Some(slot)
848                    })
849                    .unwrap();
850                self.body.output_outlets()?[output]
851            }
852        };
853        let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() };
854        let node_input_facts = model.node_input_facts(node.id)?;
855        let result = self
856            .try_body_axes_change(axis_change, false, &node_input_facts)
857            .with_context(|| "Attemping to run change through scan body".to_string())?;
858        if result.is_some() {
859            trace!("{node} accepted axis change");
860        } else {
861            trace!("{node} rejected axis change");
862        }
863        Ok(result)
864    }
865
866    fn declutter_with_session(
867        &self,
868        session: &mut OptimizerSession,
869        model: &TypedModel,
870        node: &TypedNode,
871    ) -> TractResult<Option<TypedModelPatch>> {
872        macro_rules! pass {
873            ($func:ident) => {
874                if let Some(mut r) = self
875                    .$func(session, model, node)
876                    .with_context(|| format!("{}", stringify!($func)))?
877                {
878                    trace!(stringify!($func));
879                    r.push_context(stringify!($func));
880                    return Ok(Some(r));
881                }
882            };
883        }
884        pass!(declutter_single_loop);
885        pass!(declutter_const_input);
886        pass!(declutter_discard_unused_input);
887        pass!(declutter_discard_useless_outer_output);
888        pass!(declutter_discard_empty_output_mapping_with_body_output);
889        pass!(declutter_body);
890        pass!(declutter_body_axes);
891        pass!(declutter_pull_constant_outputs);
892        pass!(declutter_pull_batcheable_input);
893        pass!(declutter_pull_batcheable_output);
894        Ok(None)
895    }
896
897    fn set_symbols(
898        &self,
899        _source: &TypedModel,
900        node: &TypedNode,
901        target: &mut TypedModel,
902        mapping: &HashMap<OutletId, OutletId>,
903        subs: &HashMap<Symbol, TDim>,
904    ) -> TractResult<TVec<OutletId>> {
905        let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
906        let op = Self {
907            output_mapping: self
908                .output_mapping
909                .iter()
910                .map(|om| om.set_symbols(subs))
911                .collect::<TractResult<Vec<_>>>()?,
912            body: self.body.set_symbols(subs)?,
913            ..self.clone()
914        };
915        target.wire_node(&node.name, op, &inputs)
916    }
917
918    fn codegen(
919        &self,
920        model: &TypedModel,
921        node: &TypedNode,
922    ) -> TractResult<Option<TypedModelPatch>> {
923        Ok(Some(TypedModelPatch::replace_single_op(
924            model,
925            node,
926            &node.inputs,
927            self.to_codegen_op(true)?,
928        )?))
929    }
930}