1use crate::{bail, CpuStorageRef, CpuStorageRefMut, Result};
2use half::{bf16, f16};
3
4#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
5pub enum DType {
6 BF16,
7 F16,
8 F32,
9 I32,
10 I64,
11}
12
13#[derive(Debug, PartialEq, Eq)]
14pub struct DTypeParseError(String);
15
16impl std::fmt::Display for DTypeParseError {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 write!(f, "cannot parse '{}' as a dtype", self.0)
19 }
20}
21
22impl std::error::Error for DTypeParseError {}
23
24impl std::str::FromStr for DType {
25 type Err = DTypeParseError;
26 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
27 match s {
28 "i32" => Ok(Self::I32),
29 "i64" => Ok(Self::I64),
30 "bf16" => Ok(Self::BF16),
31 "f16" => Ok(Self::F16),
32 "f32" => Ok(Self::F32),
33 _ => Err(DTypeParseError(s.to_string())),
34 }
35 }
36}
37
38impl DType {
39 pub fn as_str(&self) -> &'static str {
41 match self {
42 Self::I32 => "i32",
43 Self::I64 => "i64",
44 Self::BF16 => "bf16",
45 Self::F16 => "f16",
46 Self::F32 => "f32",
47 }
48 }
49
50 pub fn size_in_bytes(&self) -> usize {
52 match self {
53 Self::I32 => 4,
54 Self::I64 => 8,
55 Self::BF16 => 2,
56 Self::F16 => 2,
57 Self::F32 => 4,
58 }
59 }
60
61 pub fn is_int(&self) -> bool {
62 match self {
63 Self::I32 | Self::I64 => true,
64 Self::BF16 | Self::F16 | Self::F32 => false,
65 }
66 }
67
68 pub fn is_float(&self) -> bool {
69 match self {
70 Self::I32 | Self::I64 => false,
71 Self::BF16 | Self::F16 | Self::F32 => true,
72 }
73 }
74}
75
76pub trait WithDType: Copy + Clone + 'static + num::Zero + num::One {
77 const DTYPE: DType;
78
79 fn to_cpu_storage(data: &[Self]) -> CpuStorageRef<'_>;
80 fn from_cpu_storage(data: CpuStorageRef<'_>) -> Result<&[Self]>;
81 fn to_cpu_storage_mut(data: &mut [Self]) -> CpuStorageRefMut<'_>;
82 fn from_cpu_storage_mut(data: CpuStorageRefMut<'_>) -> Result<&mut [Self]>;
83}
84
85macro_rules! with_dtype {
86 ($ty:ty, $dtype:ident) => {
87 impl WithDType for $ty {
88 const DTYPE: DType = DType::$dtype;
89
90 fn to_cpu_storage_mut(data: &mut [Self]) -> CpuStorageRefMut<'_> {
91 CpuStorageRefMut::$dtype(data)
92 }
93
94 fn from_cpu_storage_mut(data: CpuStorageRefMut<'_>) -> Result<&mut [Self]> {
95 match data {
96 CpuStorageRefMut::$dtype(data) => Ok(data),
97 _ => {
98 bail!(
99 "unexpected dtype, expected {:?}, got {:?}",
100 Self::DTYPE,
101 data.dtype()
102 )
103 }
104 }
105 }
106 fn to_cpu_storage(data: &[Self]) -> CpuStorageRef<'_> {
107 CpuStorageRef::$dtype(data)
108 }
109
110 fn from_cpu_storage(data: CpuStorageRef<'_>) -> Result<&[Self]> {
111 match data {
112 CpuStorageRef::$dtype(data) => Ok(data),
113 _ => {
114 bail!(
115 "unexpected dtype, expected {:?}, got {:?}",
116 Self::DTYPE,
117 data.dtype()
118 )
119 }
120 }
121 }
122 }
123 };
124}
125with_dtype!(bf16, BF16);
126with_dtype!(f16, F16);
127with_dtype!(f32, F32);
128with_dtype!(i32, I32);
129with_dtype!(i64, I64);