tract_core/optim/propagate_uniform_tdim.rs
1use crate::internal::*;
2use crate::optim::OptimizerSession;
3
4/// Forward pass that refreshes `TypedFact::uniform_tdim` annotations by
5/// re-running each op's `output_facts` against its current input facts.
6///
7/// The default declutter pipeline computes a node's `uniform_tdim` once at
8/// load time, then reuses the cached fact. Some declutter rewrites
9/// (notably `Iff` folding when the condition is provably constant) shunt a
10/// node's input edge without re-running the consumer's `output_facts` —
11/// the consumer's cached fact then references the stale upstream fact and
12/// loses any newly-available `uniform_tdim` annotation. Every wire
13/// downstream of the shunt then sees `uniform_tdim = None`, and passes
14/// like `FoldUniformMask` or Blockify section detection silently miss it.
15///
16/// This pass walks the model in topological order, calls `output_facts`
17/// fresh on each node, and copies the recomputed `uniform_tdim` over the
18/// cached one when it differs. Other fact fields are untouched (the
19/// existing declutter loop is responsible for them). Iterates to fixpoint
20/// since a refreshed annotation upstream may unlock more refreshes
21/// downstream.
22#[derive(Clone, Debug, Default)]
23pub struct PropagateUniformTdim;
24
25impl super::TypedPass for PropagateUniformTdim {
26 fn reset(&mut self) -> TractResult<()> {
27 Ok(())
28 }
29
30 fn next(
31 &mut self,
32 _session: &mut OptimizerSession,
33 _model: &TypedModel,
34 ) -> TractResult<Option<TypedModelPatch>> {
35 Ok(None)
36 }
37
38 fn run_direct(&mut self, model: &mut TypedModel) -> TractResult<bool> {
39 let order = model.eval_order()?;
40 let mut any_changed = false;
41 loop {
42 let mut changed = false;
43 for &node_id in &order {
44 let typed_op = match model.nodes()[node_id].op.as_typed() {
45 Some(op) => op,
46 None => continue,
47 };
48 let input_facts: TVec<TypedFact> = model.nodes()[node_id]
49 .inputs
50 .iter()
51 .map(|i| model.outlet_fact(*i).cloned())
52 .collect::<TractResult<_>>()?;
53 let input_refs: TVec<&TypedFact> = input_facts.iter().collect();
54 let new_facts = match typed_op.output_facts(&input_refs) {
55 Ok(f) => f,
56 Err(_) => continue,
57 };
58 for (slot, new_fact) in new_facts.iter().enumerate() {
59 let current_uniform_tdim =
60 model.nodes()[node_id].outputs[slot].fact.uniform_tdim.clone();
61 if current_uniform_tdim != new_fact.uniform_tdim {
62 model.nodes_mut()[node_id].outputs[slot].fact.uniform_tdim =
63 new_fact.uniform_tdim.clone();
64 changed = true;
65 }
66 }
67 }
68 if !changed {
69 break;
70 }
71 any_changed = true;
72 }
73 Ok(any_changed)
74 }
75}