tract_core/ops/scan/
optimized.rs

1use crate::ops::OpStateFreeze;
2
3use super::*;
4use tract_data::internal::*;
5
6#[derive(Debug, Clone, new)]
7pub struct ScanOpParams {
8    pub skip: usize,
9    pub reset_every_turn: bool,
10    pub plan: Arc<TypedSimplePlan<TypedModel>>,
11    pub input_mapping: Vec<InputMapping>,
12    pub output_mapping: Vec<OutputMapping<TDim>>,
13}
14
15#[derive(Debug, Clone, new)]
16pub struct OptScan(Arc<ScanOpParams>);
17
18impl std::ops::Deref for OptScan {
19    type Target = ScanOpParams;
20    fn deref(&self) -> &ScanOpParams {
21        &self.0
22    }
23}
24
25impl OptScan {
26    pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
27        super::iteration_count(&self.input_mapping, inputs)
28    }
29}
30
31impl Op for OptScan {
32    fn name(&self) -> Cow<str> {
33        "Scan".into()
34    }
35
36    fn info(&self) -> TractResult<Vec<String>> {
37        let mut lines = vec![];
38        for (ix, im) in self.input_mapping.iter().enumerate() {
39            lines.push(format!("Model input  #{ix}: {im:?}"));
40        }
41        for (ix, om) in self.output_mapping.iter().enumerate() {
42            lines.push(format!("Model output #{ix}: {om:?}"));
43        }
44        Ok(lines)
45    }
46
47    op_as_typed_op!();
48}
49
50impl EvalOp for OptScan {
51    fn is_stateless(&self) -> bool {
52        false
53    }
54
55    fn state(
56        &self,
57        _session: &mut SessionState,
58        _node_id: usize,
59    ) -> TractResult<Option<Box<dyn OpState>>> {
60        Ok(Some(Box::new(State {
61            position: 0,
62            hidden_state: tvec!(),
63            model_state: TypedSimpleState::new(Arc::clone(&self.plan))?,
64            op: Arc::clone(&self.0),
65        })))
66    }
67}
68
69#[derive(Clone, Debug)]
70pub struct State {
71    op: Arc<ScanOpParams>,
72    position: usize,
73    hidden_state: TVec<TValue>,
74    pub model_state: TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
75}
76
77#[derive(Debug, Clone)]
78struct FrozenState {
79    op: Arc<ScanOpParams>,
80    position: usize,
81    hidden_state: TVec<Tensor>,
82    model_state: TypedFrozenSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
83}
84
85impl OpStateFreeze for State {
86    fn freeze(&self) -> Box<dyn FrozenOpState> {
87        Box::new(FrozenState {
88            op: self.op.clone(),
89            position: self.position,
90            hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tensor()).collect(),
91            model_state: self.model_state.freeze(),
92        })
93    }
94}
95
96impl FrozenOpState for FrozenState {
97    fn unfreeze(&self) -> Box<dyn OpState> {
98        Box::new(State {
99            op: self.op.clone(),
100            position: self.position,
101            hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tvalue()).collect(),
102            model_state: self.model_state.unfreeze(),
103        })
104    }
105}
106
107impl State {
108    pub fn iteration_count(&self, inputs: &TVec<TValue>) -> usize {
109        let (slot, info) = self
110            .op
111            .input_mapping
112            .iter()
113            .enumerate()
114            .find_map(|(ix, it)| it.as_scan().map(|scan| (ix, scan)))
115            .unwrap();
116        inputs[slot].shape()[info.axis].divceil(info.chunk.unsigned_abs())
117    }
118
119    pub(super) fn slice_input(
120        input: &Tensor,
121        axis: usize,
122        chunk_ix: usize,
123        chunk_dim: isize,
124    ) -> TractResult<Tensor> {
125        unsafe {
126            let full_len = input.shape()[axis];
127            let mut shape: TVec<usize> = input.shape().into();
128            shape[axis] = chunk_dim.unsigned_abs();
129            let mut t = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
130            if chunk_dim < 0 {
131                let chunk_dim = (-chunk_dim) as usize;
132                for i in 0..chunk_dim {
133                    if chunk_dim * chunk_ix + i < full_len {
134                        let dst_ix = chunk_dim - i - 1;
135                        let src_ix = full_len - 1 - (chunk_ix * chunk_dim + i);
136                        t.assign_slice_unchecked(dst_ix..=dst_ix, input, src_ix..=src_ix, axis);
137                    }
138                }
139            } else if (chunk_ix + 1) * chunk_dim as usize > full_len {
140                let chunk_dim = chunk_dim as usize;
141                let remain = full_len - chunk_ix * chunk_dim;
142                let mut shape: TVec<usize> = input.shape().into();
143                shape[axis] = chunk_dim;
144                t.assign_slice_unchecked(..remain, input, chunk_ix * chunk_dim.., axis);
145            } else {
146                let start = chunk_dim as usize * chunk_ix;
147                let end = start + chunk_dim as usize;
148                t.assign_slice_unchecked(.., input, start..end, axis);
149            }
150            Ok(t)
151        }
152    }
153
154    pub(super) fn assign_output(
155        output: &mut Tensor,
156        axis: usize,
157        element_value: &Tensor,
158        i: usize,
159        backward: bool,
160    ) {
161        let full_len = output.shape()[axis];
162        let offset = if backward {
163            full_len - 1 - i * element_value.shape()[axis]
164        } else {
165            i * element_value.shape()[axis]
166        };
167        let count = element_value.shape()[axis].min(output.shape()[axis] - offset);
168        unsafe {
169            output.assign_slice_unchecked(offset..offset + count, element_value, ..count, axis)
170        };
171    }
172}
173
174impl OpState for State {
175    fn eval(
176        &mut self,
177        session: &mut SessionState,
178        _op: &dyn Op,
179        inputs: TVec<TValue>,
180    ) -> TractResult<TVec<TValue>> {
181        let iters = self.iteration_count(&inputs);
182
183        let State { op, ref mut hidden_state, ref mut position, ref mut model_state } = self;
184
185        // initialize state at first pass, or when forced
186        if op.reset_every_turn {
187            hidden_state.clear()
188        }
189        if hidden_state.len() == 0 {
190            for (slot, input) in op.input_mapping.iter().enumerate() {
191                if input.is_state() {
192                    hidden_state.push(inputs[slot].clone());
193                }
194            }
195        }
196
197        let mut outputs = tvec!();
198        for (ix, output) in op.output_mapping.iter().enumerate() {
199            if let Some((slot, info)) = output.scan {
200                let fact = op.plan.model().output_fact(ix)?;
201                let mut shape: TVec<usize> =
202                    fact.shape.eval_to_usize(&session.resolved_symbols)?.into_owned();
203                let scanning_dim = output
204                    .full_dim_hint
205                    .as_ref()
206                    .and_then(|d| d.to_usize().ok())
207                    .unwrap_or(shape[info.axis] * iters);
208                shape[info.axis] = scanning_dim;
209                let t = unsafe { Tensor::uninitialized_dt(fact.datum_type, &shape)? };
210                outputs.push((slot, t));
211            }
212            if let Some(slot) = output.last_value_slot {
213                outputs.push((slot, Tensor::default()));
214            }
215        }
216        outputs.sort_by_key(|a| a.0);
217        let mut outputs: TVec<Tensor> = outputs.into_iter().map(|(_slot, v)| v).collect();
218
219        for i in 0..iters {
220            *position += 1;
221            if *position <= op.skip {
222                continue;
223            }
224            hidden_state.reverse();
225
226            let iter_inputs: TVec<TValue> = op
227                .input_mapping
228                .iter()
229                .enumerate()
230                .map(|(slot, m)| {
231                    Ok(match m {
232                        InputMapping::State => Some(hidden_state.pop().unwrap()),
233                        InputMapping::Scan(info) => Some(
234                            Self::slice_input(&inputs[slot], info.axis, i, info.chunk)?
235                                .into_tvalue(),
236                        ),
237                        InputMapping::Full => Some(inputs[slot].clone()),
238                    })
239                })
240                .collect::<TractResult<Vec<_>>>()?
241                .into_iter()
242                .flatten()
243                .collect();
244
245            trace!("iter_inputs #{}: {:?}", i, iter_inputs);
246            let iter_outputs =
247                model_state.run(iter_inputs).with_context(|| "Evaluating inner body")?;
248            trace!("iter_outputs #{}: {:?}", i, iter_outputs);
249
250            for (v, mapping) in iter_outputs.into_iter().zip(&op.output_mapping) {
251                if let Some((slot, info)) = mapping.scan {
252                    Self::assign_output(&mut outputs[slot], info.axis, &v, i, info.chunk < 0);
253                }
254                if i == iters - 1 {
255                    if let Some(slot) = mapping.last_value_slot {
256                        outputs[slot] = v.clone().into_tensor();
257                    }
258                }
259                if mapping.state {
260                    hidden_state.push(v);
261                }
262            }
263        }
264
265        Ok(outputs.into_iter().map(|t| t.into_tvalue()).collect())
266    }
267}
268
269impl TypedOp for OptScan {
270    as_op!();
271
272    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
273        let mut outputs = tvec!();
274        let iters = super::iteration_count(&self.input_mapping, inputs).unwrap();
275        for (ix, output) in self.output_mapping.iter().enumerate() {
276            let fact = self.plan.model().output_fact(ix)?;
277            if let Some(slot) = output.last_value_slot {
278                outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
279            }
280            if let Some((slot, info)) = output.scan {
281                let mut shape = fact.shape.clone();
282                let scanning_dim =
283                    output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
284                shape.set(info.axis, scanning_dim);
285                outputs.push((slot, fact.datum_type.fact(shape)));
286            }
287        }
288        outputs.sort_by_key(|a| a.0);
289        let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
290        Ok(outputs)
291    }
292}