tritonserver_rs/
error.rs

1use std::{
2    error::Error as ErrorExt,
3    ffi::{CStr, CString},
4    fmt, io,
5    mem::transmute,
6};
7
8use crate::sys;
9
10pub(crate) const CSTR_CONVERT_ERROR_PLUG: &str = "INVALID UTF-8 STRING";
11
12/// Triton server error codes
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14#[repr(u32)]
15pub enum ErrorCode {
16    Unknown = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_UNKNOWN,
17    Internal = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_INTERNAL,
18    NotFound = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_NOT_FOUND,
19    InvalidArg = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_INVALID_ARG,
20    Unavailable = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_UNAVAILABLE,
21    Unsupported = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_UNSUPPORTED,
22    Alreadyxists = sys::TRITONSERVER_errorcode_enum_TRITONSERVER_ERROR_ALREADY_EXISTS,
23}
24
25/// Triton server error.
26pub struct Error {
27    pub(crate) ptr: *mut sys::TRITONSERVER_Error,
28    pub(crate) owned: bool,
29}
30
31/// It's protected by the owned, so until no one changes owned it's safe.
32/// User can't change it anyhow: it's private + pub methods don't change it.
33unsafe impl Send for Error {}
34unsafe impl Sync for Error {}
35
36impl Error {
37    /// Create new custom error.
38    pub fn new<S: AsRef<str>>(code: ErrorCode, message: S) -> Self {
39        let message = CString::new(message.as_ref()).expect("CString::new failed");
40        unsafe {
41            let this = sys::TRITONSERVER_ErrorNew(code as u32, message.as_ptr());
42            assert!(!this.is_null());
43            this.into()
44        }
45    }
46
47    /// Return ErrorCode of the error.
48    pub fn code(&self) -> ErrorCode {
49        unsafe { transmute(sys::TRITONSERVER_ErrorCode(self.ptr)) }
50    }
51
52    /// Return string representation of the ErrorCode.
53    pub fn name(&self) -> &str {
54        let ptr = unsafe { sys::TRITONSERVER_ErrorCodeString(self.ptr) };
55        if ptr.is_null() {
56            "NULL"
57        } else {
58            unsafe { CStr::from_ptr(ptr) }
59                .to_str()
60                .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
61        }
62    }
63
64    /// Return error description.
65    pub fn message(&self) -> &str {
66        let ptr = unsafe { sys::TRITONSERVER_ErrorMessage(self.ptr) };
67        if ptr.is_null() {
68            "NULL"
69        } else {
70            unsafe { CStr::from_ptr(ptr) }
71                .to_str()
72                .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
73        }
74    }
75
76    #[cfg(not(feature = "gpu"))]
77    pub(crate) fn wrong_type(mem_type: crate::memory::MemoryType) -> Self {
78        Self::new(
79            ErrorCode::InvalidArg,
80            format!("Got {mem_type:?} with gpu feature disabled"),
81        )
82    }
83}
84
85impl fmt::Debug for Error {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        write!(f, "{}: {}", self.name(), self.message())
88    }
89}
90
91impl fmt::Display for Error {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        write!(f, "{}: {}", self.name(), self.message())
94    }
95}
96
97impl From<*mut sys::TRITONSERVER_Error> for Error {
98    fn from(ptr: *mut sys::TRITONSERVER_Error) -> Self {
99        Error { ptr, owned: true }
100    }
101}
102
103impl ErrorExt for Error {}
104
105impl Drop for Error {
106    fn drop(&mut self) {
107        if self.owned && !self.ptr.is_null() {
108            unsafe { sys::TRITONSERVER_ErrorDelete(self.ptr) };
109        }
110    }
111}
112
113impl From<Error> for io::Error {
114    fn from(err: Error) -> Self {
115        io::Error::new(io::ErrorKind::Other, err.to_string())
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn create() {
125        const ERROR_CODE: ErrorCode = ErrorCode::Unknown;
126        const ERROR_DESCRIPTION: &str = "some error";
127
128        let err = Error::new(ERROR_CODE, ERROR_DESCRIPTION);
129
130        assert_eq!(err.code(), ERROR_CODE);
131        assert_eq!(err.message(), ERROR_DESCRIPTION);
132    }
133}