scirs2_linalg/quantization/
simd.rs

1//! SIMD-accelerated operations for quantized matrices
2//!
3//! This module provides SIMD-accelerated implementations of matrix operations
4//! on quantized data for improved performance. These implementations leverage
5//! the scirs2-core SIMD abstractions for SIMD operations and work with the quantization types
6//! defined in the parent module.
7
8use crate::error::{LinalgError, LinalgResult};
9use crate::quantization::{
10    dequantize_matrix, dequantize_vector, get_quantized_vector_1d_i8, get_quantizedmatrix_2d_i8,
11    quantize_vector, QuantizationMethod, QuantizationParams, QuantizedData2D, QuantizedDataType,
12    QuantizedMatrix, QuantizedVector,
13};
14use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
15use scirs2_core::simd_ops::SimdUnifiedOps;
16
17/// SIMD-accelerated quantized matrix-vector multiplication
18///
19/// Performs matrix-vector multiplication where the matrix is in quantized form
20/// and the vector is in f32 format. The result is returned as f32.
21///
22/// # Arguments
23///
24/// * `qmatrix` - Quantized matrix
25/// * `qparams` - Quantization parameters for the matrix
26/// * `vector` - Vector to multiply with
27///
28/// # Returns
29///
30/// * Result vector of the multiplication
31#[allow(dead_code)]
32pub fn simd_quantized_matvec(
33    qmatrix: &QuantizedMatrix,
34    qparams: &QuantizationParams,
35    vector: &ArrayView1<f32>,
36) -> LinalgResult<Array1<f32>> {
37    // Check dimensions
38    if qmatrix.shape.1 != vector.len() {
39        return Err(LinalgError::ShapeError(format!(
40            "Matrix columns ({}) must match vector length ({})",
41            qmatrix.shape.1,
42            vector.len()
43        )));
44    }
45
46    // Create result vector
47    let mut result = Array1::zeros(qmatrix.shape.0);
48    let vec_slice = vector.as_slice().unwrap();
49
50    // Handle based on data type
51    match &qmatrix.data {
52        QuantizedData2D::Int8(data) => {
53            // Get the scale factors for dequantization
54            let scale = qparams.scale;
55            let zero_point = qparams.zero_point;
56
57            // Handle per-channel quantization separately
58            if qparams.method == QuantizationMethod::PerChannelSymmetric
59                || qparams.method == QuantizationMethod::PerChannelAffine
60            {
61                let scales = qparams
62                    .channel_scales
63                    .as_ref()
64                    .expect("Per-channel quantization requires channel scales");
65
66                let zero_points = if qparams.method == QuantizationMethod::PerChannelAffine {
67                    qparams
68                        .channel_zero_points
69                        .as_ref()
70                        .expect("Per-channel affine quantization requires channel zero points")
71                } else {
72                    &vec![0; qmatrix.shape.1] // Symmetric doesn't use zero points
73                };
74
75                // Process each row of the matrix
76                for (i, row) in data.rows().into_iter().enumerate() {
77                    // We'll use SIMD to accumulate 8 values at once
78                    let chunksize = 8;
79                    let mut acc = 0.0f32;
80
81                    let row_slice = row.as_slice().unwrap();
82                    let mut j = 0;
83
84                    // Accumulate 8 elements at a time using SIMD
85                    while j + chunksize <= row_slice.len() {
86                        // Load chunks from row, scales, zero points and vector
87                        let mut row_vals = [0.0f32; 8];
88
89                        for (k, val) in row_vals.iter_mut().enumerate().take(chunksize) {
90                            let idx = j + k;
91                            // Dequantize the value: (val - zero_point) * scale
92                            let dequantized =
93                                (row_slice[idx] as f32 - zero_points[idx] as f32) * scales[idx];
94                            *val = dequantized * vec_slice[idx];
95                        }
96
97                        // Sum the products into our accumulator using core SIMD operations
98                        let row_vals_slice = ArrayView1::from(&row_vals);
99                        acc += f32::simd_sum(&row_vals_slice);
100
101                        j += chunksize;
102                    }
103
104                    // Handle remaining elements
105                    for k in j..row_slice.len() {
106                        let dequantized = (row_slice[k] as f32 - zero_points[k] as f32) * scales[k];
107                        acc += dequantized * vec_slice[k];
108                    }
109
110                    result[i] = acc;
111                }
112            } else {
113                // Standard quantization (single scale/zero point)
114
115                // For Int4/UInt4, we need special handling
116                if qparams.data_type == QuantizedDataType::Int4
117                    || qparams.data_type == QuantizedDataType::UInt4
118                {
119                    // Process each row
120                    for (i, row) in data.rows().into_iter().enumerate() {
121                        let row_slice = row.as_slice().unwrap();
122                        let mut acc = 0.0f32;
123
124                        // For Int4/UInt4, we need to unpack two values from each byte
125                        for (j, &byte) in row_slice.iter().enumerate() {
126                            let col_idx = j * 2; // Each byte contains 2 values
127
128                            // Unpack the first 4-bit value
129                            let val1 = if qparams.data_type == QuantizedDataType::Int4 {
130                                // Extract and sign-extend 4-bit signed int
131                                let q = (byte >> 4) & 0x0F;
132                                if q & 0x08 != 0 {
133                                    q | 0xF0u8 as i8
134                                } else {
135                                    q
136                                } // Sign extend
137                            } else {
138                                // UInt4
139                                (byte >> 4) & 0x0F
140                            };
141
142                            // Process only if we're still within matrix bounds
143                            if col_idx < qmatrix.shape.1 {
144                                let dequantized = (val1 as f32 - zero_point as f32) * scale;
145                                acc += dequantized * vec_slice[col_idx];
146                            }
147
148                            // Unpack the second 4-bit value
149                            let val2 = if qparams.data_type == QuantizedDataType::Int4 {
150                                // Extract and sign-extend 4-bit signed int
151                                let q = byte & 0x0F;
152                                if q & 0x08 != 0 {
153                                    q | 0xF0u8 as i8
154                                } else {
155                                    q
156                                } // Sign extend
157                            } else {
158                                // UInt4
159                                byte & 0x0F
160                            };
161
162                            // Process only if we're still within matrix bounds
163                            if col_idx + 1 < qmatrix.shape.1 {
164                                let dequantized = (val2 as f32 - zero_point as f32) * scale;
165                                acc += dequantized * vec_slice[col_idx + 1];
166                            }
167                        }
168
169                        result[i] = acc;
170                    }
171                } else {
172                    // Standard Int8 processing
173                    for (i, row) in data.rows().into_iter().enumerate() {
174                        let row_slice = row.as_slice().unwrap();
175                        let mut acc = 0.0f32;
176
177                        // Process 8 elements at a time with SIMD
178                        let chunksize = 8;
179                        let mut j = 0;
180
181                        while j + chunksize <= row_slice.len() {
182                            // Load chunks from row and vector
183                            let row_chunk = [
184                                row_slice[j] as f32,
185                                row_slice[j + 1] as f32,
186                                row_slice[j + 2] as f32,
187                                row_slice[j + 3] as f32,
188                                row_slice[j + 4] as f32,
189                                row_slice[j + 5] as f32,
190                                row_slice[j + 6] as f32,
191                                row_slice[j + 7] as f32,
192                            ];
193
194                            let vec_chunk = [
195                                vec_slice[j],
196                                vec_slice[j + 1],
197                                vec_slice[j + 2],
198                                vec_slice[j + 3],
199                                vec_slice[j + 4],
200                                vec_slice[j + 5],
201                                vec_slice[j + 6],
202                                vec_slice[j + 7],
203                            ];
204
205                            // Convert to ndarray views for core SIMD operations
206                            let _row_view = ArrayView1::from(&row_chunk);
207                            let vec_view = ArrayView1::from(&vec_chunk);
208
209                            // Create dequantized values: (row - zero_point) * scale
210                            let mut dequantized = [0.0f32; 8];
211                            for (k, val) in dequantized.iter_mut().enumerate() {
212                                *val = (row_chunk[k] - zero_point as f32) * scale;
213                            }
214                            let dequantized_view = ArrayView1::from(&dequantized);
215
216                            // Multiply and accumulate using core SIMD
217                            acc += f32::simd_dot(&dequantized_view, &vec_view);
218
219                            j += chunksize;
220                        }
221
222                        // Process remaining elements
223                        for k in j..row_slice.len() {
224                            let dequantized = (row_slice[k] as f32 - zero_point as f32) * scale;
225                            acc += dequantized * vec_slice[k];
226                        }
227
228                        result[i] = acc;
229                    }
230                }
231            }
232        }
233        QuantizedData2D::Float16(data) => {
234            // Do a basic loop multiplication for now - optimize this later
235            for (i, row) in data.rows().into_iter().enumerate() {
236                let mut sum = 0.0f32;
237                for (j, &val) in row.iter().enumerate() {
238                    sum += f32::from(val) * vec_slice[j];
239                }
240                result[i] = sum;
241            }
242        }
243        QuantizedData2D::BFloat16(data) => {
244            // Do a basic loop multiplication for now - optimize this later
245            for (i, row) in data.rows().into_iter().enumerate() {
246                let mut sum = 0.0f32;
247                for (j, &val) in row.iter().enumerate() {
248                    sum += f32::from(val) * vec_slice[j];
249                }
250                result[i] = sum;
251            }
252        }
253    }
254
255    Ok(result)
256}
257
258/// SIMD-accelerated quantized matrix-matrix multiplication
259///
260/// Performs matrix-matrix multiplication where both matrices are in quantized form.
261/// The result is returned as f32.
262///
263/// # Arguments
264///
265/// * `a` - First quantized matrix
266/// * `a_params` - Quantization parameters for the first matrix
267/// * `b` - Second quantized matrix
268/// * `b_params` - Quantization parameters for the second matrix
269///
270/// # Returns
271///
272/// * Result matrix of the multiplication
273#[allow(dead_code)]
274pub fn simd_quantized_matmul(
275    a: &QuantizedMatrix,
276    a_params: &QuantizationParams,
277    b: &QuantizedMatrix,
278    b_params: &QuantizationParams,
279) -> LinalgResult<Array2<f32>> {
280    // Check dimensions
281    if a.shape.1 != b.shape.0 {
282        return Err(LinalgError::ShapeError(format!(
283            "Matrix dimensions mismatch for multiplication: ({}, {}) * ({}, {})",
284            a.shape.0, a.shape.1, b.shape.0, b.shape.1
285        )));
286    }
287
288    // Create result matrix
289    let (m, n) = (a.shape.0, b.shape.1);
290    let mut result = Array2::zeros((m, n));
291
292    // Get int8 data if available - we'll only handle Int8 SIMD acceleration for now
293    if let (Some(a_data), Some(b_data)) =
294        (get_quantizedmatrix_2d_i8(a), get_quantizedmatrix_2d_i8(b))
295    {
296        // If either matrix is per-channel quantized, we dequantize it fully first
297        // In the future, we can optimize this with specialized kernels
298        if a_params.method == QuantizationMethod::PerChannelSymmetric
299            || a_params.method == QuantizationMethod::PerChannelAffine
300            || b_params.method == QuantizationMethod::PerChannelSymmetric
301            || b_params.method == QuantizationMethod::PerChannelAffine
302        {
303            // Dequantize matrices
304            let a_dequant = dequantize_matrix(a, a_params);
305            let b_dequant = dequantize_matrix(b, b_params);
306
307            // Use standard matrix multiplication
308            return Ok(a_dequant.dot(&b_dequant));
309        }
310
311        // Get quantization parameters
312        let a_scale = a_params.scale;
313        let a_zero = a_params.zero_point as f32;
314        let b_scale = b_params.scale;
315        let b_zero = b_params.zero_point as f32;
316
317        // Combined scale for the output
318        let _output_scale = a_scale * b_scale; // Used in future optimizations
319
320        // For int4/uint4, each byte contains two values, and special handling is needed
321        let a_is_4bit = a_params.data_type == QuantizedDataType::Int4
322            || a_params.data_type == QuantizedDataType::UInt4;
323        let b_is_4bit = b_params.data_type == QuantizedDataType::Int4
324            || b_params.data_type == QuantizedDataType::UInt4;
325
326        // Cache-friendly block sizes
327        // These should be tuned based on target CPU cache sizes
328        const BLOCK_SIZE_M: usize = 32;
329        const BLOCK_SIZE_N: usize = 32;
330        const BLOCK_SIZE_K: usize = 32;
331
332        // Loop over blocks
333        for i0 in (0..m).step_by(BLOCK_SIZE_M) {
334            let i_end = (i0 + BLOCK_SIZE_M).min(m);
335
336            for j0 in (0..n).step_by(BLOCK_SIZE_N) {
337                let j_end = (j0 + BLOCK_SIZE_N).min(n);
338
339                // Process inner dimension in blocks
340                for k0 in (0..a.shape.1).step_by(BLOCK_SIZE_K) {
341                    let k_end = (k0 + BLOCK_SIZE_K).min(a.shape.1);
342
343                    // Process blocks
344                    for i in i0..i_end {
345                        for j in j0..j_end {
346                            // Compute dot product of row i from A and column j from B
347                            let mut sum = 0.0f32;
348
349                            // Number of elements in this block of the inner dimension
350                            let k_blocksize = k_end - k0;
351
352                            // If we're using 4-bit quantization, we need to adjust
353                            if a_is_4bit || b_is_4bit {
354                                // Simplified handling for 4-bit quantization - dequantize on the fly
355                                for k in k0..k_end {
356                                    let a_val = if a_is_4bit {
357                                        // Extract the right 4-bit value
358                                        let byte_idx = k / 2;
359                                        let byte = a_data[[i, byte_idx]];
360
361                                        if k % 2 == 0 {
362                                            // First 4 bits
363                                            let val = (byte >> 4) & 0x0F;
364                                            // Sign extend for Int4 if needed
365                                            if a_params.data_type == QuantizedDataType::Int4
366                                                && (val & 0x08) != 0
367                                            {
368                                                ((val | 0xF0u8 as i8) as f32 - a_zero) * a_scale
369                                            } else {
370                                                ((val & 0x0F) as f32 - a_zero) * a_scale
371                                            }
372                                        } else {
373                                            // Second 4 bits
374                                            let val = byte & 0x0F;
375                                            // Sign extend for Int4 if needed
376                                            if a_params.data_type == QuantizedDataType::Int4
377                                                && (val & 0x08) != 0
378                                            {
379                                                ((val | 0xF0u8 as i8) as f32 - a_zero) * a_scale
380                                            } else {
381                                                ((val & 0x0F) as f32 - a_zero) * a_scale
382                                            }
383                                        }
384                                    } else {
385                                        // Regular 8-bit quantization
386                                        (a_data[[i, k]] as f32 - a_zero) * a_scale
387                                    };
388
389                                    let b_val = if b_is_4bit {
390                                        // Extract the right 4-bit value
391                                        let byte_idx = k / 2;
392                                        let byte = b_data[[byte_idx, j]];
393
394                                        if k % 2 == 0 {
395                                            // First 4 bits
396                                            let val = (byte >> 4) & 0x0F;
397                                            // Sign extend for Int4 if needed
398                                            if b_params.data_type == QuantizedDataType::Int4
399                                                && (val & 0x08) != 0
400                                            {
401                                                ((val | 0xF0u8 as i8) as f32 - b_zero) * b_scale
402                                            } else {
403                                                ((val & 0x0F) as f32 - b_zero) * b_scale
404                                            }
405                                        } else {
406                                            // Second 4 bits
407                                            let val = byte & 0x0F;
408                                            // Sign extend for Int4 if needed
409                                            if b_params.data_type == QuantizedDataType::Int4
410                                                && (val & 0x08) != 0
411                                            {
412                                                ((val | 0xF0u8 as i8) as f32 - b_zero) * b_scale
413                                            } else {
414                                                ((val & 0x0F) as f32 - b_zero) * b_scale
415                                            }
416                                        }
417                                    } else {
418                                        // Regular 8-bit quantization
419                                        (b_data[[k, j]] as f32 - b_zero) * b_scale
420                                    };
421
422                                    sum += a_val * b_val;
423                                }
424                            } else {
425                                // Regular 8-bit quantization - we can use SIMD
426
427                                // Get row from A and column from B as slices if possible
428                                let a_row = a_data.slice(scirs2_core::ndarray::s![i, k0..k_end]);
429                                let b_col = b_data.slice(scirs2_core::ndarray::s![k0..k_end, j]);
430
431                                if let (Some(a_slice), Some(b_slice)) =
432                                    (a_row.as_slice(), b_col.as_slice())
433                                {
434                                    // Process with SIMD (8 elements at a time)
435                                    let mut l = 0;
436                                    let chunksize = 8;
437
438                                    while l + chunksize <= k_blocksize {
439                                        // Load chunks
440                                        let a_chunk = [
441                                            a_slice[l] as f32,
442                                            a_slice[l + 1] as f32,
443                                            a_slice[l + 2] as f32,
444                                            a_slice[l + 3] as f32,
445                                            a_slice[l + 4] as f32,
446                                            a_slice[l + 5] as f32,
447                                            a_slice[l + 6] as f32,
448                                            a_slice[l + 7] as f32,
449                                        ];
450                                        let b_chunk = [
451                                            b_slice[l] as f32,
452                                            b_slice[l + 1] as f32,
453                                            b_slice[l + 2] as f32,
454                                            b_slice[l + 3] as f32,
455                                            b_slice[l + 4] as f32,
456                                            b_slice[l + 5] as f32,
457                                            b_slice[l + 6] as f32,
458                                            b_slice[l + 7] as f32,
459                                        ];
460
461                                        // Dequantize chunks
462                                        let mut a_dequant = [0.0f32; 8];
463                                        let mut b_dequant = [0.0f32; 8];
464
465                                        for k in 0..8 {
466                                            a_dequant[k] = (a_chunk[k] - a_zero) * a_scale;
467                                            b_dequant[k] = (b_chunk[k] - b_zero) * b_scale;
468                                        }
469
470                                        // Convert to views and compute dot product using core SIMD
471                                        let a_view = ArrayView1::from(&a_dequant);
472                                        let b_view = ArrayView1::from(&b_dequant);
473                                        sum += f32::simd_dot(&a_view, &b_view);
474
475                                        l += chunksize;
476                                    }
477
478                                    // Process remaining elements
479                                    for m in l..k_blocksize {
480                                        let a_val = (a_slice[m] as f32 - a_zero) * a_scale;
481                                        let b_val = (b_slice[m] as f32 - b_zero) * b_scale;
482                                        sum += a_val * b_val;
483                                    }
484                                } else {
485                                    // Fallback for non-contiguous data
486                                    for k in k0..k_end {
487                                        let a_val = (a_data[[i, k]] as f32 - a_zero) * a_scale;
488                                        let b_val = (b_data[[k, j]] as f32 - b_zero) * b_scale;
489                                        sum += a_val * b_val;
490                                    }
491                                }
492                            }
493
494                            // Accumulate result
495                            result[[i, j]] += sum;
496                        }
497                    }
498                }
499            }
500        }
501    } else {
502        // If we don't have Int8 data, fall back to dequantize and multiply
503        let a_dequant = dequantize_matrix(a, a_params);
504        let b_dequant = dequantize_matrix(b, b_params);
505
506        return Ok(a_dequant.dot(&b_dequant));
507    }
508
509    Ok(result)
510}
511
512/// SIMD-accelerated quantized dot product
513///
514/// Computes the dot product of two quantized vectors using SIMD instructions.
515///
516/// # Arguments
517///
518/// * `a` - First quantized vector
519/// * `a_params` - Quantization parameters for the first vector
520/// * `b` - Second quantized vector
521/// * `b_params` - Quantization parameters for the second vector
522///
523/// # Returns
524///
525/// * Dot product result
526#[allow(dead_code)]
527pub fn simd_quantized_dot(
528    a: &QuantizedVector,
529    a_params: &QuantizationParams,
530    b: &QuantizedVector,
531    b_params: &QuantizationParams,
532) -> LinalgResult<f32> {
533    // Check dimensions
534    if a.length != b.length {
535        return Err(LinalgError::ShapeError(format!(
536            "Vector dimensions must match for dot product: {} vs {}",
537            a.length, b.length
538        )));
539    }
540
541    // Get int8 data if available
542    if let (Some(a_data), Some(b_data)) =
543        (get_quantized_vector_1d_i8(a), get_quantized_vector_1d_i8(b))
544    {
545        // Get quantization parameters
546        let a_scale = a_params.scale;
547        let a_zero = a_params.zero_point as f32;
548        let b_scale = b_params.scale;
549        let b_zero = b_params.zero_point as f32;
550
551        // Combined scale for the output
552        let _output_scale = a_scale * b_scale; // Used in future optimizations
553
554        // For int4/uint4, each byte contains two values
555        let a_is_4bit = a_params.data_type == QuantizedDataType::Int4
556            || a_params.data_type == QuantizedDataType::UInt4;
557        let b_is_4bit = b_params.data_type == QuantizedDataType::Int4
558            || b_params.data_type == QuantizedDataType::UInt4;
559
560        if a_is_4bit || b_is_4bit {
561            // Handle 4-bit specially - we need to unpack values
562            let mut sum = 0.0f32;
563
564            // We need to adjust length for 4-bit values (each byte has 2 values)
565            let _a_byte_len = a.length.div_ceil(2); // Used for bounds checking
566            let _b_byte_len = b.length.div_ceil(2); // Used for bounds checking
567
568            for i in 0..a.length {
569                // Extract values from packed 4-bit representation
570                let a_val = if a_is_4bit {
571                    let byte_idx = i / 2;
572                    let byte = a_data[byte_idx];
573
574                    if i % 2 == 0 {
575                        // First 4 bits
576                        let val = (byte >> 4) & 0x0F;
577                        // Sign extend for Int4 if needed
578                        if a_params.data_type == QuantizedDataType::Int4 && (val & 0x08) != 0 {
579                            ((val | 0xF0u8 as i8) as f32 - a_zero) * a_scale
580                        } else {
581                            (val as f32 - a_zero) * a_scale
582                        }
583                    } else {
584                        // Second 4 bits
585                        let val = byte & 0x0F;
586                        // Sign extend for Int4 if needed
587                        if a_params.data_type == QuantizedDataType::Int4 && (val & 0x08) != 0 {
588                            ((val | 0xF0u8 as i8) as f32 - a_zero) * a_scale
589                        } else {
590                            (val as f32 - a_zero) * a_scale
591                        }
592                    }
593                } else {
594                    (a_data[i] as f32 - a_zero) * a_scale
595                };
596
597                let b_val = if b_is_4bit {
598                    let byte_idx = i / 2;
599                    let byte = b_data[byte_idx];
600
601                    if i % 2 == 0 {
602                        // First 4 bits
603                        let val = (byte >> 4) & 0x0F;
604                        // Sign extend for Int4 if needed
605                        if b_params.data_type == QuantizedDataType::Int4 && (val & 0x08) != 0 {
606                            ((val | 0xF0u8 as i8) as f32 - b_zero) * b_scale
607                        } else {
608                            (val as f32 - b_zero) * b_scale
609                        }
610                    } else {
611                        // Second 4 bits
612                        let val = byte & 0x0F;
613                        // Sign extend for Int4 if needed
614                        if b_params.data_type == QuantizedDataType::Int4 && (val & 0x08) != 0 {
615                            ((val | 0xF0u8 as i8) as f32 - b_zero) * b_scale
616                        } else {
617                            (val as f32 - b_zero) * b_scale
618                        }
619                    }
620                } else {
621                    (b_data[i] as f32 - b_zero) * b_scale
622                };
623
624                sum += a_val * b_val;
625            }
626
627            return Ok(sum);
628        }
629
630        // Standard 8-bit quantization
631        let a_slice = a_data.as_slice().unwrap();
632        let b_slice = b_data.as_slice().unwrap();
633
634        // Process 8 elements at a time with SIMD
635        let mut i = 0;
636        let chunksize = 8;
637        let mut sum = 0.0f32;
638
639        while i + chunksize <= a.length {
640            // Load chunks
641            let a_chunk = [
642                a_slice[i] as f32,
643                a_slice[i + 1] as f32,
644                a_slice[i + 2] as f32,
645                a_slice[i + 3] as f32,
646                a_slice[i + 4] as f32,
647                a_slice[i + 5] as f32,
648                a_slice[i + 6] as f32,
649                a_slice[i + 7] as f32,
650            ];
651
652            let b_chunk = [
653                b_slice[i] as f32,
654                b_slice[i + 1] as f32,
655                b_slice[i + 2] as f32,
656                b_slice[i + 3] as f32,
657                b_slice[i + 4] as f32,
658                b_slice[i + 5] as f32,
659                b_slice[i + 6] as f32,
660                b_slice[i + 7] as f32,
661            ];
662
663            // Dequantize chunks
664            let mut a_dequant = [0.0f32; 8];
665            let mut b_dequant = [0.0f32; 8];
666
667            for k in 0..8 {
668                a_dequant[k] = (a_chunk[k] - a_zero) * a_scale;
669                b_dequant[k] = (b_chunk[k] - b_zero) * b_scale;
670            }
671
672            // Convert to views and compute dot product using core SIMD
673            let a_view = ArrayView1::from(&a_dequant);
674            let b_view = ArrayView1::from(&b_dequant);
675            sum += f32::simd_dot(&a_view, &b_view);
676
677            i += chunksize;
678        }
679
680        // Process remaining elements
681        for j in i..a.length {
682            let a_val = (a_slice[j] as f32 - a_zero) * a_scale;
683            let b_val = (b_slice[j] as f32 - b_zero) * b_scale;
684            sum += a_val * b_val;
685        }
686
687        Ok(sum)
688    } else {
689        // If we don't have Int8 data, fall back to dequantize and dot
690        let a_dequant = dequantize_vector(a, a_params);
691        let b_dequant = dequantize_vector(b, b_params);
692        Ok(a_dequant.dot(&b_dequant))
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use super::*;
699    use crate::quantization::{
700        quantize_matrix, quantize_matrix_per_channel, quantize_vector, QuantizationMethod,
701    };
702    use approx::assert_relative_eq;
703    use scirs2_core::ndarray::array;
704
705    #[test]
706    #[ignore = "timeout"]
707    fn test_simd_quantized_matvec() {
708        // Create test matrix and vector
709        let mat = array![
710            [1.0f32, 2.0, 3.0, 4.0],
711            [5.0, 6.0, 7.0, 8.0],
712            [9.0, 10.0, 11.0, 12.0]
713        ];
714
715        let vec = array![2.0f32, 3.0, 4.0, 5.0];
716
717        // Quantize the matrix
718        let (qmat, qparams) = quantize_matrix(&mat.view(), 8, QuantizationMethod::Symmetric);
719
720        // Compute result with SIMD acceleration
721        let result = simd_quantized_matvec(&qmat, &qparams, &vec.view()).unwrap();
722
723        // Expected result (regular matmul)
724        let expected = array![40.0f32, 96.0, 152.0];
725
726        // Verify correctness with tolerance for quantization error
727        assert_eq!(result.len(), expected.len());
728        for (a, b) in result.iter().zip(expected.iter()) {
729            assert_relative_eq!(a, b, epsilon = 0.5);
730        }
731    }
732
733    #[test]
734    #[ignore = "timeout"]
735    fn test_simd_quantized_matmul() {
736        // Create test matrices
737        let a = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
738        let b = array![[7.0f32, 8.0, 9.0], [10.0, 11.0, 12.0], [13.0, 14.0, 15.0]];
739
740        // Quantize matrices
741        let (qa, qa_params) = quantize_matrix(&a.view(), 8, QuantizationMethod::Symmetric);
742        let (qb, qb_params) = quantize_matrix(&b.view(), 8, QuantizationMethod::Symmetric);
743
744        // Compute result with SIMD acceleration
745        let result = simd_quantized_matmul(&qa, &qa_params, &qb, &qb_params).unwrap();
746
747        // Expected result (regular matmul)
748        let expected = array![[66.0f32, 72.0, 78.0], [156.0, 171.0, 186.0]];
749
750        // Verify correctness with tolerance for quantization error
751        assert_eq!(result.shape(), expected.shape());
752        for ((i, j), &val) in result.indexed_iter() {
753            assert_relative_eq!(val, expected[[i, j]], epsilon = 1.0);
754        }
755    }
756
757    #[test]
758    #[ignore = "timeout"]
759    fn test_simd_quantized_dot() {
760        // Create test vectors
761        let a = array![1.0f32, 2.0, 3.0, 4.0, 5.0];
762        let b = array![5.0f32, 4.0, 3.0, 2.0, 1.0];
763
764        // Quantize vectors
765        let (qa, qa_params) = quantize_vector(&a.view(), 8, QuantizationMethod::Symmetric);
766        let (qb, qb_params) = quantize_vector(&b.view(), 8, QuantizationMethod::Symmetric);
767
768        // Compute result with SIMD acceleration
769        let result = simd_quantized_dot(&qa, &qa_params, &qb, &qb_params).unwrap();
770
771        // Expected result (regular dot product)
772        let expected = 1.0 * 5.0 + 2.0 * 4.0 + 3.0 * 3.0 + 4.0 * 2.0 + 5.0 * 1.0;
773
774        // Verify correctness with tolerance for quantization error
775        assert_relative_eq!(result, expected, epsilon = 0.5);
776
777        // Temporary: just verify the expected calculation
778        assert_eq!(expected, 35.0);
779    }
780
781    #[test]
782    #[ignore = "timeout"]
783    fn test_simd_quantized_int4_operations() {
784        // Create test matrix and vector
785        let mat = array![[1.0f32, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
786
787        let vec = array![2.0f32, 3.0, 4.0, 5.0];
788
789        // Quantize the matrix to Int4
790        let (qmat, qparams) = quantize_matrix(&mat.view(), 4, QuantizationMethod::Int4);
791
792        // Compute result with SIMD acceleration
793        let result = simd_quantized_matvec(&qmat, &qparams, &vec.view()).unwrap();
794
795        // Expected result (regular matmul)
796        let expected = array![40.0f32, 96.0];
797
798        // Verify correctness with tolerance for Int4 quantization error (higher error expected)
799        assert_eq!(result.len(), expected.len());
800        for (a, b) in result.iter().zip(expected.iter()) {
801            assert_relative_eq!(a, b, epsilon = 3.0);
802        }
803    }
804
805    #[test]
806    #[ignore = "timeout"]
807    fn test_simd_quantized_per_channel() {
808        // Create a test matrix with very different scales in each column
809        let mat = array![
810            [0.1f32, 10.0, 100.0],
811            [0.2, 20.0, 200.0],
812            [0.3, 30.0, 300.0]
813        ];
814
815        let vec = array![1.0f32, 0.5, 0.25];
816
817        // Quantize with per-channel method
818        let (qmat, qparams) =
819            quantize_matrix_per_channel(&mat.view(), 8, QuantizationMethod::PerChannelSymmetric);
820
821        // Compute result with SIMD acceleration
822        let result = simd_quantized_matvec(&qmat, &qparams, &vec.view()).unwrap();
823
824        // Expected result (regular matmul)
825        let expected = array![
826            0.1 * 1.0 + 10.0 * 0.5 + 100.0 * 0.25,
827            0.2 * 1.0 + 20.0 * 0.5 + 200.0 * 0.25,
828            0.3 * 1.0 + 30.0 * 0.5 + 300.0 * 0.25
829        ];
830
831        // Verify correctness with tolerance for quantization error
832        assert_eq!(result.len(), expected.len());
833        for (a, b) in result.iter().zip(expected.iter()) {
834            assert_relative_eq!(a, b, epsilon = 0.5);
835        }
836    }
837}