1use pyo3::exceptions::{PyIndexError, PyOSError, PyRuntimeError, PyValueError};
4use pyo3::prelude::*;
5use torsh_core::error::TorshError;
6
7#[pyclass]
9pub struct TorshPyError {
10 message: String,
11}
12
13#[pymethods]
14impl TorshPyError {
15 #[new]
16 fn new(message: String) -> Self {
17 Self { message }
18 }
19
20 fn __str__(&self) -> String {
21 self.message.clone()
22 }
23
24 fn __repr__(&self) -> String {
25 format!("TorshError('{}')", self.message)
26 }
27}
28
29pub fn torsh_error_to_py_err(err: TorshError) -> PyErr {
31 match err {
32 TorshError::Shape(shape_err) => {
34 PyValueError::new_err(format!("Shape error: {}", shape_err))
35 }
36 TorshError::Index(index_err) => {
37 PyIndexError::new_err(format!("Index error: {}", index_err))
38 }
39 TorshError::General(general_err) => {
40 PyRuntimeError::new_err(format!("General error: {}", general_err))
41 }
42
43 TorshError::WithContext { message, .. } => {
45 PyRuntimeError::new_err(format!("ToRSh error: {}", message))
46 }
47
48 TorshError::ShapeMismatch { expected, got } => PyValueError::new_err(format!(
50 "Shape mismatch: expected {:?}, got {:?}",
51 expected, got
52 )),
53 TorshError::BroadcastError { shape1, shape2 } => PyValueError::new_err(format!(
54 "Broadcasting error: incompatible shapes {:?} and {:?}",
55 shape1, shape2
56 )),
57 TorshError::IndexOutOfBounds { index, size } => {
58 PyIndexError::new_err(format!("Index {} out of bounds for size {}", index, size))
59 }
60 TorshError::InvalidArgument(msg) => {
61 PyValueError::new_err(format!("Invalid argument: {}", msg))
62 }
63 TorshError::IoError(msg) => PyOSError::new_err(format!("IO error: {}", msg)),
64 TorshError::DeviceMismatch => {
65 PyOSError::new_err("Device mismatch: tensors must be on the same device")
66 }
67 TorshError::NotImplemented(msg) => {
68 PyRuntimeError::new_err(format!("Not implemented: {}", msg))
69 }
70 TorshError::SynchronizationError(msg) => {
71 PyRuntimeError::new_err(format!("Synchronization error: {}", msg))
72 }
73 TorshError::AllocationError(msg) => {
74 PyRuntimeError::new_err(format!("Memory allocation failed: {}", msg))
75 }
76 TorshError::InvalidOperation(msg) => {
77 PyRuntimeError::new_err(format!("Invalid operation: {}", msg))
78 }
79 TorshError::ConversionError(msg) => {
80 PyValueError::new_err(format!("Numeric conversion error: {}", msg))
81 }
82 TorshError::BackendError(msg) => PyRuntimeError::new_err(format!("Backend error: {}", msg)),
83 TorshError::InvalidShape(msg) => PyValueError::new_err(format!("Invalid shape: {}", msg)),
84 TorshError::RuntimeError(msg) => PyRuntimeError::new_err(format!("Runtime error: {}", msg)),
85
86 _ => PyRuntimeError::new_err(format!("ToRSh error: {}", err)),
88 }
89}
90
91pub type PyResult<T> = Result<T, PyErr>;
93
94pub fn to_py_result<T>(result: torsh_core::error::Result<T>) -> PyResult<T> {
96 result.map_err(torsh_error_to_py_err)
97}
98
99#[macro_export]
101macro_rules! py_result {
102 ($expr:expr) => {
103 $crate::error::to_py_result($expr)
104 };
105}
106
107pub fn register_error_types(m: &Bound<'_, PyModule>) -> PyResult<()> {
109 m.add("TorshError", m.py().get_type::<TorshPyError>())?;
110 Ok(())
111}