1use crate::ops::OpStateFreeze;
2
3use super::*;
4use tract_data::internal::*;
5
6#[derive(Debug, Clone, new)]
7pub struct ScanOpParams {
8 pub skip: usize,
9 pub reset_every_turn: bool,
10 pub plan: Arc<TypedSimplePlan<TypedModel>>,
11 pub input_mapping: Vec<InputMapping>,
12 pub output_mapping: Vec<OutputMapping<TDim>>,
13}
14
15#[derive(Debug, Clone, new)]
16pub struct OptScan(Arc<ScanOpParams>);
17
18impl std::ops::Deref for OptScan {
19 type Target = ScanOpParams;
20 fn deref(&self) -> &ScanOpParams {
21 &self.0
22 }
23}
24
25impl OptScan {
26 pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
27 super::iteration_count(&self.input_mapping, inputs)
28 }
29}
30
31impl Op for OptScan {
32 fn name(&self) -> StaticName {
33 "Scan".into()
34 }
35
36 fn info(&self) -> TractResult<Vec<String>> {
37 let mut lines = vec![];
38 for (ix, im) in self.input_mapping.iter().enumerate() {
39 lines.push(format!("Model input #{ix}: {im:?}"));
40 }
41 for (ix, om) in self.output_mapping.iter().enumerate() {
42 lines.push(format!("Model output #{ix}: {om:?}"));
43 }
44 Ok(lines)
45 }
46
47 op_as_typed_op!();
48}
49
50impl EvalOp for OptScan {
51 fn is_stateless(&self) -> bool {
52 false
53 }
54
55 fn state(
56 &self,
57 _session: &mut SessionState,
58 _node_id: usize,
59 ) -> TractResult<Option<Box<dyn OpState>>> {
60 Ok(Some(Box::new(State {
61 position: 0,
62 hidden_state: tvec!(),
63 model_state: TypedSimpleState::new(Arc::clone(&self.plan))?,
64 op: Arc::clone(&self.0),
65 })))
66 }
67}
68
69#[derive(Clone, Debug)]
70pub struct State {
71 op: Arc<ScanOpParams>,
72 position: usize,
73 hidden_state: TVec<TValue>,
74 pub model_state: TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
75}
76
77#[derive(Debug, Clone)]
78struct FrozenState {
79 op: Arc<ScanOpParams>,
80 position: usize,
81 hidden_state: TVec<Tensor>,
82 model_state: TypedFrozenSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
83}
84
85impl OpStateFreeze for State {
86 fn freeze(&self) -> Box<dyn FrozenOpState> {
87 Box::new(FrozenState {
88 op: self.op.clone(),
89 position: self.position,
90 hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tensor()).collect(),
91 model_state: self.model_state.freeze(),
92 })
93 }
94}
95
96impl FrozenOpState for FrozenState {
97 fn unfreeze(&self) -> Box<dyn OpState> {
98 Box::new(State {
99 op: self.op.clone(),
100 position: self.position,
101 hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tvalue()).collect(),
102 model_state: self.model_state.unfreeze(),
103 })
104 }
105}
106
107impl State {
108 pub fn iteration_count(&self, inputs: &TVec<TValue>) -> usize {
109 let (slot, info) = self
110 .op
111 .input_mapping
112 .iter()
113 .enumerate()
114 .find_map(|(ix, it)| it.as_scan().map(|scan| (ix, scan)))
115 .unwrap();
116 inputs[slot].shape()[info.axis].divceil(info.chunk.unsigned_abs())
117 }
118
119 pub(super) fn slice_input(
120 input: &Tensor,
121 axis: usize,
122 chunk_ix: usize,
123 chunk_dim: isize,
124 ) -> TractResult<Tensor> {
125 unsafe {
126 let full_len = input.shape()[axis];
127 let mut shape: TVec<usize> = input.shape().into();
128 shape[axis] = chunk_dim.unsigned_abs();
129 let mut t = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
130 if chunk_dim < 0 {
131 let chunk_dim = (-chunk_dim) as usize;
132 for i in 0..chunk_dim {
133 if chunk_dim * chunk_ix + i < full_len {
134 let dst_ix = chunk_dim - i - 1;
135 let src_ix = full_len - 1 - (chunk_ix * chunk_dim + i);
136 t.assign_slice_unchecked(dst_ix..=dst_ix, input, src_ix..=src_ix, axis);
137 }
138 }
139 } else if (chunk_ix + 1) * chunk_dim as usize > full_len {
140 let chunk_dim = chunk_dim as usize;
141 let remain = full_len - chunk_ix * chunk_dim;
142 let mut shape: TVec<usize> = input.shape().into();
143 shape[axis] = chunk_dim;
144 t.assign_slice_unchecked(..remain, input, chunk_ix * chunk_dim.., axis);
145 } else {
146 let start = chunk_dim as usize * chunk_ix;
147 let end = start + chunk_dim as usize;
148 t.assign_slice_unchecked(.., input, start..end, axis);
149 }
150 Ok(t)
151 }
152 }
153
154 pub(super) fn assign_output(
155 output: &mut Tensor,
156 axis: usize,
157 element_value: &Tensor,
158 i: usize,
159 backward: bool,
160 ) {
161 let full_len = output.shape()[axis];
162 let offset = if backward {
163 full_len - 1 - i * element_value.shape()[axis]
164 } else {
165 i * element_value.shape()[axis]
166 };
167 let count = element_value.shape()[axis].min(output.shape()[axis] - offset);
168 unsafe {
169 output.assign_slice_unchecked(offset..offset + count, element_value, ..count, axis)
170 };
171 }
172}
173
174impl OpState for State {
175 fn eval(
176 &mut self,
177 session: &mut SessionState,
178 _op: &dyn Op,
179 inputs: TVec<TValue>,
180 ) -> TractResult<TVec<TValue>> {
181 let iters = self.iteration_count(&inputs);
182
183 let &mut State { ref op, ref mut hidden_state, ref mut position, ref mut model_state } =
184 self;
185
186 if op.reset_every_turn {
188 hidden_state.clear()
189 }
190 if hidden_state.len() == 0 {
191 for (slot, input) in op.input_mapping.iter().enumerate() {
192 if input.is_state() {
193 hidden_state.push(inputs[slot].clone());
194 }
195 }
196 }
197
198 let mut outputs = tvec!();
199 for (ix, output) in op.output_mapping.iter().enumerate() {
200 if let Some((slot, info)) = output.scan {
201 let fact = op.plan.model().output_fact(ix)?;
202 let mut shape: TVec<usize> =
203 fact.shape.eval_to_usize(&session.resolved_symbols)?.into_owned();
204 let scanning_dim = output
205 .full_dim_hint
206 .as_ref()
207 .and_then(|d| d.to_usize().ok())
208 .unwrap_or(shape[info.axis] * iters);
209 shape[info.axis] = scanning_dim;
210 let t = unsafe { Tensor::uninitialized_dt(fact.datum_type, &shape)? };
211 outputs.push((slot, t));
212 }
213 if let Some(slot) = output.last_value_slot {
214 outputs.push((slot, Tensor::default()));
215 }
216 }
217 outputs.sort_by_key(|a| a.0);
218 let mut outputs: TVec<Tensor> = outputs.into_iter().map(|(_slot, v)| v).collect();
219
220 for i in 0..iters {
221 *position += 1;
222 if *position <= op.skip {
223 continue;
224 }
225 hidden_state.reverse();
226
227 let iter_inputs: TVec<TValue> = op
228 .input_mapping
229 .iter()
230 .enumerate()
231 .map(|(slot, m)| {
232 Ok(match m {
233 InputMapping::State => Some(hidden_state.pop().unwrap()),
234 InputMapping::Scan(info) => Some(
235 Self::slice_input(&inputs[slot], info.axis, i, info.chunk)?
236 .into_tvalue(),
237 ),
238 InputMapping::Full => Some(inputs[slot].clone()),
239 })
240 })
241 .collect::<TractResult<Vec<_>>>()?
242 .into_iter()
243 .flatten()
244 .collect();
245
246 trace!("iter_inputs #{i}: {iter_inputs:?}");
247 let iter_outputs =
248 model_state.run(iter_inputs).with_context(|| "Evaluating inner body")?;
249 trace!("iter_outputs #{i}: {iter_outputs:?}");
250
251 for (v, mapping) in iter_outputs.into_iter().zip(&op.output_mapping) {
252 if let Some((slot, info)) = mapping.scan {
253 Self::assign_output(&mut outputs[slot], info.axis, &v, i, info.chunk < 0);
254 }
255 if i == iters - 1 {
256 if let Some(slot) = mapping.last_value_slot {
257 outputs[slot] = v.clone().into_tensor();
258 }
259 }
260 if mapping.state {
261 hidden_state.push(v);
262 }
263 }
264 }
265
266 Ok(outputs.into_iter().map(|t| t.into_tvalue()).collect())
267 }
268}
269
270impl TypedOp for OptScan {
271 as_op!();
272
273 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
274 let mut outputs = tvec!();
275 let iters = super::iteration_count(&self.input_mapping, inputs).unwrap();
276 for (ix, output) in self.output_mapping.iter().enumerate() {
277 let fact = self.plan.model().output_fact(ix)?;
278 if let Some(slot) = output.last_value_slot {
279 outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
280 }
281 if let Some((slot, info)) = output.scan {
282 let mut shape = fact.shape.clone();
283 let scanning_dim =
284 output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
285 shape.set(info.axis, scanning_dim);
286 outputs.push((slot, fact.datum_type.fact(shape)));
287 }
288 }
289 outputs.sort_by_key(|a| a.0);
290 let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
291 Ok(outputs)
292 }
293
294 fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(StaticName, TDim)> {
295 vec![(
296 "loop".into(),
297 super::iteration_count(&self.input_mapping, inputs).unwrap_or_else(|| 1.to_dim()),
298 )]
299 }
300}