1#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6#[repr(u8)]
7pub enum DataType {
8 F32 = 0,
10 F16 = 1,
12 BF16 = 2,
14 I8 = 3,
16 U8 = 4,
18 I4 = 5,
20 Binary = 6,
22 PQ = 7,
24 Custom = 8,
26}
27
28impl DataType {
29 pub const fn bits_per_element(self) -> Option<u32> {
31 match self {
32 Self::F32 => Some(32),
33 Self::F16 => Some(16),
34 Self::BF16 => Some(16),
35 Self::I8 => Some(8),
36 Self::U8 => Some(8),
37 Self::I4 => Some(4),
38 Self::Binary => Some(1),
39 Self::PQ | Self::Custom => None,
40 }
41 }
42}
43
44impl TryFrom<u8> for DataType {
45 type Error = u8;
46
47 fn try_from(value: u8) -> Result<Self, Self::Error> {
48 match value {
49 0 => Ok(Self::F32),
50 1 => Ok(Self::F16),
51 2 => Ok(Self::BF16),
52 3 => Ok(Self::I8),
53 4 => Ok(Self::U8),
54 5 => Ok(Self::I4),
55 6 => Ok(Self::Binary),
56 7 => Ok(Self::PQ),
57 8 => Ok(Self::Custom),
58 other => Err(other),
59 }
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[test]
68 fn round_trip() {
69 for raw in 0..=8u8 {
70 let dt = DataType::try_from(raw).unwrap();
71 assert_eq!(dt as u8, raw);
72 }
73 }
74
75 #[test]
76 fn invalid_value() {
77 assert_eq!(DataType::try_from(9), Err(9));
78 assert_eq!(DataType::try_from(255), Err(255));
79 }
80
81 #[test]
82 fn bits_per_element() {
83 assert_eq!(DataType::F32.bits_per_element(), Some(32));
84 assert_eq!(DataType::F16.bits_per_element(), Some(16));
85 assert_eq!(DataType::I4.bits_per_element(), Some(4));
86 assert_eq!(DataType::Binary.bits_per_element(), Some(1));
87 assert_eq!(DataType::PQ.bits_per_element(), None);
88 }
89}