scirs2_linalg/quantization/
operations.rs

1//! Quantized linear algebra operations
2//!
3//! This module contains functions for performing linear algebra operations
4//! on quantized matrices and vectors, including matrix multiplication,
5//! matrix-vector multiplication, and dot products.
6
7use scirs2_core::ndarray::{Array1, Array2};
8
9use crate::error::{LinalgError, LinalgResult};
10
11use super::conversions::dequantize_matrix;
12use super::matrix::QuantizedMatrix;
13use super::types::{QuantizationMethod, QuantizationParams, QuantizedDataType};
14use super::vector::QuantizedVector;
15
16/// Perform matrix multiplication with two quantized matrices
17///
18/// # Arguments
19///
20/// * `a` - The first quantized matrix
21/// * `a_params` - Quantization parameters for the first matrix
22/// * `b` - The second quantized matrix
23/// * `b_params` - Quantization parameters for the second matrix
24///
25/// # Returns
26///
27/// The result of the matrix multiplication in floating-point
28pub fn quantized_matmul(
29    a: &QuantizedMatrix,
30    a_params: &QuantizationParams,
31    b: &QuantizedMatrix,
32    b_params: &QuantizationParams,
33) -> LinalgResult<Array2<f32>> {
34    // Check dimensions
35    if a.ncols() != b.nrows() {
36        return Err(LinalgError::DimensionError(format!(
37            "Cannot multiply matrices with shapes {:?} and {:?}",
38            a.shape(),
39            b.shape()
40        )));
41    }
42
43    let (m, k) = a.shape();
44    let (_, n) = b.shape();
45
46    // Create result matrix
47    let mut result = Array2::zeros((m, n));
48
49    // For floating point quantization types, we use floating point operations
50    if matches!(
51        a.data_type,
52        QuantizedDataType::Float16 | QuantizedDataType::BFloat16
53    ) || matches!(
54        b.data_type,
55        QuantizedDataType::Float16 | QuantizedDataType::BFloat16
56    ) {
57        // Perform floating-point matrix multiplication
58        for i in 0..m {
59            for j in 0..n {
60                let mut sum = 0.0_f32;
61                for l in 0..k {
62                    let a_val = a.get_f32(i, l);
63                    let b_val = b.get_f32(l, j);
64                    sum += a_val * b_val;
65                }
66                result[[i, j]] = sum;
67            }
68        }
69        return Ok(result);
70    }
71
72    // Check if either matrix uses per-channel quantization
73    let a_per_channel = a_params.method == QuantizationMethod::PerChannelSymmetric
74        || a_params.method == QuantizationMethod::PerChannelAffine;
75
76    let b_per_channel = b_params.method == QuantizationMethod::PerChannelSymmetric
77        || b_params.method == QuantizationMethod::PerChannelAffine;
78
79    // If either matrix uses per-channel quantization, we'll dequantize to f32 and do regular matmul
80    if a_per_channel || b_per_channel {
81        // Dequantize both matrices
82        let a_dequant = dequantize_matrix(a, a_params);
83        let b_dequant = dequantize_matrix(b, b_params);
84
85        // Perform standard matrix multiplication using dequantized matrices
86        for i in 0..m {
87            for j in 0..n {
88                let mut sum = 0.0_f32;
89                for l in 0..k {
90                    sum += a_dequant[[i, l]] * b_dequant[[l, j]];
91                }
92                result[[i, j]] = sum;
93            }
94        }
95
96        return Ok(result);
97    }
98
99    // For integer quantization, use the original approach
100    for i in 0..m {
101        for j in 0..n {
102            let mut sum = 0i32;
103            for l in 0..k {
104                // Use the get_i8 method for integer types
105                let a_val = a.get_i8(i, l) as i32;
106                let b_val = b.get_i8(l, j) as i32;
107                sum += a_val * b_val;
108            }
109
110            // Dequantize the result - scale is the same regardless of method
111            let a_scale = a_params.scale;
112            let b_scale = b_params.scale;
113
114            // Apply zero-point correction for affine quantization
115            if (a_params.method == QuantizationMethod::Affine
116                || a_params.method == QuantizationMethod::UInt4)
117                && (b_params.method == QuantizationMethod::Affine
118                    || b_params.method == QuantizationMethod::UInt4)
119            {
120                // For affine quantization, we need to correct for zero points
121                let a_zero_sum: i32 =
122                    (0..k).map(|l| b.get_i8(l, j) as i32).sum::<i32>() * a_params.zero_point;
123                let b_zero_sum: i32 =
124                    (0..k).map(|l| a.get_i8(i, l) as i32).sum::<i32>() * b_params.zero_point;
125                let zero_product = k as i32 * a_params.zero_point * b_params.zero_point;
126
127                sum = sum - a_zero_sum - b_zero_sum + zero_product;
128            }
129
130            result[[i, j]] = sum as f32 * a_scale * b_scale;
131        }
132    }
133
134    Ok(result)
135}
136
137/// Perform matrix-vector multiplication with quantized matrix and vector
138///
139/// # Arguments
140///
141/// * `a` - The quantized matrix
142/// * `a_params` - Quantization parameters for the matrix
143/// * `b` - The quantized vector
144/// * `b_params` - Quantization parameters for the vector
145///
146/// # Returns
147///
148/// The result of the matrix-vector multiplication in floating-point
149pub fn quantized_matvec(
150    a: &QuantizedMatrix,
151    a_params: &QuantizationParams,
152    b: &QuantizedVector,
153    b_params: &QuantizationParams,
154) -> LinalgResult<Array1<f32>> {
155    // Check dimensions
156    if a.ncols() != b.len() {
157        return Err(LinalgError::DimensionError(format!(
158            "Cannot multiply matrix with shape {:?} and vector with length {}",
159            a.shape(),
160            b.len()
161        )));
162    }
163
164    let m = a.nrows();
165    let n = a.ncols();
166
167    // Create result vector
168    let mut result = Array1::zeros(m);
169
170    // For floating point quantization types
171    if matches!(
172        a_params.data_type,
173        QuantizedDataType::Float16 | QuantizedDataType::BFloat16
174    ) && matches!(
175        b_params.data_type,
176        QuantizedDataType::Float16 | QuantizedDataType::BFloat16
177    ) {
178        // Dequantize and compute in floating point
179        let a_full = dequantize_matrix(a, a_params);
180        let b_full = dequantize_vector(b, b_params);
181
182        for i in 0..m {
183            let mut sum = 0.0_f32;
184            for j in 0..n {
185                sum += a_full[[i, j]] * b_full[j];
186            }
187            result[i] = sum;
188        }
189
190        return Ok(result);
191    }
192
193    // For integer quantization
194    let a_scale = a_params.scale;
195    let b_scale = b_params.scale;
196
197    for i in 0..m {
198        let mut sum: i32 = 0;
199
200        for j in 0..n {
201            let a_val = a.get_i8(i, j) as i32;
202            let b_val = b.get_i8(j) as i32;
203            sum += a_val * b_val;
204        }
205
206        result[i] = sum as f32 * a_scale * b_scale;
207    }
208
209    Ok(result)
210}
211
212/// Perform dot product with two quantized vectors
213///
214/// # Arguments
215///
216/// * `a` - The first quantized vector
217/// * `a_params` - Quantization parameters for the first vector
218/// * `b` - The second quantized vector
219/// * `b_params` - Quantization parameters for the second vector
220///
221/// # Returns
222///
223/// The dot product result in floating-point
224pub fn quantized_dot(
225    a: &QuantizedVector,
226    a_params: &QuantizationParams,
227    b: &QuantizedVector,
228    b_params: &QuantizationParams,
229) -> LinalgResult<f32> {
230    // Check dimensions
231    if a.len() != b.len() {
232        return Err(LinalgError::DimensionError(format!(
233            "Cannot compute dot product of vectors with lengths {} and {}",
234            a.len(),
235            b.len()
236        )));
237    }
238
239    let n = a.len();
240
241    // For floating point quantization types
242    if matches!(
243        a_params.data_type,
244        QuantizedDataType::Float16 | QuantizedDataType::BFloat16
245    ) && matches!(
246        b_params.data_type,
247        QuantizedDataType::Float16 | QuantizedDataType::BFloat16
248    ) {
249        // Dequantize and compute in floating point
250        let a_full = dequantize_vector(a, a_params);
251        let b_full = dequantize_vector(b, b_params);
252
253        let mut sum = 0.0_f32;
254        for i in 0..n {
255            sum += a_full[i] * b_full[i];
256        }
257
258        return Ok(sum);
259    }
260
261    // For integer quantization
262    let a_scale = a_params.scale;
263    let b_scale = b_params.scale;
264
265    let mut sum: i32 = 0;
266
267    for i in 0..n {
268        let a_val = a.get_i8(i) as i32;
269        let b_val = b.get_i8(i) as i32;
270        sum += a_val * b_val;
271    }
272
273    Ok(sum as f32 * a_scale * b_scale)
274}
275
276// Helper function to dequantize a vector
277fn dequantize_vector(vec: &QuantizedVector, _params: &QuantizationParams) -> Array1<f32> {
278    let n = vec.len();
279    let mut result = Array1::zeros(n);
280
281    for i in 0..n {
282        result[i] = vec.get_f32(i);
283    }
284
285    result
286}