#[allow(clippy::module_inception)]
mod conv;
mod depth_wise;
mod im2col;
mod lazy_im2col;
mod q_sum_b;
use crate::internal::*;
pub use self::im2col::Im2Col;
pub(crate) use self::q_sum_b::QSumB;
pub use self::conv::Conv;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)]
pub enum KernelFormat {
    #[default]
    OIHW,
    HWIO,
    OHWI,
}
impl KernelFormat {
    pub fn h_axis(&self) -> usize {
        match self {
            KernelFormat::OIHW => 2,
            KernelFormat::HWIO => 0,
            KernelFormat::OHWI => 1,
        }
    }
    pub fn spatial_shape<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
        &full_shape[self.h_axis()..][..full_shape.len() - 2]
    }
    pub fn hw<'a, D>(&self, full_shape: &'a [D]) -> &'a [D] {
        self.spatial_shape(full_shape)
    }
    pub fn i<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
        match self {
            KernelFormat::OIHW => &full_shape[1],
            KernelFormat::HWIO => &full_shape[full_shape.len() - 2],
            KernelFormat::OHWI => &full_shape[full_shape.len() - 1],
        }
    }
    pub fn o_axis<D>(&self, full_shape: &[D]) -> usize {
        match self {
            KernelFormat::OIHW | KernelFormat::OHWI => 0,
            KernelFormat::HWIO => full_shape.len() - 1,
        }
    }
    pub fn o<'a, D>(&self, full_shape: &'a [D]) -> &'a D {
        &full_shape[self.o_axis(full_shape)]
    }
    pub fn input_channels<'s, D: DimLike>(
        &self,
        full_kernel_shape: &'s [D],
        group: usize,
    ) -> Cow<'s, D> {
        match self {
            KernelFormat::OIHW => Cow::Owned(self.i(full_kernel_shape).clone() * group),
            KernelFormat::HWIO | KernelFormat::OHWI => Cow::Borrowed(self.i(full_kernel_shape)),
        }
    }
    pub fn output_channels<'s, D: DimLike>(
        &self,
        full_kernel_shape: &'s [D],
        group: usize,
    ) -> Cow<'s, D> {
        match self {
            KernelFormat::OIHW => Cow::Borrowed(self.o(full_kernel_shape)),
            KernelFormat::HWIO | KernelFormat::OHWI => {
                Cow::Owned(self.o(full_kernel_shape).clone() * group)
            }
        }
    }
    pub fn kernel_as_group_o_i_h_w_ops(
        &self,
        full_shape: &[impl DimLike],
        group: usize,
    ) -> TVec<AxisOp> {
        let geo_rank = full_shape.len() - 2;
        match self {
            KernelFormat::HWIO => {
                tvec!(
                    AxisOp::Reshape(
                        geo_rank,
                        tvec!(self.i(full_shape).to_dim()),
                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
                    ), AxisOp::Move(geo_rank, 0),     AxisOp::Move(geo_rank + 2, 1), AxisOp::Move(geo_rank + 2, 2)
                ) }
            KernelFormat::OIHW => {
                tvec!(AxisOp::Reshape(
                    0,
                    tvec!(self.o(full_shape).to_dim()),
                    tvec!(group.to_dim(), self.o(full_shape).to_dim() / group),
                ))
            }
            KernelFormat::OHWI => {
                tvec!(
                    AxisOp::Reshape(
                        geo_rank + 1,
                        tvec!(self.i(full_shape).to_dim()),
                        tvec!(group.to_dim(), self.i(full_shape).to_dim() / group),
                    ), AxisOp::Move(geo_rank + 1, 0), AxisOp::Move(geo_rank + 2, 2)
                )
            }
        }
    }
    pub fn kernel_as_group_o_i_hw_ops(
        &self,
        full_shape: &[impl DimLike],
        group: usize,
    ) -> TVec<AxisOp> {
        let mut ops = self.kernel_as_group_o_i_h_w_ops(full_shape, group);
        if self.hw(full_shape).len() > 1 {
            ops.push(AxisOp::Reshape(
                3,
                self.hw(full_shape).iter().map(|t| t.to_dim()).collect(),
                tvec!(self.hw(full_shape).iter().map(|t| t.to_dim()).product()),
            ));
        }
        ops
    }
    pub fn kernel_as_group_o_ihw_ops(
        &self,
        full_shape: &[impl DimLike],
        group: usize,
    ) -> TVec<AxisOp> {
        let i = (self.input_channels(full_shape, group).into_owned() / group).to_dim();
        let hw = self.hw(full_shape).iter().map(|t| t.to_dim()).product::<TDim>();
        let mut ops = self.kernel_as_group_o_i_hw_ops(full_shape, group);
        ops.push(AxisOp::Reshape(2, tvec!(i.clone(), hw.clone()), tvec!(i * hw)));
        ops
    }
    pub fn kernel_as_group_o_i_hw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
        let mut kernel = kernel.clone();
        let ops = self.kernel_as_group_o_i_hw_ops(kernel.shape(), group);
        for op in &ops {
            op.change_tensor(&mut kernel, false)?;
        }
        Ok(kernel)
    }
    pub fn kernel_as_group_o_ihw(&self, kernel: &Tensor, group: usize) -> TractResult<Tensor> {
        let group_o_i_hw = self.kernel_as_group_o_i_hw(kernel, group)?;
        Ok(group_o_i_hw.collapse_axis_with_next(2))
    }
}