scirs2_linalg/quantization/
fusion.rs

1//! Fusion of consecutive quantized operations
2//!
3//! This module provides optimized implementations for fusing multiple
4//! quantized operations, avoiding the overhead of intermediate dequantization
5//! and requantization steps in performance-critical code paths.
6
7use crate::error::{LinalgError, LinalgResult};
8use crate::quantization::{
9    dequantize_matrix, get_quantizedmatrix_2d_i8, QuantizationMethod, QuantizationParams,
10    QuantizedData2D, QuantizedMatrix,
11};
12use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
13use std::fmt::Debug;
14
15/// Fused quantized matrix multiplication chain
16///
17/// Computes a chain of matrix multiplications (A * B * C * ...) where all
18/// matrices are in quantized form. This is more efficient than performing
19/// individual multiplications and re-quantizing intermediate results.
20///
21/// # Arguments
22///
23/// * `matrices` - A slice of quantized matrices to multiply
24/// * `params` - A slice of quantization parameters for each matrix
25///
26/// # Returns
27///
28/// * The result of the matrix multiplication chain
29#[allow(dead_code)]
30pub fn fused_quantized_matmul_chain(
31    matrices: &[&QuantizedMatrix],
32    params: &[&QuantizationParams],
33) -> LinalgResult<Array2<f32>> {
34    // Validate input
35    if matrices.len() < 2 {
36        return Err(LinalgError::ShapeError(
37            "At least two matrices are required for matmul chain".to_string(),
38        ));
39    }
40
41    if matrices.len() != params.len() {
42        return Err(LinalgError::ShapeError(
43            "Number of matrices must match number of quantization parameters".to_string(),
44        ));
45    }
46
47    // Check dimension compatibility
48    for i in 0..matrices.len() - 1 {
49        if matrices[i].shape.1 != matrices[i + 1].shape.0 {
50            return Err(LinalgError::ShapeError(format!(
51                "Matrix dimensions mismatch at position {}: ({}, {}) * ({}, {})",
52                i,
53                matrices[i].shape.0,
54                matrices[i].shape.1,
55                matrices[i + 1].shape.0,
56                matrices[i + 1].shape.1
57            )));
58        }
59    }
60
61    // Check if all matrices are Int8 format (for now, we only optimize this case)
62    let all_int8 = matrices
63        .iter()
64        .all(|m| matches!(m.data, QuantizedData2D::Int8(_)));
65
66    let all_symmetric = params
67        .iter()
68        .all(|p| p.method == QuantizationMethod::Symmetric || p.method == QuantizationMethod::Int4);
69
70    if all_int8 && all_symmetric {
71        // Optimized path for symmetric Int8 quantization
72        fused_quantized_matmul_chain_int8_symmetric(matrices, params)
73    } else {
74        // Fallback path: dequantize all matrices and perform regular matmul
75        // This could be optimized further in the future
76        let mut dequantized_matrices = Vec::with_capacity(matrices.len());
77
78        for (matrix, param) in matrices.iter().zip(params.iter()) {
79            dequantized_matrices.push(dequantize_matrix(matrix, param));
80        }
81
82        // Compute the matrix multiplication chain
83        let mut result = dequantized_matrices[0].clone();
84        for mat in dequantized_matrices.iter().skip(1) {
85            result = result.dot(mat);
86        }
87
88        Ok(result)
89    }
90}
91
92/// Optimized implementation for Int8 symmetric quantized matrices
93#[allow(dead_code)]
94fn fused_quantized_matmul_chain_int8_symmetric(
95    matrices: &[&QuantizedMatrix],
96    params: &[&QuantizationParams],
97) -> LinalgResult<Array2<f32>> {
98    // Extract Int8 data from matrices
99    let int8_matrices: Vec<&Array2<i8>> = matrices
100        .iter()
101        .map(|m| get_quantizedmatrix_2d_i8(m).unwrap())
102        .collect();
103
104    // Scales from the quantization parameters
105    let scales: Vec<f32> = params.iter().map(|p| p.scale).collect();
106
107    // Result dimensions
108    let rows_ = matrices[0].shape.0;
109    let cols = matrices.last().unwrap().shape.1;
110    let mut result = Array2::zeros((rows_, cols));
111
112    // Compute the fused scale factor - product of all scales
113    let fused_scale: f32 = scales.iter().product();
114
115    // Use block multiplication for better cache efficiency
116    const BLOCK_SIZE: usize = 32;
117    for i0 in (0..rows_).step_by(BLOCK_SIZE) {
118        let i_end = (i0 + BLOCK_SIZE).min(rows_);
119
120        for j0 in (0..cols).step_by(BLOCK_SIZE) {
121            let j_end = (j0 + BLOCK_SIZE).min(cols);
122
123            // Process each cell in the output block
124            for i in i0..i_end {
125                for j in j0..j_end {
126                    // We need to do the matrix chain multiplication for this cell
127                    // This is a dot product through the different matrices
128
129                    // Use an intermediate buffer for partial results in the chain
130                    let mut middle_dim = matrices[0].shape.1;
131                    let mut intermediate = vec![0i32; middle_dim];
132
133                    // Initialize with first matrix row
134                    for (k, val) in intermediate.iter_mut().enumerate().take(middle_dim) {
135                        *val = int8_matrices[0][[i, k]] as i32;
136                    }
137
138                    // Process intermediate matrices (all except first and last)
139                    for mat_idx in 1..matrices.len() - 1 {
140                        let mat = int8_matrices[mat_idx];
141                        let (_, inner_dim) = matrices[mat_idx].shape;
142
143                        let mut new_intermediate = vec![0i32; inner_dim];
144
145                        // Propagate through the next matrix
146                        for l in 0..inner_dim {
147                            for k in 0..middle_dim {
148                                new_intermediate[l] += intermediate[k] * (mat[[k, l]] as i32);
149                            }
150                        }
151
152                        // Update intermediate and dimension for next iteration
153                        intermediate = new_intermediate;
154                        middle_dim = inner_dim;
155                    }
156
157                    // Final matrix
158                    let last_mat = int8_matrices.last().unwrap();
159                    let mut sum = 0i32;
160
161                    for k in 0..middle_dim {
162                        sum += intermediate[k] * (last_mat[[k, j]] as i32);
163                    }
164
165                    // Apply fused scaling factor
166                    result[[i, j]] = (sum as f32) * fused_scale;
167                }
168            }
169        }
170    }
171
172    Ok(result)
173}
174
175/// Fused quantized matrix-vector multiplication sequence
176///
177/// Computes the matrix-vector sequence (A * B * ... * x) where matrices and
178/// vector are in quantized form. This avoids dequantizing and requantizing
179/// intermediate results.
180///
181/// # Arguments
182///
183/// * `matrices` - A slice of quantized matrices to multiply
184/// * `matrix_params` - A slice of quantization parameters for each matrix
185/// * `vector` - The quantized vector to multiply with
186/// * `vector_params` - Quantization parameters for the vector
187///
188/// # Returns
189///
190/// * The result of the matrix-vector multiplication sequence
191#[allow(dead_code)]
192pub fn fused_quantized_matvec_sequence<F>(
193    matrices: &[&QuantizedMatrix],
194    matrix_params: &[&QuantizationParams],
195    vector: &ArrayView1<F>,
196    output_quantize: bool,
197) -> LinalgResult<Array1<F>>
198where
199    F: scirs2_core::numeric::Float
200        + Debug
201        + scirs2_core::numeric::AsPrimitive<f32>
202        + scirs2_core::numeric::FromPrimitive,
203    f32: scirs2_core::numeric::AsPrimitive<F>,
204{
205    // Validate input
206    if matrices.is_empty() {
207        return Err(LinalgError::ShapeError(
208            "At least one matrix is required for matvec sequence".to_string(),
209        ));
210    }
211
212    if matrices.len() != matrix_params.len() {
213        return Err(LinalgError::ShapeError(
214            "Number of matrices must match number of quantization parameters".to_string(),
215        ));
216    }
217
218    // Check dimension compatibility
219    let vector_len = vector.len();
220    if matrices.last().unwrap().shape.1 != vector_len {
221        return Err(LinalgError::ShapeError(format!(
222            "Last matrix columns ({}) must match vector length ({})",
223            matrices.last().unwrap().shape.1,
224            vector_len
225        )));
226    }
227
228    for i in 0..matrices.len() - 1 {
229        if matrices[i].shape.1 != matrices[i + 1].shape.0 {
230            return Err(LinalgError::ShapeError(format!(
231                "Matrix dimensions mismatch at position {}: ({}, {}) * ({}, {})",
232                i,
233                matrices[i].shape.0,
234                matrices[i].shape.1,
235                matrices[i + 1].shape.0,
236                matrices[i + 1].shape.1
237            )));
238        }
239    }
240
241    // Check if all matrices are Int8 format (for now, we only optimize this case)
242    let all_int8 = matrices
243        .iter()
244        .all(|m| matches!(m.data, QuantizedData2D::Int8(_)));
245
246    if all_int8 {
247        // Convert vector to f32
248        let vector_f32 = vector.mapv(|x| x.as_());
249        let vector_f32_view = vector_f32.view();
250
251        // Compute result as f32
252        let result_f32 = if matrices.len() == 1 {
253            // Single matrix case - use the existing SIMD function
254            use crate::quantization::simd::simd_quantized_matvec;
255            simd_quantized_matvec(matrices[0], matrix_params[0], &vector_f32_view)?
256        } else {
257            // Multiple matrices case - fuse the operation
258            fused_quantized_matvec_sequence_int8(matrices, matrix_params, &vector_f32_view)?
259        };
260
261        // Convert back to the original type
262        if output_quantize {
263            // In a complete implementation, we would _quantize the result to the same bit depth
264            // But for simplicity, just convert back to the original type
265            Ok(result_f32.mapv(|x| scirs2_core::numeric::FromPrimitive::from_f32(x).unwrap()))
266        } else {
267            // Return as float directly
268            Ok(result_f32.mapv(|x| scirs2_core::numeric::FromPrimitive::from_f32(x).unwrap()))
269        }
270    } else {
271        // Fallback path: dequantize all matrices and perform regular matmul
272        let mut dequantized_matrices = Vec::with_capacity(matrices.len());
273
274        for (matrix, param) in matrices.iter().zip(matrix_params.iter()) {
275            dequantized_matrices.push(dequantize_matrix(matrix, param));
276        }
277
278        // Convert to f32 for internal calculations
279        let vector_f32 = vector.mapv(|x| x.as_());
280
281        // Create a column vector from the 1D array
282        let mut result_f32 = vector_f32.insert_axis(scirs2_core::ndarray::Axis(1));
283
284        // Apply matrices in reverse order (rightmost first)
285        for mat in dequantized_matrices.iter().rev() {
286            result_f32 = mat.dot(&result_f32);
287        }
288
289        // Convert back to 1D array and then to original type
290        let result_1d_f32 = result_f32.remove_axis(scirs2_core::ndarray::Axis(1));
291
292        // Convert back to the original type
293        let result_f =
294            result_1d_f32.mapv(|x| scirs2_core::numeric::FromPrimitive::from_f32(x).unwrap());
295
296        Ok(result_f)
297    }
298}
299
300/// Optimized implementation for Int8 quantized matrices in a matvec sequence
301#[allow(dead_code)]
302fn fused_quantized_matvec_sequence_int8(
303    matrices: &[&QuantizedMatrix],
304    params: &[&QuantizationParams],
305    vector: &ArrayView1<f32>,
306) -> LinalgResult<Array1<f32>> {
307    // Extract Int8 data
308    let int8_matrices: Vec<&Array2<i8>> = matrices
309        .iter()
310        .map(|m| get_quantizedmatrix_2d_i8(m).unwrap())
311        .collect();
312
313    // Get scales from the parameters
314    let scales: Vec<f32> = params.iter().map(|p| p.scale).collect();
315    // Zero points used only in the asymmetric path
316    let _zero_points: Vec<i32> = params.iter().map(|p| p.zero_point).collect();
317
318    // For symmetric quantization, zero points should be zero
319    let symmetric = params
320        .iter()
321        .all(|p| p.method == QuantizationMethod::Symmetric);
322
323    // Get output dimensions
324    let output_dim = matrices[0].shape.0;
325    let mut result = Array1::zeros(output_dim);
326
327    // Compute the result using block-based approach for better cache efficiency
328    if symmetric {
329        // Faster path for symmetric quantization
330        let fused_scale: f32 = scales.iter().product();
331
332        // We'll compute one result element at a time
333        for i in 0..output_dim {
334            let row = int8_matrices[0].row(i);
335
336            // For each element in the output, we need to compute a complex contraction
337            // Initialize with first matrix row
338            let middle_dim = matrices[0].shape.1;
339            let mut intermediate = vec![0i32; middle_dim];
340
341            for k in 0..middle_dim {
342                intermediate[k] = row[k] as i32;
343            }
344
345            // Propagate through the other matrices
346            for mat_idx in 1..matrices.len() {
347                let mat = int8_matrices[mat_idx];
348                let (rows, cols) = matrices[mat_idx].shape;
349
350                let mut new_intermediate = vec![0i32; cols];
351
352                for c in 0..cols {
353                    for r in 0..rows {
354                        new_intermediate[c] += intermediate[r] * (mat[[r, c]] as i32);
355                    }
356                }
357
358                intermediate = new_intermediate;
359            }
360
361            // Final dot product with the vector
362            let mut sum = 0.0;
363            for k in 0..intermediate.len() {
364                sum += (intermediate[k] as f32) * vector[k];
365            }
366
367            result[i] = sum * fused_scale;
368        }
369    } else {
370        // Path for asymmetric quantization (we need to handle zero points)
371        // This is more complex and less optimized
372
373        // Dequantize the matrices first
374        let mut dequantized_matrices = Vec::with_capacity(matrices.len());
375
376        for (matrix, param) in matrices.iter().zip(params.iter()) {
377            dequantized_matrices.push(dequantize_matrix(matrix, param));
378        }
379
380        // Create a column vector for matrix operations
381        let vector_2d = vector.to_owned().insert_axis(scirs2_core::ndarray::Axis(1));
382
383        // Apply matrices in reverse order
384        let mut result_2d = vector_2d;
385        for mat in dequantized_matrices.iter().rev() {
386            result_2d = mat.dot(&result_2d);
387        }
388
389        // Extract the column back to 1D
390        result = result_2d.remove_axis(scirs2_core::ndarray::Axis(1));
391    }
392
393    Ok(result)
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use crate::quantization::{quantize_matrix, QuantizationMethod};
400    use approx::assert_relative_eq;
401    use scirs2_core::ndarray::array;
402
403    #[test]
404    fn test_fused_matmul_chain() {
405        // Create test matrices
406        let a = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
407        let b = array![[7.0f32, 8.0], [9.0, 10.0], [11.0, 12.0]];
408        let c = array![[13.0f32, 14.0, 15.0], [16.0, 17.0, 18.0]];
409
410        // Quantize matrices
411        let (qa, qa_params) = quantize_matrix(&a.view(), 8, QuantizationMethod::Symmetric);
412        let (qb, qb_params) = quantize_matrix(&b.view(), 8, QuantizationMethod::Symmetric);
413        let (qc, qc_params) = quantize_matrix(&c.view(), 8, QuantizationMethod::Symmetric);
414
415        // Expected result - regular matrix multiplication chain
416        let ab = a.dot(&b);
417        let expected = ab.dot(&c);
418
419        // Fused chain calculation
420        let matrices = [&qa, &qb, &qc];
421        let params = [&qa_params, &qb_params, &qc_params];
422        let result = fused_quantized_matmul_chain(&matrices, &params).unwrap();
423
424        // Verify correctness with tolerance for quantization error
425        assert_eq!(result.shape(), expected.shape());
426        for ((i, j), &val) in result.indexed_iter() {
427            assert_relative_eq!(val, expected[[i, j]], epsilon = 12.0);
428        }
429    }
430
431    #[test]
432    fn test_fused_matvec_sequence() {
433        // Create test matrices and vector
434        let a = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
435        let b = array![[7.0f32, 8.0], [9.0, 10.0], [11.0, 12.0]];
436        // Use 1D vector instead of 2D
437        let x = array![13.0f32, 14.0];
438
439        // Quantize matrices and vector
440        let (qa, qa_params) = quantize_matrix(&a.view(), 8, QuantizationMethod::Symmetric);
441        let (qb, qb_params) = quantize_matrix(&b.view(), 8, QuantizationMethod::Symmetric);
442
443        // Expected result
444        let bx = b.dot(&x);
445        let expected = a.dot(&bx);
446
447        // Fused calculation
448        let matrices = [&qa, &qb];
449        let params = [&qa_params, &qb_params];
450        let result = fused_quantized_matvec_sequence(&matrices, &params, &x.view(), false).unwrap();
451
452        // Verify correctness with tolerance for quantization error
453        assert_eq!(result.len(), expected.len());
454        for (i, &val) in result.iter().enumerate() {
455            assert_relative_eq!(val, expected[i], epsilon = 5.0);
456        }
457    }
458}