scirs2_linalg/quantization/
fusion.rs1use crate::error::{LinalgError, LinalgResult};
8use crate::quantization::{
9 dequantize_matrix, get_quantizedmatrix_2d_i8, QuantizationMethod, QuantizationParams,
10 QuantizedData2D, QuantizedMatrix,
11};
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
13use std::fmt::Debug;
14
15#[allow(dead_code)]
30pub fn fused_quantized_matmul_chain(
31 matrices: &[&QuantizedMatrix],
32 params: &[&QuantizationParams],
33) -> LinalgResult<Array2<f32>> {
34 if matrices.len() < 2 {
36 return Err(LinalgError::ShapeError(
37 "At least two matrices are required for matmul chain".to_string(),
38 ));
39 }
40
41 if matrices.len() != params.len() {
42 return Err(LinalgError::ShapeError(
43 "Number of matrices must match number of quantization parameters".to_string(),
44 ));
45 }
46
47 for i in 0..matrices.len() - 1 {
49 if matrices[i].shape.1 != matrices[i + 1].shape.0 {
50 return Err(LinalgError::ShapeError(format!(
51 "Matrix dimensions mismatch at position {}: ({}, {}) * ({}, {})",
52 i,
53 matrices[i].shape.0,
54 matrices[i].shape.1,
55 matrices[i + 1].shape.0,
56 matrices[i + 1].shape.1
57 )));
58 }
59 }
60
61 let all_int8 = matrices
63 .iter()
64 .all(|m| matches!(m.data, QuantizedData2D::Int8(_)));
65
66 let all_symmetric = params
67 .iter()
68 .all(|p| p.method == QuantizationMethod::Symmetric || p.method == QuantizationMethod::Int4);
69
70 if all_int8 && all_symmetric {
71 fused_quantized_matmul_chain_int8_symmetric(matrices, params)
73 } else {
74 let mut dequantized_matrices = Vec::with_capacity(matrices.len());
77
78 for (matrix, param) in matrices.iter().zip(params.iter()) {
79 dequantized_matrices.push(dequantize_matrix(matrix, param));
80 }
81
82 let mut result = dequantized_matrices[0].clone();
84 for mat in dequantized_matrices.iter().skip(1) {
85 result = result.dot(mat);
86 }
87
88 Ok(result)
89 }
90}
91
92#[allow(dead_code)]
94fn fused_quantized_matmul_chain_int8_symmetric(
95 matrices: &[&QuantizedMatrix],
96 params: &[&QuantizationParams],
97) -> LinalgResult<Array2<f32>> {
98 let int8_matrices: Vec<&Array2<i8>> = matrices
100 .iter()
101 .map(|m| get_quantizedmatrix_2d_i8(m).unwrap())
102 .collect();
103
104 let scales: Vec<f32> = params.iter().map(|p| p.scale).collect();
106
107 let rows_ = matrices[0].shape.0;
109 let cols = matrices.last().unwrap().shape.1;
110 let mut result = Array2::zeros((rows_, cols));
111
112 let fused_scale: f32 = scales.iter().product();
114
115 const BLOCK_SIZE: usize = 32;
117 for i0 in (0..rows_).step_by(BLOCK_SIZE) {
118 let i_end = (i0 + BLOCK_SIZE).min(rows_);
119
120 for j0 in (0..cols).step_by(BLOCK_SIZE) {
121 let j_end = (j0 + BLOCK_SIZE).min(cols);
122
123 for i in i0..i_end {
125 for j in j0..j_end {
126 let mut middle_dim = matrices[0].shape.1;
131 let mut intermediate = vec![0i32; middle_dim];
132
133 for (k, val) in intermediate.iter_mut().enumerate().take(middle_dim) {
135 *val = int8_matrices[0][[i, k]] as i32;
136 }
137
138 for mat_idx in 1..matrices.len() - 1 {
140 let mat = int8_matrices[mat_idx];
141 let (_, inner_dim) = matrices[mat_idx].shape;
142
143 let mut new_intermediate = vec![0i32; inner_dim];
144
145 for l in 0..inner_dim {
147 for k in 0..middle_dim {
148 new_intermediate[l] += intermediate[k] * (mat[[k, l]] as i32);
149 }
150 }
151
152 intermediate = new_intermediate;
154 middle_dim = inner_dim;
155 }
156
157 let last_mat = int8_matrices.last().unwrap();
159 let mut sum = 0i32;
160
161 for k in 0..middle_dim {
162 sum += intermediate[k] * (last_mat[[k, j]] as i32);
163 }
164
165 result[[i, j]] = (sum as f32) * fused_scale;
167 }
168 }
169 }
170 }
171
172 Ok(result)
173}
174
175#[allow(dead_code)]
192pub fn fused_quantized_matvec_sequence<F>(
193 matrices: &[&QuantizedMatrix],
194 matrix_params: &[&QuantizationParams],
195 vector: &ArrayView1<F>,
196 output_quantize: bool,
197) -> LinalgResult<Array1<F>>
198where
199 F: scirs2_core::numeric::Float
200 + Debug
201 + scirs2_core::numeric::AsPrimitive<f32>
202 + scirs2_core::numeric::FromPrimitive,
203 f32: scirs2_core::numeric::AsPrimitive<F>,
204{
205 if matrices.is_empty() {
207 return Err(LinalgError::ShapeError(
208 "At least one matrix is required for matvec sequence".to_string(),
209 ));
210 }
211
212 if matrices.len() != matrix_params.len() {
213 return Err(LinalgError::ShapeError(
214 "Number of matrices must match number of quantization parameters".to_string(),
215 ));
216 }
217
218 let vector_len = vector.len();
220 if matrices.last().unwrap().shape.1 != vector_len {
221 return Err(LinalgError::ShapeError(format!(
222 "Last matrix columns ({}) must match vector length ({})",
223 matrices.last().unwrap().shape.1,
224 vector_len
225 )));
226 }
227
228 for i in 0..matrices.len() - 1 {
229 if matrices[i].shape.1 != matrices[i + 1].shape.0 {
230 return Err(LinalgError::ShapeError(format!(
231 "Matrix dimensions mismatch at position {}: ({}, {}) * ({}, {})",
232 i,
233 matrices[i].shape.0,
234 matrices[i].shape.1,
235 matrices[i + 1].shape.0,
236 matrices[i + 1].shape.1
237 )));
238 }
239 }
240
241 let all_int8 = matrices
243 .iter()
244 .all(|m| matches!(m.data, QuantizedData2D::Int8(_)));
245
246 if all_int8 {
247 let vector_f32 = vector.mapv(|x| x.as_());
249 let vector_f32_view = vector_f32.view();
250
251 let result_f32 = if matrices.len() == 1 {
253 use crate::quantization::simd::simd_quantized_matvec;
255 simd_quantized_matvec(matrices[0], matrix_params[0], &vector_f32_view)?
256 } else {
257 fused_quantized_matvec_sequence_int8(matrices, matrix_params, &vector_f32_view)?
259 };
260
261 if output_quantize {
263 Ok(result_f32.mapv(|x| scirs2_core::numeric::FromPrimitive::from_f32(x).unwrap()))
266 } else {
267 Ok(result_f32.mapv(|x| scirs2_core::numeric::FromPrimitive::from_f32(x).unwrap()))
269 }
270 } else {
271 let mut dequantized_matrices = Vec::with_capacity(matrices.len());
273
274 for (matrix, param) in matrices.iter().zip(matrix_params.iter()) {
275 dequantized_matrices.push(dequantize_matrix(matrix, param));
276 }
277
278 let vector_f32 = vector.mapv(|x| x.as_());
280
281 let mut result_f32 = vector_f32.insert_axis(scirs2_core::ndarray::Axis(1));
283
284 for mat in dequantized_matrices.iter().rev() {
286 result_f32 = mat.dot(&result_f32);
287 }
288
289 let result_1d_f32 = result_f32.remove_axis(scirs2_core::ndarray::Axis(1));
291
292 let result_f =
294 result_1d_f32.mapv(|x| scirs2_core::numeric::FromPrimitive::from_f32(x).unwrap());
295
296 Ok(result_f)
297 }
298}
299
300#[allow(dead_code)]
302fn fused_quantized_matvec_sequence_int8(
303 matrices: &[&QuantizedMatrix],
304 params: &[&QuantizationParams],
305 vector: &ArrayView1<f32>,
306) -> LinalgResult<Array1<f32>> {
307 let int8_matrices: Vec<&Array2<i8>> = matrices
309 .iter()
310 .map(|m| get_quantizedmatrix_2d_i8(m).unwrap())
311 .collect();
312
313 let scales: Vec<f32> = params.iter().map(|p| p.scale).collect();
315 let _zero_points: Vec<i32> = params.iter().map(|p| p.zero_point).collect();
317
318 let symmetric = params
320 .iter()
321 .all(|p| p.method == QuantizationMethod::Symmetric);
322
323 let output_dim = matrices[0].shape.0;
325 let mut result = Array1::zeros(output_dim);
326
327 if symmetric {
329 let fused_scale: f32 = scales.iter().product();
331
332 for i in 0..output_dim {
334 let row = int8_matrices[0].row(i);
335
336 let middle_dim = matrices[0].shape.1;
339 let mut intermediate = vec![0i32; middle_dim];
340
341 for k in 0..middle_dim {
342 intermediate[k] = row[k] as i32;
343 }
344
345 for mat_idx in 1..matrices.len() {
347 let mat = int8_matrices[mat_idx];
348 let (rows, cols) = matrices[mat_idx].shape;
349
350 let mut new_intermediate = vec![0i32; cols];
351
352 for c in 0..cols {
353 for r in 0..rows {
354 new_intermediate[c] += intermediate[r] * (mat[[r, c]] as i32);
355 }
356 }
357
358 intermediate = new_intermediate;
359 }
360
361 let mut sum = 0.0;
363 for k in 0..intermediate.len() {
364 sum += (intermediate[k] as f32) * vector[k];
365 }
366
367 result[i] = sum * fused_scale;
368 }
369 } else {
370 let mut dequantized_matrices = Vec::with_capacity(matrices.len());
375
376 for (matrix, param) in matrices.iter().zip(params.iter()) {
377 dequantized_matrices.push(dequantize_matrix(matrix, param));
378 }
379
380 let vector_2d = vector.to_owned().insert_axis(scirs2_core::ndarray::Axis(1));
382
383 let mut result_2d = vector_2d;
385 for mat in dequantized_matrices.iter().rev() {
386 result_2d = mat.dot(&result_2d);
387 }
388
389 result = result_2d.remove_axis(scirs2_core::ndarray::Axis(1));
391 }
392
393 Ok(result)
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use crate::quantization::{quantize_matrix, QuantizationMethod};
400 use approx::assert_relative_eq;
401 use scirs2_core::ndarray::array;
402
403 #[test]
404 fn test_fused_matmul_chain() {
405 let a = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
407 let b = array![[7.0f32, 8.0], [9.0, 10.0], [11.0, 12.0]];
408 let c = array![[13.0f32, 14.0, 15.0], [16.0, 17.0, 18.0]];
409
410 let (qa, qa_params) = quantize_matrix(&a.view(), 8, QuantizationMethod::Symmetric);
412 let (qb, qb_params) = quantize_matrix(&b.view(), 8, QuantizationMethod::Symmetric);
413 let (qc, qc_params) = quantize_matrix(&c.view(), 8, QuantizationMethod::Symmetric);
414
415 let ab = a.dot(&b);
417 let expected = ab.dot(&c);
418
419 let matrices = [&qa, &qb, &qc];
421 let params = [&qa_params, &qb_params, &qc_params];
422 let result = fused_quantized_matmul_chain(&matrices, ¶ms).unwrap();
423
424 assert_eq!(result.shape(), expected.shape());
426 for ((i, j), &val) in result.indexed_iter() {
427 assert_relative_eq!(val, expected[[i, j]], epsilon = 12.0);
428 }
429 }
430
431 #[test]
432 fn test_fused_matvec_sequence() {
433 let a = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
435 let b = array![[7.0f32, 8.0], [9.0, 10.0], [11.0, 12.0]];
436 let x = array![13.0f32, 14.0];
438
439 let (qa, qa_params) = quantize_matrix(&a.view(), 8, QuantizationMethod::Symmetric);
441 let (qb, qb_params) = quantize_matrix(&b.view(), 8, QuantizationMethod::Symmetric);
442
443 let bx = b.dot(&x);
445 let expected = a.dot(&bx);
446
447 let matrices = [&qa, &qb];
449 let params = [&qa_params, &qb_params];
450 let result = fused_quantized_matvec_sequence(&matrices, ¶ms, &x.view(), false).unwrap();
451
452 assert_eq!(result.len(), expected.len());
454 for (i, &val) in result.iter().enumerate() {
455 assert_relative_eq!(val, expected[i], epsilon = 5.0);
456 }
457 }
458}