1extern crate alloc;
2use crate::axes::Axes;
3use alloc::boxed::Box;
4use alloc::vec::Vec;
5use core::ops::Range;
6
7fn to_usize_idx(index: i64, rank: usize) -> usize {
8 if index >= 0 && index <= rank as i64 {
9 index as usize
10 } else {
11 (index + rank as i64) as usize % rank
12 }
13}
14
15#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
17pub struct Shape(Box<[usize]>);
18
19impl core::fmt::Display for Shape {
20 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
21 f.write_fmt(format_args!("{:?}", self.0))
22 }
23}
24
25impl Shape {
26 #[must_use]
28 pub const fn rank(&self) -> usize {
29 self.0.len()
30 }
31
32 #[must_use]
35 pub fn numel(&self) -> usize {
36 self.0.iter().product()
37 }
38
39 #[must_use]
41 pub fn iter(&self) -> impl DoubleEndedIterator<Item = &usize> + ExactSizeIterator {
42 self.into_iter()
43 }
44
45 #[must_use]
47 pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut usize> + ExactSizeIterator {
48 self.into_iter()
49 }
50
51 #[must_use]
53 pub fn strides(&self) -> Shape {
54 let mut a = 1;
55 Shape(
56 self.0
57 .iter()
58 .rev()
59 .map(|d| {
60 let t = a;
61 a *= d;
62 t
63 })
64 .collect::<Vec<usize>>()
65 .into_iter()
66 .rev()
67 .collect(),
68 )
69 }
70
71 #[must_use]
75 pub fn permute(&self, axes: &Axes) -> Self {
76 Self(axes.into_iter().map(|axis| self.0[*axis]).collect())
78 }
79
80 #[must_use]
82 pub fn expand_axes(&self, shape: &Shape) -> Axes {
83 let mut vec = self.0.to_vec();
84 while vec.len() < shape.rank() {
85 vec.insert(0, 1);
86 }
87 Axes(
88 vec.into_iter()
89 .zip(shape)
90 .enumerate()
91 .filter_map(|(a, (d, e))| if d == *e { None } else { Some(a) })
92 .collect(),
93 )
94 }
95
96 pub(crate) fn expand_strides(&self, shape: &Shape, mut old_strides: Shape) -> Shape {
97 let mut vec = self.0.to_vec();
98 while vec.len() < shape.rank() {
99 vec.insert(0, 1);
100 old_strides.0 = [0]
101 .into_iter()
102 .chain(old_strides.0.iter().copied())
103 .collect();
104 }
105 let old_shape: Shape = vec.into();
106 Shape(
107 old_shape
108 .into_iter()
109 .zip(shape)
110 .zip(&old_strides)
111 .map(|((od, nd), st)| if od == nd { *st } else { 0 })
112 .collect(),
113 )
114 }
115
116 #[cfg(feature = "std")]
117 pub(crate) fn safetensors(&self) -> alloc::string::String {
118 let mut res = alloc::format!("{:?}", self.0);
119 res.retain(|c| !c.is_whitespace());
120 res
121 }
122
123 #[cfg(feature = "std")]
124 pub(crate) fn from_safetensors(shape: &str) -> Result<Shape, crate::error::ZyxError> {
125 Ok(Shape(
126 shape
127 .split(',')
128 .map(|d| {
129 d.parse::<usize>().map_err(|err| {
130 crate::error::ZyxError::ParseError(alloc::format!(
131 "Cannot parse safetensors shape: {err}"
132 ))
133 })
134 })
135 .collect::<Result<Box<[usize]>, crate::error::ZyxError>>()?,
136 ))
137 }
138
139 #[must_use]
141 pub fn reduce(self, axes: &Axes) -> Shape {
142 let mut shape = self;
143 for a in axes.iter() {
144 shape.0[*a] = 1;
145 }
146 shape
147 }
148
149 #[must_use]
151 pub fn pad(mut self, padding: &[(i64, i64)]) -> Shape {
152 for (i, d) in self.iter_mut().rev().enumerate() {
153 if let Some((left, right)) = padding.get(i) {
154 *d = (*d as i64 + left + right) as usize;
155 } else {
156 break;
157 }
158 }
159 self
160 }
161
162 #[must_use]
164 pub fn vi64(&self) -> Vec<i64> {
165 self.0.iter().map(|x| *x as i64).collect()
166 }
167}
168
169impl core::ops::Index<i32> for Shape {
170 type Output = usize;
171 fn index(&self, index: i32) -> &Self::Output {
172 self.0.get(to_usize_idx(index as i64, self.rank())).unwrap()
173 }
174}
175
176impl core::ops::Index<i64> for Shape {
177 type Output = usize;
178 fn index(&self, index: i64) -> &Self::Output {
179 self.0.get(to_usize_idx(index, self.rank())).unwrap()
180 }
181}
182
183impl core::ops::Index<usize> for Shape {
184 type Output = usize;
185 fn index(&self, index: usize) -> &Self::Output {
186 self.0.get(index).unwrap()
187 }
188}
189
190impl core::ops::Index<Range<i64>> for Shape {
191 type Output = [usize];
192 fn index(&self, index: Range<i64>) -> &Self::Output {
193 let rank = self.rank();
194 self.0
195 .get(to_usize_idx(index.start, rank)..to_usize_idx(index.end, rank))
196 .unwrap()
197 }
198}
199
200impl From<Shape> for Vec<usize> {
201 fn from(val: Shape) -> Self {
202 val.0.into()
203 }
204}
205
206impl From<&Shape> for Shape {
207 fn from(sh: &Shape) -> Self {
208 sh.clone()
209 }
210}
211
212impl From<Box<[usize]>> for Shape {
213 fn from(value: Box<[usize]>) -> Self {
214 Shape(value)
215 }
216}
217
218impl From<Vec<usize>> for Shape {
219 fn from(value: Vec<usize>) -> Self {
220 Shape(value.iter().copied().collect())
221 }
222}
223
224impl From<&[usize]> for Shape {
225 fn from(value: &[usize]) -> Self {
226 Shape(value.iter().copied().collect())
227 }
228}
229
230impl From<usize> for Shape {
231 fn from(value: usize) -> Self {
232 Shape(Box::new([value]))
233 }
234}
235
236impl<const N: usize> From<[usize; N]> for Shape {
237 fn from(value: [usize; N]) -> Self {
238 Shape(value.into_iter().collect())
239 }
240}
241
242impl<'a> IntoIterator for &'a Shape {
243 type Item = &'a usize;
244 type IntoIter = <&'a [usize] as IntoIterator>::IntoIter;
245 fn into_iter(self) -> Self::IntoIter {
246 self.0.iter()
247 }
248}
249
250impl<'a> IntoIterator for &'a mut Shape {
251 type Item = &'a mut usize;
252 type IntoIter = <&'a mut [usize] as IntoIterator>::IntoIter;
253 fn into_iter(self) -> Self::IntoIter {
254 self.0.iter_mut()
255 }
256}
257
258impl PartialEq<[usize]> for Shape {
259 fn eq(&self, other: &[usize]) -> bool {
260 self.rank() == other.len() && self.iter().zip(other).all(|(x, y)| x == y)
261 }
262}
263
264impl<const RANK: usize> PartialEq<[usize; RANK]> for Shape {
265 fn eq(&self, other: &[usize; RANK]) -> bool {
266 self.rank() == RANK && self.iter().zip(other).all(|(x, y)| x == y)
267 }
268}
269
270impl AsRef<[usize]> for Shape {
271 fn as_ref(&self) -> &[usize] {
272 &self.0
273 }
274}