Skip to main content

rust_ai_core/
error.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! Unified error types for the rust-ai ecosystem.
5//!
6//! This module provides common error types that are shared across all rust-ai crates.
7//! Each crate can extend these with domain-specific variants while maintaining
8//! compatibility for error conversion.
9//!
10//! ## Error Hierarchy
11//!
12//! ```text
13//! CoreError
14//! ├── InvalidConfig       - Configuration validation failures
15//! ├── ShapeMismatch       - Tensor shape incompatibilities
16//! ├── DimensionMismatch   - Dimension count mismatches
17//! ├── DeviceNotAvailable  - Requested device unavailable
18//! ├── DeviceMismatch      - Tensors on different devices
19//! ├── OutOfMemory         - GPU/CPU memory exhausted
20//! ├── KernelError         - GPU kernel launch/execution failure
21//! ├── NotImplemented      - Feature not yet implemented
22//! ├── Io                  - File/network I/O errors
23//! └── Candle              - Underlying Candle errors
24//! ```
25//!
26//! ## Crate-Specific Errors
27//!
28//! Crates should define their own error types that wrap `CoreError`:
29//!
30//! ```rust
31//! use rust_ai_core::CoreError;
32//! use thiserror::Error;
33//!
34//! #[derive(Error, Debug)]
35//! pub enum MyError {
36//!     #[error("adapter not found: {0}")]
37//!     AdapterNotFound(String),
38//!     
39//!     #[error(transparent)]
40//!     Core(#[from] CoreError),
41//! }
42//! ```
43
44use thiserror::Error;
45
46/// Result type alias for rust-ai-core operations.
47pub type Result<T> = std::result::Result<T, CoreError>;
48
49/// Core errors shared across the rust-ai ecosystem.
50///
51/// These errors represent common failure modes that can occur in any crate.
52/// Domain-specific errors should wrap these variants.
53#[derive(Error, Debug)]
54#[non_exhaustive]
55pub enum CoreError {
56    /// Invalid configuration parameter.
57    ///
58    /// Raised when a configuration value is out of bounds, incompatible,
59    /// or otherwise invalid.
60    #[error("invalid configuration: {0}")]
61    InvalidConfig(String),
62
63    /// Tensor shape mismatch.
64    ///
65    /// Raised when an operation expects tensors of specific shapes but
66    /// receives tensors with incompatible shapes.
67    #[error("shape mismatch: expected {expected:?}, got {actual:?}")]
68    ShapeMismatch {
69        /// Expected shape.
70        expected: Vec<usize>,
71        /// Actual shape received.
72        actual: Vec<usize>,
73    },
74
75    /// Dimension count mismatch.
76    ///
77    /// Raised when tensors have different numbers of dimensions.
78    #[error("dimension mismatch: {message}")]
79    DimensionMismatch {
80        /// Descriptive error message.
81        message: String,
82    },
83
84    /// Requested device not available.
85    ///
86    /// Raised when attempting to use a device (e.g., CUDA:1) that doesn't
87    /// exist or isn't accessible.
88    #[error("device not available: {device}")]
89    DeviceNotAvailable {
90        /// Description of the unavailable device.
91        device: String,
92    },
93
94    /// Device mismatch between tensors.
95    ///
96    /// Raised when an operation requires tensors on the same device but
97    /// they reside on different devices.
98    #[error("device mismatch: tensors must be on the same device")]
99    DeviceMismatch,
100
101    /// Out of memory.
102    ///
103    /// Raised when GPU or CPU memory allocation fails.
104    #[error("out of memory: {message}")]
105    OutOfMemory {
106        /// Descriptive error message.
107        message: String,
108    },
109
110    /// GPU kernel error.
111    ///
112    /// Raised when a CUDA/CubeCL kernel fails to launch or execute.
113    #[error("kernel error: {message}")]
114    KernelError {
115        /// Descriptive error message.
116        message: String,
117    },
118
119    /// Feature not implemented.
120    ///
121    /// Raised when a requested feature or code path is not yet implemented.
122    #[error("not implemented: {feature}")]
123    NotImplemented {
124        /// Description of the unimplemented feature.
125        feature: String,
126    },
127
128    /// I/O error.
129    ///
130    /// Raised for file operations, network errors, serialization failures, etc.
131    #[error("I/O error: {0}")]
132    Io(String),
133
134    /// Underlying Candle error.
135    ///
136    /// Wraps errors from the Candle tensor library.
137    #[error("candle error: {0}")]
138    Candle(#[from] candle_core::Error),
139}
140
141impl CoreError {
142    /// Create an invalid configuration error.
143    pub fn invalid_config(msg: impl Into<String>) -> Self {
144        Self::InvalidConfig(msg.into())
145    }
146
147    /// Create a shape mismatch error.
148    pub fn shape_mismatch(expected: impl Into<Vec<usize>>, actual: impl Into<Vec<usize>>) -> Self {
149        Self::ShapeMismatch {
150            expected: expected.into(),
151            actual: actual.into(),
152        }
153    }
154
155    /// Create a dimension mismatch error.
156    pub fn dim_mismatch(msg: impl Into<String>) -> Self {
157        Self::DimensionMismatch {
158            message: msg.into(),
159        }
160    }
161
162    /// Create a device not available error.
163    pub fn device_not_available(device: impl Into<String>) -> Self {
164        Self::DeviceNotAvailable {
165            device: device.into(),
166        }
167    }
168
169    /// Create an out of memory error.
170    pub fn oom(msg: impl Into<String>) -> Self {
171        Self::OutOfMemory {
172            message: msg.into(),
173        }
174    }
175
176    /// Create a kernel error.
177    pub fn kernel(msg: impl Into<String>) -> Self {
178        Self::KernelError {
179            message: msg.into(),
180        }
181    }
182
183    /// Create a not implemented error.
184    pub fn not_implemented(feature: impl Into<String>) -> Self {
185        Self::NotImplemented {
186            feature: feature.into(),
187        }
188    }
189
190    /// Create an I/O error.
191    pub fn io(msg: impl Into<String>) -> Self {
192        Self::Io(msg.into())
193    }
194}
195
196impl From<std::io::Error> for CoreError {
197    fn from(err: std::io::Error) -> Self {
198        Self::Io(err.to_string())
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn test_error_display() {
208        let err = CoreError::invalid_config("rank must be positive");
209        assert_eq!(
210            err.to_string(),
211            "invalid configuration: rank must be positive"
212        );
213
214        let err = CoreError::shape_mismatch(vec![2, 3], vec![3, 2]);
215        assert!(err.to_string().contains("shape mismatch"));
216
217        let err = CoreError::device_not_available("CUDA:5");
218        assert!(err.to_string().contains("CUDA:5"));
219    }
220
221    #[test]
222    fn test_error_conversion() {
223        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
224        let core_err: CoreError = io_err.into();
225        assert!(matches!(core_err, CoreError::Io(_)));
226    }
227}