1use std::fmt;
2
3#[cfg_attr(
5 feature = "serde-support",
6 derive(serde::Serialize, serde::Deserialize)
7)]
8#[derive(Debug, Clone, PartialEq)]
9#[non_exhaustive]
10pub enum NnError {
11 ShapeMismatch {
13 expected: Vec<usize>,
14 got: Vec<usize>,
15 },
16 NoGradient,
18 InvalidParameter {
20 name: &'static str,
21 reason: &'static str,
22 },
23 EmptyInput,
25 IndexOutOfBounds { index: usize, len: usize },
27 OnnxError(String),
29 SerializeError(String),
31 CoreError(scivex_core::CoreError),
33 #[cfg(feature = "gpu")]
35 GpuError(String),
36}
37
38impl fmt::Display for NnError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 Self::ShapeMismatch { expected, got } => {
42 write!(f, "shape mismatch: expected {expected:?}, got {got:?}")
43 }
44 Self::NoGradient => write!(f, "gradient is not available"),
45 Self::InvalidParameter { name, reason } => {
46 write!(f, "invalid parameter `{name}`: {reason}")
47 }
48 Self::OnnxError(msg) => write!(f, "onnx: {msg}"),
49 Self::SerializeError(msg) => write!(f, "serialize: {msg}"),
50 Self::EmptyInput => write!(f, "input data is empty"),
51 Self::IndexOutOfBounds { index, len } => {
52 write!(f, "index {index} out of bounds for length {len}")
53 }
54 Self::CoreError(e) => write!(f, "core: {e}"),
55 #[cfg(feature = "gpu")]
56 Self::GpuError(e) => write!(f, "gpu: {e}"),
57 }
58 }
59}
60
61impl std::error::Error for NnError {}
62
63impl From<scivex_core::CoreError> for NnError {
64 fn from(e: scivex_core::CoreError) -> Self {
65 Self::CoreError(e)
66 }
67}
68
69#[cfg(feature = "gpu")]
70impl From<scivex_gpu::GpuError> for NnError {
71 fn from(e: scivex_gpu::GpuError) -> Self {
72 Self::GpuError(e.to_string())
73 }
74}
75
76pub type Result<T> = std::result::Result<T, NnError>;