1#![allow(dead_code)]
9use crate::{BackendResult, Device};
10use torsh_core::error::TorshError;
11
12#[cfg(not(feature = "std"))]
13use alloc::{boxed::Box, string::String, vec::Vec};
14
15#[derive(Debug, Clone)]
21pub struct QuantizationHardwareFeatures {
22 pub supports_int8_simd: bool,
27
28 pub supports_int4_packed: bool,
33
34 pub supports_vnni: bool,
39
40 pub supports_dp4a: bool,
45
46 pub supports_tensor_cores: bool,
51
52 pub supports_mixed_precision: bool,
57
58 pub max_parallel_ops: usize,
63}
64
65impl Default for QuantizationHardwareFeatures {
66 fn default() -> Self {
71 Self {
72 supports_int8_simd: false,
73 supports_int4_packed: false,
74 supports_vnni: false,
75 supports_dp4a: false,
76 supports_tensor_cores: false,
77 supports_mixed_precision: false,
78 max_parallel_ops: 1,
79 }
80 }
81}
82
83impl QuantizationHardwareFeatures {
84 pub fn detect_for_device(device: &Device) -> Self {
97 match device.device_type() {
98 torsh_core::device::DeviceType::Cpu => Self::detect_cpu_features(),
99 torsh_core::device::DeviceType::Cuda(_) => Self::detect_cuda_features(),
100 _ => Self::default(),
101 }
102 }
103
104 fn detect_cpu_features() -> Self {
106 Self {
107 supports_int8_simd: Self::detect_int8_simd(),
108 supports_int4_packed: true, supports_vnni: Self::detect_vnni(),
110 supports_dp4a: false, supports_tensor_cores: false, supports_mixed_precision: true,
113 max_parallel_ops: std::thread::available_parallelism()
114 .map(|n| n.get())
115 .unwrap_or(1),
116 }
117 }
118
119 fn detect_cuda_features() -> Self {
121 Self {
122 supports_int8_simd: true, supports_int4_packed: true,
124 supports_vnni: false, supports_dp4a: Self::detect_dp4a(),
126 supports_tensor_cores: Self::detect_tensor_cores(),
127 supports_mixed_precision: true,
128 max_parallel_ops: 1024, }
130 }
131
132 fn detect_vnni() -> bool {
137 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
138 {
139 is_x86_feature_detected!("avx512vnni")
143 }
144 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
145 {
146 false
147 }
148 }
149
150 fn detect_int8_simd() -> bool {
155 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
156 {
157 is_x86_feature_detected!("sse2") || is_x86_feature_detected!("avx2")
159 }
160 #[cfg(target_arch = "aarch64")]
161 {
162 std::arch::is_aarch64_feature_detected!("neon")
164 }
165 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
166 {
167 false
168 }
169 }
170
171 fn detect_dp4a() -> bool {
176 true
179 }
180
181 fn detect_tensor_cores() -> bool {
186 true
189 }
190
191 pub fn supports_dtype_efficiently(&self, dtype: &crate::quantization::QuantizedDType) -> bool {
201 use crate::quantization::QuantizedDType;
202
203 match dtype {
204 QuantizedDType::Int8 | QuantizedDType::UInt8 => self.supports_int8_simd,
205 QuantizedDType::Int4 | QuantizedDType::UInt4 => self.supports_int4_packed,
206 QuantizedDType::Binary => self.supports_int8_simd, QuantizedDType::Int16 | QuantizedDType::UInt16 => true, QuantizedDType::Mixed(_) => self.supports_mixed_precision,
209 }
210 }
211
212 pub fn optimal_block_size(&self) -> usize {
217 if self.supports_tensor_cores {
218 256
220 } else if self.supports_int8_simd {
221 64
223 } else {
224 16
226 }
227 }
228
229 pub fn performance_ranking(&self) -> Vec<crate::quantization::QuantizationScheme> {
234 use crate::quantization::QuantizationScheme;
235
236 let mut schemes = vec![
237 QuantizationScheme::Symmetric, QuantizationScheme::Linear, QuantizationScheme::Asymmetric, QuantizationScheme::ChannelWise, QuantizationScheme::BlockWise, QuantizationScheme::Logarithmic, ];
244
245 if self.supports_vnni || self.supports_dp4a {
247 schemes.swap(2, 3); }
250
251 schemes
252 }
253}
254
255#[derive(Debug, Clone)]
260pub struct SimdQuantizationOps {
261 simd_available: bool,
263 vector_width: usize,
265}
266
267impl SimdQuantizationOps {
268 pub fn new() -> Self {
270 Self {
271 simd_available: QuantizationHardwareFeatures::detect_int8_simd(),
272 vector_width: Self::detect_vector_width(),
273 }
274 }
275
276 fn detect_vector_width() -> usize {
278 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
279 {
280 if is_x86_feature_detected!("avx512f") {
281 64 } else if is_x86_feature_detected!("avx2") {
283 32 } else if is_x86_feature_detected!("sse2") {
285 16 } else {
287 4 }
289 }
290 #[cfg(target_arch = "aarch64")]
291 {
292 16 }
294 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
295 {
296 4 }
298 }
299
300 pub fn quantize_f32_to_u8_simd(
305 &self,
306 input: &[f32],
307 scale: f32,
308 zero_point: f32,
309 ) -> BackendResult<Vec<u8>> {
310 if !self.simd_available {
311 return Err(TorshError::BackendError("SIMD not available".to_string()));
312 }
313
314 let mut output = Vec::with_capacity(input.len());
315 let inv_scale = 1.0 / scale;
316
317 let chunk_size = self.vector_width / 4; for chunk in input.chunks(chunk_size) {
321 for &val in chunk {
324 let quantized = (val * inv_scale + zero_point).round().clamp(0.0, 255.0) as u8;
325 output.push(quantized);
326 }
327 }
328
329 Ok(output)
330 }
331
332 pub fn dequantize_u8_to_f32_simd(
334 &self,
335 input: &[u8],
336 scale: f32,
337 zero_point: f32,
338 ) -> BackendResult<Vec<f32>> {
339 if !self.simd_available {
340 return Err(TorshError::BackendError("SIMD not available".to_string()));
341 }
342
343 let mut output = Vec::with_capacity(input.len());
344 let chunk_size = self.vector_width; for chunk in input.chunks(chunk_size) {
347 for &val in chunk {
348 let dequantized = (val as f32 - zero_point) * scale;
349 output.push(dequantized);
350 }
351 }
352
353 Ok(output)
354 }
355
356 pub fn add_int8_simd(&self, a: &[i8], b: &[i8]) -> BackendResult<Vec<i8>> {
358 if !self.simd_available || a.len() != b.len() {
359 return Err(TorshError::BackendError(
360 "Invalid input for SIMD addition".to_string(),
361 ));
362 }
363
364 let mut result = Vec::with_capacity(a.len());
365 let chunk_size = self.vector_width;
366
367 for (a_chunk, b_chunk) in a.chunks(chunk_size).zip(b.chunks(chunk_size)) {
368 for (&a_val, &b_val) in a_chunk.iter().zip(b_chunk.iter()) {
369 let sum = (a_val as i16 + b_val as i16).clamp(-128, 127) as i8;
370 result.push(sum);
371 }
372 }
373
374 Ok(result)
375 }
376
377 pub fn is_available(&self) -> bool {
379 self.simd_available
380 }
381
382 pub fn vector_width(&self) -> usize {
384 self.vector_width
385 }
386}
387
388#[derive(Debug, Clone)]
393pub struct QuantizedMemoryLayout {
394 pub use_packed_layout: bool,
396 pub alignment: usize,
398 pub use_interleaving: bool,
400}
401
402impl QuantizedMemoryLayout {
403 pub fn optimal_for_hardware(features: &QuantizationHardwareFeatures) -> Self {
405 Self {
406 use_packed_layout: features.supports_int4_packed,
407 alignment: if features.supports_int8_simd { 32 } else { 16 },
408 use_interleaving: features.supports_tensor_cores,
409 }
410 }
411
412 pub fn optimal_stride(&self, data_width: usize) -> usize {
414 let aligned_width = (data_width + self.alignment - 1) & !(self.alignment - 1);
416 aligned_width
417 }
418
419 pub fn is_layout_optimal(&self, data_size: usize, stride: usize) -> bool {
421 let optimal_stride = self.optimal_stride(data_size);
422 stride >= optimal_stride && stride % self.alignment == 0
423 }
424}
425
426#[derive(Debug, Clone)]
431pub struct QuantizationPerformanceHints {
432 pub preferred_dtypes: Vec<crate::quantization::QuantizedDType>,
434 pub preferred_schemes: Vec<crate::quantization::QuantizationScheme>,
436 pub optimal_batch_size: usize,
438 pub prefer_inplace: bool,
440}
441
442impl QuantizationPerformanceHints {
443 pub fn for_hardware(features: &QuantizationHardwareFeatures) -> Self {
445 use crate::quantization::QuantizedDType;
446
447 let mut preferred_dtypes = vec![];
448
449 if features.supports_int8_simd {
451 preferred_dtypes.extend([QuantizedDType::Int8, QuantizedDType::UInt8]);
452 }
453 if features.supports_int4_packed {
454 preferred_dtypes.extend([QuantizedDType::Int4, QuantizedDType::UInt4]);
455 }
456 if features.supports_mixed_precision {
457 preferred_dtypes.push(QuantizedDType::Mixed(vec![8, 4, 8]));
458 }
459
460 preferred_dtypes.extend([
462 QuantizedDType::Int16,
463 QuantizedDType::UInt16,
464 QuantizedDType::Binary,
465 ]);
466
467 let preferred_schemes = features.performance_ranking();
469
470 Self {
471 preferred_dtypes,
472 preferred_schemes,
473 optimal_batch_size: features.optimal_block_size(),
474 prefer_inplace: !features.supports_tensor_cores, }
476 }
477
478 pub fn best_dtype_for_accuracy(
480 &self,
481 min_accuracy: f64,
482 ) -> Option<&crate::quantization::QuantizedDType> {
483 use crate::quantization::QuantizedDType;
484
485 for dtype in &self.preferred_dtypes {
487 let expected_accuracy = match dtype {
488 QuantizedDType::Int16 | QuantizedDType::UInt16 => 0.99,
489 QuantizedDType::Int8 | QuantizedDType::UInt8 => 0.95,
490 QuantizedDType::Int4 | QuantizedDType::UInt4 => 0.85,
491 QuantizedDType::Binary => 0.70,
492 QuantizedDType::Mixed(_) => 0.90,
493 };
494
495 if expected_accuracy >= min_accuracy {
496 return Some(dtype);
497 }
498 }
499
500 None
501 }
502
503 pub fn best_scheme_for_latency(
505 &self,
506 max_latency_factor: f64,
507 ) -> Option<&crate::quantization::QuantizationScheme> {
508 use crate::quantization::QuantizationScheme;
509
510 for scheme in &self.preferred_schemes {
512 let latency_factor = match scheme {
513 QuantizationScheme::Symmetric => 1.0,
514 QuantizationScheme::Linear => 1.1,
515 QuantizationScheme::Asymmetric => 1.2,
516 QuantizationScheme::ChannelWise => 1.3,
517 QuantizationScheme::BlockWise => 1.4,
518 QuantizationScheme::Logarithmic => 2.0,
519 };
520
521 if latency_factor <= max_latency_factor {
522 return Some(scheme);
523 }
524 }
525
526 None
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533
534 #[test]
535 fn test_hardware_features_detection() {
536 let features = QuantizationHardwareFeatures::default();
537
538 assert!(!features.supports_int8_simd);
540 assert!(!features.supports_vnni);
541 assert!(!features.supports_dp4a);
542 assert!(!features.supports_tensor_cores);
543 assert_eq!(features.max_parallel_ops, 1);
544 }
545
546 #[test]
547 fn test_cpu_features_detection() {
548 let features = QuantizationHardwareFeatures::detect_cpu_features();
549
550 assert!(!features.supports_dp4a);
552 assert!(!features.supports_tensor_cores);
553 assert!(features.max_parallel_ops >= 1);
554 }
555
556 #[test]
557 fn test_cuda_features_detection() {
558 let features = QuantizationHardwareFeatures::detect_cuda_features();
559
560 assert!(features.supports_int8_simd);
562 assert!(!features.supports_vnni); assert!(features.max_parallel_ops > 1);
564 }
565
566 #[test]
567 fn test_device_feature_detection() {
568 let cpu_device = Device::cpu().unwrap();
569 let cpu_features = QuantizationHardwareFeatures::detect_for_device(&cpu_device);
570
571 assert!(!cpu_features.supports_dp4a);
573 assert!(!cpu_features.supports_tensor_cores);
574 }
575
576 #[test]
577 fn test_dtype_support_check() {
578 use crate::quantization::QuantizedDType;
579
580 let mut features = QuantizationHardwareFeatures::default();
581 features.supports_int8_simd = true;
582 features.supports_int4_packed = true;
583
584 assert!(features.supports_dtype_efficiently(&QuantizedDType::Int8));
585 assert!(features.supports_dtype_efficiently(&QuantizedDType::Int4));
586 assert!(!features.supports_dtype_efficiently(&QuantizedDType::Mixed(vec![8, 4])));
587 }
588
589 #[test]
590 fn test_optimal_block_size() {
591 let mut features = QuantizationHardwareFeatures::default();
592
593 features.supports_tensor_cores = true;
595 assert_eq!(features.optimal_block_size(), 256);
596
597 features.supports_tensor_cores = false;
598 features.supports_int8_simd = true;
599 assert_eq!(features.optimal_block_size(), 64);
600
601 features.supports_int8_simd = false;
602 assert_eq!(features.optimal_block_size(), 16);
603 }
604
605 #[test]
606 fn test_performance_ranking() {
607 let features = QuantizationHardwareFeatures::default();
608 let ranking = features.performance_ranking();
609
610 assert_eq!(ranking.len(), 6);
612
613 use crate::quantization::QuantizationScheme;
615 assert_eq!(ranking[0], QuantizationScheme::Symmetric);
616 }
617
618 #[test]
619 fn test_simd_ops_creation() {
620 let simd_ops = SimdQuantizationOps::new();
621
622 assert!(simd_ops.vector_width() >= 4);
624 }
625
626 #[test]
627 fn test_vector_width_detection() {
628 let width = SimdQuantizationOps::detect_vector_width();
629
630 assert!(width >= 4);
632 assert!(width <= 64);
633
634 assert!(width % 4 == 0);
636 }
637
638 #[test]
639 fn test_memory_layout_optimization() {
640 let features = QuantizationHardwareFeatures::default();
641 let layout = QuantizedMemoryLayout::optimal_for_hardware(&features);
642
643 assert!(layout.alignment >= 16);
644 assert!(!layout.use_packed_layout); }
646
647 #[test]
648 fn test_optimal_stride_calculation() {
649 let layout = QuantizedMemoryLayout {
650 use_packed_layout: false,
651 alignment: 32,
652 use_interleaving: false,
653 };
654
655 assert_eq!(layout.optimal_stride(10), 32); assert_eq!(layout.optimal_stride(32), 32); assert_eq!(layout.optimal_stride(50), 64); }
660
661 #[test]
662 fn test_layout_optimality_check() {
663 let layout = QuantizedMemoryLayout {
664 use_packed_layout: false,
665 alignment: 16,
666 use_interleaving: false,
667 };
668
669 assert!(layout.is_layout_optimal(10, 16)); assert!(layout.is_layout_optimal(10, 32)); assert!(!layout.is_layout_optimal(10, 15)); assert!(!layout.is_layout_optimal(10, 17)); }
674
675 #[test]
676 fn test_performance_hints_generation() {
677 let features = QuantizationHardwareFeatures {
678 supports_int8_simd: true,
679 supports_int4_packed: true,
680 supports_mixed_precision: true,
681 ..Default::default()
682 };
683
684 let hints = QuantizationPerformanceHints::for_hardware(&features);
685
686 assert!(!hints.preferred_dtypes.is_empty());
688 assert!(!hints.preferred_schemes.is_empty());
689 assert!(hints.optimal_batch_size > 0);
690 }
691
692 #[test]
693 fn test_best_dtype_for_accuracy() {
694 let hints = QuantizationPerformanceHints {
695 preferred_dtypes: vec![
696 crate::quantization::QuantizedDType::Int8,
697 crate::quantization::QuantizedDType::Int4,
698 crate::quantization::QuantizedDType::Binary,
699 ],
700 preferred_schemes: vec![],
701 optimal_batch_size: 64,
702 prefer_inplace: false,
703 };
704
705 let dtype = hints.best_dtype_for_accuracy(0.90);
707 assert!(dtype.is_some());
708
709 let dtype = hints.best_dtype_for_accuracy(0.99);
711 assert!(dtype.is_none());
712 }
713
714 #[test]
715 fn test_best_scheme_for_latency() {
716 use crate::quantization::QuantizationScheme;
717
718 let hints = QuantizationPerformanceHints {
719 preferred_dtypes: vec![],
720 preferred_schemes: vec![
721 QuantizationScheme::Symmetric,
722 QuantizationScheme::Linear,
723 QuantizationScheme::Asymmetric,
724 ],
725 optimal_batch_size: 64,
726 prefer_inplace: false,
727 };
728
729 let scheme = hints.best_scheme_for_latency(1.1);
731 assert!(scheme.is_some());
732
733 let scheme = hints.best_scheme_for_latency(0.5);
735 assert!(scheme.is_none());
736 }
737
738 #[test]
739 fn test_simd_quantization_operations() {
740 let simd_ops = SimdQuantizationOps::new();
741
742 if simd_ops.is_available() {
743 let input = vec![1.0, 2.0, 3.0, 4.0];
744 let result = simd_ops.quantize_f32_to_u8_simd(&input, 1.0, 0.0);
745
746 if let Ok(quantized) = result {
747 assert_eq!(quantized.len(), input.len());
748 assert!(quantized[0] <= 2); assert!(quantized[3] <= 5); }
752 }
753 }
754
755 #[test]
756 fn test_simd_dequantization_operations() {
757 let simd_ops = SimdQuantizationOps::new();
758
759 if simd_ops.is_available() {
760 let input = vec![1u8, 2u8, 3u8, 4u8];
761 let result = simd_ops.dequantize_u8_to_f32_simd(&input, 1.0, 0.0);
762
763 if let Ok(dequantized) = result {
764 assert_eq!(dequantized.len(), input.len());
765 assert!((dequantized[0] - 1.0).abs() < 0.001);
767 assert!((dequantized[3] - 4.0).abs() < 0.001);
768 }
769 }
770 }
771
772 #[test]
773 fn test_simd_int8_addition() {
774 let simd_ops = SimdQuantizationOps::new();
775
776 if simd_ops.is_available() {
777 let a = vec![10i8, 20i8, 30i8, 40i8];
778 let b = vec![5i8, 10i8, 15i8, 20i8];
779 let result = simd_ops.add_int8_simd(&a, &b);
780
781 if let Ok(sum) = result {
782 assert_eq!(sum.len(), a.len());
783 assert_eq!(sum[0], 15i8);
784 assert_eq!(sum[1], 30i8);
785 assert_eq!(sum[2], 45i8);
786 assert_eq!(sum[3], 60i8);
787 }
788 }
789 }
790}