tract_core/ops/scan/
decluttered.rs

1use std::collections::HashSet;
2
3use crate::ops::einsum::EinSum;
4use crate::ops::konst::Const;
5use crate::optim::OptimizerSession;
6
7use super::optimized::{OptScan, ScanOpParams};
8use tract_data::internal::*;
9
10use super::*;
11
12#[derive(Debug, Clone, Default)]
13pub struct Scan {
14    pub skip: usize,
15    pub reset_every_turn: bool,
16    pub body: TypedModel,
17    pub decluttered: bool,
18    pub input_mapping: Vec<InputMapping>,
19    pub output_mapping: Vec<OutputMapping<TDim>>,
20}
21
22impl Scan {
23    pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<OptScan> {
24        let mut model = self.body.clone();
25        if optimize_inner {
26            model = model.into_optimized()?;
27        }
28        let plan = SimplePlan::new(model)?;
29
30        Ok(OptScan::new(Arc::new(ScanOpParams::new(
31            self.skip,
32            self.reset_every_turn,
33            Arc::new(plan),
34            self.input_mapping.clone(),
35            self.output_mapping.clone(),
36        ))))
37    }
38
39    pub fn new(
40        body: TypedModel,
41        input_mapping: Vec<InputMapping>,
42        output_mapping: Vec<OutputMapping<TDim>>,
43        skip: usize,
44    ) -> TractResult<Scan> {
45        body.check_consistency()?;
46        ensure!(input_mapping.len() == body.input_outlets()?.len());
47        ensure!(output_mapping.len() == body.output_outlets()?.len());
48        Ok(Scan {
49            skip,
50            reset_every_turn: false,
51            body,
52            decluttered: false,
53            input_mapping,
54            output_mapping,
55        })
56    }
57
58    pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
59        self.to_codegen_op(false).unwrap().iteration_count(inputs)
60    }
61
62    fn declutter_body(
63        &self,
64        session: &mut OptimizerSession,
65        model: &TypedModel,
66        node: &TypedNode,
67    ) -> TractResult<Option<TypedModelPatch>> {
68        if !self.decluttered {
69            let mut new = self.clone();
70            let mut body = self.body.clone();
71            session.optimize(&mut body)?;
72            new.body = body;
73            new.decluttered = true;
74            Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
75        } else {
76            Ok(None)
77        }
78    }
79
80    fn declutter_body_axes(
81        &self,
82        _session: &mut OptimizerSession,
83        model: &TypedModel,
84        node: &TypedNode,
85    ) -> TractResult<Option<TypedModelPatch>> {
86        let mut suggestions = vec![];
87        for n in self.body.eval_order()? {
88            let node = self.body.node(n);
89            for suggestion in node.op.suggested_axis_changes()? {
90                let outlet = suggestion.0.as_outlet(node);
91                suggestions.push(AxisChange { outlet, op: suggestion.1 })
92            }
93            for (slot, fact) in node.outputs.iter().enumerate() {
94                for (ix, dim) in fact.fact.shape.iter().enumerate() {
95                    if dim.is_one() {
96                        suggestions.push(AxisChange {
97                            outlet: OutletId::new(n, slot),
98                            op: AxisOp::Rm(ix),
99                        });
100                    }
101                }
102            }
103        }
104        let node_input_facts = model.node_input_facts(node.id)?;
105        for suggestion in suggestions.into_iter() {
106            if let Some(conseq) = self.try_body_axes_change(suggestion, true, &node_input_facts)? {
107                let mut patch = TypedModelPatch::default();
108                let mut inputs = tvec!();
109                for outlet in &node.inputs {
110                    inputs.push(patch.tap_model(model, *outlet)?);
111                }
112                for change in conseq.wire_changes {
113                    if let InOut::In(i) = change.0 {
114                        let mut value = patch
115                            .outlet_fact(inputs[i])?
116                            .konst
117                            .clone()
118                            .context("Will only reshape constants")?
119                            .into_tensor();
120                        change.1.change_tensor(&mut value, false)?;
121                        let konst_name = patch.node(inputs[i].node).name.clone();
122                        inputs[i] = patch.add_const(konst_name, value)?;
123                    }
124                }
125                let wires = patch.wire_node(
126                    &node.name,
127                    conseq.substitute_op.unwrap_or_else(|| Box::new(self.clone())),
128                    &inputs,
129                )?;
130                for (ix, new) in wires.into_iter().enumerate() {
131                    patch.shunt_outside(model, OutletId::new(node.id, ix), new)?;
132                }
133                return Ok(Some(patch));
134            }
135        }
136        Ok(None)
137    }
138
139    fn remove_outer_output_from_mappings(
140        mappings: &[OutputMapping<TDim>],
141        discarded: usize,
142    ) -> Vec<OutputMapping<TDim>> {
143        mappings
144            .iter()
145            .map(|m| OutputMapping {
146                scan: m.scan.map(|(slot, info)| (slot - (slot > discarded) as usize, info)),
147                last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
148                full_dim_hint: m.full_dim_hint.clone(),
149                state: m.state,
150            })
151            .collect()
152    }
153
154    fn declutter_const_input(
155        &self,
156        _session: &mut OptimizerSession,
157        model: &TypedModel,
158        node: &TypedNode,
159    ) -> TractResult<Option<TypedModelPatch>> {
160        let inputs = model.node_input_facts(node.id)?;
161        for (slot, mapping) in self.input_mapping.iter().enumerate() {
162            if let InputMapping::Full = mapping {
163                if let Some(konst) = inputs[slot].konst.as_ref() {
164                    let mut op = self.clone();
165                    let src = op.body.inputs[slot];
166                    op.body.inputs.remove(slot);
167                    op.body.nodes[src.node].inputs.clear();
168                    op.body.nodes[src.node].op = Box::new(Const::new(konst.clone()));
169                    op.input_mapping.remove(slot);
170                    let mut inputs = node.inputs.clone();
171                    inputs.remove(slot);
172                    return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
173                }
174            }
175        }
176        Ok(None)
177    }
178
179    fn declutter_discard_unused_input(
180        &self,
181        _session: &mut OptimizerSession,
182        model: &TypedModel,
183        node: &TypedNode,
184    ) -> TractResult<Option<TypedModelPatch>> {
185        for (slot, input) in self.body.input_outlets()?.iter().enumerate() {
186            let source_node = self.body.node(input.node);
187            if source_node.outputs[0].successors.len() == 0
188                && !self.body.output_outlets()?.contains(input)
189            {
190                let mut new_inputs = node.inputs.clone();
191                new_inputs.remove(slot);
192                let mut new_mappings: Vec<_> = self.input_mapping.clone();
193                new_mappings.remove(slot);
194                let mut model_inputs = self.body.input_outlets()?.to_vec();
195                model_inputs.remove(slot);
196                let mut body = self.body.clone();
197                let mut patch = TypedModelPatch::default();
198                patch.obliterate(source_node.id)?;
199                patch.apply(&mut body)?;
200                body.set_input_outlets(&model_inputs)?;
201                body.declutter()?;
202                let op =
203                    Self { body, input_mapping: new_mappings, decluttered: true, ..self.clone() };
204                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
205            }
206        }
207        Ok(None)
208    }
209
210    fn declutter_discard_useless_outer_output(
211        &self,
212        _session: &mut OptimizerSession,
213        model: &TypedModel,
214        node: &TypedNode,
215    ) -> TractResult<Option<TypedModelPatch>> {
216        for (ix, o) in node.outputs.iter().enumerate() {
217            if o.successors.len() == 0
218                && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
219            {
220                let mappings = self
221                    .output_mapping
222                    .iter()
223                    .map(|m| OutputMapping {
224                        scan: m.scan.filter(|(slot, _info)| *slot != ix),
225                        last_value_slot: m.last_value_slot.filter(|s| *s != ix),
226                        full_dim_hint: m.full_dim_hint.clone(),
227                        state: m.state,
228                    })
229                    .collect::<Vec<_>>();
230                let mut op = self.clone();
231                op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
232                let mut patch = TypedModelPatch::default();
233                let inputs = node
234                    .inputs
235                    .iter()
236                    .map(|&i| patch.tap_model(model, i))
237                    .collect::<TractResult<Vec<_>>>()?;
238                let wires = patch.wire_node(&*node.name, op, &inputs)?;
239                for oix in 0..node.outputs.len() {
240                    if oix != ix {
241                        patch.shunt_outside(
242                            model,
243                            OutletId::new(node.id, oix),
244                            wires[oix - (oix > ix) as usize],
245                        )?;
246                    }
247                }
248                return Ok(Some(patch));
249            }
250        }
251        Ok(None)
252    }
253
254    fn declutter_discard_empty_output_mapping_with_body_output(
255        &self,
256        _session: &mut OptimizerSession,
257        model: &TypedModel,
258        node: &TypedNode,
259    ) -> TractResult<Option<TypedModelPatch>> {
260        for (ix, om) in self.output_mapping.iter().enumerate() {
261            if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
262                let mut new_op = self.clone();
263                new_op.output_mapping.remove(ix);
264                new_op.body.outputs.remove(ix);
265                new_op.decluttered = false;
266                return Ok(Some(TypedModelPatch::replace_single_op(
267                    model,
268                    node,
269                    &node.inputs,
270                    new_op,
271                )?));
272            }
273        }
274        Ok(None)
275    }
276
277    fn declutter_pull_batcheable_input(
278        &self,
279        _session: &mut OptimizerSession,
280        model: &TypedModel,
281        node: &TypedNode,
282    ) -> TractResult<Option<TypedModelPatch>> {
283        'candidate: for (slot, input) in self.input_mapping.iter().enumerate() {
284            if let Some(scan_info) = input.as_scan() {
285                let scan_source = self.body.input_outlets()?[slot];
286                let scan_source_node = self.body.node(scan_source.node);
287                for mut succ in &scan_source_node.outputs[0].successors {
288                    for &succ_input in &self.body.node(succ.node).inputs {
289                        if succ_input != scan_source
290                            && self.body.outlet_fact(succ_input)?.konst.is_none()
291                        {
292                            continue 'candidate;
293                        }
294                    }
295                    if self.body.node(succ.node).outputs.len() != 1 {
296                        continue;
297                    }
298                    let mut new_body = self.body.clone();
299                    // insert propagate axis on einsum
300                    if let Some(einsum) = new_body.node(succ.node).op_as::<EinSum>() {
301                        if let Some(patch) = einsum
302                            .propagate_axis(
303                                &new_body,
304                                new_body.node(succ.node),
305                                InOut::In(succ.slot),
306                                scan_info.axis,
307                            )
308                            .context("building axis propagating patch")?
309                        {
310                            patch.apply(&mut new_body)?;
311                            new_body.compute_const_facts()?;
312                            // propagate axis injects new nodes at the end. last successor of input
313                            // in new net will be the new succ
314                            let new_body_scan_input = new_body.input_outlets()?[slot];
315                            succ = new_body.node(new_body_scan_input.node).outputs[0]
316                                .successors
317                                .last()
318                                .unwrap();
319                        }
320                    }
321
322                    let axes_mapping = {
323                        let (input_facts, output_facts) =
324                            new_body.node_facts(new_body.node(succ.node).id)?;
325                        new_body.node(succ.node).op.axes_mapping(&input_facts, &output_facts)?
326                    };
327                    let axis_info = axes_mapping.axis((InOut::In(succ.slot), scan_info.axis))?;
328                    if let &[axis_after] = &*axis_info.outputs[0] {
329                        let mut outside_patch = TypedModelPatch::new(format!(
330                            "Outer patch for input extraction of {}",
331                            new_body.node(succ.node)
332                        ));
333                        let mut patch_inputs = node
334                            .inputs
335                            .iter()
336                            .map(|&i| outside_patch.tap_model(model, i))
337                            .collect::<TractResult<TVec<_>>>()?;
338                        let mut extracted_op_inputs = tvec!();
339                        for (ix, outlet) in new_body.node(succ.node).inputs.iter().enumerate() {
340                            let wire = if ix == succ.slot {
341                                patch_inputs[slot]
342                            } else if let Some(konst) =
343                                new_body.outlet_fact(*outlet)?.konst.as_ref()
344                            {
345                                outside_patch.add_const(
346                                    format!(
347                                        "{}.extracted.{}",
348                                        node.name,
349                                        new_body.node(outlet.node).name
350                                    ),
351                                    konst.clone(),
352                                )?
353                            } else {
354                                unreachable!();
355                            };
356                            extracted_op_inputs.push(wire);
357                        }
358                        let new_input_wire = outside_patch.wire_node(
359                            format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
360                            new_body.node(succ.node).op.clone(),
361                            &extracted_op_inputs,
362                        )?[0];
363                        patch_inputs.push(new_input_wire);
364                        let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
365                        let mut new_input_inner_fact = new_input_outer_fact.clone();
366                        new_input_inner_fact.shape.set(axis_after, scan_info.chunk.abs().to_dim());
367
368                        let mut new_body = new_body.clone();
369                        let new_source_wire = new_body.add_source(
370                            format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
371                            new_input_inner_fact,
372                        )?;
373                        let mut inner_patch = TypedModelPatch::new(format!(
374                            "Inner body patch for extraction of {}",
375                            new_body.node(succ.node)
376                        ));
377                        let new_source_wire_in_patch =
378                            inner_patch.tap_model(&new_body, new_source_wire)?;
379                        inner_patch
380                            .shunt_outside(
381                                &new_body,
382                                OutletId::new(succ.node, 0),
383                                new_source_wire_in_patch,
384                            )
385                            .with_context(|| "patching inner model")?;
386                        inner_patch.apply(&mut new_body)?;
387
388                        let mut input_mapping = self.input_mapping.clone();
389                        input_mapping.push(InputMapping::Scan(ScanInfo {
390                            axis: axis_after,
391                            chunk: scan_info.chunk,
392                        }));
393
394                        let new_op = Self {
395                            input_mapping,
396                            decluttered: false,
397                            body: new_body,
398                            ..self.clone()
399                        };
400                        let output_wires =
401                            outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
402                        for w in output_wires {
403                            outside_patch
404                                .shunt_outside(model, OutletId::new(node.id, w.slot), w)
405                                .with_context(|| "patching outer model")?;
406                        }
407                        return Ok(Some(outside_patch));
408                    }
409                }
410            }
411        }
412        Ok(None)
413    }
414
415    fn declutter_pull_constant_outputs(
416        &self,
417        _session: &mut OptimizerSession,
418        model: &TypedModel,
419        node: &TypedNode,
420    ) -> TractResult<Option<TypedModelPatch>> {
421        for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
422            if let Some(slot) = mapping.last_value_slot {
423                if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
424                    let inner_node = self.body.output_outlets()?[model_output_ix].node;
425                    let inner_node = self.body.node(inner_node);
426                    let mut patch =
427                        TypedModelPatch::new(format!("Extract const node {inner_node}"));
428                    let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
429                    patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
430                    return Ok(Some(patch));
431                }
432            }
433        }
434        Ok(None)
435    }
436
437    fn declutter_pull_batcheable_output(
438        &self,
439        _session: &mut OptimizerSession,
440        model: &TypedModel,
441        node: &TypedNode,
442    ) -> TractResult<Option<TypedModelPatch>> {
443        for (mapping_ix, mapping) in self.output_mapping.iter().enumerate() {
444            if let Some((_, scan_info)) = mapping.scan {
445                let emitter_outlet = self.body.output_outlets()?[mapping_ix];
446                if self.body.node(emitter_outlet.node).outputs[emitter_outlet.slot].successors.len()
447                    > 0
448                    || self.body.inputs.contains(&emitter_outlet)
449                    || mapping.state
450                    || mapping.scan.map(|(_slot, i)| i.chunk > 1).unwrap_or(true)
451                {
452                    // continue if both last_value and full values are exported
453                    continue;
454                }
455                let mut new_body = self.body.clone();
456                if let Some(einsum) = new_body.node(emitter_outlet.node).op_as::<EinSum>() {
457                    if let Some(patch) = einsum
458                        .propagate_axis(
459                            &new_body,
460                            new_body.node(emitter_outlet.node),
461                            InOut::Out(0),
462                            scan_info.axis,
463                        )
464                        .context("building axis propagating patch")?
465                    {
466                        patch.apply(&mut new_body)?;
467                        new_body.prop_consts()?;
468                    }
469                }
470                let emitter_outlet = new_body.output_outlets()?[mapping_ix];
471                let invariants = {
472                    let (input_facts, output_facts) = new_body.node_facts(emitter_outlet.node)?;
473                    new_body
474                        .node(emitter_outlet.node)
475                        .op
476                        .axes_mapping(&input_facts, &output_facts)?
477                };
478                let axis_tracking =
479                    invariants.axis((InOut::Out(emitter_outlet.slot), scan_info.axis))?;
480                if axis_tracking.outputs.iter().any(|o| o.len() > 1) {
481                    return Ok(None);
482                }
483                let mut new_output_mapping = self.output_mapping.clone();
484                let mut new_scan_outputs = node.outputs.len();
485                let mut outer_slots = vec![];
486
487                // rewire input of the extracted node through the scan outlet boundary
488                for (input_slot, input) in
489                    new_body.node(emitter_outlet.node).inputs.clone().iter().enumerate()
490                {
491                    if new_body.outputs.iter().all(|o| o != input) {
492                        new_output_mapping.push(OutputMapping::default());
493                        new_body.outputs.push(*input);
494                    }
495                    let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
496                    let mapping = &mut new_output_mapping[body_output_id];
497                    let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
498                        if mapping.last_value_slot.is_none() {
499                            mapping.last_value_slot = Some(new_scan_outputs);
500                            new_scan_outputs += 1;
501                        }
502                        mapping.last_value_slot.unwrap()
503                    } else if let &[axis] = &*axis_tracking.inputs[input_slot] {
504                        if mapping.scan.is_none() {
505                            mapping.scan =
506                                Some((new_scan_outputs, ScanInfo { axis, chunk: scan_info.chunk }));
507                            new_scan_outputs += 1;
508                        }
509                        mapping.scan.unwrap().0
510                    } else {
511                        return Ok(None);
512                    };
513                    outer_slots.push(outer_slot);
514                }
515                let mut outside_patch = TypedModelPatch::new(format!(
516                    "Outside patch for output extraction of {}",
517                    new_body.node(emitter_outlet.node)
518                ));
519                let inputs = node
520                    .inputs
521                    .iter()
522                    .map(|&i| outside_patch.tap_model(model, i))
523                    .collect::<TractResult<TVec<_>>>()?;
524                let new_op = Self {
525                    output_mapping: new_output_mapping,
526                    decluttered: false,
527                    body: new_body.clone(), // FIXME maybe remove clone
528                    ..self.clone()
529                };
530                let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
531                let output = mapping.scan.unwrap();
532                let inputs =
533                    outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
534                let wire = outside_patch.wire_node(
535                    &new_body.node(emitter_outlet.node).name,
536                    new_body.node(emitter_outlet.node).op.clone(),
537                    &inputs,
538                )?[0];
539                outside_patch.shunt_outside(model, OutletId::new(node.id, output.0), wire)?;
540                for output_slot in 0..node.outputs.len() {
541                    if output_slot != output.0 {
542                        outside_patch.shunt_outside(
543                            model,
544                            OutletId::new(node.id, output_slot),
545                            OutletId::new(scan_outputs[0].node, output_slot),
546                        )?;
547                    }
548                }
549                return Ok(Some(outside_patch));
550            }
551        }
552        Ok(None)
553    }
554
555    fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
556        let input_state_outlets = self
557            .input_mapping
558            .iter()
559            .zip(self.body.input_outlets()?.iter())
560            .filter(|(m, _)| m.is_state())
561            .map(|(_, o)| o);
562        let output_state_outlets = self
563            .output_mapping
564            .iter()
565            .zip(self.body.output_outlets()?.iter())
566            .filter(|(m, _)| m.state)
567            .map(|(_, o)| o);
568        Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
569    }
570
571    fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult<TVec<OutletId>> {
572        let input_outlets =
573            self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| {
574                if node_input_facts[slot].konst.is_none() {
575                    Some(o)
576                } else {
577                    None
578                }
579            });
580        let output_outlets = self
581            .output_mapping
582            .iter()
583            .zip(self.body.output_outlets()?.iter())
584            .filter(|(m, _)| !m.invisible())
585            .map(|(_, o)| o);
586        Ok(input_outlets.chain(output_outlets).cloned().collect())
587    }
588
589    fn try_body_axes_change(
590        &self,
591        change: AxisChange,
592        locked_interface: bool,
593        node_input_facts: &[&TypedFact],
594    ) -> TractResult<Option<AxisChangeConsequence>> {
595        self.body.check_consistency()?;
596        let locked_outlets = self.body_locked_outlets(node_input_facts)?;
597        let mut explored: HashSet<AxisChange> = Default::default();
598        let (body_patch, body_changed_wires) = if let Some(changes) =
599            crate::optim::change_axes::change_axes(
600                &self.body,
601                &change,
602                if locked_interface { &locked_outlets } else { &[] },
603                &self.body_bounds()?,
604                &mut explored
605            )? {
606            changes
607        } else {
608            return Ok(None);
609        };
610        let mut body = self.body.clone();
611        body_patch.apply(&mut body)?;
612        body.compact()?;
613        let mut wire_changes = tvec!();
614        let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
615        for (slot, m) in input_mapping.iter_mut().enumerate() {
616            if let Some(change) = body_changed_wires
617                .iter()
618                .find(|(iface, _change)| iface == &InOut::In(slot))
619                .map(|pair| pair.1.clone())
620            {
621                wire_changes.push((InOut::In(slot), change.clone()));
622                if let InputMapping::Scan(info) = m {
623                    if let Some(axis) = change.transform_axis(info.axis) {
624                        info.axis = axis;
625                    } else {
626                        return Ok(None);
627                    };
628                };
629            }
630        }
631        let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
632        for (ix, m) in output_mapping.iter_mut().enumerate() {
633            if let Some(change) = body_changed_wires
634                .iter()
635                .find(|(iface, _change)| iface == &InOut::Out(ix))
636                .map(|pair| pair.1.clone())
637            {
638                if let Some((slot, info)) = m.scan.as_mut() {
639                    if let Some(new_axis) = change.transform_axis(info.axis) {
640                        info.axis = new_axis;
641                    } else {
642                        return Ok(None);
643                    }
644                    wire_changes.push((InOut::Out(*slot), change.clone()));
645                }
646                if let Some(slot) = m.last_value_slot {
647                    wire_changes.push((InOut::Out(slot), change.clone()));
648                }
649            };
650        }
651        body.check_consistency()?;
652        let op = Some(Box::new(Scan {
653            body,
654            input_mapping,
655            output_mapping,
656            decluttered: false,
657            ..self.clone()
658        }) as _);
659        Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
660    }
661}
662
663impl Op for Scan {
664    fn name(&self) -> Cow<str> {
665        "Scan".into()
666    }
667
668    fn info(&self) -> TractResult<Vec<String>> {
669        let mut lines = vec![];
670        for (ix, im) in self.input_mapping.iter().enumerate() {
671            lines.push(format!("Model input  #{ix}: {im:?}"));
672        }
673        for (ix, om) in self.output_mapping.iter().enumerate() {
674            lines.push(format!("Model output #{ix}: {om:?}"));
675        }
676        lines.push(format!("skip:{} reset_every_turn:{:?}", self.skip, self.reset_every_turn));
677        Ok(lines)
678    }
679
680    fn validation(&self) -> Validation {
681        Validation::Rounding
682    }
683
684    op_as_typed_op!();
685}
686
687impl EvalOp for Scan {
688    fn is_stateless(&self) -> bool {
689        false
690    }
691    fn state(
692        &self,
693        session: &mut SessionState,
694        node_id: usize,
695    ) -> TractResult<Option<Box<dyn OpState>>> {
696        self.to_codegen_op(false)?.state(session, node_id)
697    }
698}
699
700impl TypedOp for Scan {
701    as_op!();
702
703    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
704        anyhow::ensure!(inputs.len() == self.body.inputs.len());
705        anyhow::ensure!(self.input_mapping.len() == self.body.inputs.len());
706        anyhow::ensure!(
707            self.input_mapping.iter().filter(|m| m.is_state()).count()
708                == self.output_mapping.iter().filter(|m| m.state).count()
709        );
710        for (i, o) in
711            self.input_mapping.iter().enumerate().filter(|(_, m)| m.is_state()).map(|(i, _)| i).zip(
712                self.output_mapping.iter().enumerate().filter(|(_, m)| m.state).map(|(o, _)| o),
713            )
714        {
715            let ifact = self.body.outlet_fact(self.body.inputs[i])?;
716            let ofact = self.body.outlet_fact(self.body.outputs[o])?;
717            anyhow::ensure!(ifact == ofact,
718                "inconsistent state shape: body input {i} is {ifact:?} and body output {o} is {ofact:?}",
719            )
720        }
721        let mut outputs = tvec!();
722        let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?;
723        for (ix, output) in self.output_mapping.iter().enumerate() {
724            let fact = self.body.output_fact(ix)?;
725            if let Some((slot, info)) = output.scan {
726                let mut shape = fact.shape.clone();
727                let scanning_dim =
728                    output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
729                shape.set(info.axis, scanning_dim);
730                outputs.push((slot, fact.datum_type.fact(shape)));
731            }
732            if let Some(slot) = output.last_value_slot {
733                outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
734            }
735        }
736        outputs.sort_by_key(|a| a.0);
737        anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
738        let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
739        Ok(outputs)
740    }
741
742    fn axes_mapping(
743        &self,
744        inputs: &[&TypedFact],
745        outputs: &[&TypedFact],
746    ) -> TractResult<AxesMapping> {
747        let mut mappings = vec![];
748        let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?;
749        for body_axis in body_invs.iter_all_axes() {
750            let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len());
751            info.inputs.clone_from(&body_axis.inputs);
752            for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
753                let mut slots = vec![];
754                if let Some((slot, _scan)) = output_mapping.scan {
755                    slots.push(slot);
756                }
757                if let Some(slot) = output_mapping.last_value_slot {
758                    slots.push(slot);
759                }
760                for slot in slots {
761                    info.outputs[slot].clone_from(&body_axis.outputs[ix]);
762                }
763            }
764            if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) {
765                mappings.push(info);
766            }
767        }
768        AxesMapping::new(inputs.len(), outputs.len(), mappings)
769    }
770
771    fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
772        let mut suggestions = tvec!();
773        for (slot, input) in self.input_mapping.iter().enumerate() {
774            if let InputMapping::Scan(info) = input {
775                if info.axis != 0 {
776                    suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0)))
777                }
778            }
779        }
780        for output in &self.output_mapping {
781            if let Some((slot, scan)) = output.scan {
782                if scan.axis != 0 {
783                    suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0)))
784                }
785            }
786        }
787        Ok(suggestions)
788    }
789
790    fn change_axes(
791        &self,
792        model: &TypedModel,
793        node: &TypedNode,
794        io: InOut,
795        change: &AxisOp,
796    ) -> TractResult<Option<AxisChangeConsequence>> {
797        trace!("Propagating through {}: {:?} {:?}", node, io, change);
798        let body_leading_outlet = match io {
799            InOut::In(ix) => self.body.input_outlets()?[ix],
800            InOut::Out(slot) => {
801                let output = self
802                    .output_mapping
803                    .iter()
804                    .position(|im| {
805                        im.scan.map(|(slot, _i)| slot) == Some(slot)
806                            || im.last_value_slot == Some(slot)
807                    })
808                    .unwrap();
809                self.body.output_outlets()?[output]
810            }
811        };
812        let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() };
813        let node_input_facts = model.node_input_facts(node.id)?;
814        let result = self
815            .try_body_axes_change(axis_change, false, &node_input_facts)
816            .with_context(|| "Attemping to run change through scan body".to_string())?;
817        if result.is_some() {
818            trace!("{} accepted axis change", node);
819        } else {
820            trace!("{} rejected axis change", node);
821        }
822        Ok(result)
823    }
824
825    fn declutter_with_session(
826        &self,
827        session: &mut OptimizerSession,
828        model: &TypedModel,
829        node: &TypedNode,
830    ) -> TractResult<Option<TypedModelPatch>> {
831        macro_rules! pass {
832            ($func:ident) => {
833                if let Some(mut r) = self
834                    .$func(session, model, node)
835                    .with_context(|| format!("{}", stringify!($func)))?
836                {
837                    trace!(stringify!($func));
838                    r.push_context(stringify!($func));
839                    return Ok(Some(r));
840                }
841            };
842        }
843        pass!(declutter_const_input);
844        pass!(declutter_discard_unused_input);
845        pass!(declutter_discard_useless_outer_output);
846        pass!(declutter_discard_empty_output_mapping_with_body_output);
847        pass!(declutter_body);
848        pass!(declutter_body_axes);
849        pass!(declutter_pull_constant_outputs);
850        pass!(declutter_pull_batcheable_input);
851        pass!(declutter_pull_batcheable_output);
852        Ok(None)
853    }
854
855    fn concretize_dims(
856        &self,
857        _source: &TypedModel,
858        node: &TypedNode,
859        target: &mut TypedModel,
860        mapping: &HashMap<OutletId, OutletId>,
861        values: &SymbolValues,
862    ) -> TractResult<TVec<OutletId>> {
863        let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
864        let op = Self {
865            output_mapping: self
866                .output_mapping
867                .iter()
868                .map(|om| om.concretize_dims(values))
869                .collect::<TractResult<Vec<_>>>()?,
870            body: self.body.concretize_dims(values)?,
871            ..self.clone()
872        };
873        target.wire_node(&node.name, op, &inputs)
874    }
875
876    fn codegen(
877        &self,
878        model: &TypedModel,
879        node: &TypedNode,
880    ) -> TractResult<Option<TypedModelPatch>> {
881        Ok(Some(TypedModelPatch::replace_single_op(
882            model,
883            node,
884            &node.inputs,
885            self.to_codegen_op(true)?,
886        )?))
887    }
888}