Skip to main content

rvf_types/
data_type.rs

1//! Vector data type discriminator.
2
3/// Identifies the numeric encoding of vector elements.
4#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6#[repr(u8)]
7pub enum DataType {
8    /// 32-bit IEEE 754 float.
9    F32 = 0,
10    /// 16-bit IEEE 754 half-precision float.
11    F16 = 1,
12    /// Brain floating point (bfloat16).
13    BF16 = 2,
14    /// Signed 8-bit integer (scalar quantized).
15    I8 = 3,
16    /// Unsigned 8-bit integer.
17    U8 = 4,
18    /// 4-bit integer (packed, 2 per byte).
19    I4 = 5,
20    /// 1-bit binary (packed, 8 per byte).
21    Binary = 6,
22    /// Product-quantized codes.
23    PQ = 7,
24    /// Custom encoding (see QUANT_SEG for details).
25    Custom = 8,
26}
27
28impl DataType {
29    /// Returns the number of bits per element, or `None` for variable-width types.
30    pub const fn bits_per_element(self) -> Option<u32> {
31        match self {
32            Self::F32 => Some(32),
33            Self::F16 => Some(16),
34            Self::BF16 => Some(16),
35            Self::I8 => Some(8),
36            Self::U8 => Some(8),
37            Self::I4 => Some(4),
38            Self::Binary => Some(1),
39            Self::PQ | Self::Custom => None,
40        }
41    }
42}
43
44impl TryFrom<u8> for DataType {
45    type Error = u8;
46
47    fn try_from(value: u8) -> Result<Self, Self::Error> {
48        match value {
49            0 => Ok(Self::F32),
50            1 => Ok(Self::F16),
51            2 => Ok(Self::BF16),
52            3 => Ok(Self::I8),
53            4 => Ok(Self::U8),
54            5 => Ok(Self::I4),
55            6 => Ok(Self::Binary),
56            7 => Ok(Self::PQ),
57            8 => Ok(Self::Custom),
58            other => Err(other),
59        }
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66
67    #[test]
68    fn round_trip() {
69        for raw in 0..=8u8 {
70            let dt = DataType::try_from(raw).unwrap();
71            assert_eq!(dt as u8, raw);
72        }
73    }
74
75    #[test]
76    fn invalid_value() {
77        assert_eq!(DataType::try_from(9), Err(9));
78        assert_eq!(DataType::try_from(255), Err(255));
79    }
80
81    #[test]
82    fn bits_per_element() {
83        assert_eq!(DataType::F32.bits_per_element(), Some(32));
84        assert_eq!(DataType::F16.bits_per_element(), Some(16));
85        assert_eq!(DataType::I4.bits_per_element(), Some(4));
86        assert_eq!(DataType::Binary.bits_per_element(), Some(1));
87        assert_eq!(DataType::PQ.bits_per_element(), None);
88    }
89}