1use crate::infer::*;
2use crate::internal::*;
3
4pub use tract_core::ops::scan::Scan;
5use tract_core::ops::scan::ScanInfo;
6pub use tract_core::ops::scan::{InputMapping, OutputMapping};
7
8#[derive(Debug, Clone, new, Default)]
9pub struct InferenceScan {
10 pub body: InferenceModel,
11 pub input_mapping: Vec<InputMapping>,
12 pub output_mapping: Vec<OutputMapping<TDim>>,
13 pub clean_scan_counts: bool,
14 pub iter_count_fact: GenericFactoid<TDim>,
15}
16
17impl Op for InferenceScan {
18 fn name(&self) -> StaticName {
19 "Scan".into()
20 }
21
22 fn info(&self) -> TractResult<Vec<String>> {
23 let mut lines = vec![];
24 for (ix, im) in self.input_mapping.iter().enumerate() {
25 lines.push(format!("Model input #{ix}: {im:?}"));
26 }
27 for (ix, om) in self.output_mapping.iter().enumerate() {
28 lines.push(format!("Model output #{ix}: {om:?}"));
29 }
30 Ok(lines)
31 }
32
33 not_a_typed_op!();
34}
35
36impl EvalOp for InferenceScan {
37 fn is_stateless(&self) -> bool {
38 false
39 }
40
41 fn state(
42 &self,
43 session: &mut SessionState,
44 node_id: usize,
45 ) -> TractResult<Option<Box<dyn OpState>>> {
46 self.to_mir_scan()?.state(session, node_id)
47 }
48}
49
50impl InferenceScan {
51 pub(super) fn to_mir_scan(&self) -> TractResult<Box<Scan>> {
52 let typed_model = self.body.clone().into_typed()?;
53 let input_mapping = self
54 .input_mapping
55 .iter()
56 .enumerate()
57 .map(|(ix, im)| {
58 Ok(match im {
59 InputMapping::Scan(info) => InputMapping::Scan(ScanInfo {
60 chunk: typed_model.input_fact(ix)?.shape[info.axis].to_isize()?,
61 ..*info
62 }),
63 other => other.clone(),
64 })
65 })
66 .collect::<TractResult<_>>()?;
67 let output_mapping = self
68 .output_mapping
69 .iter()
70 .enumerate()
71 .map(|(ix, im)| {
72 let scan = if let Some((slot, scan)) = im.scan {
73 Some((
74 slot,
75 ScanInfo {
76 chunk: typed_model.input_fact(ix)?.shape[scan.axis].to_isize()?,
77 ..scan
78 },
79 ))
80 } else {
81 None
82 };
83 Ok(OutputMapping {
84 state: im.state,
85 scan,
86 full_dim_hint: im.full_dim_hint.clone(),
87 last_value_slot: im.last_value_slot,
88 })
89 })
90 .collect::<TractResult<_>>()?;
91 Ok(Box::new(Scan::new(typed_model, input_mapping, output_mapping, 0)?))
92 }
93
94 fn unify_scanning_tensor_fact(
95 outer: &mut InferenceFact,
96 inner: &mut InferenceFact,
97 outer_scan_axis: usize,
98 ) -> TractResult<bool> {
99 let mut changed = outer.datum_type.unify_with_mut(&mut inner.datum_type)?;
100 let rank = outer
101 .shape
102 .rank()
103 .concretize()
104 .or_else(|| inner.shape.rank().concretize())
105 .map(|r| r as usize);
106 if let Some(rank) = rank {
107 if outer.shape.unify_with(&ShapeFactoid::closed(tvec!(GenericFactoid::Any; rank)))? {
108 changed = true;
109 }
110 if inner.shape.unify_with(&ShapeFactoid::closed(tvec!(GenericFactoid::Any; rank)))? {
111 changed = true;
112 }
113 for axis in 0..rank {
114 if axis != outer_scan_axis {
115 let value = outer
116 .shape
117 .dim(axis)
118 .unwrap()
119 .concretize()
120 .or_else(|| inner.shape.dim(axis).unwrap().concretize());
121 if let Some(value) = value {
122 if outer.shape.set_dim(axis, value.clone()) {
123 changed = true
124 }
125 if inner.shape.set_dim(axis, value) {
126 changed = true
127 }
128 }
129 }
130 }
131 }
132 Ok(changed)
133 }
134
135 fn unify_facts(
136 &mut self,
137 inputs: &mut [InferenceFact],
138 outputs: &mut [InferenceFact],
139 ) -> TractResult<bool> {
140 let mut changed = false;
141 let hidden_state_len = self.input_mapping.iter().filter(|m| m.is_state()).count();
142 #[allow(clippy::needless_range_loop)]
143 for state_ix in 0..hidden_state_len {
144 trace!("Unify hidden state #{state_ix}");
145 let inner_model_output_ix = self
146 .output_mapping
147 .iter()
148 .enumerate()
149 .filter(|(_ix, map)| map.state)
150 .nth(state_ix)
151 .unwrap()
152 .0;
153 let mut facts = self.body.outlets_fact_mut(&[
154 self.body.input_outlets()?[state_ix],
155 self.body.output_outlets()?[inner_model_output_ix],
156 ])?;
157 facts.push(&mut inputs[state_ix]);
158 if Factoid::unify_all(
159 &mut facts.iter_mut().map(|f| &mut f.datum_type).collect::<TVec<_>>(),
160 )? {
161 changed = true;
162 }
163 if Factoid::unify_all(&mut facts.iter_mut().map(|f| &mut f.shape).collect::<TVec<_>>())?
164 {
165 changed = true;
166 }
167 }
168 for (slot, i) in self.input_mapping.iter().enumerate() {
169 match i {
170 InputMapping::State => {}
171 InputMapping::Full => {
172 if inputs[slot].unify_with_mut(self.body.input_fact_mut(slot)?)? {
173 changed = true;
174 }
175 }
176 InputMapping::Scan(scan) => {
177 let incoming = &mut inputs[slot];
178 let inner = self.body.input_fact_mut(slot)?;
179 if Self::unify_scanning_tensor_fact(incoming, inner, scan.axis)? {
180 changed = true;
181 };
182 if self.clean_scan_counts {
183 if incoming.shape.ensure_rank_at_least(scan.axis) {
184 changed = true;
185 }
186 let value =
187 self.iter_count_fact.unify(&incoming.shape.dim(scan.axis).unwrap())?;
188 if self.iter_count_fact != value {
189 changed = true;
190 self.iter_count_fact = value.clone();
191 }
192 if incoming.shape.dim(scan.axis).unwrap() != value {
193 changed = true;
194 incoming.shape.set_dim(scan.axis, value.concretize().unwrap());
195 }
196 }
197 }
198 }
199 }
200 for (ix, i) in self.output_mapping.iter().enumerate() {
201 if let Some((slot, scan)) = i.scan {
202 let outgoing = &mut outputs[slot];
203 let inner = self.body.output_fact_mut(ix)?;
204 if Self::unify_scanning_tensor_fact(outgoing, inner, scan.axis)? {
205 changed = true
206 }
207 if self.clean_scan_counts {
208 if outgoing.shape.ensure_rank_at_least(scan.axis) {
209 changed = true;
210 }
211 let value =
212 self.iter_count_fact.unify(&outgoing.shape.dim(scan.axis).unwrap())?;
213 if self.iter_count_fact != value {
214 changed = true;
215 self.iter_count_fact = value.clone();
216 }
217 if outgoing.shape.dim(scan.axis).unwrap() != value {
218 changed = true;
219 outgoing.shape.set_dim(scan.axis, value.concretize().unwrap());
220 }
221 }
222 }
223 if let Some(slot) = i.last_value_slot {
224 if outputs[slot].unify_with(self.body.output_fact_mut(ix)?)? {
225 changed = true;
226 }
227 }
228 }
229 Ok(changed)
230 }
231}
232
233impl InferenceOp for InferenceScan {
234 fn infer_facts(
235 &mut self,
236 inputs: TVec<&InferenceFact>,
237 outputs: TVec<&InferenceFact>,
238 _observed: TVec<&InferenceFact>,
239 ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
240 let body_inputs = self.body.input_outlets()?.len();
241 let body_outputs = self.body.output_outlets()?.len();
242 let expected_op_inputs = self.input_mapping.len();
243 let expected_op_outputs = self
244 .output_mapping
245 .iter()
246 .filter_map(|om| om.last_value_slot)
247 .chain(self.output_mapping.iter().filter_map(|om| om.scan.map(|si| si.0)))
248 .max()
249 .context("No output slot found")?
250 + 1;
251 if inputs.len() != expected_op_inputs {
252 bail!("Scan receives {} inputs, mappings expects {}", inputs.len(), expected_op_inputs)
253 }
254 if body_inputs != self.input_mapping.len() {
255 bail!(
256 "Scan body expect {} inputs, mappings expects {}",
257 body_inputs,
258 self.input_mapping.len()
259 )
260 }
261 if outputs.len() != expected_op_outputs {
262 bail!("Scan has {} outputs, mappings expects {}", outputs.len(), expected_op_outputs);
263 }
264 if body_outputs != self.output_mapping.len() {
265 bail!(
266 "Scan body expect {} outputs, mappings expects {}",
267 body_outputs,
268 self.output_mapping.len()
269 )
270 }
271 let mut inputs: TVec<InferenceFact> = inputs.into_iter().cloned().collect();
272 let mut outputs: TVec<InferenceFact> = outputs.into_iter().cloned().collect();
273 loop {
274 trace!("Unify inner and outer interface");
275 let mut changed = self.unify_facts(&mut inputs, &mut outputs)?;
276 trace!("iters: {:?} changed: {:?}", self.iter_count_fact, changed);
277 for (ix, input) in self.body.input_outlets()?.iter().enumerate() {
278 trace!(" Input inner model: {} {:?} {:?}", ix, input, self.body.input_fact(ix));
279 }
280 for (ix, output) in self.body.output_outlets()?.iter().enumerate() {
281 trace!(" Output inner model: {} {:?} {:?}", ix, output, self.body.output_fact(ix));
282 }
283 trace!("Inner model analyse");
284 if self.body.analyse(false).context("analysing inner model")? {
285 changed = true;
286 }
287 if !changed {
288 break;
289 }
290 trace!("Finished inner model analyse");
291 }
292 Ok((inputs, outputs, tvec!()))
293 }
294
295 fn to_typed(
296 &self,
297 _source: &InferenceModel,
298 node: &InferenceNode,
299 target: &mut TypedModel,
300 mapping: &HashMap<OutletId, OutletId>,
301 ) -> TractResult<TVec<OutletId>> {
302 let inputs = node.inputs.iter().map(|m| mapping[m]).collect::<TVec<_>>();
303 target.wire_node(&*node.name, self.to_mir_scan()? as Box<dyn TypedOp>, &inputs)
304 }
305
306 fn nboutputs(&self) -> TractResult<usize> {
307 Ok(self.output_mapping.iter().filter(|om| !om.invisible()).count())
308 }
309
310 as_op!();
311}