Skip to main content

tract_core/ops/scan/
optimized.rs

1use crate::ops::OpStateFreeze;
2
3use super::*;
4use tract_data::internal::*;
5
6#[derive(Debug, Clone, new)]
7pub struct ScanOpParams {
8    pub skip: usize,
9    pub reset_every_turn: bool,
10    pub plan: Arc<TypedSimplePlan>,
11    pub input_mapping: Vec<InputMapping>,
12    pub output_mapping: Vec<OutputMapping<TDim>>,
13}
14
15#[derive(Debug, Clone, new)]
16pub struct OptScan(Arc<ScanOpParams>);
17
18impl std::ops::Deref for OptScan {
19    type Target = ScanOpParams;
20    fn deref(&self) -> &ScanOpParams {
21        &self.0
22    }
23}
24
25impl PartialEq for OptScan {
26    fn eq(&self, _other: &Self) -> bool {
27        false
28    }
29}
30impl Eq for OptScan {}
31
32impl OptScan {
33    pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
34        super::iteration_count(&self.input_mapping, inputs)
35    }
36}
37
38impl Op for OptScan {
39    fn name(&self) -> StaticName {
40        "Scan".into()
41    }
42
43    fn info(&self) -> TractResult<Vec<String>> {
44        let mut lines = vec![];
45        for (ix, im) in self.input_mapping.iter().enumerate() {
46            lines.push(format!("Model input  #{ix}: {im:?}"));
47        }
48        for (ix, om) in self.output_mapping.iter().enumerate() {
49            lines.push(format!("Model output #{ix}: {om:?}"));
50        }
51        Ok(lines)
52    }
53
54    op_as_typed_op!();
55}
56
57impl EvalOp for OptScan {
58    fn is_stateless(&self) -> bool {
59        false
60    }
61
62    fn state(
63        &self,
64        _session: &TurnState,
65        _node_id: usize,
66    ) -> TractResult<Option<Box<dyn OpState>>> {
67        Ok(Some(Box::new(State {
68            position: 0,
69            hidden_state: tvec!(),
70            model_state: self.plan.spawn()?,
71            op: Arc::clone(&self.0),
72        })))
73    }
74}
75
76#[derive(Clone, Debug)]
77pub struct State {
78    op: Arc<ScanOpParams>,
79    position: usize,
80    hidden_state: TVec<TValue>,
81    pub model_state: TypedSimpleState,
82}
83
84#[derive(Debug, Clone)]
85struct FrozenState {
86    op: Arc<ScanOpParams>,
87    position: usize,
88    hidden_state: TVec<Tensor>,
89    model_state: TypedFrozenSimpleState,
90}
91
92impl OpStateFreeze for State {
93    fn freeze(&self) -> Box<dyn FrozenOpState> {
94        Box::new(FrozenState {
95            op: self.op.clone(),
96            position: self.position,
97            hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tensor()).collect(),
98            model_state: self.model_state.freeze(),
99        })
100    }
101
102    fn freeze_into(self: Box<Self>) -> Box<dyn FrozenOpState> {
103        Box::new(FrozenState {
104            op: self.op,
105            position: self.position,
106            hidden_state: self.hidden_state.into_iter().map(|t| t.into_tensor()).collect(),
107            model_state: self.model_state.freeze_into(),
108        })
109    }
110}
111
112impl FrozenOpState for FrozenState {
113    fn unfreeze(&self) -> Box<dyn OpState> {
114        Box::new(State {
115            op: self.op.clone(),
116            position: self.position,
117            hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tvalue()).collect(),
118            model_state: self.model_state.unfreeze(),
119        })
120    }
121}
122
123impl State {
124    pub fn iteration_count(&self, inputs: &TVec<TValue>) -> usize {
125        let (slot, info) = self
126            .op
127            .input_mapping
128            .iter()
129            .enumerate()
130            .find_map(|(ix, it)| it.as_scan().map(|scan| (ix, scan)))
131            .unwrap();
132        inputs[slot].shape()[info.axis].divceil(info.chunk.unsigned_abs())
133    }
134
135    pub(super) fn slice_input(
136        input: &Tensor,
137        axis: usize,
138        chunk_ix: usize,
139        chunk_dim: isize,
140    ) -> TractResult<Tensor> {
141        unsafe {
142            let full_len = input.shape()[axis];
143            let mut shape: TVec<usize> = input.shape().into();
144            shape[axis] = chunk_dim.unsigned_abs();
145            let mut t = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
146            if chunk_dim < 0 {
147                let chunk_dim = (-chunk_dim) as usize;
148                for i in 0..chunk_dim {
149                    if chunk_dim * chunk_ix + i < full_len {
150                        let dst_ix = chunk_dim - i - 1;
151                        let src_ix = full_len - 1 - (chunk_ix * chunk_dim + i);
152                        t.assign_slice_unchecked(dst_ix..=dst_ix, input, src_ix..=src_ix, axis);
153                    }
154                }
155            } else if (chunk_ix + 1) * chunk_dim as usize > full_len {
156                let chunk_dim = chunk_dim as usize;
157                let remain = full_len - chunk_ix * chunk_dim;
158                let mut shape: TVec<usize> = input.shape().into();
159                shape[axis] = chunk_dim;
160                t.assign_slice_unchecked(..remain, input, chunk_ix * chunk_dim.., axis);
161            } else {
162                let start = chunk_dim as usize * chunk_ix;
163                let end = start + chunk_dim as usize;
164                t.assign_slice_unchecked(.., input, start..end, axis);
165            }
166            Ok(t)
167        }
168    }
169
170    pub(super) fn assign_output(
171        output: &mut Tensor,
172        axis: usize,
173        element_value: &Tensor,
174        i: usize,
175        backward: bool,
176    ) {
177        let full_len = output.shape()[axis];
178        let offset = if backward {
179            full_len - 1 - i * element_value.shape()[axis]
180        } else {
181            i * element_value.shape()[axis]
182        };
183        let count = element_value.shape()[axis].min(output.shape()[axis] - offset);
184        unsafe {
185            output.assign_slice_unchecked(offset..offset + count, element_value, ..count, axis)
186        };
187    }
188}
189
190impl OpState for State {
191    fn eval(
192        &mut self,
193        session: &mut TurnState,
194        _op: &dyn Op,
195        inputs: TVec<TValue>,
196    ) -> TractResult<TVec<TValue>> {
197        let iters = self.iteration_count(&inputs);
198
199        let &mut State { ref op, ref mut hidden_state, ref mut position, ref mut model_state } =
200            self;
201
202        // initialize state at first pass, or when forced
203        if op.reset_every_turn {
204            hidden_state.clear()
205        }
206        if hidden_state.len() == 0 {
207            for (slot, input) in op.input_mapping.iter().enumerate() {
208                if input.is_state() {
209                    hidden_state.push(inputs[slot].clone());
210                }
211            }
212        }
213
214        let mut outputs = tvec!();
215        for (ix, output) in op.output_mapping.iter().enumerate() {
216            if let Some((slot, info)) = output.scan {
217                let fact = op.plan.model().output_fact(ix)?;
218                let mut shape: TVec<usize> =
219                    fact.shape.eval_to_usize(&session.resolved_symbols)?.into_owned();
220                let scanning_dim = output
221                    .full_dim_hint
222                    .as_ref()
223                    .and_then(|d| d.to_usize().ok())
224                    .unwrap_or(shape[info.axis] * iters);
225                shape[info.axis] = scanning_dim;
226                let t = unsafe { Tensor::uninitialized_dt(fact.datum_type, &shape)? };
227                outputs.push((slot, t));
228            }
229            if let Some(slot) = output.last_value_slot {
230                outputs.push((slot, Tensor::default()));
231            }
232        }
233        outputs.sort_by_key(|a| a.0);
234        let mut outputs: TVec<Tensor> = outputs.into_iter().map(|(_slot, v)| v).collect();
235
236        for i in 0..iters {
237            *position += 1;
238            if *position <= op.skip {
239                continue;
240            }
241            hidden_state.reverse();
242
243            let iter_inputs: TVec<TValue> = op
244                .input_mapping
245                .iter()
246                .enumerate()
247                .map(|(slot, m)| {
248                    Ok(match m {
249                        InputMapping::State => Some(hidden_state.pop().unwrap()),
250                        InputMapping::Scan(info) => Some(
251                            Self::slice_input(&inputs[slot], info.axis, i, info.chunk)?
252                                .into_tvalue(),
253                        ),
254                        InputMapping::Full => Some(inputs[slot].clone()),
255                    })
256                })
257                .collect::<TractResult<Vec<_>>>()?
258                .into_iter()
259                .flatten()
260                .collect();
261
262            trace!("iter_inputs #{i}: {iter_inputs:?}");
263            let iter_outputs =
264                model_state.run(iter_inputs).with_context(|| "Evaluating inner body")?;
265            trace!("iter_outputs #{i}: {iter_outputs:?}");
266
267            for (v, mapping) in iter_outputs.into_iter().zip(&op.output_mapping) {
268                if let Some((slot, info)) = mapping.scan {
269                    Self::assign_output(&mut outputs[slot], info.axis, &v, i, info.chunk < 0);
270                }
271                if i == iters - 1
272                    && let Some(slot) = mapping.last_value_slot
273                {
274                    outputs[slot] = v.clone().into_tensor();
275                }
276                if mapping.state {
277                    hidden_state.push(v);
278                }
279            }
280        }
281
282        Ok(outputs.into_iter().map(|t| t.into_tvalue()).collect())
283    }
284}
285
286impl TypedOp for OptScan {
287    as_op!();
288
289    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
290        let mut outputs = tvec!();
291        let iters = super::iteration_count(&self.input_mapping, inputs).unwrap();
292        for (ix, output) in self.output_mapping.iter().enumerate() {
293            let fact = self.plan.model().output_fact(ix)?;
294            if let Some(slot) = output.last_value_slot {
295                outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
296            }
297            if let Some((slot, info)) = output.scan {
298                let mut shape = fact.shape.clone();
299                let scanning_dim =
300                    output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
301                shape.set(info.axis, scanning_dim);
302                outputs.push((slot, fact.datum_type.fact(shape)));
303            }
304        }
305        outputs.sort_by_key(|a| a.0);
306        let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
307        Ok(outputs)
308    }
309
310    fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(StaticName, TDim)> {
311        vec![(
312            "loop".into(),
313            super::iteration_count(&self.input_mapping, inputs).unwrap_or_else(|| 1.to_dim()),
314        )]
315    }
316}