slop_tensor/
dimensions.rs1use core::fmt;
2
3use arrayvec::ArrayVec;
4use itertools::Itertools;
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use thiserror::Error;
7
8const MAX_DIMENSIONS: usize = 3;
9
10#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
11#[repr(C)]
12pub struct Dimensions {
13 sizes: ArrayVec<usize, MAX_DIMENSIONS>,
14 strides: ArrayVec<usize, MAX_DIMENSIONS>,
15}
16
17impl fmt::Display for Dimensions {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 write!(f, "Dimensions({})", self.sizes.iter().join(", "))
20 }
21}
22
23#[derive(Debug, Clone, Copy, Error)]
24pub enum DimensionsError {
25 #[error("Too many dimensions {0}, maximum number allowed is {MAX_DIMENSIONS}")]
26 TooManyDimensions(usize),
27 #[error("total number of elements must match, expected {0}, got {1}")]
28 NumElementsMismatch(usize, usize),
29}
30
31impl Dimensions {
32 fn new(sizes: ArrayVec<usize, MAX_DIMENSIONS>) -> Self {
33 let mut strides = ArrayVec::new();
34 let mut stride = 1;
35 for size in sizes.iter().rev() {
36 strides.push(stride);
37 stride *= size;
38 }
39 strides.reverse();
40 Self { sizes, strides }
41 }
42
43 #[inline]
44 pub fn total_len(&self) -> usize {
45 self.sizes.iter().product()
46 }
47
48 #[inline]
49 pub(crate) fn compatible(&self, other: &Dimensions) -> Result<(), DimensionsError> {
50 if self.total_len() != other.total_len() {
51 return Err(DimensionsError::NumElementsMismatch(self.total_len(), other.total_len()));
52 }
53 Ok(())
54 }
55
56 #[inline]
57 pub fn sizes(&self) -> &[usize] {
58 &self.sizes
59 }
60
61 pub(crate) fn sizes_mut(&mut self) -> &mut ArrayVec<usize, MAX_DIMENSIONS> {
62 &mut self.sizes
63 }
64
65 pub(crate) fn strides_mut(&mut self) -> &mut ArrayVec<usize, MAX_DIMENSIONS> {
66 &mut self.strides
67 }
68
69 #[inline]
70 pub fn strides(&self) -> &[usize] {
71 &self.strides
72 }
73
74 #[inline]
79 pub(crate) fn index_map(&self, index: impl AsRef<[usize]>) -> usize {
80 #[inline(never)]
83 #[cold]
84 #[track_caller]
85 fn index_length_mismatch(buffer_index: &[usize], dimensions: &Dimensions) -> ! {
86 panic!(
87 "Index tuple {buffer_index:?} has length {} which is out of bounds for dimensions
88 {dimensions} of length {}",
89 buffer_index.len(),
90 dimensions.sizes().len()
91 );
92 }
93
94 #[inline(never)]
97 #[cold]
98 #[track_caller]
99 fn index_out_of_bounds_fail(buffer_index: &[usize], dimensions: &Dimensions) -> ! {
100 panic!("Index {buffer_index:?} is out of bounds for dimensions {dimensions}",);
101 }
102
103 if index.as_ref().len() != self.sizes.len() {
104 index_length_mismatch(index.as_ref(), self);
105 }
106
107 let mut buffer_index = 0;
108 for ((idx, stride), len) in
109 index.as_ref().iter().zip_eq(self.strides.iter()).zip_eq(self.sizes.iter())
110 {
111 if *idx >= *len {
112 index_out_of_bounds_fail(index.as_ref(), self);
113 }
114 buffer_index += idx * stride;
115 }
116
117 buffer_index
118 }
119}
120
121impl TryFrom<&[usize]> for Dimensions {
122 type Error = DimensionsError;
123
124 fn try_from(value: &[usize]) -> Result<Self, Self::Error> {
125 let sizes = ArrayVec::try_from(value)
126 .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
127 Ok(Self::new(sizes))
128 }
129}
130
131impl TryFrom<Vec<usize>> for Dimensions {
132 type Error = DimensionsError;
133
134 fn try_from(value: Vec<usize>) -> Result<Self, Self::Error> {
135 let sizes = ArrayVec::try_from(value.as_slice())
136 .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
137 Ok(Self::new(sizes))
138 }
139}
140
141impl<const N: usize> TryFrom<[usize; N]> for Dimensions {
142 type Error = DimensionsError;
143
144 fn try_from(value: [usize; N]) -> Result<Self, Self::Error> {
145 let sizes = ArrayVec::try_from(value.as_slice())
146 .map_err(|_| DimensionsError::TooManyDimensions(value.len()))?;
147 Ok(Self::new(sizes))
148 }
149}
150
151impl FromIterator<usize> for Dimensions {
152 #[inline]
153 fn from_iter<T: IntoIterator<Item = usize>>(iter: T) -> Self {
154 let sizes = ArrayVec::from_iter(iter);
155 Self::new(sizes)
156 }
157}
158
159impl Serialize for Dimensions {
160 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
161 self.sizes.serialize(serializer)
162 }
163}
164
165impl<'de> Deserialize<'de> for Dimensions {
166 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
167 let sizes = Vec::deserialize(deserializer)?;
168 Ok(Self::try_from(sizes).expect("invalid dimension length"))
169 }
170}