scirs2_linalg/quantization/
matrix.rs

1//! Matrix quantization types and implementations
2//!
3//! This module contains the QuantizedMatrix struct and QuantizedData2D enum
4//! along with their implementations for handling quantized matrix data.
5
6use half::{bf16, f16};
7use scirs2_core::ndarray::{Array1, Array2};
8
9use super::types::QuantizedDataType;
10
11/// A matrix with quantized values
12#[derive(Debug, Clone)]
13pub struct QuantizedMatrix {
14    /// The quantized values can be stored in different formats
15    pub data: QuantizedData2D,
16
17    /// The original shape of the matrix
18    pub shape: (usize, usize),
19
20    /// The data type used for quantization
21    pub data_type: QuantizedDataType,
22}
23
24/// Storage for quantized 2D data (matrices) in different formats
25#[derive(Debug, Clone)]
26pub enum QuantizedData2D {
27    /// 8-bit integer storage
28    Int8(Array2<i8>),
29    /// 16-bit float storage (IEEE 754 half-precision)
30    Float16(Array2<f16>),
31    /// 16-bit brain float storage
32    BFloat16(Array2<bf16>),
33}
34
35impl QuantizedData2D {
36    /// Get the number of elements in the storage
37    pub fn len(&self) -> usize {
38        match self {
39            QuantizedData2D::Int8(arr) => arr.len(),
40            QuantizedData2D::Float16(arr) => arr.len(),
41            QuantizedData2D::BFloat16(arr) => arr.len(),
42        }
43    }
44
45    /// Check if the storage is empty
46    pub fn is_empty(&self) -> bool {
47        self.len() == 0
48    }
49}
50
51/// Helper function to get the i8 data from a QuantizedMatrix if available
52///
53/// Returns None if the matrix does not use Int8 storage
54#[allow(dead_code)]
55pub fn get_quantizedmatrix_2d_i8(matrix: &QuantizedMatrix) -> Option<&Array2<i8>> {
56    match &matrix.data {
57        QuantizedData2D::Int8(data) => Some(data),
58        _ => None,
59    }
60}
61
62impl QuantizedMatrix {
63    /// Creates a new quantized matrix with int8 storage
64    pub fn new_i8(data: Array2<i8>, shape: (usize, usize), data_type: QuantizedDataType) -> Self {
65        Self {
66            data: QuantizedData2D::Int8(data),
67            shape,
68            data_type,
69        }
70    }
71
72    /// Creates a new f16 quantized matrix
73    pub fn new_f16(data: Array2<f16>, shape: (usize, usize)) -> Self {
74        Self {
75            data: QuantizedData2D::Float16(data),
76            shape,
77            data_type: QuantizedDataType::Float16,
78        }
79    }
80
81    /// Creates a new bf16 quantized matrix
82    pub fn new_bf16(data: Array2<bf16>, shape: (usize, usize)) -> Self {
83        Self {
84            data: QuantizedData2D::BFloat16(data),
85            shape,
86            data_type: QuantizedDataType::BFloat16,
87        }
88    }
89
90    /// Creates a standard Int8 quantized matrix (for backward compatibility)
91    pub fn from_i8(data: Array2<i8>, shape: (usize, usize)) -> Self {
92        Self {
93            data: QuantizedData2D::Int8(data),
94            shape,
95            data_type: QuantizedDataType::Int8,
96        }
97    }
98
99    // This method stays for backward compatibility but will be deprecated in the future
100    // Use get_i8 or get_f32 instead
101    #[deprecated(since = "0.1.0", note = "Use get_i8 or get_f32 instead")]
102    pub fn get(&self, row: usize, col: usize) -> i8 {
103        self.get_i8(row, col)
104    }
105
106    /// Returns the shape of the matrix
107    pub fn shape(&self) -> (usize, usize) {
108        self.shape
109    }
110
111    /// Returns the number of rows in the matrix
112    pub fn nrows(&self) -> usize {
113        self.shape.0
114    }
115
116    /// Returns the number of columns in the matrix
117    pub fn ncols(&self) -> usize {
118        self.shape.1
119    }
120
121    /// Get value at specified position as i8 (for int quantization)
122    pub fn get_i8(&self, row: usize, col: usize) -> i8 {
123        match &self.data {
124            QuantizedData2D::Int8(arr) => {
125                match self.data_type {
126                    QuantizedDataType::Int8 => arr[[row, col]],
127                    QuantizedDataType::Int4 => {
128                        let idx = row * self.shape.1 + col;
129                        let byte_idx = idx / 2;
130                        let nibble_idx = idx % 2;
131                        let byte = arr.as_slice().unwrap()[byte_idx];
132
133                        if nibble_idx == 0 {
134                            // Upper 4 bits
135                            byte >> 4
136                        } else {
137                            // Lower 4 bits
138                            byte & 0x0F
139                        }
140                    }
141                    QuantizedDataType::UInt4 => {
142                        let idx = row * self.shape.1 + col;
143                        let byte_idx = idx / 2;
144                        let nibble_idx = idx % 2;
145                        let byte = arr.as_slice().unwrap()[byte_idx];
146
147                        if nibble_idx == 0 {
148                            // Upper 4 bits
149                            (byte >> 4) & 0x0F
150                        } else {
151                            // Lower 4 bits
152                            byte & 0x0F
153                        }
154                    }
155                    _ => unreachable!(
156                        "Invalid quantization type for Int8 storage: expected Int8, Int4, or UInt4"
157                    ),
158                }
159            }
160            _ => unreachable!("Cannot get i8 value from floating-point quantized matrix"),
161        }
162    }
163
164    /// Get value at specified position as f32 (for all quantization types)
165    pub fn get_f32(&self, row: usize, col: usize) -> f32 {
166        match &self.data {
167            QuantizedData2D::Int8(arr) => match self.data_type {
168                QuantizedDataType::Int8 => arr[[row, col]] as f32,
169                QuantizedDataType::Int4 => self.get_i8(row, col) as f32,
170                QuantizedDataType::UInt4 => self.get_i8(row, col) as f32,
171                _ => unreachable!(
172                    "Invalid data type for Int8 storage: expected Int8, Int4, or UInt4"
173                ),
174            },
175            QuantizedData2D::Float16(arr) => arr[[row, col]].to_f32(),
176            QuantizedData2D::BFloat16(arr) => arr[[row, col]].to_f32(),
177        }
178    }
179}