Skip to main content

shrew_core/
shape.rs

1use std::fmt;
2
3// Shape — N-dimensional shape representation
4//
5// A Shape describes the size of each dimension of a tensor.
6// For example:
7//   - Scalar: Shape([])          — 0 dimensions, 1 element
8//   - Vector: Shape([5])         — 1 dimension, 5 elements
9//   - Matrix: Shape([3, 4])      — 2 dimensions, 12 elements
10//   - Batch:  Shape([2, 3, 4])   — 3 dimensions, 24 elements
11//
12// The shape is fundamental because it determines:
13//   1. How many elements are in the tensor (product of all dims)
14//   2. The default (contiguous/row-major) strides for memory layout
15//   3. Whether two tensors are compatible for operations (broadcasting rules)
16
17/// N-dimensional shape of a tensor.
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
19pub struct Shape(Vec<usize>);
20
21impl Shape {
22    /// Create a new shape from a vector of dimension sizes.
23    pub fn new(dims: Vec<usize>) -> Self {
24        Shape(dims)
25    }
26
27    /// The dimension sizes as a slice.
28    pub fn dims(&self) -> &[usize] {
29        &self.0
30    }
31
32    /// Number of dimensions (0 for scalar, 1 for vector, 2 for matrix, etc.).
33    pub fn rank(&self) -> usize {
34        self.0.len()
35    }
36
37    /// Total number of elements (product of all dimensions).
38    /// A scalar shape [] has 1 element.
39    pub fn elem_count(&self) -> usize {
40        self.0.iter().product::<usize>().max(1)
41    }
42
43    /// Compute the contiguous (row-major / C-order) strides for this shape.
44    ///
45    /// For shape [2, 3, 4], strides are [12, 4, 1]:
46    ///   - Moving 1 step in dim 0 jumps 12 elements (3*4)
47    ///   - Moving 1 step in dim 1 jumps 4 elements
48    ///   - Moving 1 step in dim 2 jumps 1 element
49    ///
50    /// This is how row-major memory works: the last dimension is contiguous.
51    pub fn stride_contiguous(&self) -> Vec<usize> {
52        let mut strides = vec![0usize; self.rank()];
53        if self.rank() > 0 {
54            strides[self.rank() - 1] = 1;
55            for i in (0..self.rank() - 1).rev() {
56                strides[i] = strides[i + 1] * self.0[i + 1];
57            }
58        }
59        strides
60    }
61
62    /// Size of a specific dimension.
63    pub fn dim(&self, d: usize) -> crate::Result<usize> {
64        self.0.get(d).copied().ok_or(crate::Error::DimOutOfRange {
65            dim: d,
66            rank: self.rank(),
67        })
68    }
69
70    // Broadcasting
71
72    /// Compute the broadcast output shape from two input shapes.
73    ///
74    /// NumPy-style broadcasting rules:
75    ///   1. Align shapes from the right (trailing dimensions).
76    ///   2. Dimensions are compatible if they are equal or one of them is 1.
77    ///   3. Missing leading dimensions are treated as 1.
78    ///
79    /// Examples:
80    ///   [3, 4] and [4]     → [3, 4]   (expand [4] to [1, 4] then broadcast dim 0)
81    ///   [2, 1] and [1, 3]  → [2, 3]
82    ///   [5, 3, 1] and [3, 4] → [5, 3, 4]
83    ///   [3] and [4]        → Error (3 ≠ 4 and neither is 1)
84    pub fn broadcast_shape(lhs: &Shape, rhs: &Shape) -> crate::Result<Shape> {
85        let l = lhs.dims();
86        let r = rhs.dims();
87        let max_rank = l.len().max(r.len());
88        let mut result = Vec::with_capacity(max_rank);
89
90        for i in 0..max_rank {
91            // Index from the right: l.len()-1-i walks backwards. If i >= len, treat as 1.
92            let ld = if i < l.len() { l[l.len() - 1 - i] } else { 1 };
93            let rd = if i < r.len() { r[r.len() - 1 - i] } else { 1 };
94
95            if ld == rd {
96                result.push(ld);
97            } else if ld == 1 {
98                result.push(rd);
99            } else if rd == 1 {
100                result.push(ld);
101            } else {
102                return Err(crate::Error::msg(format!(
103                    "shapes {:?} and {:?} are not broadcast-compatible (dim {} from right: {} vs {})",
104                    l, r, i, ld, rd
105                )));
106            }
107        }
108
109        result.reverse(); // We built it from the right
110        Ok(Shape::new(result))
111    }
112
113    /// Return the broadcast strides for this shape to match a target broadcast shape.
114    ///
115    /// For each dimension where self.dim[i] == 1 and target.dim[i] > 1,
116    /// the stride is set to 0 (repeating the single element).
117    /// For missing leading dimensions (self has fewer dims), stride is also 0.
118    pub fn broadcast_strides(&self, target: &Shape) -> Vec<usize> {
119        let self_dims = self.dims();
120        let target_dims = target.dims();
121        let self_strides = self.stride_contiguous();
122
123        let mut result = vec![0usize; target_dims.len()];
124        let offset = target_dims.len() - self_dims.len();
125
126        for i in 0..self_dims.len() {
127            if self_dims[i] == target_dims[i + offset] {
128                result[i + offset] = self_strides[i];
129            } else {
130                // self_dims[i] must be 1 → stride 0 (broadcast)
131                result[i + offset] = 0;
132            }
133        }
134        // Leading dimensions (offset region) are already 0 (broadcast)
135        result
136    }
137}
138
139impl fmt::Display for Shape {
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        write!(f, "[")?;
142        for (i, d) in self.0.iter().enumerate() {
143            if i > 0 {
144                write!(f, ", ")?;
145            }
146            write!(f, "{}", d)?;
147        }
148        write!(f, "]")
149    }
150}
151
152// Convenient From implementations
153// These let you write: Shape::from((3, 4)) instead of Shape::new(vec![3, 4])
154
155impl From<()> for Shape {
156    /// Scalar shape (0 dimensions).
157    fn from(_: ()) -> Self {
158        Shape(vec![])
159    }
160}
161
162impl From<usize> for Shape {
163    /// 1-D shape.
164    fn from(d: usize) -> Self {
165        Shape(vec![d])
166    }
167}
168
169impl From<(usize,)> for Shape {
170    fn from((d0,): (usize,)) -> Self {
171        Shape(vec![d0])
172    }
173}
174
175impl From<(usize, usize)> for Shape {
176    fn from((d0, d1): (usize, usize)) -> Self {
177        Shape(vec![d0, d1])
178    }
179}
180
181impl From<(usize, usize, usize)> for Shape {
182    fn from((d0, d1, d2): (usize, usize, usize)) -> Self {
183        Shape(vec![d0, d1, d2])
184    }
185}
186
187impl From<(usize, usize, usize, usize)> for Shape {
188    fn from((d0, d1, d2, d3): (usize, usize, usize, usize)) -> Self {
189        Shape(vec![d0, d1, d2, d3])
190    }
191}
192
193impl From<Vec<usize>> for Shape {
194    fn from(v: Vec<usize>) -> Self {
195        Shape(v)
196    }
197}
198
199impl From<&[usize]> for Shape {
200    fn from(s: &[usize]) -> Self {
201        Shape(s.to_vec())
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn test_scalar_shape() {
211        let s = Shape::from(());
212        assert_eq!(s.rank(), 0);
213        assert_eq!(s.elem_count(), 1);
214        assert_eq!(s.stride_contiguous(), vec![]);
215    }
216
217    #[test]
218    fn test_vector_shape() {
219        let s = Shape::from(5);
220        assert_eq!(s.rank(), 1);
221        assert_eq!(s.elem_count(), 5);
222        assert_eq!(s.stride_contiguous(), vec![1]);
223    }
224
225    #[test]
226    fn test_matrix_shape() {
227        let s = Shape::from((3, 4));
228        assert_eq!(s.rank(), 2);
229        assert_eq!(s.elem_count(), 12);
230        // Row-major: stride for dim0 = 4, stride for dim1 = 1
231        assert_eq!(s.stride_contiguous(), vec![4, 1]);
232    }
233
234    #[test]
235    fn test_3d_strides() {
236        let s = Shape::from((2, 3, 4));
237        // [2,3,4]: strides = [3*4, 4, 1] = [12, 4, 1]
238        assert_eq!(s.stride_contiguous(), vec![12, 4, 1]);
239        assert_eq!(s.elem_count(), 24);
240    }
241
242    #[test]
243    fn test_display() {
244        let s = Shape::from((3, 4));
245        assert_eq!(format!("{}", s), "[3, 4]");
246    }
247}