1use std::fmt;
2
3#[derive(Debug, Clone, PartialEq, Eq, Hash)]
19pub struct Shape(Vec<usize>);
20
21impl Shape {
22 pub fn new(dims: Vec<usize>) -> Self {
24 Shape(dims)
25 }
26
27 pub fn dims(&self) -> &[usize] {
29 &self.0
30 }
31
32 pub fn rank(&self) -> usize {
34 self.0.len()
35 }
36
37 pub fn elem_count(&self) -> usize {
40 self.0.iter().product::<usize>().max(1)
41 }
42
43 pub fn stride_contiguous(&self) -> Vec<usize> {
52 let mut strides = vec![0usize; self.rank()];
53 if self.rank() > 0 {
54 strides[self.rank() - 1] = 1;
55 for i in (0..self.rank() - 1).rev() {
56 strides[i] = strides[i + 1] * self.0[i + 1];
57 }
58 }
59 strides
60 }
61
62 pub fn dim(&self, d: usize) -> crate::Result<usize> {
64 self.0.get(d).copied().ok_or(crate::Error::DimOutOfRange {
65 dim: d,
66 rank: self.rank(),
67 })
68 }
69
70 pub fn broadcast_shape(lhs: &Shape, rhs: &Shape) -> crate::Result<Shape> {
85 let l = lhs.dims();
86 let r = rhs.dims();
87 let max_rank = l.len().max(r.len());
88 let mut result = Vec::with_capacity(max_rank);
89
90 for i in 0..max_rank {
91 let ld = if i < l.len() { l[l.len() - 1 - i] } else { 1 };
93 let rd = if i < r.len() { r[r.len() - 1 - i] } else { 1 };
94
95 if ld == rd {
96 result.push(ld);
97 } else if ld == 1 {
98 result.push(rd);
99 } else if rd == 1 {
100 result.push(ld);
101 } else {
102 return Err(crate::Error::msg(format!(
103 "shapes {:?} and {:?} are not broadcast-compatible (dim {} from right: {} vs {})",
104 l, r, i, ld, rd
105 )));
106 }
107 }
108
109 result.reverse(); Ok(Shape::new(result))
111 }
112
113 pub fn broadcast_strides(&self, target: &Shape) -> Vec<usize> {
119 let self_dims = self.dims();
120 let target_dims = target.dims();
121 let self_strides = self.stride_contiguous();
122
123 let mut result = vec![0usize; target_dims.len()];
124 let offset = target_dims.len() - self_dims.len();
125
126 for i in 0..self_dims.len() {
127 if self_dims[i] == target_dims[i + offset] {
128 result[i + offset] = self_strides[i];
129 } else {
130 result[i + offset] = 0;
132 }
133 }
134 result
136 }
137}
138
139impl fmt::Display for Shape {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 write!(f, "[")?;
142 for (i, d) in self.0.iter().enumerate() {
143 if i > 0 {
144 write!(f, ", ")?;
145 }
146 write!(f, "{}", d)?;
147 }
148 write!(f, "]")
149 }
150}
151
152impl From<()> for Shape {
156 fn from(_: ()) -> Self {
158 Shape(vec![])
159 }
160}
161
162impl From<usize> for Shape {
163 fn from(d: usize) -> Self {
165 Shape(vec![d])
166 }
167}
168
169impl From<(usize,)> for Shape {
170 fn from((d0,): (usize,)) -> Self {
171 Shape(vec![d0])
172 }
173}
174
175impl From<(usize, usize)> for Shape {
176 fn from((d0, d1): (usize, usize)) -> Self {
177 Shape(vec![d0, d1])
178 }
179}
180
181impl From<(usize, usize, usize)> for Shape {
182 fn from((d0, d1, d2): (usize, usize, usize)) -> Self {
183 Shape(vec![d0, d1, d2])
184 }
185}
186
187impl From<(usize, usize, usize, usize)> for Shape {
188 fn from((d0, d1, d2, d3): (usize, usize, usize, usize)) -> Self {
189 Shape(vec![d0, d1, d2, d3])
190 }
191}
192
193impl From<Vec<usize>> for Shape {
194 fn from(v: Vec<usize>) -> Self {
195 Shape(v)
196 }
197}
198
199impl From<&[usize]> for Shape {
200 fn from(s: &[usize]) -> Self {
201 Shape(s.to_vec())
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn test_scalar_shape() {
211 let s = Shape::from(());
212 assert_eq!(s.rank(), 0);
213 assert_eq!(s.elem_count(), 1);
214 assert_eq!(s.stride_contiguous(), vec![]);
215 }
216
217 #[test]
218 fn test_vector_shape() {
219 let s = Shape::from(5);
220 assert_eq!(s.rank(), 1);
221 assert_eq!(s.elem_count(), 5);
222 assert_eq!(s.stride_contiguous(), vec![1]);
223 }
224
225 #[test]
226 fn test_matrix_shape() {
227 let s = Shape::from((3, 4));
228 assert_eq!(s.rank(), 2);
229 assert_eq!(s.elem_count(), 12);
230 assert_eq!(s.stride_contiguous(), vec![4, 1]);
232 }
233
234 #[test]
235 fn test_3d_strides() {
236 let s = Shape::from((2, 3, 4));
237 assert_eq!(s.stride_contiguous(), vec![12, 4, 1]);
239 assert_eq!(s.elem_count(), 24);
240 }
241
242 #[test]
243 fn test_display() {
244 let s = Shape::from((3, 4));
245 assert_eq!(format!("{}", s), "[3, 4]");
246 }
247}