Skip to main content

tract_core/optim/
propagate_roi.rs

1use crate::internal::*;
2use crate::ops::logic::sym_to_coord_axis;
3use crate::optim::OptimizerSession;
4
5/// Backward pass that propagates `region_of_interest` annotations by
6/// calling `TypedOp::input_roi` on each node.
7///
8/// Ops can **introduce** ROIs (e.g. Iff reads its mask's uniform_tdim and
9/// creates a ROI on the scores input) or **bubble** them (e.g. element-wise
10/// ops pass an output ROI through to their inputs).
11///
12/// When multiple consumers of a wire produce different ROIs, they are merged
13/// via boolean OR using De Morgan: `a โˆจ b = a + b - a * b`.
14/// If any consumer returns `None` for a wire (needs all positions), that wire
15/// gets no ROI.
16///
17/// The pass iterates to fixpoint: introductions may enable further bubbling.
18#[derive(Clone, Debug, Default)]
19pub struct PropagateRoi;
20
21/// Merge two ROI expressions via boolean OR: `a โˆจ b = a + b - a * b`.
22fn roi_union(a: &TDim, b: &TDim) -> TDim {
23    if a == b {
24        return a.clone();
25    }
26    a.clone() + b.clone() - a.clone() * b.clone()
27}
28
29/// Bubble output ROI to inputs using the op's axes_mapping.
30///
31/// For each input, builds a coordinate substitution map from the axes mapping:
32/// each output axis that appears in this input gets ๐ŸŽฏ{out_pos} โ†’ ๐ŸŽฏ{in_pos}.
33/// If any ROI coordinate symbol has no corresponding input axis (contracted,
34/// broadcast from dim=1, or absent), returns None for that input.
35pub fn bubble_roi(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TVec<Option<TDim>>>> {
36    let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
37    rule_if_some!(roi = &output_fact.region_of_interest);
38
39    let input_facts: TVec<&TypedFact> =
40        node.inputs.iter().map(|i| model.outlet_fact(*i)).collect::<TractResult<_>>()?;
41    let output_facts = tvec![output_fact];
42    let inputs_ref: Vec<&TypedFact> = input_facts.iter().copied().collect();
43    let outputs_ref: Vec<&TypedFact> = output_facts.iter().copied().collect();
44    let mapping = node.op.as_typed().unwrap().axes_mapping(&inputs_ref, &outputs_ref)?;
45
46    // Collect ROI coordinate symbols and their output axis positions.
47    let roi_coord_syms: Vec<(usize, Symbol)> =
48        roi.symbols().into_iter().filter_map(|s| sym_to_coord_axis(&s).map(|k| (k, s))).collect();
49
50    let remap_for_input = |input_ix: usize| -> Option<TDim> {
51        let mut sub_map: HashMap<Symbol, TDim> = HashMap::new();
52        for (out_pos, sym) in &roi_coord_syms {
53            let logical = mapping
54                .iter_all_axes()
55                .find(|a| a.outputs.first().is_some_and(|o| o.contains(out_pos)))?;
56            if logical.inputs[input_ix].is_empty() {
57                return None;
58            }
59            let in_pos = logical.inputs[input_ix][0];
60            if input_facts[input_ix].shape[in_pos] != output_fact.shape[*out_pos] {
61                return None;
62            }
63            if in_pos != *out_pos {
64                let scope = sym.scope()?;
65                sub_map.insert(sym.clone(), TDim::Sym(scope.coord_sym(in_pos)));
66            }
67        }
68        if sub_map.is_empty() { Some(roi.clone()) } else { roi.substitute_all(&sub_map).ok() }
69    };
70    let result: TVec<Option<TDim>> = (0..node.inputs.len()).map(|ix| remap_for_input(ix)).collect();
71    Ok(Some(result))
72}
73
74/// Recognise a chunked-band predicate on output coords `(p, k_axis)` of the
75/// shape produced by `DiagGather::input_roi`'s `c โ†’ r + q โˆ’ offset`
76/// substitution applied to a `Mul(Ge(L, q/k โˆ’ c/k), Ge(q/k โˆ’ c/k, 0))` band,
77/// and return the projected band on `k_axis` after existentially
78/// quantifying `p` over its dim bound.
79///
80/// Specifically, recognises:
81///
82///   `Mul(Ge(L_val, A), Ge(A, 0))`
83///
84/// where `A = โŒŠp/kโŒ‹ โˆ’ โŒŠ(p + k_axis โˆ’ offset)/kโŒ‹` with `p` the projected
85/// coord symbol (e.g. query) and `k_axis` the kept coord symbol (e.g.
86/// rel-pos index).  Closed-form projection: as `p` varies, `A` takes
87/// values in `{โˆ’โŒˆ(k_axis โˆ’ offset)/kโŒ‰, โˆ’โŒŠ(k_axis โˆ’ offset)/kโŒ‹}`, so the
88/// existential `0 โ‰ค A โ‰ค L_val` is satisfiable iff
89///
90///   `k_axis โˆˆ [offset โˆ’ (L_val + 1)ยทk + 1, offset + (k โˆ’ 1)]`
91///
92/// โ€” a constant band of width `(L_val + 2)ยทk โˆ’ 1`.
93///
94/// Returns `None` if the pattern doesn't match.
95pub fn recognise_chunked_band_project(roi: &TDim, p_sym: &Symbol, k_sym: &Symbol) -> Option<TDim> {
96    // Match Mul(Ge(L, A), Ge(A, R)).
97    let TDim::Mul(terms) = roi else { return None };
98    if terms.len() != 2 {
99        return None;
100    }
101    let TDim::Ge(top_l, top_r) = &terms[0] else { return None };
102    let TDim::Ge(bot_l, bot_r) = &terms[1] else { return None };
103
104    // Identify which orientation: top = Ge(L, A) and bot = Ge(A, R)?
105    // We need the same `A` to appear as second arg of first and first arg
106    // of second.
107    let (l_val, a, r_val) = if top_r.as_ref() == bot_l.as_ref() {
108        (top_l.as_ref(), top_r.as_ref(), bot_r.as_ref())
109    } else if top_l.as_ref() == bot_r.as_ref() {
110        // Reverse: top is Ge(A, L'), bot is Ge(R', A) โ€” swap roles.
111        (bot_l.as_ref(), top_l.as_ref(), top_r.as_ref())
112    } else {
113        return None;
114    };
115
116    // R side must be 0 (the band is 0 โ‰ค X โ‰ค L).
117    if r_val != &TDim::Val(0) {
118        return None;
119    }
120    let big_l = l_val.to_i64().ok()?;
121    if big_l < 0 {
122        return None;
123    }
124
125    // `A` may have a constant offset c factored out by the simplifier (e.g.
126    // when the original offset isn't a multiple of k, the simplifier
127    // rewrites `(p+r-9)/k` as `(p+r+5)/k - 1` for k=14).  Peel c off so
128    // we can match the inner diff-of-divs, then re-fold cยทk into the
129    // recovered offset.
130    let (a_no_const, c) = split_const(a);
131    let (k, p_num, q_num) = match_diff_of_divs(&a_no_const)?;
132    let derived_inner_offset = (p_num + TDim::Sym(k_sym.clone()) - q_num).reduce();
133    if derived_inner_offset.symbols().contains(p_sym)
134        || derived_inner_offset.symbols().contains(k_sym)
135    {
136        return None;
137    }
138    let actual_offset = (derived_inner_offset + TDim::Val(c * k as i64)).reduce();
139
140    // The projected band on k_sym: [offset โˆ’ (L+1)ยทk + 1, offset + (k โˆ’ 1)].
141    let high = (actual_offset.clone() + TDim::Val(k as i64 - 1)).reduce();
142    let low = (actual_offset - TDim::Val((big_l + 1) * k as i64 - 1)).reduce();
143    Some(
144        TDim::Mul(vec![
145            TDim::Ge(Box::new(high), Box::new(TDim::Sym(k_sym.clone()))),
146            TDim::Ge(Box::new(TDim::Sym(k_sym.clone())), Box::new(low)),
147        ])
148        .reduce(),
149    )
150}
151
152/// Split `expr` into `(expr_without_constant, constant_part)`.  If `expr`
153/// is `Add([...constants..., ...non-constants...])`, sum up the constant
154/// terms and return the non-constant remainder.  Otherwise returns
155/// `(expr, 0)`.
156fn split_const(expr: &TDim) -> (TDim, i64) {
157    if let TDim::Add(terms) = expr {
158        let mut c = 0i64;
159        let mut rest: Vec<TDim> = vec![];
160        for t in terms {
161            match t {
162                TDim::Val(v) => c += *v,
163                _ => rest.push(t.clone()),
164            }
165        }
166        let new_expr = if rest.is_empty() {
167            TDim::Val(0)
168        } else if rest.len() == 1 {
169            rest.into_iter().next().unwrap()
170        } else {
171            TDim::Add(rest)
172        };
173        return (new_expr, c);
174    }
175    (expr.clone(), 0)
176}
177
178/// If `expr` matches `Div(p_expr, k) โˆ’ Div(q_expr, k)` (in either order),
179/// returns `(k, p_expr, q_expr)` where `p_expr` is the numerator with the
180/// positive coefficient.
181fn match_diff_of_divs(expr: &TDim) -> Option<(u64, TDim, TDim)> {
182    let TDim::Add(terms) = expr else { return None };
183    if terms.len() != 2 {
184        return None;
185    }
186    let mut pos_div: Option<(TDim, u64)> = None;
187    let mut neg_div: Option<(TDim, u64)> = None;
188    for t in terms {
189        match t {
190            TDim::Div(inner, k) => {
191                pos_div = Some(((**inner).clone(), *k));
192            }
193            TDim::MulInt(-1, inner) => {
194                if let TDim::Div(num, k) = inner.as_ref() {
195                    neg_div = Some(((**num).clone(), *k));
196                }
197            }
198            _ => {}
199        }
200    }
201    let (p_expr, k1) = pos_div?;
202    let (q_expr, k2) = neg_div?;
203    if k1 != k2 {
204        return None;
205    }
206    Some((k1, p_expr, q_expr))
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212
213    /// Closed-form recognition: chunked-band predicate after DG substitution
214    /// `c โ†’ r + q โˆ’ offset` should project `q` out and yield a constant band
215    /// on `r` of width `(L+2)ยทk โˆ’ 1`, centred around `offset`.
216    #[test]
217    fn recognise_chunked_band_yields_constant_band() {
218        let scope = SymbolScope::default();
219        let p = scope.coord_sym(0); // q (projected)
220        let k_ax = scope.coord_sym(1); // r (kept)
221        let offset = 9i64;
222        let k: u64 = 14;
223        let big_l = 5i64;
224
225        // A = p/k โˆ’ (p + k_ax โˆ’ offset)/k
226        let num1 = TDim::Sym(p.clone());
227        let num2 = TDim::Sym(p.clone()) + TDim::Sym(k_ax.clone()) - TDim::Val(offset);
228        let a = (TDim::Div(Box::new(num1), k) - TDim::Div(Box::new(num2), k)).reduce();
229        let band = TDim::Mul(vec![
230            TDim::Ge(Box::new(TDim::Val(big_l)), Box::new(a.clone())),
231            TDim::Ge(Box::new(a), Box::new(TDim::Val(0))),
232        ])
233        .reduce();
234        eprintln!("input band: {band}");
235
236        let projected =
237            recognise_chunked_band_project(&band, &p, &k_ax).expect("recogniser should match");
238        eprintln!("projected: {projected}");
239
240        // Expected: r โˆˆ [offset โˆ’ (L+1)ยทk + 1, offset + (k โˆ’ 1)]
241        //         = [9 โˆ’ 84 + 1, 9 + 13] = [-74, 22] (width 97)
242        let high_expected = offset + k as i64 - 1; // 22
243        let low_expected = offset - (big_l + 1) * k as i64 + 1; // -74
244        let TDim::Mul(terms) = &projected else { panic!("expected Mul") };
245        assert_eq!(terms.len(), 2);
246        // Position-independent: one Ge term is `Ge(high, r)` (= r โ‰ค high),
247        // the other is `Ge(r, low)` (= r โ‰ฅ low).
248        let mut saw_high = false;
249        let mut saw_low = false;
250        for t in terms {
251            let TDim::Ge(l, r) = t else { panic!("expected Ge inside Mul") };
252            if **l == TDim::Val(high_expected) && **r == TDim::Sym(k_ax.clone()) {
253                saw_high = true;
254            } else if **l == TDim::Sym(k_ax.clone()) && **r == TDim::Val(low_expected) {
255                saw_low = true;
256            }
257        }
258        assert!(saw_high, "missing Ge(high={high_expected}, r); got: {projected}");
259        assert!(saw_low, "missing Ge(r, low={low_expected}); got: {projected}");
260    }
261}
262
263impl super::TypedPass for PropagateRoi {
264    fn reset(&mut self) -> TractResult<()> {
265        Ok(())
266    }
267
268    fn next(
269        &mut self,
270        _session: &mut OptimizerSession,
271        _model: &TypedModel,
272    ) -> TractResult<Option<TypedModelPatch>> {
273        Ok(None)
274    }
275
276    fn run_direct(&mut self, model: &mut TypedModel) -> TractResult<bool> {
277        let order = model.eval_order()?;
278        let mut any_changed = false;
279
280        loop {
281            let mut changed = false;
282            let mut demands: HashMap<OutletId, Option<TDim>> = HashMap::new();
283
284            for &node_id in &order {
285                let node = &model.nodes()[node_id];
286                let Some(input_rois) = node.op.as_typed().unwrap().input_roi(model, node)? else {
287                    continue;
288                };
289                for (ix, roi) in input_rois.into_iter().enumerate() {
290                    let outlet = node.inputs[ix];
291                    match (demands.get(&outlet), &roi) {
292                        (_, None) => {
293                            demands.insert(outlet, None);
294                        }
295                        (Option::None, Some(roi)) => {
296                            demands.insert(outlet, Some(roi.clone()));
297                        }
298                        (Some(None), Some(_)) => {}
299                        (Some(Some(existing)), Some(new)) => {
300                            demands.insert(outlet, Some(roi_union(existing, new)));
301                        }
302                    }
303                }
304            }
305
306            // Apply demands to model facts.
307            for (outlet, demand) in demands {
308                if let Some(roi) = demand {
309                    let roi = roi.simplify();
310                    // ROI of 1 means "all positions matter" โ€” equivalent to None.
311                    if roi == TDim::Val(1) {
312                        continue;
313                    }
314                    let fact = &mut model.nodes_mut()[outlet.node].outputs[outlet.slot].fact;
315                    if fact.region_of_interest.as_ref() != Some(&roi) {
316                        fact.region_of_interest = Some(roi);
317                        changed = true;
318                    }
319                }
320            }
321
322            any_changed |= changed;
323            if !changed {
324                break;
325            }
326        }
327
328        Ok(any_changed)
329    }
330}