Skip to main content

torsh_backend/quantization/
hardware.rs

1//! Hardware acceleration features for quantization operations
2//!
3//! This module provides hardware-specific optimizations and feature detection
4//! for quantization operations. It includes support for various CPU and GPU
5//! acceleration technologies including SIMD, VNNI, DP4A, and Tensor Cores.
6
7// Framework infrastructure - components designed for future use
8#![allow(dead_code)]
9use crate::{BackendResult, Device};
10use torsh_core::error::TorshError;
11
12#[cfg(not(feature = "std"))]
13use alloc::{boxed::Box, string::String, vec::Vec};
14
15/// Hardware-specific quantization features available on the current device
16///
17/// This struct encapsulates the hardware capabilities available for quantization
18/// operations, enabling the system to choose optimal implementations based on
19/// what the hardware supports.
20#[derive(Debug, Clone)]
21pub struct QuantizationHardwareFeatures {
22    /// Supports INT8 SIMD operations
23    ///
24    /// Indicates whether the hardware can perform vectorized INT8 operations,
25    /// which significantly accelerates quantized computations.
26    pub supports_int8_simd: bool,
27
28    /// Supports packed INT4 operations
29    ///
30    /// Some hardware can efficiently handle sub-byte quantization formats
31    /// like INT4, where multiple values are packed into single bytes.
32    pub supports_int4_packed: bool,
33
34    /// Supports VNNI (Vector Neural Network Instructions)
35    ///
36    /// Intel's VNNI instructions provide hardware acceleration for
37    /// neural network workloads, particularly beneficial for quantized models.
38    pub supports_vnni: bool,
39
40    /// Supports DP4A (4-element dot product and accumulate)
41    ///
42    /// NVIDIA's DP4A instruction performs 4-element dot products in a single
43    /// operation, ideal for quantized matrix operations on CUDA devices.
44    pub supports_dp4a: bool,
45
46    /// Supports tensor core operations
47    ///
48    /// Modern GPUs include specialized tensor cores for mixed-precision
49    /// and quantized neural network computations.
50    pub supports_tensor_cores: bool,
51
52    /// Supports mixed precision operations
53    ///
54    /// Hardware capability to efficiently mix different quantization
55    /// precisions within the same computation.
56    pub supports_mixed_precision: bool,
57
58    /// Maximum number of parallel operations
59    ///
60    /// The optimal number of parallel operations for this hardware,
61    /// used for scheduling and batching decisions.
62    pub max_parallel_ops: usize,
63}
64
65impl Default for QuantizationHardwareFeatures {
66    /// Conservative default hardware features
67    ///
68    /// Returns a conservative set of capabilities that should work
69    /// on any hardware without advanced acceleration features.
70    fn default() -> Self {
71        Self {
72            supports_int8_simd: false,
73            supports_int4_packed: false,
74            supports_vnni: false,
75            supports_dp4a: false,
76            supports_tensor_cores: false,
77            supports_mixed_precision: false,
78            max_parallel_ops: 1,
79        }
80    }
81}
82
83impl QuantizationHardwareFeatures {
84    /// Detect hardware features for the given device
85    ///
86    /// Performs runtime detection of available hardware acceleration
87    /// features and returns a capabilities structure.
88    ///
89    /// # Arguments
90    ///
91    /// * `device` - The target device to analyze
92    ///
93    /// # Returns
94    ///
95    /// A `QuantizationHardwareFeatures` struct with detected capabilities
96    pub fn detect_for_device(device: &Device) -> Self {
97        match device.device_type() {
98            torsh_core::device::DeviceType::Cpu => Self::detect_cpu_features(),
99            torsh_core::device::DeviceType::Cuda(_) => Self::detect_cuda_features(),
100            _ => Self::default(),
101        }
102    }
103
104    /// Detect CPU-specific quantization features
105    fn detect_cpu_features() -> Self {
106        Self {
107            supports_int8_simd: Self::detect_int8_simd(),
108            supports_int4_packed: true, // Generally available through software
109            supports_vnni: Self::detect_vnni(),
110            supports_dp4a: false,         // DP4A is CUDA-specific
111            supports_tensor_cores: false, // Tensor cores are GPU-specific
112            supports_mixed_precision: true,
113            max_parallel_ops: std::thread::available_parallelism()
114                .map(|n| n.get())
115                .unwrap_or(1),
116        }
117    }
118
119    /// Detect CUDA GPU quantization features
120    fn detect_cuda_features() -> Self {
121        Self {
122            supports_int8_simd: true, // CUDA has vectorized INT8 support
123            supports_int4_packed: true,
124            supports_vnni: false, // VNNI is Intel-specific
125            supports_dp4a: Self::detect_dp4a(),
126            supports_tensor_cores: Self::detect_tensor_cores(),
127            supports_mixed_precision: true,
128            max_parallel_ops: 1024, // Many CUDA cores available
129        }
130    }
131
132    /// Detect Intel VNNI (Vector Neural Network Instructions) support
133    ///
134    /// VNNI instructions accelerate neural network computations by providing
135    /// hardware support for common operations like dot products with INT8 data.
136    fn detect_vnni() -> bool {
137        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
138        {
139            // Check for VNNI support via CPUID
140            // Note: This checks for AVX512-VNNI specifically
141            // AVX-VNNI support would require different detection
142            is_x86_feature_detected!("avx512vnni")
143        }
144        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
145        {
146            false
147        }
148    }
149
150    /// Detect general INT8 SIMD support
151    ///
152    /// Checks for hardware support of vectorized INT8 operations,
153    /// which are crucial for efficient quantized computation.
154    fn detect_int8_simd() -> bool {
155        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
156        {
157            // Most modern x86 CPUs support some form of INT8 SIMD
158            is_x86_feature_detected!("sse2") || is_x86_feature_detected!("avx2")
159        }
160        #[cfg(target_arch = "aarch64")]
161        {
162            // ARM NEON supports INT8 operations
163            std::arch::is_aarch64_feature_detected!("neon")
164        }
165        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
166        {
167            false
168        }
169    }
170
171    /// Detect NVIDIA DP4A support
172    ///
173    /// DP4A (4-element dot product and accumulate) is available on
174    /// modern NVIDIA GPUs and provides efficient INT8 matrix operations.
175    fn detect_dp4a() -> bool {
176        // In a real implementation, this would query CUDA device properties
177        // For now, assume modern CUDA GPUs have DP4A support
178        true
179    }
180
181    /// Detect tensor core support
182    ///
183    /// Tensor cores provide specialized acceleration for mixed-precision
184    /// and quantized neural network operations on modern GPUs.
185    fn detect_tensor_cores() -> bool {
186        // In a real implementation, this would check GPU architecture
187        // (Volta, Turing, Ampere, etc.) for tensor core availability
188        true
189    }
190
191    /// Check if the hardware supports a specific quantization data type efficiently
192    ///
193    /// # Arguments
194    ///
195    /// * `dtype` - The quantization data type to check
196    ///
197    /// # Returns
198    ///
199    /// `true` if the hardware can efficiently process this data type
200    pub fn supports_dtype_efficiently(&self, dtype: &crate::quantization::QuantizedDType) -> bool {
201        use crate::quantization::QuantizedDType;
202
203        match dtype {
204            QuantizedDType::Int8 | QuantizedDType::UInt8 => self.supports_int8_simd,
205            QuantizedDType::Int4 | QuantizedDType::UInt4 => self.supports_int4_packed,
206            QuantizedDType::Binary => self.supports_int8_simd, // Can use SIMD for binary ops
207            QuantizedDType::Int16 | QuantizedDType::UInt16 => true, // Generally well supported
208            QuantizedDType::Mixed(_) => self.supports_mixed_precision,
209        }
210    }
211
212    /// Get the optimal block size for parallel operations
213    ///
214    /// Returns the recommended block size for batching operations
215    /// based on hardware characteristics and parallelism capabilities.
216    pub fn optimal_block_size(&self) -> usize {
217        if self.supports_tensor_cores {
218            // Tensor cores work well with larger blocks
219            256
220        } else if self.supports_int8_simd {
221            // SIMD operations benefit from medium-sized blocks
222            64
223        } else {
224            // Conservative block size for scalar operations
225            16
226        }
227    }
228
229    /// Get the performance preference ranking for quantization schemes
230    ///
231    /// Returns quantization schemes ordered by expected performance
232    /// on this hardware, with the fastest schemes first.
233    pub fn performance_ranking(&self) -> Vec<crate::quantization::QuantizationScheme> {
234        use crate::quantization::QuantizationScheme;
235
236        let mut schemes = vec![
237            QuantizationScheme::Symmetric,   // Often fastest due to no zero point
238            QuantizationScheme::Linear,      // Standard implementation
239            QuantizationScheme::Asymmetric,  // Requires zero point handling
240            QuantizationScheme::ChannelWise, // More complex but better accuracy
241            QuantizationScheme::BlockWise,   // Complex memory access patterns
242            QuantizationScheme::Logarithmic, // Requires expensive log operations
243        ];
244
245        // Adjust ranking based on hardware capabilities
246        if self.supports_vnni || self.supports_dp4a {
247            // Hardware-accelerated schemes can handle complexity better
248            schemes.swap(2, 3); // Prefer channel-wise over asymmetric
249        }
250
251        schemes
252    }
253}
254
255/// SIMD-accelerated quantization operations
256///
257/// This struct provides vectorized implementations of quantization operations
258/// that can take advantage of CPU SIMD instructions for improved performance.
259#[derive(Debug, Clone)]
260pub struct SimdQuantizationOps {
261    /// Whether SIMD operations are available
262    simd_available: bool,
263    /// Optimal vector width for this hardware
264    vector_width: usize,
265}
266
267impl SimdQuantizationOps {
268    /// Create new SIMD quantization operations
269    pub fn new() -> Self {
270        Self {
271            simd_available: QuantizationHardwareFeatures::detect_int8_simd(),
272            vector_width: Self::detect_vector_width(),
273        }
274    }
275
276    /// Detect optimal vector width for SIMD operations
277    fn detect_vector_width() -> usize {
278        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
279        {
280            if is_x86_feature_detected!("avx512f") {
281                64 // AVX-512 can handle 64 bytes (512 bits)
282            } else if is_x86_feature_detected!("avx2") {
283                32 // AVX2 can handle 32 bytes (256 bits)
284            } else if is_x86_feature_detected!("sse2") {
285                16 // SSE2 can handle 16 bytes (128 bits)
286            } else {
287                4 // Fallback to scalar with some vectorization
288            }
289        }
290        #[cfg(target_arch = "aarch64")]
291        {
292            16 // ARM NEON typically handles 128-bit vectors
293        }
294        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
295        {
296            4 // Conservative fallback
297        }
298    }
299
300    /// SIMD-accelerated f32 to u8 quantization
301    ///
302    /// Uses vectorized operations to quantize multiple floating-point values
303    /// to 8-bit unsigned integers simultaneously.
304    pub fn quantize_f32_to_u8_simd(
305        &self,
306        input: &[f32],
307        scale: f32,
308        zero_point: f32,
309    ) -> BackendResult<Vec<u8>> {
310        if !self.simd_available {
311            return Err(TorshError::BackendError("SIMD not available".to_string()));
312        }
313
314        let mut output = Vec::with_capacity(input.len());
315        let inv_scale = 1.0 / scale;
316
317        // Process in chunks optimized for the vector width
318        let chunk_size = self.vector_width / 4; // 4 bytes per f32
319
320        for chunk in input.chunks(chunk_size) {
321            // In a real implementation, this would use platform-specific SIMD intrinsics
322            // For portability, we use a vectorized approach that compilers can optimize
323            for &val in chunk {
324                let quantized = (val * inv_scale + zero_point).round().clamp(0.0, 255.0) as u8;
325                output.push(quantized);
326            }
327        }
328
329        Ok(output)
330    }
331
332    /// SIMD-accelerated u8 to f32 dequantization
333    pub fn dequantize_u8_to_f32_simd(
334        &self,
335        input: &[u8],
336        scale: f32,
337        zero_point: f32,
338    ) -> BackendResult<Vec<f32>> {
339        if !self.simd_available {
340            return Err(TorshError::BackendError("SIMD not available".to_string()));
341        }
342
343        let mut output = Vec::with_capacity(input.len());
344        let chunk_size = self.vector_width; // 1 byte per u8
345
346        for chunk in input.chunks(chunk_size) {
347            for &val in chunk {
348                let dequantized = (val as f32 - zero_point) * scale;
349                output.push(dequantized);
350            }
351        }
352
353        Ok(output)
354    }
355
356    /// SIMD-accelerated INT8 vector addition
357    pub fn add_int8_simd(&self, a: &[i8], b: &[i8]) -> BackendResult<Vec<i8>> {
358        if !self.simd_available || a.len() != b.len() {
359            return Err(TorshError::BackendError(
360                "Invalid input for SIMD addition".to_string(),
361            ));
362        }
363
364        let mut result = Vec::with_capacity(a.len());
365        let chunk_size = self.vector_width;
366
367        for (a_chunk, b_chunk) in a.chunks(chunk_size).zip(b.chunks(chunk_size)) {
368            for (&a_val, &b_val) in a_chunk.iter().zip(b_chunk.iter()) {
369                let sum = (a_val as i16 + b_val as i16).clamp(-128, 127) as i8;
370                result.push(sum);
371            }
372        }
373
374        Ok(result)
375    }
376
377    /// Check if SIMD operations are available
378    pub fn is_available(&self) -> bool {
379        self.simd_available
380    }
381
382    /// Get the optimal vector width for this hardware
383    pub fn vector_width(&self) -> usize {
384        self.vector_width
385    }
386}
387
388/// Memory layout optimization for quantized data
389///
390/// Provides utilities for organizing quantized data in memory layouts
391/// that are optimal for hardware acceleration.
392#[derive(Debug, Clone)]
393pub struct QuantizedMemoryLayout {
394    /// Whether to use packed layouts for sub-byte types
395    pub use_packed_layout: bool,
396    /// Preferred memory alignment in bytes
397    pub alignment: usize,
398    /// Whether to use interleaved data layouts
399    pub use_interleaving: bool,
400}
401
402impl QuantizedMemoryLayout {
403    /// Create optimal memory layout for the given hardware features
404    pub fn optimal_for_hardware(features: &QuantizationHardwareFeatures) -> Self {
405        Self {
406            use_packed_layout: features.supports_int4_packed,
407            alignment: if features.supports_int8_simd { 32 } else { 16 },
408            use_interleaving: features.supports_tensor_cores,
409        }
410    }
411
412    /// Calculate optimal stride for accessing quantized data
413    pub fn optimal_stride(&self, data_width: usize) -> usize {
414        // Align stride to hardware requirements
415        let aligned_width = (data_width + self.alignment - 1) & !(self.alignment - 1);
416        aligned_width
417    }
418
419    /// Check if the given memory layout is hardware-optimal
420    pub fn is_layout_optimal(&self, data_size: usize, stride: usize) -> bool {
421        let optimal_stride = self.optimal_stride(data_size);
422        stride >= optimal_stride && stride % self.alignment == 0
423    }
424}
425
426/// Hardware-specific performance hints
427///
428/// Provides recommendations for optimal quantization strategies
429/// based on detected hardware capabilities.
430#[derive(Debug, Clone)]
431pub struct QuantizationPerformanceHints {
432    /// Recommended quantization data types in order of preference
433    pub preferred_dtypes: Vec<crate::quantization::QuantizedDType>,
434    /// Recommended quantization schemes in order of preference
435    pub preferred_schemes: Vec<crate::quantization::QuantizationScheme>,
436    /// Optimal batch size for operations
437    pub optimal_batch_size: usize,
438    /// Whether to prefer in-place operations
439    pub prefer_inplace: bool,
440}
441
442impl QuantizationPerformanceHints {
443    /// Generate performance hints for the given hardware features
444    pub fn for_hardware(features: &QuantizationHardwareFeatures) -> Self {
445        use crate::quantization::QuantizedDType;
446
447        let mut preferred_dtypes = vec![];
448
449        // Order data types by hardware support and performance
450        if features.supports_int8_simd {
451            preferred_dtypes.extend([QuantizedDType::Int8, QuantizedDType::UInt8]);
452        }
453        if features.supports_int4_packed {
454            preferred_dtypes.extend([QuantizedDType::Int4, QuantizedDType::UInt4]);
455        }
456        if features.supports_mixed_precision {
457            preferred_dtypes.push(QuantizedDType::Mixed(vec![8, 4, 8]));
458        }
459
460        // Add remaining types
461        preferred_dtypes.extend([
462            QuantizedDType::Int16,
463            QuantizedDType::UInt16,
464            QuantizedDType::Binary,
465        ]);
466
467        // Use hardware-specific scheme ranking
468        let preferred_schemes = features.performance_ranking();
469
470        Self {
471            preferred_dtypes,
472            preferred_schemes,
473            optimal_batch_size: features.optimal_block_size(),
474            prefer_inplace: !features.supports_tensor_cores, // Tensor cores often prefer separate output
475        }
476    }
477
478    /// Get the best quantization data type for the given requirements
479    pub fn best_dtype_for_accuracy(
480        &self,
481        min_accuracy: f64,
482    ) -> Option<&crate::quantization::QuantizedDType> {
483        use crate::quantization::QuantizedDType;
484
485        // Higher bit widths generally provide better accuracy
486        for dtype in &self.preferred_dtypes {
487            let expected_accuracy = match dtype {
488                QuantizedDType::Int16 | QuantizedDType::UInt16 => 0.99,
489                QuantizedDType::Int8 | QuantizedDType::UInt8 => 0.95,
490                QuantizedDType::Int4 | QuantizedDType::UInt4 => 0.85,
491                QuantizedDType::Binary => 0.70,
492                QuantizedDType::Mixed(_) => 0.90,
493            };
494
495            if expected_accuracy >= min_accuracy {
496                return Some(dtype);
497            }
498        }
499
500        None
501    }
502
503    /// Get the best quantization scheme for the given performance requirements
504    pub fn best_scheme_for_latency(
505        &self,
506        max_latency_factor: f64,
507    ) -> Option<&crate::quantization::QuantizationScheme> {
508        use crate::quantization::QuantizationScheme;
509
510        // Different schemes have different computational complexity
511        for scheme in &self.preferred_schemes {
512            let latency_factor = match scheme {
513                QuantizationScheme::Symmetric => 1.0,
514                QuantizationScheme::Linear => 1.1,
515                QuantizationScheme::Asymmetric => 1.2,
516                QuantizationScheme::ChannelWise => 1.3,
517                QuantizationScheme::BlockWise => 1.4,
518                QuantizationScheme::Logarithmic => 2.0,
519            };
520
521            if latency_factor <= max_latency_factor {
522                return Some(scheme);
523            }
524        }
525
526        None
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533
534    #[test]
535    fn test_hardware_features_detection() {
536        let features = QuantizationHardwareFeatures::default();
537
538        // Default features should be conservative
539        assert!(!features.supports_int8_simd);
540        assert!(!features.supports_vnni);
541        assert!(!features.supports_dp4a);
542        assert!(!features.supports_tensor_cores);
543        assert_eq!(features.max_parallel_ops, 1);
544    }
545
546    #[test]
547    fn test_cpu_features_detection() {
548        let features = QuantizationHardwareFeatures::detect_cpu_features();
549
550        // CPU features should never include GPU-specific capabilities
551        assert!(!features.supports_dp4a);
552        assert!(!features.supports_tensor_cores);
553        assert!(features.max_parallel_ops >= 1);
554    }
555
556    #[test]
557    fn test_cuda_features_detection() {
558        let features = QuantizationHardwareFeatures::detect_cuda_features();
559
560        // CUDA features should include GPU-specific capabilities
561        assert!(features.supports_int8_simd);
562        assert!(!features.supports_vnni); // VNNI is Intel-specific
563        assert!(features.max_parallel_ops > 1);
564    }
565
566    #[test]
567    fn test_device_feature_detection() {
568        let cpu_device = Device::cpu().unwrap();
569        let cpu_features = QuantizationHardwareFeatures::detect_for_device(&cpu_device);
570
571        // Should detect CPU-appropriate features
572        assert!(!cpu_features.supports_dp4a);
573        assert!(!cpu_features.supports_tensor_cores);
574    }
575
576    #[test]
577    fn test_dtype_support_check() {
578        use crate::quantization::QuantizedDType;
579
580        let mut features = QuantizationHardwareFeatures::default();
581        features.supports_int8_simd = true;
582        features.supports_int4_packed = true;
583
584        assert!(features.supports_dtype_efficiently(&QuantizedDType::Int8));
585        assert!(features.supports_dtype_efficiently(&QuantizedDType::Int4));
586        assert!(!features.supports_dtype_efficiently(&QuantizedDType::Mixed(vec![8, 4])));
587    }
588
589    #[test]
590    fn test_optimal_block_size() {
591        let mut features = QuantizationHardwareFeatures::default();
592
593        // Test different hardware configurations
594        features.supports_tensor_cores = true;
595        assert_eq!(features.optimal_block_size(), 256);
596
597        features.supports_tensor_cores = false;
598        features.supports_int8_simd = true;
599        assert_eq!(features.optimal_block_size(), 64);
600
601        features.supports_int8_simd = false;
602        assert_eq!(features.optimal_block_size(), 16);
603    }
604
605    #[test]
606    fn test_performance_ranking() {
607        let features = QuantizationHardwareFeatures::default();
608        let ranking = features.performance_ranking();
609
610        // Should have all schemes
611        assert_eq!(ranking.len(), 6);
612
613        // Symmetric should typically be first (fastest)
614        use crate::quantization::QuantizationScheme;
615        assert_eq!(ranking[0], QuantizationScheme::Symmetric);
616    }
617
618    #[test]
619    fn test_simd_ops_creation() {
620        let simd_ops = SimdQuantizationOps::new();
621
622        // Should detect SIMD availability appropriately for the platform
623        assert!(simd_ops.vector_width() >= 4);
624    }
625
626    #[test]
627    fn test_vector_width_detection() {
628        let width = SimdQuantizationOps::detect_vector_width();
629
630        // Should return a reasonable vector width
631        assert!(width >= 4);
632        assert!(width <= 64);
633
634        // Should be a power of 2 or multiple of 4
635        assert!(width % 4 == 0);
636    }
637
638    #[test]
639    fn test_memory_layout_optimization() {
640        let features = QuantizationHardwareFeatures::default();
641        let layout = QuantizedMemoryLayout::optimal_for_hardware(&features);
642
643        assert!(layout.alignment >= 16);
644        assert!(!layout.use_packed_layout); // Default doesn't support packed
645    }
646
647    #[test]
648    fn test_optimal_stride_calculation() {
649        let layout = QuantizedMemoryLayout {
650            use_packed_layout: false,
651            alignment: 32,
652            use_interleaving: false,
653        };
654
655        // Test stride calculation with different data widths
656        assert_eq!(layout.optimal_stride(10), 32); // Rounds up to alignment
657        assert_eq!(layout.optimal_stride(32), 32); // Already aligned
658        assert_eq!(layout.optimal_stride(50), 64); // Rounds up to next alignment
659    }
660
661    #[test]
662    fn test_layout_optimality_check() {
663        let layout = QuantizedMemoryLayout {
664            use_packed_layout: false,
665            alignment: 16,
666            use_interleaving: false,
667        };
668
669        assert!(layout.is_layout_optimal(10, 16)); // Properly aligned
670        assert!(layout.is_layout_optimal(10, 32)); // Over-aligned (OK)
671        assert!(!layout.is_layout_optimal(10, 15)); // Under-aligned
672        assert!(!layout.is_layout_optimal(10, 17)); // Misaligned
673    }
674
675    #[test]
676    fn test_performance_hints_generation() {
677        let features = QuantizationHardwareFeatures {
678            supports_int8_simd: true,
679            supports_int4_packed: true,
680            supports_mixed_precision: true,
681            ..Default::default()
682        };
683
684        let hints = QuantizationPerformanceHints::for_hardware(&features);
685
686        // Should have preferences based on hardware support
687        assert!(!hints.preferred_dtypes.is_empty());
688        assert!(!hints.preferred_schemes.is_empty());
689        assert!(hints.optimal_batch_size > 0);
690    }
691
692    #[test]
693    fn test_best_dtype_for_accuracy() {
694        let hints = QuantizationPerformanceHints {
695            preferred_dtypes: vec![
696                crate::quantization::QuantizedDType::Int8,
697                crate::quantization::QuantizedDType::Int4,
698                crate::quantization::QuantizedDType::Binary,
699            ],
700            preferred_schemes: vec![],
701            optimal_batch_size: 64,
702            prefer_inplace: false,
703        };
704
705        // Should return INT8 for high accuracy requirements
706        let dtype = hints.best_dtype_for_accuracy(0.90);
707        assert!(dtype.is_some());
708
709        // Should return None for impossible accuracy requirements
710        let dtype = hints.best_dtype_for_accuracy(0.99);
711        assert!(dtype.is_none());
712    }
713
714    #[test]
715    fn test_best_scheme_for_latency() {
716        use crate::quantization::QuantizationScheme;
717
718        let hints = QuantizationPerformanceHints {
719            preferred_dtypes: vec![],
720            preferred_schemes: vec![
721                QuantizationScheme::Symmetric,
722                QuantizationScheme::Linear,
723                QuantizationScheme::Asymmetric,
724            ],
725            optimal_batch_size: 64,
726            prefer_inplace: false,
727        };
728
729        // Should return fastest scheme for strict latency requirements
730        let scheme = hints.best_scheme_for_latency(1.1);
731        assert!(scheme.is_some());
732
733        // Should return None for impossible latency requirements
734        let scheme = hints.best_scheme_for_latency(0.5);
735        assert!(scheme.is_none());
736    }
737
738    #[test]
739    fn test_simd_quantization_operations() {
740        let simd_ops = SimdQuantizationOps::new();
741
742        if simd_ops.is_available() {
743            let input = vec![1.0, 2.0, 3.0, 4.0];
744            let result = simd_ops.quantize_f32_to_u8_simd(&input, 1.0, 0.0);
745
746            if let Ok(quantized) = result {
747                assert_eq!(quantized.len(), input.len());
748                // Values should be approximately correct
749                assert!(quantized[0] <= 2); // 1.0 rounded
750                assert!(quantized[3] <= 5); // 4.0 rounded
751            }
752        }
753    }
754
755    #[test]
756    fn test_simd_dequantization_operations() {
757        let simd_ops = SimdQuantizationOps::new();
758
759        if simd_ops.is_available() {
760            let input = vec![1u8, 2u8, 3u8, 4u8];
761            let result = simd_ops.dequantize_u8_to_f32_simd(&input, 1.0, 0.0);
762
763            if let Ok(dequantized) = result {
764                assert_eq!(dequantized.len(), input.len());
765                // Values should match input (scale=1, zero_point=0)
766                assert!((dequantized[0] - 1.0).abs() < 0.001);
767                assert!((dequantized[3] - 4.0).abs() < 0.001);
768            }
769        }
770    }
771
772    #[test]
773    fn test_simd_int8_addition() {
774        let simd_ops = SimdQuantizationOps::new();
775
776        if simd_ops.is_available() {
777            let a = vec![10i8, 20i8, 30i8, 40i8];
778            let b = vec![5i8, 10i8, 15i8, 20i8];
779            let result = simd_ops.add_int8_simd(&a, &b);
780
781            if let Ok(sum) = result {
782                assert_eq!(sum.len(), a.len());
783                assert_eq!(sum[0], 15i8);
784                assert_eq!(sum[1], 30i8);
785                assert_eq!(sum[2], 45i8);
786                assert_eq!(sum[3], 60i8);
787            }
788        }
789    }
790}