Skip to main content

slop_tensor/
dimensions.rs

1use core::fmt;
2
3use arrayvec::ArrayVec;
4use itertools::Itertools;
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use thiserror::Error;
7
8const MAX_DIMENSIONS: usize = 3;
9
10#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
11#[repr(C)]
12pub struct Dimensions {
13    sizes: ArrayVec<usize, MAX_DIMENSIONS>,
14    strides: ArrayVec<usize, MAX_DIMENSIONS>,
15}
16
17impl fmt::Display for Dimensions {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        write!(f, "Dimensions({})", self.sizes.iter().join(", "))
20    }
21}
22
23#[derive(Debug, Clone, Copy, Error)]
24pub enum DimensionsError {
25    #[error("Too many dimensions {0}, maximum number allowed is {MAX_DIMENSIONS}")]
26    TooManyDimensions(usize),
27    #[error("total number of elements must match, expected {0}, got {1}")]
28    NumElementsMismatch(usize, usize),
29}
30
31impl Dimensions {
32    fn new(sizes: ArrayVec<usize, MAX_DIMENSIONS>) -> Self {
33        let mut strides = ArrayVec::new();
34        let mut stride = 1;
35        for size in sizes.iter().rev() {
36            strides.push(stride);
37            stride *= size;
38        }
39        strides.reverse();
40        Self { sizes, strides }
41    }
42
43    #[inline]
44    pub fn total_len(&self) -> usize {
45        self.sizes.iter().product()
46    }
47
48    #[inline]
49    pub(crate) fn compatible(&self, other: &Dimensions) -> Result<(), DimensionsError> {
50        if self.total_len() != other.total_len() {
51            return Err(DimensionsError::NumElementsMismatch(self.total_len(), other.total_len()));
52        }
53        Ok(())
54    }
55
56    #[inline]
57    pub fn sizes(&self) -> &[usize] {
58        &self.sizes
59    }
60
61    pub(crate) fn sizes_mut(&mut self) -> &mut ArrayVec<usize, MAX_DIMENSIONS> {
62        &mut self.sizes
63    }
64
65    pub(crate) fn strides_mut(&mut self) -> &mut ArrayVec<usize, MAX_DIMENSIONS> {
66        &mut self.strides
67    }
68
69    #[inline]
70    pub fn strides(&self) -> &[usize] {
71        &self.strides
72    }
73
74    /// Maps a multi-dimensional index to a single-dimensional buffer index.
75    ///
76    /// Panics if the index is out of bounds, or the length of the index does not match the number
77    /// of dimensions.
78    #[inline]
79    pub(crate) fn index_map(&self, index: impl AsRef<[usize]>) -> usize {
80        // The panic code path was put into a cold function to not bloat the
81        // call site.
82        #[inline(never)]
83        #[cold]
84        #[track_caller]
85        fn index_length_mismatch(buffer_index: &[usize], dimensions: &Dimensions) -> ! {
86            panic!(
87                "Index tuple {buffer_index:?} has length {} which is out of bounds for dimensions 
88                {dimensions} of length {}",
89                buffer_index.len(),
90                dimensions.sizes().len()
91            );
92        }
93
94        // The panic code path was put into a cold function to not bloat the
95        // call site.
96        #[inline(never)]
97        #[cold]
98        #[track_caller]
99        fn index_out_of_bounds_fail(buffer_index: &[usize], dimensions: &Dimensions) -> ! {
100            panic!("Index {buffer_index:?} is out of bounds for dimensions {dimensions}",);
101        }
102
103        if index.as_ref().len() != self.sizes.len() {
104            index_length_mismatch(index.as_ref(), self);
105        }
106
107        let mut buffer_index = 0;
108        for ((idx, stride), len) in
109            index.as_ref().iter().zip_eq(self.strides.iter()).zip_eq(self.sizes.iter())
110        {
111            if *idx >= *len {
112                index_out_of_bounds_fail(index.as_ref(), self);
113            }
114            buffer_index += idx * stride;
115        }
116
117        buffer_index
118    }
119}
120
121impl TryFrom<&[usize]> for Dimensions {
122    type Error = DimensionsError;
123
124    fn try_from(value: &[usize]) -> Result<Self, Self::Error> {
125        let sizes = ArrayVec::try_from(value)
126            .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
127        Ok(Self::new(sizes))
128    }
129}
130
131impl TryFrom<Vec<usize>> for Dimensions {
132    type Error = DimensionsError;
133
134    fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
135        let sizes = ArrayVec::try_from(value.as_slice())
136            .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
137        Ok(Self::new(sizes))
138    }
139}
140
141impl<const N: usize> TryFrom<[usize; N]> for Dimensions {
142    type Error = DimensionsError;
143
144    fn try_from(value: [usize; N]) -> Result<Self, Self::Error> {
145        let sizes = ArrayVec::try_from(value.as_slice())
146            .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
147        Ok(Self::new(sizes))
148    }
149}
150
151impl FromIterator<usize> for Dimensions {
152    #[inline]
153    fn from_iter<T: IntoIterator<Item = usize>>(iter: T) -> Self {
154        let sizes = ArrayVec::from_iter(iter);
155        Self::new(sizes)
156    }
157}
158
159impl Serialize for Dimensions {
160    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
161        self.sizes.serialize(serializer)
162    }
163}
164
165impl<'de> Deserialize<'de> for Dimensions {
166    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
167        let sizes = Vec::deserialize(deserializer)?;
168        Ok(Self::try_from(sizes).expect("invalid dimension length"))
169    }
170}