1use std::fmt::Debug;
2
3use num_enum::{IntoPrimitive, TryFromPrimitive};
4use singe_core::{impl_enum_conversion, impl_enum_display};
5use singe_cuda_sys::library_types::cudaDataType_t;
6
7use crate::types::{
8 Complex32, Complex64, bf16, f4e2m1, f6e2m3, f6e3m2, f8e4m3, f8e5m2, f8ue8m0, f16,
9};
10
11#[non_exhaustive]
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, TryFromPrimitive, IntoPrimitive)]
14#[repr(u32)]
15pub enum DataType {
16 F16 = cudaDataType_t::CUDA_R_16F as _,
18 ComplexF16 = cudaDataType_t::CUDA_C_16F as _,
20 Bf16 = cudaDataType_t::CUDA_R_16BF as _,
22 ComplexBf16 = cudaDataType_t::CUDA_C_16BF as _,
24 F32 = cudaDataType_t::CUDA_R_32F as _,
26 ComplexF32 = cudaDataType_t::CUDA_C_32F as _,
28 F64 = cudaDataType_t::CUDA_R_64F as _,
30 ComplexF64 = cudaDataType_t::CUDA_C_64F as _,
32 I4 = cudaDataType_t::CUDA_R_4I as _,
34 ComplexI4 = cudaDataType_t::CUDA_C_4I as _,
36 U4 = cudaDataType_t::CUDA_R_4U as _,
38 ComplexU4 = cudaDataType_t::CUDA_C_4U as _,
40 I8 = cudaDataType_t::CUDA_R_8I as _,
42 ComplexI8 = cudaDataType_t::CUDA_C_8I as _,
44 U8 = cudaDataType_t::CUDA_R_8U as _,
46 ComplexU8 = cudaDataType_t::CUDA_C_8U as _,
48 I16 = cudaDataType_t::CUDA_R_16I as _,
50 ComplexI16 = cudaDataType_t::CUDA_C_16I as _,
52 U16 = cudaDataType_t::CUDA_R_16U as _,
54 ComplexU16 = cudaDataType_t::CUDA_C_16U as _,
56 I32 = cudaDataType_t::CUDA_R_32I as _,
58 ComplexI32 = cudaDataType_t::CUDA_C_32I as _,
60 U32 = cudaDataType_t::CUDA_R_32U as _,
62 ComplexU32 = cudaDataType_t::CUDA_C_32U as _,
64 I64 = cudaDataType_t::CUDA_R_64I as _,
66 ComplexI64 = cudaDataType_t::CUDA_C_64I as _,
68 U64 = cudaDataType_t::CUDA_R_64U as _,
70 ComplexU64 = cudaDataType_t::CUDA_C_64U as _,
72 F8E4M3 = cudaDataType_t::CUDA_R_8F_E4M3 as _,
74 F8E5M2 = cudaDataType_t::CUDA_R_8F_E5M2 as _,
76 F8UE8M0 = cudaDataType_t::CUDA_R_8F_UE8M0 as _,
78 F6E2M3 = cudaDataType_t::CUDA_R_6F_E2M3 as _,
80 F6E3M2 = cudaDataType_t::CUDA_R_6F_E3M2 as _,
82 F4E2M1 = cudaDataType_t::CUDA_R_4F_E2M1 as _,
84}
85
86impl_enum_conversion!(DataType, cudaDataType_t);
87
88impl DataType {
89 pub const fn size_of(self) -> usize {
90 match self {
91 Self::F16 | Self::Bf16 | Self::I16 | Self::U16 => 2,
92 Self::ComplexF16
93 | Self::ComplexBf16
94 | Self::F32
95 | Self::I32
96 | Self::U32
97 | Self::I8
98 | Self::U8 => 4,
99 Self::ComplexF32 | Self::F64 | Self::I64 | Self::U64 => 8,
100 Self::ComplexF64 => 16,
101 Self::I4 | Self::U4 | Self::F4E2M1 => 1,
102 Self::ComplexI4 | Self::ComplexU4 => 1,
103 Self::F8E4M3 | Self::F8E5M2 | Self::F8UE8M0 => 1,
104 Self::F6E2M3 | Self::F6E3M2 => 1,
105 Self::ComplexI8 | Self::ComplexU8 => 2,
106 Self::ComplexI16 | Self::ComplexU16 => 4,
107 Self::ComplexI32 | Self::ComplexU32 => 8,
108 Self::ComplexI64 | Self::ComplexU64 => 16,
109 }
110 }
111}
112
113impl_enum_display!(DataType, {
114 Self::F16 => "CUDA_R_16F",
115 Self::ComplexF16 => "CUDA_C_16F",
116 Self::Bf16 => "CUDA_R_16BF",
117 Self::ComplexBf16 => "CUDA_C_16BF",
118 Self::F32 => "CUDA_R_32F",
119 Self::ComplexF32 => "CUDA_C_32F",
120 Self::F64 => "CUDA_R_64F",
121 Self::ComplexF64 => "CUDA_C_64F",
122 Self::I4 => "CUDA_R_4I",
123 Self::ComplexI4 => "CUDA_C_4I",
124 Self::U4 => "CUDA_R_4U",
125 Self::ComplexU4 => "CUDA_C_4U",
126 Self::I8 => "CUDA_R_8I",
127 Self::ComplexI8 => "CUDA_C_8I",
128 Self::U8 => "CUDA_R_8U",
129 Self::ComplexU8 => "CUDA_C_8U",
130 Self::I16 => "CUDA_R_16I",
131 Self::ComplexI16 => "CUDA_C_16I",
132 Self::U16 => "CUDA_R_16U",
133 Self::ComplexU16 => "CUDA_C_16U",
134 Self::I32 => "CUDA_R_32I",
135 Self::ComplexI32 => "CUDA_C_32I",
136 Self::U32 => "CUDA_R_32U",
137 Self::ComplexU32 => "CUDA_C_32U",
138 Self::I64 => "CUDA_R_64I",
139 Self::ComplexI64 => "CUDA_C_64I",
140 Self::U64 => "CUDA_R_64U",
141 Self::ComplexU64 => "CUDA_C_64U",
142 Self::F8E4M3 => "CUDA_R_8F_E4M3",
143 Self::F8E5M2 => "CUDA_R_8F_E5M2",
144 Self::F8UE8M0 => "CUDA_R_8F_UE8M0",
145 Self::F6E2M3 => "CUDA_R_6F_E2M3",
146 Self::F6E3M2 => "CUDA_R_6F_E3M2",
147 Self::F4E2M1 => "CUDA_R_4F_E2M1",
148});
149
150pub trait DataTypeLike: Clone + Copy + Default + Debug + 'static {
151 fn data_type() -> DataType;
152
153 fn is_complex() -> bool;
154
155 fn rust_type_name() -> &'static str;
156}
157
158macro_rules! impl_data_type {
159 ($ty:ty, $data_type:ident, $is_complex:expr) => {
160 impl DataTypeLike for $ty {
161 fn data_type() -> DataType {
162 DataType::$data_type
163 }
164
165 fn is_complex() -> bool {
166 $is_complex
167 }
168
169 fn rust_type_name() -> &'static str {
170 stringify!($ty)
172 }
173 }
174 };
175}
176
177impl_data_type!(f32, F32, false);
178impl_data_type!(f64, F64, false);
179impl_data_type!(f16, F16, false);
180impl_data_type!(bf16, Bf16, false);
181impl_data_type!(f8e4m3, F8E4M3, false);
182impl_data_type!(f8e5m2, F8E5M2, false);
183impl_data_type!(f8ue8m0, F8UE8M0, false);
184impl_data_type!(f6e2m3, F6E2M3, false);
185impl_data_type!(f6e3m2, F6E3M2, false);
186impl_data_type!(f4e2m1, F4E2M1, false);
187impl_data_type!(i8, I8, false);
188impl_data_type!(u8, U8, false);
189impl_data_type!(i32, I32, false);
190impl_data_type!(u32, U32, false);
191impl_data_type!(Complex32, ComplexF32, true);
192impl_data_type!(Complex64, ComplexF64, true);
193
194#[cfg(test)]
195mod tests {
196 use super::{
197 DataType, DataTypeLike, bf16, f4e2m1, f6e2m3, f6e3m2, f8e4m3, f8e5m2, f8ue8m0, f16,
198 };
199
200 #[test]
201 fn test_low_precision_module_reexports_expected_types() {
202 let _ = f16::from_f32(1.0);
203 let _ = bf16::from_f32(1.0);
204 let _ = f8e4m3::from_bits(0);
205 let _ = f8e5m2::from_bits(0);
206 let _ = f8ue8m0::from_bits(0);
207 let _ = f6e2m3::from_bits(0);
208 let _ = f6e3m2::from_bits(0);
209 let _ = f4e2m1::from_bits(0);
210 }
211
212 #[test]
213 fn test_data_type_like_maps_low_precision_storage_types() {
214 assert_eq!(f16::data_type(), DataType::F16);
215 assert_eq!(bf16::data_type(), DataType::Bf16);
216 assert_eq!(f8e4m3::data_type(), DataType::F8E4M3);
217 assert_eq!(f8e5m2::data_type(), DataType::F8E5M2);
218 assert_eq!(f8ue8m0::data_type(), DataType::F8UE8M0);
219 assert_eq!(f6e2m3::data_type(), DataType::F6E2M3);
220 assert_eq!(f6e3m2::data_type(), DataType::F6E3M2);
221 assert_eq!(f4e2m1::data_type(), DataType::F4E2M1);
222 }
223}