1use crate::config::{ObserverType, QScheme, QuantBackend, QuantConfig};
32use torsh_core::{Result as TorshResult, TorshError};
33use torsh_tensor::Tensor;
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum ConfigObjective {
38 MaximumCompression,
40 MaximumAccuracy,
42 BalancedQuality,
44 MaximumSpeed,
46 MinimumMemory,
48 EdgeOptimized,
50}
51
52#[derive(Debug, Clone)]
54pub struct TensorProfile {
55 pub shape: Vec<usize>,
57 pub numel: usize,
59 pub stats: TensorStats,
61 pub sparsity: f32,
63 pub distribution: DistributionProfile,
65}
66
67#[derive(Debug, Clone)]
69pub struct TensorStats {
70 pub min: f32,
72 pub max: f32,
74 pub mean: f32,
76 pub std_dev: f32,
78 pub range: f32,
80 pub has_outliers: bool,
82 pub near_zero_ratio: f32,
84}
85
86#[derive(Debug, Clone, PartialEq)]
88pub enum DistributionProfile {
89 Normal,
91 Uniform,
93 HeavyTailed,
95 Bimodal,
97 Skewed,
99 Sparse,
101}
102
103pub struct AutoConfigurator {
105 objective: ConfigObjective,
106 history: Vec<ConfigPerformance>,
108 feature_weights: FeatureWeights,
110}
111
112#[derive(Debug, Clone)]
114struct ConfigPerformance {
115 #[allow(dead_code)]
116 config: QuantConfig,
117 profile: TensorProfile,
118 error: f32,
120 #[allow(dead_code)]
121 compression: f32,
123 #[allow(dead_code)]
124 speedup: Option<f32>,
126}
127
128#[derive(Debug, Clone)]
130struct FeatureWeights {
131 range_weight: f32,
133 sparsity_weight: f32,
135 distribution_weight: f32,
137 size_weight: f32,
139}
140
141impl Default for FeatureWeights {
142 fn default() -> Self {
143 Self {
144 range_weight: 1.0,
145 sparsity_weight: 0.8,
146 distribution_weight: 0.9,
147 size_weight: 0.7,
148 }
149 }
150}
151
152impl AutoConfigurator {
153 pub fn new(objective: ConfigObjective) -> Self {
155 Self {
156 objective,
157 history: Vec::new(),
158 feature_weights: FeatureWeights::default(),
159 }
160 }
161
162 pub fn recommend(
164 &self,
165 tensor: &Tensor,
166 constraints: Option<ConfigConstraints>,
167 ) -> TorshResult<QuantConfig> {
168 let profile = self.analyze_tensor(tensor)?;
170
171 let config = self.select_configuration(&profile, constraints)?;
173
174 Ok(config)
175 }
176
177 pub fn recommend_ranked(
179 &self,
180 tensor: &Tensor,
181 top_k: usize,
182 constraints: Option<ConfigConstraints>,
183 ) -> TorshResult<Vec<(QuantConfig, f32)>> {
184 let profile = self.analyze_tensor(tensor)?;
185 let mut candidates = self.generate_candidates(&profile, constraints)?;
186
187 for (config, score) in &mut candidates {
189 *score = self.score_configuration(config, &profile);
190 }
191
192 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
194
195 candidates.truncate(top_k);
197 Ok(candidates)
198 }
199
200 pub fn update_performance(
202 &mut self,
203 config: &QuantConfig,
204 tensor: &Tensor,
205 observed_error: f32,
206 observed_compression: f32,
207 speedup: Option<f32>,
208 ) -> TorshResult<()> {
209 let profile = self.analyze_tensor(tensor)?;
210
211 let performance = ConfigPerformance {
212 config: config.clone(),
213 profile,
214 error: observed_error,
215 compression: observed_compression,
216 speedup,
217 };
218
219 self.history.push(performance);
220
221 if self.history.len() >= 10 {
223 self.update_feature_weights();
224 }
225
226 Ok(())
227 }
228
229 fn analyze_tensor(&self, tensor: &Tensor) -> TorshResult<TensorProfile> {
235 let data = tensor.data()?;
236 let shape = tensor.shape().dims().to_vec();
237 let numel = tensor.shape().numel();
238
239 let stats = self.calculate_stats(&data)?;
241
242 let sparsity = self.calculate_sparsity(&data);
244
245 let distribution = self.classify_distribution(&data, &stats);
247
248 Ok(TensorProfile {
249 shape,
250 numel,
251 stats,
252 sparsity,
253 distribution,
254 })
255 }
256
257 fn calculate_stats(&self, data: &[f32]) -> TorshResult<TensorStats> {
259 if data.is_empty() {
260 return Err(TorshError::InvalidArgument(
261 "Cannot calculate stats for empty tensor".to_string(),
262 ));
263 }
264
265 let min = data.iter().copied().fold(f32::INFINITY, f32::min);
266 let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
267 let range = max - min;
268
269 let mean = data.iter().sum::<f32>() / data.len() as f32;
270
271 let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
272 let std_dev = variance.sqrt();
273
274 let mut sorted = data.to_vec();
276 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
277
278 let q1_idx = sorted.len() / 4;
279 let q3_idx = 3 * sorted.len() / 4;
280 let q1 = sorted[q1_idx];
281 let q3 = sorted[q3_idx];
282 let iqr = q3 - q1;
283
284 let outlier_threshold_low = q1 - 1.5 * iqr;
285 let outlier_threshold_high = q3 + 1.5 * iqr;
286
287 let has_outliers = data
288 .iter()
289 .any(|&x| x < outlier_threshold_low || x > outlier_threshold_high);
290
291 let zero_threshold = range.abs() * 0.01; let near_zero_count = data.iter().filter(|&&x| x.abs() < zero_threshold).count();
294 let near_zero_ratio = near_zero_count as f32 / data.len() as f32;
295
296 Ok(TensorStats {
297 min,
298 max,
299 mean,
300 std_dev,
301 range,
302 has_outliers,
303 near_zero_ratio,
304 })
305 }
306
307 fn calculate_sparsity(&self, data: &[f32]) -> f32 {
309 let zero_count = data.iter().filter(|&&x| x.abs() < 1e-8).count();
310 zero_count as f32 / data.len() as f32
311 }
312
313 fn classify_distribution(&self, data: &[f32], stats: &TensorStats) -> DistributionProfile {
315 if stats.near_zero_ratio > 0.6 {
317 return DistributionProfile::Sparse;
318 }
319
320 let skewness = data
322 .iter()
323 .map(|&x| ((x - stats.mean) / stats.std_dev).powi(3))
324 .sum::<f32>()
325 / data.len() as f32;
326
327 let kurtosis = data
329 .iter()
330 .map(|&x| ((x - stats.mean) / stats.std_dev).powi(4))
331 .sum::<f32>()
332 / data.len() as f32;
333
334 if skewness.abs() > 1.0 {
336 DistributionProfile::Skewed
337 } else if kurtosis > 4.0 {
338 DistributionProfile::HeavyTailed
339 } else if (kurtosis - 3.0).abs() < 0.5 && skewness.abs() < 0.5 {
340 DistributionProfile::Normal
341 } else if kurtosis < 2.0 {
342 DistributionProfile::Uniform
343 } else {
344 DistributionProfile::Bimodal
345 }
346 }
347
348 fn select_configuration(
350 &self,
351 profile: &TensorProfile,
352 constraints: Option<ConfigConstraints>,
353 ) -> TorshResult<QuantConfig> {
354 let mut config = match self.objective {
355 ConfigObjective::MaximumCompression => self.select_for_compression(profile),
356 ConfigObjective::MaximumAccuracy => self.select_for_accuracy(profile),
357 ConfigObjective::BalancedQuality => self.select_balanced(profile),
358 ConfigObjective::MaximumSpeed => self.select_for_speed(profile),
359 ConfigObjective::MinimumMemory => self.select_for_memory(profile),
360 ConfigObjective::EdgeOptimized => self.select_for_edge(profile),
361 }?;
362
363 if let Some(constraints) = constraints {
365 config = self.apply_constraints(config, constraints)?;
366 }
367
368 Ok(config)
369 }
370
371 fn select_for_compression(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
373 if profile.sparsity > 0.5 {
375 if profile.distribution == DistributionProfile::Sparse {
377 Ok(QuantConfig::binary())
378 } else {
379 Ok(QuantConfig::ternary())
380 }
381 } else if profile.numel < 1000 {
382 Ok(QuantConfig::int4())
384 } else {
385 let group_size = (profile.numel / 100).min(128).max(16);
387 Ok(QuantConfig::group_wise(0, group_size))
388 }
389 }
390
391 fn select_for_accuracy(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
393 let mut config = if profile.stats.has_outliers
394 || profile.distribution == DistributionProfile::HeavyTailed
395 {
396 QuantConfig::int8().with_observer(ObserverType::Histogram)
398 } else if profile.stats.range > 1000.0 {
399 QuantConfig::per_channel(0).with_observer(ObserverType::Percentile)
401 } else {
402 QuantConfig::int8().with_observer(ObserverType::Percentile)
404 };
405
406 if profile.stats.range > 10000.0 {
408 config = config.with_reduce_range(crate::config::ReduceRange::Reduce);
409 }
410
411 Ok(config)
412 }
413
414 fn select_balanced(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
416 if profile.numel > 100000 && profile.sparsity < 0.1 {
417 let group_size = if profile.stats.has_outliers { 32 } else { 64 };
419 Ok(QuantConfig::group_wise(0, group_size).with_observer(ObserverType::Histogram))
420 } else if profile.sparsity > 0.3 {
421 Ok(QuantConfig::int4().with_observer(ObserverType::MinMax))
423 } else {
424 Ok(QuantConfig::int8().with_observer(ObserverType::Histogram))
426 }
427 }
428
429 fn select_for_speed(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
431 let mut config = if profile.numel < 10000 {
433 QuantConfig::int8()
434 } else {
435 QuantConfig::int8().with_observer(ObserverType::MinMax) };
437
438 config = config.with_backend(QuantBackend::Fbgemm);
440
441 Ok(config)
442 }
443
444 fn select_for_memory(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
446 if profile.sparsity > 0.4 {
448 Ok(QuantConfig::binary())
449 } else if profile.numel > 50000 {
450 Ok(QuantConfig::int4())
451 } else {
452 Ok(QuantConfig::int8())
453 }
454 }
455
456 fn select_for_edge(&self, _profile: &TensorProfile) -> TorshResult<QuantConfig> {
458 Ok(QuantConfig::int8()
460 .with_backend(QuantBackend::Qnnpack)
461 .with_observer(ObserverType::MinMax))
462 }
463
464 fn generate_candidates(
466 &self,
467 profile: &TensorProfile,
468 constraints: Option<ConfigConstraints>,
469 ) -> TorshResult<Vec<(QuantConfig, f32)>> {
470 let mut candidates = vec![
471 (QuantConfig::int8(), 0.0),
472 (QuantConfig::int4(), 0.0),
473 (QuantConfig::per_channel(0), 0.0),
474 ];
475
476 if profile.sparsity > 0.3 {
478 candidates.push((QuantConfig::binary(), 0.0));
479 candidates.push((QuantConfig::ternary(), 0.0));
480 }
481
482 if profile.numel > 10000 {
483 candidates.push((QuantConfig::group_wise(0, 64), 0.0));
484 candidates.push((QuantConfig::group_wise(0, 32), 0.0));
485 }
486
487 if let Some(constraints) = constraints {
489 candidates.retain(|(config, _)| self.satisfies_constraints(config, &constraints));
490 }
491
492 Ok(candidates)
493 }
494
495 fn score_configuration(&self, config: &QuantConfig, profile: &TensorProfile) -> f32 {
497 let mut score = 0.0;
498
499 let scheme_score = self.score_scheme(config.scheme, profile);
501 score += scheme_score * self.feature_weights.distribution_weight;
502
503 let observer_score = self.score_observer(config.observer_type, profile);
505 score += observer_score * self.feature_weights.range_weight;
506
507 let backend_score = self.score_backend(config.backend, profile);
509 score += backend_score * 0.5;
510
511 let size_score = self.score_size(config.scheme, profile.numel);
513 score += size_score * self.feature_weights.size_weight;
514
515 score
516 }
517
518 fn score_scheme(&self, scheme: QScheme, _profile: &TensorProfile) -> f32 {
520 match (self.objective, scheme) {
521 (ConfigObjective::MaximumCompression, QScheme::Binary) => 10.0,
522 (ConfigObjective::MaximumCompression, QScheme::Ternary) => 9.0,
523 (ConfigObjective::MaximumCompression, QScheme::Int4PerTensor) => 8.0,
524 (ConfigObjective::MaximumAccuracy, QScheme::PerChannelAffine) => 10.0,
525 (ConfigObjective::MaximumAccuracy, QScheme::PerTensorAffine) => 8.5,
526 (ConfigObjective::MaximumSpeed, QScheme::PerTensorAffine) => 10.0,
527 (ConfigObjective::MaximumSpeed, QScheme::PerTensorSymmetric) => 9.5,
528 (ConfigObjective::BalancedQuality, QScheme::GroupWise) => 9.0,
529 (ConfigObjective::BalancedQuality, QScheme::PerTensorAffine) => 8.0,
530 _ => 5.0,
531 }
532 }
533
534 fn score_observer(&self, observer: ObserverType, profile: &TensorProfile) -> f32 {
536 match observer {
537 ObserverType::Histogram if profile.stats.has_outliers => 10.0,
538 ObserverType::Percentile
539 if profile.distribution == DistributionProfile::HeavyTailed =>
540 {
541 9.5
542 }
543 ObserverType::MinMax => 7.0, _ => 6.0,
545 }
546 }
547
548 fn score_backend(&self, backend: QuantBackend, _profile: &TensorProfile) -> f32 {
550 match (self.objective, backend) {
551 (ConfigObjective::MaximumSpeed, QuantBackend::Fbgemm) => 10.0,
552 (ConfigObjective::EdgeOptimized, QuantBackend::Qnnpack) => 10.0,
553 _ => 5.0,
554 }
555 }
556
557 fn score_size(&self, scheme: QScheme, numel: usize) -> f32 {
559 match scheme {
560 QScheme::GroupWise if numel > 100000 => 10.0,
561 QScheme::PerChannelAffine if numel > 10000 => 8.0,
562 QScheme::Binary if numel < 1000 => 3.0, _ => 5.0,
564 }
565 }
566
567 fn apply_constraints(
569 &self,
570 mut config: QuantConfig,
571 constraints: ConfigConstraints,
572 ) -> TorshResult<QuantConfig> {
573 if let Some(backend) = constraints.required_backend {
574 config = config.with_backend(backend);
575 }
576
577 if let Some(min_bits) = constraints.min_bits {
578 if min_bits >= 8
580 && matches!(
581 config.scheme,
582 QScheme::Int4PerTensor | QScheme::Binary | QScheme::Ternary
583 )
584 {
585 config = QuantConfig::int8();
586 }
587 }
588
589 Ok(config)
590 }
591
592 fn satisfies_constraints(&self, config: &QuantConfig, constraints: &ConfigConstraints) -> bool {
594 if let Some(backend) = constraints.required_backend {
595 if config.backend != backend {
596 return false;
597 }
598 }
599
600 if let Some(min_bits) = constraints.min_bits {
601 let scheme_bits = match config.scheme {
602 QScheme::Binary => 1,
603 QScheme::Ternary => 2,
604 QScheme::Int4PerTensor | QScheme::Int4PerChannel => 4,
605 _ => 8,
606 };
607 if scheme_bits < min_bits {
608 return false;
609 }
610 }
611
612 true
613 }
614
615 fn update_feature_weights(&mut self) {
617 if self.history.len() < 10 {
621 return;
622 }
623
624 let sparse_configs: Vec<&ConfigPerformance> = self
626 .history
627 .iter()
628 .filter(|p| p.profile.sparsity > 0.3)
629 .collect();
630
631 let dense_configs: Vec<&ConfigPerformance> = self
632 .history
633 .iter()
634 .filter(|p| p.profile.sparsity <= 0.3)
635 .collect();
636
637 if !sparse_configs.is_empty() {
639 let avg_sparse_error =
640 sparse_configs.iter().map(|p| p.error).sum::<f32>() / sparse_configs.len() as f32;
641 let avg_dense_error =
642 dense_configs.iter().map(|p| p.error).sum::<f32>() / dense_configs.len() as f32;
643
644 if avg_sparse_error < avg_dense_error {
645 self.feature_weights.sparsity_weight *= 1.1;
646 } else {
647 self.feature_weights.sparsity_weight *= 0.95;
648 }
649
650 self.feature_weights.sparsity_weight =
652 self.feature_weights.sparsity_weight.clamp(0.5, 2.0);
653 }
654 }
655}
656
657#[derive(Debug, Clone, Default)]
659pub struct ConfigConstraints {
660 pub required_backend: Option<QuantBackend>,
662 pub min_bits: Option<u32>,
664 pub max_memory: Option<usize>,
666 pub target_compression: Option<f32>,
668}
669
670impl ConfigConstraints {
671 pub fn new() -> Self {
673 Self::default()
674 }
675
676 pub fn with_backend(mut self, backend: QuantBackend) -> Self {
678 self.required_backend = Some(backend);
679 self
680 }
681
682 pub fn with_min_bits(mut self, bits: u32) -> Self {
684 self.min_bits = Some(bits);
685 self
686 }
687
688 pub fn with_max_memory(mut self, bytes: usize) -> Self {
690 self.max_memory = Some(bytes);
691 self
692 }
693
694 pub fn with_target_compression(mut self, ratio: f32) -> Self {
696 self.target_compression = Some(ratio);
697 self
698 }
699}
700
701#[cfg(test)]
702mod tests {
703 use super::*;
704 use torsh_tensor::creation::tensor_1d;
705
706 #[test]
707 fn test_auto_configurator_basic() {
708 let configurator = AutoConfigurator::new(ConfigObjective::BalancedQuality);
709 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
710 let tensor = tensor_1d(&data).unwrap();
711
712 let config = configurator.recommend(&tensor, None).unwrap();
713 assert!(config.validate().is_ok());
714 }
715
716 #[test]
717 fn test_tensor_profile_analysis() {
718 let configurator = AutoConfigurator::new(ConfigObjective::MaximumAccuracy);
719 let data = vec![1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0, 100.0]; let tensor = tensor_1d(&data).unwrap();
722
723 let profile = configurator.analyze_tensor(&tensor).unwrap();
724 assert!(
725 profile.stats.has_outliers,
726 "Expected outliers to be detected"
727 );
728 assert_eq!(profile.numel, 10);
729 assert!(profile.stats.max > 90.0, "Max value should be around 100");
730 }
731
732 #[test]
733 fn test_sparse_tensor_recommendation() {
734 let configurator = AutoConfigurator::new(ConfigObjective::MaximumCompression);
735 let data = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 2.0];
736 let tensor = tensor_1d(&data).unwrap();
737
738 let config = configurator.recommend(&tensor, None).unwrap();
739 assert!(matches!(config.scheme, QScheme::Binary | QScheme::Ternary));
741 }
742
743 #[test]
744 fn test_constraints_application() {
745 let configurator = AutoConfigurator::new(ConfigObjective::MaximumSpeed);
746 let data = vec![1.0, 2.0, 3.0, 4.0];
747 let tensor = tensor_1d(&data).unwrap();
748
749 let constraints = ConfigConstraints::new()
750 .with_backend(QuantBackend::Qnnpack)
751 .with_min_bits(8);
752
753 let config = configurator.recommend(&tensor, Some(constraints)).unwrap();
754 assert_eq!(config.backend, QuantBackend::Qnnpack);
755 }
756
757 #[test]
758 fn test_ranked_recommendations() {
759 let configurator = AutoConfigurator::new(ConfigObjective::BalancedQuality);
760 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
761 let tensor = tensor_1d(&data).unwrap();
762
763 let ranked = configurator.recommend_ranked(&tensor, 3, None).unwrap();
764 assert_eq!(ranked.len(), 3);
765
766 assert!(ranked[0].1 >= ranked[1].1);
768 assert!(ranked[1].1 >= ranked[2].1);
769 }
770
771 #[test]
772 fn test_performance_update() {
773 let mut configurator = AutoConfigurator::new(ConfigObjective::MaximumAccuracy);
774 let data = vec![1.0, 2.0, 3.0, 4.0];
775 let tensor = tensor_1d(&data).unwrap();
776 let config = QuantConfig::int8();
777
778 configurator
779 .update_performance(&config, &tensor, 0.1, 4.0, Some(1.5))
780 .unwrap();
781
782 assert_eq!(configurator.history.len(), 1);
783 }
784
785 #[test]
786 fn test_distribution_classification() {
787 let configurator = AutoConfigurator::new(ConfigObjective::BalancedQuality);
788
789 let normal_data = vec![1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0, 2.0];
791 let tensor = tensor_1d(&normal_data).unwrap();
792 let _profile = configurator.analyze_tensor(&tensor).unwrap();
793 let sparse_data = vec![0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0];
797 let tensor = tensor_1d(&sparse_data).unwrap();
798 let _profile = configurator.analyze_tensor(&tensor).unwrap();
799 assert_eq!(_profile.distribution, DistributionProfile::Sparse);
800 }
801
802 #[test]
803 fn test_objective_specific_selection() {
804 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
805 let tensor = tensor_1d(&data).unwrap();
806
807 let objectives = vec![
809 ConfigObjective::MaximumCompression,
810 ConfigObjective::MaximumAccuracy,
811 ConfigObjective::BalancedQuality,
812 ConfigObjective::MaximumSpeed,
813 ConfigObjective::MinimumMemory,
814 ConfigObjective::EdgeOptimized,
815 ];
816
817 for objective in objectives {
818 let configurator = AutoConfigurator::new(objective);
819 let config = configurator.recommend(&tensor, None).unwrap();
820 assert!(
821 config.validate().is_ok(),
822 "Failed for objective {:?}",
823 objective
824 );
825 }
826 }
827}