tch_plus/
error.rs

1use std::ffi::NulError;
2use std::io;
3use std::num::ParseIntError;
4
5use thiserror::Error;
6use zip::result::ZipError;
7
8/// Main library error type.
9#[derive(Error, Debug)]
10pub enum TchError {
11    /// Conversion error.
12    #[error("conversion error: {0}")]
13    Convert(String),
14
15    /// Invalid file format.
16    #[error("invalid file format: {0}")]
17    FileFormat(String),
18
19    /// Missing tensor with name.
20    #[error("cannot find the tensor named {0} in {1}")]
21    TensorNameNotFound(String, String),
22
23    /// I/O error.
24    #[error(transparent)]
25    Io(#[from] io::Error),
26
27    /// Tensor kind error.
28    #[error("tensor kind error: {0}")]
29    Kind(String),
30
31    /// Missing image.
32    #[error("no image found in {0}")]
33    MissingImage(String),
34
35    /// Null pointer.
36    #[error(transparent)]
37    Nul(#[from] NulError),
38
39    /// Integer parse error.
40    #[error(transparent)]
41    ParseInt(#[from] ParseIntError),
42
43    /// Invalid shape.
44    #[error("invalid shape: {0}")]
45    Shape(String),
46
47    /// Unknown kind
48    #[error("unknown kind: {0}")]
49    UnknownKind(libc::c_int),
50
51    /// Errors returned by the Torch C++ API.
52    #[error("Internal torch error: {0}")]
53    Torch(String),
54
55    /// Zip file format error.
56    #[error(transparent)]
57    Zip(#[from] ZipError),
58
59    #[error(transparent)]
60    NdArray(#[from] ndarray::ShapeError),
61
62    /// Errors returned by the safetensors library.
63    #[error("safetensors error {path}: {err}")]
64    SafeTensorError { path: String, err: safetensors::SafeTensorError },
65}
66
67impl TchError {
68    pub fn path_context(&self, path_name: &str) -> Self {
69        match self {
70            TchError::Torch(error) => TchError::Torch(format!("{path_name}: {error}")),
71            _ => unimplemented!(),
72        }
73    }
74}