1use std::{fmt, ops::Deref};
2
3mod removed_axis;
4pub(crate) use removed_axis::RemovedAxis;
5
6mod strides;
7pub use strides::Strides;
8
9#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
11pub struct Axis(pub usize);
12
13impl Deref for Axis {
14 type Target = usize;
15
16 fn deref(&self) -> &Self::Target {
17 &self.0
18 }
19}
20
21#[derive(Clone, Debug, Eq, Hash, PartialEq)]
23pub struct Shape(pub Vec<usize>);
24
25impl Shape {
26 pub fn dimensions(&self) -> usize {
28 self.0.len()
29 }
30
31 pub fn elements(&self) -> usize {
33 self.iter().product()
34 }
35
36 pub(crate) fn index_from_flat_unchecked(&self, mut flat: usize) -> Vec<usize> {
37 let mut n = self.elements();
38 let mut index = vec![0; self.len()];
39 for (i, v) in self.iter().enumerate() {
40 n /= v;
41 index[i] = flat / n;
42 flat %= n;
43 }
44 index
45 }
46
47 pub(crate) fn index_sum_from_flat_unchecked(&self, mut flat: usize) -> usize {
48 let mut n = self.elements();
49 let mut sum = 0;
50 for v in self.iter() {
51 n /= v;
52 sum += flat / n;
53 flat %= n;
54 }
55 sum
56 }
57
58 pub(crate) fn remove_axis(&self, axis: Axis) -> RemovedAxis<Self> {
59 RemovedAxis::new(self, axis)
60 }
61
62 pub(crate) fn strides(&self) -> Strides {
63 let mut strides = vec![1; self.len()];
64
65 for (i, v) in self.iter().enumerate().skip(1).rev() {
66 strides.iter_mut().take(i).for_each(|stride| *stride *= v)
67 }
68
69 Strides(strides)
70 }
71}
72
73impl AsRef<[usize]> for Shape {
74 fn as_ref(&self) -> &[usize] {
75 self
76 }
77}
78
79impl Deref for Shape {
80 type Target = [usize];
81
82 fn deref(&self) -> &Self::Target {
83 &self.0
84 }
85}
86
87impl From<Vec<usize>> for Shape {
88 fn from(shape: Vec<usize>) -> Self {
89 Self(shape)
90 }
91}
92
93impl<const N: usize> From<[usize; N]> for Shape {
94 fn from(shape: [usize; N]) -> Self {
95 Self(shape.to_vec())
96 }
97}
98
99impl From<usize> for Shape {
100 fn from(shape: usize) -> Self {
101 Self(vec![shape])
102 }
103}
104
105impl fmt::Display for Shape {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 write!(f, "{}", self[0])?;
108 for v in self.iter().skip(1) {
109 write!(f, "/{v}")?;
110 }
111 Ok(())
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_index_from_flat_unchecked() {
121 let shape = Shape(vec![3, 3, 4]);
122
123 assert_eq!(shape.index_from_flat_unchecked(0), vec![0, 0, 0]);
124 assert_eq!(shape.index_from_flat_unchecked(1), vec![0, 0, 1]);
125 assert_eq!(shape.index_from_flat_unchecked(3), vec![0, 0, 3]);
126 assert_eq!(shape.index_from_flat_unchecked(4), vec![0, 1, 0]);
127 assert_eq!(shape.index_from_flat_unchecked(35), vec![2, 2, 3]);
128 }
129
130 #[test]
131 fn test_strides() {
132 let shape = Shape(vec![6, 3, 7]);
133 let strides = shape.strides();
134
135 assert_eq!(strides, Strides(vec![21, 7, 1]));
136 }
137}