use ndarray::*;
use super::*;
use tract_data::internal::*;
#[derive(Debug, Clone, new, Hash)]
pub struct LirScanOpParams {
pub skip: usize,
pub plan: Arc<TypedSimplePlan<TypedModel>>,
pub input_mapping: Vec<InputMapping>,
pub output_mapping: Vec<OutputMapping<TDim>>,
}
#[derive(Debug, Clone, new, Hash)]
pub struct LirScan(Arc<LirScanOpParams>);
impl std::ops::Deref for LirScan {
type Target = LirScanOpParams;
fn deref(&self) -> &LirScanOpParams {
&self.0
}
}
impl_dyn_hash!(LirScan);
impl LirScan {
pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
let (outside_slot, axis, chunk) = self
.input_mapping
.iter()
.filter_map(|it| match it {
InputMapping::Scan { axis, slot, chunk } => Some((*slot, *axis, *chunk)),
_ => None,
})
.next()
.unwrap();
let outside_dim = inputs[outside_slot].shape[axis].clone();
Some(outside_dim / chunk)
}
}
impl Op for LirScan {
fn name(&self) -> Cow<str> {
"Scan".into()
}
fn info(&self) -> TractResult<Vec<String>> {
let mut lines = vec![];
for (ix, im) in self.input_mapping.iter().enumerate() {
lines.push(format!("Model input #{}: {:?}", ix, im));
}
for (ix, om) in self.output_mapping.iter().enumerate() {
lines.push(format!("Model output #{}: {:?}", ix, om));
}
Ok(lines)
}
op_core_lir!();
op_as_typed_op!();
}
impl EvalOp for LirScan {
fn is_stateless(&self) -> bool {
false
}
fn state(
&self,
_session: &mut SessionState,
_node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(Some(Box::new(State {
mutable: MutableState {
position: 0,
hidden_state: tvec!(),
model_state: TypedSimpleState::new(Arc::clone(&self.plan))?,
},
op: Arc::clone(&self.0),
})))
}
}
#[derive(Clone, Debug)]
struct State {
op: Arc<LirScanOpParams>,
mutable: MutableState,
}
#[derive(Clone, Debug)]
struct MutableState {
position: usize,
hidden_state: TVec<Tensor>,
model_state: TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
}
impl MutableState {
pub(super) fn slice_input_t<T: Datum>(
&self,
input: &Tensor,
axis: usize,
chunk_ix: usize,
chunk_dim: isize,
) -> TractResult<Tensor> {
let view = input.to_array_view::<T>()?;
let full_len = view.shape()[axis];
if chunk_dim < 0 {
let mut shape: TVec<usize> = view.shape().into();
let chunk_dim = (-chunk_dim) as usize;
shape[axis] = chunk_dim;
let mut t = ArrayD::<T>::default(&*shape);
for i in 0..chunk_dim {
if chunk_dim * chunk_ix + i < full_len {
t.index_axis_mut(Axis(axis), chunk_dim - i - 1).assign(
&view.index_axis(Axis(axis), full_len - 1 - (chunk_ix * chunk_dim + i)),
);
}
}
Ok(t.into_tensor())
} else if (chunk_ix + 1) * chunk_dim as usize > full_len {
let chunk_dim = chunk_dim as usize;
let remain = full_len - chunk_ix * chunk_dim;
let mut shape: TVec<usize> = view.shape().into();
shape[axis] = chunk_dim;
let mut t = ArrayD::<T>::default(&*shape);
t.slice_axis_mut(Axis(axis), (0..remain).into())
.assign(&view.slice_axis(Axis(axis), (chunk_ix * chunk_dim..).into()));
Ok(t.into_tensor())
} else {
let chunk_dim = chunk_dim as usize;
Ok(view
.slice_axis(Axis(axis), (chunk_ix * chunk_dim..(chunk_ix + 1) * chunk_dim).into())
.to_owned()
.into_tensor())
}
}
pub(super) fn alloc_output_t<T: Datum + Default>(
&self,
shape: &[usize],
) -> TractResult<Tensor> {
unsafe { Tensor::uninitialized::<T>(&shape) }
}
pub(super) fn assign_output_t<T: Datum + Default>(
&self,
output: &mut Tensor,
axis: usize,
element_value: &Tensor,
i: usize,
backward: bool,
) -> TractResult<()> {
let mut view = output.to_array_view_mut::<T>()?;
let full_len = view.shape()[axis];
let element = element_value.to_array_view::<T>()?;
let offset = if backward {
full_len - 1 - i * element_value.shape()[axis]
} else {
i * element_value.shape()[axis]
};
let count = element_value.shape()[axis].min(view.shape()[axis] - offset);
view.slice_axis_mut(Axis(axis), (offset..offset + count).into())
.assign(&element.slice_axis(Axis(axis), (..count).into()));
Ok(())
}
}
impl OpState for State {
fn eval(
&mut self,
_session: &mut SessionState,
_op: &dyn Op,
inputs: TVec<Arc<Tensor>>,
) -> TractResult<TVec<Arc<Tensor>>> {
let State { op, ref mut mutable } = self;
if mutable.hidden_state.len() == 0 {
for input in &op.input_mapping {
if let InputMapping::State { initializer } = input {
mutable.hidden_state.push(match initializer {
StateInitializer::FromInput(slot) => (*inputs[*slot]).to_owned(),
StateInitializer::Value(v) => (**v).to_owned(),
});
}
}
}
let iters = {
let (outside_slot, axis, chunk) = op
.input_mapping
.iter()
.filter_map(|it| match it {
InputMapping::Scan { axis, slot, chunk } => Some((*slot, *axis, *chunk)),
_ => None,
})
.next()
.unwrap();
inputs[outside_slot].shape()[axis].div_ceil(chunk.abs() as usize)
};
let mut outputs = tvec!();
for (ix, output) in op.output_mapping.iter().enumerate() {
if let Some(slot) = output.full_slot {
let fact = op.plan.model().output_fact(ix)?;
let mut shape: TVec<usize> = fact.shape.as_finite().unwrap().into();
let scanning_dim = output
.full_dim_hint
.as_ref()
.and_then(|d| d.to_usize().ok())
.unwrap_or(shape[output.axis] * iters);
shape[output.axis] = scanning_dim;
let t = dispatch_datum!(MutableState::alloc_output_t(fact.datum_type)(
mutable, &*shape
))?;
outputs.push((slot, t));
}
if let Some(slot) = output.last_value_slot {
outputs.push((slot, Tensor::default()));
}
}
outputs.sort_by_key(|a| a.0);
let mut outputs: TVec<Tensor> = outputs.into_iter().map(|(_slot, v)| v).collect();
for i in 0..iters {
mutable.position += 1;
if mutable.position <= op.skip {
continue;
}
mutable.hidden_state.reverse();
let iter_inputs: TVec<Tensor> = op
.input_mapping
.iter()
.map(|m| {
Ok(match m {
InputMapping::State { .. } => Some(mutable.hidden_state.pop().unwrap()),
InputMapping::Scan { slot, axis, chunk } => Some(dispatch_datum!(
MutableState::slice_input_t(inputs[*slot].datum_type())(
mutable,
inputs[*slot].as_ref(),
*axis,
i,
*chunk
)
)?),
InputMapping::Full { slot } => Some(inputs[*slot].clone().into_tensor()),
})
})
.collect::<TractResult<Vec<_>>>()?
.into_iter()
.filter_map(|x| x)
.collect();
trace!("iter_inputs: {:?}", iter_inputs);
let iter_outputs =
mutable.model_state.run(iter_inputs).with_context(|| "Evaluating inner body")?;
trace!("iter_outputs: {:?}", iter_outputs);
for (v, mapping) in iter_outputs.into_iter().zip(&op.output_mapping) {
if let Some(slot) = mapping.full_slot {
dispatch_datum!(MutableState::assign_output_t(v.datum_type())(
mutable,
&mut outputs[slot],
mapping.axis,
v.as_ref(),
i,
mapping.chunk < 0
))?;
}
if i == iters - 1 {
if let Some(slot) = mapping.last_value_slot {
outputs[slot] = v.clone().into_tensor();
}
}
if mapping.state {
mutable.hidden_state.push(v.into_tensor());
}
}
}
Ok(outputs.into_iter().map(Arc::new).collect())
}
}
impl TypedOp for LirScan {
as_op!();
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut outputs = tvec!();
let iters = {
let (outside_slot, axis, chunk) = self
.input_mapping
.iter()
.filter_map(|it| match it {
InputMapping::Scan { axis, slot, chunk } => Some((*slot, *axis, *chunk)),
_ => None,
})
.next()
.unwrap();
inputs[outside_slot].shape[axis].clone().div_ceil(chunk.abs() as _)
};
for (ix, output) in self.output_mapping.iter().enumerate() {
let fact = self.plan.model().output_fact(ix)?;
if let Some(slot) = output.last_value_slot {
outputs.push((slot, TypedFact::dt_shape(fact.datum_type, fact.shape.clone())?));
}
if let Some(slot) = output.full_slot {
let mut shape = fact.shape.clone();
let scanning_dim = output
.full_dim_hint
.clone()
.unwrap_or(shape[output.axis].maybe_mul(&iters)?);
shape[output.axis] = scanning_dim;
outputs.push((slot, TypedFact::dt_shape(fact.datum_type, shape)?));
}
}
outputs.sort_by_key(|a| a.0);
let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
Ok(outputs)
}
}