1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
use super::slice_dim::SliceDim;
use crate::{dim::DimDyn, index::SliceTrait, shape_stride::ShapeStride};

#[derive(Clone, Debug, Copy, PartialEq)]

pub struct Slice {
    pub index: [SliceDim; 4],
    pub len: usize,
}

impl SliceTrait for Slice {
    type Dim = DimDyn;

    fn sliced_shape_stride(&self, shape: Self::Dim, stride: Self::Dim) -> ShapeStride<Self::Dim> {
        let mut new_shape = DimDyn::default();
        let mut new_stride = DimDyn::default();

        for i in 0..self.len {
            match self.index[i].new_dim(shape[i]) {
                0 => continue,
                new_dim => {
                    new_shape.push_dim(new_dim);
                    new_stride.push_dim(self.index[i].new_stride(stride[i]));
                }
            }
        }

        ShapeStride::new(new_shape, new_stride)
    }

    fn sliced_offset(&self, stride: Self::Dim, original_offset: usize) -> usize {
        let mut offset = 0;

        for i in 0..self.len {
            let start = self.index[i].start.unwrap_or(0);
            offset += start * stride[i];
        }

        offset + original_offset
    }
}

impl From<&[SliceDim]> for Slice {
    fn from(s: &[SliceDim]) -> Self {
        if s.len() > 4 {
            panic!("too many slice dimensions");
        } else if s.len() == 1 {
            Slice {
                index: [
                    s[0],
                    SliceDim::default(),
                    SliceDim::default(),
                    SliceDim::default(),
                ],
                len: 1,
            }
        } else if s.len() == 2 {
            Slice {
                index: [s[0], s[1], SliceDim::default(), SliceDim::default()],
                len: 2,
            }
        } else if s.len() == 3 {
            Slice {
                index: [s[0], s[1], s[2], SliceDim::default()],
                len: 3,
            }
        } else {
            Slice {
                index: [s[0], s[1], s[2], s[3]],
                len: 4,
            }
        }
    }
}

#[cfg(test)]
mod slice_dyn_slice {
    use crate::{dim::DimDyn, index::SliceTrait, slice_dynamic};

    #[test]
    fn dyn_slice() {
        let shape = DimDyn::new(&[2, 3, 4]);
        let stride = DimDyn::new(&[12, 4, 1]);
        let slice = slice_dynamic!(.., 1, 1..2);
        let shape_stride = slice.sliced_shape_stride(shape, stride);
        let result_shape = dbg!(shape_stride.shape());
        let result_stride = dbg!(shape_stride.stride());
        assert_eq!(result_shape, DimDyn::new(&[2, 1]));
        assert_eq!(result_stride, DimDyn::new(&[12, 1]));
    }
}