zyx_core/
shape.rs

1extern crate alloc;
2use crate::axes::Axes;
3use alloc::boxed::Box;
4use alloc::vec::Vec;
5use core::ops::Range;
6
7fn to_usize_idx(index: i64, rank: usize) -> usize {
8    if index >= 0 && index <= rank as i64 {
9        index as usize
10    } else {
11        (index + rank as i64) as usize % rank
12    }
13}
14
15/// Shape of tensor
16#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
17pub struct Shape(Box<[usize]>);
18
19impl core::fmt::Display for Shape {
20    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
21        f.write_fmt(format_args!("{:?}", self.0))
22    }
23}
24
25impl Shape {
26    /// Get shape's rank
27    #[must_use]
28    pub const fn rank(&self) -> usize {
29        self.0.len()
30    }
31
32    /// Get number of elements in tensor with this shape
33    /// (a product of it's dimensions).
34    #[must_use]
35    pub fn numel(&self) -> usize {
36        self.0.iter().product()
37    }
38
39    /// Iter
40    #[must_use]
41    pub fn iter(&self) -> impl DoubleEndedIterator<Item = &usize> + ExactSizeIterator {
42        self.into_iter()
43    }
44
45    /// Iter mut
46    #[must_use]
47    pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = &mut usize> + ExactSizeIterator {
48        self.into_iter()
49    }
50
51    /// Get shape's strides
52    #[must_use]
53    pub fn strides(&self) -> Shape {
54        let mut a = 1;
55        Shape(
56            self.0
57                .iter()
58                .rev()
59                .map(|d| {
60                    let t = a;
61                    a *= d;
62                    t
63                })
64                .collect::<Vec<usize>>()
65                .into_iter()
66                .rev()
67                .collect(),
68        )
69    }
70
71    /// Permute shape's dimensions with axes
72    /// # Panics
73    /// Panics if axes is incorrect.
74    #[must_use]
75    pub fn permute(&self, axes: &Axes) -> Self {
76        //std::println!("self: {self}, axes: {axes:?}");
77        Self(axes.into_iter().map(|axis| self.0[*axis]).collect())
78    }
79
80    /// Get axes along which self was expanded to shape
81    #[must_use]
82    pub fn expand_axes(&self, shape: &Shape) -> Axes {
83        let mut vec = self.0.to_vec();
84        while vec.len() < shape.rank() {
85            vec.insert(0, 1);
86        }
87        Axes(
88            vec.into_iter()
89                .zip(shape)
90                .enumerate()
91                .filter_map(|(a, (d, e))| if d == *e { None } else { Some(a) })
92                .collect(),
93        )
94    }
95
96    pub(crate) fn expand_strides(&self, shape: &Shape, mut old_strides: Shape) -> Shape {
97        let mut vec = self.0.to_vec();
98        while vec.len() < shape.rank() {
99            vec.insert(0, 1);
100            old_strides.0 = [0]
101                .into_iter()
102                .chain(old_strides.0.iter().copied())
103                .collect();
104        }
105        let old_shape: Shape = vec.into();
106        Shape(
107            old_shape
108                .into_iter()
109                .zip(shape)
110                .zip(&old_strides)
111                .map(|((od, nd), st)| if od == nd { *st } else { 0 })
112                .collect(),
113        )
114    }
115
116    #[cfg(feature = "std")]
117    pub(crate) fn safetensors(&self) -> alloc::string::String {
118        let mut res = alloc::format!("{:?}", self.0);
119        res.retain(|c| !c.is_whitespace());
120        res
121    }
122
123    #[cfg(feature = "std")]
124    pub(crate) fn from_safetensors(shape: &str) -> Result<Shape, crate::error::ZyxError> {
125        Ok(Shape(
126            shape
127                .split(',')
128                .map(|d| {
129                    d.parse::<usize>().map_err(|err| {
130                        crate::error::ZyxError::ParseError(alloc::format!(
131                            "Cannot parse safetensors shape: {err}"
132                        ))
133                    })
134                })
135                .collect::<Result<Box<[usize]>, crate::error::ZyxError>>()?,
136        ))
137    }
138
139    /// Reduce self along axes
140    #[must_use]
141    pub fn reduce(self, axes: &Axes) -> Shape {
142        let mut shape = self;
143        for a in axes.iter() {
144            shape.0[*a] = 1;
145        }
146        shape
147    }
148
149    /// Pad self with padding
150    #[must_use]
151    pub fn pad(mut self, padding: &[(i64, i64)]) -> Shape {
152        for (i, d) in self.iter_mut().rev().enumerate() {
153            if let Some((left, right)) = padding.get(i) {
154                *d = (*d as i64 + left + right) as usize;
155            } else {
156                break;
157            }
158        }
159        self
160    }
161
162    /// Get self as vector i64
163    #[must_use]
164    pub fn vi64(&self) -> Vec<i64> {
165        self.0.iter().map(|x| *x as i64).collect()
166    }
167}
168
169impl core::ops::Index<i32> for Shape {
170    type Output = usize;
171    fn index(&self, index: i32) -> &Self::Output {
172        self.0.get(to_usize_idx(index as i64, self.rank())).unwrap()
173    }
174}
175
176impl core::ops::Index<i64> for Shape {
177    type Output = usize;
178    fn index(&self, index: i64) -> &Self::Output {
179        self.0.get(to_usize_idx(index, self.rank())).unwrap()
180    }
181}
182
183impl core::ops::Index<usize> for Shape {
184    type Output = usize;
185    fn index(&self, index: usize) -> &Self::Output {
186        self.0.get(index).unwrap()
187    }
188}
189
190impl core::ops::Index<Range<i64>> for Shape {
191    type Output = [usize];
192    fn index(&self, index: Range<i64>) -> &Self::Output {
193        let rank = self.rank();
194        self.0
195            .get(to_usize_idx(index.start, rank)..to_usize_idx(index.end, rank))
196            .unwrap()
197    }
198}
199
200impl From<Shape> for Vec<usize> {
201    fn from(val: Shape) -> Self {
202        val.0.into()
203    }
204}
205
206impl From<&Shape> for Shape {
207    fn from(sh: &Shape) -> Self {
208        sh.clone()
209    }
210}
211
212impl From<Box<[usize]>> for Shape {
213    fn from(value: Box<[usize]>) -> Self {
214        Shape(value)
215    }
216}
217
218impl From<Vec<usize>> for Shape {
219    fn from(value: Vec<usize>) -> Self {
220        Shape(value.iter().copied().collect())
221    }
222}
223
224impl From<&[usize]> for Shape {
225    fn from(value: &[usize]) -> Self {
226        Shape(value.iter().copied().collect())
227    }
228}
229
230impl From<usize> for Shape {
231    fn from(value: usize) -> Self {
232        Shape(Box::new([value]))
233    }
234}
235
236impl<const N: usize> From<[usize; N]> for Shape {
237    fn from(value: [usize; N]) -> Self {
238        Shape(value.into_iter().collect())
239    }
240}
241
242impl<'a> IntoIterator for &'a Shape {
243    type Item = &'a usize;
244    type IntoIter = <&'a [usize] as IntoIterator>::IntoIter;
245    fn into_iter(self) -> Self::IntoIter {
246        self.0.iter()
247    }
248}
249
250impl<'a> IntoIterator for &'a mut Shape {
251    type Item = &'a mut usize;
252    type IntoIter = <&'a mut [usize] as IntoIterator>::IntoIter;
253    fn into_iter(self) -> Self::IntoIter {
254        self.0.iter_mut()
255    }
256}
257
258impl PartialEq<[usize]> for Shape {
259    fn eq(&self, other: &[usize]) -> bool {
260        self.rank() == other.len() && self.iter().zip(other).all(|(x, y)| x == y)
261    }
262}
263
264impl<const RANK: usize> PartialEq<[usize; RANK]> for Shape {
265    fn eq(&self, other: &[usize; RANK]) -> bool {
266        self.rank() == RANK && self.iter().zip(other).all(|(x, y)| x == y)
267    }
268}
269
270impl AsRef<[usize]> for Shape {
271    fn as_ref(&self) -> &[usize] {
272        &self.0
273    }
274}