torsh_python/device.rs
1//! Device handling for Python bindings
2//!
3//! This module provides PyO3 bindings for ToRSh device types, allowing Python code
4//! to specify and manage computational devices (CPU, CUDA, Metal, etc.).
5//!
6//! # Examples
7//!
8//! ```python
9//! import torsh
10//!
11//! # Create devices
12//! cpu = torsh.PyDevice("cpu")
13//! cuda = torsh.PyDevice("cuda:0")
14//! metal = torsh.PyDevice("metal:0")
15//!
16//! # Check device properties
17//! print(cpu.type) # "cpu"
18//! print(cuda.index) # 0
19//! ```
20
21use crate::error::PyResult;
22use pyo3::prelude::*;
23use torsh_core::device::DeviceType;
24
25/// Python wrapper for ToRSh devices
26///
27/// Represents a computational device where tensors can be allocated and operations executed.
28/// Supports CPU, CUDA (NVIDIA GPUs), Metal (Apple Silicon), and WGPU devices.
29///
30/// # Examples
31///
32/// ```python
33/// # Create CPU device
34/// cpu = torsh.PyDevice("cpu")
35///
36/// # Create CUDA device (default index 0)
37/// cuda = torsh.PyDevice("cuda")
38///
39/// # Create CUDA device with specific index
40/// cuda1 = torsh.PyDevice("cuda:1")
41///
42/// # Create from integer (defaults to CUDA)
43/// cuda2 = torsh.PyDevice(2) # cuda:2
44///
45/// # Check device properties
46/// print(cpu.type) # "cpu"
47/// print(cuda1.type) # "cuda"
48/// print(cuda1.index) # 1
49/// ```
50#[pyclass(name = "device")]
51#[derive(Clone, Debug)]
52pub struct PyDevice {
53 pub(crate) device: DeviceType,
54}
55
56#[pymethods]
57impl PyDevice {
58 /// Create a new device from a string or integer specification.
59 ///
60 /// # Arguments
61 ///
62 /// * `device` - Device specification as string ("cpu", "cuda", "cuda:0", "metal:0")
63 /// or integer (for CUDA device index)
64 ///
65 /// # Returns
66 ///
67 /// New PyDevice instance
68 ///
69 /// # Errors
70 ///
71 /// Returns ValueError if:
72 /// - Device string is not recognized
73 /// - Device index is invalid (negative or malformed)
74 /// - Input type is not string or integer
75 ///
76 /// # Examples
77 ///
78 /// ```python
79 /// cpu = torsh.PyDevice("cpu")
80 /// cuda = torsh.PyDevice("cuda:0")
81 /// cuda_from_int = torsh.PyDevice(1) # cuda:1
82 /// ```
83 #[new]
84 fn new(device: &Bound<'_, PyAny>) -> PyResult<Self> {
85 let device_type = if let Ok(s) = device.extract::<String>() {
86 match s.as_str() {
87 "cpu" => DeviceType::Cpu,
88 "cuda" | "cuda:0" => DeviceType::Cuda(0),
89 "metal" | "metal:0" => DeviceType::Metal(0),
90 s if s.starts_with("cuda:") => {
91 let id: usize = s[5..].parse().map_err(|_| {
92 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
93 "Invalid CUDA device ID: {}",
94 &s[5..]
95 ))
96 })?;
97 DeviceType::Cuda(id)
98 }
99 s if s.starts_with("metal:") => {
100 let id: usize = s[6..].parse().map_err(|_| {
101 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
102 "Invalid Metal device ID: {}",
103 &s[6..]
104 ))
105 })?;
106 DeviceType::Metal(id)
107 }
108 _ => {
109 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
110 "Unknown device: {}",
111 s
112 )))
113 }
114 }
115 } else if let Ok(i) = device.extract::<i32>() {
116 // Accept integer for CUDA device ID
117 if i < 0 {
118 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
119 "Device ID must be non-negative",
120 ));
121 }
122 DeviceType::Cuda(i as usize)
123 } else {
124 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
125 "Device must be a string or integer",
126 ));
127 };
128
129 Ok(Self {
130 device: device_type,
131 })
132 }
133
134 fn __str__(&self) -> String {
135 match self.device {
136 DeviceType::Cpu => "cpu".to_string(),
137 DeviceType::Cuda(id) => format!("cuda:{}", id),
138 DeviceType::Metal(id) => format!("metal:{}", id),
139 DeviceType::Wgpu(id) => format!("wgpu:{}", id),
140 }
141 }
142
143 fn __repr__(&self) -> String {
144 match self.index() {
145 Some(idx) => format!("device(type='{}', index={})", self.type_(), idx),
146 None => format!("device(type='{}')", self.type_()),
147 }
148 }
149
150 fn __eq__(&self, other: &PyDevice) -> bool {
151 self.device == other.device
152 }
153
154 fn __hash__(&self) -> u64 {
155 use std::collections::hash_map::DefaultHasher;
156 use std::hash::{Hash, Hasher};
157 let mut hasher = DefaultHasher::new();
158 self.device.hash(&mut hasher);
159 hasher.finish()
160 }
161
162 /// Get the type of this device (cpu, cuda, metal, wgpu).
163 ///
164 /// # Returns
165 ///
166 /// String representing the device type
167 ///
168 /// # Examples
169 ///
170 /// ```python
171 /// cpu = torsh.PyDevice("cpu")
172 /// print(cpu.type) # "cpu"
173 ///
174 /// cuda = torsh.PyDevice("cuda:3")
175 /// print(cuda.type) # "cuda"
176 /// ```
177 #[getter]
178 #[pyo3(name = "type")]
179 fn type_(&self) -> String {
180 match self.device {
181 DeviceType::Cpu => "cpu".to_string(),
182 DeviceType::Cuda(_) => "cuda".to_string(),
183 DeviceType::Metal(_) => "metal".to_string(),
184 DeviceType::Wgpu(_) => "wgpu".to_string(),
185 }
186 }
187
188 /// Get the index of this device (for multi-device systems).
189 ///
190 /// # Returns
191 ///
192 /// Device index (0-based) for CUDA/Metal/WGPU devices, None for CPU
193 ///
194 /// # Examples
195 ///
196 /// ```python
197 /// cpu = torsh.PyDevice("cpu")
198 /// print(cpu.index) # None
199 ///
200 /// cuda = torsh.PyDevice("cuda:2")
201 /// print(cuda.index) # 2
202 /// ```
203 #[getter]
204 fn index(&self) -> Option<u32> {
205 match self.device {
206 DeviceType::Cpu => None,
207 DeviceType::Cuda(id) => Some(id as u32),
208 DeviceType::Metal(id) => Some(id as u32),
209 DeviceType::Wgpu(id) => Some(id as u32),
210 }
211 }
212}
213
214impl From<DeviceType> for PyDevice {
215 fn from(device: DeviceType) -> Self {
216 Self { device }
217 }
218}
219
220impl From<PyDevice> for DeviceType {
221 fn from(py_device: PyDevice) -> Self {
222 py_device.device
223 }
224}
225
226impl std::fmt::Display for PyDevice {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 write!(f, "{}", self.__str__())
229 }
230}
231
232/// Helper function to parse device from Python arguments
233pub fn parse_device(device: Option<&Bound<'_, PyAny>>) -> PyResult<DeviceType> {
234 match device {
235 Some(d) => Ok(PyDevice::new(d)?.device),
236 None => Ok(DeviceType::Cpu), // Default to CPU
237 }
238}
239
240/// Register device constants and utility functions with the Python module.
241///
242/// This function adds:
243/// - Device constants (cpu, etc.)
244/// - Device utility functions (device_count, is_available, etc.)
245///
246/// # Arguments
247///
248/// * `m` - Python module to register functions with
249///
250/// # Returns
251///
252/// PyResult<()> indicating success or failure
253pub fn register_device_constants(m: &Bound<'_, PyModule>) -> PyResult<()> {
254 use pyo3::wrap_pyfunction;
255
256 // Create device constants
257 m.add(
258 "cpu",
259 PyDevice {
260 device: DeviceType::Cpu,
261 },
262 )?;
263
264 /// Get the number of available devices.
265 ///
266 /// # Returns
267 ///
268 /// Number of available compute devices
269 ///
270 /// # Note
271 ///
272 /// Currently returns 1 (CPU). Proper device discovery will be added in future versions.
273 #[pyfunction]
274 fn device_count() -> u32 {
275 // For now, return 1 (would need proper device discovery)
276 1
277 }
278
279 #[pyfunction]
280 fn is_available() -> bool {
281 true
282 }
283
284 #[pyfunction]
285 fn cuda_is_available() -> bool {
286 // Would need proper CUDA detection
287 false
288 }
289
290 #[pyfunction]
291 fn mps_is_available() -> bool {
292 // Metal Performance Shaders availability
293 false
294 }
295
296 #[pyfunction]
297 fn get_device_name(device: Option<PyDevice>) -> String {
298 match device {
299 Some(d) => d.__str__(),
300 None => "cpu".to_string(),
301 }
302 }
303
304 m.add_function(wrap_pyfunction!(device_count, m)?)?;
305 m.add_function(wrap_pyfunction!(is_available, m)?)?;
306 m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
307 m.add_function(wrap_pyfunction!(mps_is_available, m)?)?;
308 m.add_function(wrap_pyfunction!(get_device_name, m)?)?;
309
310 Ok(())
311}