tract_core/optim/
propagate_roi.rs1use crate::internal::*;
2use crate::ops::logic::sym_to_coord_axis;
3use crate::optim::OptimizerSession;
4
5#[derive(Clone, Debug, Default)]
19pub struct PropagateRoi;
20
21fn roi_union(a: &TDim, b: &TDim) -> TDim {
23 if a == b {
24 return a.clone();
25 }
26 a.clone() + b.clone() - a.clone() * b.clone()
27}
28
29pub fn bubble_roi(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TVec<Option<TDim>>>> {
36 let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
37 rule_if_some!(roi = &output_fact.region_of_interest);
38
39 let input_facts: TVec<&TypedFact> =
40 node.inputs.iter().map(|i| model.outlet_fact(*i)).collect::<TractResult<_>>()?;
41 let output_facts = tvec![output_fact];
42 let inputs_ref: Vec<&TypedFact> = input_facts.iter().copied().collect();
43 let outputs_ref: Vec<&TypedFact> = output_facts.iter().copied().collect();
44 let mapping = node.op.as_typed().unwrap().axes_mapping(&inputs_ref, &outputs_ref)?;
45
46 let roi_coord_syms: Vec<(usize, Symbol)> =
48 roi.symbols().into_iter().filter_map(|s| sym_to_coord_axis(&s).map(|k| (k, s))).collect();
49
50 let remap_for_input = |input_ix: usize| -> Option<TDim> {
51 let mut sub_map: HashMap<Symbol, TDim> = HashMap::new();
52 for (out_pos, sym) in &roi_coord_syms {
53 let logical = mapping
54 .iter_all_axes()
55 .find(|a| a.outputs.first().is_some_and(|o| o.contains(out_pos)))?;
56 if logical.inputs[input_ix].is_empty() {
57 return None;
58 }
59 let in_pos = logical.inputs[input_ix][0];
60 if input_facts[input_ix].shape[in_pos] != output_fact.shape[*out_pos] {
61 return None;
62 }
63 if in_pos != *out_pos {
64 let scope = sym.scope()?;
65 sub_map.insert(sym.clone(), TDim::Sym(scope.coord_sym(in_pos)));
66 }
67 }
68 if sub_map.is_empty() { Some(roi.clone()) } else { roi.substitute_all(&sub_map).ok() }
69 };
70 let result: TVec<Option<TDim>> = (0..node.inputs.len()).map(|ix| remap_for_input(ix)).collect();
71 Ok(Some(result))
72}
73
74pub fn recognise_chunked_band_project(roi: &TDim, p_sym: &Symbol, k_sym: &Symbol) -> Option<TDim> {
96 let TDim::Mul(terms) = roi else { return None };
98 if terms.len() != 2 {
99 return None;
100 }
101 let TDim::Ge(top_l, top_r) = &terms[0] else { return None };
102 let TDim::Ge(bot_l, bot_r) = &terms[1] else { return None };
103
104 let (l_val, a, r_val) = if top_r.as_ref() == bot_l.as_ref() {
108 (top_l.as_ref(), top_r.as_ref(), bot_r.as_ref())
109 } else if top_l.as_ref() == bot_r.as_ref() {
110 (bot_l.as_ref(), top_l.as_ref(), top_r.as_ref())
112 } else {
113 return None;
114 };
115
116 if r_val != &TDim::Val(0) {
118 return None;
119 }
120 let big_l = l_val.to_i64().ok()?;
121 if big_l < 0 {
122 return None;
123 }
124
125 let (a_no_const, c) = split_const(a);
131 let (k, p_num, q_num) = match_diff_of_divs(&a_no_const)?;
132 let derived_inner_offset = (p_num + TDim::Sym(k_sym.clone()) - q_num).reduce();
133 if derived_inner_offset.symbols().contains(p_sym)
134 || derived_inner_offset.symbols().contains(k_sym)
135 {
136 return None;
137 }
138 let actual_offset = (derived_inner_offset + TDim::Val(c * k as i64)).reduce();
139
140 let high = (actual_offset.clone() + TDim::Val(k as i64 - 1)).reduce();
142 let low = (actual_offset - TDim::Val((big_l + 1) * k as i64 - 1)).reduce();
143 Some(
144 TDim::Mul(vec![
145 TDim::Ge(Box::new(high), Box::new(TDim::Sym(k_sym.clone()))),
146 TDim::Ge(Box::new(TDim::Sym(k_sym.clone())), Box::new(low)),
147 ])
148 .reduce(),
149 )
150}
151
152fn split_const(expr: &TDim) -> (TDim, i64) {
157 if let TDim::Add(terms) = expr {
158 let mut c = 0i64;
159 let mut rest: Vec<TDim> = vec![];
160 for t in terms {
161 match t {
162 TDim::Val(v) => c += *v,
163 _ => rest.push(t.clone()),
164 }
165 }
166 let new_expr = if rest.is_empty() {
167 TDim::Val(0)
168 } else if rest.len() == 1 {
169 rest.into_iter().next().unwrap()
170 } else {
171 TDim::Add(rest)
172 };
173 return (new_expr, c);
174 }
175 (expr.clone(), 0)
176}
177
178fn match_diff_of_divs(expr: &TDim) -> Option<(u64, TDim, TDim)> {
182 let TDim::Add(terms) = expr else { return None };
183 if terms.len() != 2 {
184 return None;
185 }
186 let mut pos_div: Option<(TDim, u64)> = None;
187 let mut neg_div: Option<(TDim, u64)> = None;
188 for t in terms {
189 match t {
190 TDim::Div(inner, k) => {
191 pos_div = Some(((**inner).clone(), *k));
192 }
193 TDim::MulInt(-1, inner) => {
194 if let TDim::Div(num, k) = inner.as_ref() {
195 neg_div = Some(((**num).clone(), *k));
196 }
197 }
198 _ => {}
199 }
200 }
201 let (p_expr, k1) = pos_div?;
202 let (q_expr, k2) = neg_div?;
203 if k1 != k2 {
204 return None;
205 }
206 Some((k1, p_expr, q_expr))
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[test]
217 fn recognise_chunked_band_yields_constant_band() {
218 let scope = SymbolScope::default();
219 let p = scope.coord_sym(0); let k_ax = scope.coord_sym(1); let offset = 9i64;
222 let k: u64 = 14;
223 let big_l = 5i64;
224
225 let num1 = TDim::Sym(p.clone());
227 let num2 = TDim::Sym(p.clone()) + TDim::Sym(k_ax.clone()) - TDim::Val(offset);
228 let a = (TDim::Div(Box::new(num1), k) - TDim::Div(Box::new(num2), k)).reduce();
229 let band = TDim::Mul(vec![
230 TDim::Ge(Box::new(TDim::Val(big_l)), Box::new(a.clone())),
231 TDim::Ge(Box::new(a), Box::new(TDim::Val(0))),
232 ])
233 .reduce();
234 eprintln!("input band: {band}");
235
236 let projected =
237 recognise_chunked_band_project(&band, &p, &k_ax).expect("recogniser should match");
238 eprintln!("projected: {projected}");
239
240 let high_expected = offset + k as i64 - 1; let low_expected = offset - (big_l + 1) * k as i64 + 1; let TDim::Mul(terms) = &projected else { panic!("expected Mul") };
245 assert_eq!(terms.len(), 2);
246 let mut saw_high = false;
249 let mut saw_low = false;
250 for t in terms {
251 let TDim::Ge(l, r) = t else { panic!("expected Ge inside Mul") };
252 if **l == TDim::Val(high_expected) && **r == TDim::Sym(k_ax.clone()) {
253 saw_high = true;
254 } else if **l == TDim::Sym(k_ax.clone()) && **r == TDim::Val(low_expected) {
255 saw_low = true;
256 }
257 }
258 assert!(saw_high, "missing Ge(high={high_expected}, r); got: {projected}");
259 assert!(saw_low, "missing Ge(r, low={low_expected}); got: {projected}");
260 }
261}
262
263impl super::TypedPass for PropagateRoi {
264 fn reset(&mut self) -> TractResult<()> {
265 Ok(())
266 }
267
268 fn next(
269 &mut self,
270 _session: &mut OptimizerSession,
271 _model: &TypedModel,
272 ) -> TractResult<Option<TypedModelPatch>> {
273 Ok(None)
274 }
275
276 fn run_direct(&mut self, model: &mut TypedModel) -> TractResult<bool> {
277 let order = model.eval_order()?;
278 let mut any_changed = false;
279
280 loop {
281 let mut changed = false;
282 let mut demands: HashMap<OutletId, Option<TDim>> = HashMap::new();
283
284 for &node_id in &order {
285 let node = &model.nodes()[node_id];
286 let Some(input_rois) = node.op.as_typed().unwrap().input_roi(model, node)? else {
287 continue;
288 };
289 for (ix, roi) in input_rois.into_iter().enumerate() {
290 let outlet = node.inputs[ix];
291 match (demands.get(&outlet), &roi) {
292 (_, None) => {
293 demands.insert(outlet, None);
294 }
295 (Option::None, Some(roi)) => {
296 demands.insert(outlet, Some(roi.clone()));
297 }
298 (Some(None), Some(_)) => {}
299 (Some(Some(existing)), Some(new)) => {
300 demands.insert(outlet, Some(roi_union(existing, new)));
301 }
302 }
303 }
304 }
305
306 for (outlet, demand) in demands {
308 if let Some(roi) = demand {
309 let roi = roi.simplify();
310 if roi == TDim::Val(1) {
312 continue;
313 }
314 let fact = &mut model.nodes_mut()[outlet.node].outputs[outlet.slot].fact;
315 if fact.region_of_interest.as_ref() != Some(&roi) {
316 fact.region_of_interest = Some(roi);
317 changed = true;
318 }
319 }
320 }
321
322 any_changed |= changed;
323 if !changed {
324 break;
325 }
326 }
327
328 Ok(any_changed)
329 }
330}