use tract_nnef::internal::*;
use tract_nnef::tract_core::ops::OpStateFreeze;
pub fn register(registry: &mut Registry) {
registry.register_primitive(
"tract_pulse_delay",
&[
TypeName::Scalar.tensor().named("input"),
TypeName::Integer.named("axis"),
TypeName::Integer.named("delay"),
TypeName::Integer.named("overlap"),
],
&[("output", TypeName::Scalar.tensor())],
de_delay,
);
}
fn de_delay(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult<Value> {
let wire = invocation.named_arg_as(builder, "input")?;
let axis = invocation.named_arg_as::<i64>(builder, "axis")? as usize;
let delay = invocation.named_arg_as::<i64>(builder, "delay")? as usize;
let overlap = invocation.named_arg_as::<i64>(builder, "overlap")? as usize;
let input_fact = builder.model.outlet_fact(wire)?;
let op = Delay::new_typed(input_fact, axis, delay, overlap);
builder.wire(op, &[wire])
}
#[derive(Debug, Clone)]
pub struct DelayState {
pub buffer: Option<Tensor>,
}
impl DelayState {
pub unsafe fn apply_delay_unchecked(
&mut self,
op: &Delay,
input: &Tensor,
output: &mut Tensor,
) {
let buffered = op.delay + op.overlap;
let input_pulse = input.shape()[op.axis];
let output_pulse = input_pulse + op.overlap;
let buffer = self.buffer.as_mut().unwrap();
if op.delay < input_pulse {
let from_input = input_pulse - op.delay;
let from_buffer = output_pulse - from_input;
output.assign_slice_unchecked(..from_buffer, buffer, ..from_buffer, op.axis);
output.assign_slice_unchecked(from_buffer.., input, ..from_input, op.axis);
} else {
output.assign_slice_unchecked(.., buffer, ..output_pulse, op.axis);
};
if buffered < input_pulse {
buffer.assign_slice_unchecked(.., input, (input_pulse - buffered).., op.axis);
} else {
let stride = buffer.shape().iter().skip(op.axis + 1).product::<usize>()
* input.datum_type().size_of()
* input_pulse;
std::slice::from_raw_parts_mut(
buffer.as_ptr_mut_unchecked::<u8>(),
buffer.len() * input.datum_type().size_of(),
)
.rotate_left(stride);
buffer.assign_slice_unchecked((buffered - input_pulse).., input, .., op.axis);
}
}
}
impl OpState for DelayState {
fn eval(
&mut self,
_state: &mut SessionState,
op: &dyn Op,
mut inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
let input = args_1!(inputs);
let op = op.downcast_ref::<Delay>().ok_or_else(|| format_err!("Wrong Op type"))?;
let buffered = op.delay + op.overlap;
let input_pulse = input.shape()[op.axis];
let output_pulse = input_pulse + op.overlap;
let mut output_shape: TVec<usize> = input.shape().into();
output_shape[op.axis] = output_pulse;
unsafe {
if self.buffer.is_none() {
let mut shape = input.shape().to_owned();
shape[op.axis] = buffered;
self.buffer = Some(Tensor::uninitialized_dt(input.datum_type(), &shape)?);
};
let mut output = Tensor::uninitialized_dt(input.datum_type(), &output_shape)?;
self.apply_delay_unchecked(op, &input, &mut output);
Ok(tvec!(output.into()))
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Delay {
pub datum_type: DatumType,
pub buffer_shape: TVec<TDim>,
pub axis: usize,
pub delay: usize,
pub overlap: usize,
}
impl_dyn_hash!(Delay);
impl Delay {
pub fn new_typed(input_fact: &TypedFact, axis: usize, delay: usize, overlap: usize) -> Delay {
let mut buffer_shape: TVec<TDim> = input_fact.shape.to_tvec();
buffer_shape[axis] = (delay + overlap).to_dim();
Delay { datum_type: input_fact.datum_type, buffer_shape, axis, delay, overlap }
}
}
impl Op for Delay {
fn name(&self) -> Cow<str> {
"Delay".into()
}
fn info(&self) -> TractResult<Vec<String>> {
Ok(vec![
format!("axis: {} delay: {} overlap: {}", self.axis, self.delay, self.overlap),
format!("buffer: {:?} {:?}", self.buffer_shape, self.datum_type),
])
}
impl_op_same_as!();
op_as_typed_op!();
}
impl EvalOp for Delay {
fn is_stateless(&self) -> bool {
false
}
fn state(
&self,
_session: &mut SessionState,
_node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(Some(Box::new(DelayState { buffer: None })))
}
}
impl TypedOp for Delay {
as_op!();
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut fact = inputs[0].clone();
fact.shape.set(self.axis, fact.shape[self.axis].clone() + self.overlap.to_dim());
Ok(tvec!(fact))
}
fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
Ok(tvec!((Cost::Buffer(self.datum_type), self.buffer_shape.iter().product())))
}
fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
if self.axis != 0 {
Ok(tvec!((InOut::In(0), AxisOp::Move(self.axis, 0))))
} else {
Ok(tvec!())
}
}
fn change_axes(
&self,
model: &TypedModel,
node: &TypedNode,
_io: InOut,
change: &AxisOp,
) -> TractResult<Option<AxisChangeConsequence>> {
if let Some(axis) = change.transform_axis(self.axis) {
if axis != self.axis {
Ok(Some(AxisChangeConsequence::new(
model,
node,
Some(Box::new(Self { axis, ..self.clone() }) as _),
change,
)))
} else {
Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
}
} else {
Ok(None)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
struct FrozenDelayState {
buffer: Option<Arc<Tensor>>,
}
impl OpStateFreeze for DelayState {
fn freeze(&self) -> Box<dyn FrozenOpState> {
Box::new(FrozenDelayState { buffer: self.buffer.as_ref().map(|t| t.clone().into_arc_tensor()) })
}
}
impl FrozenOpState for FrozenDelayState {
fn unfreeze(&self) -> Box<dyn OpState> {
Box::new(DelayState { buffer: self.buffer.as_ref().map(|t| t.clone().into_tensor()) })
}
}