use ndarray::*;
use crate::ops::prelude::*;
use insideout::InsideOut;
use super::conv_gemm::ConvGemm;
use super::im2col::Im2Col;
use super::Conv;
use crate::ops::nn::conv::KernelFormat;
use crate::ops::nn::{DataFormat, PaddingSpec, Patch};
use std::sync::Arc;
use tract_linalg::MatMul;
#[derive(Debug, Clone)]
pub struct ConvUnary {
pub data_fmt: DataFormat,
pub kernel_fmt: KernelFormat,
pub padding: PaddingSpec,
pub dilations: TVec<usize>,
pub strides: TVec<usize>,
pub kernel: Tensor,
pub bias: Option<Tensor>,
pub full_input_shape: TVec<TDim>,
pub full_output_shape: TVec<TDim>,
pub group: usize,
}
impl ConvUnary {
pub fn new(
conv: &Conv,
full_input_shape: &[TDim],
full_output_shape: &[TDim],
kernel: Tensor,
bias: Option<Tensor>,
group: usize,
) -> TractResult<ConvUnary> {
let spatial_rank = full_input_shape.len() - 2;
let dilations = conv
.dilations
.as_ref()
.map(|a| TVec::from(&**a))
.unwrap_or(tvec!(1; spatial_rank));
let strides = conv
.strides
.as_ref()
.map(|a| TVec::from(&**a))
.unwrap_or(tvec!(1; spatial_rank));
let unary = ConvUnary {
data_fmt: conv.data_fmt,
kernel_fmt: conv.kernel_fmt,
padding: conv.padding.clone(),
dilations,
strides,
kernel,
bias,
full_input_shape: full_input_shape.into(),
full_output_shape: full_output_shape.into(),
group,
};
Ok(unary)
}
fn to_im2col_pair<T>(&self, input_full_shape: &[usize]) -> TractResult<(Im2Col<T>, ConvGemm<T>)>
where
T: Datum + Clone + ndarray::LinalgScalar + std::ops::AddAssign<T> + PartialEq,
{
trace!("input {:?} {:?}", self.data_fmt, input_full_shape);
trace!("kernl {:?} {:?}", self.kernel_fmt, self.kernel.shape());
let output_channels: usize = match self.kernel_fmt {
KernelFormat::OIHW => self.kernel.shape()[0],
KernelFormat::HWIO => *self.kernel.shape().last().unwrap(),
};
let kernel_spatial_shape =
&self.kernel.shape()[self.kernel_fmt.h_axis()..][..(input_full_shape.len() - 2)];
trace!("kernel spatial shape {:?}", kernel_spatial_shape);
let patch = Patch::new(
self.data_fmt,
self.dilations.clone(),
kernel_spatial_shape.into(),
&self.padding,
self.strides.clone(),
input_full_shape.into(),
);
let shape: TVec<usize> = patch.output_full_shape(output_channels);
let kernel = self.kernel.to_array_view::<T>()?;
let m = output_channels / self.group;
let k = kernel.len() / output_channels;
let n = patch
.output_spatial_shape
.iter()
.cloned()
.product::<usize>();
let mm: Arc<MatMul<T>> = T::packed_mat_mul(m, k, n)
.ok_or_else(|| {
format!(
"Can not perfom convolution on {:?} (not a linear algebra type)",
T::datum_type()
)
})?
.into();
let packed_b_len = mm.packed_b_len();
trace!(
"Gemm iters={} m={} k={} n={}",
patch.input_shape.n_dim() * self.group,
m,
k,
n
);
let kernel_reshaped = (output_channels, k);
let kernel: Array2<T> = match self.kernel_fmt {
KernelFormat::HWIO => {
let mut permutation: Vec<usize> = vec![kernel.ndim() - 1, kernel.ndim() - 2];
permutation.extend(0..(kernel.ndim() - 2));
let permuted = kernel.permuted_axes(permutation);
Array2::<T>::from_shape_vec(
kernel_reshaped,
permuted.iter().cloned().collect::<Vec<_>>(),
)?
}
KernelFormat::OIHW => kernel.into_shape(kernel_reshaped)?.to_owned(),
};
let mut packed_kernels: Vec<Tensor> = vec![];
let co_per_group = output_channels / self.group;
for g in 0..self.group {
let subkernel =
kernel.slice_axis(Axis(0), (co_per_group * g..co_per_group * (g + 1)).into());
let mut packed = unsafe {
Tensor::uninitialized_aligned::<T>(&[mm.packed_a_len()], mm.packed_a_alignment())?
};
mm.pack_a(
packed.as_slice_mut()?.as_mut_ptr(),
subkernel.as_ptr(),
subkernel.strides()[0],
subkernel.strides()[1],
);
packed_kernels.push(packed);
}
let bias: Option<ArrayD<T>> = self
.bias
.as_ref()
.map(|bias| -> TractResult<_> {
let mut bias_shape: Vec<usize> = ::std::iter::repeat(1).take(shape.len()).collect();
bias_shape[1] = output_channels;
Ok(bias
.to_array_view::<T>()?
.into_shape(&*bias_shape)?
.to_owned())
})
.inside_out()?;
let im2col = Im2Col::new(patch.clone(), m, k, n, self.group, packed_b_len, mm.clone());
trace!("im2col: {:?}", im2col);
let conv_gemm = ConvGemm::new(
patch,
shape,
m,
k,
n,
self.kernel_fmt,
packed_kernels,
bias,
self.group,
mm.clone(),
);
trace!("cvgemm: {:?}", conv_gemm);
Ok((im2col, conv_gemm))
}
fn to_boxed_im2col_pair<T>(&self, input_full_shape: &[usize]) -> TractResult<(Box<Op>, Box<Op>)>
where
T: Datum + Clone + ::ndarray::LinalgScalar + ::std::ops::AddAssign<T> + PartialEq,
{
let (op1, op2) = self.to_im2col_pair::<T>(input_full_shape)?;
Ok((Box::new(op1), Box::new(op2)))
}
fn eval_t<T>(&self, mut inputs: TVec<SharedTensor>) -> TractResult<TVec<SharedTensor>>
where
T: Datum + Clone + ::ndarray::LinalgScalar + ::std::ops::AddAssign<T> + PartialEq,
{
let input = args_1!(inputs);
let (im2col, conv_gemm) = self.to_im2col_pair::<T>(input.shape())?;
let mega = im2col.im2col(&input.to_array_view()?)?;
let output = conv_gemm.conv_gemm(&mega.to_array_view::<T>()?.into_dimensionality()?)?;
Ok(tvec!(output.into()))
}
pub fn rm_dummy_axis(&self, axis: usize) -> TractResult<Option<ConvUnary>> {
let shape = self.data_fmt.shape(&self.full_input_shape);
if axis < shape.h_axis() {
return Ok(None);
}
let geo_axis = axis - shape.h_axis();
if geo_axis >= shape.hw_rank() {
return Ok(None);
}
if self.dilations[geo_axis] != 1
|| self.strides[geo_axis] != 1
|| !self.padding.valid_dim(geo_axis)
{
return Ok(None);
}
let kernel_spatial_shape =
&self.kernel.shape()[self.kernel_fmt.h_axis()..][..shape.hw_rank()];
if kernel_spatial_shape[geo_axis] != 1 {
return Ok(None);
}
fn copy_rm_nth<D: DimLike>(input: &[D], nth: usize) -> TVec<D> {
input
.iter()
.enumerate()
.filter(|&(ax, _)| ax != nth)
.map(|(_, &d)| d)
.collect()
}
let kernel_shape: TVec<usize> = copy_rm_nth(
self.kernel.shape().clone(),
geo_axis + self.kernel_fmt.h_axis(),
);
let kernel = self.kernel.clone().into_shape(&kernel_shape)?;
let new_op = ConvUnary {
data_fmt: self.data_fmt,
kernel_fmt: self.kernel_fmt,
padding: self.padding.rm_axis(geo_axis),
dilations: copy_rm_nth(&self.dilations, geo_axis),
strides: copy_rm_nth(&self.strides, geo_axis),
kernel,
bias: self.bias.clone(),
full_input_shape: copy_rm_nth(&self.full_input_shape, axis),
full_output_shape: copy_rm_nth(&self.full_output_shape, axis),
group: self.group,
};
Ok(Some(new_op))
}
}
impl Op for ConvUnary {
fn name(&self) -> Cow<str> {
"ConvUnary".into()
}
fn reduce(
&self,
inputs: TVec<&TensorFact>,
_outputs: TVec<&TensorFact>,
phase: ReductionPhase,
) -> TractResult<Option<ReducedOpRewire>> {
if phase == ReductionPhase::Normalize {
return Ok(None);
}
let spatial_rank = self.full_input_shape.len() - 2;
let kernel_spatial_shape = &self.kernel.shape()[self.kernel_fmt.h_axis()..][..spatial_rank];
if kernel_spatial_shape.iter().product::<usize>() == 1
&& self.dilations.iter().all(|&x| x == 1)
&& self.strides.iter().all(|&x| x == 1)
&& self.group == 1
&& self.bias.is_none()
&& (0..spatial_rank).all(|ax| self.padding.valid_dim(ax))
{
if self.kernel_fmt == KernelFormat::HWIO && self.data_fmt == DataFormat::NHWC {
use crate::ops::math::mat_mul::MatMulUnaryA;
let kernel_shape = &self.kernel.shape()[spatial_rank..];
let kernel = self.kernel.clone().into_shape(&kernel_shape)?;
return Ok(Some(ReducedOpRewire::unary(MatMulUnaryA::new(kernel))));
}
} else if let (Some(shape), Some(dt)) = (
inputs[0].shape.concretize(),
inputs[0].datum_type.concretize(),
) {
if inputs[0].stream_info()?.is_none() {
let shape: Vec<usize> = shape
.iter()
.map(|d| d.to_integer().unwrap() as usize)
.collect();
let (op1, op2) = dispatch_floatlike!(Self::to_boxed_im2col_pair(dt)(self, &shape))?;
return Ok(Some(ReducedOpRewire {
ops: vec![op1, op2],
rewired: tvec!(0),
}));
}
}
Ok(None)
}
fn pulsify(
&self,
mut inputs: TVec<&PulsedTensorFact>,
) -> TractResult<Vec<crate::pulse::PulsifiedOp>> {
let input = args_1!(inputs);
let shape = self.data_fmt.shape(&input.shape);
if input.axis == shape.n_axis() {
let mut op = self.clone();
op.full_output_shape[input.axis] = input.pulse().to_dim();
let mut fact = input.clone();
fact.shape = op
.full_output_shape
.iter()
.enumerate()
.map(|(ax, &d)| {
if ax == input.axis {
input.pulse()
} else {
d.to_integer().unwrap() as usize
}
})
.collect();
Ok(vec![PulsifiedOp::new(Box::new(op), tvec!(fact))])
} else if input.axis == shape.c_axis() {
bail!("Can not pulsify convolution alongs the input channel axis");
} else {
let spatial_rank = self.full_input_shape.len() - 2;
let geo_axis = input.axis - shape.h_axis();
let kernel_spatial_shape =
&self.kernel.shape()[self.kernel_fmt.h_axis()..][..spatial_rank];
let kernel_len = (kernel_spatial_shape[geo_axis] - 1)
* self.strides[geo_axis]
* self.dilations[geo_axis];
let mut augmented_fact = input.clone();
augmented_fact.shape[augmented_fact.axis] += kernel_len;
augmented_fact.delay += kernel_len;
let mut conv_op = self.clone();
conv_op.full_input_shape[input.axis] = augmented_fact.pulse().to_dim();
conv_op.full_output_shape[input.axis] =
(augmented_fact.pulse() - kernel_len / self.strides[geo_axis]).to_dim();
let mut conv_fact = input.clone();
conv_fact.shape = self
.full_output_shape
.iter()
.enumerate()
.map(|(ax, &d)| {
if ax == input.axis {
input.pulse() / self.strides[geo_axis]
} else {
d.to_integer().unwrap() as usize
}
})
.collect();
conv_fact.delay += kernel_len;
conv_fact.dim -= kernel_len.to_dim();
let memory = PulsifiedOp::new(
Box::new(crate::pulse::delay::Delay::new(
input.clone(),
0,
kernel_len,
)),
tvec!(augmented_fact),
);
let conv = PulsifiedOp::new(Box::new(conv_op), tvec!(conv_fact));
Ok(vec![memory, conv])
}
}
}
impl StatelessOp for ConvUnary {
fn eval(&self, inputs: TVec<SharedTensor>) -> TractResult<TVec<SharedTensor>> {
dispatch_floatlike!(Self::eval_t(inputs[0].datum_type())(self, inputs))
}
}
impl InferenceRulesOp for ConvUnary {
fn rules<'r, 'p: 'r, 's: 'r>(
&'s self,
s: &mut Solver<'r>,
inputs: &'p SharedTensorsProxy,
outputs: &'p SharedTensorsProxy,
) -> InferenceResult {
s.equals(&inputs.len, 1)?;
s.equals(&outputs.len, 1)?;
s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
s.equals(&inputs[0].shape, self.full_input_shape.clone())?;
s.equals(&outputs[0].shape, self.full_output_shape.clone())?;
Ok(())
}
}