use crate::internal::*;
use ndarray::prelude::*;
use crate::ops::cnn::pools::{ConcretePoolGeometry, PoolGeometry, PoolSpec};
#[derive(Debug, Clone, new, Hash)]
pub struct MaxPool {
pub pool_spec: PoolSpec,
pub with_index_outputs: Option<DatumType>,
}
impl_dyn_hash!(MaxPool);
impl Op for MaxPool {
fn name(&self) -> Cow<str> {
"MaxPool".into()
}
fn info(&self) -> TractResult<Vec<String>> {
Ok(self.pool_spec.info())
}
op_as_typed_op!();
}
impl EvalOp for MaxPool {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let shape: TVec<TDim> = inputs[0].shape().iter().map(|d| d.to_dim()).collect();
self.to_lir(&shape)?.eval(inputs)
}
}
impl TypedOp for MaxPool {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut facts = self.pool_spec.output_facts(inputs)?;
if let Some(idt) = self.with_index_outputs {
facts.push(facts[0].clone());
facts[1].datum_type = idt;
}
Ok(facts)
}
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if self.with_index_outputs.is_some()
&& node.outputs[1].successors.len() == 0
&& !model.output_outlets()?.contains(&OutletId::new(node.id, 1))
{
let op = Self { with_index_outputs: None, ..self.clone() };
let mut patch = TypedModelPatch::default();
let mut wire = patch.tap_model(model, node.inputs[0])?;
wire = patch.wire_node(&node.name, op, &[wire])?[0];
patch.shunt_outside(model, node.id.into(), wire)?;
return Ok(Some(patch));
}
Ok(None)
}
as_op!();
}
impl MaxPool {
fn to_lir(&self, input_shape: &[TDim]) -> TractResult<LirMaxPool> {
Ok(LirMaxPool {
pool_spec: self.pool_spec.clone(),
with_index_outputs: self.with_index_outputs,
geometry: self.pool_spec.compute_geo(input_shape)?,
})
}
}
#[derive(Debug, Clone, new, Hash)]
pub struct LirMaxPool {
pub pool_spec: PoolSpec,
pub with_index_outputs: Option<DatumType>,
pub geometry: PoolGeometry,
}
impl_dyn_hash!(LirMaxPool);
impl Op for LirMaxPool {
fn name(&self) -> Cow<str> {
"LirMaxPool".into()
}
fn info(&self) -> TractResult<Vec<String>> {
Ok(self.pool_spec.info())
}
op_as_typed_op!();
}
impl EvalOp for LirMaxPool {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let input = args_1!(inputs);
let geo = self.geometry.to_concrete(input.shape())?;
dispatch_numbers!(Self::eval_t(input.datum_type())(self, &*input, geo.as_ref()))
}
}
impl TypedOp for LirMaxPool {
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let mut facts = self.pool_spec.output_facts(inputs)?;
if let Some(idt) = self.with_index_outputs {
facts.push(facts[0].clone());
facts[1].datum_type = idt;
}
Ok(facts)
}
as_op!();
}
impl LirMaxPool {
fn eval_t<T: Datum + Copy + num_traits::Bounded + PartialOrd>(
&self,
input: &Tensor,
geo: &ConcretePoolGeometry,
) -> TractResult<TVec<TValue>> {
let input_dt = input.datum_type();
let input: ArrayViewD<T> = input.to_array_view()?;
let input_ptr = input.as_ptr();
let mut values = unsafe { ArrayD::<T>::uninit(&*geo.output_shape.shape).assume_init() };
let mut indices = if self.with_index_outputs.is_some() {
Some(unsafe { ArrayD::<i32>::uninit(&*geo.output_shape.shape).assume_init() })
} else {
None
};
let n = *geo.input_shape.n().unwrap_or(&1);
let n_stride_i = geo.input_shape.n_stride().unwrap_or(&0);
let n_stride_o = geo.output_shape.n_stride().unwrap_or(&0);
unsafe {
geo.patch.visit_output(|visitor| {
for n in 0..n {
let input_offset = n * n_stride_i;
let output_offset = n * n_stride_o;
for c in 0..*geo.input_shape.c() {
let input_offset = input_offset + geo.input_shape.c_stride() * c;
let output_offset = output_offset + geo.output_shape.c_stride() * c;
let max = visitor
.valid_offsets()
.map(|v| (v, *input_ptr.offset(v + input_offset as isize)))
.fold((0, T::min_value()), |acc, v| if acc.1 < v.1 { v } else { acc });
*values
.as_mut_ptr()
.offset(output_offset as isize + visitor.output_offset) = max.1;
if let Some(ref mut indices) = indices {
*indices
.as_mut_ptr()
.offset(output_offset as isize + visitor.output_offset) =
max.0 as i32 / geo.patch.spec.output_inner_stride as i32;
}
}
}
});
}
let mut values = values.into_tensor();
unsafe {
values.set_datum_type(input_dt);
}
if let Some(dt) = self.with_index_outputs {
Ok(tvec!(
values.into_tvalue(),
indices.unwrap().into_tensor().cast_to_dt(dt)?.into_owned().into_tvalue()
))
} else {
Ok(tvec!(values.into_tvalue()))
}
}
}