Skip to main content

torsh_quantization/
algorithms.rs

1//! Core quantization algorithms and tensor operations
2//!
3//! This module provides the fundamental quantization and dequantization algorithms
4//! for tensor operations, including per-tensor and per-channel quantization schemes.
5//!
6//! # Features
7//!
8//! - **Per-tensor quantization**: Single scale/zero-point for entire tensor
9//! - **Per-channel quantization**: Individual scale/zero-point per channel
10//! - **Dequantization**: Reverse quantization to restore floating point values
11//! - **Multiple schemes**: Affine and symmetric quantization support
12//! - **Configuration-driven**: Integration with QuantConfig for flexible usage
13
14use crate::config::{QScheme, QuantConfig};
15use torsh_core::{
16    dtype::DType,
17    error::{Result as TorshResult, TorshError},
18};
19use torsh_tensor::Tensor;
20
21#[cfg(not(feature = "std"))]
22extern crate alloc;
23
24#[cfg(not(feature = "std"))]
25use alloc::vec::Vec;
26
27use scirs2_core::parallel_ops::*;
28
29/// Quantize a tensor using specified configuration
30pub fn quantize_with_config(
31    tensor: &Tensor,
32    config: &QuantConfig,
33) -> TorshResult<(Tensor, f32, i32)> {
34    config.validate()?;
35
36    match config.scheme {
37        QScheme::PerTensorAffine | QScheme::PerTensorSymmetric => {
38            quantize_tensor_auto(tensor, config.dtype, config.scheme)
39        }
40        QScheme::PerChannelAffine | QScheme::PerChannelSymmetric => {
41            let axis = config.ch_axis.unwrap_or(0);
42            let (quantized, scales, zero_points) =
43                quantize_per_channel_auto(tensor, axis, config.dtype, config.scheme)?;
44            // Return first channel's parameters for compatibility
45            Ok((quantized, scales[0], zero_points[0]))
46        }
47        QScheme::GroupWise => {
48            let axis = config.ch_axis.unwrap_or(0);
49            let group_size = config.group_size.unwrap_or(32);
50            crate::specialized::quantize_group_wise(tensor, axis, group_size, config)
51        }
52        QScheme::Int4PerTensor => crate::specialized::quantize_int4_per_tensor(tensor, config),
53        QScheme::Int4PerChannel => {
54            let axis = config.ch_axis.unwrap_or(0);
55            crate::specialized::quantize_int4_per_channel(tensor, axis, config)
56        }
57        QScheme::Binary => crate::specialized::quantize_binary(tensor),
58        QScheme::Ternary => crate::specialized::quantize_ternary(tensor),
59        QScheme::MixedPrecision => {
60            // Mixed precision requires different handling
61            Err(TorshError::InvalidArgument(
62                "Mixed precision quantization requires specialized API".to_string(),
63            ))
64        }
65    }
66}
67
68/// Quantize a tensor to INT8 using specified scale and zero point
69pub fn quantize_per_tensor(
70    tensor: &Tensor,
71    scale: f32,
72    zero_point: i32,
73    _dtype: DType,
74) -> TorshResult<Tensor> {
75    let (quantized, _, _) = quantize_per_tensor_affine(tensor, scale, zero_point)?;
76    Ok(quantized)
77}
78
79/// Dequantize a quantized tensor using scale and zero_point
80pub fn dequantize(tensor: &Tensor, scale: f32, zero_point: i32) -> TorshResult<Tensor> {
81    dequantize_per_tensor_affine(tensor, scale, zero_point)
82}
83
84/// Auto-quantize a tensor using per-tensor scheme
85pub fn quantize_tensor_auto(
86    tensor: &Tensor,
87    dtype: DType,
88    scheme: QScheme,
89) -> TorshResult<(Tensor, f32, i32)> {
90    let data = tensor.data()?;
91
92    if data.is_empty() {
93        return Err(TorshError::InvalidArgument(
94            "Cannot quantize empty tensor".to_string(),
95        ));
96    }
97
98    // Calculate min and max values using SIMD acceleration when beneficial
99    let (min_val, max_val) = if data.len() > 64 && crate::simd_ops::is_simd_available() {
100        crate::simd_ops::find_min_max_simd(&data)?
101    } else {
102        // Fallback for small tensors
103        let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
104        let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
105        (min_val, max_val)
106    };
107
108    // Determine quantization parameters based on scheme
109    let (scale, zero_point) = match scheme {
110        QScheme::PerTensorAffine => calculate_affine_quantization_params(min_val, max_val, dtype)?,
111        QScheme::PerTensorSymmetric => {
112            calculate_symmetric_quantization_params(min_val, max_val, dtype)?
113        }
114        _ => {
115            return Err(TorshError::InvalidArgument(format!(
116                "Unsupported scheme for auto quantization: {:?}",
117                scheme
118            )));
119        }
120    };
121
122    quantize_per_tensor_affine(tensor, scale, zero_point)
123}
124
125/// Auto-quantize a tensor using per-channel scheme
126pub fn quantize_per_channel_auto(
127    tensor: &Tensor,
128    axis: usize,
129    dtype: DType,
130    scheme: QScheme,
131) -> TorshResult<(Tensor, Vec<f32>, Vec<i32>)> {
132    let binding = tensor.shape();
133    let shape = binding.dims();
134
135    if axis >= shape.len() {
136        return Err(TorshError::InvalidArgument(
137            "Axis out of bounds".to_string(),
138        ));
139    }
140
141    let num_channels = shape[axis];
142    let data = tensor.data()?;
143
144    // Calculate strides for efficient channel access
145    let mut strides = vec![1; shape.len()];
146    for i in (0..shape.len() - 1).rev() {
147        strides[i] = strides[i + 1] * shape[i + 1];
148    }
149
150    let mut scales = Vec::with_capacity(num_channels);
151    let mut zero_points = Vec::with_capacity(num_channels);
152    let mut quantized_data = vec![0.0f32; data.len()];
153
154    // Process each channel
155    for ch in 0..num_channels {
156        let mut channel_min = f32::INFINITY;
157        let mut channel_max = f32::NEG_INFINITY;
158
159        // Calculate channel statistics
160        let mut indices = vec![0; shape.len()];
161        let channel_size = data.len() / num_channels;
162
163        for i in 0..channel_size {
164            // Calculate multi-dimensional index for this channel
165            let mut temp_i = i;
166            for dim in (0..shape.len()).rev() {
167                if dim == axis {
168                    indices[dim] = ch;
169                } else {
170                    let other_dim_size = if dim == axis { 1 } else { shape[dim] };
171                    indices[dim] = temp_i % other_dim_size;
172                    temp_i /= other_dim_size;
173                }
174            }
175
176            // Convert multi-dimensional index to flat index
177            let flat_idx = indices
178                .iter()
179                .zip(strides.iter())
180                .map(|(idx, stride)| idx * stride)
181                .sum::<usize>();
182
183            if flat_idx < data.len() {
184                let val = data[flat_idx];
185                channel_min = channel_min.min(val);
186                channel_max = channel_max.max(val);
187            }
188        }
189
190        // Calculate quantization parameters for this channel
191        let (scale, zero_point) = match scheme {
192            QScheme::PerChannelAffine => {
193                calculate_affine_quantization_params(channel_min, channel_max, dtype)?
194            }
195            QScheme::PerChannelSymmetric => {
196                calculate_symmetric_quantization_params(channel_min, channel_max, dtype)?
197            }
198            _ => {
199                return Err(TorshError::InvalidArgument(format!(
200                    "Unsupported scheme for per-channel quantization: {:?}",
201                    scheme
202                )));
203            }
204        };
205
206        scales.push(scale);
207        zero_points.push(zero_point);
208
209        // Quantize channel data
210        for i in 0..channel_size {
211            let mut temp_i = i;
212            for dim in (0..shape.len()).rev() {
213                if dim == axis {
214                    indices[dim] = ch;
215                } else {
216                    let other_dim_size = if dim == axis { 1 } else { shape[dim] };
217                    indices[dim] = temp_i % other_dim_size;
218                    temp_i /= other_dim_size;
219                }
220            }
221
222            let flat_idx = indices
223                .iter()
224                .zip(strides.iter())
225                .map(|(idx, stride)| idx * stride)
226                .sum::<usize>();
227
228            if flat_idx < data.len() {
229                let val = data[flat_idx];
230                let quantized = ((val / scale).round() + zero_point as f32).clamp(
231                    get_dtype_range(dtype).0 as f32,
232                    get_dtype_range(dtype).1 as f32,
233                );
234                quantized_data[flat_idx] = quantized;
235            }
236        }
237    }
238
239    let quantized_tensor = Tensor::from_data(quantized_data, shape.to_vec(), tensor.device())?;
240
241    Ok((quantized_tensor, scales, zero_points))
242}
243
244/// Quantize a tensor using per-tensor affine quantization (returns I8 tensor)
245pub fn quantize_per_tensor_affine_i8(
246    tensor: &Tensor,
247    scale: f32,
248    zero_point: i32,
249) -> TorshResult<(Tensor<i8>, f32, i32)> {
250    let data = tensor.data()?;
251
252    if scale <= 0.0 {
253        return Err(TorshError::InvalidArgument(
254            "Scale must be positive".to_string(),
255        ));
256    }
257
258    let quantized_data: Vec<i8> = data
259        .iter()
260        .map(|&x| {
261            let quantized = (x / scale).round() + zero_point as f32;
262            // Clamp to int8 range and convert to i8
263            quantized.clamp(-128.0, 127.0) as i8
264        })
265        .collect();
266
267    let quantized_tensor = Tensor::from_data(
268        quantized_data,
269        tensor.shape().dims().to_vec(),
270        tensor.device(),
271    )?;
272
273    Ok((quantized_tensor, scale, zero_point))
274}
275
276/// Quantize a tensor using per-tensor affine quantization
277pub fn quantize_per_tensor_affine(
278    tensor: &Tensor,
279    scale: f32,
280    zero_point: i32,
281) -> TorshResult<(Tensor, f32, i32)> {
282    let data = tensor.data()?;
283
284    if scale <= 0.0 {
285        return Err(TorshError::InvalidArgument(
286            "Scale must be positive".to_string(),
287        ));
288    }
289
290    // Use SIMD-accelerated quantization when available and beneficial
291    let mut quantized_data = vec![0.0f32; data.len()];
292    if data.len() > 64 && crate::simd_ops::is_simd_available() {
293        // Use SIMD for larger tensors
294        crate::simd_ops::quantize_per_tensor_affine_simd(
295            &data,
296            scale,
297            zero_point,
298            &mut quantized_data,
299        )?;
300    } else {
301        // Fallback to scalar implementation for small tensors
302        for (i, &x) in data.iter().enumerate() {
303            let quantized = (x / scale).round() + zero_point as f32;
304            quantized_data[i] = quantized.clamp(-128.0, 127.0);
305        }
306    }
307
308    let quantized_tensor = Tensor::from_data(
309        quantized_data,
310        tensor.shape().dims().to_vec(),
311        tensor.device(),
312    )?;
313
314    Ok((quantized_tensor, scale, zero_point))
315}
316
317/// Dequantize a tensor using per-tensor affine dequantization
318pub fn dequantize_per_tensor_affine(
319    tensor: &Tensor,
320    scale: f32,
321    zero_point: i32,
322) -> TorshResult<Tensor> {
323    let data = tensor.data()?;
324
325    if scale <= 0.0 {
326        return Err(TorshError::InvalidArgument(
327            "Scale must be positive".to_string(),
328        ));
329    }
330
331    // Use SIMD-accelerated dequantization when available and beneficial
332    let mut dequantized_data = vec![0.0f32; data.len()];
333    if data.len() > 64 && crate::simd_ops::is_simd_available() {
334        // Use SIMD for larger tensors
335        crate::simd_ops::dequantize_per_tensor_affine_simd(
336            &data,
337            scale,
338            zero_point,
339            &mut dequantized_data,
340        )?;
341    } else {
342        // Fallback to scalar implementation for small tensors
343        for (i, &x) in data.iter().enumerate() {
344            dequantized_data[i] = (x - zero_point as f32) * scale;
345        }
346    }
347
348    let dequantized_tensor = Tensor::from_data(
349        dequantized_data,
350        tensor.shape().dims().to_vec(),
351        tensor.device(),
352    )?;
353
354    Ok(dequantized_tensor)
355}
356
357/// Calculate affine quantization parameters (scale and zero_point)
358pub fn calculate_affine_quantization_params(
359    min_val: f32,
360    max_val: f32,
361    dtype: DType,
362) -> TorshResult<(f32, i32)> {
363    if !min_val.is_finite() || !max_val.is_finite() {
364        return Err(TorshError::InvalidArgument(
365            "Min and max values must be finite".to_string(),
366        ));
367    }
368
369    if min_val > max_val {
370        return Err(TorshError::InvalidArgument(
371            "Min value must be <= max value".to_string(),
372        ));
373    }
374
375    let (qmin, qmax) = get_dtype_range(dtype);
376    let qmin = qmin as f32;
377    let qmax = qmax as f32;
378
379    // Handle edge case where min == max
380    if (max_val - min_val).abs() < f32::EPSILON {
381        let scale = 1.0;
382        let zero_point = qmin as i32;
383        return Ok((scale, zero_point));
384    }
385
386    // Calculate scale
387    let scale = (max_val - min_val) / (qmax - qmin);
388
389    // Calculate zero_point
390    let zero_point_fp = qmin - min_val / scale;
391    let zero_point = zero_point_fp.round().clamp(qmin, qmax) as i32;
392
393    Ok((scale, zero_point))
394}
395
396/// Calculate symmetric quantization parameters (scale only, zero_point = 0)
397pub fn calculate_symmetric_quantization_params(
398    min_val: f32,
399    max_val: f32,
400    dtype: DType,
401) -> TorshResult<(f32, i32)> {
402    if !min_val.is_finite() || !max_val.is_finite() {
403        return Err(TorshError::InvalidArgument(
404            "Min and max values must be finite".to_string(),
405        ));
406    }
407
408    let (_qmin, qmax) = get_dtype_range(dtype);
409    let abs_max = min_val.abs().max(max_val.abs());
410
411    // Handle edge case where range is zero
412    if abs_max < f32::EPSILON {
413        return Ok((1.0, 0));
414    }
415
416    // For symmetric quantization, we use the maximum absolute value
417    // and map it to the maximum quantized range
418    let scale = abs_max / qmax as f32;
419    let zero_point = 0; // Symmetric quantization always has zero_point = 0
420
421    Ok((scale, zero_point))
422}
423
424/// Get the quantization range for a given data type
425pub fn get_dtype_range(dtype: DType) -> (i32, i32) {
426    match dtype {
427        DType::I8 => (-128, 127),
428        DType::U8 => (0, 255),
429        DType::I16 => (-32768, 32767),
430        DType::I32 => (i32::MIN, i32::MAX),
431        _ => (-128, 127), // Default to int8 range
432    }
433}
434
435/// Convenience function to quantize with automatic parameter calculation
436pub fn quantize_auto(tensor: &Tensor, config: &QuantConfig) -> TorshResult<(Tensor, f32, i32)> {
437    quantize_with_config(tensor, config)
438}
439
440// ===== Cache-Aware Algorithm Enhancements =====
441
442/// Cache-aware quantization parameters for optimal memory access patterns
443#[derive(Debug, Clone)]
444pub struct CacheAwareParams {
445    /// Cache line size in bytes (typically 64 bytes)
446    pub cache_line_size: usize,
447    /// L1 cache size in bytes (typically 32KB)
448    pub l1_cache_size: usize,
449    /// L2 cache size in bytes (typically 256KB)
450    pub l2_cache_size: usize,
451    /// L3 cache size in bytes (typically 8MB)
452    pub l3_cache_size: usize,
453    /// Prefetch distance (elements ahead to prefetch)
454    pub prefetch_distance: usize,
455    /// Enable cache-optimized chunking
456    pub enable_chunking: bool,
457}
458
459impl Default for CacheAwareParams {
460    fn default() -> Self {
461        Self {
462            cache_line_size: 64,
463            l1_cache_size: 32 * 1024,       // 32KB L1
464            l2_cache_size: 256 * 1024,      // 256KB L2
465            l3_cache_size: 8 * 1024 * 1024, // 8MB L3
466            prefetch_distance: 16,
467            enable_chunking: true,
468        }
469    }
470}
471
472/// Cache-aware per-tensor quantization optimized for memory hierarchy
473pub fn quantize_per_tensor_affine_cache_aware(
474    input: &[f32],
475    scale: f32,
476    zero_point: i32,
477    output: &mut [f32],
478    cache_params: &CacheAwareParams,
479) -> TorshResult<()> {
480    if input.len() != output.len() {
481        return Err(TorshError::InvalidArgument(
482            "Input and output length mismatch".to_string(),
483        ));
484    }
485
486    if scale <= 0.0 {
487        return Err(TorshError::InvalidArgument(
488            "Scale must be positive".to_string(),
489        ));
490    }
491
492    let inv_scale = 1.0 / scale;
493    let zero_point_f32 = zero_point as f32;
494
495    if !cache_params.enable_chunking || input.len() < cache_params.cache_line_size {
496        // Direct processing for small arrays
497        for (inp, out) in input.iter().zip(output.iter_mut()) {
498            let quantized = (*inp * inv_scale).round() + zero_point_f32;
499            *out = quantized.clamp(-128.0, 127.0);
500        }
501        return Ok(());
502    }
503
504    // Calculate optimal chunk size based on cache hierarchy
505    let _elements_per_cache_line = cache_params.cache_line_size / std::mem::size_of::<f32>();
506    let optimal_chunk_size =
507        (cache_params.l2_cache_size / std::mem::size_of::<f32>() / 4).min(input.len());
508
509    // Process in cache-friendly chunks
510    input
511        .par_chunks(optimal_chunk_size)
512        .zip(output.par_chunks_mut(optimal_chunk_size))
513        .for_each(|(input_chunk, output_chunk)| {
514            // Process chunk with cache-friendly pattern
515            for (inp, out) in input_chunk.iter().zip(output_chunk.iter_mut()) {
516                let quantized = (*inp * inv_scale).round() + zero_point_f32;
517                *out = quantized.clamp(-128.0, 127.0);
518            }
519        });
520
521    Ok(())
522}
523
524/// Cache-optimized tensor statistics calculation with blocking
525pub fn calculate_tensor_stats_cache_optimized(
526    data: &[f32],
527    cache_params: &CacheAwareParams,
528) -> TorshResult<(f32, f32, f32, f32)> {
529    if data.is_empty() {
530        return Err(TorshError::InvalidArgument(
531            "Cannot calculate stats of empty tensor".to_string(),
532        ));
533    }
534
535    let optimal_block_size = cache_params.l2_cache_size / std::mem::size_of::<f32>();
536    let block_size = optimal_block_size.min(data.len());
537
538    // Use blocked algorithm for better cache performance
539    let results: Vec<(f32, f32, f64, f64)> = data
540        .par_chunks(block_size)
541        .map(|chunk| {
542            let mut local_min = f32::INFINITY;
543            let mut local_max = f32::NEG_INFINITY;
544            let mut local_sum = 0.0f64;
545            let mut local_sum_sq = 0.0f64;
546
547            // Process block with good cache locality
548            for &val in chunk {
549                local_min = local_min.min(val);
550                local_max = local_max.max(val);
551                let val_f64 = val as f64;
552                local_sum += val_f64;
553                local_sum_sq += val_f64 * val_f64;
554            }
555
556            (local_min, local_max, local_sum, local_sum_sq)
557        })
558        .collect();
559
560    // Combine results
561    let mut min_val = f32::INFINITY;
562    let mut max_val = f32::NEG_INFINITY;
563    let mut total_sum = 0.0f64;
564    let mut total_sum_sq = 0.0f64;
565
566    for (local_min, local_max, local_sum, local_sum_sq) in results {
567        min_val = min_val.min(local_min);
568        max_val = max_val.max(local_max);
569        total_sum += local_sum;
570        total_sum_sq += local_sum_sq;
571    }
572
573    let n = data.len() as f64;
574    let mean = (total_sum / n) as f32;
575    let variance = ((total_sum_sq / n) - (mean as f64).powi(2)) as f32;
576
577    Ok((min_val, max_val, mean, variance.sqrt()))
578}
579
580/// Cache-friendly matrix quantization using tiling for 2D tensors
581pub fn quantize_matrix_cache_friendly(
582    matrix: &[f32],
583    rows: usize,
584    cols: usize,
585    scale: f32,
586    zero_point: i32,
587    output: &mut [f32],
588    cache_params: &CacheAwareParams,
589) -> TorshResult<()> {
590    if matrix.len() != rows * cols || output.len() != rows * cols {
591        return Err(TorshError::InvalidArgument(
592            "Matrix dimensions don't match buffer sizes".to_string(),
593        ));
594    }
595
596    if scale <= 0.0 {
597        return Err(TorshError::InvalidArgument(
598            "Scale must be positive".to_string(),
599        ));
600    }
601
602    let inv_scale = 1.0 / scale;
603    let zero_point_f32 = zero_point as f32;
604
605    // Calculate optimal tile sizes based on cache hierarchy
606    let elements_per_cache_line = cache_params.cache_line_size / std::mem::size_of::<f32>();
607    let l2_elements = cache_params.l2_cache_size / std::mem::size_of::<f32>();
608
609    // Find good tile dimensions that fit in L2 cache
610    let max_tile_size = (l2_elements / 4).min(1024); // Reserve space for other data
611    let tile_rows = (max_tile_size / cols).max(1).min(rows);
612    let tile_cols = (max_tile_size / tile_rows)
613        .max(elements_per_cache_line)
614        .min(cols);
615
616    // Process matrix in cache-friendly tiles
617    for row_start in (0..rows).step_by(tile_rows) {
618        let row_end = (row_start + tile_rows).min(rows);
619
620        for col_start in (0..cols).step_by(tile_cols) {
621            let col_end = (col_start + tile_cols).min(cols);
622
623            // Process tile with good spatial locality
624            for row in row_start..row_end {
625                for col in col_start..col_end {
626                    let idx = row * cols + col;
627                    let quantized = (matrix[idx] * inv_scale).round() + zero_point_f32;
628                    output[idx] = quantized.clamp(-128.0, 127.0);
629                }
630            }
631        }
632    }
633
634    Ok(())
635}
636
637/// Prefetch-aware sequential quantization for streaming data
638pub fn quantize_streaming_with_prefetch(
639    input: &[f32],
640    scale: f32,
641    zero_point: i32,
642    output: &mut [f32],
643    cache_params: &CacheAwareParams,
644) -> TorshResult<()> {
645    if input.len() != output.len() {
646        return Err(TorshError::InvalidArgument(
647            "Input and output length mismatch".to_string(),
648        ));
649    }
650
651    if scale <= 0.0 {
652        return Err(TorshError::InvalidArgument(
653            "Scale must be positive".to_string(),
654        ));
655    }
656
657    let inv_scale = 1.0 / scale;
658    let zero_point_f32 = zero_point as f32;
659    let prefetch_distance = cache_params.prefetch_distance;
660
661    // Sequential processing with software prefetching
662    for i in 0..input.len() {
663        // Software prefetch hint for future data (compiler may optimize this)
664        if i + prefetch_distance < input.len() {
665            // This is a hint to the processor - actual prefetch intrinsics would be platform-specific
666            let _prefetch_addr = &input[i + prefetch_distance];
667        }
668
669        let quantized = (input[i] * inv_scale).round() + zero_point_f32;
670        output[i] = quantized.clamp(-128.0, 127.0);
671    }
672
673    Ok(())
674}
675
676/// Get cache-aware optimization recommendations for tensor operations
677pub fn get_cache_optimization_recommendations(
678    tensor_size: usize,
679    element_size: usize,
680    cache_params: &CacheAwareParams,
681) -> Vec<String> {
682    let mut recommendations = Vec::new();
683    let total_bytes = tensor_size * element_size;
684
685    if total_bytes <= cache_params.l1_cache_size {
686        recommendations.push("Tensor fits in L1 cache - use simple sequential access".to_string());
687    } else if total_bytes <= cache_params.l2_cache_size {
688        recommendations.push("Tensor fits in L2 cache - consider blocked algorithms".to_string());
689    } else if total_bytes <= cache_params.l3_cache_size {
690        recommendations
691            .push("Tensor fits in L3 cache - use tiled processing with medium blocks".to_string());
692    } else {
693        recommendations
694            .push("Large tensor - use streaming algorithms with prefetching".to_string());
695        recommendations
696            .push("Consider parallel processing to utilize multiple cache hierarchies".to_string());
697    }
698
699    let elements_per_cache_line = cache_params.cache_line_size / element_size;
700    if tensor_size % elements_per_cache_line != 0 {
701        recommendations.push(format!(
702            "Consider padding to align with cache lines ({}B boundaries)",
703            cache_params.cache_line_size
704        ));
705    }
706
707    recommendations
708}
709
710/// Auto-select optimal quantization algorithm based on cache analysis
711pub fn quantize_with_cache_optimization(
712    input: &[f32],
713    scale: f32,
714    zero_point: i32,
715    output: &mut [f32],
716    cache_params: Option<&CacheAwareParams>,
717) -> TorshResult<()> {
718    let default_params = CacheAwareParams::default();
719    let params = cache_params.unwrap_or(&default_params);
720    let total_bytes = std::mem::size_of_val(input);
721
722    if total_bytes <= params.l1_cache_size {
723        // Small data - use simple sequential processing
724        quantize_streaming_with_prefetch(input, scale, zero_point, output, params)
725    } else if total_bytes <= params.l2_cache_size {
726        // Medium data - use cache-aware chunking
727        quantize_per_tensor_affine_cache_aware(input, scale, zero_point, output, params)
728    } else {
729        // Large data - use parallel processing with cache-friendly chunks
730        quantize_per_tensor_affine_cache_aware(input, scale, zero_point, output, params)
731    }
732}
733
734#[cfg(test)]
735mod tests {
736    use super::*;
737    use crate::config::{QScheme, QuantConfig};
738
739    use torsh_tensor::creation::tensor_1d;
740
741    #[test]
742    fn test_calculate_affine_quantization_params() {
743        // Test normal case
744        let (scale, zero_point) =
745            calculate_affine_quantization_params(-1.0, 1.0, DType::I8).unwrap();
746
747        assert!(scale > 0.0);
748        assert!(zero_point >= -128 && zero_point <= 127);
749
750        // Test edge case: min == max
751        let (scale, zero_point) =
752            calculate_affine_quantization_params(1.0, 1.0, DType::I8).unwrap();
753
754        assert_eq!(scale, 1.0);
755        assert_eq!(zero_point, -128);
756
757        // Test invalid case: min > max
758        let result = calculate_affine_quantization_params(2.0, 1.0, DType::I8);
759        assert!(result.is_err());
760    }
761
762    #[test]
763    fn test_calculate_symmetric_quantization_params() {
764        // Test normal case
765        let (scale, zero_point) =
766            calculate_symmetric_quantization_params(-2.0, 1.0, DType::I8).unwrap();
767
768        assert!(scale > 0.0);
769        assert_eq!(zero_point, 0); // Symmetric always has zero_point = 0
770
771        // Test edge case: zero range
772        let (scale, zero_point) =
773            calculate_symmetric_quantization_params(0.0, 0.0, DType::I8).unwrap();
774
775        assert_eq!(scale, 1.0);
776        assert_eq!(zero_point, 0);
777    }
778
779    #[test]
780    fn test_get_dtype_range() {
781        assert_eq!(get_dtype_range(DType::I8), (-128, 127));
782        assert_eq!(get_dtype_range(DType::U8), (0, 255));
783        assert_eq!(get_dtype_range(DType::I16), (-32768, 32767));
784    }
785
786    #[test]
787    fn test_quantize_per_tensor_affine() {
788        let data = vec![1.0, 2.0, 3.0, 4.0];
789        let tensor = tensor_1d(&data).unwrap();
790
791        let (quantized, scale, zero_point) = quantize_per_tensor_affine(&tensor, 0.1, 0).unwrap();
792
793        let quantized_data = quantized.data().unwrap();
794
795        // Verify quantization: (value / scale) + zero_point
796        assert_eq!(quantized_data[0], 10.0); // (1.0 / 0.1) + 0 = 10
797        assert_eq!(quantized_data[1], 20.0); // (2.0 / 0.1) + 0 = 20
798        assert_eq!(scale, 0.1);
799        assert_eq!(zero_point, 0);
800    }
801
802    #[test]
803    fn test_dequantize_per_tensor_affine() {
804        let quantized_data = vec![10.0, 20.0, 30.0, 40.0];
805        let quantized_tensor = tensor_1d(&quantized_data).unwrap();
806
807        let dequantized = dequantize_per_tensor_affine(&quantized_tensor, 0.1, 0).unwrap();
808        let dequantized_data = dequantized.data().unwrap();
809
810        // Verify dequantization: (quantized_value - zero_point) * scale
811        assert!((dequantized_data[0] - 1.0).abs() < 1e-6); // (10 - 0) * 0.1 = 1.0
812        assert!((dequantized_data[1] - 2.0).abs() < 1e-6); // (20 - 0) * 0.1 = 2.0
813    }
814
815    #[test]
816    fn test_quantize_tensor_auto() {
817        let data = vec![-1.0, 0.0, 1.0, 2.0];
818        let tensor = tensor_1d(&data).unwrap();
819
820        let (quantized, scale, zero_point) =
821            quantize_tensor_auto(&tensor, DType::I8, QScheme::PerTensorAffine).unwrap();
822
823        assert!(scale > 0.0);
824        assert!(zero_point >= -128 && zero_point <= 127);
825
826        // Verify tensor was quantized
827        let quantized_data = quantized.data().unwrap();
828        assert_eq!(quantized_data.len(), data.len());
829    }
830
831    #[test]
832    fn test_quantize_with_config() {
833        let data = vec![1.0, 2.0, 3.0, 4.0];
834        let tensor = tensor_1d(&data).unwrap();
835        let config = QuantConfig::int8();
836
837        let result = quantize_with_config(&tensor, &config);
838        assert!(result.is_ok());
839
840        let (quantized, scale, zero_point) = result.unwrap();
841        assert!(scale > 0.0);
842        assert!(zero_point >= -128 && zero_point <= 127);
843        assert_eq!(quantized.shape().dims(), tensor.shape().dims());
844    }
845
846    #[test]
847    fn test_dequantize() {
848        let quantized_data = vec![64.0, 128.0, -64.0, 0.0];
849        let quantized_tensor = tensor_1d(&quantized_data).unwrap();
850
851        let dequantized = dequantize(&quantized_tensor, 0.5, 0).unwrap();
852        let dequantized_data = dequantized.data().unwrap();
853
854        // Test dequantization with scale 0.5
855        assert!((dequantized_data[0] - 32.0).abs() < 1e-6);
856        assert!((dequantized_data[1] - 64.0).abs() < 1e-6);
857        assert!((dequantized_data[2] + 32.0).abs() < 1e-6);
858        assert!((dequantized_data[3] - 0.0).abs() < 1e-6);
859    }
860
861    #[test]
862    fn test_quantize_auto() {
863        let data = vec![0.5, 1.0, 1.5, 2.0];
864        let tensor = tensor_1d(&data).unwrap();
865        let config = QuantConfig::int8();
866
867        let result = quantize_auto(&tensor, &config);
868        assert!(result.is_ok());
869
870        let (quantized, scale, _zero_point) = result.unwrap();
871        assert!(scale > 0.0);
872        assert_eq!(quantized.shape().dims(), tensor.shape().dims());
873    }
874
875    #[test]
876    fn test_error_cases() {
877        // Test invalid scale
878        let data = vec![1.0, 2.0];
879        let tensor = tensor_1d(&data).unwrap();
880
881        let result = quantize_per_tensor_affine(&tensor, -1.0, 0);
882        assert!(result.is_err());
883
884        let result = dequantize_per_tensor_affine(&tensor, 0.0, 0);
885        assert!(result.is_err());
886
887        // Test empty tensor
888        let empty_data: Vec<f32> = vec![];
889        let empty_tensor = tensor_1d(&empty_data).unwrap();
890
891        let result = quantize_tensor_auto(&empty_tensor, DType::I8, QScheme::PerTensorAffine);
892        assert!(result.is_err());
893    }
894
895    // ===== Cache-Aware Algorithm Tests =====
896
897    #[test]
898    fn test_cache_aware_quantization() {
899        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
900        let mut output = vec![0.0; 8];
901        let cache_params = CacheAwareParams::default();
902
903        let result =
904            quantize_per_tensor_affine_cache_aware(&input, 0.1, 0, &mut output, &cache_params);
905
906        assert!(result.is_ok());
907        assert_eq!(output[0], 10.0);
908        assert_eq!(output[7], 80.0);
909    }
910
911    #[test]
912    fn test_cache_optimized_stats() {
913        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
914        let cache_params = CacheAwareParams::default();
915
916        let result = calculate_tensor_stats_cache_optimized(&data, &cache_params);
917        assert!(result.is_ok());
918
919        let (min_val, max_val, mean, std_dev) = result.unwrap();
920        assert_eq!(min_val, 1.0);
921        assert_eq!(max_val, 10.0);
922        assert!((mean - 5.5).abs() < 0.001);
923        assert!(std_dev > 0.0);
924    }
925
926    #[test]
927    fn test_matrix_cache_friendly_quantization() {
928        let matrix = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
929        let mut output = vec![0.0; 9];
930        let cache_params = CacheAwareParams::default();
931
932        let result =
933            quantize_matrix_cache_friendly(&matrix, 3, 3, 0.1, 0, &mut output, &cache_params);
934
935        assert!(result.is_ok());
936        assert_eq!(output[0], 10.0); // 1.0 / 0.1 = 10
937        assert_eq!(output[8], 90.0); // 9.0 / 0.1 = 90
938    }
939
940    #[test]
941    fn test_streaming_with_prefetch() {
942        let input = vec![0.1, 0.2, 0.3, 0.4, 0.5];
943        let mut output = vec![0.0; 5];
944        let cache_params = CacheAwareParams::default();
945
946        let result = quantize_streaming_with_prefetch(&input, 0.01, 10, &mut output, &cache_params);
947
948        assert!(result.is_ok());
949        // 0.1 / 0.01 + 10 = 10 + 10 = 20
950        assert_eq!(output[0], 20.0);
951    }
952
953    #[test]
954    fn test_cache_optimization_recommendations() {
955        let cache_params = CacheAwareParams::default();
956
957        // Small tensor (L1 cache)
958        let recommendations = get_cache_optimization_recommendations(1000, 4, &cache_params);
959        assert!(!recommendations.is_empty());
960        assert!(recommendations[0].contains("L1 cache"));
961
962        // Large tensor (beyond L3)
963        let large_recommendations =
964            get_cache_optimization_recommendations(10_000_000, 4, &cache_params);
965        assert!(large_recommendations
966            .iter()
967            .any(|r| r.contains("streaming")));
968    }
969
970    #[test]
971    fn test_auto_cache_optimization() {
972        let input = vec![1.0, 2.0, 3.0, 4.0];
973        let mut output = vec![0.0; 4];
974
975        // Test with default cache parameters
976        let result = quantize_with_cache_optimization(&input, 0.1, 0, &mut output, None);
977
978        assert!(result.is_ok());
979        assert_eq!(output[0], 10.0);
980        assert_eq!(output[3], 40.0);
981    }
982
983    #[test]
984    fn test_cache_params_default() {
985        let params = CacheAwareParams::default();
986
987        assert_eq!(params.cache_line_size, 64);
988        assert_eq!(params.l1_cache_size, 32 * 1024);
989        assert_eq!(params.l2_cache_size, 256 * 1024);
990        assert_eq!(params.l3_cache_size, 8 * 1024 * 1024);
991        assert_eq!(params.prefetch_distance, 16);
992        assert!(params.enable_chunking);
993    }
994
995    #[test]
996    fn test_cache_aware_error_cases() {
997        let input = vec![1.0, 2.0];
998        let mut output = vec![0.0; 3]; // Wrong size
999        let cache_params = CacheAwareParams::default();
1000
1001        let result =
1002            quantize_per_tensor_affine_cache_aware(&input, 0.1, 0, &mut output, &cache_params);
1003        assert!(result.is_err());
1004
1005        // Test invalid scale
1006        let mut output_correct = vec![0.0; 2];
1007        let result = quantize_per_tensor_affine_cache_aware(
1008            &input,
1009            -0.1,
1010            0,
1011            &mut output_correct,
1012            &cache_params,
1013        );
1014        assert!(result.is_err());
1015
1016        // Test empty data for stats
1017        let empty_data: Vec<f32> = vec![];
1018        let result = calculate_tensor_stats_cache_optimized(&empty_data, &cache_params);
1019        assert!(result.is_err());
1020    }
1021}