use crate::internal::*;
use ndarray::*;
use std::convert::TryFrom;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PadMode {
Constant(Arc<Tensor>),
Reflect,
Edge,
}
impl Default for PadMode {
fn default() -> PadMode {
PadMode::Constant(Arc::new(0.0f32.into()))
}
}
#[derive(Debug, Clone, new, Default, Hash)]
pub struct Pad {
pub pads: Vec<(usize, usize)>,
pub mode: PadMode,
}
impl_dyn_hash!(Pad);
impl Pad {
fn eval_t<T>(&self, input_tensor: TValue) -> TractResult<TValue>
where
T: Copy + Datum,
{
let input = input_tensor.to_array_view::<T>()?;
let output_shape: Vec<usize> =
input.shape().iter().zip(self.pads.iter()).map(|(&d, &(a, b))| d + a + b).collect();
let element = match &self.mode {
PadMode::Constant(f) => *f.to_scalar::<T>()?,
_ => T::default(),
};
let mut output = ArrayD::<T>::from_elem(output_shape, element);
let slice_spec: Vec<SliceInfoElem> = self
.pads
.iter()
.map(|&(a, b)| SliceInfoElem::Slice {
start: a as isize,
end: if b != 0 { Some(-(b as isize)) } else { None },
step: 1,
})
.collect();
let slice_info = SliceInfo::<_, IxDyn, IxDyn>::try_from(slice_spec).unwrap();
output.slice_mut(slice_info.as_ref()).assign(&input);
if self.mode == PadMode::Reflect || self.mode == PadMode::Edge {
for (ax, &(bef, aft)) in self.pads.iter().enumerate() {
let axis = Axis(ax);
let dim = output.shape()[ax];
{
let (mut pad, data) = output.view_mut().split_at(axis, bef);
for i in 0..bef {
let mut target = pad.slice_axis_mut(axis, Slice::from(i..i + 1));
let source_slice = match self.mode {
PadMode::Edge => 0,
PadMode::Reflect => bef - i,
_ => panic!(),
};
let source =
data.slice_axis(axis, Slice::from(source_slice..source_slice + 1));
target.assign(&source);
}
}
{
let (data, mut pad) = output.view_mut().split_at(axis, dim - aft);
for i in 0..aft {
let mut target = pad.slice_axis_mut(axis, Slice::from(i..i + 1));
let source_slice = match self.mode {
PadMode::Edge => dim - aft - 1,
PadMode::Reflect => dim - aft - 2 - i,
_ => panic!(),
};
let source =
data.slice_axis(axis, Slice::from(source_slice..source_slice + 1));
target.assign(&source);
}
}
}
}
let mut output = output.into_tensor();
unsafe { output.set_datum_type(input_tensor.datum_type()) }
Ok(output.into_tvalue())
}
}
impl Op for Pad {
fn name(&self) -> Cow<str> {
"Pad".into()
}
fn info(&self) -> TractResult<Vec<String>> {
Ok(vec![format!("Mode: {:?}, pads: {:?})", self.mode, self.pads,)])
}
op_as_typed_op!();
}
impl EvalOp for Pad {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let input = args_1!(inputs);
Ok(tvec!(dispatch_numbers!(Self::eval_t(input.datum_type())(self, input))?))
}
}
impl TypedOp for Pad {
as_op!();
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut fact = inputs[0].clone();
if self.pads.len() != fact.rank() {
bail!("Inconsistent pad: input of rank {}, pads are: {:?}", fact.rank(), self.pads);
}
for (ix, (b, e)) in self.pads.iter().enumerate() {
fact.shape.set(ix, fact.shape[ix].clone() + *b + *e);
}
Ok(tvec!(fact))
}
fn invariants(
&self,
_inputs: &[&TypedFact],
_outputs: &[&TypedFact],
) -> TractResult<Invariants> {
let mut inv = Invariants::none();
for (axis, pads) in self.pads.iter().enumerate() {
if pads == &(0, 0) {
inv.axes.push(AxisInfo::simple(axis))
}
}
Ok(inv)
}
fn change_axes(
&self,
model: &TypedModel,
node: &TypedNode,
io: InOut,
change: &AxisOp,
) -> TractResult<Option<AxisChangeConsequence>> {
let mut new_op = self.clone();
if let (InOut::In(0), AxisOp::Rm(ix)) = (io, change) {
if new_op.pads.remove(*ix) == (0, 0) {
return Ok(Some(AxisChangeConsequence::new(
model,
node,
Some(Box::new(new_op)),
change,
)));
}
}
if let (InOut::In(0), AxisOp::Add(ix)) = (io, change) {
new_op.pads.insert(*ix, (0,0));
return Ok(Some(AxisChangeConsequence::new(
model,
node,
Some(Box::new(new_op)),
change,
)));
}
Ok(None)
}
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.pads.iter().all(|p| p.0 == 0 && p.1 == 0) {
TypedModelPatch::shunt_one_op(model, node)
} else {
Ok(None)
}
}
}