1use std::fmt;
5
6pub type RusTorchResult<T> = crate::error::RusTorchResult<T>;
9
10#[derive(Debug)]
13pub enum RusTorchError {
14 TensorError(TensorError),
17
18 GpuError(GpuError),
21
22 DistributedError(DistributedError),
25
26 NeuralNetworkError(NeuralNetworkError),
29
30 OptimizationError(OptimizationError),
33
34 DataError(DataError),
37
38 MemoryError(MemoryError),
41
42 IoError(std::io::Error),
45
46 Generic(String),
49}
50
51#[derive(Debug, Clone)]
54pub enum TensorError {
55 ShapeMismatch {
58 expected: Vec<usize>,
61 actual: Vec<usize>,
64 },
65 DimensionMismatch {
68 lhs: Vec<usize>,
71 rhs: Vec<usize>,
74 },
75 InsufficientDimensions {
78 required: usize,
81 actual: usize,
84 },
85 InvalidShape(Vec<usize>),
88 InvalidIndex(Vec<usize>),
91 InvalidOperation(String),
94 EmptyTensor,
97 DataTypeError(String),
100}
101
102#[derive(Debug, Clone)]
105pub enum GpuError {
106 DeviceNotFound(usize),
109 DeviceNotSupported(String),
112 MemoryAllocationFailed(usize),
115 MemoryTransferFailed(String),
118 KernelCompilationFailed(String),
121 KernelExecutionFailed(String),
124 ContextCreationFailed(String),
127 InvalidDevice(String),
130 OutOfMemory,
133 DriverError(String),
136}
137
138#[derive(Debug, Clone)]
141pub enum DistributedError {
142 BackendNotSupported(String),
145 CommunicationFailed(String),
148 ProcessGroupError(String),
151 SynchronizationFailed(String),
154 NodeConnectionFailed(String),
157 InvalidRank(i32),
160 InvalidWorldSize(i32),
163 TimeoutError(String),
166 NetworkError(String),
169}
170
171#[derive(Debug, Clone)]
174pub enum NeuralNetworkError {
175 LayerError(String),
178 ActivationError(String),
181 LossError(String),
184 ForwardPassError(String),
187 BackwardPassError(String),
190 ParameterError(String),
193 ModelError(String),
196}
197
198#[derive(Debug, Clone)]
201pub enum OptimizationError {
202 OptimizerError(String),
205 SchedulerError(String),
208 GradientError(String),
211 ConvergenceError(String),
214 LearningRateError(String),
217}
218
219#[derive(Debug, Clone)]
222pub enum DataError {
223 DatasetError(String),
226 DataLoaderError(String),
229 BatchError(String),
232 TransformError(String),
235 FileError(String),
238}
239
240#[derive(Debug, Clone)]
243pub enum MemoryError {
244 AllocationFailed(usize),
247 DeallocationFailed(String),
250 AlignmentError(usize),
253 PoolExhausted,
256 InvalidPointer,
259 MemoryLeak(String),
262}
263
264impl fmt::Display for RusTorchError {
265 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266 match self {
267 RusTorchError::TensorError(e) => write!(f, "Tensor error: {}", e),
268 RusTorchError::GpuError(e) => write!(f, "GPU error: {}", e),
269 RusTorchError::DistributedError(e) => write!(f, "Distributed error: {}", e),
270 RusTorchError::NeuralNetworkError(e) => write!(f, "Neural network error: {}", e),
271 RusTorchError::OptimizationError(e) => write!(f, "Optimization error: {}", e),
272 RusTorchError::DataError(e) => write!(f, "Data error: {}", e),
273 RusTorchError::MemoryError(e) => write!(f, "Memory error: {}", e),
274 RusTorchError::IoError(e) => write!(f, "I/O error: {}", e),
275 RusTorchError::Generic(msg) => write!(f, "Error: {}", msg),
276 }
277 }
278}
279
280impl fmt::Display for TensorError {
281 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282 match self {
283 TensorError::ShapeMismatch { expected, actual } => {
284 write!(
285 f,
286 "Shape mismatch: expected {:?}, got {:?}",
287 expected, actual
288 )
289 }
290 TensorError::DimensionMismatch { lhs, rhs } => {
291 write!(f, "Dimension mismatch: {:?} vs {:?}", lhs, rhs)
292 }
293 TensorError::InsufficientDimensions { required, actual } => {
294 write!(
295 f,
296 "Insufficient dimensions: required {}, got {}",
297 required, actual
298 )
299 }
300 TensorError::InvalidShape(shape) => write!(f, "Invalid shape: {:?}", shape),
301 TensorError::InvalidIndex(index) => write!(f, "Invalid index: {:?}", index),
302 TensorError::InvalidOperation(op) => write!(f, "Invalid operation: {}", op),
303 TensorError::EmptyTensor => write!(f, "Operation on empty tensor"),
304 TensorError::DataTypeError(msg) => write!(f, "Data type error: {}", msg),
305 }
306 }
307}
308
309impl fmt::Display for GpuError {
310 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
311 match self {
312 GpuError::DeviceNotFound(id) => write!(f, "GPU device {} not found", id),
313 GpuError::DeviceNotSupported(device) => {
314 write!(f, "GPU device not supported: {}", device)
315 }
316 GpuError::MemoryAllocationFailed(size) => {
317 write!(f, "GPU memory allocation failed: {} bytes", size)
318 }
319 GpuError::MemoryTransferFailed(msg) => write!(f, "GPU memory transfer failed: {}", msg),
320 GpuError::KernelCompilationFailed(msg) => {
321 write!(f, "GPU kernel compilation failed: {}", msg)
322 }
323 GpuError::KernelExecutionFailed(msg) => {
324 write!(f, "GPU kernel execution failed: {}", msg)
325 }
326 GpuError::ContextCreationFailed(msg) => {
327 write!(f, "GPU context creation failed: {}", msg)
328 }
329 GpuError::InvalidDevice(device) => write!(f, "Invalid GPU device: {}", device),
330 GpuError::OutOfMemory => write!(f, "GPU out of memory"),
331 GpuError::DriverError(msg) => write!(f, "GPU driver error: {}", msg),
332 }
333 }
334}
335
336impl fmt::Display for DistributedError {
337 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
338 match self {
339 DistributedError::BackendNotSupported(backend) => {
340 write!(f, "Distributed backend not supported: {}", backend)
341 }
342 DistributedError::CommunicationFailed(msg) => {
343 write!(f, "Distributed communication failed: {}", msg)
344 }
345 DistributedError::ProcessGroupError(msg) => {
346 write!(f, "Process group error: {}", msg)
347 }
348 DistributedError::SynchronizationFailed(msg) => {
349 write!(f, "Synchronization failed: {}", msg)
350 }
351 DistributedError::NodeConnectionFailed(msg) => {
352 write!(f, "Node connection failed: {}", msg)
353 }
354 DistributedError::InvalidRank(rank) => write!(f, "Invalid rank: {}", rank),
355 DistributedError::InvalidWorldSize(size) => write!(f, "Invalid world size: {}", size),
356 DistributedError::TimeoutError(msg) => write!(f, "Timeout error: {}", msg),
357 DistributedError::NetworkError(msg) => write!(f, "Network error: {}", msg),
358 }
359 }
360}
361
362impl fmt::Display for NeuralNetworkError {
363 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
364 match self {
365 NeuralNetworkError::LayerError(msg) => write!(f, "Layer error: {}", msg),
366 NeuralNetworkError::ActivationError(msg) => write!(f, "Activation error: {}", msg),
367 NeuralNetworkError::LossError(msg) => write!(f, "Loss error: {}", msg),
368 NeuralNetworkError::ForwardPassError(msg) => write!(f, "Forward pass error: {}", msg),
369 NeuralNetworkError::BackwardPassError(msg) => write!(f, "Backward pass error: {}", msg),
370 NeuralNetworkError::ParameterError(msg) => write!(f, "Parameter error: {}", msg),
371 NeuralNetworkError::ModelError(msg) => write!(f, "Model error: {}", msg),
372 }
373 }
374}
375
376impl fmt::Display for OptimizationError {
377 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378 match self {
379 OptimizationError::OptimizerError(msg) => write!(f, "Optimizer error: {}", msg),
380 OptimizationError::SchedulerError(msg) => write!(f, "Scheduler error: {}", msg),
381 OptimizationError::GradientError(msg) => write!(f, "Gradient error: {}", msg),
382 OptimizationError::ConvergenceError(msg) => write!(f, "Convergence error: {}", msg),
383 OptimizationError::LearningRateError(msg) => write!(f, "Learning rate error: {}", msg),
384 }
385 }
386}
387
388impl fmt::Display for DataError {
389 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
390 match self {
391 DataError::DatasetError(msg) => write!(f, "Dataset error: {}", msg),
392 DataError::DataLoaderError(msg) => write!(f, "DataLoader error: {}", msg),
393 DataError::BatchError(msg) => write!(f, "Batch error: {}", msg),
394 DataError::TransformError(msg) => write!(f, "Transform error: {}", msg),
395 DataError::FileError(msg) => write!(f, "File error: {}", msg),
396 }
397 }
398}
399
400impl fmt::Display for MemoryError {
401 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
402 match self {
403 MemoryError::AllocationFailed(size) => {
404 write!(f, "Memory allocation failed: {} bytes", size)
405 }
406 MemoryError::DeallocationFailed(msg) => {
407 write!(f, "Memory deallocation failed: {}", msg)
408 }
409 MemoryError::AlignmentError(alignment) => {
410 write!(f, "Memory alignment error: {} bytes", alignment)
411 }
412 MemoryError::PoolExhausted => write!(f, "Memory pool exhausted"),
413 MemoryError::InvalidPointer => write!(f, "Invalid memory pointer"),
414 MemoryError::MemoryLeak(msg) => write!(f, "Memory leak detected: {}", msg),
415 }
416 }
417}
418
419impl std::error::Error for RusTorchError {}
420impl std::error::Error for TensorError {}
421impl std::error::Error for GpuError {}
422impl std::error::Error for DistributedError {}
423impl std::error::Error for NeuralNetworkError {}
424impl std::error::Error for OptimizationError {}
425impl std::error::Error for DataError {}
426impl std::error::Error for MemoryError {}
427
428impl From<std::io::Error> for RusTorchError {
430 fn from(err: std::io::Error) -> Self {
431 RusTorchError::IoError(err)
432 }
433}
434
435impl From<TensorError> for RusTorchError {
436 fn from(err: TensorError) -> Self {
437 RusTorchError::TensorError(err)
438 }
439}
440
441impl From<GpuError> for RusTorchError {
442 fn from(err: GpuError) -> Self {
443 RusTorchError::GpuError(err)
444 }
445}
446
447impl From<DistributedError> for RusTorchError {
448 fn from(err: DistributedError) -> Self {
449 RusTorchError::DistributedError(err)
450 }
451}
452
453impl From<NeuralNetworkError> for RusTorchError {
454 fn from(err: NeuralNetworkError) -> Self {
455 RusTorchError::NeuralNetworkError(err)
456 }
457}
458
459impl From<OptimizationError> for RusTorchError {
460 fn from(err: OptimizationError) -> Self {
461 RusTorchError::OptimizationError(err)
462 }
463}
464
465impl From<DataError> for RusTorchError {
466 fn from(err: DataError) -> Self {
467 RusTorchError::DataError(err)
468 }
469}
470
471impl From<MemoryError> for RusTorchError {
472 fn from(err: MemoryError) -> Self {
473 RusTorchError::MemoryError(err)
474 }
475}
476
477#[macro_export]
484macro_rules! tensor_error {
485 ($variant:ident) => {
486 RusTorchError::TensorError(TensorError::$variant)
487 };
488 ($variant:ident, $($arg:expr),+) => {
489 RusTorchError::TensorError(TensorError::$variant { $($arg),+ })
490 };
491}
492
493#[macro_export]
496macro_rules! gpu_error {
497 ($variant:ident) => {
498 RusTorchError::GpuError(GpuError::$variant)
499 };
500 ($variant:ident, $arg:expr) => {
501 RusTorchError::GpuError(GpuError::$variant($arg))
502 };
503}
504
505#[macro_export]
508macro_rules! distributed_error {
509 ($variant:ident, $arg:expr) => {
510 RusTorchError::DistributedError(DistributedError::$variant($arg))
511 };
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn test_error_display() {
520 let tensor_err = RusTorchError::TensorError(TensorError::EmptyTensor);
521 assert!(tensor_err.to_string().contains("empty tensor"));
522
523 let gpu_err = RusTorchError::GpuError(GpuError::OutOfMemory);
524 assert!(gpu_err.to_string().contains("out of memory"));
525 }
526
527 #[test]
528 fn test_error_conversion() {
529 let tensor_err = TensorError::EmptyTensor;
530 let rustorch_err: RusTorchError = tensor_err.into();
531 matches!(rustorch_err, RusTorchError::TensorError(_));
532 }
533
534 #[test]
535 fn test_error_macros() {
536 let err = tensor_error!(EmptyTensor);
537 matches!(err, RusTorchError::TensorError(TensorError::EmptyTensor));
538
539 let err = gpu_error!(OutOfMemory);
540 matches!(err, RusTorchError::GpuError(GpuError::OutOfMemory));
541 }
542}