ug/
dtype.rs

1use crate::{bail, CpuStorageRef, CpuStorageRefMut, Result};
2use half::{bf16, f16};
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
5pub enum DType {
6    BF16,
7    F16,
8    F32,
9    I32,
10    I64,
11}
12
13#[derive(Debug, PartialEq, Eq)]
14pub struct DTypeParseError(String);
15
16impl std::fmt::Display for DTypeParseError {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        write!(f, "cannot parse '{}' as a dtype", self.0)
19    }
20}
21
22impl std::error::Error for DTypeParseError {}
23
24impl std::str::FromStr for DType {
25    type Err = DTypeParseError;
26    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
27        match s {
28            "i32" => Ok(Self::I32),
29            "i64" => Ok(Self::I64),
30            "bf16" => Ok(Self::BF16),
31            "f16" => Ok(Self::F16),
32            "f32" => Ok(Self::F32),
33            _ => Err(DTypeParseError(s.to_string())),
34        }
35    }
36}
37
38impl DType {
39    /// String representation for dtypes.
40    pub fn as_str(&self) -> &'static str {
41        match self {
42            Self::I32 => "i32",
43            Self::I64 => "i64",
44            Self::BF16 => "bf16",
45            Self::F16 => "f16",
46            Self::F32 => "f32",
47        }
48    }
49
50    /// The size used by each element in bytes, i.e. 4 for `F32`.
51    pub fn size_in_bytes(&self) -> usize {
52        match self {
53            Self::I32 => 4,
54            Self::I64 => 8,
55            Self::BF16 => 2,
56            Self::F16 => 2,
57            Self::F32 => 4,
58        }
59    }
60
61    pub fn is_int(&self) -> bool {
62        match self {
63            Self::I32 | Self::I64 => true,
64            Self::BF16 | Self::F16 | Self::F32 => false,
65        }
66    }
67
68    pub fn is_float(&self) -> bool {
69        match self {
70            Self::I32 | Self::I64 => false,
71            Self::BF16 | Self::F16 | Self::F32 => true,
72        }
73    }
74}
75
76pub trait WithDType: Copy + Clone + 'static + num::Zero + num::One {
77    const DTYPE: DType;
78
79    fn to_cpu_storage(data: &[Self]) -> CpuStorageRef<'_>;
80    fn from_cpu_storage(data: CpuStorageRef<'_>) -> Result<&[Self]>;
81    fn to_cpu_storage_mut(data: &mut [Self]) -> CpuStorageRefMut<'_>;
82    fn from_cpu_storage_mut(data: CpuStorageRefMut<'_>) -> Result<&mut [Self]>;
83}
84
85macro_rules! with_dtype {
86    ($ty:ty, $dtype:ident) => {
87        impl WithDType for $ty {
88            const DTYPE: DType = DType::$dtype;
89
90            fn to_cpu_storage_mut(data: &mut [Self]) -> CpuStorageRefMut<'_> {
91                CpuStorageRefMut::$dtype(data)
92            }
93
94            fn from_cpu_storage_mut(data: CpuStorageRefMut<'_>) -> Result<&mut [Self]> {
95                match data {
96                    CpuStorageRefMut::$dtype(data) => Ok(data),
97                    _ => {
98                        bail!(
99                            "unexpected dtype, expected {:?}, got {:?}",
100                            Self::DTYPE,
101                            data.dtype()
102                        )
103                    }
104                }
105            }
106            fn to_cpu_storage(data: &[Self]) -> CpuStorageRef<'_> {
107                CpuStorageRef::$dtype(data)
108            }
109
110            fn from_cpu_storage(data: CpuStorageRef<'_>) -> Result<&[Self]> {
111                match data {
112                    CpuStorageRef::$dtype(data) => Ok(data),
113                    _ => {
114                        bail!(
115                            "unexpected dtype, expected {:?}, got {:?}",
116                            Self::DTYPE,
117                            data.dtype()
118                        )
119                    }
120                }
121            }
122        }
123    };
124}
125with_dtype!(bf16, BF16);
126with_dtype!(f16, F16);
127with_dtype!(f32, F32);
128with_dtype!(i32, I32);
129with_dtype!(i64, I64);