1use crate::error::PyResult;
4use pyo3::prelude::*;
5use torsh_core::device::DeviceType;
6
7#[pyclass(name = "device")]
9#[derive(Clone, Debug)]
10pub struct PyDevice {
11 pub(crate) device: DeviceType,
12}
13
14#[pymethods]
15impl PyDevice {
16 #[new]
17 fn new(device: &Bound<'_, PyAny>) -> PyResult<Self> {
18 let device_type = if let Ok(s) = device.extract::<String>() {
19 match s.as_str() {
20 "cpu" => DeviceType::Cpu,
21 "cuda" | "cuda:0" => DeviceType::Cuda(0),
22 "metal" | "metal:0" => DeviceType::Metal(0),
23 s if s.starts_with("cuda:") => {
24 let id: usize = s[5..].parse().map_err(|_| {
25 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
26 "Invalid CUDA device ID: {}",
27 &s[5..]
28 ))
29 })?;
30 DeviceType::Cuda(id)
31 }
32 s if s.starts_with("metal:") => {
33 let id: usize = s[6..].parse().map_err(|_| {
34 PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
35 "Invalid Metal device ID: {}",
36 &s[6..]
37 ))
38 })?;
39 DeviceType::Metal(id)
40 }
41 _ => {
42 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
43 "Unknown device: {}",
44 s
45 )))
46 }
47 }
48 } else if let Ok(i) = device.extract::<i32>() {
49 if i < 0 {
51 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
52 "Device ID must be non-negative",
53 ));
54 }
55 DeviceType::Cuda(i as usize)
56 } else {
57 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
58 "Device must be a string or integer",
59 ));
60 };
61
62 Ok(Self {
63 device: device_type,
64 })
65 }
66
67 fn __str__(&self) -> String {
68 match self.device {
69 DeviceType::Cpu => "cpu".to_string(),
70 DeviceType::Cuda(id) => format!("cuda:{}", id),
71 DeviceType::Metal(id) => format!("metal:{}", id),
72 DeviceType::Wgpu(id) => format!("wgpu:{}", id),
73 }
74 }
75
76 fn __repr__(&self) -> String {
77 match self.index() {
78 Some(idx) => format!("device(type='{}', index={})", self.type_(), idx),
79 None => format!("device(type='{}')", self.type_()),
80 }
81 }
82
83 fn __eq__(&self, other: &PyDevice) -> bool {
84 self.device == other.device
85 }
86
87 fn __hash__(&self) -> u64 {
88 use std::collections::hash_map::DefaultHasher;
89 use std::hash::{Hash, Hasher};
90 let mut hasher = DefaultHasher::new();
91 self.device.hash(&mut hasher);
92 hasher.finish()
93 }
94
95 #[getter]
96 fn type_(&self) -> String {
97 match self.device {
98 DeviceType::Cpu => "cpu".to_string(),
99 DeviceType::Cuda(_) => "cuda".to_string(),
100 DeviceType::Metal(_) => "metal".to_string(),
101 DeviceType::Wgpu(_) => "wgpu".to_string(),
102 }
103 }
104
105 #[getter]
106 fn index(&self) -> Option<u32> {
107 match self.device {
108 DeviceType::Cpu => None,
109 DeviceType::Cuda(id) => Some(id as u32),
110 DeviceType::Metal(id) => Some(id as u32),
111 DeviceType::Wgpu(id) => Some(id as u32),
112 }
113 }
114}
115
116impl From<DeviceType> for PyDevice {
117 fn from(device: DeviceType) -> Self {
118 Self { device }
119 }
120}
121
122impl From<PyDevice> for DeviceType {
123 fn from(py_device: PyDevice) -> Self {
124 py_device.device
125 }
126}
127
128impl std::fmt::Display for PyDevice {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 write!(f, "{}", self.__str__())
131 }
132}
133
134pub fn parse_device(device: Option<&Bound<'_, PyAny>>) -> PyResult<DeviceType> {
136 match device {
137 Some(d) => Ok(PyDevice::new(d)?.device),
138 None => Ok(DeviceType::Cpu), }
140}
141
142pub fn register_device_constants(m: &Bound<'_, PyModule>) -> PyResult<()> {
144 use pyo3::wrap_pyfunction;
145
146 m.add(
148 "cpu",
149 PyDevice {
150 device: DeviceType::Cpu,
151 },
152 )?;
153
154 #[pyfunction]
156 fn device_count() -> u32 {
157 1
159 }
160
161 #[pyfunction]
162 fn is_available() -> bool {
163 true
164 }
165
166 #[pyfunction]
167 fn cuda_is_available() -> bool {
168 false
170 }
171
172 #[pyfunction]
173 fn mps_is_available() -> bool {
174 false
176 }
177
178 #[pyfunction]
179 fn get_device_name(device: Option<PyDevice>) -> String {
180 match device {
181 Some(d) => d.__str__(),
182 None => "cpu".to_string(),
183 }
184 }
185
186 m.add_function(wrap_pyfunction!(device_count, m)?)?;
187 m.add_function(wrap_pyfunction!(is_available, m)?)?;
188 m.add_function(wrap_pyfunction!(cuda_is_available, m)?)?;
189 m.add_function(wrap_pyfunction!(mps_is_available, m)?)?;
190 m.add_function(wrap_pyfunction!(get_device_name, m)?)?;
191
192 Ok(())
193}