Skip to main content

singe_cuda/
data_type.rs

1use std::fmt::Debug;
2
3use num_enum::{IntoPrimitive, TryFromPrimitive};
4use singe_core::{impl_enum_conversion, impl_enum_display};
5use singe_cuda_sys::library_types::cudaDataType_t;
6
7use crate::types::{
8    Complex32, Complex64, bf16, f4e2m1, f6e2m3, f6e3m2, f8e4m3, f8e5m2, f8ue8m0, f16,
9};
10
11/// Rust wrapper for CUDA's data type enum.
12#[non_exhaustive]
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
14#[repr(u32)]
15pub enum DataType {
16    /// 16-bit real half precision floating-point (IEEE 754-2008 binary16).
17    F16 = cudaDataType_t::CUDA_R_16F as _,
18    /// 32-bit complex (2x16-bit half precision floats).
19    ComplexF16 = cudaDataType_t::CUDA_C_16F as _,
20    /// 16-bit real bfloat16 floating-point.
21    Bf16 = cudaDataType_t::CUDA_R_16BF as _,
22    /// 32-bit complex (2x16-bit bfloat16 floats).
23    ComplexBf16 = cudaDataType_t::CUDA_C_16BF as _,
24    /// 32-bit real single precision floating-point (IEEE 754 binary32).
25    F32 = cudaDataType_t::CUDA_R_32F as _,
26    /// 64-bit complex (2x32-bit single precision floats).
27    ComplexF32 = cudaDataType_t::CUDA_C_32F as _,
28    /// 64-bit real double precision floating-point (IEEE 754 binary64).
29    F64 = cudaDataType_t::CUDA_R_64F as _,
30    /// 128-bit complex (2x64-bit double precision floats).
31    ComplexF64 = cudaDataType_t::CUDA_C_64F as _,
32    /// 4-bit real signed integer.
33    I4 = cudaDataType_t::CUDA_R_4I as _,
34    /// 8-bit complex (2x4-bit signed integers).
35    ComplexI4 = cudaDataType_t::CUDA_C_4I as _,
36    /// 4-bit real unsigned integer.
37    U4 = cudaDataType_t::CUDA_R_4U as _,
38    /// 8-bit complex (2x4-bit unsigned integers).
39    ComplexU4 = cudaDataType_t::CUDA_C_4U as _,
40    /// 8-bit real signed integer.
41    I8 = cudaDataType_t::CUDA_R_8I as _,
42    /// 16-bit complex (2x8-bit signed integers).
43    ComplexI8 = cudaDataType_t::CUDA_C_8I as _,
44    /// 8-bit real unsigned integer.
45    U8 = cudaDataType_t::CUDA_R_8U as _,
46    /// 16-bit complex (2x8-bit unsigned integers).
47    ComplexU8 = cudaDataType_t::CUDA_C_8U as _,
48    /// 16-bit real signed integer.
49    I16 = cudaDataType_t::CUDA_R_16I as _,
50    /// 32-bit complex (2x16-bit signed integers).
51    ComplexI16 = cudaDataType_t::CUDA_C_16I as _,
52    /// 16-bit real unsigned integer.
53    U16 = cudaDataType_t::CUDA_R_16U as _,
54    /// 32-bit complex (2x16-bit unsigned integers).
55    ComplexU16 = cudaDataType_t::CUDA_C_16U as _,
56    /// 32-bit real signed integer.
57    I32 = cudaDataType_t::CUDA_R_32I as _,
58    /// 64-bit complex (2x32-bit signed integers).
59    ComplexI32 = cudaDataType_t::CUDA_C_32I as _,
60    /// 32-bit real unsigned integer.
61    U32 = cudaDataType_t::CUDA_R_32U as _,
62    /// 64-bit complex (2x32-bit unsigned integers).
63    ComplexU32 = cudaDataType_t::CUDA_C_32U as _,
64    /// 64-bit real signed integer.
65    I64 = cudaDataType_t::CUDA_R_64I as _,
66    /// 128-bit complex (2x64-bit signed integers).
67    ComplexI64 = cudaDataType_t::CUDA_C_64I as _,
68    /// 64-bit real unsigned integer.
69    U64 = cudaDataType_t::CUDA_R_64U as _,
70    /// 128-bit complex (2x64-bit unsigned integers).
71    ComplexU64 = cudaDataType_t::CUDA_C_64U as _,
72    /// 8-bit real floating point in E4M3 format.
73    F8E4M3 = cudaDataType_t::CUDA_R_8F_E4M3 as _,
74    /// 8-bit real floating point in E5M2 format.
75    F8E5M2 = cudaDataType_t::CUDA_R_8F_E5M2 as _,
76    /// 8-bit real floating point in E8M0 format (unsigned exponent, zero mantissa bits).
77    F8UE8M0 = cudaDataType_t::CUDA_R_8F_UE8M0 as _,
78    /// 6-bit real floating point in E2M3 format (2-bit exponent, 3-bit mantissa).
79    F6E2M3 = cudaDataType_t::CUDA_R_6F_E2M3 as _,
80    /// 6-bit real floating point in E3M2 format (3-bit exponent, 2-bit mantissa).
81    F6E3M2 = cudaDataType_t::CUDA_R_6F_E3M2 as _,
82    /// 4-bit real floating point in E2M1 format (2-bit exponent, 1-bit mantissa).
83    F4E2M1 = cudaDataType_t::CUDA_R_4F_E2M1 as _,
84}
85
86impl_enum_conversion!(DataType, cudaDataType_t);
87
88impl DataType {
89    pub const fn size_of(self) -> usize {
90        match self {
91            Self::F16 | Self::Bf16 | Self::I16 | Self::U16 => 2,
92            Self::ComplexF16
93            | Self::ComplexBf16
94            | Self::F32
95            | Self::I32
96            | Self::U32
97            | Self::I8
98            | Self::U8 => 4,
99            Self::ComplexF32 | Self::F64 | Self::I64 | Self::U64 => 8,
100            Self::ComplexF64 => 16,
101            Self::I4 | Self::U4 | Self::F4E2M1 => 1,
102            Self::ComplexI4 | Self::ComplexU4 => 1,
103            Self::F8E4M3 | Self::F8E5M2 | Self::F8UE8M0 => 1,
104            Self::F6E2M3 | Self::F6E3M2 => 1,
105            Self::ComplexI8 | Self::ComplexU8 => 2,
106            Self::ComplexI16 | Self::ComplexU16 => 4,
107            Self::ComplexI32 | Self::ComplexU32 => 8,
108            Self::ComplexI64 | Self::ComplexU64 => 16,
109        }
110    }
111}
112
113impl_enum_display!(DataType, {
114    Self::F16 => "CUDA_R_16F",
115    Self::ComplexF16 => "CUDA_C_16F",
116    Self::Bf16 => "CUDA_R_16BF",
117    Self::ComplexBf16 => "CUDA_C_16BF",
118    Self::F32 => "CUDA_R_32F",
119    Self::ComplexF32 => "CUDA_C_32F",
120    Self::F64 => "CUDA_R_64F",
121    Self::ComplexF64 => "CUDA_C_64F",
122    Self::I4 => "CUDA_R_4I",
123    Self::ComplexI4 => "CUDA_C_4I",
124    Self::U4 => "CUDA_R_4U",
125    Self::ComplexU4 => "CUDA_C_4U",
126    Self::I8 => "CUDA_R_8I",
127    Self::ComplexI8 => "CUDA_C_8I",
128    Self::U8 => "CUDA_R_8U",
129    Self::ComplexU8 => "CUDA_C_8U",
130    Self::I16 => "CUDA_R_16I",
131    Self::ComplexI16 => "CUDA_C_16I",
132    Self::U16 => "CUDA_R_16U",
133    Self::ComplexU16 => "CUDA_C_16U",
134    Self::I32 => "CUDA_R_32I",
135    Self::ComplexI32 => "CUDA_C_32I",
136    Self::U32 => "CUDA_R_32U",
137    Self::ComplexU32 => "CUDA_C_32U",
138    Self::I64 => "CUDA_R_64I",
139    Self::ComplexI64 => "CUDA_C_64I",
140    Self::U64 => "CUDA_R_64U",
141    Self::ComplexU64 => "CUDA_C_64U",
142    Self::F8E4M3 => "CUDA_R_8F_E4M3",
143    Self::F8E5M2 => "CUDA_R_8F_E5M2",
144    Self::F8UE8M0 => "CUDA_R_8F_UE8M0",
145    Self::F6E2M3 => "CUDA_R_6F_E2M3",
146    Self::F6E3M2 => "CUDA_R_6F_E3M2",
147    Self::F4E2M1 => "CUDA_R_4F_E2M1",
148});
149
150pub trait DataTypeLike: Clone + Copy + Default + Debug + 'static {
151    fn data_type() -> DataType;
152
153    fn is_complex() -> bool;
154
155    fn rust_type_name() -> &'static str;
156}
157
158macro_rules! impl_data_type {
159    ($ty:ty, $data_type:ident, $is_complex:expr) => {
160        impl DataTypeLike for $ty {
161            fn data_type() -> DataType {
162                DataType::$data_type
163            }
164
165            fn is_complex() -> bool {
166                $is_complex
167            }
168
169            fn rust_type_name() -> &'static str {
170                // TODO: `type_name::<T>()`?
171                stringify!($ty)
172            }
173        }
174    };
175}
176
177impl_data_type!(f32, F32, false);
178impl_data_type!(f64, F64, false);
179impl_data_type!(f16, F16, false);
180impl_data_type!(bf16, Bf16, false);
181impl_data_type!(f8e4m3, F8E4M3, false);
182impl_data_type!(f8e5m2, F8E5M2, false);
183impl_data_type!(f8ue8m0, F8UE8M0, false);
184impl_data_type!(f6e2m3, F6E2M3, false);
185impl_data_type!(f6e3m2, F6E3M2, false);
186impl_data_type!(f4e2m1, F4E2M1, false);
187impl_data_type!(i8, I8, false);
188impl_data_type!(u8, U8, false);
189impl_data_type!(i32, I32, false);
190impl_data_type!(u32, U32, false);
191impl_data_type!(Complex32, ComplexF32, true);
192impl_data_type!(Complex64, ComplexF64, true);
193
194#[cfg(test)]
195mod tests {
196    use super::{
197        DataType, DataTypeLike, bf16, f4e2m1, f6e2m3, f6e3m2, f8e4m3, f8e5m2, f8ue8m0, f16,
198    };
199
200    #[test]
201    fn test_low_precision_module_reexports_expected_types() {
202        let _ = f16::from_f32(1.0);
203        let _ = bf16::from_f32(1.0);
204        let _ = f8e4m3::from_bits(0);
205        let _ = f8e5m2::from_bits(0);
206        let _ = f8ue8m0::from_bits(0);
207        let _ = f6e2m3::from_bits(0);
208        let _ = f6e3m2::from_bits(0);
209        let _ = f4e2m1::from_bits(0);
210    }
211
212    #[test]
213    fn test_data_type_like_maps_low_precision_storage_types() {
214        assert_eq!(f16::data_type(), DataType::F16);
215        assert_eq!(bf16::data_type(), DataType::Bf16);
216        assert_eq!(f8e4m3::data_type(), DataType::F8E4M3);
217        assert_eq!(f8e5m2::data_type(), DataType::F8E5M2);
218        assert_eq!(f8ue8m0::data_type(), DataType::F8UE8M0);
219        assert_eq!(f6e2m3::data_type(), DataType::F6E2M3);
220        assert_eq!(f6e3m2::data_type(), DataType::F6E3M2);
221        assert_eq!(f4e2m1::data_type(), DataType::F4E2M1);
222    }
223}