Skip to main content

slop_tensor/
dimensions.rs

1use arrayvec::ArrayVec;
2use itertools::Itertools;
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use thiserror::Error;
5
6const MAX_DIMENSIONS: usize = 3;
7
8#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
9#[repr(C)]
10pub struct Dimensions {
11    sizes: ArrayVec<usize, MAX_DIMENSIONS>,
12    strides: ArrayVec<usize, MAX_DIMENSIONS>,
13}
14
15#[derive(Debug, Clone, Copy, Error)]
16pub enum DimensionsError {
17    #[error("Too many dimensions {0}, maximum number allowed is {MAX_DIMENSIONS}")]
18    TooManyDimensions(usize),
19    #[error("total number of elements must match, expected {0}, got {1}")]
20    NumElementsMismatch(usize, usize),
21}
22
23impl Dimensions {
24    fn new(sizes: ArrayVec<usize, MAX_DIMENSIONS>) -> Self {
25        let mut strides = ArrayVec::new();
26        let mut stride = 1;
27        for size in sizes.iter().rev() {
28            strides.push(stride);
29            stride *= size;
30        }
31        strides.reverse();
32        Self { sizes, strides }
33    }
34
35    #[inline]
36    pub fn total_len(&self) -> usize {
37        self.sizes.iter().product()
38    }
39
40    #[inline]
41    pub(crate) fn compatible(&self, other: &Dimensions) -> Result<(), DimensionsError> {
42        if self.total_len() != other.total_len() {
43            return Err(DimensionsError::NumElementsMismatch(self.total_len(), other.total_len()));
44        }
45        Ok(())
46    }
47
48    #[inline]
49    pub fn sizes(&self) -> &[usize] {
50        &self.sizes
51    }
52
53    pub(crate) fn sizes_mut(&mut self) -> &mut ArrayVec<usize, MAX_DIMENSIONS> {
54        &mut self.sizes
55    }
56
57    pub(crate) fn strides_mut(&mut self) -> &mut ArrayVec<usize, MAX_DIMENSIONS> {
58        &mut self.strides
59    }
60
61    #[inline]
62    pub fn strides(&self) -> &[usize] {
63        &self.strides
64    }
65
66    #[inline]
67    pub(crate) fn index_map(&self, index: impl AsRef<[usize]>) -> usize {
68        index.as_ref().iter().zip_eq(self.strides.iter()).map(|(i, s)| i * s).sum()
69    }
70}
71
72impl TryFrom<&[usize]> for Dimensions {
73    type Error = DimensionsError;
74
75    fn try_from(value: &[usize]) -> Result<Self, Self::Error> {
76        let sizes = ArrayVec::try_from(value)
77            .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
78        Ok(Self::new(sizes))
79    }
80}
81
82impl TryFrom<Vec<usize>> for Dimensions {
83    type Error = DimensionsError;
84
85    fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
86        let sizes = ArrayVec::try_from(value.as_slice())
87            .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
88        Ok(Self::new(sizes))
89    }
90}
91
92impl<const N: usize> TryFrom<[usize; N]> for Dimensions {
93    type Error = DimensionsError;
94
95    fn try_from(value: [usize; N]) -> Result<Self, Self::Error> {
96        let sizes = ArrayVec::try_from(value.as_slice())
97            .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
98        Ok(Self::new(sizes))
99    }
100}
101
102impl FromIterator<usize> for Dimensions {
103    #[inline]
104    fn from_iter<T: IntoIterator<Item = usize>>(iter: T) -> Self {
105        let sizes = ArrayVec::from_iter(iter);
106        Self::new(sizes)
107    }
108}
109
110impl Serialize for Dimensions {
111    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
112        self.sizes.serialize(serializer)
113    }
114}
115
116impl<'de> Deserialize<'de> for Dimensions {
117    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
118        let sizes = Vec::deserialize(deserializer)?;
119        Ok(Self::try_from(sizes).expect("invalid dimension length"))
120    }
121}