1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use crate::{
    dim::{Dim0, Dim1, Dim2, Dim3, Dim4},
    index::SliceTrait,
    shape_stride::ShapeStride,
};

use super::slice_dim::SliceDim;

#[derive(Clone, Debug, Copy, PartialEq)]
pub struct Slice0D {}

impl SliceTrait for Slice0D {
    type Dim = Dim0;

    fn sliced_shape_stride(&self, shape: Self::Dim, stride: Self::Dim) -> ShapeStride<Self::Dim> {
        ShapeStride::new(shape, stride)
    }

    fn sliced_offset(&self, _stride: Self::Dim, _original_offset: usize) -> usize {
        0
    }
}

macro_rules! impl_slice_ty {
    ($impl_name:ident, $num_item:expr, $dim_ty:ty) => {
        #[derive(Clone, Debug, Copy, PartialEq)]
        pub struct $impl_name {
            pub index: [SliceDim; $num_item],
        }

        impl $impl_name {
            pub fn new(index: [SliceDim; $num_item]) -> Self {
                Self { index }
            }

            pub fn index(&self) -> &[SliceDim; $num_item] {
                &self.index
            }
        }

        impl SliceTrait for $impl_name {
            type Dim = $dim_ty;

            fn sliced_shape_stride(
                &self,
                shape: Self::Dim,
                stride: Self::Dim,
            ) -> ShapeStride<Self::Dim> {
                let mut new_shape = shape.clone();
                let mut new_stride = stride.clone();

                for i in 0..$num_item {
                    new_shape[i] = self.index[i].new_dim(shape[i]);
                    new_stride[i] = self.index[i].new_stride(stride[i]);
                }

                ShapeStride::new(new_shape, new_stride)
            }

            fn sliced_offset(&self, stride: Self::Dim, original_offset: usize) -> usize {
                let mut offset = 0;

                for i in 0..$num_item {
                    let start = self.index[i].start.unwrap_or(0);
                    offset += start * stride[i];
                }

                offset + original_offset
            }
        }
    };
}
impl_slice_ty!(Slice1D, 1, Dim1);
impl_slice_ty!(Slice2D, 2, Dim2);
impl_slice_ty!(Slice3D, 3, Dim3);
impl_slice_ty!(Slice4D, 4, Dim4);

#[cfg(test)]
mod static_dim_slice {
    use crate::dim::{Dim1, Dim2, Dim3};
    use crate::index::SliceTrait;
    use crate::slice;

    #[test]
    fn sliced_1d() {
        let shape = Dim1::new([6]);
        let stride = Dim1::new([1]);
        let slice = slice!(..;2);

        let stride_shape = dbg!(slice.sliced_shape_stride(shape, stride));

        assert_eq!(stride_shape.shape(), Dim1::new([3]));
        assert_eq!(stride_shape.stride(), Dim1::new([2]));
    }

    #[test]
    fn test_sliced_shape_stride_2d() {
        let original_shape = Dim2::new([10, 20]);
        let original_stride = Dim2::new([1, 10]);
        let slice = crate::slice!(1..5;2, 3..10;1);
        let new = slice.sliced_shape_stride(original_shape, original_stride);

        assert_eq!(new.shape(), Dim2::new([2, 7]));
        assert_eq!(new.stride(), Dim2::new([2, 10]));
    }

    #[test]
    fn test_sliced_shape_stride_3d() {
        let original_shape = Dim3::new([10, 20, 30]);
        let original_stride = Dim3::new([1, 10, 200]);
        let slice = crate::slice!(1..5;2, 3..10;1, ..15;3);
        let new = slice.sliced_shape_stride(original_shape, original_stride);

        assert_eq!(new.shape(), Dim3::new([2, 7, 5]));
        assert_eq!(new.stride(), Dim3::new([2, 10, 600]),);
    }

    #[test]
    fn test_sliced_offset_2d() {
        let stride = Dim2::new([10, 1]);
        let slice = crate::slice!(1..5;2, 3..10;1);
        let offset = slice.sliced_offset(stride, 0);

        assert_eq!(offset, 13);
    }
}