1use crate::error::PyResult;
4use pyo3::prelude::*;
5use torsh_core::dtype::DType;
6
7#[pyclass(name = "dtype")]
9#[derive(Clone, Debug)]
10pub struct PyDType {
11 pub(crate) dtype: DType,
12}
13
14#[pymethods]
15impl PyDType {
16 #[new]
17 fn new(name: &str) -> PyResult<Self> {
18 let dtype = match name {
19 "float32" | "f32" => DType::F32,
20 "float64" | "f64" => DType::F64,
21 "int8" | "i8" => DType::I8,
22 "int16" | "i16" => DType::I16,
23 "int32" | "i32" => DType::I32,
24 "int64" | "i64" => DType::I64,
25 "uint8" | "u8" => DType::U8,
26 "uint16" | "u16" => {
27 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
28 "uint16/u16 data type is not supported in ToRSh",
29 ))
30 }
31 "uint32" | "u32" => DType::U32,
32 "uint64" | "u64" => DType::U64,
33 "bool" => DType::Bool,
34 _ => {
35 return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
36 "Unknown dtype: {}",
37 name
38 )))
39 }
40 };
41 Ok(Self { dtype })
42 }
43
44 fn __str__(&self) -> String {
45 match self.dtype {
46 DType::F32 => "torch.float32".to_string(),
47 DType::F64 => "torch.float64".to_string(),
48 DType::I8 => "torch.int8".to_string(),
49 DType::I16 => "torch.int16".to_string(),
50 DType::I32 => "torch.int32".to_string(),
51 DType::I64 => "torch.int64".to_string(),
52 DType::U8 => "torch.uint8".to_string(),
53 DType::U32 => "torch.uint32".to_string(),
54 DType::U64 => "torch.uint64".to_string(),
55 DType::Bool => "torch.bool".to_string(),
56 DType::F16 => "torch.float16".to_string(),
57 DType::BF16 => "torch.bfloat16".to_string(),
58 DType::C64 => "torch.complex64".to_string(),
59 DType::C128 => "torch.complex128".to_string(),
60 DType::QInt8 => "torch.qint8".to_string(),
61 _ => format!("torch.{:?}", self.dtype).to_lowercase(),
62 }
63 }
64
65 fn __repr__(&self) -> String {
66 self.__str__()
67 }
68
69 fn __eq__(&self, other: &PyDType) -> bool {
70 self.dtype == other.dtype
71 }
72
73 fn __hash__(&self) -> u64 {
74 use std::collections::hash_map::DefaultHasher;
75 use std::hash::{Hash, Hasher};
76 let mut hasher = DefaultHasher::new();
77 self.dtype.hash(&mut hasher);
78 hasher.finish()
79 }
80
81 #[getter]
82 fn name(&self) -> String {
83 match self.dtype {
84 DType::F32 => "float32".to_string(),
85 DType::F64 => "float64".to_string(),
86 DType::I8 => "int8".to_string(),
87 DType::I16 => "int16".to_string(),
88 DType::I32 => "int32".to_string(),
89 DType::I64 => "int64".to_string(),
90 DType::U8 => "uint8".to_string(),
91 DType::U32 => "uint32".to_string(),
92 DType::U64 => "uint64".to_string(),
93 DType::Bool => "bool".to_string(),
94 DType::F16 => "float16".to_string(),
95 DType::BF16 => "bfloat16".to_string(),
96 DType::C64 => "complex64".to_string(),
97 DType::C128 => "complex128".to_string(),
98 DType::QInt8 => "qint8".to_string(),
99 _ => format!("{:?}", self.dtype).to_lowercase(),
100 }
101 }
102
103 #[getter]
104 fn itemsize(&self) -> usize {
105 match self.dtype {
106 DType::F32 => 4,
107 DType::F64 => 8,
108 DType::I8 => 1,
109 DType::I16 => 2,
110 DType::I32 => 4,
111 DType::I64 => 8,
112 DType::U8 => 1,
113 DType::U32 => 4,
114 DType::U64 => 8,
115 DType::Bool => 1,
116 DType::F16 => 2,
117 DType::BF16 => 2,
118 DType::C64 => 8,
119 DType::C128 => 16,
120 DType::QInt8 => 1,
121 _ => 4, }
123 }
124
125 #[getter]
126 fn is_floating_point(&self) -> bool {
127 matches!(
128 self.dtype,
129 DType::F16 | DType::F32 | DType::F64 | DType::BF16
130 )
131 }
132
133 #[getter]
134 fn is_signed(&self) -> bool {
135 matches!(
136 self.dtype,
137 DType::F16
138 | DType::F32
139 | DType::F64
140 | DType::BF16
141 | DType::I8
142 | DType::I16
143 | DType::I32
144 | DType::I64
145 | DType::QInt8
146 )
147 }
148}
149
150impl From<DType> for PyDType {
151 fn from(dtype: DType) -> Self {
152 Self { dtype }
153 }
154}
155
156impl From<PyDType> for DType {
157 fn from(py_dtype: PyDType) -> Self {
158 py_dtype.dtype
159 }
160}
161
162impl std::fmt::Display for PyDType {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 write!(f, "{}", self.__str__())
165 }
166}
167
168pub fn register_dtype_constants(m: &Bound<'_, PyModule>) -> PyResult<()> {
170 m.add("float32", PyDType { dtype: DType::F32 })?;
172 m.add("float64", PyDType { dtype: DType::F64 })?;
173 m.add("int8", PyDType { dtype: DType::I8 })?;
174 m.add("int16", PyDType { dtype: DType::I16 })?;
175 m.add("int32", PyDType { dtype: DType::I32 })?;
176 m.add("int64", PyDType { dtype: DType::I64 })?;
177 m.add("uint8", PyDType { dtype: DType::U8 })?;
178 m.add("uint32", PyDType { dtype: DType::U32 })?;
179 m.add("uint64", PyDType { dtype: DType::U64 })?;
180 m.add("bool", PyDType { dtype: DType::Bool })?;
181
182 m.add("float", PyDType { dtype: DType::F32 })?;
184 m.add("double", PyDType { dtype: DType::F64 })?;
185 m.add("long", PyDType { dtype: DType::I64 })?;
186 m.add("int", PyDType { dtype: DType::I32 })?;
187 m.add("short", PyDType { dtype: DType::I16 })?;
188 m.add("char", PyDType { dtype: DType::I8 })?;
189 m.add("byte", PyDType { dtype: DType::U8 })?;
190
191 Ok(())
192}