1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
use std::fmt::Debug;

use crate::dim::{default_stride, into_dyn, DimDyn, DimTrait};

#[derive(Clone, Debug, Copy, PartialEq)]
pub struct ShapeStride<D: DimTrait> {
    shape: D,
    stride: D,
}

impl<D: DimTrait> ShapeStride<D> {
    pub fn new(shape: D, stride: D) -> Self {
        Self { shape, stride }
    }

    pub fn shape(&self) -> D {
        self.shape
    }

    pub fn stride(&self) -> D {
        self.stride
    }

    pub fn sort_by_stride(&self) -> Self {
        let mut indeies = (0..self.stride.len()).collect::<Vec<_>>();
        indeies.sort_by(|&a, &b| self.stride[b].cmp(&self.stride[a]));

        let shape = indeies.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
        let stride = indeies.iter().map(|&i| self.stride[i]).collect::<Vec<_>>();

        let mut new_shape = self.shape();
        let mut new_stride = self.stride();

        for i in 0..self.stride.len() {
            new_shape[i] = shape[i];
            new_stride[i] = stride[i];
        }

        Self::new(new_shape, new_stride)
    }

    pub fn min_stride(&self) -> usize {
        let slice = self.stride.slice();
        *slice.iter().min().unwrap()
    }

    /// このShapeStrideが連続しているかどうかを判定する
    /// transposeされていた場合は並び替えを行い、
    /// そのストライドが、default_strideのn倍になっているかどうかを判定する
    pub fn is_contiguous(&self) -> bool {
        let sorted = self.sort_by_stride();

        let default_stride = default_stride(sorted.shape());

        let n = default_stride[0] / sorted.stride[0];

        let is_zero = default_stride[0] % sorted.stride[0] == 0;
        if !is_zero {
            return false;
        }

        let mut default_stride = default_stride;
        for i in 0..default_stride.len() {
            default_stride[i] *= n;
        }

        default_stride == sorted.stride
    }

    /// 転置は最後の次元と最後から2番目の次元を入れ替えることで表現される
    pub fn is_transposed(&self) -> bool {
        let last = self.stride()[self.stride().len() - 1];
        let last_2 = self.stride()[self.stride().len() - 2];

        last > last_2
    }

    pub fn transpose(&self) -> Self {
        let mut shape = self.shape();
        let mut stride = self.stride();

        let num_dim = shape.len();

        // 入れ替える
        let last = shape[shape.len() - 1];
        let last_2 = shape[shape.len() - 2];

        shape[num_dim - 1] = last_2;
        shape[num_dim - 2] = last;

        let last = stride[stride.len() - 1];
        let last_2 = stride[stride.len() - 2];

        stride[num_dim - 1] = last_2;
        stride[num_dim - 2] = last;

        Self::new(shape, stride)
    }

    pub fn is_default_stride(&self) -> bool {
        default_stride(self.shape()) == self.stride()
    }

    /// shpae strideが転置されている場合、
    /// 転置を元に戻した場合default_strideになっているかどうかを判定する
    pub fn is_transposed_default_stride(&self) -> bool {
        self.transpose().is_default_stride()
    }

    pub(crate) fn into_dyn(self) -> ShapeStride<DimDyn> {
        let shape = into_dyn(self.shape);
        let stride = into_dyn(self.stride);
        ShapeStride::new(shape, stride)
    }
}

impl ShapeStride<DimDyn> {
    pub fn get_dim_by_offset(&self, offset: usize) -> DimDyn {
        let mut offset = offset;
        let mut dim = DimDyn::default();
        for i in 0..self.shape.len() {
            dim.push_dim(offset / self.stride[i]);
            offset %= self.stride[i];
        }
        dim
    }
}

#[cfg(test)]
mod shape_stride {
    use super::*;
    use crate::dim::{default_stride, Dim2, Dim4};

    #[test]
    fn is_transposed_false() {
        let shape = [2, 3];
        let shape: Dim2 = shape.into();
        let default_stride = default_stride(shape);

        let shape_stride = super::ShapeStride::new(shape, default_stride);

        assert!(!shape_stride.is_transposed());
    }

    #[test]
    fn is_transposed_true() {
        // transpose
        let shape_transposed = [2, 3, 5, 4];
        let stride_transposed = [60, 20, 1, 5];
        let shape_transposed: Dim4 = shape_transposed.into();
        let stride_transposed: Dim4 = stride_transposed.into();
        let shape_stride = ShapeStride::new(shape_transposed, stride_transposed);

        assert_eq!(shape_stride.is_transposed(), true);
    }
}