tract_core/optim/
propagate_roi.rs1use crate::internal::*;
2use crate::ops::logic::sym_to_coord_axis;
3use crate::optim::OptimizerSession;
4
5#[derive(Clone, Debug, Default)]
19pub struct PropagateRoi;
20
21fn roi_union(a: &TDim, b: &TDim) -> TDim {
23 if a == b {
24 return a.clone();
25 }
26 a.clone() + b.clone() - a.clone() * b.clone()
27}
28
29pub fn bubble_roi(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TVec<Option<TDim>>>> {
36 let output_fact = model.outlet_fact(OutletId::new(node.id, 0))?;
37 let Some(roi) = &output_fact.region_of_interest else { return Ok(None) };
38
39 let input_facts: TVec<&TypedFact> =
40 node.inputs.iter().map(|i| model.outlet_fact(*i)).collect::<TractResult<_>>()?;
41 let output_facts = tvec![output_fact];
42 let inputs_ref: Vec<&TypedFact> = input_facts.iter().copied().collect();
43 let outputs_ref: Vec<&TypedFact> = output_facts.iter().copied().collect();
44 let mapping = node.op.as_typed().unwrap().axes_mapping(&inputs_ref, &outputs_ref)?;
45
46 let roi_coord_syms: Vec<(usize, Symbol)> =
48 roi.symbols().into_iter().filter_map(|s| sym_to_coord_axis(&s).map(|k| (k, s))).collect();
49
50 let remap_for_input = |input_ix: usize| -> Option<TDim> {
51 let mut sub_map: HashMap<Symbol, TDim> = HashMap::new();
52 for (out_pos, sym) in &roi_coord_syms {
53 let logical = mapping
54 .iter_all_axes()
55 .find(|a| a.outputs.first().is_some_and(|o| o.contains(out_pos)))?;
56 if logical.inputs[input_ix].is_empty() {
57 return None;
58 }
59 let in_pos = logical.inputs[input_ix][0];
60 if input_facts[input_ix].shape[in_pos] != output_fact.shape[*out_pos] {
61 return None;
62 }
63 if in_pos != *out_pos {
64 let scope = sym.scope()?;
65 sub_map.insert(sym.clone(), TDim::Sym(scope.coord_sym(in_pos)));
66 }
67 }
68 if sub_map.is_empty() { Some(roi.clone()) } else { roi.substitute_all(&sub_map).ok() }
69 };
70 let result: TVec<Option<TDim>> = (0..node.inputs.len()).map(|ix| remap_for_input(ix)).collect();
71 Ok(Some(result))
72}
73
74impl super::TypedPass for PropagateRoi {
75 fn reset(&mut self) -> TractResult<()> {
76 Ok(())
77 }
78
79 fn next(
80 &mut self,
81 _session: &mut OptimizerSession,
82 _model: &TypedModel,
83 ) -> TractResult<Option<TypedModelPatch>> {
84 Ok(None)
85 }
86
87 fn run_direct(&mut self, model: &mut TypedModel) -> TractResult<bool> {
88 let order = model.eval_order()?;
89 let mut changed = false;
90
91 let mut demands: HashMap<OutletId, Option<TDim>> = HashMap::new();
93
94 for &node_id in &order {
95 let node = &model.nodes()[node_id];
96 let Some(input_rois) = node.op.as_typed().unwrap().input_roi(model, node)? else {
97 continue;
98 };
99 for (ix, roi) in input_rois.into_iter().enumerate() {
100 let outlet = node.inputs[ix];
101 match (demands.get(&outlet), &roi) {
102 (_, None) => {
103 demands.insert(outlet, None);
104 }
105 (Option::None, Some(roi)) => {
106 demands.insert(outlet, Some(roi.clone()));
107 }
108 (Some(None), Some(_)) => {}
109 (Some(Some(existing)), Some(new)) => {
110 demands.insert(outlet, Some(roi_union(existing, new)));
111 }
112 }
113 }
114 }
115
116 for (outlet, demand) in demands {
118 if let Some(roi) = demand {
119 let fact = &mut model.nodes_mut()[outlet.node].outputs[outlet.slot].fact;
120 if fact.region_of_interest.as_ref() != Some(&roi) {
121 fact.region_of_interest = Some(roi);
122 changed = true;
123 }
124 }
125 }
126
127 Ok(changed)
128 }
129}