1use thiserror::Error;
2
3#[derive(Debug, Clone, PartialEq, Eq, Error)]
5pub enum TensorError {
6 #[error("tensor size overflow for shape {shape:?}")]
7 SizeOverflow { shape: Vec<usize> },
8 #[error(
9 "tensor shape {shape:?} expects {} elements, got {data_len}",
10 shape_element_count(shape).unwrap_or(0)
11 )]
12 SizeMismatch { shape: Vec<usize>, data_len: usize },
13 #[error("shape mismatch: left={left:?}, right={right:?}")]
14 ShapeMismatch { left: Vec<usize>, right: Vec<usize> },
15 #[error("broadcast mismatch: left={left:?}, right={right:?}")]
16 BroadcastIncompatible { left: Vec<usize>, right: Vec<usize> },
17 #[error("cannot reshape from {from:?} to {to:?} due to size mismatch")]
18 ReshapeSizeMismatch { from: Vec<usize>, to: Vec<usize> },
19 #[error("axis {axis} is out of range for rank {rank}")]
20 InvalidAxis { axis: usize, rank: usize },
21 #[error("index rank mismatch: expected {expected}, got {got}")]
22 InvalidIndexRank { expected: usize, got: usize },
23 #[error("index out of bounds at axis {axis}: index={index}, dim={dim}")]
24 IndexOutOfBounds {
25 axis: usize,
26 index: usize,
27 dim: usize,
28 },
29 #[error("dtype mismatch: expected {expected:?}, got {got:?}")]
30 DTypeMismatch { expected: DType, got: DType },
31 #[error("unsupported operation: {msg}")]
32 UnsupportedOperation { msg: String },
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37pub enum DType {
38 F32,
40 F16,
42 BF16,
44}
45
46fn shape_element_count(shape: &[usize]) -> Option<usize> {
47 shape
48 .iter()
49 .try_fold(1usize, |acc, dim| acc.checked_mul(*dim))
50}