1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub const MAX_TENSOR_RANK: usize = 64;
8
9pub mod prelude {
10 pub use crate::{
11 MAX_TENSOR_RANK, TensorAxis, TensorDType, TensorDeviceKind, TensorDim, TensorLayout,
12 TensorMemoryFormat, TensorRank, TensorShape, TensorShapeError,
13 };
14}
15
16#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
17pub struct TensorShape {
18 dims: Vec<usize>,
19}
20
21impl TensorShape {
22 pub fn new(dims: impl Into<Vec<usize>>) -> Result<Self, TensorShapeError> {
23 let dims = dims.into();
24 if dims.len() > MAX_TENSOR_RANK {
25 return Err(TensorShapeError::RankTooLarge {
26 rank: dims.len(),
27 max: MAX_TENSOR_RANK,
28 });
29 }
30
31 Ok(Self { dims })
32 }
33
34 pub fn scalar() -> Self {
35 Self { dims: Vec::new() }
36 }
37
38 pub fn rank(&self) -> usize {
39 self.dims.len()
40 }
41
42 pub fn dims(&self) -> &[usize] {
43 &self.dims
44 }
45
46 pub fn num_elements(&self) -> Option<usize> {
47 self.dims
48 .iter()
49 .copied()
50 .try_fold(1_usize, usize::checked_mul)
51 }
52
53 pub fn is_scalar(&self) -> bool {
54 self.rank() == 0
55 }
56
57 pub fn is_vector(&self) -> bool {
58 self.rank() == 1
59 }
60
61 pub fn is_matrix(&self) -> bool {
62 self.rank() == 2
63 }
64}
65
66#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
67pub struct TensorDim(usize);
68
69impl TensorDim {
70 pub const fn new(value: usize) -> Self {
71 Self(value)
72 }
73
74 pub const fn get(self) -> usize {
75 self.0
76 }
77}
78
79impl fmt::Display for TensorDim {
80 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
81 self.0.fmt(formatter)
82 }
83}
84
85#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
86pub struct TensorRank(usize);
87
88impl TensorRank {
89 pub fn new(value: usize) -> Result<Self, TensorShapeError> {
90 if value > MAX_TENSOR_RANK {
91 Err(TensorShapeError::RankTooLarge {
92 rank: value,
93 max: MAX_TENSOR_RANK,
94 })
95 } else {
96 Ok(Self(value))
97 }
98 }
99
100 pub const fn get(self) -> usize {
101 self.0
102 }
103}
104
105#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
106pub struct TensorAxis(usize);
107
108impl TensorAxis {
109 pub fn new(value: usize, rank: TensorRank) -> Result<Self, TensorShapeError> {
110 if value < rank.get() {
111 Ok(Self(value))
112 } else {
113 Err(TensorShapeError::AxisOutOfBounds {
114 axis: value,
115 rank: rank.get(),
116 })
117 }
118 }
119
120 pub const fn unchecked(value: usize) -> Self {
121 Self(value)
122 }
123
124 pub const fn index(self) -> usize {
125 self.0
126 }
127}
128
129macro_rules! tensor_enum {
130 ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
131 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
132 pub enum $name {
133 $($variant),+
134 }
135
136 impl $name {
137 pub const fn as_str(self) -> &'static str {
138 match self {
139 $(Self::$variant => $label),+
140 }
141 }
142 }
143
144 impl fmt::Display for $name {
145 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
146 formatter.write_str(self.as_str())
147 }
148 }
149
150 impl FromStr for $name {
151 type Err = TensorShapeError;
152
153 fn from_str(value: &str) -> Result<Self, Self::Err> {
154 match normalized_label(value)?.as_str() {
155 $($label => Ok(Self::$variant),)+
156 _ => Err(TensorShapeError::UnknownLabel),
157 }
158 }
159 }
160 };
161}
162
163tensor_enum!(TensorDType {
164 Bool => "bool",
165 Int8 => "int8",
166 Int16 => "int16",
167 Int32 => "int32",
168 Int64 => "int64",
169 Uint8 => "uint8",
170 Uint16 => "uint16",
171 Uint32 => "uint32",
172 Uint64 => "uint64",
173 Float16 => "float16",
174 BFloat16 => "bfloat16",
175 Float32 => "float32",
176 Float64 => "float64",
177 Complex64 => "complex64",
178 Complex128 => "complex128",
179 String => "string",
180 Unknown => "unknown",
181});
182
183tensor_enum!(TensorLayout {
184 Dense => "dense",
185 Sparse => "sparse",
186 Ragged => "ragged",
187 Quantized => "quantized",
188 BlockSparse => "block-sparse",
189 Unknown => "unknown",
190});
191
192tensor_enum!(TensorDeviceKind {
193 Cpu => "cpu",
194 Gpu => "gpu",
195 Tpu => "tpu",
196 Npu => "npu",
197 Metal => "metal",
198 Vulkan => "vulkan",
199 Wasm => "wasm",
200 Unknown => "unknown",
201});
202
203tensor_enum!(TensorMemoryFormat {
204 Contiguous => "contiguous",
205 ChannelsLast => "channels-last",
206 ChannelsFirst => "channels-first",
207 Sparse => "sparse",
208 Unknown => "unknown",
209});
210
211#[derive(Clone, Copy, Debug, Eq, PartialEq)]
212pub enum TensorShapeError {
213 EmptyLabel,
214 UnknownLabel,
215 RankTooLarge { rank: usize, max: usize },
216 AxisOutOfBounds { axis: usize, rank: usize },
217}
218
219impl fmt::Display for TensorShapeError {
220 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
221 match self {
222 Self::EmptyLabel => formatter.write_str("tensor metadata label cannot be empty"),
223 Self::UnknownLabel => formatter.write_str("unknown tensor metadata label"),
224 Self::RankTooLarge { rank, max } => {
225 write!(formatter, "tensor rank {rank} exceeds maximum rank {max}")
226 },
227 Self::AxisOutOfBounds { axis, rank } => {
228 write!(formatter, "tensor axis {axis} is outside rank {rank}")
229 },
230 }
231 }
232}
233
234impl Error for TensorShapeError {}
235
236fn normalized_label(value: &str) -> Result<String, TensorShapeError> {
237 let trimmed = value.trim();
238 if trimmed.is_empty() {
239 Err(TensorShapeError::EmptyLabel)
240 } else {
241 Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::{
248 TensorAxis, TensorDType, TensorDeviceKind, TensorLayout, TensorMemoryFormat, TensorRank,
249 TensorShape, TensorShapeError,
250 };
251
252 #[test]
253 fn models_tensor_shapes() -> Result<(), TensorShapeError> {
254 let scalar = TensorShape::scalar();
255 let vector = TensorShape::new([3])?;
256 let matrix = TensorShape::new([2, 3])?;
257 let tensor = TensorShape::new([2, 3, 4])?;
258
259 assert!(scalar.is_scalar());
260 assert!(vector.is_vector());
261 assert!(matrix.is_matrix());
262 assert_eq!(tensor.rank(), 3);
263 assert_eq!(tensor.dims(), &[2, 3, 4]);
264 assert_eq!(tensor.num_elements(), Some(24));
265 Ok(())
266 }
267
268 #[test]
269 fn protects_element_count_overflow() -> Result<(), TensorShapeError> {
270 let shape = TensorShape::new([usize::MAX, 2])?;
271
272 assert_eq!(shape.num_elements(), None);
273 Ok(())
274 }
275
276 #[test]
277 fn validates_rank_and_axis() -> Result<(), TensorShapeError> {
278 let rank = TensorRank::new(3)?;
279 let axis = TensorAxis::new(2, rank)?;
280
281 assert_eq!(rank.get(), 3);
282 assert_eq!(axis.index(), 2);
283 assert_eq!(
284 TensorAxis::new(3, rank),
285 Err(TensorShapeError::AxisOutOfBounds { axis: 3, rank: 3 })
286 );
287 Ok(())
288 }
289
290 #[test]
291 fn displays_and_parses_tensor_enums() -> Result<(), TensorShapeError> {
292 assert_eq!("float32".parse::<TensorDType>()?, TensorDType::Float32);
293 assert_eq!(
294 "block sparse".parse::<TensorLayout>()?,
295 TensorLayout::BlockSparse
296 );
297 assert_eq!("gpu".parse::<TensorDeviceKind>()?, TensorDeviceKind::Gpu);
298 assert_eq!(
299 "channels_last".parse::<TensorMemoryFormat>()?,
300 TensorMemoryFormat::ChannelsLast
301 );
302 Ok(())
303 }
304}