Skip to main content

sapient_core/
shape.rs

1//! `Shape` — tensor shape utilities and broadcasting rules.
2
3use crate::error::{Result, SapientError};
4use serde::{Deserialize, Serialize};
5
6/// Newtype around a dimension vector that carries shape utilities.
7#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
8pub struct Shape(pub Vec<usize>);
9
10impl Shape {
11    /// Construct from any iterator of `usize`.
12    pub fn new(dims: impl IntoIterator<Item = usize>) -> Self {
13        Self(dims.into_iter().collect())
14    }
15
16    /// Number of dimensions (rank).
17    #[inline]
18    pub fn ndim(&self) -> usize {
19        self.0.len()
20    }
21
22    /// Total number of elements (product of all dims).
23    #[inline]
24    pub fn numel(&self) -> usize {
25        self.0.iter().product()
26    }
27
28    /// Dimension slice.
29    #[inline]
30    pub fn dims(&self) -> &[usize] {
31        &self.0
32    }
33
34    /// Row-major (C-contiguous) strides.
35    pub fn strides(&self) -> Vec<usize> {
36        let n = self.ndim();
37        if n == 0 {
38            return vec![];
39        }
40        let mut strides = vec![1usize; n];
41        for i in (0..n - 1).rev() {
42            strides[i] = strides[i + 1] * self.0[i + 1];
43        }
44        strides
45    }
46
47    /// Scalar (0-dimensional) shape.
48    pub fn scalar() -> Self {
49        Self(vec![])
50    }
51
52    /// Whether this is a scalar.
53    pub fn is_scalar(&self) -> bool {
54        self.0.is_empty()
55    }
56
57    /// Reshape — ensures the total numel is unchanged.
58    pub fn reshape(&self, new_dims: impl IntoIterator<Item = usize>) -> Result<Shape> {
59        let new_shape = Shape::new(new_dims);
60        if new_shape.numel() != self.numel() {
61            return Err(SapientError::ShapeMismatch {
62                expected: self.0.clone(),
63                got: new_shape.0.clone(),
64            });
65        }
66        Ok(new_shape)
67    }
68
69    /// Compute the broadcast output shape of `self` and `other` (NumPy rules).
70    pub fn broadcast_with(&self, other: &Shape) -> Result<Shape> {
71        let (a, b) = (&self.0, &other.0);
72        let len = a.len().max(b.len());
73        let mut out = vec![0usize; len];
74        for i in 0..len {
75            let ai = if i < len - a.len() {
76                1
77            } else {
78                a[i - (len - a.len())]
79            };
80            let bi = if i < len - b.len() {
81                1
82            } else {
83                b[i - (len - b.len())]
84            };
85            if ai == bi {
86                out[i] = ai;
87            } else if ai == 1 {
88                out[i] = bi;
89            } else if bi == 1 {
90                out[i] = ai;
91            } else {
92                return Err(SapientError::BroadcastError {
93                    lhs: self.0.clone(),
94                    rhs: other.0.clone(),
95                });
96            }
97        }
98        Ok(Shape(out))
99    }
100
101    /// Insert a new axis of size 1 at `axis` (like `np.expand_dims`).
102    pub fn expand_dims(&self, axis: usize) -> Result<Shape> {
103        if axis > self.ndim() {
104            return Err(SapientError::internal(format!(
105                "expand_dims: axis {axis} out of range for rank {}",
106                self.ndim()
107            )));
108        }
109        let mut dims = self.0.clone();
110        dims.insert(axis, 1);
111        Ok(Shape(dims))
112    }
113
114    /// Remove all dimensions of size 1 (like `np.squeeze`).
115    pub fn squeeze(&self) -> Shape {
116        Shape(self.0.iter().copied().filter(|&d| d != 1).collect())
117    }
118
119    /// Validate that every dim is > 0.
120    pub fn validate(&self) -> Result<()> {
121        for (i, &d) in self.0.iter().enumerate() {
122            if d == 0 {
123                return Err(SapientError::InvalidGraph(format!(
124                    "Shape has zero dimension at axis {i}"
125                )));
126            }
127        }
128        Ok(())
129    }
130
131    /// Contiguous byte offset for a multi-index into row-major storage.
132    pub fn flat_index(&self, idx: &[usize]) -> Result<usize> {
133        if idx.len() != self.ndim() {
134            return Err(SapientError::RankMismatch {
135                expected: self.ndim(),
136                got: idx.len(),
137            });
138        }
139        let strides = self.strides();
140        let mut offset = 0;
141        for (i, (&ix, &st)) in idx.iter().zip(strides.iter()).enumerate() {
142            if ix >= self.0[i] {
143                return Err(SapientError::internal(format!(
144                    "Index {ix} out of bounds for dim {i} (size {})",
145                    self.0[i]
146                )));
147            }
148            offset += ix * st;
149        }
150        Ok(offset)
151    }
152}
153
154impl std::fmt::Display for Shape {
155    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156        write!(f, "[")?;
157        for (i, d) in self.0.iter().enumerate() {
158            if i > 0 {
159                write!(f, ", ")?;
160            }
161            write!(f, "{d}")?;
162        }
163        write!(f, "]")
164    }
165}
166
167impl From<Vec<usize>> for Shape {
168    fn from(v: Vec<usize>) -> Self {
169        Self(v)
170    }
171}
172
173impl From<&[usize]> for Shape {
174    fn from(s: &[usize]) -> Self {
175        Self(s.to_vec())
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn numel() {
185        assert_eq!(Shape::new([2, 3, 4]).numel(), 24);
186        assert_eq!(Shape::scalar().numel(), 1);
187    }
188
189    #[test]
190    fn strides_row_major() {
191        let s = Shape::new([2, 3, 4]);
192        assert_eq!(s.strides(), vec![12, 4, 1]);
193    }
194
195    #[test]
196    fn broadcast() {
197        let a = Shape::new([1, 3]);
198        let b = Shape::new([2, 3]);
199        assert_eq!(a.broadcast_with(&b).unwrap(), Shape::new([2, 3]));
200    }
201
202    #[test]
203    fn broadcast_fail() {
204        let a = Shape::new([2, 3]);
205        let b = Shape::new([2, 4]);
206        assert!(a.broadcast_with(&b).is_err());
207    }
208
209    #[test]
210    fn reshape() {
211        let s = Shape::new([2, 3]);
212        let r = s.reshape([6]).unwrap();
213        assert_eq!(r, Shape::new([6]));
214    }
215
216    #[test]
217    fn flat_index() {
218        let s = Shape::new([2, 3, 4]);
219        assert_eq!(s.flat_index(&[1, 2, 3]).unwrap(), 12 + 8 + 3);
220    }
221}