zenu_matrix/
shape_stride.rs

1use std::fmt::Debug;
2
3use crate::dim::{default_stride, into_dyn, DimDyn, DimTrait};
4
5#[derive(Clone, Debug, Copy, PartialEq)]
6pub struct ShapeStride<D: DimTrait> {
7    shape: D,
8    stride: D,
9}
10
11impl<D: DimTrait> ShapeStride<D> {
12    pub fn new(shape: D, stride: D) -> Self {
13        Self { shape, stride }
14    }
15
16    pub fn shape(&self) -> D {
17        self.shape
18    }
19
20    pub fn stride(&self) -> D {
21        self.stride
22    }
23
24    #[must_use]
25    pub fn sort_by_stride(&self) -> Self {
26        let mut indeies = (0..self.stride.len()).collect::<Vec<_>>();
27        indeies.sort_by(|&a, &b| self.stride[b].cmp(&self.stride[a]));
28
29        let shape = indeies.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
30        let stride = indeies.iter().map(|&i| self.stride[i]).collect::<Vec<_>>();
31
32        let mut new_shape = self.shape();
33        let mut new_stride = self.stride();
34
35        for i in 0..self.stride.len() {
36            new_shape[i] = shape[i];
37            new_stride[i] = stride[i];
38        }
39
40        Self::new(new_shape, new_stride)
41    }
42
43    #[expect(clippy::missing_panics_doc)]
44    pub fn min_stride(&self) -> usize {
45        let slice = self.stride.slice();
46        *slice.iter().min().unwrap()
47    }
48
49    /// この`ShapeStride`が連続しているかどうかを判定する
50    /// transposeされていた場合は並び替えを行い、
51    /// そのストライドが、`default_stride`のn倍になっているかどうかを判定する
52    pub fn is_contiguous(&self) -> bool {
53        let sorted = self.sort_by_stride();
54
55        let default_stride = default_stride(sorted.shape());
56
57        let n = default_stride[0] / sorted.stride[0];
58
59        let is_zero = default_stride[0] % sorted.stride[0] == 0;
60        if !is_zero {
61            return false;
62        }
63
64        let mut default_stride = default_stride;
65        for i in 0..default_stride.len() {
66            default_stride[i] *= n;
67        }
68
69        default_stride == sorted.stride
70    }
71
72    /// 転置は最後の次元と最後から2番目の次元を入れ替えることで表現される
73    pub fn is_transposed(&self) -> bool {
74        let last = self.stride()[self.stride().len() - 1];
75        let last_2 = self.stride()[self.stride().len() - 2];
76
77        last > last_2
78    }
79
80    #[must_use]
81    pub fn transpose(&self) -> Self {
82        let mut shape = self.shape();
83        let mut stride = self.stride();
84
85        let num_dim = shape.len();
86
87        // 入れ替える
88        let last = shape[shape.len() - 1];
89        let last_2 = shape[shape.len() - 2];
90
91        shape[num_dim - 1] = last_2;
92        shape[num_dim - 2] = last;
93
94        let last = stride[stride.len() - 1];
95        let last_2 = stride[stride.len() - 2];
96
97        stride[num_dim - 1] = last_2;
98        stride[num_dim - 2] = last;
99
100        Self::new(shape, stride)
101    }
102
103    pub fn is_default_stride(&self) -> bool {
104        if self.shape().len() == 1 {
105            return true;
106        }
107        default_stride(self.shape()) == self.stride()
108    }
109
110    /// `shpae` `stride`が転置されている場合、
111    /// 転置を元に戻した場合`default_stride`になっているかどうかを判定する
112    pub fn is_transposed_default_stride(&self) -> bool {
113        self.transpose().is_default_stride()
114    }
115
116    pub(crate) fn into_dyn(self) -> ShapeStride<DimDyn> {
117        let shape = into_dyn(self.shape);
118        let stride = into_dyn(self.stride);
119        ShapeStride::new(shape, stride)
120    }
121
122    pub(crate) fn transpose_by_index(&self, index: &[usize]) -> Self {
123        let mut shape = self.shape();
124        let mut stride = self.stride();
125
126        let num_dim = shape.len();
127
128        for i in 0..num_dim {
129            shape[i] = self.shape()[index[i]];
130            stride[i] = self.stride()[index[i]];
131        }
132
133        Self::new(shape, stride)
134    }
135
136    pub(crate) fn swap_index(self, a: usize, b: usize) -> Self {
137        if a == b {
138            return self;
139        }
140        assert!(
141            (a < self.shape().len()) && (b < self.shape().len()),
142            "Index out of bounds"
143        );
144        let mut shape = self.shape();
145        let mut stride = self.stride();
146
147        let tmp_shape = shape[a];
148        let tmp_stride = stride[a];
149
150        shape[a] = shape[b];
151        stride[a] = stride[b];
152
153        shape[b] = tmp_shape;
154        stride[b] = tmp_stride;
155
156        Self::new(shape, stride)
157    }
158}
159
160impl ShapeStride<DimDyn> {
161    #[must_use]
162    pub fn get_dim_by_offset(&self, offset: usize) -> DimDyn {
163        let mut offset = offset;
164        let mut dim = DimDyn::default();
165        for i in 0..self.shape.len() {
166            dim.push_dim(offset / self.stride[i]);
167            offset %= self.stride[i];
168        }
169        dim
170    }
171
172    #[must_use]
173    pub fn add_axis(self, axis: usize) -> Self {
174        if self.shape().is_empty() {
175            return ShapeStride::new(DimDyn::from([1]), DimDyn::from([1]));
176        }
177        let mut shape = DimDyn::default();
178        let mut stride = DimDyn::default();
179
180        for i in 0..self.shape.len() {
181            if i == axis {
182                shape.push_dim(1);
183                stride.push_dim(self.stride[i]);
184            }
185            shape.push_dim(self.shape[i]);
186            stride.push_dim(self.stride[i]);
187        }
188        if axis == self.shape.len() {
189            shape.push_dim(1);
190            stride.push_dim(1);
191        }
192        ShapeStride::new(shape, stride)
193    }
194}
195
196#[cfg(test)]
197mod shape_stride_test {
198    use super::*;
199    use crate::dim::{default_stride, Dim2, Dim4};
200
201    #[test]
202    fn is_transposed_false() {
203        let shape = [2, 3];
204        let shape: Dim2 = shape.into();
205        let default_stride = default_stride(shape);
206
207        let shape_stride = super::ShapeStride::new(shape, default_stride);
208
209        assert!(!shape_stride.is_transposed());
210    }
211
212    #[test]
213    fn is_transposed_true() {
214        // transpose
215        let shape_transposed = [2, 3, 5, 4];
216        let stride_transposed = [60, 20, 1, 5];
217        let shape_transposed: Dim4 = shape_transposed.into();
218        let stride_transposed: Dim4 = stride_transposed.into();
219        let shape_stride = ShapeStride::new(shape_transposed, stride_transposed);
220
221        assert!(shape_stride.is_transposed());
222    }
223}