tract_hir/ops/
scan.rs

1use crate::infer::*;
2use crate::internal::*;
3
4pub use tract_core::ops::scan::Scan;
5use tract_core::ops::scan::ScanInfo;
6pub use tract_core::ops::scan::{InputMapping, OutputMapping};
7
8#[derive(Debug, Clone, new, Default)]
9pub struct InferenceScan {
10    pub body: InferenceModel,
11    pub input_mapping: Vec<InputMapping>,
12    pub output_mapping: Vec<OutputMapping<TDim>>,
13    pub clean_scan_counts: bool,
14    pub iter_count_fact: GenericFactoid<TDim>,
15}
16
17impl Op for InferenceScan {
18    fn name(&self) -> StaticName {
19        "Scan".into()
20    }
21
22    fn info(&self) -> TractResult<Vec<String>> {
23        let mut lines = vec![];
24        for (ix, im) in self.input_mapping.iter().enumerate() {
25            lines.push(format!("Model input  #{ix}: {im:?}"));
26        }
27        for (ix, om) in self.output_mapping.iter().enumerate() {
28            lines.push(format!("Model output #{ix}: {om:?}"));
29        }
30        Ok(lines)
31    }
32
33    not_a_typed_op!();
34}
35
36impl EvalOp for InferenceScan {
37    fn is_stateless(&self) -> bool {
38        false
39    }
40
41    fn state(
42        &self,
43        session: &mut SessionState,
44        node_id: usize,
45    ) -> TractResult<Option<Box<dyn OpState>>> {
46        self.to_mir_scan()?.state(session, node_id)
47    }
48}
49
50impl InferenceScan {
51    pub(super) fn to_mir_scan(&self) -> TractResult<Box<Scan>> {
52        let typed_model = self.body.clone().into_typed()?;
53        let input_mapping = self
54            .input_mapping
55            .iter()
56            .enumerate()
57            .map(|(ix, im)| {
58                Ok(match im {
59                    InputMapping::Scan(info) => InputMapping::Scan(ScanInfo {
60                        chunk: typed_model.input_fact(ix)?.shape[info.axis].to_isize()?,
61                        ..*info
62                    }),
63                    other => other.clone(),
64                })
65            })
66            .collect::<TractResult<_>>()?;
67        let output_mapping = self
68            .output_mapping
69            .iter()
70            .enumerate()
71            .map(|(ix, im)| {
72                let scan = if let Some((slot, scan)) = im.scan {
73                    Some((
74                        slot,
75                        ScanInfo {
76                            chunk: typed_model.input_fact(ix)?.shape[scan.axis].to_isize()?,
77                            ..scan
78                        },
79                    ))
80                } else {
81                    None
82                };
83                Ok(OutputMapping {
84                    state: im.state,
85                    scan,
86                    full_dim_hint: im.full_dim_hint.clone(),
87                    last_value_slot: im.last_value_slot,
88                })
89            })
90            .collect::<TractResult<_>>()?;
91        Ok(Box::new(Scan::new(typed_model, input_mapping, output_mapping, 0)?))
92    }
93
94    fn unify_scanning_tensor_fact(
95        outer: &mut InferenceFact,
96        inner: &mut InferenceFact,
97        outer_scan_axis: usize,
98    ) -> TractResult<bool> {
99        let mut changed = outer.datum_type.unify_with_mut(&mut inner.datum_type)?;
100        let rank = outer
101            .shape
102            .rank()
103            .concretize()
104            .or_else(|| inner.shape.rank().concretize())
105            .map(|r| r as usize);
106        if let Some(rank) = rank {
107            if outer.shape.unify_with(&ShapeFactoid::closed(tvec!(GenericFactoid::Any; rank)))? {
108                changed = true;
109            }
110            if inner.shape.unify_with(&ShapeFactoid::closed(tvec!(GenericFactoid::Any; rank)))? {
111                changed = true;
112            }
113            for axis in 0..rank {
114                if axis != outer_scan_axis {
115                    let value = outer
116                        .shape
117                        .dim(axis)
118                        .unwrap()
119                        .concretize()
120                        .or_else(|| inner.shape.dim(axis).unwrap().concretize());
121                    if let Some(value) = value {
122                        if outer.shape.set_dim(axis, value.clone()) {
123                            changed = true
124                        }
125                        if inner.shape.set_dim(axis, value) {
126                            changed = true
127                        }
128                    }
129                }
130            }
131        }
132        Ok(changed)
133    }
134
135    fn unify_facts(
136        &mut self,
137        inputs: &mut [InferenceFact],
138        outputs: &mut [InferenceFact],
139    ) -> TractResult<bool> {
140        let mut changed = false;
141        let hidden_state_len = self.input_mapping.iter().filter(|m| m.is_state()).count();
142        #[allow(clippy::needless_range_loop)]
143        for state_ix in 0..hidden_state_len {
144            trace!("Unify hidden state #{state_ix}");
145            let inner_model_output_ix = self
146                .output_mapping
147                .iter()
148                .enumerate()
149                .filter(|(_ix, map)| map.state)
150                .nth(state_ix)
151                .unwrap()
152                .0;
153            let mut facts = self.body.outlets_fact_mut(&[
154                self.body.input_outlets()?[state_ix],
155                self.body.output_outlets()?[inner_model_output_ix],
156            ])?;
157            facts.push(&mut inputs[state_ix]);
158            if Factoid::unify_all(
159                &mut facts.iter_mut().map(|f| &mut f.datum_type).collect::<TVec<_>>(),
160            )? {
161                changed = true;
162            }
163            if Factoid::unify_all(&mut facts.iter_mut().map(|f| &mut f.shape).collect::<TVec<_>>())?
164            {
165                changed = true;
166            }
167        }
168        for (slot, i) in self.input_mapping.iter().enumerate() {
169            match i {
170                InputMapping::State => {}
171                InputMapping::Full => {
172                    if inputs[slot].unify_with_mut(self.body.input_fact_mut(slot)?)? {
173                        changed = true;
174                    }
175                }
176                InputMapping::Scan(scan) => {
177                    let incoming = &mut inputs[slot];
178                    let inner = self.body.input_fact_mut(slot)?;
179                    if Self::unify_scanning_tensor_fact(incoming, inner, scan.axis)? {
180                        changed = true;
181                    };
182                    if self.clean_scan_counts {
183                        if incoming.shape.ensure_rank_at_least(scan.axis) {
184                            changed = true;
185                        }
186                        let value =
187                            self.iter_count_fact.unify(&incoming.shape.dim(scan.axis).unwrap())?;
188                        if self.iter_count_fact != value {
189                            changed = true;
190                            self.iter_count_fact = value.clone();
191                        }
192                        if incoming.shape.dim(scan.axis).unwrap() != value {
193                            changed = true;
194                            incoming.shape.set_dim(scan.axis, value.concretize().unwrap());
195                        }
196                    }
197                }
198            }
199        }
200        for (ix, i) in self.output_mapping.iter().enumerate() {
201            if let Some((slot, scan)) = i.scan {
202                let outgoing = &mut outputs[slot];
203                let inner = self.body.output_fact_mut(ix)?;
204                if Self::unify_scanning_tensor_fact(outgoing, inner, scan.axis)? {
205                    changed = true
206                }
207                if self.clean_scan_counts {
208                    if outgoing.shape.ensure_rank_at_least(scan.axis) {
209                        changed = true;
210                    }
211                    let value =
212                        self.iter_count_fact.unify(&outgoing.shape.dim(scan.axis).unwrap())?;
213                    if self.iter_count_fact != value {
214                        changed = true;
215                        self.iter_count_fact = value.clone();
216                    }
217                    if outgoing.shape.dim(scan.axis).unwrap() != value {
218                        changed = true;
219                        outgoing.shape.set_dim(scan.axis, value.concretize().unwrap());
220                    }
221                }
222            }
223            if let Some(slot) = i.last_value_slot {
224                if outputs[slot].unify_with(self.body.output_fact_mut(ix)?)? {
225                    changed = true;
226                }
227            }
228        }
229        Ok(changed)
230    }
231}
232
233impl InferenceOp for InferenceScan {
234    fn infer_facts(
235        &mut self,
236        inputs: TVec<&InferenceFact>,
237        outputs: TVec<&InferenceFact>,
238        _observed: TVec<&InferenceFact>,
239    ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
240        let body_inputs = self.body.input_outlets()?.len();
241        let body_outputs = self.body.output_outlets()?.len();
242        let expected_op_inputs = self.input_mapping.len();
243        let expected_op_outputs = self
244            .output_mapping
245            .iter()
246            .filter_map(|om| om.last_value_slot)
247            .chain(self.output_mapping.iter().filter_map(|om| om.scan.map(|si| si.0)))
248            .max()
249            .context("No output slot found")?
250            + 1;
251        if inputs.len() != expected_op_inputs {
252            bail!("Scan receives {} inputs, mappings expects {}", inputs.len(), expected_op_inputs)
253        }
254        if body_inputs != self.input_mapping.len() {
255            bail!(
256                "Scan body expect {} inputs, mappings expects {}",
257                body_inputs,
258                self.input_mapping.len()
259            )
260        }
261        if outputs.len() != expected_op_outputs {
262            bail!("Scan has {} outputs, mappings expects {}", outputs.len(), expected_op_outputs);
263        }
264        if body_outputs != self.output_mapping.len() {
265            bail!(
266                "Scan body expect {} outputs, mappings expects {}",
267                body_outputs,
268                self.output_mapping.len()
269            )
270        }
271        let mut inputs: TVec<InferenceFact> = inputs.into_iter().cloned().collect();
272        let mut outputs: TVec<InferenceFact> = outputs.into_iter().cloned().collect();
273        loop {
274            trace!("Unify inner and outer interface");
275            let mut changed = self.unify_facts(&mut inputs, &mut outputs)?;
276            trace!("iters: {:?} changed: {:?}", self.iter_count_fact, changed);
277            for (ix, input) in self.body.input_outlets()?.iter().enumerate() {
278                trace!("  Input inner model: {} {:?} {:?}", ix, input, self.body.input_fact(ix));
279            }
280            for (ix, output) in self.body.output_outlets()?.iter().enumerate() {
281                trace!("  Output inner model: {} {:?} {:?}", ix, output, self.body.output_fact(ix));
282            }
283            trace!("Inner model analyse");
284            if self.body.analyse(false).context("analysing inner model")? {
285                changed = true;
286            }
287            if !changed {
288                break;
289            }
290            trace!("Finished inner model analyse");
291        }
292        Ok((inputs, outputs, tvec!()))
293    }
294
295    fn to_typed(
296        &self,
297        _source: &InferenceModel,
298        node: &InferenceNode,
299        target: &mut TypedModel,
300        mapping: &HashMap<OutletId, OutletId>,
301    ) -> TractResult<TVec<OutletId>> {
302        let inputs = node.inputs.iter().map(|m| mapping[m]).collect::<TVec<_>>();
303        target.wire_node(&*node.name, self.to_mir_scan()? as Box<dyn TypedOp>, &inputs)
304    }
305
306    fn nboutputs(&self) -> TractResult<usize> {
307        Ok(self.output_mapping.iter().filter(|om| !om.invisible()).count())
308    }
309
310    as_op!();
311}