scirs2_linalg/quantization/simd.rs
1//! SIMD-accelerated operations for quantized matrices
2//!
3//! This module provides SIMD-accelerated implementations of matrix operations
4//! on quantized data for improved performance. These implementations leverage
5//! the scirs2-core SIMD abstractions for SIMD operations and work with the quantization types
6//! defined in the parent module.
7
8use crate::error::{LinalgError, LinalgResult};
9use crate::quantization::{
10 dequantize_matrix, dequantize_vector, get_quantized_vector_1d_i8, get_quantizedmatrix_2d_i8,
11 quantize_vector, QuantizationMethod, QuantizationParams, QuantizedData2D, QuantizedDataType,
12 QuantizedMatrix, QuantizedVector,
13};
14use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
15use scirs2_core::simd_ops::SimdUnifiedOps;
16
17/// SIMD-accelerated quantized matrix-vector multiplication
18///
19/// Performs matrix-vector multiplication where the matrix is in quantized form
20/// and the vector is in f32 format. The result is returned as f32.
21///
22/// # Arguments
23///
24/// * `qmatrix` - Quantized matrix
25/// * `qparams` - Quantization parameters for the matrix
26/// * `vector` - Vector to multiply with
27///
28/// # Returns
29///
30/// * Result vector of the multiplication
31#[allow(dead_code)]
32pub fn simd_quantized_matvec(
33 qmatrix: &QuantizedMatrix,
34 qparams: &QuantizationParams,
35 vector: &ArrayView1<f32>,
36) -> LinalgResult<Array1<f32>> {
37 // Check dimensions
38 if qmatrix.shape.1 != vector.len() {
39 return Err(LinalgError::ShapeError(format!(
40 "Matrix columns ({}) must match vector length ({})",
41 qmatrix.shape.1,
42 vector.len()
43 )));
44 }
45
46 // Create result vector
47 let mut result = Array1::zeros(qmatrix.shape.0);
48 let vec_slice = vector.as_slice().unwrap();
49
50 // Handle based on data type
51 match &qmatrix.data {
52 QuantizedData2D::Int8(data) => {
53 // Get the scale factors for dequantization
54 let scale = qparams.scale;
55 let zero_point = qparams.zero_point;
56
57 // Handle per-channel quantization separately
58 if qparams.method == QuantizationMethod::PerChannelSymmetric
59 || qparams.method == QuantizationMethod::PerChannelAffine
60 {
61 let scales = qparams
62 .channel_scales
63 .as_ref()
64 .expect("Per-channel quantization requires channel scales");
65
66 let zero_points = if qparams.method == QuantizationMethod::PerChannelAffine {
67 qparams
68 .channel_zero_points
69 .as_ref()
70 .expect("Per-channel affine quantization requires channel zero points")
71 } else {
72 &vec![0; qmatrix.shape.1] // Symmetric doesn't use zero points
73 };
74
75 // Process each row of the matrix
76 for (i, row) in data.rows().into_iter().enumerate() {
77 // We'll use SIMD to accumulate 8 values at once
78 let chunksize = 8;
79 let mut acc = 0.0f32;
80
81 let row_slice = row.as_slice().unwrap();
82 let mut j = 0;
83
84 // Accumulate 8 elements at a time using SIMD
85 while j + chunksize <= row_slice.len() {
86 // Load chunks from row, scales, zero points and vector
87 let mut row_vals = [0.0f32; 8];
88
89 for (k, val) in row_vals.iter_mut().enumerate().take(chunksize) {
90 let idx = j + k;
91 // Dequantize the value: (val - zero_point) * scale
92 let dequantized =
93 (row_slice[idx] as f32 - zero_points[idx] as f32) * scales[idx];
94 *val = dequantized * vec_slice[idx];
95 }
96
97 // Sum the products into our accumulator using core SIMD operations
98 let row_vals_slice = ArrayView1::from(&row_vals);
99 acc += f32::simd_sum(&row_vals_slice);
100
101 j += chunksize;
102 }
103
104 // Handle remaining elements
105 for k in j..row_slice.len() {
106 let dequantized = (row_slice[k] as f32 - zero_points[k] as f32) * scales[k];
107 acc += dequantized * vec_slice[k];
108 }
109
110 result[i] = acc;
111 }
112 } else {
113 // Standard quantization (single scale/zero point)
114
115 // For Int4/UInt4, we need special handling
116 if qparams.data_type == QuantizedDataType::Int4
117 || qparams.data_type == QuantizedDataType::UInt4
118 {
119 // Process each row
120 for (i, row) in data.rows().into_iter().enumerate() {
121 let row_slice = row.as_slice().unwrap();
122 let mut acc = 0.0f32;
123
124 // For Int4/UInt4, we need to unpack two values from each byte
125 for (j, &byte) in row_slice.iter().enumerate() {
126 let col_idx = j * 2; // Each byte contains 2 values
127
128 // Unpack the first 4-bit value
129 let val1 = if qparams.data_type == QuantizedDataType::Int4 {
130 // Extract and sign-extend 4-bit signed int
131 let q = (byte >> 4) & 0x0F;
132 if q & 0x08 != 0 {
133 q | 0xF0u8 as i8
134 } else {
135 q
136 } // Sign extend
137 } else {
138 // UInt4
139 (byte >> 4) & 0x0F
140 };
141
142 // Process only if we're still within matrix bounds
143 if col_idx < qmatrix.shape.1 {
144 let dequantized = (val1 as f32 - zero_point as f32) * scale;
145 acc += dequantized * vec_slice[col_idx];
146 }
147
148 // Unpack the second 4-bit value
149 let val2 = if qparams.data_type == QuantizedDataType::Int4 {
150 // Extract and sign-extend 4-bit signed int
151 let q = byte & 0x0F;
152 if q & 0x08 != 0 {
153 q | 0xF0u8 as i8
154 } else {
155 q
156 } // Sign extend
157 } else {
158 // UInt4
159 byte & 0x0F
160 };
161
162 // Process only if we're still within matrix bounds
163 if col_idx + 1 < qmatrix.shape.1 {
164 let dequantized = (val2 as f32 - zero_point as f32) * scale;
165 acc += dequantized * vec_slice[col_idx + 1];
166 }
167 }
168
169 result[i] = acc;
170 }
171 } else {
172 // Standard Int8 processing
173 for (i, row) in data.rows().into_iter().enumerate() {
174 let row_slice = row.as_slice().unwrap();
175 let mut acc = 0.0f32;
176
177 // Process 8 elements at a time with SIMD
178 let chunksize = 8;
179 let mut j = 0;
180
181 while j + chunksize <= row_slice.len() {
182 // Load chunks from row and vector
183 let row_chunk = [
184 row_slice[j] as f32,
185 row_slice[j + 1] as f32,
186 row_slice[j + 2] as f32,
187 row_slice[j + 3] as f32,
188 row_slice[j + 4] as f32,
189 row_slice[j + 5] as f32,
190 row_slice[j + 6] as f32,
191 row_slice[j + 7] as f32,
192 ];
193
194 let vec_chunk = [
195 vec_slice[j],
196 vec_slice[j + 1],
197 vec_slice[j + 2],
198 vec_slice[j + 3],
199 vec_slice[j + 4],
200 vec_slice[j + 5],
201 vec_slice[j + 6],
202 vec_slice[j + 7],
203 ];
204
205 // Convert to ndarray views for core SIMD operations
206 let _row_view = ArrayView1::from(&row_chunk);
207 let vec_view = ArrayView1::from(&vec_chunk);
208
209 // Create dequantized values: (row - zero_point) * scale
210 let mut dequantized = [0.0f32; 8];
211 for (k, val) in dequantized.iter_mut().enumerate() {
212 *val = (row_chunk[k] - zero_point as f32) * scale;
213 }
214 let dequantized_view = ArrayView1::from(&dequantized);
215
216 // Multiply and accumulate using core SIMD
217 acc += f32::simd_dot(&dequantized_view, &vec_view);
218
219 j += chunksize;
220 }
221
222 // Process remaining elements
223 for k in j..row_slice.len() {
224 let dequantized = (row_slice[k] as f32 - zero_point as f32) * scale;
225 acc += dequantized * vec_slice[k];
226 }
227
228 result[i] = acc;
229 }
230 }
231 }
232 }
233 QuantizedData2D::Float16(data) => {
234 // Do a basic loop multiplication for now - optimize this later
235 for (i, row) in data.rows().into_iter().enumerate() {
236 let mut sum = 0.0f32;
237 for (j, &val) in row.iter().enumerate() {
238 sum += f32::from(val) * vec_slice[j];
239 }
240 result[i] = sum;
241 }
242 }
243 QuantizedData2D::BFloat16(data) => {
244 // Do a basic loop multiplication for now - optimize this later
245 for (i, row) in data.rows().into_iter().enumerate() {
246 let mut sum = 0.0f32;
247 for (j, &val) in row.iter().enumerate() {
248 sum += f32::from(val) * vec_slice[j];
249 }
250 result[i] = sum;
251 }
252 }
253 }
254
255 Ok(result)
256}
257
258/// SIMD-accelerated quantized matrix-matrix multiplication
259///
260/// Performs matrix-matrix multiplication where both matrices are in quantized form.
261/// The result is returned as f32.
262///
263/// # Arguments
264///
265/// * `a` - First quantized matrix
266/// * `a_params` - Quantization parameters for the first matrix
267/// * `b` - Second quantized matrix
268/// * `b_params` - Quantization parameters for the second matrix
269///
270/// # Returns
271///
272/// * Result matrix of the multiplication
273#[allow(dead_code)]
274pub fn simd_quantized_matmul(
275 a: &QuantizedMatrix,
276 a_params: &QuantizationParams,
277 b: &QuantizedMatrix,
278 b_params: &QuantizationParams,
279) -> LinalgResult<Array2<f32>> {
280 // Check dimensions
281 if a.shape.1 != b.shape.0 {
282 return Err(LinalgError::ShapeError(format!(
283 "Matrix dimensions mismatch for multiplication: ({}, {}) * ({}, {})",
284 a.shape.0, a.shape.1, b.shape.0, b.shape.1
285 )));
286 }
287
288 // Create result matrix
289 let (m, n) = (a.shape.0, b.shape.1);
290 let mut result = Array2::zeros((m, n));
291
292 // Get int8 data if available - we'll only handle Int8 SIMD acceleration for now
293 if let (Some(a_data), Some(b_data)) =
294 (get_quantizedmatrix_2d_i8(a), get_quantizedmatrix_2d_i8(b))
295 {
296 // If either matrix is per-channel quantized, we dequantize it fully first
297 // In the future, we can optimize this with specialized kernels
298 if a_params.method == QuantizationMethod::PerChannelSymmetric
299 || a_params.method == QuantizationMethod::PerChannelAffine
300 || b_params.method == QuantizationMethod::PerChannelSymmetric
301 || b_params.method == QuantizationMethod::PerChannelAffine
302 {
303 // Dequantize matrices
304 let a_dequant = dequantize_matrix(a, a_params);
305 let b_dequant = dequantize_matrix(b, b_params);
306
307 // Use standard matrix multiplication
308 return Ok(a_dequant.dot(&b_dequant));
309 }
310
311 // Get quantization parameters
312 let a_scale = a_params.scale;
313 let a_zero = a_params.zero_point as f32;
314 let b_scale = b_params.scale;
315 let b_zero = b_params.zero_point as f32;
316
317 // Combined scale for the output
318 let _output_scale = a_scale * b_scale; // Used in future optimizations
319
320 // For int4/uint4, each byte contains two values, and special handling is needed
321 let a_is_4bit = a_params.data_type == QuantizedDataType::Int4
322 || a_params.data_type == QuantizedDataType::UInt4;
323 let b_is_4bit = b_params.data_type == QuantizedDataType::Int4
324 || b_params.data_type == QuantizedDataType::UInt4;
325
326 // Cache-friendly block sizes
327 // These should be tuned based on target CPU cache sizes
328 const BLOCK_SIZE_M: usize = 32;
329 const BLOCK_SIZE_N: usize = 32;
330 const BLOCK_SIZE_K: usize = 32;
331
332 // Loop over blocks
333 for i0 in (0..m).step_by(BLOCK_SIZE_M) {
334 let i_end = (i0 + BLOCK_SIZE_M).min(m);
335
336 for j0 in (0..n).step_by(BLOCK_SIZE_N) {
337 let j_end = (j0 + BLOCK_SIZE_N).min(n);
338
339 // Process inner dimension in blocks
340 for k0 in (0..a.shape.1).step_by(BLOCK_SIZE_K) {
341 let k_end = (k0 + BLOCK_SIZE_K).min(a.shape.1);
342
343 // Process blocks
344 for i in i0..i_end {
345 for j in j0..j_end {
346 // Compute dot product of row i from A and column j from B
347 let mut sum = 0.0f32;
348
349 // Number of elements in this block of the inner dimension
350 let k_blocksize = k_end - k0;
351
352 // If we're using 4-bit quantization, we need to adjust
353 if a_is_4bit || b_is_4bit {
354 // Simplified handling for 4-bit quantization - dequantize on the fly
355 for k in k0..k_end {
356 let a_val = if a_is_4bit {
357 // Extract the right 4-bit value
358 let byte_idx = k / 2;
359 let byte = a_data[[i, byte_idx]];
360
361 if k % 2 == 0 {
362 // First 4 bits
363 let val = (byte >> 4) & 0x0F;
364 // Sign extend for Int4 if needed
365 if a_params.data_type == QuantizedDataType::Int4
366 && (val & 0x08) != 0
367 {
368 ((val | 0xF0u8 as i8) as f32 - a_zero) * a_scale
369 } else {
370 ((val & 0x0F) as f32 - a_zero) * a_scale
371 }
372 } else {
373 // Second 4 bits
374 let val = byte & 0x0F;
375 // Sign extend for Int4 if needed
376 if a_params.data_type == QuantizedDataType::Int4
377 && (val & 0x08) != 0
378 {
379 ((val | 0xF0u8 as i8) as f32 - a_zero) * a_scale
380 } else {
381 ((val & 0x0F) as f32 - a_zero) * a_scale
382 }
383 }
384 } else {
385 // Regular 8-bit quantization
386 (a_data[[i, k]] as f32 - a_zero) * a_scale
387 };
388
389 let b_val = if b_is_4bit {
390 // Extract the right 4-bit value
391 let byte_idx = k / 2;
392 let byte = b_data[[byte_idx, j]];
393
394 if k % 2 == 0 {
395 // First 4 bits
396 let val = (byte >> 4) & 0x0F;
397 // Sign extend for Int4 if needed
398 if b_params.data_type == QuantizedDataType::Int4
399 && (val & 0x08) != 0
400 {
401 ((val | 0xF0u8 as i8) as f32 - b_zero) * b_scale
402 } else {
403 ((val & 0x0F) as f32 - b_zero) * b_scale
404 }
405 } else {
406 // Second 4 bits
407 let val = byte & 0x0F;
408 // Sign extend for Int4 if needed
409 if b_params.data_type == QuantizedDataType::Int4
410 && (val & 0x08) != 0
411 {
412 ((val | 0xF0u8 as i8) as f32 - b_zero) * b_scale
413 } else {
414 ((val & 0x0F) as f32 - b_zero) * b_scale
415 }
416 }
417 } else {
418 // Regular 8-bit quantization
419 (b_data[[k, j]] as f32 - b_zero) * b_scale
420 };
421
422 sum += a_val * b_val;
423 }
424 } else {
425 // Regular 8-bit quantization - we can use SIMD
426
427 // Get row from A and column from B as slices if possible
428 let a_row = a_data.slice(scirs2_core::ndarray::s![i, k0..k_end]);
429 let b_col = b_data.slice(scirs2_core::ndarray::s![k0..k_end, j]);
430
431 if let (Some(a_slice), Some(b_slice)) =
432 (a_row.as_slice(), b_col.as_slice())
433 {
434 // Process with SIMD (8 elements at a time)
435 let mut l = 0;
436 let chunksize = 8;
437
438 while l + chunksize <= k_blocksize {
439 // Load chunks
440 let a_chunk = [
441 a_slice[l] as f32,
442 a_slice[l + 1] as f32,
443 a_slice[l + 2] as f32,
444 a_slice[l + 3] as f32,
445 a_slice[l + 4] as f32,
446 a_slice[l + 5] as f32,
447 a_slice[l + 6] as f32,
448 a_slice[l + 7] as f32,
449 ];
450 let b_chunk = [
451 b_slice[l] as f32,
452 b_slice[l + 1] as f32,
453 b_slice[l + 2] as f32,
454 b_slice[l + 3] as f32,
455 b_slice[l + 4] as f32,
456 b_slice[l + 5] as f32,
457 b_slice[l + 6] as f32,
458 b_slice[l + 7] as f32,
459 ];
460
461 // Dequantize chunks
462 let mut a_dequant = [0.0f32; 8];
463 let mut b_dequant = [0.0f32; 8];
464
465 for k in 0..8 {
466 a_dequant[k] = (a_chunk[k] - a_zero) * a_scale;
467 b_dequant[k] = (b_chunk[k] - b_zero) * b_scale;
468 }
469
470 // Convert to views and compute dot product using core SIMD
471 let a_view = ArrayView1::from(&a_dequant);
472 let b_view = ArrayView1::from(&b_dequant);
473 sum += f32::simd_dot(&a_view, &b_view);
474
475 l += chunksize;
476 }
477
478 // Process remaining elements
479 for m in l..k_blocksize {
480 let a_val = (a_slice[m] as f32 - a_zero) * a_scale;
481 let b_val = (b_slice[m] as f32 - b_zero) * b_scale;
482 sum += a_val * b_val;
483 }
484 } else {
485 // Fallback for non-contiguous data
486 for k in k0..k_end {
487 let a_val = (a_data[[i, k]] as f32 - a_zero) * a_scale;
488 let b_val = (b_data[[k, j]] as f32 - b_zero) * b_scale;
489 sum += a_val * b_val;
490 }
491 }
492 }
493
494 // Accumulate result
495 result[[i, j]] += sum;
496 }
497 }
498 }
499 }
500 }
501 } else {
502 // If we don't have Int8 data, fall back to dequantize and multiply
503 let a_dequant = dequantize_matrix(a, a_params);
504 let b_dequant = dequantize_matrix(b, b_params);
505
506 return Ok(a_dequant.dot(&b_dequant));
507 }
508
509 Ok(result)
510}
511
512/// SIMD-accelerated quantized dot product
513///
514/// Computes the dot product of two quantized vectors using SIMD instructions.
515///
516/// # Arguments
517///
518/// * `a` - First quantized vector
519/// * `a_params` - Quantization parameters for the first vector
520/// * `b` - Second quantized vector
521/// * `b_params` - Quantization parameters for the second vector
522///
523/// # Returns
524///
525/// * Dot product result
526#[allow(dead_code)]
527pub fn simd_quantized_dot(
528 a: &QuantizedVector,
529 a_params: &QuantizationParams,
530 b: &QuantizedVector,
531 b_params: &QuantizationParams,
532) -> LinalgResult<f32> {
533 // Check dimensions
534 if a.length != b.length {
535 return Err(LinalgError::ShapeError(format!(
536 "Vector dimensions must match for dot product: {} vs {}",
537 a.length, b.length
538 )));
539 }
540
541 // Get int8 data if available
542 if let (Some(a_data), Some(b_data)) =
543 (get_quantized_vector_1d_i8(a), get_quantized_vector_1d_i8(b))
544 {
545 // Get quantization parameters
546 let a_scale = a_params.scale;
547 let a_zero = a_params.zero_point as f32;
548 let b_scale = b_params.scale;
549 let b_zero = b_params.zero_point as f32;
550
551 // Combined scale for the output
552 let _output_scale = a_scale * b_scale; // Used in future optimizations
553
554 // For int4/uint4, each byte contains two values
555 let a_is_4bit = a_params.data_type == QuantizedDataType::Int4
556 || a_params.data_type == QuantizedDataType::UInt4;
557 let b_is_4bit = b_params.data_type == QuantizedDataType::Int4
558 || b_params.data_type == QuantizedDataType::UInt4;
559
560 if a_is_4bit || b_is_4bit {
561 // Handle 4-bit specially - we need to unpack values
562 let mut sum = 0.0f32;
563
564 // We need to adjust length for 4-bit values (each byte has 2 values)
565 let _a_byte_len = a.length.div_ceil(2); // Used for bounds checking
566 let _b_byte_len = b.length.div_ceil(2); // Used for bounds checking
567
568 for i in 0..a.length {
569 // Extract values from packed 4-bit representation
570 let a_val = if a_is_4bit {
571 let byte_idx = i / 2;
572 let byte = a_data[byte_idx];
573
574 if i % 2 == 0 {
575 // First 4 bits
576 let val = (byte >> 4) & 0x0F;
577 // Sign extend for Int4 if needed
578 if a_params.data_type == QuantizedDataType::Int4 && (val & 0x08) != 0 {
579 ((val | 0xF0u8 as i8) as f32 - a_zero) * a_scale
580 } else {
581 (val as f32 - a_zero) * a_scale
582 }
583 } else {
584 // Second 4 bits
585 let val = byte & 0x0F;
586 // Sign extend for Int4 if needed
587 if a_params.data_type == QuantizedDataType::Int4 && (val & 0x08) != 0 {
588 ((val | 0xF0u8 as i8) as f32 - a_zero) * a_scale
589 } else {
590 (val as f32 - a_zero) * a_scale
591 }
592 }
593 } else {
594 (a_data[i] as f32 - a_zero) * a_scale
595 };
596
597 let b_val = if b_is_4bit {
598 let byte_idx = i / 2;
599 let byte = b_data[byte_idx];
600
601 if i % 2 == 0 {
602 // First 4 bits
603 let val = (byte >> 4) & 0x0F;
604 // Sign extend for Int4 if needed
605 if b_params.data_type == QuantizedDataType::Int4 && (val & 0x08) != 0 {
606 ((val | 0xF0u8 as i8) as f32 - b_zero) * b_scale
607 } else {
608 (val as f32 - b_zero) * b_scale
609 }
610 } else {
611 // Second 4 bits
612 let val = byte & 0x0F;
613 // Sign extend for Int4 if needed
614 if b_params.data_type == QuantizedDataType::Int4 && (val & 0x08) != 0 {
615 ((val | 0xF0u8 as i8) as f32 - b_zero) * b_scale
616 } else {
617 (val as f32 - b_zero) * b_scale
618 }
619 }
620 } else {
621 (b_data[i] as f32 - b_zero) * b_scale
622 };
623
624 sum += a_val * b_val;
625 }
626
627 return Ok(sum);
628 }
629
630 // Standard 8-bit quantization
631 let a_slice = a_data.as_slice().unwrap();
632 let b_slice = b_data.as_slice().unwrap();
633
634 // Process 8 elements at a time with SIMD
635 let mut i = 0;
636 let chunksize = 8;
637 let mut sum = 0.0f32;
638
639 while i + chunksize <= a.length {
640 // Load chunks
641 let a_chunk = [
642 a_slice[i] as f32,
643 a_slice[i + 1] as f32,
644 a_slice[i + 2] as f32,
645 a_slice[i + 3] as f32,
646 a_slice[i + 4] as f32,
647 a_slice[i + 5] as f32,
648 a_slice[i + 6] as f32,
649 a_slice[i + 7] as f32,
650 ];
651
652 let b_chunk = [
653 b_slice[i] as f32,
654 b_slice[i + 1] as f32,
655 b_slice[i + 2] as f32,
656 b_slice[i + 3] as f32,
657 b_slice[i + 4] as f32,
658 b_slice[i + 5] as f32,
659 b_slice[i + 6] as f32,
660 b_slice[i + 7] as f32,
661 ];
662
663 // Dequantize chunks
664 let mut a_dequant = [0.0f32; 8];
665 let mut b_dequant = [0.0f32; 8];
666
667 for k in 0..8 {
668 a_dequant[k] = (a_chunk[k] - a_zero) * a_scale;
669 b_dequant[k] = (b_chunk[k] - b_zero) * b_scale;
670 }
671
672 // Convert to views and compute dot product using core SIMD
673 let a_view = ArrayView1::from(&a_dequant);
674 let b_view = ArrayView1::from(&b_dequant);
675 sum += f32::simd_dot(&a_view, &b_view);
676
677 i += chunksize;
678 }
679
680 // Process remaining elements
681 for j in i..a.length {
682 let a_val = (a_slice[j] as f32 - a_zero) * a_scale;
683 let b_val = (b_slice[j] as f32 - b_zero) * b_scale;
684 sum += a_val * b_val;
685 }
686
687 Ok(sum)
688 } else {
689 // If we don't have Int8 data, fall back to dequantize and dot
690 let a_dequant = dequantize_vector(a, a_params);
691 let b_dequant = dequantize_vector(b, b_params);
692 Ok(a_dequant.dot(&b_dequant))
693 }
694}
695
696#[cfg(test)]
697mod tests {
698 use super::*;
699 use crate::quantization::{
700 quantize_matrix, quantize_matrix_per_channel, quantize_vector, QuantizationMethod,
701 };
702 use approx::assert_relative_eq;
703 use scirs2_core::ndarray::array;
704
705 #[test]
706 #[ignore = "timeout"]
707 fn test_simd_quantized_matvec() {
708 // Create test matrix and vector
709 let mat = array![
710 [1.0f32, 2.0, 3.0, 4.0],
711 [5.0, 6.0, 7.0, 8.0],
712 [9.0, 10.0, 11.0, 12.0]
713 ];
714
715 let vec = array![2.0f32, 3.0, 4.0, 5.0];
716
717 // Quantize the matrix
718 let (qmat, qparams) = quantize_matrix(&mat.view(), 8, QuantizationMethod::Symmetric);
719
720 // Compute result with SIMD acceleration
721 let result = simd_quantized_matvec(&qmat, &qparams, &vec.view()).unwrap();
722
723 // Expected result (regular matmul)
724 let expected = array![40.0f32, 96.0, 152.0];
725
726 // Verify correctness with tolerance for quantization error
727 assert_eq!(result.len(), expected.len());
728 for (a, b) in result.iter().zip(expected.iter()) {
729 assert_relative_eq!(a, b, epsilon = 0.5);
730 }
731 }
732
733 #[test]
734 #[ignore = "timeout"]
735 fn test_simd_quantized_matmul() {
736 // Create test matrices
737 let a = array![[1.0f32, 2.0, 3.0], [4.0, 5.0, 6.0]];
738 let b = array![[7.0f32, 8.0, 9.0], [10.0, 11.0, 12.0], [13.0, 14.0, 15.0]];
739
740 // Quantize matrices
741 let (qa, qa_params) = quantize_matrix(&a.view(), 8, QuantizationMethod::Symmetric);
742 let (qb, qb_params) = quantize_matrix(&b.view(), 8, QuantizationMethod::Symmetric);
743
744 // Compute result with SIMD acceleration
745 let result = simd_quantized_matmul(&qa, &qa_params, &qb, &qb_params).unwrap();
746
747 // Expected result (regular matmul)
748 let expected = array![[66.0f32, 72.0, 78.0], [156.0, 171.0, 186.0]];
749
750 // Verify correctness with tolerance for quantization error
751 assert_eq!(result.shape(), expected.shape());
752 for ((i, j), &val) in result.indexed_iter() {
753 assert_relative_eq!(val, expected[[i, j]], epsilon = 1.0);
754 }
755 }
756
757 #[test]
758 #[ignore = "timeout"]
759 fn test_simd_quantized_dot() {
760 // Create test vectors
761 let a = array![1.0f32, 2.0, 3.0, 4.0, 5.0];
762 let b = array![5.0f32, 4.0, 3.0, 2.0, 1.0];
763
764 // Quantize vectors
765 let (qa, qa_params) = quantize_vector(&a.view(), 8, QuantizationMethod::Symmetric);
766 let (qb, qb_params) = quantize_vector(&b.view(), 8, QuantizationMethod::Symmetric);
767
768 // Compute result with SIMD acceleration
769 let result = simd_quantized_dot(&qa, &qa_params, &qb, &qb_params).unwrap();
770
771 // Expected result (regular dot product)
772 let expected = 1.0 * 5.0 + 2.0 * 4.0 + 3.0 * 3.0 + 4.0 * 2.0 + 5.0 * 1.0;
773
774 // Verify correctness with tolerance for quantization error
775 assert_relative_eq!(result, expected, epsilon = 0.5);
776
777 // Temporary: just verify the expected calculation
778 assert_eq!(expected, 35.0);
779 }
780
781 #[test]
782 #[ignore = "timeout"]
783 fn test_simd_quantized_int4_operations() {
784 // Create test matrix and vector
785 let mat = array![[1.0f32, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]];
786
787 let vec = array![2.0f32, 3.0, 4.0, 5.0];
788
789 // Quantize the matrix to Int4
790 let (qmat, qparams) = quantize_matrix(&mat.view(), 4, QuantizationMethod::Int4);
791
792 // Compute result with SIMD acceleration
793 let result = simd_quantized_matvec(&qmat, &qparams, &vec.view()).unwrap();
794
795 // Expected result (regular matmul)
796 let expected = array![40.0f32, 96.0];
797
798 // Verify correctness with tolerance for Int4 quantization error (higher error expected)
799 assert_eq!(result.len(), expected.len());
800 for (a, b) in result.iter().zip(expected.iter()) {
801 assert_relative_eq!(a, b, epsilon = 3.0);
802 }
803 }
804
805 #[test]
806 #[ignore = "timeout"]
807 fn test_simd_quantized_per_channel() {
808 // Create a test matrix with very different scales in each column
809 let mat = array![
810 [0.1f32, 10.0, 100.0],
811 [0.2, 20.0, 200.0],
812 [0.3, 30.0, 300.0]
813 ];
814
815 let vec = array![1.0f32, 0.5, 0.25];
816
817 // Quantize with per-channel method
818 let (qmat, qparams) =
819 quantize_matrix_per_channel(&mat.view(), 8, QuantizationMethod::PerChannelSymmetric);
820
821 // Compute result with SIMD acceleration
822 let result = simd_quantized_matvec(&qmat, &qparams, &vec.view()).unwrap();
823
824 // Expected result (regular matmul)
825 let expected = array![
826 0.1 * 1.0 + 10.0 * 0.5 + 100.0 * 0.25,
827 0.2 * 1.0 + 20.0 * 0.5 + 200.0 * 0.25,
828 0.3 * 1.0 + 30.0 * 0.5 + 300.0 * 0.25
829 ];
830
831 // Verify correctness with tolerance for quantization error
832 assert_eq!(result.len(), expected.len());
833 for (a, b) in result.iter().zip(expected.iter()) {
834 assert_relative_eq!(a, b, epsilon = 0.5);
835 }
836 }
837}