1use crate::error::{Result, SapientError};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
8pub struct Shape(pub Vec<usize>);
9
10impl Shape {
11 pub fn new(dims: impl IntoIterator<Item = usize>) -> Self {
13 Self(dims.into_iter().collect())
14 }
15
16 #[inline]
18 pub fn ndim(&self) -> usize {
19 self.0.len()
20 }
21
22 #[inline]
24 pub fn numel(&self) -> usize {
25 self.0.iter().product()
26 }
27
28 #[inline]
30 pub fn dims(&self) -> &[usize] {
31 &self.0
32 }
33
34 pub fn strides(&self) -> Vec<usize> {
36 let n = self.ndim();
37 if n == 0 {
38 return vec![];
39 }
40 let mut strides = vec![1usize; n];
41 for i in (0..n - 1).rev() {
42 strides[i] = strides[i + 1] * self.0[i + 1];
43 }
44 strides
45 }
46
47 pub fn scalar() -> Self {
49 Self(vec![])
50 }
51
52 pub fn is_scalar(&self) -> bool {
54 self.0.is_empty()
55 }
56
57 pub fn reshape(&self, new_dims: impl IntoIterator<Item = usize>) -> Result<Shape> {
59 let new_shape = Shape::new(new_dims);
60 if new_shape.numel() != self.numel() {
61 return Err(SapientError::ShapeMismatch {
62 expected: self.0.clone(),
63 got: new_shape.0.clone(),
64 });
65 }
66 Ok(new_shape)
67 }
68
69 pub fn broadcast_with(&self, other: &Shape) -> Result<Shape> {
71 let (a, b) = (&self.0, &other.0);
72 let len = a.len().max(b.len());
73 let mut out = vec![0usize; len];
74 for i in 0..len {
75 let ai = if i < len - a.len() {
76 1
77 } else {
78 a[i - (len - a.len())]
79 };
80 let bi = if i < len - b.len() {
81 1
82 } else {
83 b[i - (len - b.len())]
84 };
85 if ai == bi {
86 out[i] = ai;
87 } else if ai == 1 {
88 out[i] = bi;
89 } else if bi == 1 {
90 out[i] = ai;
91 } else {
92 return Err(SapientError::BroadcastError {
93 lhs: self.0.clone(),
94 rhs: other.0.clone(),
95 });
96 }
97 }
98 Ok(Shape(out))
99 }
100
101 pub fn expand_dims(&self, axis: usize) -> Result<Shape> {
103 if axis > self.ndim() {
104 return Err(SapientError::internal(format!(
105 "expand_dims: axis {axis} out of range for rank {}",
106 self.ndim()
107 )));
108 }
109 let mut dims = self.0.clone();
110 dims.insert(axis, 1);
111 Ok(Shape(dims))
112 }
113
114 pub fn squeeze(&self) -> Shape {
116 Shape(self.0.iter().copied().filter(|&d| d != 1).collect())
117 }
118
119 pub fn validate(&self) -> Result<()> {
121 for (i, &d) in self.0.iter().enumerate() {
122 if d == 0 {
123 return Err(SapientError::InvalidGraph(format!(
124 "Shape has zero dimension at axis {i}"
125 )));
126 }
127 }
128 Ok(())
129 }
130
131 pub fn flat_index(&self, idx: &[usize]) -> Result<usize> {
133 if idx.len() != self.ndim() {
134 return Err(SapientError::RankMismatch {
135 expected: self.ndim(),
136 got: idx.len(),
137 });
138 }
139 let strides = self.strides();
140 let mut offset = 0;
141 for (i, (&ix, &st)) in idx.iter().zip(strides.iter()).enumerate() {
142 if ix >= self.0[i] {
143 return Err(SapientError::internal(format!(
144 "Index {ix} out of bounds for dim {i} (size {})",
145 self.0[i]
146 )));
147 }
148 offset += ix * st;
149 }
150 Ok(offset)
151 }
152}
153
154impl std::fmt::Display for Shape {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 write!(f, "[")?;
157 for (i, d) in self.0.iter().enumerate() {
158 if i > 0 {
159 write!(f, ", ")?;
160 }
161 write!(f, "{d}")?;
162 }
163 write!(f, "]")
164 }
165}
166
167impl From<Vec<usize>> for Shape {
168 fn from(v: Vec<usize>) -> Self {
169 Self(v)
170 }
171}
172
173impl From<&[usize]> for Shape {
174 fn from(s: &[usize]) -> Self {
175 Self(s.to_vec())
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn numel() {
185 assert_eq!(Shape::new([2, 3, 4]).numel(), 24);
186 assert_eq!(Shape::scalar().numel(), 1);
187 }
188
189 #[test]
190 fn strides_row_major() {
191 let s = Shape::new([2, 3, 4]);
192 assert_eq!(s.strides(), vec![12, 4, 1]);
193 }
194
195 #[test]
196 fn broadcast() {
197 let a = Shape::new([1, 3]);
198 let b = Shape::new([2, 3]);
199 assert_eq!(a.broadcast_with(&b).unwrap(), Shape::new([2, 3]));
200 }
201
202 #[test]
203 fn broadcast_fail() {
204 let a = Shape::new([2, 3]);
205 let b = Shape::new([2, 4]);
206 assert!(a.broadcast_with(&b).is_err());
207 }
208
209 #[test]
210 fn reshape() {
211 let s = Shape::new([2, 3]);
212 let r = s.reshape([6]).unwrap();
213 assert_eq!(r, Shape::new([6]));
214 }
215
216 #[test]
217 fn flat_index() {
218 let s = Shape::new([2, 3, 4]);
219 assert_eq!(s.flat_index(&[1, 2, 3]).unwrap(), 12 + 8 + 3);
220 }
221}