torsh_python/
device.rs

1//! Device handling for Python bindings
2
3use crate::error::PyResult;
4use pyo3::prelude::*;
5use torsh_core::device::DeviceType;
6
7/// Python wrapper for ToRSh devices
8#[pyclass(name = "device")]
9#[derive(Clone, Debug)]
10pub struct PyDevice {
11    pub(crate) device: DeviceType,
12}
13
14#[pymethods]
15impl PyDevice {
16    #[new]
17    fn new(device: &Bound<'_, PyAny>) -> PyResult<Self> {
18        let device_type = if let Ok(s) = device.extract::<String>() {
19            match s.as_str() {
20                "cpu" => DeviceType::Cpu,
21                "cuda" | "cuda:0" => DeviceType::Cuda(0),
22                "metal" | "metal:0" => DeviceType::Metal(0),
23                s if s.starts_with("cuda:") => {
24                    let id: usize = s[5..].parse().map_err(|_| {
25                        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
26                            "Invalid CUDA device ID: {}",
27                            &s[5..]
28                        ))
29                    })?;
30                    DeviceType::Cuda(id)
31                }
32                s if s.starts_with("metal:") => {
33                    let id: usize = s[6..].parse().map_err(|_| {
34                        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
35                            "Invalid Metal device ID: {}",
36                            &s[6..]
37                        ))
38                    })?;
39                    DeviceType::Metal(id)
40                }
41                _ => {
42                    return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
43                        "Unknown device: {}",
44                        s
45                    )))
46                }
47            }
48        } else if let Ok(i) = device.extract::<i32>() {
49            // Accept integer for CUDA device ID
50            if i < 0 {
51                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
52                    "Device ID must be non-negative",
53                ));
54            }
55            DeviceType::Cuda(i as usize)
56        } else {
57            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
58                "Device must be a string or integer",
59            ));
60        };
61
62        Ok(Self {
63            device: device_type,
64        })
65    }
66
67    fn __str__(&self) -> String {
68        match self.device {
69            DeviceType::Cpu => "cpu".to_string(),
70            DeviceType::Cuda(id) => format!("cuda:{}", id),
71            DeviceType::Metal(id) => format!("metal:{}", id),
72            DeviceType::Wgpu(id) => format!("wgpu:{}", id),
73        }
74    }
75
76    fn __repr__(&self) -> String {
77        match self.index() {
78            Some(idx) => format!("device(type='{}', index={})", self.type_(), idx),
79            None => format!("device(type='{}')", self.type_()),
80        }
81    }
82
83    fn __eq__(&self, other: &PyDevice) -> bool {
84        self.device == other.device
85    }
86
87    fn __hash__(&self) -> u64 {
88        use std::collections::hash_map::DefaultHasher;
89        use std::hash::{Hash, Hasher};
90        let mut hasher = DefaultHasher::new();
91        self.device.hash(&mut hasher);
92        hasher.finish()
93    }
94
95    #[getter]
96    fn type_(&self) -> String {
97        match self.device {
98            DeviceType::Cpu => "cpu".to_string(),
99            DeviceType::Cuda(_) => "cuda".to_string(),
100            DeviceType::Metal(_) => "metal".to_string(),
101            DeviceType::Wgpu(_) => "wgpu".to_string(),
102        }
103    }
104
105    #[getter]
106    fn index(&self) -> Option<u32> {
107        match self.device {
108            DeviceType::Cpu => None,
109            DeviceType::Cuda(id) => Some(id as u32),
110            DeviceType::Metal(id) => Some(id as u32),
111            DeviceType::Wgpu(id) => Some(id as u32),
112        }
113    }
114}
115
116impl From<DeviceType> for PyDevice {
117    fn from(device: DeviceType) -> Self {
118        Self { device }
119    }
120}
121
122impl From<PyDevice> for DeviceType {
123    fn from(py_device: PyDevice) -> Self {
124        py_device.device
125    }
126}
127
128impl std::fmt::Display for PyDevice {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        write!(f, "{}", self.__str__())
131    }
132}
133
134/// Helper function to parse device from Python arguments
135pub fn parse_device(device: Option<&Bound<'_, PyAny>>) -> PyResult<DeviceType> {
136    match device {
137        Some(d) => Ok(PyDevice::new(d)?.device),
138        None => Ok(DeviceType::Cpu), // Default to CPU
139    }
140}
141
142/// Register device constants and functions
143pub fn register_device_constants(m: &Bound<'_, PyModule>) -> PyResult<()> {
144    use pyo3::wrap_pyfunction;
145
146    // Create device constants
147    m.add(
148        "cpu",
149        PyDevice {
150            device: DeviceType::Cpu,
151        },
152    )?;
153
154    // Add device utility functions
155    #[pyfunction]
156    fn device_count() -> u32 {
157        // For now, return 1 (would need proper device discovery)
158        1
159    }
160
161    #[pyfunction]
162    fn is_available() -> bool {
163        true
164    }
165
166    #[pyfunction]
167    fn cuda_is_available() -> bool {
168        // Would need proper CUDA detection
169        false
170    }
171
172    #[pyfunction]
173    fn mps_is_available() -> bool {
174        // Metal Performance Shaders availability
175        false
176    }
177
178    #[pyfunction]
179    fn get_device_name(device: Option<PyDevice>) -> String {
180        match device {
181            Some(d) => d.__str__(),
182            None => "cpu".to_string(),
183        }
184    }
185
186    m.add_function(wrap_pyfunction!(device_count, m)?)?;
187    m.add_function(wrap_pyfunction!(is_available, m)?)?;
188    m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
189    m.add_function(wrap_pyfunction!(mps_is_available, m)?)?;
190    m.add_function(wrap_pyfunction!(get_device_name, m)?)?;
191
192    Ok(())
193}