scirs2_linalg/quantization/
vector.rs

1//! Vector quantization types and implementations
2//!
3//! This module contains the QuantizedVector struct and QuantizedData1D enum
4//! along with their implementations for handling quantized vector data.
5
6use half::{bf16, f16};
7use scirs2_core::ndarray::Array1;
8
9use super::types::QuantizedDataType;
10
11/// A vector with quantized values
12#[derive(Debug, Clone)]
13pub struct QuantizedVector {
14    /// The quantized values can be stored in different formats
15    pub data: QuantizedData1D,
16
17    /// The original length of the vector
18    pub length: usize,
19
20    /// The data type used for quantization
21    pub data_type: QuantizedDataType,
22}
23
24/// Storage for quantized 1D data (vectors) in different formats
25#[derive(Debug, Clone)]
26pub enum QuantizedData1D {
27    /// 8-bit integer storage
28    Int8(Array1<i8>),
29    /// 16-bit float storage (IEEE 754 half-precision)
30    Float16(Array1<f16>),
31    /// 16-bit brain float storage
32    BFloat16(Array1<bf16>),
33}
34
35impl QuantizedData1D {
36    /// Get the number of elements in the storage
37    pub fn len(&self) -> usize {
38        match self {
39            QuantizedData1D::Int8(arr) => arr.len(),
40            QuantizedData1D::Float16(arr) => arr.len(),
41            QuantizedData1D::BFloat16(arr) => arr.len(),
42        }
43    }
44
45    /// Check if the storage is empty
46    pub fn is_empty(&self) -> bool {
47        self.len() == 0
48    }
49}
50
51/// Helper function to get the i8 data from a QuantizedVector if available
52///
53/// Returns None if the vector does not use Int8 storage
54#[allow(dead_code)]
55pub fn get_quantized_vector_1d_i8(vector: &QuantizedVector) -> Option<&Array1<i8>> {
56    match &vector.data {
57        QuantizedData1D::Int8(data) => Some(data),
58        _ => None,
59    }
60}
61
62impl QuantizedVector {
63    /// Creates a new quantized vector with int8 storage
64    pub fn new_i8(data: Array1<i8>, length: usize, datatype: QuantizedDataType) -> Self {
65        Self {
66            data: QuantizedData1D::Int8(data),
67            length,
68            data_type: datatype,
69        }
70    }
71
72    /// Creates a new f16 quantized vector
73    pub fn new_f16(data: Array1<f16>, length: usize) -> Self {
74        Self {
75            data: QuantizedData1D::Float16(data),
76            length,
77            data_type: QuantizedDataType::Float16,
78        }
79    }
80
81    /// Creates a new bf16 quantized vector
82    pub fn new_bf16(data: Array1<bf16>, length: usize) -> Self {
83        Self {
84            data: QuantizedData1D::BFloat16(data),
85            length,
86            data_type: QuantizedDataType::BFloat16,
87        }
88    }
89
90    /// Creates a standard Int8 quantized vector (for backward compatibility)
91    pub fn from_i8(data: Array1<i8>, length: usize) -> Self {
92        Self {
93            data: QuantizedData1D::Int8(data),
94            length,
95            data_type: QuantizedDataType::Int8,
96        }
97    }
98
99    // This method stays for backward compatibility but will be deprecated in the future
100    // Use get_i8 or get_f32 instead
101    #[deprecated(since = "0.1.0", note = "Use get_i8 or get_f32 instead")]
102    pub fn get(&self, idx: usize) -> i8 {
103        self.get_i8(idx)
104    }
105
106    /// Returns the length of the vector
107    pub fn len(&self) -> usize {
108        self.length
109    }
110
111    /// Returns true if the vector is empty
112    pub fn is_empty(&self) -> bool {
113        self.length == 0
114    }
115
116    /// Get value at specified position as i8 (for int quantization)
117    pub fn get_i8(&self, idx: usize) -> i8 {
118        match &self.data {
119            QuantizedData1D::Int8(arr) => {
120                match self.data_type {
121                    QuantizedDataType::Int8 => arr[idx],
122                    QuantizedDataType::Int4 => {
123                        let byte_idx = idx / 2;
124                        let nibble_idx = idx % 2;
125                        let byte = arr[byte_idx];
126
127                        if nibble_idx == 0 {
128                            // Upper 4 bits (including sign bit)
129                            byte >> 4
130                        } else {
131                            // Lower 4 bits (including sign bit)
132                            byte & 0x0F
133                        }
134                    }
135                    QuantizedDataType::UInt4 => {
136                        let byte_idx = idx / 2;
137                        let nibble_idx = idx % 2;
138                        let byte = arr[byte_idx];
139
140                        if nibble_idx == 0 {
141                            // Upper 4 bits (no sign bit)
142                            (byte >> 4) & 0x0F
143                        } else {
144                            // Lower 4 bits (no sign bit)
145                            byte & 0x0F
146                        }
147                    }
148                    _ => unreachable!(
149                        "Invalid quantization type for Int8 storage: expected Int8, Int4, or UInt4"
150                    ),
151                }
152            }
153            _ => unreachable!("Cannot get i8 value from floating-point quantized vector"),
154        }
155    }
156
157    /// Get value at specified position as f32 (for all quantization types)
158    pub fn get_f32(&self, idx: usize) -> f32 {
159        match &self.data {
160            QuantizedData1D::Int8(arr) => match self.data_type {
161                QuantizedDataType::Int8 => arr[idx] as f32,
162                QuantizedDataType::Int4 => self.get_i8(idx) as f32,
163                QuantizedDataType::UInt4 => self.get_i8(idx) as f32,
164                _ => unreachable!(
165                    "Invalid data type for Int8 storage: expected Int8, Int4, or UInt4"
166                ),
167            },
168            QuantizedData1D::Float16(arr) => arr[idx].to_f32(),
169            QuantizedData1D::BFloat16(arr) => arr[idx].to_f32(),
170        }
171    }
172}