rust_ai_core/error.rs
1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! Unified error types for the rust-ai ecosystem.
5//!
6//! This module provides common error types that are shared across all rust-ai crates.
7//! Each crate can extend these with domain-specific variants while maintaining
8//! compatibility for error conversion.
9//!
10//! ## Error Hierarchy
11//!
12//! ```text
13//! CoreError
14//! ├── InvalidConfig - Configuration validation failures
15//! ├── ShapeMismatch - Tensor shape incompatibilities
16//! ├── DimensionMismatch - Dimension count mismatches
17//! ├── DeviceNotAvailable - Requested device unavailable
18//! ├── DeviceMismatch - Tensors on different devices
19//! ├── OutOfMemory - GPU/CPU memory exhausted
20//! ├── KernelError - GPU kernel launch/execution failure
21//! ├── NotImplemented - Feature not yet implemented
22//! ├── Io - File/network I/O errors
23//! └── Candle - Underlying Candle errors
24//! ```
25//!
26//! ## Crate-Specific Errors
27//!
28//! Crates should define their own error types that wrap `CoreError`:
29//!
30//! ```rust
31//! use rust_ai_core::CoreError;
32//! use thiserror::Error;
33//!
34//! #[derive(Error, Debug)]
35//! pub enum MyError {
36//! #[error("adapter not found: {0}")]
37//! AdapterNotFound(String),
38//!
39//! #[error(transparent)]
40//! Core(#[from] CoreError),
41//! }
42//! ```
43
44use thiserror::Error;
45
46/// Result type alias for rust-ai-core operations.
47pub type Result<T> = std::result::Result<T, CoreError>;
48
49/// Core errors shared across the rust-ai ecosystem.
50///
51/// These errors represent common failure modes that can occur in any crate.
52/// Domain-specific errors should wrap these variants.
53#[derive(Error, Debug)]
54#[non_exhaustive]
55pub enum CoreError {
56 /// Invalid configuration parameter.
57 ///
58 /// Raised when a configuration value is out of bounds, incompatible,
59 /// or otherwise invalid.
60 #[error("invalid configuration: {0}")]
61 InvalidConfig(String),
62
63 /// Tensor shape mismatch.
64 ///
65 /// Raised when an operation expects tensors of specific shapes but
66 /// receives tensors with incompatible shapes.
67 #[error("shape mismatch: expected {expected:?}, got {actual:?}")]
68 ShapeMismatch {
69 /// Expected shape.
70 expected: Vec<usize>,
71 /// Actual shape received.
72 actual: Vec<usize>,
73 },
74
75 /// Dimension count mismatch.
76 ///
77 /// Raised when tensors have different numbers of dimensions.
78 #[error("dimension mismatch: {message}")]
79 DimensionMismatch {
80 /// Descriptive error message.
81 message: String,
82 },
83
84 /// Requested device not available.
85 ///
86 /// Raised when attempting to use a device (e.g., CUDA:1) that doesn't
87 /// exist or isn't accessible.
88 #[error("device not available: {device}")]
89 DeviceNotAvailable {
90 /// Description of the unavailable device.
91 device: String,
92 },
93
94 /// Device mismatch between tensors.
95 ///
96 /// Raised when an operation requires tensors on the same device but
97 /// they reside on different devices.
98 #[error("device mismatch: tensors must be on the same device")]
99 DeviceMismatch,
100
101 /// Out of memory.
102 ///
103 /// Raised when GPU or CPU memory allocation fails.
104 #[error("out of memory: {message}")]
105 OutOfMemory {
106 /// Descriptive error message.
107 message: String,
108 },
109
110 /// GPU kernel error.
111 ///
112 /// Raised when a CUDA/CubeCL kernel fails to launch or execute.
113 #[error("kernel error: {message}")]
114 KernelError {
115 /// Descriptive error message.
116 message: String,
117 },
118
119 /// Feature not implemented.
120 ///
121 /// Raised when a requested feature or code path is not yet implemented.
122 #[error("not implemented: {feature}")]
123 NotImplemented {
124 /// Description of the unimplemented feature.
125 feature: String,
126 },
127
128 /// I/O error.
129 ///
130 /// Raised for file operations, network errors, serialization failures, etc.
131 #[error("I/O error: {0}")]
132 Io(String),
133
134 /// Underlying Candle error.
135 ///
136 /// Wraps errors from the Candle tensor library.
137 #[error("candle error: {0}")]
138 Candle(#[from] candle_core::Error),
139}
140
141impl CoreError {
142 /// Create an invalid configuration error.
143 pub fn invalid_config(msg: impl Into<String>) -> Self {
144 Self::InvalidConfig(msg.into())
145 }
146
147 /// Create a shape mismatch error.
148 pub fn shape_mismatch(expected: impl Into<Vec<usize>>, actual: impl Into<Vec<usize>>) -> Self {
149 Self::ShapeMismatch {
150 expected: expected.into(),
151 actual: actual.into(),
152 }
153 }
154
155 /// Create a dimension mismatch error.
156 pub fn dim_mismatch(msg: impl Into<String>) -> Self {
157 Self::DimensionMismatch {
158 message: msg.into(),
159 }
160 }
161
162 /// Create a device not available error.
163 pub fn device_not_available(device: impl Into<String>) -> Self {
164 Self::DeviceNotAvailable {
165 device: device.into(),
166 }
167 }
168
169 /// Create an out of memory error.
170 pub fn oom(msg: impl Into<String>) -> Self {
171 Self::OutOfMemory {
172 message: msg.into(),
173 }
174 }
175
176 /// Create a kernel error.
177 pub fn kernel(msg: impl Into<String>) -> Self {
178 Self::KernelError {
179 message: msg.into(),
180 }
181 }
182
183 /// Create a not implemented error.
184 pub fn not_implemented(feature: impl Into<String>) -> Self {
185 Self::NotImplemented {
186 feature: feature.into(),
187 }
188 }
189
190 /// Create an I/O error.
191 pub fn io(msg: impl Into<String>) -> Self {
192 Self::Io(msg.into())
193 }
194}
195
196impl From<std::io::Error> for CoreError {
197 fn from(err: std::io::Error) -> Self {
198 Self::Io(err.to_string())
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn test_error_display() {
208 let err = CoreError::invalid_config("rank must be positive");
209 assert_eq!(
210 err.to_string(),
211 "invalid configuration: rank must be positive"
212 );
213
214 let err = CoreError::shape_mismatch(vec![2, 3], vec![3, 2]);
215 assert!(err.to_string().contains("shape mismatch"));
216
217 let err = CoreError::device_not_available("CUDA:5");
218 assert!(err.to_string().contains("CUDA:5"));
219 }
220
221 #[test]
222 fn test_error_conversion() {
223 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
224 let core_err: CoreError = io_err.into();
225 assert!(matches!(core_err, CoreError::Io(_)));
226 }
227}