scirs2_linalg/quantization/
matrix.rs1use half::{bf16, f16};
7use scirs2_core::ndarray::{Array1, Array2};
8
9use super::types::QuantizedDataType;
10
11#[derive(Debug, Clone)]
13pub struct QuantizedMatrix {
14 pub data: QuantizedData2D,
16
17 pub shape: (usize, usize),
19
20 pub data_type: QuantizedDataType,
22}
23
24#[derive(Debug, Clone)]
26pub enum QuantizedData2D {
27 Int8(Array2<i8>),
29 Float16(Array2<f16>),
31 BFloat16(Array2<bf16>),
33}
34
35impl QuantizedData2D {
36 pub fn len(&self) -> usize {
38 match self {
39 QuantizedData2D::Int8(arr) => arr.len(),
40 QuantizedData2D::Float16(arr) => arr.len(),
41 QuantizedData2D::BFloat16(arr) => arr.len(),
42 }
43 }
44
45 pub fn is_empty(&self) -> bool {
47 self.len() == 0
48 }
49}
50
51#[allow(dead_code)]
55pub fn get_quantizedmatrix_2d_i8(matrix: &QuantizedMatrix) -> Option<&Array2<i8>> {
56 match &matrix.data {
57 QuantizedData2D::Int8(data) => Some(data),
58 _ => None,
59 }
60}
61
62impl QuantizedMatrix {
63 pub fn new_i8(data: Array2<i8>, shape: (usize, usize), data_type: QuantizedDataType) -> Self {
65 Self {
66 data: QuantizedData2D::Int8(data),
67 shape,
68 data_type,
69 }
70 }
71
72 pub fn new_f16(data: Array2<f16>, shape: (usize, usize)) -> Self {
74 Self {
75 data: QuantizedData2D::Float16(data),
76 shape,
77 data_type: QuantizedDataType::Float16,
78 }
79 }
80
81 pub fn new_bf16(data: Array2<bf16>, shape: (usize, usize)) -> Self {
83 Self {
84 data: QuantizedData2D::BFloat16(data),
85 shape,
86 data_type: QuantizedDataType::BFloat16,
87 }
88 }
89
90 pub fn from_i8(data: Array2<i8>, shape: (usize, usize)) -> Self {
92 Self {
93 data: QuantizedData2D::Int8(data),
94 shape,
95 data_type: QuantizedDataType::Int8,
96 }
97 }
98
99 #[deprecated(since = "0.1.0", note = "Use get_i8 or get_f32 instead")]
102 pub fn get(&self, row: usize, col: usize) -> i8 {
103 self.get_i8(row, col)
104 }
105
106 pub fn shape(&self) -> (usize, usize) {
108 self.shape
109 }
110
111 pub fn nrows(&self) -> usize {
113 self.shape.0
114 }
115
116 pub fn ncols(&self) -> usize {
118 self.shape.1
119 }
120
121 pub fn get_i8(&self, row: usize, col: usize) -> i8 {
123 match &self.data {
124 QuantizedData2D::Int8(arr) => {
125 match self.data_type {
126 QuantizedDataType::Int8 => arr[[row, col]],
127 QuantizedDataType::Int4 => {
128 let idx = row * self.shape.1 + col;
129 let byte_idx = idx / 2;
130 let nibble_idx = idx % 2;
131 let byte = arr.as_slice().unwrap()[byte_idx];
132
133 if nibble_idx == 0 {
134 byte >> 4
136 } else {
137 byte & 0x0F
139 }
140 }
141 QuantizedDataType::UInt4 => {
142 let idx = row * self.shape.1 + col;
143 let byte_idx = idx / 2;
144 let nibble_idx = idx % 2;
145 let byte = arr.as_slice().unwrap()[byte_idx];
146
147 if nibble_idx == 0 {
148 (byte >> 4) & 0x0F
150 } else {
151 byte & 0x0F
153 }
154 }
155 _ => unreachable!(
156 "Invalid quantization type for Int8 storage: expected Int8, Int4, or UInt4"
157 ),
158 }
159 }
160 _ => unreachable!("Cannot get i8 value from floating-point quantized matrix"),
161 }
162 }
163
164 pub fn get_f32(&self, row: usize, col: usize) -> f32 {
166 match &self.data {
167 QuantizedData2D::Int8(arr) => match self.data_type {
168 QuantizedDataType::Int8 => arr[[row, col]] as f32,
169 QuantizedDataType::Int4 => self.get_i8(row, col) as f32,
170 QuantizedDataType::UInt4 => self.get_i8(row, col) as f32,
171 _ => unreachable!(
172 "Invalid data type for Int8 storage: expected Int8, Int4, or UInt4"
173 ),
174 },
175 QuantizedData2D::Float16(arr) => arr[[row, col]].to_f32(),
176 QuantizedData2D::BFloat16(arr) => arr[[row, col]].to_f32(),
177 }
178 }
179}