zenu_matrix/
shape_stride.rs1use std::fmt::Debug;
2
3use crate::dim::{default_stride, into_dyn, DimDyn, DimTrait};
4
5#[derive(Clone, Debug, Copy, PartialEq)]
6pub struct ShapeStride<D: DimTrait> {
7 shape: D,
8 stride: D,
9}
10
11impl<D: DimTrait> ShapeStride<D> {
12 pub fn new(shape: D, stride: D) -> Self {
13 Self { shape, stride }
14 }
15
16 pub fn shape(&self) -> D {
17 self.shape
18 }
19
20 pub fn stride(&self) -> D {
21 self.stride
22 }
23
24 #[must_use]
25 pub fn sort_by_stride(&self) -> Self {
26 let mut indeies = (0..self.stride.len()).collect::<Vec<_>>();
27 indeies.sort_by(|&a, &b| self.stride[b].cmp(&self.stride[a]));
28
29 let shape = indeies.iter().map(|&i| self.shape[i]).collect::<Vec<_>>();
30 let stride = indeies.iter().map(|&i| self.stride[i]).collect::<Vec<_>>();
31
32 let mut new_shape = self.shape();
33 let mut new_stride = self.stride();
34
35 for i in 0..self.stride.len() {
36 new_shape[i] = shape[i];
37 new_stride[i] = stride[i];
38 }
39
40 Self::new(new_shape, new_stride)
41 }
42
43 #[expect(clippy::missing_panics_doc)]
44 pub fn min_stride(&self) -> usize {
45 let slice = self.stride.slice();
46 *slice.iter().min().unwrap()
47 }
48
49 pub fn is_contiguous(&self) -> bool {
53 let sorted = self.sort_by_stride();
54
55 let default_stride = default_stride(sorted.shape());
56
57 let n = default_stride[0] / sorted.stride[0];
58
59 let is_zero = default_stride[0] % sorted.stride[0] == 0;
60 if !is_zero {
61 return false;
62 }
63
64 let mut default_stride = default_stride;
65 for i in 0..default_stride.len() {
66 default_stride[i] *= n;
67 }
68
69 default_stride == sorted.stride
70 }
71
72 pub fn is_transposed(&self) -> bool {
74 let last = self.stride()[self.stride().len() - 1];
75 let last_2 = self.stride()[self.stride().len() - 2];
76
77 last > last_2
78 }
79
80 #[must_use]
81 pub fn transpose(&self) -> Self {
82 let mut shape = self.shape();
83 let mut stride = self.stride();
84
85 let num_dim = shape.len();
86
87 let last = shape[shape.len() - 1];
89 let last_2 = shape[shape.len() - 2];
90
91 shape[num_dim - 1] = last_2;
92 shape[num_dim - 2] = last;
93
94 let last = stride[stride.len() - 1];
95 let last_2 = stride[stride.len() - 2];
96
97 stride[num_dim - 1] = last_2;
98 stride[num_dim - 2] = last;
99
100 Self::new(shape, stride)
101 }
102
103 pub fn is_default_stride(&self) -> bool {
104 if self.shape().len() == 1 {
105 return true;
106 }
107 default_stride(self.shape()) == self.stride()
108 }
109
110 pub fn is_transposed_default_stride(&self) -> bool {
113 self.transpose().is_default_stride()
114 }
115
116 pub(crate) fn into_dyn(self) -> ShapeStride<DimDyn> {
117 let shape = into_dyn(self.shape);
118 let stride = into_dyn(self.stride);
119 ShapeStride::new(shape, stride)
120 }
121
122 pub(crate) fn transpose_by_index(&self, index: &[usize]) -> Self {
123 let mut shape = self.shape();
124 let mut stride = self.stride();
125
126 let num_dim = shape.len();
127
128 for i in 0..num_dim {
129 shape[i] = self.shape()[index[i]];
130 stride[i] = self.stride()[index[i]];
131 }
132
133 Self::new(shape, stride)
134 }
135
136 pub(crate) fn swap_index(self, a: usize, b: usize) -> Self {
137 if a == b {
138 return self;
139 }
140 assert!(
141 (a < self.shape().len()) && (b < self.shape().len()),
142 "Index out of bounds"
143 );
144 let mut shape = self.shape();
145 let mut stride = self.stride();
146
147 let tmp_shape = shape[a];
148 let tmp_stride = stride[a];
149
150 shape[a] = shape[b];
151 stride[a] = stride[b];
152
153 shape[b] = tmp_shape;
154 stride[b] = tmp_stride;
155
156 Self::new(shape, stride)
157 }
158}
159
160impl ShapeStride<DimDyn> {
161 #[must_use]
162 pub fn get_dim_by_offset(&self, offset: usize) -> DimDyn {
163 let mut offset = offset;
164 let mut dim = DimDyn::default();
165 for i in 0..self.shape.len() {
166 dim.push_dim(offset / self.stride[i]);
167 offset %= self.stride[i];
168 }
169 dim
170 }
171
172 #[must_use]
173 pub fn add_axis(self, axis: usize) -> Self {
174 if self.shape().is_empty() {
175 return ShapeStride::new(DimDyn::from([1]), DimDyn::from([1]));
176 }
177 let mut shape = DimDyn::default();
178 let mut stride = DimDyn::default();
179
180 for i in 0..self.shape.len() {
181 if i == axis {
182 shape.push_dim(1);
183 stride.push_dim(self.stride[i]);
184 }
185 shape.push_dim(self.shape[i]);
186 stride.push_dim(self.stride[i]);
187 }
188 if axis == self.shape.len() {
189 shape.push_dim(1);
190 stride.push_dim(1);
191 }
192 ShapeStride::new(shape, stride)
193 }
194}
195
196#[cfg(test)]
197mod shape_stride_test {
198 use super::*;
199 use crate::dim::{default_stride, Dim2, Dim4};
200
201 #[test]
202 fn is_transposed_false() {
203 let shape = [2, 3];
204 let shape: Dim2 = shape.into();
205 let default_stride = default_stride(shape);
206
207 let shape_stride = super::ShapeStride::new(shape, default_stride);
208
209 assert!(!shape_stride.is_transposed());
210 }
211
212 #[test]
213 fn is_transposed_true() {
214 let shape_transposed = [2, 3, 5, 4];
216 let stride_transposed = [60, 20, 1, 5];
217 let shape_transposed: Dim4 = shape_transposed.into();
218 let stride_transposed: Dim4 = stride_transposed.into();
219 let shape_stride = ShapeStride::new(shape_transposed, stride_transposed);
220
221 assert!(shape_stride.is_transposed());
222 }
223}