Skip to main content

tract_core/optim/
propagate_roi.rs

1use crate::internal::*;
2use crate::ops::logic::sym_to_coord_axis;
3use crate::optim::OptimizerSession;
4
5/// Backward pass that propagates `region_of_interest` annotations by
6/// calling `TypedOp::input_roi` on each node.
7///
8/// Ops can **introduce** ROIs (e.g. Iff reads its mask's uniform_tdim and
9/// creates a ROI on the scores input) or **bubble** them (e.g. element-wise
10/// ops pass an output ROI through to their inputs).
11///
12/// When multiple consumers of a wire produce different ROIs, they are merged
13/// via boolean OR using De Morgan: `a ∨ b = a + b - a * b`.
14/// If any consumer returns `None` for a wire (needs all positions), that wire
15/// gets no ROI.
16///
17/// The pass iterates to fixpoint: introductions may enable further bubbling.
18#[derive(Clone, Debug, Default)]
19pub struct PropagateRoi;
20
21/// Merge two ROI expressions via boolean OR: `a ∨ b = a + b - a * b`.
22fn roi_union(a: &TDim, b: &TDim) -> TDim {
23    if a == b {
24        return a.clone();
25    }
26    a.clone() + b.clone() - a.clone() * b.clone()
27}
28
29/// Bubble output ROI to inputs using the op's axes_mapping.
30///
31/// For each input, builds a coordinate substitution map from the axes mapping:
32/// each output axis that appears in this input gets 🎯{out_pos} → 🎯{in_pos}.
33/// If any ROI coordinate symbol has no corresponding input axis (contracted,
34/// broadcast from dim=1, or absent), returns None for that input.
35pub fn bubble_roi(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TVec<Option<TDim>>>> {
36    let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
37    let Some(roi) = &output_fact.region_of_interest else { return Ok(None) };
38
39    let input_facts: TVec<&TypedFact> =
40        node.inputs.iter().map(|i| model.outlet_fact(*i)).collect::<TractResult<_>>()?;
41    let output_facts = tvec![output_fact];
42    let inputs_ref: Vec<&TypedFact> = input_facts.iter().copied().collect();
43    let outputs_ref: Vec<&TypedFact> = output_facts.iter().copied().collect();
44    let mapping = node.op.as_typed().unwrap().axes_mapping(&inputs_ref, &outputs_ref)?;
45
46    // Collect ROI coordinate symbols and their output axis positions.
47    let roi_coord_syms: Vec<(usize, Symbol)> =
48        roi.symbols().into_iter().filter_map(|s| sym_to_coord_axis(&s).map(|k| (k, s))).collect();
49
50    let remap_for_input = |input_ix: usize| -> Option<TDim> {
51        let mut sub_map: HashMap<Symbol, TDim> = HashMap::new();
52        for (out_pos, sym) in &roi_coord_syms {
53            let logical = mapping
54                .iter_all_axes()
55                .find(|a| a.outputs.first().is_some_and(|o| o.contains(out_pos)))?;
56            if logical.inputs[input_ix].is_empty() {
57                return None;
58            }
59            let in_pos = logical.inputs[input_ix][0];
60            if input_facts[input_ix].shape[in_pos] != output_fact.shape[*out_pos] {
61                return None;
62            }
63            if in_pos != *out_pos {
64                let scope = sym.scope()?;
65                sub_map.insert(sym.clone(), TDim::Sym(scope.coord_sym(in_pos)));
66            }
67        }
68        if sub_map.is_empty() { Some(roi.clone()) } else { roi.substitute_all(&sub_map).ok() }
69    };
70    let result: TVec<Option<TDim>> = (0..node.inputs.len()).map(|ix| remap_for_input(ix)).collect();
71    Ok(Some(result))
72}
73
74impl super::TypedPass for PropagateRoi {
75    fn reset(&mut self) -> TractResult<()> {
76        Ok(())
77    }
78
79    fn next(
80        &mut self,
81        _session: &mut OptimizerSession,
82        _model: &TypedModel,
83    ) -> TractResult<Option<TypedModelPatch>> {
84        Ok(None)
85    }
86
87    fn run_direct(&mut self, model: &mut TypedModel) -> TractResult<bool> {
88        let order = model.eval_order()?;
89        let mut changed = false;
90
91        // Collect ROI demands from all nodes.
92        let mut demands: HashMap<OutletId, Option<TDim>> = HashMap::new();
93
94        for &node_id in &order {
95            let node = &model.nodes()[node_id];
96            let Some(input_rois) = node.op.as_typed().unwrap().input_roi(model, node)? else {
97                continue;
98            };
99            for (ix, roi) in input_rois.into_iter().enumerate() {
100                let outlet = node.inputs[ix];
101                match (demands.get(&outlet), &roi) {
102                    (_, None) => {
103                        demands.insert(outlet, None);
104                    }
105                    (Option::None, Some(roi)) => {
106                        demands.insert(outlet, Some(roi.clone()));
107                    }
108                    (Some(None), Some(_)) => {}
109                    (Some(Some(existing)), Some(new)) => {
110                        demands.insert(outlet, Some(roi_union(existing, new)));
111                    }
112                }
113            }
114        }
115
116        // Apply demands to model facts.
117        for (outlet, demand) in demands {
118            if let Some(roi) = demand {
119                let fact = &mut model.nodes_mut()[outlet.node].outputs[outlet.slot].fact;
120                if fact.region_of_interest.as_ref() != Some(&roi) {
121                    fact.region_of_interest = Some(roi);
122                    changed = true;
123                }
124            }
125        }
126
127        Ok(changed)
128    }
129}