scirs2_linalg/quantization/
operations.rs1use scirs2_core::ndarray::{Array1, Array2};
8
9use crate::error::{LinalgError, LinalgResult};
10
11use super::conversions::dequantize_matrix;
12use super::matrix::QuantizedMatrix;
13use super::types::{QuantizationMethod, QuantizationParams, QuantizedDataType};
14use super::vector::QuantizedVector;
15
16pub fn quantized_matmul(
29 a: &QuantizedMatrix,
30 a_params: &QuantizationParams,
31 b: &QuantizedMatrix,
32 b_params: &QuantizationParams,
33) -> LinalgResult<Array2<f32>> {
34 if a.ncols() != b.nrows() {
36 return Err(LinalgError::DimensionError(format!(
37 "Cannot multiply matrices with shapes {:?} and {:?}",
38 a.shape(),
39 b.shape()
40 )));
41 }
42
43 let (m, k) = a.shape();
44 let (_, n) = b.shape();
45
46 let mut result = Array2::zeros((m, n));
48
49 if matches!(
51 a.data_type,
52 QuantizedDataType::Float16 | QuantizedDataType::BFloat16
53 ) || matches!(
54 b.data_type,
55 QuantizedDataType::Float16 | QuantizedDataType::BFloat16
56 ) {
57 for i in 0..m {
59 for j in 0..n {
60 let mut sum = 0.0_f32;
61 for l in 0..k {
62 let a_val = a.get_f32(i, l);
63 let b_val = b.get_f32(l, j);
64 sum += a_val * b_val;
65 }
66 result[[i, j]] = sum;
67 }
68 }
69 return Ok(result);
70 }
71
72 let a_per_channel = a_params.method == QuantizationMethod::PerChannelSymmetric
74 || a_params.method == QuantizationMethod::PerChannelAffine;
75
76 let b_per_channel = b_params.method == QuantizationMethod::PerChannelSymmetric
77 || b_params.method == QuantizationMethod::PerChannelAffine;
78
79 if a_per_channel || b_per_channel {
81 let a_dequant = dequantize_matrix(a, a_params);
83 let b_dequant = dequantize_matrix(b, b_params);
84
85 for i in 0..m {
87 for j in 0..n {
88 let mut sum = 0.0_f32;
89 for l in 0..k {
90 sum += a_dequant[[i, l]] * b_dequant[[l, j]];
91 }
92 result[[i, j]] = sum;
93 }
94 }
95
96 return Ok(result);
97 }
98
99 for i in 0..m {
101 for j in 0..n {
102 let mut sum = 0i32;
103 for l in 0..k {
104 let a_val = a.get_i8(i, l) as i32;
106 let b_val = b.get_i8(l, j) as i32;
107 sum += a_val * b_val;
108 }
109
110 let a_scale = a_params.scale;
112 let b_scale = b_params.scale;
113
114 if (a_params.method == QuantizationMethod::Affine
116 || a_params.method == QuantizationMethod::UInt4)
117 && (b_params.method == QuantizationMethod::Affine
118 || b_params.method == QuantizationMethod::UInt4)
119 {
120 let a_zero_sum: i32 =
122 (0..k).map(|l| b.get_i8(l, j) as i32).sum::<i32>() * a_params.zero_point;
123 let b_zero_sum: i32 =
124 (0..k).map(|l| a.get_i8(i, l) as i32).sum::<i32>() * b_params.zero_point;
125 let zero_product = k as i32 * a_params.zero_point * b_params.zero_point;
126
127 sum = sum - a_zero_sum - b_zero_sum + zero_product;
128 }
129
130 result[[i, j]] = sum as f32 * a_scale * b_scale;
131 }
132 }
133
134 Ok(result)
135}
136
137pub fn quantized_matvec(
150 a: &QuantizedMatrix,
151 a_params: &QuantizationParams,
152 b: &QuantizedVector,
153 b_params: &QuantizationParams,
154) -> LinalgResult<Array1<f32>> {
155 if a.ncols() != b.len() {
157 return Err(LinalgError::DimensionError(format!(
158 "Cannot multiply matrix with shape {:?} and vector with length {}",
159 a.shape(),
160 b.len()
161 )));
162 }
163
164 let m = a.nrows();
165 let n = a.ncols();
166
167 let mut result = Array1::zeros(m);
169
170 if matches!(
172 a_params.data_type,
173 QuantizedDataType::Float16 | QuantizedDataType::BFloat16
174 ) && matches!(
175 b_params.data_type,
176 QuantizedDataType::Float16 | QuantizedDataType::BFloat16
177 ) {
178 let a_full = dequantize_matrix(a, a_params);
180 let b_full = dequantize_vector(b, b_params);
181
182 for i in 0..m {
183 let mut sum = 0.0_f32;
184 for j in 0..n {
185 sum += a_full[[i, j]] * b_full[j];
186 }
187 result[i] = sum;
188 }
189
190 return Ok(result);
191 }
192
193 let a_scale = a_params.scale;
195 let b_scale = b_params.scale;
196
197 for i in 0..m {
198 let mut sum: i32 = 0;
199
200 for j in 0..n {
201 let a_val = a.get_i8(i, j) as i32;
202 let b_val = b.get_i8(j) as i32;
203 sum += a_val * b_val;
204 }
205
206 result[i] = sum as f32 * a_scale * b_scale;
207 }
208
209 Ok(result)
210}
211
212pub fn quantized_dot(
225 a: &QuantizedVector,
226 a_params: &QuantizationParams,
227 b: &QuantizedVector,
228 b_params: &QuantizationParams,
229) -> LinalgResult<f32> {
230 if a.len() != b.len() {
232 return Err(LinalgError::DimensionError(format!(
233 "Cannot compute dot product of vectors with lengths {} and {}",
234 a.len(),
235 b.len()
236 )));
237 }
238
239 let n = a.len();
240
241 if matches!(
243 a_params.data_type,
244 QuantizedDataType::Float16 | QuantizedDataType::BFloat16
245 ) && matches!(
246 b_params.data_type,
247 QuantizedDataType::Float16 | QuantizedDataType::BFloat16
248 ) {
249 let a_full = dequantize_vector(a, a_params);
251 let b_full = dequantize_vector(b, b_params);
252
253 let mut sum = 0.0_f32;
254 for i in 0..n {
255 sum += a_full[i] * b_full[i];
256 }
257
258 return Ok(sum);
259 }
260
261 let a_scale = a_params.scale;
263 let b_scale = b_params.scale;
264
265 let mut sum: i32 = 0;
266
267 for i in 0..n {
268 let a_val = a.get_i8(i) as i32;
269 let b_val = b.get_i8(i) as i32;
270 sum += a_val * b_val;
271 }
272
273 Ok(sum as f32 * a_scale * b_scale)
274}
275
276fn dequantize_vector(vec: &QuantizedVector, _params: &QuantizationParams) -> Array1<f32> {
278 let n = vec.len();
279 let mut result = Array1::zeros(n);
280
281 for i in 0..n {
282 result[i] = vec.get_f32(i);
283 }
284
285 result
286}