use ndarray::*;
use num_traits::Zero;
pub mod raw;
pub mod unary;
use tract_core::ops::prelude::*;
pub fn space_to_batch_nd(pb: &crate::tfpb::node_def::NodeDef) -> TractResult<Box<Op>> {
    let datum_type = pb.get_attr_datum_type("T")?;
    Ok(Box::new(raw::SpaceToBatch::new(datum_type)))
}
pub fn batch_to_space_nd(pb: &crate::tfpb::node_def::NodeDef) -> TractResult<Box<Op>> {
    let datum_type = pb.get_attr_datum_type("T")?;
    Ok(Box::new(raw::BatchToSpace::new(datum_type)))
}
fn space_to_batch<T: Copy + Datum + Zero>(
    input: SharedTensor,
    block_shape: &ArrayView1<i32>,
    paddings: &ArrayView2<i32>,
) -> TractResult<SharedTensor> {
    let mut data = input.to_array::<T>()?;
    for (ix, pad) in paddings.view().outer_iter().enumerate() {
        if pad[0] != 0 {
            let mut pad_shape = data.shape().to_vec();
            pad_shape[ix + 1] = pad[0] as usize;
            let tmp = ::ndarray::stack(
                ::ndarray::Axis(ix + 1),
                &[::ndarray::ArrayD::zeros(pad_shape).view(), data.view()],
            )?;
            data = tmp;
        }
        if pad[1] != 0 {
            let mut pad_shape = data.shape().to_vec();
            pad_shape[ix + 1] = pad[1] as usize;
            let tmp = ::ndarray::stack(
                ::ndarray::Axis(ix + 1),
                &[data.view(), ::ndarray::ArrayD::zeros(pad_shape).view()],
            )?;
            data = tmp;
        }
    }
    let mut reshaped = vec![data.shape()[0]];
    let block_size = block_shape.iter().map(|a| *a as usize).product::<usize>();
    let mut final_shape = vec![block_size * data.shape()[0]];
    for (m, &block_shape_dim) in block_shape.iter().enumerate() {
        reshaped.push(data.shape()[m + 1] / block_shape_dim as usize);
        reshaped.push(block_shape_dim as usize);
        final_shape.push(data.shape()[m + 1] / block_shape_dim as usize);
    }
    reshaped.extend(&data.shape()[block_shape.len() + 1..]);
    final_shape.extend(&data.shape()[block_shape.len() + 1..]);
    let data = data.into_shape(reshaped)?;
    let mut permuted_axes: Vec<_> = (0..block_shape.len()).map(|x| 2 * x + 2).collect();
    permuted_axes.push(0);
    permuted_axes.extend((0..block_shape.len()).map(|x| 2 * x + 1));
    permuted_axes.extend((block_shape.len() * 2 + 1)..data.ndim());
    let data = data.permuted_axes(permuted_axes);
    let data: Vec<T> = data.into_iter().map(|x| *x).collect();
    let data = ::ndarray::ArrayD::from_shape_vec(final_shape, data)?;
    Ok(data.into())
}
fn batch_to_space<T: Copy + Datum + Zero>(
    input: SharedTensor,
    block_shape: &ArrayView1<i32>,
    crops: &ArrayView2<i32>,
) -> TractResult<SharedTensor> {
    let data = input.to_array()?;
    let input_shape = data.shape().to_vec();
    let crops: ArrayView2<i32> = crops.view().into_dimensionality()?;
    let block_size = block_shape.iter().map(|a| *a as usize).product::<usize>();
    
    let mut unflatten_blocked_shape = vec![];
    unflatten_blocked_shape.extend(block_shape.iter().map(|a| *a as usize));
    let batches = data.shape()[0] / block_size;
    unflatten_blocked_shape.push(batches);
    unflatten_blocked_shape.extend(&data.shape()[1..]);
    let data = data.into_shape(&*unflatten_blocked_shape)?;
    let mut permuted_axes = vec![block_shape.len()];
    let mut padded_shape = vec![batches];
    for i in 0..block_shape.shape()[0] {
        permuted_axes.push(block_shape.len() + 1 + i);
        permuted_axes.push(i);
        padded_shape.push(block_shape[i] as usize * input_shape[i + 1]);
    }
    permuted_axes.extend((1 + block_shape.len() * 2)..data.ndim());
    padded_shape.extend(&input_shape[1 + block_shape.len()..]);
    let data = data.permuted_axes(permuted_axes);
    let data: Vec<T> = data.into_iter().map(|x| *x).collect();
    let data = ::ndarray::ArrayD::from_shape_vec(padded_shape, data)?;
    let mut data = data;
    for (i, crop) in crops.outer_iter().enumerate() {
        if crop[0] != 0 || crop[1] != 0 {
            let end = data.shape()[1 + i] as usize;
            let range = (crop[0] as usize)..(end - crop[1] as usize);
            data = data
                .slice_axis(Axis(i + 1), range.into())
                .map(|x| *x)
                .to_owned();
        }
    }
    Ok(data.into())
}
#[cfg(test)]
mod tests {
    #![allow(non_snake_case)]
    use super::raw::{BatchToSpace, SpaceToBatch};
    use super::*;
    use tract_core::ops::InferenceOp;
    use tract_core::tensor::arr4;
    
    #[test]
    fn space_to_batch_nd_1() {
        assert_eq!(
            SpaceToBatch::new(i32::datum_type())
                .eval(tvec![
                    arr4(&[[[[1i32], [2]], [[3], [4]]]]).into(),
                    arr1(&[2, 2]).into(),
                    arr2(&[[0, 0], [0, 0]]).into(),
                ])
                .unwrap(),
            tvec![arr4(&[[[[1i32]]], [[[2]]], [[[3]]], [[[4]]]]).into()],
        )
    }
    #[test]
    fn space_to_batch_nd_2() {
        assert_eq!(
            SpaceToBatch::new(i32::datum_type())
                .eval(tvec![
                    arr4(&[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]]).into(),
                    arr1(&[2, 2]).into(),
                    arr2(&[[0, 0], [0, 0]]).into(),
                ])
                .unwrap(),
            tvec![arr4(&[
                [[[1i32, 2, 3]]],
                [[[4, 5, 6]]],
                [[[7, 8, 9]]],
                [[[10, 11, 12]]],
            ])
            .into(),],
        )
    }
    #[test]
    fn space_to_batch_nd_3() {
        assert_eq!(
            SpaceToBatch::new(i32::datum_type())
                .eval(tvec![
                    arr4(&[[
                        [[1], [2], [3], [4]],
                        [[5], [6], [7], [8]],
                        [[9], [10], [11], [12]],
                        [[13], [14], [15], [16]],
                    ]])
                    .into(),
                    arr1(&[2, 2]).into(),
                    arr2(&[[0, 0], [0, 0]]).into(),
                ])
                .unwrap(),
            tvec![arr4(&[
                [[[1], [3]], [[9], [11]]],
                [[[2], [4]], [[10], [12]]],
                [[[5], [7]], [[13], [15]]],
                [[[6], [8]], [[14], [16]]],
            ])
            .into(),],
        )
    }
    #[test]
    fn space_to_batch_nd_4() {
        assert_eq!(
            SpaceToBatch::new(i32::datum_type())
                .eval(tvec![
                    arr4(&[
                        [[[1], [2], [3], [4]], [[5], [6], [7], [8]]],
                        [[[9], [10], [11], [12]], [[13], [14], [15], [16]]],
                    ])
                    .into(),
                    arr1(&[2, 2]).into(),
                    arr2(&[[0, 0], [2, 0]]).into(),
                ])
                .unwrap(),
            tvec![arr4(&[
                [[[0], [1], [3]]],
                [[[0], [9], [11]]],
                [[[0], [2], [4]]],
                [[[0], [10], [12]]],
                [[[0], [5], [7]]],
                [[[0], [13], [15]]],
                [[[0], [6], [8]]],
                [[[0], [14], [16]]],
            ])
            .into(),],
        )
    }
    #[test]
    fn space_to_batch_nd_infer_1() {
        let op = SpaceToBatch::new(f32::datum_type());
        let data = TensorFact::dt_shape(DatumType::F32, shapefact!(1, 4, 16));
        let block_shape = TensorFact::from(Tensor::from(arr1(&[2])));
        let paddings = TensorFact::from(Tensor::from(arr2(&[[0.to_dim(), 0.to_dim()]])));
        let any = TensorFact::default();
        let (_, outputs) = op
            .infer_facts(tvec!(&data, &block_shape, &paddings), tvec!(&any))
            .unwrap();
        assert_eq!(
            outputs[0],
            TensorFact::dt_shape(DatumType::F32, shapefact!(2, 2, 16))
        );
    }
    #[test]
    fn space_to_batch_nd_infer_2() {
        let op = SpaceToBatch::new(f32::datum_type());
        let data = TensorFact::dt_shape(DatumType::F32, shapefact!(1, (TDim::s() - 4), 16));
        let block_shape = TensorFact::from(Tensor::from(arr1(&[2])));
        let paddings = TensorFact::from(Tensor::from(arr2(&[[0.to_dim(), (TDim::s() % 2)]])));
        let any = TensorFact::default();
        let (_, outputs) = op
            .infer_facts(tvec!(&data, &block_shape, &paddings), tvec!(&any))
            .unwrap();
        assert_eq!(
            outputs[0],
            TensorFact::dt_shape(
                DatumType::F32,
                shapefact!(2, ((TDim::s() + TDim::s() % 2 - 4) / 2), 16)
            )
        );
    }
    #[test]
    fn batch_to_space_nd_1() {
        assert_eq!(
            BatchToSpace::new(i32::datum_type())
                .eval(tvec![
                    arr4(&[[[[1]]], [[[2]]], [[[3]]], [[[4]]]]).into(),
                    arr1(&[2, 2]).into(),
                    arr2(&[[0, 0], [0, 0]]).into(),
                ])
                .unwrap(),
            tvec![arr4(&[[[[1], [2]], [[3], [4]]]]).into()]
        )
    }
    #[test]
    fn batch_to_space_nd_2() {
        assert_eq!(
            BatchToSpace::new(i32::datum_type())
                .eval(tvec![
                    arr4(&[
                        [[[1i32, 2, 3]]],
                        [[[4, 5, 6]]],
                        [[[7, 8, 9]]],
                        [[[10, 11, 12]]],
                    ])
                    .into(),
                    arr1(&[2, 2]).into(),
                    arr2(&[[0, 0], [0, 0]]).into(),
                ])
                .unwrap(),
            tvec![arr4(&[[[[1i32, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]]).into()]
        )
    }
    #[test]
    fn batch_to_space_nd_3() {
        assert_eq!(
            BatchToSpace::new(i32::datum_type())
                .eval(tvec![
                    arr4(&[
                        [[[1i32], [3]], [[9], [11]]],
                        [[[2], [4]], [[10], [12]]],
                        [[[5], [7]], [[13], [15]]],
                        [[[6], [8]], [[14], [16]]],
                    ])
                    .into(),
                    arr1(&[2, 2]).into(),
                    arr2(&[[0, 0], [0, 0]]).into(),
                ])
                .unwrap(),
            tvec![arr4(&[[
                [[1i32], [2], [3], [4]],
                [[5], [6], [7], [8]],
                [[9], [10], [11], [12]],
                [[13], [14], [15], [16]],
            ]])
            .into(),]
        )
    }
    #[test]
    fn batch_to_space_nd_4() {
        assert_eq!(
            BatchToSpace::new(i32::datum_type())
                .eval(tvec![
                    arr4(&[
                        [[[0i32], [1], [3]]],
                        [[[0], [9], [11]]],
                        [[[0], [2], [4]]],
                        [[[0], [10], [12]]],
                        [[[0], [5], [7]]],
                        [[[0], [13], [15]]],
                        [[[0], [6], [8]]],
                        [[[0], [14], [16]]],
                    ])
                    .into(),
                    arr1(&[2, 2]).into(),
                    arr2(&[[0, 0], [2, 0]]).into(),
                ])
                .unwrap(),
            tvec![arr4(&[
                [[[1], [2], [3], [4]], [[5], [6], [7], [8]]],
                [[[9], [10], [11], [12]], [[13], [14], [15], [16]]],
            ])
            .into(),]
        )
    }
}