torsh_python/
dtype.rs

1//! Data type handling for Python bindings
2
3use crate::error::PyResult;
4use pyo3::prelude::*;
5use torsh_core::dtype::DType;
6
7/// Python wrapper for ToRSh data types
8#[pyclass(name = "dtype")]
9#[derive(Clone, Debug)]
10pub struct PyDType {
11    pub(crate) dtype: DType,
12}
13
14#[pymethods]
15impl PyDType {
16    #[new]
17    fn new(name: &str) -> PyResult<Self> {
18        let dtype = match name {
19            "float32" | "f32" => DType::F32,
20            "float64" | "f64" => DType::F64,
21            "int8" | "i8" => DType::I8,
22            "int16" | "i16" => DType::I16,
23            "int32" | "i32" => DType::I32,
24            "int64" | "i64" => DType::I64,
25            "uint8" | "u8" => DType::U8,
26            "uint16" | "u16" => {
27                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
28                    "uint16/u16 data type is not supported in ToRSh",
29                ))
30            }
31            "uint32" | "u32" => DType::U32,
32            "uint64" | "u64" => DType::U64,
33            "bool" => DType::Bool,
34            _ => {
35                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
36                    "Unknown dtype: {}",
37                    name
38                )))
39            }
40        };
41        Ok(Self { dtype })
42    }
43
44    fn __str__(&self) -> String {
45        match self.dtype {
46            DType::F32 => "torch.float32".to_string(),
47            DType::F64 => "torch.float64".to_string(),
48            DType::I8 => "torch.int8".to_string(),
49            DType::I16 => "torch.int16".to_string(),
50            DType::I32 => "torch.int32".to_string(),
51            DType::I64 => "torch.int64".to_string(),
52            DType::U8 => "torch.uint8".to_string(),
53            DType::U32 => "torch.uint32".to_string(),
54            DType::U64 => "torch.uint64".to_string(),
55            DType::Bool => "torch.bool".to_string(),
56            DType::F16 => "torch.float16".to_string(),
57            DType::BF16 => "torch.bfloat16".to_string(),
58            DType::C64 => "torch.complex64".to_string(),
59            DType::C128 => "torch.complex128".to_string(),
60            DType::QInt8 => "torch.qint8".to_string(),
61            _ => format!("torch.{:?}", self.dtype).to_lowercase(),
62        }
63    }
64
65    fn __repr__(&self) -> String {
66        self.__str__()
67    }
68
69    fn __eq__(&self, other: &PyDType) -> bool {
70        self.dtype == other.dtype
71    }
72
73    fn __hash__(&self) -> u64 {
74        use std::collections::hash_map::DefaultHasher;
75        use std::hash::{Hash, Hasher};
76        let mut hasher = DefaultHasher::new();
77        self.dtype.hash(&mut hasher);
78        hasher.finish()
79    }
80
81    #[getter]
82    fn name(&self) -> String {
83        match self.dtype {
84            DType::F32 => "float32".to_string(),
85            DType::F64 => "float64".to_string(),
86            DType::I8 => "int8".to_string(),
87            DType::I16 => "int16".to_string(),
88            DType::I32 => "int32".to_string(),
89            DType::I64 => "int64".to_string(),
90            DType::U8 => "uint8".to_string(),
91            DType::U32 => "uint32".to_string(),
92            DType::U64 => "uint64".to_string(),
93            DType::Bool => "bool".to_string(),
94            DType::F16 => "float16".to_string(),
95            DType::BF16 => "bfloat16".to_string(),
96            DType::C64 => "complex64".to_string(),
97            DType::C128 => "complex128".to_string(),
98            DType::QInt8 => "qint8".to_string(),
99            _ => format!("{:?}", self.dtype).to_lowercase(),
100        }
101    }
102
103    #[getter]
104    fn itemsize(&self) -> usize {
105        match self.dtype {
106            DType::F32 => 4,
107            DType::F64 => 8,
108            DType::I8 => 1,
109            DType::I16 => 2,
110            DType::I32 => 4,
111            DType::I64 => 8,
112            DType::U8 => 1,
113            DType::U32 => 4,
114            DType::U64 => 8,
115            DType::Bool => 1,
116            DType::F16 => 2,
117            DType::BF16 => 2,
118            DType::C64 => 8,
119            DType::C128 => 16,
120            DType::QInt8 => 1,
121            _ => 4, // Default to 4 bytes for unknown types
122        }
123    }
124
125    #[getter]
126    fn is_floating_point(&self) -> bool {
127        matches!(
128            self.dtype,
129            DType::F16 | DType::F32 | DType::F64 | DType::BF16
130        )
131    }
132
133    #[getter]
134    fn is_signed(&self) -> bool {
135        matches!(
136            self.dtype,
137            DType::F16
138                | DType::F32
139                | DType::F64
140                | DType::BF16
141                | DType::I8
142                | DType::I16
143                | DType::I32
144                | DType::I64
145                | DType::QInt8
146        )
147    }
148}
149
150impl From<DType> for PyDType {
151    fn from(dtype: DType) -> Self {
152        Self { dtype }
153    }
154}
155
156impl From<PyDType> for DType {
157    fn from(py_dtype: PyDType) -> Self {
158        py_dtype.dtype
159    }
160}
161
162impl std::fmt::Display for PyDType {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        write!(f, "{}", self.__str__())
165    }
166}
167
168/// Register dtype constants with the module
169pub fn register_dtype_constants(m: &Bound<'_, PyModule>) -> PyResult<()> {
170    // Create dtype constants similar to PyTorch
171    m.add("float32", PyDType { dtype: DType::F32 })?;
172    m.add("float64", PyDType { dtype: DType::F64 })?;
173    m.add("int8", PyDType { dtype: DType::I8 })?;
174    m.add("int16", PyDType { dtype: DType::I16 })?;
175    m.add("int32", PyDType { dtype: DType::I32 })?;
176    m.add("int64", PyDType { dtype: DType::I64 })?;
177    m.add("uint8", PyDType { dtype: DType::U8 })?;
178    m.add("uint32", PyDType { dtype: DType::U32 })?;
179    m.add("uint64", PyDType { dtype: DType::U64 })?;
180    m.add("bool", PyDType { dtype: DType::Bool })?;
181
182    // PyTorch-style aliases
183    m.add("float", PyDType { dtype: DType::F32 })?;
184    m.add("double", PyDType { dtype: DType::F64 })?;
185    m.add("long", PyDType { dtype: DType::I64 })?;
186    m.add("int", PyDType { dtype: DType::I32 })?;
187    m.add("short", PyDType { dtype: DType::I16 })?;
188    m.add("char", PyDType { dtype: DType::I8 })?;
189    m.add("byte", PyDType { dtype: DType::U8 })?;
190
191    Ok(())
192}