1use crate::error::PyResult;
4use pyo3::prelude::*;
5use torsh_core::dtype::DType;
6
7#[pyclass(name = "dtype")]
9#[derive(Clone, Debug)]
10pub struct PyDType {
11 pub(crate) dtype: DType,
12}
13
14#[pymethods]
15impl PyDType {
16 #[new]
17 fn new(name: &str) -> PyResult<Self> {
18 let dtype = match name {
19 "float32" | "f32" => DType::F32,
20 "float64" | "f64" => DType::F64,
21 "int8" | "i8" => DType::I8,
22 "int16" | "i16" => DType::I16,
23 "int32" | "i32" => DType::I32,
24 "int64" | "i64" => DType::I64,
25 "uint8" | "u8" => DType::U8,
26 "uint16" | "u16" => {
27 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
28 "uint16/u16 data type is not supported in ToRSh",
29 ))
30 }
31 "uint32" | "u32" => DType::U32,
32 "uint64" | "u64" => DType::U64,
33 "bool" => DType::Bool,
34 _ => {
35 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
36 "Unknown dtype: {}",
37 name
38 )))
39 }
40 };
41 Ok(Self { dtype })
42 }
43
44 fn __str__(&self) -> String {
45 match self.dtype {
46 DType::F32 => "torch.float32".to_string(),
47 DType::F64 => "torch.float64".to_string(),
48 DType::I8 => "torch.int8".to_string(),
49 DType::I16 => "torch.int16".to_string(),
50 DType::I32 => "torch.int32".to_string(),
51 DType::I64 => "torch.int64".to_string(),
52 DType::U8 => "torch.uint8".to_string(),
53 DType::U32 => "torch.uint32".to_string(),
54 DType::U64 => "torch.uint64".to_string(),
55 DType::Bool => "torch.bool".to_string(),
56 DType::F16 => "torch.float16".to_string(),
57 DType::BF16 => "torch.bfloat16".to_string(),
58 DType::C64 => "torch.complex64".to_string(),
59 DType::C128 => "torch.complex128".to_string(),
60 DType::QInt8 => "torch.qint8".to_string(),
61 _ => format!("torch.{:?}", self.dtype).to_lowercase(),
62 }
63 }
64
65 fn __repr__(&self) -> String {
66 self.__str__()
67 }
68
69 fn __eq__(&self, other: &PyDType) -> bool {
70 self.dtype == other.dtype
71 }
72
73 fn __hash__(&self) -> u64 {
74 use std::collections::hash_map::DefaultHasher;
75 use std::hash::{Hash, Hasher};
76 let mut hasher = DefaultHasher::new();
77 self.dtype.hash(&mut hasher);
78 hasher.finish()
79 }
80
81 #[getter]
82 fn name(&self) -> String {
83 match self.dtype {
84 DType::F32 => "float32".to_string(),
85 DType::F64 => "float64".to_string(),
86 DType::I8 => "int8".to_string(),
87 DType::I16 => "int16".to_string(),
88 DType::I32 => "int32".to_string(),
89 DType::I64 => "int64".to_string(),
90 DType::U8 => "uint8".to_string(),
91 DType::U32 => "uint32".to_string(),
92 DType::U64 => "uint64".to_string(),
93 DType::Bool => "bool".to_string(),
94 DType::F16 => "float16".to_string(),
95 DType::BF16 => "bfloat16".to_string(),
96 DType::C64 => "complex64".to_string(),
97 DType::C128 => "complex128".to_string(),
98 DType::QInt8 => "qint8".to_string(),
99 _ => format!("{:?}", self.dtype).to_lowercase(),
100 }
101 }
102
103 #[getter]
104 fn itemsize(&self) -> usize {
105 match self.dtype {
106 DType::F32 => 4,
107 DType::F64 => 8,
108 DType::I8 => 1,
109 DType::I16 => 2,
110 DType::I32 => 4,
111 DType::I64 => 8,
112 DType::U8 => 1,
113 DType::U32 => 4,
114 DType::U64 => 8,
115 DType::Bool => 1,
116 DType::F16 => 2,
117 DType::BF16 => 2,
118 DType::C64 => 8,
119 DType::C128 => 16,
120 DType::QInt8 => 1,
121 _ => 4, }
123 }
124
125 #[getter]
126 fn is_floating_point(&self) -> bool {
127 matches!(
128 self.dtype,
129 DType::F16 | DType::F32 | DType::F64 | DType::BF16
130 )
131 }
132
133 #[getter]
134 fn is_signed(&self) -> bool {
135 matches!(
136 self.dtype,
137 DType::F16
138 | DType::F32
139 | DType::F64
140 | DType::BF16
141 | DType::I8
142 | DType::I16
143 | DType::I32
144 | DType::I64
145 | DType::QInt8
146 )
147 }
148
149 #[getter]
162 fn is_complex(&self) -> bool {
163 matches!(self.dtype, DType::C64 | DType::C128)
164 }
165
166 #[getter]
182 fn is_integer(&self) -> bool {
183 matches!(
184 self.dtype,
185 DType::I8 | DType::I16 | DType::I32 | DType::I64 | DType::U8 | DType::U32 | DType::U64
186 )
187 }
188
189 #[getter]
202 fn numpy_dtype(&self) -> String {
203 match self.dtype {
204 DType::F32 => "float32".to_string(),
205 DType::F64 => "float64".to_string(),
206 DType::F16 => "float16".to_string(),
207 DType::I8 => "int8".to_string(),
208 DType::I16 => "int16".to_string(),
209 DType::I32 => "int32".to_string(),
210 DType::I64 => "int64".to_string(),
211 DType::U8 => "uint8".to_string(),
212 DType::U32 => "uint32".to_string(),
213 DType::U64 => "uint64".to_string(),
214 DType::Bool => "bool".to_string(),
215 DType::C64 => "complex64".to_string(),
216 DType::C128 => "complex128".to_string(),
217 _ => format!("{:?}", self.dtype).to_lowercase(),
218 }
219 }
220
221 fn can_cast(&self, other: &PyDType) -> bool {
243 if self.dtype == other.dtype {
245 return true;
246 }
247
248 match (self.dtype, other.dtype) {
250 (DType::I8, DType::I16 | DType::I32 | DType::I64) => true,
252 (DType::I16, DType::I32 | DType::I64) => true,
253 (DType::I32, DType::I64) => true,
254
255 (DType::U8, DType::U32 | DType::U64) => true,
257 (DType::U32, DType::U64) => true,
258
259 (DType::I8 | DType::I16 | DType::I32, DType::F32 | DType::F64) => true,
261 (DType::I64, DType::F64) => true,
262 (DType::U8 | DType::U32, DType::F32 | DType::F64) => true,
263
264 (DType::F16, DType::F32 | DType::F64) => true,
266 (DType::F32, DType::F64) => true,
267
268 (DType::Bool, DType::I8 | DType::I16 | DType::I32 | DType::I64) => true,
270 (DType::Bool, DType::U8 | DType::U32 | DType::U64) => true,
271 (DType::Bool, DType::F16 | DType::F32 | DType::F64) => true,
272
273 (DType::F32, DType::C64 | DType::C128) => true,
275 (DType::F64, DType::C128) => true,
276
277 (DType::C64, DType::C128) => true,
279
280 _ => false,
282 }
283 }
284}
285
286impl From<DType> for PyDType {
287 fn from(dtype: DType) -> Self {
288 Self { dtype }
289 }
290}
291
292impl From<PyDType> for DType {
293 fn from(py_dtype: PyDType) -> Self {
294 py_dtype.dtype
295 }
296}
297
298impl std::fmt::Display for PyDType {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 write!(f, "{}", self.__str__())
301 }
302}
303
304pub fn register_dtype_constants(m: &Bound<'_, PyModule>) -> PyResult<()> {
311 use pyo3::wrap_pyfunction;
312
313 m.add("float32", PyDType { dtype: DType::F32 })?;
315 m.add("float64", PyDType { dtype: DType::F64 })?;
316 m.add("int8", PyDType { dtype: DType::I8 })?;
317 m.add("int16", PyDType { dtype: DType::I16 })?;
318 m.add("int32", PyDType { dtype: DType::I32 })?;
319 m.add("int64", PyDType { dtype: DType::I64 })?;
320 m.add("uint8", PyDType { dtype: DType::U8 })?;
321 m.add("uint32", PyDType { dtype: DType::U32 })?;
322 m.add("uint64", PyDType { dtype: DType::U64 })?;
323 m.add("bool", PyDType { dtype: DType::Bool })?;
324
325 m.add("float", PyDType { dtype: DType::F32 })?;
327 m.add("double", PyDType { dtype: DType::F64 })?;
328 m.add("long", PyDType { dtype: DType::I64 })?;
329 m.add("int", PyDType { dtype: DType::I32 })?;
330 m.add("short", PyDType { dtype: DType::I16 })?;
331 m.add("char", PyDType { dtype: DType::I8 })?;
332 m.add("byte", PyDType { dtype: DType::U8 })?;
333
334 #[pyfunction]
356 fn promote_types(dtype1: &PyDType, dtype2: &PyDType) -> PyDType {
357 use DType::*;
358
359 if dtype1.dtype == dtype2.dtype {
361 return dtype1.clone();
362 }
363
364 let promoted = match (dtype1.dtype, dtype2.dtype) {
366 (Bool, other) | (other, Bool) => other,
368
369 (C128, _) | (_, C128) => C128,
371 (C64, _) | (_, C64) => C64,
372
373 (F64, _) | (_, F64) => F64,
375 (F32, _) | (_, F32) => F32,
376 (F16, _) | (_, F16) => F16,
377
378 (I64, I8 | I16 | I32 | U8 | U32 | U64) | (I8 | I16 | I32 | U8 | U32 | U64, I64) => I64,
380 (I32, I8 | I16 | U8) | (I8 | I16 | U8, I32) => I32,
381 (I16, I8 | U8) | (I8 | U8, I16) => I16,
382
383 (U64, U8 | U32) | (U8 | U32, U64) => U64,
385 (U32, U8) | (U8, U32) => U32,
386
387 (a, b) => {
389 let size_a = dtype1.itemsize();
390 let size_b = dtype2.itemsize();
391 if size_a >= size_b {
392 a
393 } else {
394 b
395 }
396 }
397 };
398
399 PyDType { dtype: promoted }
400 }
401
402 #[pyfunction]
420 fn result_type(dtype1: &PyDType, dtype2: &PyDType) -> PyDType {
421 promote_types(dtype1, dtype2)
424 }
425
426 #[pyfunction]
444 fn can_operate(_dtype1: &PyDType, _dtype2: &PyDType) -> bool {
445 true
448 }
449
450 m.add_function(wrap_pyfunction!(promote_types, m)?)?;
451 m.add_function(wrap_pyfunction!(result_type, m)?)?;
452 m.add_function(wrap_pyfunction!(can_operate, m)?)?;
453
454 Ok(())
455}