torsh_python/
error.rs

1//! Error handling for PyTorch Python bindings
2
3use pyo3::exceptions::{PyIndexError, PyOSError, PyRuntimeError, PyValueError};
4use pyo3::prelude::*;
5use torsh_core::error::TorshError;
6
7/// Python exception wrapper for ToRSh errors
8#[pyclass]
9pub struct TorshPyError {
10    message: String,
11}
12
13#[pymethods]
14impl TorshPyError {
15    #[new]
16    fn new(message: String) -> Self {
17        Self { message }
18    }
19
20    fn __str__(&self) -> String {
21        self.message.clone()
22    }
23
24    fn __repr__(&self) -> String {
25        format!("TorshError('{}')", self.message)
26    }
27}
28
29/// Convert ToRSh errors to Python exceptions
30pub fn torsh_error_to_py_err(err: TorshError) -> PyErr {
31    match err {
32        // Modular error variants
33        TorshError::Shape(shape_err) => {
34            PyValueError::new_err(format!("Shape error: {}", shape_err))
35        }
36        TorshError::Index(index_err) => {
37            PyIndexError::new_err(format!("Index error: {}", index_err))
38        }
39        TorshError::General(general_err) => {
40            PyRuntimeError::new_err(format!("General error: {}", general_err))
41        }
42
43        // Error with context
44        TorshError::WithContext { message, .. } => {
45            PyRuntimeError::new_err(format!("ToRSh error: {}", message))
46        }
47
48        // Legacy compatibility variants
49        TorshError::ShapeMismatch { expected, got } => PyValueError::new_err(format!(
50            "Shape mismatch: expected {:?}, got {:?}",
51            expected, got
52        )),
53        TorshError::BroadcastError { shape1, shape2 } => PyValueError::new_err(format!(
54            "Broadcasting error: incompatible shapes {:?} and {:?}",
55            shape1, shape2
56        )),
57        TorshError::IndexOutOfBounds { index, size } => {
58            PyIndexError::new_err(format!("Index {} out of bounds for size {}", index, size))
59        }
60        TorshError::InvalidArgument(msg) => {
61            PyValueError::new_err(format!("Invalid argument: {}", msg))
62        }
63        TorshError::IoError(msg) => PyOSError::new_err(format!("IO error: {}", msg)),
64        TorshError::DeviceMismatch => {
65            PyOSError::new_err("Device mismatch: tensors must be on the same device")
66        }
67        TorshError::NotImplemented(msg) => {
68            PyRuntimeError::new_err(format!("Not implemented: {}", msg))
69        }
70        TorshError::SynchronizationError(msg) => {
71            PyRuntimeError::new_err(format!("Synchronization error: {}", msg))
72        }
73        TorshError::AllocationError(msg) => {
74            PyRuntimeError::new_err(format!("Memory allocation failed: {}", msg))
75        }
76        TorshError::InvalidOperation(msg) => {
77            PyRuntimeError::new_err(format!("Invalid operation: {}", msg))
78        }
79        TorshError::ConversionError(msg) => {
80            PyValueError::new_err(format!("Numeric conversion error: {}", msg))
81        }
82        TorshError::BackendError(msg) => PyRuntimeError::new_err(format!("Backend error: {}", msg)),
83        TorshError::InvalidShape(msg) => PyValueError::new_err(format!("Invalid shape: {}", msg)),
84        TorshError::RuntimeError(msg) => PyRuntimeError::new_err(format!("Runtime error: {}", msg)),
85
86        // Additional missing variants - handle all with catch-all
87        _ => PyRuntimeError::new_err(format!("ToRSh error: {}", err)),
88    }
89}
90
91/// Result type for Python operations
92pub type PyResult<T> = Result<T, PyErr>;
93
94/// Convert ToRSh Result to Python Result
95pub fn to_py_result<T>(result: torsh_core::error::Result<T>) -> PyResult<T> {
96    result.map_err(torsh_error_to_py_err)
97}
98
99/// Macro for easy error conversion
100#[macro_export]
101macro_rules! py_result {
102    ($expr:expr) => {
103        $crate::error::to_py_result($expr)
104    };
105}
106
107/// Register error types with Python module
108pub fn register_error_types(m: &Bound<'_, PyModule>) -> PyResult<()> {
109    m.add("TorshError", m.py().get_type::<TorshPyError>())?;
110    Ok(())
111}