1use crate::ops::OpStateFreeze;
2
3use super::*;
4use tract_data::internal::*;
5
6#[derive(Debug, Clone, new)]
7pub struct ScanOpParams {
8 pub skip: usize,
9 pub reset_every_turn: bool,
10 pub plan: Arc<TypedSimplePlan<TypedModel>>,
11 pub input_mapping: Vec<InputMapping>,
12 pub output_mapping: Vec<OutputMapping<TDim>>,
13}
14
15#[derive(Debug, Clone, new)]
16pub struct OptScan(Arc<ScanOpParams>);
17
18impl std::ops::Deref for OptScan {
19 type Target = ScanOpParams;
20 fn deref(&self) -> &ScanOpParams {
21 &self.0
22 }
23}
24
25impl OptScan {
26 pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
27 super::iteration_count(&self.input_mapping, inputs)
28 }
29}
30
31impl Op for OptScan {
32 fn name(&self) -> Cow<str> {
33 "Scan".into()
34 }
35
36 fn info(&self) -> TractResult<Vec<String>> {
37 let mut lines = vec![];
38 for (ix, im) in self.input_mapping.iter().enumerate() {
39 lines.push(format!("Model input #{ix}: {im:?}"));
40 }
41 for (ix, om) in self.output_mapping.iter().enumerate() {
42 lines.push(format!("Model output #{ix}: {om:?}"));
43 }
44 Ok(lines)
45 }
46
47 op_as_typed_op!();
48}
49
50impl EvalOp for OptScan {
51 fn is_stateless(&self) -> bool {
52 false
53 }
54
55 fn state(
56 &self,
57 _session: &mut SessionState,
58 _node_id: usize,
59 ) -> TractResult<Option<Box<dyn OpState>>> {
60 Ok(Some(Box::new(State {
61 position: 0,
62 hidden_state: tvec!(),
63 model_state: TypedSimpleState::new(Arc::clone(&self.plan))?,
64 op: Arc::clone(&self.0),
65 })))
66 }
67}
68
69#[derive(Clone, Debug)]
70pub struct State {
71 op: Arc<ScanOpParams>,
72 position: usize,
73 hidden_state: TVec<TValue>,
74 pub model_state: TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
75}
76
77#[derive(Debug, Clone)]
78struct FrozenState {
79 op: Arc<ScanOpParams>,
80 position: usize,
81 hidden_state: TVec<Tensor>,
82 model_state: TypedFrozenSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
83}
84
85impl OpStateFreeze for State {
86 fn freeze(&self) -> Box<dyn FrozenOpState> {
87 Box::new(FrozenState {
88 op: self.op.clone(),
89 position: self.position,
90 hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tensor()).collect(),
91 model_state: self.model_state.freeze(),
92 })
93 }
94}
95
96impl FrozenOpState for FrozenState {
97 fn unfreeze(&self) -> Box<dyn OpState> {
98 Box::new(State {
99 op: self.op.clone(),
100 position: self.position,
101 hidden_state: self.hidden_state.iter().map(|t| t.clone().into_tvalue()).collect(),
102 model_state: self.model_state.unfreeze(),
103 })
104 }
105}
106
107impl State {
108 pub fn iteration_count(&self, inputs: &TVec<TValue>) -> usize {
109 let (slot, info) = self
110 .op
111 .input_mapping
112 .iter()
113 .enumerate()
114 .find_map(|(ix, it)| it.as_scan().map(|scan| (ix, scan)))
115 .unwrap();
116 inputs[slot].shape()[info.axis].divceil(info.chunk.unsigned_abs())
117 }
118
119 pub(super) fn slice_input(
120 input: &Tensor,
121 axis: usize,
122 chunk_ix: usize,
123 chunk_dim: isize,
124 ) -> TractResult<Tensor> {
125 unsafe {
126 let full_len = input.shape()[axis];
127 let mut shape: TVec<usize> = input.shape().into();
128 shape[axis] = chunk_dim.unsigned_abs();
129 let mut t = Tensor::uninitialized_dt(input.datum_type(), &shape)?;
130 if chunk_dim < 0 {
131 let chunk_dim = (-chunk_dim) as usize;
132 for i in 0..chunk_dim {
133 if chunk_dim * chunk_ix + i < full_len {
134 let dst_ix = chunk_dim - i - 1;
135 let src_ix = full_len - 1 - (chunk_ix * chunk_dim + i);
136 t.assign_slice_unchecked(dst_ix..=dst_ix, input, src_ix..=src_ix, axis);
137 }
138 }
139 } else if (chunk_ix + 1) * chunk_dim as usize > full_len {
140 let chunk_dim = chunk_dim as usize;
141 let remain = full_len - chunk_ix * chunk_dim;
142 let mut shape: TVec<usize> = input.shape().into();
143 shape[axis] = chunk_dim;
144 t.assign_slice_unchecked(..remain, input, chunk_ix * chunk_dim.., axis);
145 } else {
146 let start = chunk_dim as usize * chunk_ix;
147 let end = start + chunk_dim as usize;
148 t.assign_slice_unchecked(.., input, start..end, axis);
149 }
150 Ok(t)
151 }
152 }
153
154 pub(super) fn assign_output(
155 output: &mut Tensor,
156 axis: usize,
157 element_value: &Tensor,
158 i: usize,
159 backward: bool,
160 ) {
161 let full_len = output.shape()[axis];
162 let offset = if backward {
163 full_len - 1 - i * element_value.shape()[axis]
164 } else {
165 i * element_value.shape()[axis]
166 };
167 let count = element_value.shape()[axis].min(output.shape()[axis] - offset);
168 unsafe {
169 output.assign_slice_unchecked(offset..offset + count, element_value, ..count, axis)
170 };
171 }
172}
173
174impl OpState for State {
175 fn eval(
176 &mut self,
177 session: &mut SessionState,
178 _op: &dyn Op,
179 inputs: TVec<TValue>,
180 ) -> TractResult<TVec<TValue>> {
181 let iters = self.iteration_count(&inputs);
182
183 let State { op, ref mut hidden_state, ref mut position, ref mut model_state } = self;
184
185 if op.reset_every_turn {
187 hidden_state.clear()
188 }
189 if hidden_state.len() == 0 {
190 for (slot, input) in op.input_mapping.iter().enumerate() {
191 if input.is_state() {
192 hidden_state.push(inputs[slot].clone());
193 }
194 }
195 }
196
197 let mut outputs = tvec!();
198 for (ix, output) in op.output_mapping.iter().enumerate() {
199 if let Some((slot, info)) = output.scan {
200 let fact = op.plan.model().output_fact(ix)?;
201 let mut shape: TVec<usize> =
202 fact.shape.eval_to_usize(&session.resolved_symbols)?.into_owned();
203 let scanning_dim = output
204 .full_dim_hint
205 .as_ref()
206 .and_then(|d| d.to_usize().ok())
207 .unwrap_or(shape[info.axis] * iters);
208 shape[info.axis] = scanning_dim;
209 let t = unsafe { Tensor::uninitialized_dt(fact.datum_type, &shape)? };
210 outputs.push((slot, t));
211 }
212 if let Some(slot) = output.last_value_slot {
213 outputs.push((slot, Tensor::default()));
214 }
215 }
216 outputs.sort_by_key(|a| a.0);
217 let mut outputs: TVec<Tensor> = outputs.into_iter().map(|(_slot, v)| v).collect();
218
219 for i in 0..iters {
220 *position += 1;
221 if *position <= op.skip {
222 continue;
223 }
224 hidden_state.reverse();
225
226 let iter_inputs: TVec<TValue> = op
227 .input_mapping
228 .iter()
229 .enumerate()
230 .map(|(slot, m)| {
231 Ok(match m {
232 InputMapping::State => Some(hidden_state.pop().unwrap()),
233 InputMapping::Scan(info) => Some(
234 Self::slice_input(&inputs[slot], info.axis, i, info.chunk)?
235 .into_tvalue(),
236 ),
237 InputMapping::Full => Some(inputs[slot].clone()),
238 })
239 })
240 .collect::<TractResult<Vec<_>>>()?
241 .into_iter()
242 .flatten()
243 .collect();
244
245 trace!("iter_inputs #{}: {:?}", i, iter_inputs);
246 let iter_outputs =
247 model_state.run(iter_inputs).with_context(|| "Evaluating inner body")?;
248 trace!("iter_outputs #{}: {:?}", i, iter_outputs);
249
250 for (v, mapping) in iter_outputs.into_iter().zip(&op.output_mapping) {
251 if let Some((slot, info)) = mapping.scan {
252 Self::assign_output(&mut outputs[slot], info.axis, &v, i, info.chunk < 0);
253 }
254 if i == iters - 1 {
255 if let Some(slot) = mapping.last_value_slot {
256 outputs[slot] = v.clone().into_tensor();
257 }
258 }
259 if mapping.state {
260 hidden_state.push(v);
261 }
262 }
263 }
264
265 Ok(outputs.into_iter().map(|t| t.into_tvalue()).collect())
266 }
267}
268
269impl TypedOp for OptScan {
270 as_op!();
271
272 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
273 let mut outputs = tvec!();
274 let iters = super::iteration_count(&self.input_mapping, inputs).unwrap();
275 for (ix, output) in self.output_mapping.iter().enumerate() {
276 let fact = self.plan.model().output_fact(ix)?;
277 if let Some(slot) = output.last_value_slot {
278 outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
279 }
280 if let Some((slot, info)) = output.scan {
281 let mut shape = fact.shape.clone();
282 let scanning_dim =
283 output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
284 shape.set(info.axis, scanning_dim);
285 outputs.push((slot, fact.datum_type.fact(shape)));
286 }
287 }
288 outputs.sort_by_key(|a| a.0);
289 let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
290 Ok(outputs)
291 }
292}