1use crate::{
8 device_info::{MobileDeviceInfo, PerformanceTier},
9 optimization::memory_pool::MobileMemoryPool,
10 Result,
11};
12use half::f16;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::{Arc, Mutex};
16use trustformers_core::errors::{invalid_config, invalid_input};
17use trustformers_core::Tensor;
18
19pub struct MobileQuantizationEngine {
21 config: QuantizationConfig,
22 device_info: MobileDeviceInfo,
23 calibration_data: Option<CalibrationDataset>,
24 quantization_cache: Arc<Mutex<HashMap<String, QuantizedModel>>>,
25 memory_pool: Option<Arc<MobileMemoryPool>>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct QuantizationConfig {
31 pub target_precision: MobilePrecision,
33 pub enable_mixed_precision: bool,
35 pub dynamic_strategy: DynamicQuantizationStrategy,
37 pub hardware_aware: bool,
39 pub granularity: QuantizationGranularity,
41 pub quality_threshold: f32,
43 pub memory_constraint_mb: usize,
45 pub enable_gradient_quantization: bool,
47 pub kl_threshold: f32,
49 pub enable_ptq: bool,
51 pub enable_qat: bool,
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
57pub enum MobilePrecision {
58 INT4,
60 INT8,
62 FP16,
64 Mixed4_8,
66 Mixed8_16,
68 DYNAMIC,
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
74pub enum DynamicQuantizationStrategy {
75 BatteryAware,
77 ThermalAware,
79 MemoryAware,
81 PerformanceAware,
83 Adaptive,
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
89pub enum QuantizationGranularity {
90 PerTensor,
92 PerChannel,
94 PerGroup { group_size: usize },
96 PerLayer,
98}
99
100#[derive(Debug, Clone)]
102pub struct CalibrationDataset {
103 pub samples: Vec<Tensor>,
105 pub weights: Option<Vec<f32>>,
107 pub statistics: DatasetStatistics,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, Default)]
113pub struct DatasetStatistics {
114 pub activation_ranges: HashMap<String, (f32, f32)>,
116 pub layer_means: HashMap<String, f32>,
118 pub layer_variances: HashMap<String, f32>,
120 pub kl_scores: HashMap<String, f32>,
122}
123
124#[derive(Debug, Clone)]
126pub struct QuantizedModel {
127 pub weights: HashMap<String, QuantizedTensor>,
129 pub parameters: QuantizationParameters,
131 pub metadata: ModelMetadata,
133 pub benchmarks: QuantizationBenchmarks,
135}
136
137#[derive(Debug, Clone)]
139pub struct QuantizedTensor {
140 pub data: Vec<i8>,
142 pub scales: Vec<f32>,
144 pub zero_points: Vec<i32>,
146 pub shape: Vec<usize>,
148 pub scheme: QuantizationScheme,
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct QuantizationScheme {
155 pub bits: u8,
157 pub symmetric: bool,
159 pub signed: bool,
161 pub method: QuantizationMethod,
163}
164
165#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
167pub enum QuantizationMethod {
168 Linear,
170 Logarithmic,
172 PowerOfTwo,
174 KMeans,
176 Learned,
178}
179
180#[derive(Debug, Clone, Copy, PartialEq)]
182pub enum ModelFormat {
183 SafeTensors,
185 PyTorchPickle,
187 TensorFlow,
189 ONNX,
191 Custom,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct QuantizationParameters {
198 pub global_scale: f32,
200 pub layer_scales: HashMap<String, f32>,
202 pub layer_zero_points: HashMap<String, i32>,
204 pub dequant_overhead_ms: f32,
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct ModelMetadata {
211 pub original_size_bytes: usize,
213 pub quantized_size_bytes: usize,
215 pub compression_ratio: f32,
217 pub quality_score: f32,
219 pub timestamp: std::time::SystemTime,
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct QuantizationBenchmarks {
226 pub original_inference_ms: f32,
228 pub quantized_inference_ms: f32,
230 pub speedup_factor: f32,
232 pub memory_reduction_mb: f32,
234 pub power_reduction_mw: f32,
236}
237
238impl MobileQuantizationEngine {
239 pub fn new(
241 config: QuantizationConfig,
242 device_info: MobileDeviceInfo,
243 memory_pool: Option<Arc<MobileMemoryPool>>,
244 ) -> Result<Self> {
245 Ok(Self {
246 config,
247 device_info,
248 calibration_data: None,
249 quantization_cache: Arc::new(Mutex::new(HashMap::new())),
250 memory_pool,
251 })
252 }
253
254 pub fn set_calibration_data(&mut self, dataset: CalibrationDataset) -> Result<()> {
256 if dataset.samples.is_empty() {
258 return Err(invalid_config(
259 "set_calibration_data",
260 "Calibration dataset cannot be empty",
261 ));
262 }
263
264 self.calibration_data = Some(dataset);
265 Ok(())
266 }
267
268 pub fn quantize_model(&self, model_id: &str, model_data: &[u8]) -> Result<QuantizedModel> {
270 {
272 let cache = self.quantization_cache.lock().expect("Operation failed");
273 if let Some(cached_model) = cache.get(model_id) {
274 return Ok(cached_model.clone());
275 }
276 }
277
278 let strategy = self.determine_quantization_strategy()?;
280
281 let hardware_config = self.get_hardware_quantization_config()?;
283
284 let quantized_model = self.perform_quantization(model_data, &strategy, &hardware_config)?;
286
287 let benchmarks = self.benchmark_quantized_model(&quantized_model)?;
289
290 let final_model = QuantizedModel {
291 weights: quantized_model.weights,
292 parameters: quantized_model.parameters,
293 metadata: quantized_model.metadata,
294 benchmarks,
295 };
296
297 {
299 let mut cache = self.quantization_cache.lock().expect("Operation failed");
300 cache.insert(model_id.to_string(), final_model.clone());
301 }
302
303 Ok(final_model)
304 }
305
306 fn determine_quantization_strategy(&self) -> Result<MobilePrecision> {
308 match (
309 &self.device_info.performance_scores.overall_tier,
310 &self.config.target_precision,
311 ) {
312 (PerformanceTier::High, MobilePrecision::DYNAMIC) => {
313 Ok(MobilePrecision::Mixed8_16)
315 },
316 (PerformanceTier::Mid, MobilePrecision::DYNAMIC) => {
317 Ok(MobilePrecision::INT8)
319 },
320 (PerformanceTier::Budget, MobilePrecision::DYNAMIC) => {
321 Ok(MobilePrecision::Mixed4_8)
323 },
324 (_, precision) => Ok(*precision),
325 }
326 }
327
328 fn get_hardware_quantization_config(&self) -> Result<HardwareQuantizationConfig> {
330 let mut config = HardwareQuantizationConfig::default();
331
332 if self.device_info.npu_info.is_some() {
334 config.use_npu_kernels = true;
335 config.preferred_precision = MobilePrecision::INT8;
336 }
337
338 if self.device_info.gpu_info.is_some() {
339 config.use_gpu_kernels = true;
340 config.gpu_memory_optimization = true;
341 }
342
343 if self.device_info.cpu_info.architecture.contains("arm")
345 || self.device_info.cpu_info.architecture.contains("aarch64")
346 {
347 config.use_neon_instructions = true;
348 config.arm_specific_kernels = true;
349 }
350
351 Ok(config)
352 }
353
354 fn perform_quantization(
356 &self,
357 model_data: &[u8],
358 strategy: &MobilePrecision,
359 hardware_config: &HardwareQuantizationConfig,
360 ) -> Result<QuantizedModel> {
361 let weights = self.parse_model_weights(model_data)?;
363
364 let quantized_weights = match strategy {
366 MobilePrecision::INT4 => self.quantize_to_int4(&weights)?,
367 MobilePrecision::INT8 => self.quantize_to_int8(&weights)?,
368 MobilePrecision::FP16 => self.quantize_to_fp16(&weights)?,
369 MobilePrecision::Mixed4_8 => self.quantize_mixed_4_8(&weights)?,
370 MobilePrecision::Mixed8_16 => self.quantize_mixed_8_16(&weights)?,
371 MobilePrecision::DYNAMIC => self.quantize_dynamic(&weights)?,
372 };
373
374 let parameters = self.calculate_quantization_parameters(&quantized_weights)?;
376
377 let metadata = ModelMetadata {
379 original_size_bytes: model_data.len(),
380 quantized_size_bytes: self.calculate_quantized_size(&quantized_weights),
381 compression_ratio: model_data.len() as f32
382 / self.calculate_quantized_size(&quantized_weights) as f32,
383 quality_score: self.estimate_quality_score(&quantized_weights)?,
384 timestamp: std::time::SystemTime::now(),
385 };
386
387 Ok(QuantizedModel {
388 weights: quantized_weights,
389 parameters,
390 metadata,
391 benchmarks: QuantizationBenchmarks::default(), })
393 }
394
395 fn quantize_to_int4(
397 &self,
398 weights: &HashMap<String, Tensor>,
399 ) -> Result<HashMap<String, QuantizedTensor>> {
400 let mut quantized = HashMap::new();
401
402 for (layer_name, tensor) in weights {
403 let tensor_data = tensor.data()?.to_vec();
404
405 let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
407 let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
408
409 let scale = (max_val - min_val) / 15.0; let zero_point = (-min_val / scale).round() as i32;
411
412 let quantized_data: Vec<i8> = tensor_data
414 .iter()
415 .map(|&x| {
416 let quantized = ((x / scale) + zero_point as f32).round();
417 quantized.max(0.0).min(15.0) as i8
418 })
419 .collect();
420
421 let quantized_tensor = QuantizedTensor {
422 data: quantized_data,
423 scales: vec![scale],
424 zero_points: vec![zero_point],
425 shape: tensor.shape().to_vec(),
426 scheme: QuantizationScheme {
427 bits: 4,
428 symmetric: false,
429 signed: false,
430 method: QuantizationMethod::Linear,
431 },
432 };
433
434 quantized.insert(layer_name.clone(), quantized_tensor);
435 }
436
437 Ok(quantized)
438 }
439
440 fn quantize_to_int8(
442 &self,
443 weights: &HashMap<String, Tensor>,
444 ) -> Result<HashMap<String, QuantizedTensor>> {
445 let mut quantized = HashMap::new();
446
447 for (layer_name, tensor) in weights {
448 let tensor_data = tensor.data()?.to_vec();
449
450 let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
452 let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
453
454 let scale = (max_val - min_val) / 255.0; let zero_point = (-min_val / scale).round() as i32;
456
457 let quantized_data: Vec<i8> = tensor_data
459 .iter()
460 .map(|&x| {
461 let quantized = ((x / scale) + zero_point as f32).round();
462 (quantized.max(0.0).min(255.0) as i32 - 128) as i8 })
464 .collect();
465
466 let quantized_tensor = QuantizedTensor {
467 data: quantized_data,
468 scales: vec![scale],
469 zero_points: vec![zero_point],
470 shape: tensor.shape().to_vec(),
471 scheme: QuantizationScheme {
472 bits: 8,
473 symmetric: false,
474 signed: true,
475 method: QuantizationMethod::Linear,
476 },
477 };
478
479 quantized.insert(layer_name.clone(), quantized_tensor);
480 }
481
482 Ok(quantized)
483 }
484
485 fn quantize_to_fp16(
487 &self,
488 weights: &HashMap<String, Tensor>,
489 ) -> Result<HashMap<String, QuantizedTensor>> {
490 let mut quantized = HashMap::new();
492
493 for (layer_name, tensor) in weights {
494 let tensor_data = tensor.data()?.to_vec();
495
496 let quantized_data: Vec<i8> = tensor_data
498 .iter()
499 .flat_map(|&x| {
500 let fp16_bits = f16::from_f32(x).to_bits();
501 [(fp16_bits & 0xFF) as i8, ((fp16_bits >> 8) & 0xFF) as i8]
502 })
503 .collect();
504
505 let quantized_tensor = QuantizedTensor {
506 data: quantized_data,
507 scales: vec![1.0], zero_points: vec![0],
509 shape: tensor.shape().to_vec(),
510 scheme: QuantizationScheme {
511 bits: 16,
512 symmetric: true,
513 signed: true,
514 method: QuantizationMethod::Linear,
515 },
516 };
517
518 quantized.insert(layer_name.clone(), quantized_tensor);
519 }
520
521 Ok(quantized)
522 }
523
524 fn quantize_mixed_4_8(
526 &self,
527 weights: &HashMap<String, Tensor>,
528 ) -> Result<HashMap<String, QuantizedTensor>> {
529 let mut quantized = HashMap::new();
530
531 for (layer_name, tensor) in weights {
532 let use_4bit = self.should_use_4bit_for_layer(layer_name, tensor)?;
534
535 if use_4bit {
536 let quantized_4bit = self.quantize_to_int4(
537 &[(layer_name.clone(), tensor.clone())].iter().cloned().collect(),
538 )?;
539 quantized.extend(quantized_4bit);
540 } else {
541 let quantized_8bit = self.quantize_to_int8(
542 &[(layer_name.clone(), tensor.clone())].iter().cloned().collect(),
543 )?;
544 quantized.extend(quantized_8bit);
545 }
546 }
547
548 Ok(quantized)
549 }
550
551 fn quantize_mixed_8_16(
553 &self,
554 weights: &HashMap<String, Tensor>,
555 ) -> Result<HashMap<String, QuantizedTensor>> {
556 let mut quantized = HashMap::new();
557
558 for (layer_name, tensor) in weights {
559 let use_8bit = self.should_use_8bit_for_layer(layer_name, tensor)?;
561
562 if use_8bit {
563 let quantized_8bit = self.quantize_to_int8(
564 &[(layer_name.clone(), tensor.clone())].iter().cloned().collect(),
565 )?;
566 quantized.extend(quantized_8bit);
567 } else {
568 let quantized_16bit = self.quantize_to_fp16(
569 &[(layer_name.clone(), tensor.clone())].iter().cloned().collect(),
570 )?;
571 quantized.extend(quantized_16bit);
572 }
573 }
574
575 Ok(quantized)
576 }
577
578 fn quantize_dynamic(
580 &self,
581 weights: &HashMap<String, Tensor>,
582 ) -> Result<HashMap<String, QuantizedTensor>> {
583 self.quantize_to_int8(weights)
585 }
586
587 fn should_use_4bit_for_layer(&self, layer_name: &str, tensor: &Tensor) -> Result<bool> {
589 let is_embedding = layer_name.contains("embed") || layer_name.contains("token");
591 let is_output = layer_name.contains("output") || layer_name.contains("head");
592 let is_large = tensor.shape().iter().product::<usize>() > 1000000;
593
594 Ok(is_embedding || (is_large && !is_output))
595 }
596
597 fn should_use_8bit_for_layer(&self, layer_name: &str, tensor: &Tensor) -> Result<bool> {
599 let is_attention = layer_name.contains("attn") || layer_name.contains("attention");
601 let is_output = layer_name.contains("output") || layer_name.contains("head");
602 let is_norm = layer_name.contains("norm") || layer_name.contains("ln");
603
604 Ok(!(is_attention || is_output || is_norm))
605 }
606
607 fn parse_model_weights(&self, model_data: &[u8]) -> Result<HashMap<String, Tensor>> {
609 #[allow(dead_code)]
611 let mut weights = HashMap::new();
612
613 let format = self.detect_model_format(model_data)?;
615
616 match format {
617 ModelFormat::SafeTensors => {
618 weights = self.parse_safetensors(model_data)?;
619 },
620 ModelFormat::PyTorchPickle => {
621 weights = self.parse_pytorch_pickle(model_data)?;
622 },
623 ModelFormat::TensorFlow => {
624 weights = self.parse_tensorflow(model_data)?;
625 },
626 ModelFormat::ONNX => {
627 weights = self.parse_onnx(model_data)?;
628 },
629 ModelFormat::Custom => {
630 weights = self.parse_custom_format(model_data)?;
631 },
632 }
633
634 self.validate_parsed_weights(&weights)?;
636
637 Ok(weights)
638 }
639
640 fn calculate_quantization_parameters(
642 &self,
643 weights: &HashMap<String, QuantizedTensor>,
644 ) -> Result<QuantizationParameters> {
645 let mut layer_scales = HashMap::new();
646 let mut layer_zero_points = HashMap::new();
647 let mut total_dequant_overhead = 0.0;
648
649 for (layer_name, quantized_tensor) in weights {
651 let scale = quantized_tensor.scales.first().copied().unwrap_or(1.0);
653 let zero_point = quantized_tensor.zero_points.first().copied().unwrap_or(0);
654
655 layer_scales.insert(layer_name.clone(), scale);
656 layer_zero_points.insert(layer_name.clone(), zero_point);
657
658 let tensor_size = quantized_tensor.data.len();
660 let overhead_factor = match quantized_tensor.scheme.bits {
661 4 => 0.05, 8 => 0.03, 16 => 0.01, _ => 0.04, };
666
667 total_dequant_overhead += (tensor_size as f32 * overhead_factor) / 1000.0;
668 }
670
671 let total_elements: f32 = weights.values().map(|t| t.data.len() as f32).sum();
673 let global_scale = if total_elements > 0.0 {
674 layer_scales.values().sum::<f32>() / layer_scales.len() as f32
675 } else {
676 1.0
677 };
678
679 Ok(QuantizationParameters {
680 global_scale,
681 layer_scales,
682 layer_zero_points,
683 dequant_overhead_ms: total_dequant_overhead,
684 })
685 }
686
687 fn calculate_quantized_size(&self, weights: &HashMap<String, QuantizedTensor>) -> usize {
689 let mut total_size = 0;
690
691 for (layer_name, quantized_tensor) in weights {
692 let data_size = quantized_tensor.data.len();
694
695 let metadata_size = quantized_tensor.scales.len() * 4 + quantized_tensor.zero_points.len() * 4 + quantized_tensor.shape.len() * 8 + layer_name.len() + 32; total_size += data_size + metadata_size;
702 }
703
704 total_size
705 }
706
707 fn estimate_quality_score(&self, weights: &HashMap<String, QuantizedTensor>) -> Result<f32> {
709 if weights.is_empty() {
710 return Ok(1.0);
711 }
712
713 let mut total_quality = 0.0;
714 let mut total_weight = 0.0;
715
716 for (layer_name, quantized_tensor) in weights {
717 let layer_weight = quantized_tensor.data.len() as f32;
718
719 let base_quality = match quantized_tensor.scheme.bits {
721 4 => 0.85, 8 => 0.93, 16 => 0.98, _ => 0.90, };
726
727 let layer_quality_factor = self.get_layer_quality_factor(layer_name);
729 let adjusted_quality = base_quality * layer_quality_factor;
730
731 total_quality += adjusted_quality * layer_weight;
733 total_weight += layer_weight;
734 }
735
736 let overall_quality = if total_weight > 0.0 { total_quality / total_weight } else { 1.0 };
737
738 let calibration_factor = if let Some(ref cal_data) = self.calibration_data {
740 self.estimate_calibration_quality_impact(cal_data)?
741 } else {
742 0.95 };
744
745 Ok((overall_quality * calibration_factor).min(1.0))
746 }
747
748 fn benchmark_quantized_model(&self, model: &QuantizedModel) -> Result<QuantizationBenchmarks> {
750 let mut benchmarks = QuantizationBenchmarks::default();
751
752 let original_params = model.metadata.original_size_bytes / 4; benchmarks.original_inference_ms =
755 self.estimate_inference_time(original_params, MobilePrecision::FP16)?;
756
757 let quantized_params = model.metadata.quantized_size_bytes;
759 let avg_precision = self.estimate_average_precision(&model.weights);
760 benchmarks.quantized_inference_ms =
761 self.estimate_inference_time(quantized_params, avg_precision)?;
762
763 benchmarks.speedup_factor = if benchmarks.quantized_inference_ms > 0.0 {
765 benchmarks.original_inference_ms / benchmarks.quantized_inference_ms
766 } else {
767 1.0
768 };
769
770 benchmarks.memory_reduction_mb = (model.metadata.original_size_bytes
772 - model.metadata.quantized_size_bytes) as f32
773 / (1024.0 * 1024.0);
774
775 benchmarks.power_reduction_mw = self.estimate_power_reduction(&model.weights)?;
777
778 Ok(benchmarks)
779 }
780
781 fn detect_model_format(&self, data: &[u8]) -> Result<ModelFormat> {
785 if data.len() < 8 {
786 return Err(invalid_input("Model data too small to detect format"));
787 }
788
789 if data.starts_with(b"STFR") || data.starts_with(&[0x53, 0x54, 0x46, 0x52]) {
791 return Ok(ModelFormat::SafeTensors);
792 }
793
794 if data.starts_with(&[0x80, 0x02]) || data.starts_with(&[0x80, 0x03]) {
796 return Ok(ModelFormat::PyTorchPickle);
797 }
798
799 if data.starts_with(b"TF") {
801 return Ok(ModelFormat::TensorFlow);
802 }
803
804 if data.starts_with(&[0x08, 0x01]) {
806 return Ok(ModelFormat::ONNX);
807 }
808
809 Ok(ModelFormat::Custom)
811 }
812
813 fn parse_safetensors(&self, _data: &[u8]) -> Result<HashMap<String, Tensor>> {
815 Ok(HashMap::new())
817 }
818
819 fn parse_pytorch_pickle(&self, _data: &[u8]) -> Result<HashMap<String, Tensor>> {
821 Ok(HashMap::new())
823 }
824
825 fn parse_tensorflow(&self, _data: &[u8]) -> Result<HashMap<String, Tensor>> {
827 Ok(HashMap::new())
829 }
830
831 fn parse_onnx(&self, _data: &[u8]) -> Result<HashMap<String, Tensor>> {
833 Ok(HashMap::new())
835 }
836
837 fn parse_custom_format(&self, _data: &[u8]) -> Result<HashMap<String, Tensor>> {
839 Ok(HashMap::new())
841 }
842
843 fn validate_parsed_weights(&self, weights: &HashMap<String, Tensor>) -> Result<()> {
845 if weights.is_empty() {
846 return Err(invalid_input("No weights found in model"));
847 }
848
849 for (layer_name, tensor) in weights {
850 if tensor.shape().is_empty() {
852 return Err(invalid_input(format!(
853 "Invalid tensor shape for layer: {}",
854 layer_name
855 )));
856 }
857
858 let total_elements: usize = tensor.shape().iter().product();
860 if total_elements == 0 {
861 return Err(invalid_input(format!(
862 "Empty tensor for layer: {}",
863 layer_name
864 )));
865 }
866
867 if total_elements > 100_000_000 {
869 tracing::warn!(
870 "Large tensor detected in layer {}: {} elements",
871 layer_name,
872 total_elements
873 );
874 }
875 }
876
877 Ok(())
878 }
879
880 fn get_layer_quality_factor(&self, layer_name: &str) -> f32 {
882 if layer_name.contains("output") || layer_name.contains("head") {
884 0.95 } else if layer_name.contains("attention") || layer_name.contains("attn") {
886 0.92 } else if layer_name.contains("norm") || layer_name.contains("ln") {
888 0.98 } else if layer_name.contains("embed") || layer_name.contains("token") {
890 0.90 } else {
892 1.0 }
894 }
895
896 fn estimate_calibration_quality_impact(&self, cal_data: &CalibrationDataset) -> Result<f32> {
898 let sample_factor = (cal_data.samples.len() as f32 / 100.0).min(1.0);
900
901 let stats_quality =
903 if !cal_data.statistics.activation_ranges.is_empty() { 1.0 } else { 0.9 };
904
905 Ok(0.95 + 0.05 * sample_factor * stats_quality)
906 }
907
908 fn estimate_inference_time(&self, params: usize, precision: MobilePrecision) -> Result<f32> {
910 let base_time_per_param = match self.device_info.performance_scores.overall_tier {
912 PerformanceTier::VeryLow => 0.01, PerformanceTier::Low => 0.008, PerformanceTier::Budget => 0.005, PerformanceTier::Medium => 0.003, PerformanceTier::Mid => 0.002, PerformanceTier::High => 0.001, PerformanceTier::VeryHigh => 0.0007, PerformanceTier::Flagship => 0.0005, };
921
922 let precision_factor = match precision {
924 MobilePrecision::INT4 => 0.5,
925 MobilePrecision::INT8 => 0.7,
926 MobilePrecision::FP16 => 1.0,
927 MobilePrecision::Mixed4_8 => 0.6,
928 MobilePrecision::Mixed8_16 => 0.85,
929 MobilePrecision::DYNAMIC => 0.8,
930 };
931
932 let hw_factor = if self.device_info.npu_info.is_some() {
934 0.6
935 } else if self.device_info.gpu_info.is_some() {
936 0.8
937 } else {
938 1.0
939 };
940
941 let total_time = params as f32 * base_time_per_param * precision_factor * hw_factor;
942 Ok(total_time)
943 }
944
945 fn estimate_average_precision(
947 &self,
948 weights: &HashMap<String, QuantizedTensor>,
949 ) -> MobilePrecision {
950 if weights.is_empty() {
951 return MobilePrecision::FP16;
952 }
953
954 let mut total_bits = 0;
955 let mut total_tensors = 0;
956
957 for tensor in weights.values() {
958 total_bits += tensor.scheme.bits as u32;
959 total_tensors += 1;
960 }
961
962 let avg_bits = total_bits as f32 / total_tensors as f32;
963
964 match avg_bits.round() as u8 {
965 4 => MobilePrecision::INT4,
966 8 => MobilePrecision::INT8,
967 16 => MobilePrecision::FP16,
968 _ => MobilePrecision::INT8, }
970 }
971
972 fn estimate_power_reduction(&self, weights: &HashMap<String, QuantizedTensor>) -> Result<f32> {
974 let mut total_power_reduction = 0.0;
975
976 for tensor in weights.values() {
977 let tensor_size = tensor.data.len() as f32;
978
979 let reduction_per_op = match tensor.scheme.bits {
981 4 => 0.08, 8 => 0.05, 16 => 0.02, _ => 0.04, };
986
987 total_power_reduction += tensor_size * reduction_per_op / 1000.0;
988 }
989
990 Ok(total_power_reduction)
991 }
992}
993
994#[derive(Debug, Clone)]
996struct HardwareQuantizationConfig {
997 use_npu_kernels: bool,
998 use_gpu_kernels: bool,
999 use_neon_instructions: bool,
1000 arm_specific_kernels: bool,
1001 gpu_memory_optimization: bool,
1002 preferred_precision: MobilePrecision,
1003}
1004
1005impl Default for HardwareQuantizationConfig {
1006 fn default() -> Self {
1007 Self {
1008 use_npu_kernels: false,
1009 use_gpu_kernels: false,
1010 use_neon_instructions: false,
1011 arm_specific_kernels: false,
1012 gpu_memory_optimization: false,
1013 preferred_precision: MobilePrecision::INT8,
1014 }
1015 }
1016}
1017
1018impl Default for QuantizationBenchmarks {
1019 fn default() -> Self {
1020 Self {
1021 original_inference_ms: 0.0,
1022 quantized_inference_ms: 0.0,
1023 speedup_factor: 1.0,
1024 memory_reduction_mb: 0.0,
1025 power_reduction_mw: 0.0,
1026 }
1027 }
1028}
1029
1030impl Default for QuantizationConfig {
1031 fn default() -> Self {
1032 Self {
1033 target_precision: MobilePrecision::INT8,
1034 enable_mixed_precision: true,
1035 dynamic_strategy: DynamicQuantizationStrategy::Adaptive,
1036 hardware_aware: true,
1037 granularity: QuantizationGranularity::PerChannel,
1038 quality_threshold: 0.9,
1039 memory_constraint_mb: 512,
1040 enable_gradient_quantization: false,
1041 kl_threshold: 0.01,
1042 enable_ptq: true,
1043 enable_qat: false,
1044 }
1045 }
1046}
1047
1048#[cfg(test)]
1049mod tests {
1050 use super::*;
1051 use trustformers_core::Tensor;
1052
1053 #[test]
1054 fn test_model_format_detection() {
1055 let engine = create_test_engine();
1056
1057 let safetensors_data = b"STFR\x00\x00\x00\x00test data";
1059 let format = engine.detect_model_format(safetensors_data).expect("Operation failed");
1060 assert_eq!(format, ModelFormat::SafeTensors);
1061
1062 let pytorch_data = b"\x80\x02test data";
1064 let format = engine.detect_model_format(pytorch_data).expect("Operation failed");
1065 assert_eq!(format, ModelFormat::PyTorchPickle);
1066
1067 let tf_data = b"TFtest data";
1069 let format = engine.detect_model_format(tf_data).expect("Operation failed");
1070 assert_eq!(format, ModelFormat::TensorFlow);
1071
1072 let onnx_data = b"\x08\x01test data";
1074 let format = engine.detect_model_format(onnx_data).expect("Operation failed");
1075 assert_eq!(format, ModelFormat::ONNX);
1076
1077 let custom_data = b"custom test data";
1079 let format = engine.detect_model_format(custom_data).expect("Operation failed");
1080 assert_eq!(format, ModelFormat::Custom);
1081 }
1082
1083 #[test]
1084 fn test_quantization_parameters_calculation() {
1085 let engine = create_test_engine();
1086 let weights = create_test_quantized_weights();
1087
1088 let params = engine.calculate_quantization_parameters(&weights).expect("Operation failed");
1089
1090 assert!(params.global_scale > 0.0);
1091 assert!(!params.layer_scales.is_empty());
1092 assert!(!params.layer_zero_points.is_empty());
1093 assert!(params.dequant_overhead_ms >= 0.0);
1094 }
1095
1096 #[test]
1097 fn test_quality_score_estimation() {
1098 let engine = create_test_engine();
1099 let weights = create_test_quantized_weights();
1100
1101 let quality = engine.estimate_quality_score(&weights).expect("Operation failed");
1102
1103 assert!((0.0..=1.0).contains(&quality));
1104 }
1105
1106 #[test]
1107 fn test_layer_quality_factors() {
1108 let engine = create_test_engine();
1109
1110 assert_eq!(engine.get_layer_quality_factor("model.output.weight"), 0.95);
1112 assert_eq!(
1113 engine.get_layer_quality_factor("model.attention.weight"),
1114 0.92
1115 );
1116 assert_eq!(
1117 engine.get_layer_quality_factor("model.layer_norm.weight"),
1118 0.98
1119 );
1120 assert_eq!(
1121 engine.get_layer_quality_factor("model.embedding.weight"),
1122 0.90
1123 );
1124 assert_eq!(engine.get_layer_quality_factor("model.hidden.weight"), 1.0);
1125 }
1126
1127 #[test]
1128 fn test_inference_time_estimation() {
1129 let engine = create_test_engine();
1130
1131 let time = engine
1132 .estimate_inference_time(1000, MobilePrecision::INT8)
1133 .expect("Operation failed");
1134 assert!(time > 0.0);
1135
1136 let time_fp16 = engine
1137 .estimate_inference_time(1000, MobilePrecision::FP16)
1138 .expect("Operation failed");
1139 let time_int4 = engine
1140 .estimate_inference_time(1000, MobilePrecision::INT4)
1141 .expect("Operation failed");
1142
1143 assert!(time_int4 < time_fp16);
1145 }
1146
1147 #[test]
1148 fn test_power_reduction_estimation() {
1149 let engine = create_test_engine();
1150 let weights = create_test_quantized_weights();
1151
1152 let power_reduction = engine.estimate_power_reduction(&weights).expect("Operation failed");
1153 assert!(power_reduction >= 0.0);
1154 }
1155
1156 #[test]
1157 fn test_quantized_size_calculation() {
1158 let engine = create_test_engine();
1159 let weights = create_test_quantized_weights();
1160
1161 let size = engine.calculate_quantized_size(&weights);
1162 assert!(size > 0);
1163 }
1164
1165 #[test]
1166 fn test_weight_validation() {
1167 let engine = create_test_engine();
1168
1169 let empty_weights = HashMap::new();
1171 assert!(engine.validate_parsed_weights(&empty_weights).is_err());
1172
1173 let valid_weights = create_test_weights();
1175 assert!(engine.validate_parsed_weights(&valid_weights).is_ok());
1176 }
1177
1178 #[test]
1179 fn test_calibration_data_validation() {
1180 let mut engine = create_test_engine();
1181
1182 let empty_dataset = CalibrationDataset {
1184 samples: vec![],
1185 weights: None,
1186 statistics: DatasetStatistics::default(),
1187 };
1188 assert!(engine.set_calibration_data(empty_dataset).is_err());
1189
1190 let valid_dataset = CalibrationDataset {
1192 samples: vec![Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("Operation failed")],
1193 weights: None,
1194 statistics: DatasetStatistics::default(),
1195 };
1196 assert!(engine.set_calibration_data(valid_dataset).is_ok());
1197 }
1198
1199 fn create_test_engine() -> MobileQuantizationEngine {
1201 let config = QuantizationConfig::default();
1202 let device_info = crate::device_info::MobileDeviceDetector::detect().unwrap_or_else(|_| {
1204 use crate::device_info::*;
1206 MobileDeviceInfo {
1207 platform: crate::MobilePlatform::Generic,
1208 basic_info: BasicDeviceInfo {
1209 platform: crate::MobilePlatform::Generic,
1210 manufacturer: "Test".to_string(),
1211 model: "Test Device".to_string(),
1212 os_version: "1.0".to_string(),
1213 hardware_id: "test".to_string(),
1214 device_generation: Some(2023),
1215 },
1216 cpu_info: CpuInfo {
1217 architecture: "aarch64".to_string(),
1218 total_cores: 8,
1219 core_count: 8,
1220 performance_cores: 4,
1221 efficiency_cores: 4,
1222 max_frequency_mhz: Some(3000),
1223 l1_cache_kb: Some(64),
1224 l2_cache_kb: Some(512),
1225 l3_cache_kb: Some(4096),
1226 features: vec!["neon".to_string(), "fp16".to_string()],
1227 simd_support: SimdSupport::Advanced,
1228 },
1229 memory_info: MemoryInfo {
1230 total_mb: 8192,
1231 available_mb: 6144,
1232 total_memory: 8192,
1233 available_memory: 6144,
1234 bandwidth_mbps: Some(51200),
1235 memory_type: "LPDDR5".to_string(),
1236 frequency_mhz: Some(6400),
1237 is_low_memory_device: false,
1238 },
1239 gpu_info: Some(GpuInfo {
1240 vendor: "ARM".to_string(),
1241 model: "Mali-G78".to_string(),
1242 driver_version: "1.0".to_string(),
1243 memory_mb: Some(2048),
1244 compute_units: Some(14),
1245 supported_apis: vec![GpuApi::OpenGLES3, GpuApi::Vulkan11],
1246 performance_tier: GpuPerformanceTier::High,
1247 }),
1248 npu_info: None,
1249 thermal_info: ThermalInfo {
1250 current_state: ThermalState::Nominal,
1251 state: ThermalState::Nominal,
1252 throttling_supported: true,
1253 temperature_sensors: vec![],
1254 thermal_zones: vec![],
1255 },
1256 power_info: PowerInfo {
1257 battery_capacity_mah: Some(4000),
1258 battery_level_percent: Some(80),
1259 battery_level: Some(80),
1260 battery_health_percent: Some(100),
1261 charging_status: ChargingStatus::Discharging,
1262 is_charging: false,
1263 power_save_mode: false,
1264 low_power_mode_available: true,
1265 },
1266 available_backends: vec![crate::MobileBackend::CPU, crate::MobileBackend::GPU],
1267 performance_scores: PerformanceScores {
1268 cpu_single_core: Some(1200),
1269 cpu_multi_core: Some(4800),
1270 gpu_score: Some(2500),
1271 memory_score: Some(1800),
1272 overall_tier: PerformanceTier::Mid,
1273 tier: PerformanceTier::Mid,
1274 },
1275 }
1276 });
1277
1278 MobileQuantizationEngine::new(config, device_info, None).expect("Operation failed")
1279 }
1280
1281 fn create_test_quantized_weights() -> HashMap<String, QuantizedTensor> {
1282 let mut weights = HashMap::new();
1283
1284 weights.insert(
1285 "layer1.weight".to_string(),
1286 QuantizedTensor {
1287 data: vec![1, 2, 3, 4, 5],
1288 scales: vec![0.1],
1289 zero_points: vec![0],
1290 shape: vec![5],
1291 scheme: QuantizationScheme {
1292 bits: 8,
1293 symmetric: false,
1294 signed: true,
1295 method: QuantizationMethod::Linear,
1296 },
1297 },
1298 );
1299
1300 weights.insert(
1301 "layer2.weight".to_string(),
1302 QuantizedTensor {
1303 data: vec![6, 7, 8, 9, 10],
1304 scales: vec![0.2],
1305 zero_points: vec![1],
1306 shape: vec![5],
1307 scheme: QuantizationScheme {
1308 bits: 4,
1309 symmetric: false,
1310 signed: false,
1311 method: QuantizationMethod::Linear,
1312 },
1313 },
1314 );
1315
1316 weights
1317 }
1318
1319 fn create_test_weights() -> HashMap<String, Tensor> {
1320 let mut weights = HashMap::new();
1321
1322 weights.insert(
1323 "layer1.weight".to_string(),
1324 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5]).expect("Operation failed"),
1325 );
1326 weights.insert(
1327 "layer2.weight".to_string(),
1328 Tensor::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0], &[5]).expect("Operation failed"),
1329 );
1330
1331 weights
1332 }
1333}