Skip to main content

use_ml_tensor/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub const MAX_TENSOR_RANK: usize = 64;
8
9pub mod prelude {
10    pub use crate::{
11        MAX_TENSOR_RANK, TensorAxis, TensorDType, TensorDeviceKind, TensorDim, TensorLayout,
12        TensorMemoryFormat, TensorRank, TensorShape, TensorShapeError,
13    };
14}
15
16#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
17pub struct TensorShape {
18    dims: Vec<usize>,
19}
20
21impl TensorShape {
22    pub fn new(dims: impl Into<Vec<usize>>) -> Result<Self, TensorShapeError> {
23        let dims = dims.into();
24        if dims.len() > MAX_TENSOR_RANK {
25            return Err(TensorShapeError::RankTooLarge {
26                rank: dims.len(),
27                max: MAX_TENSOR_RANK,
28            });
29        }
30
31        Ok(Self { dims })
32    }
33
34    pub fn scalar() -> Self {
35        Self { dims: Vec::new() }
36    }
37
38    pub fn rank(&self) -> usize {
39        self.dims.len()
40    }
41
42    pub fn dims(&self) -> &[usize] {
43        &self.dims
44    }
45
46    pub fn num_elements(&self) -> Option<usize> {
47        self.dims
48            .iter()
49            .copied()
50            .try_fold(1_usize, usize::checked_mul)
51    }
52
53    pub fn is_scalar(&self) -> bool {
54        self.rank() == 0
55    }
56
57    pub fn is_vector(&self) -> bool {
58        self.rank() == 1
59    }
60
61    pub fn is_matrix(&self) -> bool {
62        self.rank() == 2
63    }
64}
65
66#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
67pub struct TensorDim(usize);
68
69impl TensorDim {
70    pub const fn new(value: usize) -> Self {
71        Self(value)
72    }
73
74    pub const fn get(self) -> usize {
75        self.0
76    }
77}
78
79impl fmt::Display for TensorDim {
80    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
81        self.0.fmt(formatter)
82    }
83}
84
85#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
86pub struct TensorRank(usize);
87
88impl TensorRank {
89    pub fn new(value: usize) -> Result<Self, TensorShapeError> {
90        if value > MAX_TENSOR_RANK {
91            Err(TensorShapeError::RankTooLarge {
92                rank: value,
93                max: MAX_TENSOR_RANK,
94            })
95        } else {
96            Ok(Self(value))
97        }
98    }
99
100    pub const fn get(self) -> usize {
101        self.0
102    }
103}
104
105#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
106pub struct TensorAxis(usize);
107
108impl TensorAxis {
109    pub fn new(value: usize, rank: TensorRank) -> Result<Self, TensorShapeError> {
110        if value < rank.get() {
111            Ok(Self(value))
112        } else {
113            Err(TensorShapeError::AxisOutOfBounds {
114                axis: value,
115                rank: rank.get(),
116            })
117        }
118    }
119
120    pub const fn unchecked(value: usize) -> Self {
121        Self(value)
122    }
123
124    pub const fn index(self) -> usize {
125        self.0
126    }
127}
128
129macro_rules! tensor_enum {
130    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
131        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
132        pub enum $name {
133            $($variant),+
134        }
135
136        impl $name {
137            pub const fn as_str(self) -> &'static str {
138                match self {
139                    $(Self::$variant => $label),+
140                }
141            }
142        }
143
144        impl fmt::Display for $name {
145            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
146                formatter.write_str(self.as_str())
147            }
148        }
149
150        impl FromStr for $name {
151            type Err = TensorShapeError;
152
153            fn from_str(value: &str) -> Result<Self, Self::Err> {
154                match normalized_label(value)?.as_str() {
155                    $($label => Ok(Self::$variant),)+
156                    _ => Err(TensorShapeError::UnknownLabel),
157                }
158            }
159        }
160    };
161}
162
163tensor_enum!(TensorDType {
164    Bool => "bool",
165    Int8 => "int8",
166    Int16 => "int16",
167    Int32 => "int32",
168    Int64 => "int64",
169    Uint8 => "uint8",
170    Uint16 => "uint16",
171    Uint32 => "uint32",
172    Uint64 => "uint64",
173    Float16 => "float16",
174    BFloat16 => "bfloat16",
175    Float32 => "float32",
176    Float64 => "float64",
177    Complex64 => "complex64",
178    Complex128 => "complex128",
179    String => "string",
180    Unknown => "unknown",
181});
182
183tensor_enum!(TensorLayout {
184    Dense => "dense",
185    Sparse => "sparse",
186    Ragged => "ragged",
187    Quantized => "quantized",
188    BlockSparse => "block-sparse",
189    Unknown => "unknown",
190});
191
192tensor_enum!(TensorDeviceKind {
193    Cpu => "cpu",
194    Gpu => "gpu",
195    Tpu => "tpu",
196    Npu => "npu",
197    Metal => "metal",
198    Vulkan => "vulkan",
199    Wasm => "wasm",
200    Unknown => "unknown",
201});
202
203tensor_enum!(TensorMemoryFormat {
204    Contiguous => "contiguous",
205    ChannelsLast => "channels-last",
206    ChannelsFirst => "channels-first",
207    Sparse => "sparse",
208    Unknown => "unknown",
209});
210
211#[derive(Clone, Copy, Debug, Eq, PartialEq)]
212pub enum TensorShapeError {
213    EmptyLabel,
214    UnknownLabel,
215    RankTooLarge { rank: usize, max: usize },
216    AxisOutOfBounds { axis: usize, rank: usize },
217}
218
219impl fmt::Display for TensorShapeError {
220    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
221        match self {
222            Self::EmptyLabel => formatter.write_str("tensor metadata label cannot be empty"),
223            Self::UnknownLabel => formatter.write_str("unknown tensor metadata label"),
224            Self::RankTooLarge { rank, max } => {
225                write!(formatter, "tensor rank {rank} exceeds maximum rank {max}")
226            },
227            Self::AxisOutOfBounds { axis, rank } => {
228                write!(formatter, "tensor axis {axis} is outside rank {rank}")
229            },
230        }
231    }
232}
233
234impl Error for TensorShapeError {}
235
236fn normalized_label(value: &str) -> Result<String, TensorShapeError> {
237    let trimmed = value.trim();
238    if trimmed.is_empty() {
239        Err(TensorShapeError::EmptyLabel)
240    } else {
241        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::{
248        TensorAxis, TensorDType, TensorDeviceKind, TensorLayout, TensorMemoryFormat, TensorRank,
249        TensorShape, TensorShapeError,
250    };
251
252    #[test]
253    fn models_tensor_shapes() -> Result<(), TensorShapeError> {
254        let scalar = TensorShape::scalar();
255        let vector = TensorShape::new([3])?;
256        let matrix = TensorShape::new([2, 3])?;
257        let tensor = TensorShape::new([2, 3, 4])?;
258
259        assert!(scalar.is_scalar());
260        assert!(vector.is_vector());
261        assert!(matrix.is_matrix());
262        assert_eq!(tensor.rank(), 3);
263        assert_eq!(tensor.dims(), &[2, 3, 4]);
264        assert_eq!(tensor.num_elements(), Some(24));
265        Ok(())
266    }
267
268    #[test]
269    fn protects_element_count_overflow() -> Result<(), TensorShapeError> {
270        let shape = TensorShape::new([usize::MAX, 2])?;
271
272        assert_eq!(shape.num_elements(), None);
273        Ok(())
274    }
275
276    #[test]
277    fn validates_rank_and_axis() -> Result<(), TensorShapeError> {
278        let rank = TensorRank::new(3)?;
279        let axis = TensorAxis::new(2, rank)?;
280
281        assert_eq!(rank.get(), 3);
282        assert_eq!(axis.index(), 2);
283        assert_eq!(
284            TensorAxis::new(3, rank),
285            Err(TensorShapeError::AxisOutOfBounds { axis: 3, rank: 3 })
286        );
287        Ok(())
288    }
289
290    #[test]
291    fn displays_and_parses_tensor_enums() -> Result<(), TensorShapeError> {
292        assert_eq!("float32".parse::<TensorDType>()?, TensorDType::Float32);
293        assert_eq!(
294            "block sparse".parse::<TensorLayout>()?,
295            TensorLayout::BlockSparse
296        );
297        assert_eq!("gpu".parse::<TensorDeviceKind>()?, TensorDeviceKind::Gpu);
298        assert_eq!(
299            "channels_last".parse::<TensorMemoryFormat>()?,
300            TensorMemoryFormat::ChannelsLast
301        );
302        Ok(())
303    }
304}