Skip to main content

tract_core/optim/
propagate_uniform_tdim.rs

1use crate::internal::*;
2use crate::optim::OptimizerSession;
3
4/// Forward pass that refreshes `TypedFact::uniform_tdim` annotations by
5/// re-running each op's `output_facts` against its current input facts.
6///
7/// The default declutter pipeline computes a node's `uniform_tdim` once at
8/// load time, then reuses the cached fact.  Some declutter rewrites
9/// (notably `Iff` folding when the condition is provably constant) shunt a
10/// node's input edge without re-running the consumer's `output_facts` —
11/// the consumer's cached fact then references the stale upstream fact and
12/// loses any newly-available `uniform_tdim` annotation.  Every wire
13/// downstream of the shunt then sees `uniform_tdim = None`, and passes
14/// like `FoldUniformMask` or Blockify section detection silently miss it.
15///
16/// This pass walks the model in topological order, calls `output_facts`
17/// fresh on each node, and copies the recomputed `uniform_tdim` over the
18/// cached one when it differs.  Other fact fields are untouched (the
19/// existing declutter loop is responsible for them).  Iterates to fixpoint
20/// since a refreshed annotation upstream may unlock more refreshes
21/// downstream.
22#[derive(Clone, Debug, Default)]
23pub struct PropagateUniformTdim;
24
25impl super::TypedPass for PropagateUniformTdim {
26    fn reset(&mut self) -> TractResult<()> {
27        Ok(())
28    }
29
30    fn next(
31        &mut self,
32        _session: &mut OptimizerSession,
33        _model: &TypedModel,
34    ) -> TractResult<Option<TypedModelPatch>> {
35        Ok(None)
36    }
37
38    fn run_direct(&mut self, model: &mut TypedModel) -> TractResult<bool> {
39        let order = model.eval_order()?;
40        let mut any_changed = false;
41        loop {
42            let mut changed = false;
43            for &node_id in &order {
44                let typed_op = match model.nodes()[node_id].op.as_typed() {
45                    Some(op) => op,
46                    None => continue,
47                };
48                let input_facts: TVec<TypedFact> = model.nodes()[node_id]
49                    .inputs
50                    .iter()
51                    .map(|i| model.outlet_fact(*i).cloned())
52                    .collect::<TractResult<_>>()?;
53                let input_refs: TVec<&TypedFact> = input_facts.iter().collect();
54                let new_facts = match typed_op.output_facts(&input_refs) {
55                    Ok(f) => f,
56                    Err(_) => continue,
57                };
58                for (slot, new_fact) in new_facts.iter().enumerate() {
59                    let current_uniform_tdim =
60                        model.nodes()[node_id].outputs[slot].fact.uniform_tdim.clone();
61                    if current_uniform_tdim != new_fact.uniform_tdim {
62                        model.nodes_mut()[node_id].outputs[slot].fact.uniform_tdim =
63                            new_fact.uniform_tdim.clone();
64                        changed = true;
65                    }
66                }
67            }
68            if !changed {
69                break;
70            }
71            any_changed = true;
72        }
73        Ok(any_changed)
74    }
75}