slop_tensor/
dimensions.rs1use arrayvec::ArrayVec;
2use itertools::Itertools;
3use serde::{Deserialize, Deserializer, Serialize, Serializer};
4use thiserror::Error;
5
6const MAX_DIMENSIONS: usize = 3;
7
8#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
9#[repr(C)]
10pub struct Dimensions {
11 sizes: ArrayVec<usize, MAX_DIMENSIONS>,
12 strides: ArrayVec<usize, MAX_DIMENSIONS>,
13}
14
15#[derive(Debug, Clone, Copy, Error)]
16pub enum DimensionsError {
17 #[error("Too many dimensions {0}, maximum number allowed is {MAX_DIMENSIONS}")]
18 TooManyDimensions(usize),
19 #[error("total number of elements must match, expected {0}, got {1}")]
20 NumElementsMismatch(usize, usize),
21}
22
23impl Dimensions {
24 fn new(sizes: ArrayVec<usize, MAX_DIMENSIONS>) -> Self {
25 let mut strides = ArrayVec::new();
26 let mut stride = 1;
27 for size in sizes.iter().rev() {
28 strides.push(stride);
29 stride *= size;
30 }
31 strides.reverse();
32 Self { sizes, strides }
33 }
34
35 #[inline]
36 pub fn total_len(&self) -> usize {
37 self.sizes.iter().product()
38 }
39
40 #[inline]
41 pub(crate) fn compatible(&self, other: &Dimensions) -> Result<(), DimensionsError> {
42 if self.total_len() != other.total_len() {
43 return Err(DimensionsError::NumElementsMismatch(self.total_len(), other.total_len()));
44 }
45 Ok(())
46 }
47
48 #[inline]
49 pub fn sizes(&self) -> &[usize] {
50 &self.sizes
51 }
52
53 pub(crate) fn sizes_mut(&mut self) -> &mut ArrayVec<usize, MAX_DIMENSIONS> {
54 &mut self.sizes
55 }
56
57 pub(crate) fn strides_mut(&mut self) -> &mut ArrayVec<usize, MAX_DIMENSIONS> {
58 &mut self.strides
59 }
60
61 #[inline]
62 pub fn strides(&self) -> &[usize] {
63 &self.strides
64 }
65
66 #[inline]
67 pub(crate) fn index_map(&self, index: impl AsRef<[usize]>) -> usize {
68 index.as_ref().iter().zip_eq(self.strides.iter()).map(|(i, s)| i * s).sum()
69 }
70}
71
72impl TryFrom<&[usize]> for Dimensions {
73 type Error = DimensionsError;
74
75 fn try_from(value: &[usize]) -> Result<Self, Self::Error> {
76 let sizes = ArrayVec::try_from(value)
77 .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
78 Ok(Self::new(sizes))
79 }
80}
81
82impl TryFrom<Vec<usize>> for Dimensions {
83 type Error = DimensionsError;
84
85 fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
86 let sizes = ArrayVec::try_from(value.as_slice())
87 .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
88 Ok(Self::new(sizes))
89 }
90}
91
92impl<const N: usize> TryFrom<[usize; N]> for Dimensions {
93 type Error = DimensionsError;
94
95 fn try_from(value: [usize; N]) -> Result<Self, Self::Error> {
96 let sizes = ArrayVec::try_from(value.as_slice())
97 .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
98 Ok(Self::new(sizes))
99 }
100}
101
102impl FromIterator<usize> for Dimensions {
103 #[inline]
104 fn from_iter<T: IntoIterator<Item = usize>>(iter: T) -> Self {
105 let sizes = ArrayVec::from_iter(iter);
106 Self::new(sizes)
107 }
108}
109
110impl Serialize for Dimensions {
111 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
112 self.sizes.serialize(serializer)
113 }
114}
115
116impl<'de> Deserialize<'de> for Dimensions {
117 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
118 let sizes = Vec::deserialize(deserializer)?;
119 Ok(Self::try_from(sizes).expect("invalid dimension length"))
120 }
121}