torsh_python/
dtype.rs

1//! Data type handling for Python bindings
2
3use crate::error::PyResult;
4use pyo3::prelude::*;
5use torsh_core::dtype::DType;
6
7/// Python wrapper for ToRSh data types
8#[pyclass(name = "dtype")]
9#[derive(Clone, Debug)]
10pub struct PyDType {
11    pub(crate) dtype: DType,
12}
13
14#[pymethods]
15impl PyDType {
16    #[new]
17    fn new(name: &str) -> PyResult<Self> {
18        let dtype = match name {
19            "float32" | "f32" => DType::F32,
20            "float64" | "f64" => DType::F64,
21            "int8" | "i8" => DType::I8,
22            "int16" | "i16" => DType::I16,
23            "int32" | "i32" => DType::I32,
24            "int64" | "i64" => DType::I64,
25            "uint8" | "u8" => DType::U8,
26            "uint16" | "u16" => {
27                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
28                    "uint16/u16 data type is not supported in ToRSh",
29                ))
30            }
31            "uint32" | "u32" => DType::U32,
32            "uint64" | "u64" => DType::U64,
33            "bool" => DType::Bool,
34            _ => {
35                return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
36                    "Unknown dtype: {}",
37                    name
38                )))
39            }
40        };
41        Ok(Self { dtype })
42    }
43
44    fn __str__(&self) -> String {
45        match self.dtype {
46            DType::F32 => "torch.float32".to_string(),
47            DType::F64 => "torch.float64".to_string(),
48            DType::I8 => "torch.int8".to_string(),
49            DType::I16 => "torch.int16".to_string(),
50            DType::I32 => "torch.int32".to_string(),
51            DType::I64 => "torch.int64".to_string(),
52            DType::U8 => "torch.uint8".to_string(),
53            DType::U32 => "torch.uint32".to_string(),
54            DType::U64 => "torch.uint64".to_string(),
55            DType::Bool => "torch.bool".to_string(),
56            DType::F16 => "torch.float16".to_string(),
57            DType::BF16 => "torch.bfloat16".to_string(),
58            DType::C64 => "torch.complex64".to_string(),
59            DType::C128 => "torch.complex128".to_string(),
60            DType::QInt8 => "torch.qint8".to_string(),
61            _ => format!("torch.{:?}", self.dtype).to_lowercase(),
62        }
63    }
64
65    fn __repr__(&self) -> String {
66        self.__str__()
67    }
68
69    fn __eq__(&self, other: &PyDType) -> bool {
70        self.dtype == other.dtype
71    }
72
73    fn __hash__(&self) -> u64 {
74        use std::collections::hash_map::DefaultHasher;
75        use std::hash::{Hash, Hasher};
76        let mut hasher = DefaultHasher::new();
77        self.dtype.hash(&mut hasher);
78        hasher.finish()
79    }
80
81    #[getter]
82    fn name(&self) -> String {
83        match self.dtype {
84            DType::F32 => "float32".to_string(),
85            DType::F64 => "float64".to_string(),
86            DType::I8 => "int8".to_string(),
87            DType::I16 => "int16".to_string(),
88            DType::I32 => "int32".to_string(),
89            DType::I64 => "int64".to_string(),
90            DType::U8 => "uint8".to_string(),
91            DType::U32 => "uint32".to_string(),
92            DType::U64 => "uint64".to_string(),
93            DType::Bool => "bool".to_string(),
94            DType::F16 => "float16".to_string(),
95            DType::BF16 => "bfloat16".to_string(),
96            DType::C64 => "complex64".to_string(),
97            DType::C128 => "complex128".to_string(),
98            DType::QInt8 => "qint8".to_string(),
99            _ => format!("{:?}", self.dtype).to_lowercase(),
100        }
101    }
102
103    #[getter]
104    fn itemsize(&self) -> usize {
105        match self.dtype {
106            DType::F32 => 4,
107            DType::F64 => 8,
108            DType::I8 => 1,
109            DType::I16 => 2,
110            DType::I32 => 4,
111            DType::I64 => 8,
112            DType::U8 => 1,
113            DType::U32 => 4,
114            DType::U64 => 8,
115            DType::Bool => 1,
116            DType::F16 => 2,
117            DType::BF16 => 2,
118            DType::C64 => 8,
119            DType::C128 => 16,
120            DType::QInt8 => 1,
121            _ => 4, // Default to 4 bytes for unknown types
122        }
123    }
124
125    #[getter]
126    fn is_floating_point(&self) -> bool {
127        matches!(
128            self.dtype,
129            DType::F16 | DType::F32 | DType::F64 | DType::BF16
130        )
131    }
132
133    #[getter]
134    fn is_signed(&self) -> bool {
135        matches!(
136            self.dtype,
137            DType::F16
138                | DType::F32
139                | DType::F64
140                | DType::BF16
141                | DType::I8
142                | DType::I16
143                | DType::I32
144                | DType::I64
145                | DType::QInt8
146        )
147    }
148
149    /// Check if this is a complex dtype.
150    ///
151    /// # Returns
152    ///
153    /// True if dtype is complex (complex64, complex128), False otherwise
154    ///
155    /// # Examples
156    ///
157    /// ```python
158    /// float32 = torsh.PyDType("float32")
159    /// print(float32.is_complex)  # False
160    /// ```
161    #[getter]
162    fn is_complex(&self) -> bool {
163        matches!(self.dtype, DType::C64 | DType::C128)
164    }
165
166    /// Check if this is an integer dtype.
167    ///
168    /// # Returns
169    ///
170    /// True if dtype is integer (int8, int16, int32, int64, uint8, etc.), False otherwise
171    ///
172    /// # Examples
173    ///
174    /// ```python
175    /// int32 = torsh.PyDType("int32")
176    /// print(int32.is_integer)  # True
177    ///
178    /// float32 = torsh.PyDType("float32")
179    /// print(float32.is_integer)  # False
180    /// ```
181    #[getter]
182    fn is_integer(&self) -> bool {
183        matches!(
184            self.dtype,
185            DType::I8 | DType::I16 | DType::I32 | DType::I64 | DType::U8 | DType::U32 | DType::U64
186        )
187    }
188
189    /// Get the NumPy-compatible dtype string.
190    ///
191    /// # Returns
192    ///
193    /// NumPy dtype string (e.g., 'float32', 'int64')
194    ///
195    /// # Examples
196    ///
197    /// ```python
198    /// dtype = torsh.PyDType("float32")
199    /// print(dtype.numpy_dtype)  # 'float32'
200    /// ```
201    #[getter]
202    fn numpy_dtype(&self) -> String {
203        match self.dtype {
204            DType::F32 => "float32".to_string(),
205            DType::F64 => "float64".to_string(),
206            DType::F16 => "float16".to_string(),
207            DType::I8 => "int8".to_string(),
208            DType::I16 => "int16".to_string(),
209            DType::I32 => "int32".to_string(),
210            DType::I64 => "int64".to_string(),
211            DType::U8 => "uint8".to_string(),
212            DType::U32 => "uint32".to_string(),
213            DType::U64 => "uint64".to_string(),
214            DType::Bool => "bool".to_string(),
215            DType::C64 => "complex64".to_string(),
216            DType::C128 => "complex128".to_string(),
217            _ => format!("{:?}", self.dtype).to_lowercase(),
218        }
219    }
220
221    /// Check if this dtype can be safely cast to another dtype.
222    ///
223    /// # Arguments
224    ///
225    /// * `other` - Target dtype to check casting compatibility
226    ///
227    /// # Returns
228    ///
229    /// True if safe cast is possible, False otherwise
230    ///
231    /// # Examples
232    ///
233    /// ```python
234    /// int32 = torsh.PyDType("int32")
235    /// int64 = torsh.PyDType("int64")
236    /// float32 = torsh.PyDType("float32")
237    ///
238    /// print(int32.can_cast(int64))    # True (widening)
239    /// print(int64.can_cast(int32))    # False (narrowing)
240    /// print(int32.can_cast(float32))  # True (int to float)
241    /// ```
242    fn can_cast(&self, other: &PyDType) -> bool {
243        // Same type is always safe
244        if self.dtype == other.dtype {
245            return true;
246        }
247
248        // Casting rules based on type promotion
249        match (self.dtype, other.dtype) {
250            // Integer widening is safe
251            (DType::I8, DType::I16 | DType::I32 | DType::I64) => true,
252            (DType::I16, DType::I32 | DType::I64) => true,
253            (DType::I32, DType::I64) => true,
254
255            // Unsigned integer widening is safe
256            (DType::U8, DType::U32 | DType::U64) => true,
257            (DType::U32, DType::U64) => true,
258
259            // Integer to float is generally safe (may lose precision for large integers)
260            (DType::I8 | DType::I16 | DType::I32, DType::F32 | DType::F64) => true,
261            (DType::I64, DType::F64) => true,
262            (DType::U8 | DType::U32, DType::F32 | DType::F64) => true,
263
264            // Float widening is safe
265            (DType::F16, DType::F32 | DType::F64) => true,
266            (DType::F32, DType::F64) => true,
267
268            // Bool can be cast to any numeric type
269            (DType::Bool, DType::I8 | DType::I16 | DType::I32 | DType::I64) => true,
270            (DType::Bool, DType::U8 | DType::U32 | DType::U64) => true,
271            (DType::Bool, DType::F16 | DType::F32 | DType::F64) => true,
272
273            // Float to complex
274            (DType::F32, DType::C64 | DType::C128) => true,
275            (DType::F64, DType::C128) => true,
276
277            // Complex widening
278            (DType::C64, DType::C128) => true,
279
280            // Everything else is not safe
281            _ => false,
282        }
283    }
284}
285
286impl From<DType> for PyDType {
287    fn from(dtype: DType) -> Self {
288        Self { dtype }
289    }
290}
291
292impl From<PyDType> for DType {
293    fn from(py_dtype: PyDType) -> Self {
294        py_dtype.dtype
295    }
296}
297
298impl std::fmt::Display for PyDType {
299    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300        write!(f, "{}", self.__str__())
301    }
302}
303
304/// Register dtype constants and utility functions with the module.
305///
306/// This function adds:
307/// - Dtype constants (float32, int64, etc.)
308/// - PyTorch-style aliases (float, double, long, etc.)
309/// - Utility functions for dtype operations
310pub fn register_dtype_constants(m: &Bound<'_, PyModule>) -> PyResult<()> {
311    use pyo3::wrap_pyfunction;
312
313    // Create dtype constants similar to PyTorch
314    m.add("float32", PyDType { dtype: DType::F32 })?;
315    m.add("float64", PyDType { dtype: DType::F64 })?;
316    m.add("int8", PyDType { dtype: DType::I8 })?;
317    m.add("int16", PyDType { dtype: DType::I16 })?;
318    m.add("int32", PyDType { dtype: DType::I32 })?;
319    m.add("int64", PyDType { dtype: DType::I64 })?;
320    m.add("uint8", PyDType { dtype: DType::U8 })?;
321    m.add("uint32", PyDType { dtype: DType::U32 })?;
322    m.add("uint64", PyDType { dtype: DType::U64 })?;
323    m.add("bool", PyDType { dtype: DType::Bool })?;
324
325    // PyTorch-style aliases
326    m.add("float", PyDType { dtype: DType::F32 })?;
327    m.add("double", PyDType { dtype: DType::F64 })?;
328    m.add("long", PyDType { dtype: DType::I64 })?;
329    m.add("int", PyDType { dtype: DType::I32 })?;
330    m.add("short", PyDType { dtype: DType::I16 })?;
331    m.add("char", PyDType { dtype: DType::I8 })?;
332    m.add("byte", PyDType { dtype: DType::U8 })?;
333
334    // Utility functions
335    /// Promote two dtypes to a common dtype for operations.
336    ///
337    /// # Arguments
338    ///
339    /// * `dtype1` - First dtype
340    /// * `dtype2` - Second dtype
341    ///
342    /// # Returns
343    ///
344    /// Promoted dtype that can safely represent both inputs
345    ///
346    /// # Examples
347    ///
348    /// ```python
349    /// result = torsh.promote_types(torsh.int32, torsh.float32)
350    /// print(result)  # float32
351    ///
352    /// result = torsh.promote_types(torsh.int32, torsh.int64)
353    /// print(result)  # int64
354    /// ```
355    #[pyfunction]
356    fn promote_types(dtype1: &PyDType, dtype2: &PyDType) -> PyDType {
357        use DType::*;
358
359        // If same type, return it
360        if dtype1.dtype == dtype2.dtype {
361            return dtype1.clone();
362        }
363
364        // Type promotion rules (similar to NumPy/PyTorch)
365        let promoted = match (dtype1.dtype, dtype2.dtype) {
366            // Bool promotes to anything else
367            (Bool, other) | (other, Bool) => other,
368
369            // Complex types take precedence
370            (C128, _) | (_, C128) => C128,
371            (C64, _) | (_, C64) => C64,
372
373            // Float promotion
374            (F64, _) | (_, F64) => F64,
375            (F32, _) | (_, F32) => F32,
376            (F16, _) | (_, F16) => F16,
377
378            // Integer promotion - signed takes precedence, larger size wins
379            (I64, I8 | I16 | I32 | U8 | U32 | U64) | (I8 | I16 | I32 | U8 | U32 | U64, I64) => I64,
380            (I32, I8 | I16 | U8) | (I8 | I16 | U8, I32) => I32,
381            (I16, I8 | U8) | (I8 | U8, I16) => I16,
382
383            // Unsigned integer promotion
384            (U64, U8 | U32) | (U8 | U32, U64) => U64,
385            (U32, U8) | (U8, U32) => U32,
386
387            // Default to the larger type
388            (a, b) => {
389                let size_a = dtype1.itemsize();
390                let size_b = dtype2.itemsize();
391                if size_a >= size_b {
392                    a
393                } else {
394                    b
395                }
396            }
397        };
398
399        PyDType { dtype: promoted }
400    }
401
402    /// Get the result dtype for a binary operation between two dtypes.
403    ///
404    /// # Arguments
405    ///
406    /// * `dtype1` - First operand dtype
407    /// * `dtype2` - Second operand dtype
408    ///
409    /// # Returns
410    ///
411    /// Result dtype for the operation
412    ///
413    /// # Examples
414    ///
415    /// ```python
416    /// result = torsh.result_type(torsh.int32, torsh.float32)
417    /// print(result)  # float32
418    /// ```
419    #[pyfunction]
420    fn result_type(dtype1: &PyDType, dtype2: &PyDType) -> PyDType {
421        // For now, result_type is the same as promote_types
422        // In the future, this could have different rules for specific operations
423        promote_types(dtype1, dtype2)
424    }
425
426    /// Check if two dtypes are compatible for operations.
427    ///
428    /// # Arguments
429    ///
430    /// * `dtype1` - First dtype
431    /// * `dtype2` - Second dtype
432    ///
433    /// # Returns
434    ///
435    /// True if dtypes can be used together in operations
436    ///
437    /// # Examples
438    ///
439    /// ```python
440    /// print(torsh.can_operate(torsh.int32, torsh.float32))  # True
441    /// print(torsh.can_operate(torsh.bool, torsh.int32))     # True
442    /// ```
443    #[pyfunction]
444    fn can_operate(_dtype1: &PyDType, _dtype2: &PyDType) -> bool {
445        // Most dtypes can operate together (via promotion)
446        // Only complex and non-numeric types might be incompatible
447        true
448    }
449
450    m.add_function(wrap_pyfunction!(promote_types, m)?)?;
451    m.add_function(wrap_pyfunction!(result_type, m)?)?;
452    m.add_function(wrap_pyfunction!(can_operate, m)?)?;
453
454    Ok(())
455}