sfs_core/array/
shape.rs

1use std::{fmt, ops::Deref};
2
3mod removed_axis;
4pub(crate) use removed_axis::RemovedAxis;
5
6mod strides;
7pub use strides::Strides;
8
9/// An axis index for an array.
10#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
11pub struct Axis(pub usize);
12
13impl Deref for Axis {
14    type Target = usize;
15
16    fn deref(&self) -> &Self::Target {
17        &self.0
18    }
19}
20
21/// An N-dimensional array shape.
22#[derive(Clone, Debug, Eq, Hash, PartialEq)]
23pub struct Shape(pub Vec<usize>);
24
25impl Shape {
26    /// The number of dimensions of an array with the corresponding shape.
27    pub fn dimensions(&self) -> usize {
28        self.0.len()
29    }
30
31    /// The number of elements of an array with the corresponding shape.
32    pub fn elements(&self) -> usize {
33        self.iter().product()
34    }
35
36    pub(crate) fn index_from_flat_unchecked(&self, mut flat: usize) -> Vec<usize> {
37        let mut n = self.elements();
38        let mut index = vec![0; self.len()];
39        for (i, v) in self.iter().enumerate() {
40            n /= v;
41            index[i] = flat / n;
42            flat %= n;
43        }
44        index
45    }
46
47    pub(crate) fn index_sum_from_flat_unchecked(&self, mut flat: usize) -> usize {
48        let mut n = self.elements();
49        let mut sum = 0;
50        for v in self.iter() {
51            n /= v;
52            sum += flat / n;
53            flat %= n;
54        }
55        sum
56    }
57
58    pub(crate) fn remove_axis(&self, axis: Axis) -> RemovedAxis<Self> {
59        RemovedAxis::new(self, axis)
60    }
61
62    pub(crate) fn strides(&self) -> Strides {
63        let mut strides = vec![1; self.len()];
64
65        for (i, v) in self.iter().enumerate().skip(1).rev() {
66            strides.iter_mut().take(i).for_each(|stride| *stride *= v)
67        }
68
69        Strides(strides)
70    }
71}
72
73impl AsRef<[usize]> for Shape {
74    fn as_ref(&self) -> &[usize] {
75        self
76    }
77}
78
79impl Deref for Shape {
80    type Target = [usize];
81
82    fn deref(&self) -> &Self::Target {
83        &self.0
84    }
85}
86
87impl From<Vec<usize>> for Shape {
88    fn from(shape: Vec<usize>) -> Self {
89        Self(shape)
90    }
91}
92
93impl<const N: usize> From<[usize; N]> for Shape {
94    fn from(shape: [usize; N]) -> Self {
95        Self(shape.to_vec())
96    }
97}
98
99impl From<usize> for Shape {
100    fn from(shape: usize) -> Self {
101        Self(vec![shape])
102    }
103}
104
105impl fmt::Display for Shape {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        write!(f, "{}", self[0])?;
108        for v in self.iter().skip(1) {
109            write!(f, "/{v}")?;
110        }
111        Ok(())
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_index_from_flat_unchecked() {
121        let shape = Shape(vec![3, 3, 4]);
122
123        assert_eq!(shape.index_from_flat_unchecked(0), vec![0, 0, 0]);
124        assert_eq!(shape.index_from_flat_unchecked(1), vec![0, 0, 1]);
125        assert_eq!(shape.index_from_flat_unchecked(3), vec![0, 0, 3]);
126        assert_eq!(shape.index_from_flat_unchecked(4), vec![0, 1, 0]);
127        assert_eq!(shape.index_from_flat_unchecked(35), vec![2, 2, 3]);
128    }
129
130    #[test]
131    fn test_strides() {
132        let shape = Shape(vec![6, 3, 7]);
133        let strides = shape.strides();
134
135        assert_eq!(strides, Strides(vec![21, 7, 1]));
136    }
137}