Skip to main content

torsh_quantization/
simd_ops.rs

1//! SIMD-accelerated quantization operations
2//!
3//! This module provides optimized SIMD implementations for performance-critical
4//! quantization operations using the scirs2-core SIMD abstraction layer.
5//!
6//! # Features
7//!
8//! - **Vectorized Quantization**: SIMD-accelerated per-tensor quantization
9//! - **Vectorized Dequantization**: SIMD-accelerated dequantization operations
10//! - **Fast Min/Max Finding**: Hardware-accelerated min/max computation for calibration
11//! - **Batch Operations**: Optimized batch processing for multiple tensors
12//! - **Fallback Support**: Automatic fallback to scalar operations when SIMD unavailable
13
14use scirs2_core::parallel_ops::*;
15use torsh_core::error::{Result as TorshResult, TorshError};
16
17/// SIMD-accelerated per-tensor quantization
18pub fn quantize_per_tensor_affine_simd(
19    input: &[f32],
20    scale: f32,
21    zero_point: i32,
22    output: &mut [f32],
23) -> TorshResult<()> {
24    if input.len() != output.len() {
25        return Err(TorshError::InvalidArgument(
26            "Input and output length mismatch".to_string(),
27        ));
28    }
29
30    if scale <= 0.0 {
31        return Err(TorshError::InvalidArgument(
32            "Scale must be positive".to_string(),
33        ));
34    }
35
36    let inv_scale = 1.0 / scale;
37    let zero_point_f32 = zero_point as f32;
38
39    // Use optimized parallel processing for quantization operation
40    input
41        .par_iter()
42        .zip(output.par_iter_mut())
43        .for_each(|(&x, out)| {
44            let quantized = (x * inv_scale).round() + zero_point_f32;
45            *out = quantized.clamp(-128.0, 127.0);
46        });
47
48    Ok(())
49}
50
51/// SIMD-accelerated per-tensor dequantization
52pub fn dequantize_per_tensor_affine_simd(
53    input: &[f32],
54    scale: f32,
55    zero_point: i32,
56    output: &mut [f32],
57) -> TorshResult<()> {
58    if input.len() != output.len() {
59        return Err(TorshError::InvalidArgument(
60            "Input and output length mismatch".to_string(),
61        ));
62    }
63
64    if scale <= 0.0 {
65        return Err(TorshError::InvalidArgument(
66            "Scale must be positive".to_string(),
67        ));
68    }
69
70    let zero_point_f32 = zero_point as f32;
71
72    // Use optimized parallel processing for dequantization operation
73    input
74        .par_iter()
75        .zip(output.par_iter_mut())
76        .for_each(|(&x, out)| {
77            *out = (x - zero_point_f32) * scale;
78        });
79
80    Ok(())
81}
82
83/// SIMD-accelerated min/max finding for calibration
84pub fn find_min_max_simd(data: &[f32]) -> TorshResult<(f32, f32)> {
85    if data.is_empty() {
86        return Err(TorshError::InvalidArgument(
87            "Cannot find min/max of empty array".to_string(),
88        ));
89    }
90
91    // Use parallel operations for min/max reduction
92    const CHUNK_SIZE: usize = 1024; // Process in cache-friendly chunks
93    let (min_val, max_val) = if data.len() > CHUNK_SIZE {
94        // Use parallel processing for large datasets
95        data.par_chunks(CHUNK_SIZE)
96            .map(|chunk| {
97                let mut local_min = f32::INFINITY;
98                let mut local_max = f32::NEG_INFINITY;
99                for &val in chunk {
100                    local_min = local_min.min(val);
101                    local_max = local_max.max(val);
102                }
103                (local_min, local_max)
104            })
105            .reduce(
106                || (f32::INFINITY, f32::NEG_INFINITY),
107                |(min1, max1), (min2, max2)| (min1.min(min2), max1.max(max2)),
108            )
109    } else {
110        // Sequential processing for small datasets
111        let mut min_val = f32::INFINITY;
112        let mut max_val = f32::NEG_INFINITY;
113        for &val in data {
114            min_val = min_val.min(val);
115            max_val = max_val.max(val);
116        }
117        (min_val, max_val)
118    };
119
120    Ok((min_val, max_val))
121}
122
123/// SIMD-accelerated per-channel quantization
124pub fn quantize_per_channel_simd(
125    input: &[f32],
126    scales: &[f32],
127    zero_points: &[i32],
128    channel_size: usize,
129    output: &mut [f32],
130) -> TorshResult<()> {
131    if input.len() != output.len() {
132        return Err(TorshError::InvalidArgument(
133            "Input and output length mismatch".to_string(),
134        ));
135    }
136
137    let num_channels = scales.len();
138    if num_channels != zero_points.len() {
139        return Err(TorshError::InvalidArgument(
140            "Scales and zero_points length mismatch".to_string(),
141        ));
142    }
143
144    if input.len() != num_channels * channel_size {
145        return Err(TorshError::InvalidArgument(
146            "Input size does not match channel configuration".to_string(),
147        ));
148    }
149
150    // Process each channel with SIMD acceleration
151    for (ch, (&scale, &zero_point)) in scales.iter().zip(zero_points.iter()).enumerate() {
152        if scale <= 0.0 {
153            return Err(TorshError::InvalidArgument(format!(
154                "Scale for channel {} must be positive",
155                ch
156            )));
157        }
158
159        let channel_start = ch * channel_size;
160        let channel_end = channel_start + channel_size;
161
162        let input_slice = &input[channel_start..channel_end];
163        let output_slice = &mut output[channel_start..channel_end];
164
165        quantize_per_tensor_affine_simd(input_slice, scale, zero_point, output_slice)?;
166    }
167
168    Ok(())
169}
170
171/// SIMD-accelerated batch quantization for consistent parameters
172pub fn quantize_batch_consistent_simd(
173    tensors: &[&[f32]],
174    scale: f32,
175    zero_point: i32,
176    outputs: &mut [&mut [f32]],
177) -> TorshResult<()> {
178    if tensors.len() != outputs.len() {
179        return Err(TorshError::InvalidArgument(
180            "Number of input tensors must match output tensors".to_string(),
181        ));
182    }
183
184    // Use parallel processing for each tensor
185    tensors
186        .par_iter()
187        .zip(outputs.par_iter_mut())
188        .try_for_each(|(input, output)| {
189            quantize_per_tensor_affine_simd(input, scale, zero_point, output)
190        })?;
191
192    Ok(())
193}
194
195/// SIMD-accelerated floating-point to integer quantization (optimized for INT8)
196pub fn quantize_to_int8_simd(
197    input: &[f32],
198    scale: f32,
199    zero_point: i32,
200    output: &mut [i8],
201) -> TorshResult<()> {
202    if input.len() != output.len() {
203        return Err(TorshError::InvalidArgument(
204            "Input and output length mismatch".to_string(),
205        ));
206    }
207
208    if scale <= 0.0 {
209        return Err(TorshError::InvalidArgument(
210            "Scale must be positive".to_string(),
211        ));
212    }
213
214    let inv_scale = 1.0 / scale;
215    let zero_point_f32 = zero_point as f32;
216
217    // Use optimized parallel processing for quantization to INT8
218    input
219        .par_iter()
220        .zip(output.par_iter_mut())
221        .for_each(|(&x, out)| {
222            let quantized = (x * inv_scale).round() + zero_point_f32;
223            *out = quantized.clamp(-128.0, 127.0) as i8;
224        });
225
226    Ok(())
227}
228
229/// SIMD-accelerated statistics calculation for quantization calibration
230pub fn calculate_tensor_stats_simd(data: &[f32]) -> TorshResult<TensorStats> {
231    if data.is_empty() {
232        return Err(TorshError::InvalidArgument(
233            "Cannot calculate stats of empty tensor".to_string(),
234        ));
235    }
236
237    let (min_val, max_val) = find_min_max_simd(data)?;
238
239    // Calculate mean using parallel reduction
240    let sum: f64 = data.par_iter().map(|&x| x as f64).sum();
241    let mean = sum / data.len() as f64;
242
243    // Calculate variance using parallel reduction
244    let variance_sum: f64 = data
245        .par_iter()
246        .map(|&x| {
247            let diff = x as f64 - mean;
248            diff * diff
249        })
250        .sum();
251    let variance = variance_sum / data.len() as f64;
252    let std_dev = variance.sqrt();
253
254    Ok(TensorStats {
255        min: min_val,
256        max: max_val,
257        mean: mean as f32,
258        std_dev: std_dev as f32,
259        variance: variance as f32,
260    })
261}
262
263/// Tensor statistics structure
264#[derive(Debug, Clone)]
265pub struct TensorStats {
266    pub min: f32,
267    pub max: f32,
268    pub mean: f32,
269    pub std_dev: f32,
270    pub variance: f32,
271}
272
273/// Check if SIMD operations are available on current hardware
274pub fn is_simd_available() -> bool {
275    // Check for common SIMD instruction sets (x86 and ARM)
276    cfg!(any(
277        target_feature = "avx512f",
278        target_feature = "avx2",
279        target_feature = "avx",
280        target_feature = "sse2",
281        target_feature = "neon" // ARM NEON support
282    ))
283}
284
285/// Get optimal SIMD vector width for current hardware
286pub fn get_simd_width() -> usize {
287    // Return optimal width based on available instruction set
288    // AVX2: 8 x f32, AVX-512: 16 x f32, NEON: 4 x f32
289    if cfg!(target_feature = "avx512f") {
290        16 // AVX-512: 16 x f32 elements
291    } else if cfg!(target_feature = "avx2") {
292        8 // AVX2: 8 x f32 elements
293    } else if cfg!(any(target_feature = "sse2", target_feature = "neon")) {
294        4 // SSE2/NEON: 4 x f32 elements
295    } else {
296        1 // Fallback to scalar
297    }
298}
299
300/// ARM NEON-specific optimized quantization
301#[cfg(target_arch = "aarch64")]
302pub fn quantize_neon_optimized(
303    input: &[f32],
304    scale: f32,
305    zero_point: i32,
306    output: &mut [f32],
307) -> TorshResult<()> {
308    if input.len() != output.len() {
309        return Err(TorshError::InvalidArgument(
310            "Input and output length mismatch".to_string(),
311        ));
312    }
313
314    if scale <= 0.0 {
315        return Err(TorshError::InvalidArgument(
316            "Scale must be positive".to_string(),
317        ));
318    }
319
320    let inv_scale = 1.0 / scale;
321    let zero_point_f32 = zero_point as f32;
322
323    // Process in NEON-friendly chunks of 4 elements
324    const NEON_WIDTH: usize = 4;
325    let chunks = input.len() / NEON_WIDTH;
326
327    // Process aligned chunks for optimal NEON performance
328    for i in 0..chunks {
329        let start = i * NEON_WIDTH;
330        let end = start + NEON_WIDTH;
331
332        // Use vectorized operations for NEON
333        for (&inp, out) in input[start..end].iter().zip(output[start..end].iter_mut()) {
334            let quantized = (inp * inv_scale).round() + zero_point_f32;
335            *out = quantized.clamp(-128.0, 127.0);
336        }
337    }
338
339    // Handle remaining elements
340    let remainder_start = chunks * NEON_WIDTH;
341    for (&inp, out) in input[remainder_start..]
342        .iter()
343        .zip(output[remainder_start..].iter_mut())
344    {
345        let quantized = (inp * inv_scale).round() + zero_point_f32;
346        *out = quantized.clamp(-128.0, 127.0);
347    }
348
349    Ok(())
350}
351
352/// ARM NEON-optimized min/max finding
353#[cfg(target_arch = "aarch64")]
354pub fn find_min_max_neon(data: &[f32]) -> TorshResult<(f32, f32)> {
355    if data.is_empty() {
356        return Err(TorshError::InvalidArgument(
357            "Cannot find min/max of empty array".to_string(),
358        ));
359    }
360
361    const NEON_WIDTH: usize = 4;
362    let chunks = data.len() / NEON_WIDTH;
363
364    let mut min_val = f32::INFINITY;
365    let mut max_val = f32::NEG_INFINITY;
366
367    // Process in NEON-friendly chunks
368    for i in 0..chunks {
369        let start = i * NEON_WIDTH;
370        let end = start + NEON_WIDTH;
371
372        // Vectorized min/max operations for NEON
373        for &val in &data[start..end] {
374            min_val = min_val.min(val);
375            max_val = max_val.max(val);
376        }
377    }
378
379    // Handle remaining elements
380    let remainder_start = chunks * NEON_WIDTH;
381    for &val in &data[remainder_start..] {
382        min_val = min_val.min(val);
383        max_val = max_val.max(val);
384    }
385
386    Ok((min_val, max_val))
387}
388
389/// Mobile-optimized quantization with reduced memory usage
390pub fn quantize_mobile_optimized(
391    input: &[f32],
392    scale: f32,
393    zero_point: i32,
394    output: &mut [i8],
395    use_reduced_precision: bool,
396) -> TorshResult<()> {
397    if input.len() != output.len() {
398        return Err(TorshError::InvalidArgument(
399            "Input and output length mismatch".to_string(),
400        ));
401    }
402
403    if scale <= 0.0 {
404        return Err(TorshError::InvalidArgument(
405            "Scale must be positive".to_string(),
406        ));
407    }
408
409    let inv_scale = if use_reduced_precision {
410        // Use faster but less precise arithmetic for mobile
411        1.0 / scale
412    } else {
413        1.0 / scale
414    };
415
416    let zero_point_f32 = zero_point as f32;
417
418    // Use smaller chunk sizes for better mobile cache performance
419    const MOBILE_CHUNK_SIZE: usize = 256;
420
421    if input.len() > MOBILE_CHUNK_SIZE {
422        // Process in mobile-optimized chunks
423        input
424            .chunks(MOBILE_CHUNK_SIZE)
425            .zip(output.chunks_mut(MOBILE_CHUNK_SIZE))
426            .for_each(|(input_chunk, output_chunk)| {
427                for (&x, out) in input_chunk.iter().zip(output_chunk.iter_mut()) {
428                    let quantized = if use_reduced_precision {
429                        // Faster rounding for mobile
430                        (x * inv_scale + 0.5).floor() + zero_point_f32
431                    } else {
432                        (x * inv_scale).round() + zero_point_f32
433                    };
434                    *out = quantized.clamp(-128.0, 127.0) as i8;
435                }
436            });
437    } else {
438        // Direct processing for small tensors
439        for (&x, out) in input.iter().zip(output.iter_mut()) {
440            let quantized = (x * inv_scale).round() + zero_point_f32;
441            *out = quantized.clamp(-128.0, 127.0) as i8;
442        }
443    }
444
445    Ok(())
446}
447
448/// Get mobile-specific optimization recommendations
449pub fn get_mobile_optimization_hints() -> MobileOptimizationHints {
450    MobileOptimizationHints {
451        prefer_int8: true,
452        use_reduced_precision: cfg!(target_os = "android") || cfg!(target_os = "ios"),
453        optimal_chunk_size: if cfg!(target_arch = "aarch64") {
454            256
455        } else {
456            512
457        },
458        enable_fast_math: true,
459        prefer_sequential: false, // Mobile devices often benefit from some parallelism
460    }
461}
462
463/// Mobile optimization configuration hints
464#[derive(Debug, Clone)]
465pub struct MobileOptimizationHints {
466    pub prefer_int8: bool,
467    pub use_reduced_precision: bool,
468    pub optimal_chunk_size: usize,
469    pub enable_fast_math: bool,
470    pub prefer_sequential: bool,
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476    use approx::assert_relative_eq;
477
478    #[test]
479    fn test_quantize_per_tensor_affine_simd() {
480        let input = vec![1.0, 2.0, 3.0, 4.0];
481        let mut output = vec![0.0; 4];
482
483        quantize_per_tensor_affine_simd(&input, 0.1, 0, &mut output).unwrap();
484
485        assert_relative_eq!(output[0], 10.0, epsilon = 1e-6);
486        assert_relative_eq!(output[1], 20.0, epsilon = 1e-6);
487        assert_relative_eq!(output[2], 30.0, epsilon = 1e-6);
488        assert_relative_eq!(output[3], 40.0, epsilon = 1e-6);
489    }
490
491    #[test]
492    fn test_dequantize_per_tensor_affine_simd() {
493        let input = vec![10.0, 20.0, 30.0, 40.0];
494        let mut output = vec![0.0; 4];
495
496        dequantize_per_tensor_affine_simd(&input, 0.1, 0, &mut output).unwrap();
497
498        assert_relative_eq!(output[0], 1.0, epsilon = 1e-6);
499        assert_relative_eq!(output[1], 2.0, epsilon = 1e-6);
500        assert_relative_eq!(output[2], 3.0, epsilon = 1e-6);
501        assert_relative_eq!(output[3], 4.0, epsilon = 1e-6);
502    }
503
504    #[test]
505    fn test_find_min_max_simd() {
506        let data = vec![-1.5, 0.0, 2.3, -0.8, 4.7, 1.2];
507        let (min_val, max_val) = find_min_max_simd(&data).unwrap();
508
509        assert_relative_eq!(min_val, -1.5, epsilon = 1e-6);
510        assert_relative_eq!(max_val, 4.7, epsilon = 1e-6);
511    }
512
513    #[test]
514    fn test_calculate_tensor_stats_simd() {
515        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
516        let stats = calculate_tensor_stats_simd(&data).unwrap();
517
518        assert_relative_eq!(stats.min, 1.0, epsilon = 1e-6);
519        assert_relative_eq!(stats.max, 5.0, epsilon = 1e-6);
520        assert_relative_eq!(stats.mean, 3.0, epsilon = 1e-6);
521        assert_relative_eq!(stats.std_dev, (2.0f64).sqrt() as f32, epsilon = 1e-4);
522    }
523
524    #[test]
525    fn test_quantize_to_int8_simd() {
526        let input = vec![1.0, 2.0, 3.0, 4.0];
527        let mut output = vec![0i8; 4];
528
529        quantize_to_int8_simd(&input, 0.1, 0, &mut output).unwrap();
530
531        assert_eq!(output[0], 10i8);
532        assert_eq!(output[1], 20i8);
533        assert_eq!(output[2], 30i8);
534        assert_eq!(output[3], 40i8);
535    }
536
537    #[test]
538    fn test_error_cases() {
539        let input = vec![1.0, 2.0];
540        let mut output = vec![0.0; 3]; // Wrong size
541
542        let result = quantize_per_tensor_affine_simd(&input, 0.1, 0, &mut output);
543        assert!(result.is_err());
544
545        let mut output_correct = vec![0.0; 2];
546        let result = quantize_per_tensor_affine_simd(&input, -0.1, 0, &mut output_correct);
547        assert!(result.is_err());
548
549        let empty_data: Vec<f32> = vec![];
550        let result = find_min_max_simd(&empty_data);
551        assert!(result.is_err());
552    }
553
554    #[test]
555    fn test_simd_availability() {
556        let available = is_simd_available();
557        let width = get_simd_width();
558
559        // SIMD availability depends on compile-time target features
560        // We just check that the functions return reasonable values
561        assert!(width >= 1); // Should be at least scalar width
562
563        // Test that availability is consistent with width
564        if available {
565            assert!(width > 1); // If SIMD available, width should be > 1
566        }
567    }
568
569    #[test]
570    fn test_mobile_optimized_quantization() {
571        let input = vec![1.0, 2.0, 3.0, 4.0, -1.0, -2.0];
572        let mut output = vec![0i8; 6];
573
574        quantize_mobile_optimized(&input, 0.1, 0, &mut output, false).unwrap();
575
576        assert_eq!(output[0], 10i8);
577        assert_eq!(output[1], 20i8);
578        assert_eq!(output[2], 30i8);
579        assert_eq!(output[3], 40i8);
580        assert_eq!(output[4], -10i8);
581        assert_eq!(output[5], -20i8);
582    }
583
584    #[test]
585    fn test_mobile_optimized_quantization_reduced_precision() {
586        let input = vec![1.0, 2.0, 3.0, 4.0];
587        let mut output = vec![0i8; 4];
588
589        // Test with reduced precision enabled
590        quantize_mobile_optimized(&input, 0.1, 0, &mut output, true).unwrap();
591
592        // Results should be close but may have slight differences due to reduced precision
593        assert!((output[0] as f32 - 10.0).abs() <= 1.0);
594        assert!((output[1] as f32 - 20.0).abs() <= 1.0);
595    }
596
597    #[cfg(target_arch = "aarch64")]
598    #[test]
599    fn test_neon_quantization() {
600        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
601        let mut output = vec![0.0; 8];
602
603        quantize_neon_optimized(&input, 0.1, 0, &mut output).unwrap();
604
605        assert_relative_eq!(output[0], 10.0, epsilon = 1e-6);
606        assert_relative_eq!(output[1], 20.0, epsilon = 1e-6);
607        assert_relative_eq!(output[7], 80.0, epsilon = 1e-6);
608    }
609
610    #[cfg(target_arch = "aarch64")]
611    #[test]
612    fn test_neon_min_max() {
613        let data = vec![-1.5, 0.0, 2.3, -0.8, 4.7, 1.2, 9.5, -2.1];
614        let (min_val, max_val) = find_min_max_neon(&data).unwrap();
615
616        assert_relative_eq!(min_val, -2.1, epsilon = 1e-6);
617        assert_relative_eq!(max_val, 9.5, epsilon = 1e-6);
618    }
619
620    #[test]
621    fn test_mobile_optimization_hints() {
622        let hints = get_mobile_optimization_hints();
623
624        assert!(hints.prefer_int8); // Should prefer INT8 for mobile
625        assert!(hints.optimal_chunk_size > 0); // Should have a reasonable chunk size
626        assert_eq!(hints.prefer_sequential, false); // Should allow some parallelism
627    }
628}