Skip to main content

torsh_python/
device.rs

1//! Device handling for Python bindings
2//!
3//! This module provides PyO3 bindings for ToRSh device types, allowing Python code
4//! to specify and manage computational devices (CPU, CUDA, Metal, etc.).
5//!
6//! # Examples
7//!
8//! ```python
9//! import torsh
10//!
11//! # Create devices
12//! cpu = torsh.PyDevice("cpu")
13//! cuda = torsh.PyDevice("cuda:0")
14//! metal = torsh.PyDevice("metal:0")
15//!
16//! # Check device properties
17//! print(cpu.type)    # "cpu"
18//! print(cuda.index)  # 0
19//! ```
20
21use crate::error::PyResult;
22use pyo3::prelude::*;
23use torsh_core::device::DeviceType;
24
25/// Python wrapper for ToRSh devices
26///
27/// Represents a computational device where tensors can be allocated and operations executed.
28/// Supports CPU, CUDA (NVIDIA GPUs), Metal (Apple Silicon), and WGPU devices.
29///
30/// # Examples
31///
32/// ```python
33/// # Create CPU device
34/// cpu = torsh.PyDevice("cpu")
35///
36/// # Create CUDA device (default index 0)
37/// cuda = torsh.PyDevice("cuda")
38///
39/// # Create CUDA device with specific index
40/// cuda1 = torsh.PyDevice("cuda:1")
41///
42/// # Create from integer (defaults to CUDA)
43/// cuda2 = torsh.PyDevice(2)  # cuda:2
44///
45/// # Check device properties
46/// print(cpu.type)     # "cpu"
47/// print(cuda1.type)   # "cuda"
48/// print(cuda1.index)  # 1
49/// ```
50#[pyclass(name = "device")]
51#[derive(Clone, Debug)]
52pub struct PyDevice {
53    pub(crate) device: DeviceType,
54}
55
56#[pymethods]
57impl PyDevice {
58    /// Create a new device from a string or integer specification.
59    ///
60    /// # Arguments
61    ///
62    /// * `device` - Device specification as string ("cpu", "cuda", "cuda:0", "metal:0")
63    ///              or integer (for CUDA device index)
64    ///
65    /// # Returns
66    ///
67    /// New PyDevice instance
68    ///
69    /// # Errors
70    ///
71    /// Returns ValueError if:
72    /// - Device string is not recognized
73    /// - Device index is invalid (negative or malformed)
74    /// - Input type is not string or integer
75    ///
76    /// # Examples
77    ///
78    /// ```python
79    /// cpu = torsh.PyDevice("cpu")
80    /// cuda = torsh.PyDevice("cuda:0")
81    /// cuda_from_int = torsh.PyDevice(1)  # cuda:1
82    /// ```
83    #[new]
84    fn new(device: &Bound<'_, PyAny>) -> PyResult<Self> {
85        let device_type = if let Ok(s) = device.extract::<String>() {
86            match s.as_str() {
87                "cpu" => DeviceType::Cpu,
88                "cuda" | "cuda:0" => DeviceType::Cuda(0),
89                "metal" | "metal:0" => DeviceType::Metal(0),
90                s if s.starts_with("cuda:") => {
91                    let id: usize = s[5..].parse().map_err(|_| {
92                        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
93                            "Invalid CUDA device ID: {}",
94                            &s[5..]
95                        ))
96                    })?;
97                    DeviceType::Cuda(id)
98                }
99                s if s.starts_with("metal:") => {
100                    let id: usize = s[6..].parse().map_err(|_| {
101                        PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
102                            "Invalid Metal device ID: {}",
103                            &s[6..]
104                        ))
105                    })?;
106                    DeviceType::Metal(id)
107                }
108                _ => {
109                    return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
110                        "Unknown device: {}",
111                        s
112                    )))
113                }
114            }
115        } else if let Ok(i) = device.extract::<i32>() {
116            // Accept integer for CUDA device ID
117            if i < 0 {
118                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
119                    "Device ID must be non-negative",
120                ));
121            }
122            DeviceType::Cuda(i as usize)
123        } else {
124            return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
125                "Device must be a string or integer",
126            ));
127        };
128
129        Ok(Self {
130            device: device_type,
131        })
132    }
133
134    fn __str__(&self) -> String {
135        match self.device {
136            DeviceType::Cpu => "cpu".to_string(),
137            DeviceType::Cuda(id) => format!("cuda:{}", id),
138            DeviceType::Metal(id) => format!("metal:{}", id),
139            DeviceType::Wgpu(id) => format!("wgpu:{}", id),
140        }
141    }
142
143    fn __repr__(&self) -> String {
144        match self.index() {
145            Some(idx) => format!("device(type='{}', index={})", self.type_(), idx),
146            None => format!("device(type='{}')", self.type_()),
147        }
148    }
149
150    fn __eq__(&self, other: &PyDevice) -> bool {
151        self.device == other.device
152    }
153
154    fn __hash__(&self) -> u64 {
155        use std::collections::hash_map::DefaultHasher;
156        use std::hash::{Hash, Hasher};
157        let mut hasher = DefaultHasher::new();
158        self.device.hash(&mut hasher);
159        hasher.finish()
160    }
161
162    /// Get the type of this device (cpu, cuda, metal, wgpu).
163    ///
164    /// # Returns
165    ///
166    /// String representing the device type
167    ///
168    /// # Examples
169    ///
170    /// ```python
171    /// cpu = torsh.PyDevice("cpu")
172    /// print(cpu.type)  # "cpu"
173    ///
174    /// cuda = torsh.PyDevice("cuda:3")
175    /// print(cuda.type)  # "cuda"
176    /// ```
177    #[getter]
178    #[pyo3(name = "type")]
179    fn type_(&self) -> String {
180        match self.device {
181            DeviceType::Cpu => "cpu".to_string(),
182            DeviceType::Cuda(_) => "cuda".to_string(),
183            DeviceType::Metal(_) => "metal".to_string(),
184            DeviceType::Wgpu(_) => "wgpu".to_string(),
185        }
186    }
187
188    /// Get the index of this device (for multi-device systems).
189    ///
190    /// # Returns
191    ///
192    /// Device index (0-based) for CUDA/Metal/WGPU devices, None for CPU
193    ///
194    /// # Examples
195    ///
196    /// ```python
197    /// cpu = torsh.PyDevice("cpu")
198    /// print(cpu.index)  # None
199    ///
200    /// cuda = torsh.PyDevice("cuda:2")
201    /// print(cuda.index)  # 2
202    /// ```
203    #[getter]
204    fn index(&self) -> Option<u32> {
205        match self.device {
206            DeviceType::Cpu => None,
207            DeviceType::Cuda(id) => Some(id as u32),
208            DeviceType::Metal(id) => Some(id as u32),
209            DeviceType::Wgpu(id) => Some(id as u32),
210        }
211    }
212}
213
214impl From<DeviceType> for PyDevice {
215    fn from(device: DeviceType) -> Self {
216        Self { device }
217    }
218}
219
220impl From<PyDevice> for DeviceType {
221    fn from(py_device: PyDevice) -> Self {
222        py_device.device
223    }
224}
225
226impl std::fmt::Display for PyDevice {
227    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228        write!(f, "{}", self.__str__())
229    }
230}
231
232/// Helper function to parse device from Python arguments
233pub fn parse_device(device: Option<&Bound<'_, PyAny>>) -> PyResult<DeviceType> {
234    match device {
235        Some(d) => Ok(PyDevice::new(d)?.device),
236        None => Ok(DeviceType::Cpu), // Default to CPU
237    }
238}
239
240/// Register device constants and utility functions with the Python module.
241///
242/// This function adds:
243/// - Device constants (cpu, etc.)
244/// - Device utility functions (device_count, is_available, etc.)
245///
246/// # Arguments
247///
248/// * `m` - Python module to register functions with
249///
250/// # Returns
251///
252/// PyResult<()> indicating success or failure
253pub fn register_device_constants(m: &Bound<'_, PyModule>) -> PyResult<()> {
254    use pyo3::wrap_pyfunction;
255
256    // Create device constants
257    m.add(
258        "cpu",
259        PyDevice {
260            device: DeviceType::Cpu,
261        },
262    )?;
263
264    /// Get the number of available devices.
265    ///
266    /// # Returns
267    ///
268    /// Number of available compute devices
269    ///
270    /// # Note
271    ///
272    /// Currently returns 1 (CPU). Proper device discovery will be added in future versions.
273    #[pyfunction]
274    fn device_count() -> u32 {
275        // For now, return 1 (would need proper device discovery)
276        1
277    }
278
279    #[pyfunction]
280    fn is_available() -> bool {
281        true
282    }
283
284    #[pyfunction]
285    fn cuda_is_available() -> bool {
286        // Would need proper CUDA detection
287        false
288    }
289
290    #[pyfunction]
291    fn mps_is_available() -> bool {
292        // Metal Performance Shaders availability
293        false
294    }
295
296    #[pyfunction]
297    fn get_device_name(device: Option<PyDevice>) -> String {
298        match device {
299            Some(d) => d.__str__(),
300            None => "cpu".to_string(),
301        }
302    }
303
304    m.add_function(wrap_pyfunction!(device_count, m)?)?;
305    m.add_function(wrap_pyfunction!(is_available, m)?)?;
306    m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
307    m.add_function(wrap_pyfunction!(mps_is_available, m)?)?;
308    m.add_function(wrap_pyfunction!(get_device_name, m)?)?;
309
310    Ok(())
311}