Skip to main content

sp1_gpu_cudart/
error.rs

1use core::fmt;
2use std::ffi::CStr;
3
4use sp1_gpu_sys::runtime::{
5    CudaRustError, CUDA_ERROR_NOT_READY_SLOP, CUDA_OUT_OF_MEMORY, CUDA_SUCCESS_CSL,
6};
7use thiserror::Error;
8
9#[derive(Clone, Copy, PartialEq, Eq)]
10pub struct OtherError(CudaRustError);
11
12#[derive(Clone, Debug, Copy, PartialEq, Eq, Error)]
13pub enum CudaError {
14    #[error("out of GPU memory")]
15    OutOfMemory,
16    #[error("not ready")]
17    NotReady,
18    #[error("other CUDA error: {0}")]
19    Other(#[from] OtherError),
20}
21
22unsafe impl Send for CudaError {}
23unsafe impl Sync for CudaError {}
24
25impl CudaError {
26    /// Get a result from a [CudaRustError].
27    ///
28    /// The [CudaRustError] is the FFI representation of the cuda runtime result enum which could
29    /// signal a success or an error. In case of success, this method returns `Ok(())`. In case of
30    /// an error, this method returns an error of the appropriate type.
31    #[inline]
32    pub fn result_from_ffi(maybe_error: CudaRustError) -> Result<(), Self> {
33        // # Safety
34        // These constants are well defined in the sys crate.
35        unsafe {
36            match maybe_error {
37                e if e == CUDA_SUCCESS_CSL => Ok(()),
38                e if e == CUDA_OUT_OF_MEMORY => Err(Self::OutOfMemory),
39                e if e == CUDA_ERROR_NOT_READY_SLOP => Err(Self::NotReady),
40                _ => Err(Self::Other(OtherError(maybe_error))),
41            }
42        }
43    }
44}
45
46impl fmt::Debug for OtherError {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        // # Safety
49        // This is safe because the error came from a well formed CudaRustError type.
50        let message = unsafe { CStr::from_ptr(self.0.message).to_str().map_err(|_| fmt::Error)? };
51        write!(f, "CudaRustError: {message}")
52    }
53}
54
55impl fmt::Display for OtherError {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        write!(f, "{self:?}")
58    }
59}
60
61impl core::error::Error for OtherError {}