1use crate::ops::OpStateFreeze;
2
3use super::*;
4use tract_data::internal::*;
5
6#[derive(Debug, Clone, new)]
7pub struct ScanOpParams {
8 pub skip: usize,
9 pub reset_every_turn: bool,
10 pub plan: Arc<TypedSimplePlan>,
11 pub input_mapping: Vec<InputMapping>,
12 pub output_mapping: Vec<OutputMapping<TDim>>,
13}
14
15#[derive(Debug, Clone, new)]
16pub struct OptScan(Arc<ScanOpParams>);
17
18impl std::ops::Deref for OptScan {
19 type Target = ScanOpParams;
20 fn deref(&self) -> &ScanOpParams {
21 &self.0
22 }
23}
24
25impl PartialEq for OptScan {
26 fn eq(&self, _other: &Self) -> bool {
27 false
28 }
29}
30impl Eq for OptScan {}
31
32impl OptScan {
33 pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
34 super::iteration_count(&self.input_mapping, inputs)
35 }
36}
37
38impl Op for OptScan {
39 fn name(&self) -> StaticName {
40 "Scan".into()
41 }
42
43 fn info(&self) -> TractResult<Vec<String>> {
44 let mut lines = vec![];
45 for (ix, im) in self.input_mapping.iter().enumerate() {
46 lines.push(format!("Model input #{ix}: {im:?}"));
47 }
48 for (ix, om) in self.output_mapping.iter().enumerate() {
49 lines.push(format!("Model output #{ix}: {om:?}"));
50 }
51 Ok(lines)
52 }
53
54 op_as_typed_op!();
55}
56
57impl EvalOp for OptScan {
58 fn is_stateless(&self) -> bool {
59 false
60 }
61
62 fn state(
63 &self,
64 _session: &TurnState,
65 _node_id: usize,
66 ) -> TractResult<Option<Box<dyn OpState>>> {
67 Ok(Some(Box::new(State {
68 position: 0,
69 hidden_state: tvec!(),
70 model_state: self.plan.spawn()?,
71 op: Arc::clone(&self.0),
72 })))
73 }
74}
75
76#[derive(Clone, Debug)]
77pub struct State {
78 op: Arc<ScanOpParams>,
79 position: usize,
80 hidden_state: TVec<TValue>,
81 pub model_state: TypedSimpleState,
82}
83
84#[derive(Debug, Clone)]
85struct FrozenState {
86 op: Arc<ScanOpParams>,
87 position: usize,
88 hidden_state: TVec<Tensor>,
89 model_state: TypedFrozenSimpleState,
90}
91
92impl OpStateFreeze for State {
93 fn freeze(&self) -> Box<dyn FrozenOpState> {
94 Box::new(FrozenState {
95 op: self.op.clone(),
96 position: self.position,
97 hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tensor()).collect(),
98 model_state: self.model_state.freeze(),
99 })
100 }
101
102 fn freeze_into(self: Box<Self>) -> Box<dyn FrozenOpState> {
103 Box::new(FrozenState {
104 op: self.op,
105 position: self.position,
106 hidden_state: self.hidden_state.into_iter().map(|t| t.into_tensor()).collect(),
107 model_state: self.model_state.freeze_into(),
108 })
109 }
110}
111
112impl FrozenOpState for FrozenState {
113 fn unfreeze(&self) -> Box<dyn OpState> {
114 Box::new(State {
115 op: self.op.clone(),
116 position: self.position,
117 hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tvalue()).collect(),
118 model_state: self.model_state.unfreeze(),
119 })
120 }
121}
122
123impl State {
124 pub fn iteration_count(&self, inputs: &TVec<TValue>) -> usize {
125 let (slot, info) = self
126 .op
127 .input_mapping
128 .iter()
129 .enumerate()
130 .find_map(|(ix, it)| it.as_scan().map(|scan| (ix, scan)))
131 .unwrap();
132 inputs[slot].shape()[info.axis].divceil(info.chunk.unsigned_abs())
133 }
134
135 pub(super) fn slice_input(
136 input: &Tensor,
137 axis: usize,
138 chunk_ix: usize,
139 chunk_dim: isize,
140 ) -> TractResult<Tensor> {
141 unsafe {
142 let full_len = input.shape()[axis];
143 let mut shape: TVec<usize> = input.shape().into();
144 shape[axis] = chunk_dim.unsigned_abs();
145 let mut t = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
146 if chunk_dim < 0 {
147 let chunk_dim = (-chunk_dim) as usize;
148 for i in 0..chunk_dim {
149 if chunk_dim * chunk_ix + i < full_len {
150 let dst_ix = chunk_dim - i - 1;
151 let src_ix = full_len - 1 - (chunk_ix * chunk_dim + i);
152 t.assign_slice_unchecked(dst_ix..=dst_ix, input, src_ix..=src_ix, axis);
153 }
154 }
155 } else if (chunk_ix + 1) * chunk_dim as usize > full_len {
156 let chunk_dim = chunk_dim as usize;
157 let remain = full_len - chunk_ix * chunk_dim;
158 let mut shape: TVec<usize> = input.shape().into();
159 shape[axis] = chunk_dim;
160 t.assign_slice_unchecked(..remain, input, chunk_ix * chunk_dim.., axis);
161 } else {
162 let start = chunk_dim as usize * chunk_ix;
163 let end = start + chunk_dim as usize;
164 t.assign_slice_unchecked(.., input, start..end, axis);
165 }
166 Ok(t)
167 }
168 }
169
170 pub(super) fn assign_output(
171 output: &mut Tensor,
172 axis: usize,
173 element_value: &Tensor,
174 i: usize,
175 backward: bool,
176 ) {
177 let full_len = output.shape()[axis];
178 let offset = if backward {
179 full_len - 1 - i * element_value.shape()[axis]
180 } else {
181 i * element_value.shape()[axis]
182 };
183 let count = element_value.shape()[axis].min(output.shape()[axis] - offset);
184 unsafe {
185 output.assign_slice_unchecked(offset..offset + count, element_value, ..count, axis)
186 };
187 }
188}
189
190impl OpState for State {
191 fn eval(
192 &mut self,
193 session: &mut TurnState,
194 _op: &dyn Op,
195 inputs: TVec<TValue>,
196 ) -> TractResult<TVec<TValue>> {
197 let iters = self.iteration_count(&inputs);
198
199 let &mut State { ref op, ref mut hidden_state, ref mut position, ref mut model_state } =
200 self;
201
202 if op.reset_every_turn {
204 hidden_state.clear()
205 }
206 if hidden_state.len() == 0 {
207 for (slot, input) in op.input_mapping.iter().enumerate() {
208 if input.is_state() {
209 hidden_state.push(inputs[slot].clone());
210 }
211 }
212 }
213
214 let mut outputs = tvec!();
215 for (ix, output) in op.output_mapping.iter().enumerate() {
216 if let Some((slot, info)) = output.scan {
217 let fact = op.plan.model().output_fact(ix)?;
218 let mut shape: TVec<usize> =
219 fact.shape.eval_to_usize(&session.resolved_symbols)?.into_owned();
220 let scanning_dim = output
221 .full_dim_hint
222 .as_ref()
223 .and_then(|d| d.to_usize().ok())
224 .unwrap_or(shape[info.axis] * iters);
225 shape[info.axis] = scanning_dim;
226 let t = unsafe { Tensor::uninitialized_dt(fact.datum_type, &shape)? };
227 outputs.push((slot, t));
228 }
229 if let Some(slot) = output.last_value_slot {
230 outputs.push((slot, Tensor::default()));
231 }
232 }
233 outputs.sort_by_key(|a| a.0);
234 let mut outputs: TVec<Tensor> = outputs.into_iter().map(|(_slot, v)| v).collect();
235
236 for i in 0..iters {
237 *position += 1;
238 if *position <= op.skip {
239 continue;
240 }
241 hidden_state.reverse();
242
243 let iter_inputs: TVec<TValue> = op
244 .input_mapping
245 .iter()
246 .enumerate()
247 .map(|(slot, m)| {
248 Ok(match m {
249 InputMapping::State => Some(hidden_state.pop().unwrap()),
250 InputMapping::Scan(info) => Some(
251 Self::slice_input(&inputs[slot], info.axis, i, info.chunk)?
252 .into_tvalue(),
253 ),
254 InputMapping::Full => Some(inputs[slot].clone()),
255 })
256 })
257 .collect::<TractResult<Vec<_>>>()?
258 .into_iter()
259 .flatten()
260 .collect();
261
262 trace!("iter_inputs #{i}: {iter_inputs:?}");
263 let iter_outputs =
264 model_state.run(iter_inputs).with_context(|| "Evaluating inner body")?;
265 trace!("iter_outputs #{i}: {iter_outputs:?}");
266
267 for (v, mapping) in iter_outputs.into_iter().zip(&op.output_mapping) {
268 if let Some((slot, info)) = mapping.scan {
269 Self::assign_output(&mut outputs[slot], info.axis, &v, i, info.chunk < 0);
270 }
271 if i == iters - 1
272 && let Some(slot) = mapping.last_value_slot
273 {
274 outputs[slot] = v.clone().into_tensor();
275 }
276 if mapping.state {
277 hidden_state.push(v);
278 }
279 }
280 }
281
282 Ok(outputs.into_iter().map(|t| t.into_tvalue()).collect())
283 }
284}
285
286impl TypedOp for OptScan {
287 as_op!();
288
289 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
290 let mut outputs = tvec!();
291 let iters = super::iteration_count(&self.input_mapping, inputs).unwrap();
292 for (ix, output) in self.output_mapping.iter().enumerate() {
293 let fact = self.plan.model().output_fact(ix)?;
294 if let Some(slot) = output.last_value_slot {
295 outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
296 }
297 if let Some((slot, info)) = output.scan {
298 let mut shape = fact.shape.clone();
299 let scanning_dim =
300 output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
301 shape.set(info.axis, scanning_dim);
302 outputs.push((slot, fact.datum_type.fact(shape)));
303 }
304 }
305 outputs.sort_by_key(|a| a.0);
306 let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
307 Ok(outputs)
308 }
309
310 fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(StaticName, TDim)> {
311 vec![(
312 "loop".into(),
313 super::iteration_count(&self.input_mapping, inputs).unwrap_or_else(|| 1.to_dim()),
314 )]
315 }
316}