Skip to main content

torsh_quantization/
auto_config.rs

1//! ML-Powered Auto-Configuration System
2//!
3//! This module provides intelligent quantization configuration recommendations
4//! based on tensor characteristics, performance metrics, and learned patterns.
5//!
6//! ## Features
7//!
8//! - **Tensor Analysis**: Analyzes tensor properties (shape, distribution, sparsity, etc.)
9//! - **Performance Prediction**: Estimates quantization quality and performance trade-offs
10//! - **Configuration Selection**: Automatically selects optimal quantization schemes
11//! - **Adaptive Recommendations**: Learns from historical quantization results
12//!
13//! ## Usage
14//!
15//! ```rust
16//! use torsh_quantization::auto_config::{AutoConfigurator, ConfigObjective};
17//! use torsh_tensor::creation::tensor_1d;
18//!
19//! // Create auto-configurator with specific objectives
20//! let configurator = AutoConfigurator::new(ConfigObjective::BalancedQuality);
21//!
22//! // Create a tensor to analyze
23//! let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
24//! let tensor = tensor_1d(&data).unwrap();
25//!
26//! // Get optimal configuration for a tensor
27//! let optimal_config = configurator.recommend(&tensor, None).unwrap();
28//! assert!(optimal_config.validate().is_ok());
29//! ```
30
31use crate::config::{ObserverType, QScheme, QuantBackend, QuantConfig};
32use torsh_core::{Result as TorshResult, TorshError};
33use torsh_tensor::Tensor;
34
35/// Objectives for configuration selection
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum ConfigObjective {
38    /// Maximize compression ratio
39    MaximumCompression,
40    /// Maximize accuracy (minimize quantization error)
41    MaximumAccuracy,
42    /// Balance between compression and accuracy
43    BalancedQuality,
44    /// Optimize for inference speed
45    MaximumSpeed,
46    /// Optimize for memory efficiency
47    MinimumMemory,
48    /// Optimize for mobile/edge devices
49    EdgeOptimized,
50}
51
52/// Tensor characteristics for ML-based analysis
53#[derive(Debug, Clone)]
54pub struct TensorProfile {
55    /// Shape dimensions
56    pub shape: Vec<usize>,
57    /// Total number of elements
58    pub numel: usize,
59    /// Data statistics
60    pub stats: TensorStats,
61    /// Sparsity level (0.0 = dense, 1.0 = all zeros)
62    pub sparsity: f32,
63    /// Distribution characteristics
64    pub distribution: DistributionProfile,
65}
66
67/// Statistical properties of tensor data
68#[derive(Debug, Clone)]
69pub struct TensorStats {
70    /// Minimum value
71    pub min: f32,
72    /// Maximum value
73    pub max: f32,
74    /// Mean value
75    pub mean: f32,
76    /// Standard deviation
77    pub std_dev: f32,
78    /// Dynamic range
79    pub range: f32,
80    /// Presence of outliers
81    pub has_outliers: bool,
82    /// Percentage of near-zero values
83    pub near_zero_ratio: f32,
84}
85
86/// Distribution profile for intelligent scheme selection
87#[derive(Debug, Clone, PartialEq)]
88pub enum DistributionProfile {
89    /// Normal/Gaussian distribution
90    Normal,
91    /// Uniform distribution
92    Uniform,
93    /// Heavy-tailed distribution (many outliers)
94    HeavyTailed,
95    /// Bimodal distribution
96    Bimodal,
97    /// Highly skewed distribution
98    Skewed,
99    /// Sparse distribution
100    Sparse,
101}
102
103/// ML-powered auto-configurator
104pub struct AutoConfigurator {
105    objective: ConfigObjective,
106    /// Historical performance data for learning
107    history: Vec<ConfigPerformance>,
108    /// Feature importance weights (learned from experience)
109    feature_weights: FeatureWeights,
110}
111
112/// Performance metrics for a configuration
113#[derive(Debug, Clone)]
114struct ConfigPerformance {
115    #[allow(dead_code)]
116    config: QuantConfig,
117    profile: TensorProfile,
118    /// Observed quantization error
119    error: f32,
120    #[allow(dead_code)]
121    /// Compression ratio achieved
122    compression: f32,
123    #[allow(dead_code)]
124    /// Inference speedup (if measured)
125    speedup: Option<f32>,
126}
127
128/// Learned feature importance weights
129#[derive(Debug, Clone)]
130struct FeatureWeights {
131    /// Weight for data range consideration
132    range_weight: f32,
133    /// Weight for sparsity consideration
134    sparsity_weight: f32,
135    /// Weight for distribution type
136    distribution_weight: f32,
137    /// Weight for tensor size
138    size_weight: f32,
139}
140
141impl Default for FeatureWeights {
142    fn default() -> Self {
143        Self {
144            range_weight: 1.0,
145            sparsity_weight: 0.8,
146            distribution_weight: 0.9,
147            size_weight: 0.7,
148        }
149    }
150}
151
152impl AutoConfigurator {
153    /// Create a new auto-configurator with specified objective
154    pub fn new(objective: ConfigObjective) -> Self {
155        Self {
156            objective,
157            history: Vec::new(),
158            feature_weights: FeatureWeights::default(),
159        }
160    }
161
162    /// Recommend optimal configuration for a tensor
163    pub fn recommend(
164        &self,
165        tensor: &Tensor,
166        constraints: Option<ConfigConstraints>,
167    ) -> TorshResult<QuantConfig> {
168        // Analyze tensor characteristics
169        let profile = self.analyze_tensor(tensor)?;
170
171        // Select optimal configuration based on profile and objective
172        let config = self.select_configuration(&profile, constraints)?;
173
174        Ok(config)
175    }
176
177    /// Recommend multiple configurations ranked by expected performance
178    pub fn recommend_ranked(
179        &self,
180        tensor: &Tensor,
181        top_k: usize,
182        constraints: Option<ConfigConstraints>,
183    ) -> TorshResult<Vec<(QuantConfig, f32)>> {
184        let profile = self.analyze_tensor(tensor)?;
185        let mut candidates = self.generate_candidates(&profile, constraints)?;
186
187        // Score each candidate
188        for (config, score) in &mut candidates {
189            *score = self.score_configuration(config, &profile);
190        }
191
192        // Sort by score (descending)
193        candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
194
195        // Return top-k
196        candidates.truncate(top_k);
197        Ok(candidates)
198    }
199
200    /// Update the configurator with observed performance
201    pub fn update_performance(
202        &mut self,
203        config: &QuantConfig,
204        tensor: &Tensor,
205        observed_error: f32,
206        observed_compression: f32,
207        speedup: Option<f32>,
208    ) -> TorshResult<()> {
209        let profile = self.analyze_tensor(tensor)?;
210
211        let performance = ConfigPerformance {
212            config: config.clone(),
213            profile,
214            error: observed_error,
215            compression: observed_compression,
216            speedup,
217        };
218
219        self.history.push(performance);
220
221        // Update feature weights based on new data (simple online learning)
222        if self.history.len() >= 10 {
223            self.update_feature_weights();
224        }
225
226        Ok(())
227    }
228
229    // -------------------------------------------------------------------------
230    // Private helper methods
231    // -------------------------------------------------------------------------
232
233    /// Analyze tensor to extract characteristics
234    fn analyze_tensor(&self, tensor: &Tensor) -> TorshResult<TensorProfile> {
235        let data = tensor.data()?;
236        let shape = tensor.shape().dims().to_vec();
237        let numel = tensor.shape().numel();
238
239        // Calculate statistics
240        let stats = self.calculate_stats(&data)?;
241
242        // Calculate sparsity
243        let sparsity = self.calculate_sparsity(&data);
244
245        // Determine distribution profile
246        let distribution = self.classify_distribution(&data, &stats);
247
248        Ok(TensorProfile {
249            shape,
250            numel,
251            stats,
252            sparsity,
253            distribution,
254        })
255    }
256
257    /// Calculate statistical properties
258    fn calculate_stats(&self, data: &[f32]) -> TorshResult<TensorStats> {
259        if data.is_empty() {
260            return Err(TorshError::InvalidArgument(
261                "Cannot calculate stats for empty tensor".to_string(),
262            ));
263        }
264
265        let min = data.iter().copied().fold(f32::INFINITY, f32::min);
266        let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
267        let range = max - min;
268
269        let mean = data.iter().sum::<f32>() / data.len() as f32;
270
271        let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
272        let std_dev = variance.sqrt();
273
274        // Detect outliers using IQR method
275        let mut sorted = data.to_vec();
276        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
277
278        let q1_idx = sorted.len() / 4;
279        let q3_idx = 3 * sorted.len() / 4;
280        let q1 = sorted[q1_idx];
281        let q3 = sorted[q3_idx];
282        let iqr = q3 - q1;
283
284        let outlier_threshold_low = q1 - 1.5 * iqr;
285        let outlier_threshold_high = q3 + 1.5 * iqr;
286
287        let has_outliers = data
288            .iter()
289            .any(|&x| x < outlier_threshold_low || x > outlier_threshold_high);
290
291        // Calculate near-zero ratio
292        let zero_threshold = range.abs() * 0.01; // 1% of range
293        let near_zero_count = data.iter().filter(|&&x| x.abs() < zero_threshold).count();
294        let near_zero_ratio = near_zero_count as f32 / data.len() as f32;
295
296        Ok(TensorStats {
297            min,
298            max,
299            mean,
300            std_dev,
301            range,
302            has_outliers,
303            near_zero_ratio,
304        })
305    }
306
307    /// Calculate sparsity level
308    fn calculate_sparsity(&self, data: &[f32]) -> f32 {
309        let zero_count = data.iter().filter(|&&x| x.abs() < 1e-8).count();
310        zero_count as f32 / data.len() as f32
311    }
312
313    /// Classify distribution type
314    fn classify_distribution(&self, data: &[f32], stats: &TensorStats) -> DistributionProfile {
315        // Check for sparsity first
316        if stats.near_zero_ratio > 0.6 {
317            return DistributionProfile::Sparse;
318        }
319
320        // Calculate skewness
321        let skewness = data
322            .iter()
323            .map(|&x| ((x - stats.mean) / stats.std_dev).powi(3))
324            .sum::<f32>()
325            / data.len() as f32;
326
327        // Calculate kurtosis for tail heaviness
328        let kurtosis = data
329            .iter()
330            .map(|&x| ((x - stats.mean) / stats.std_dev).powi(4))
331            .sum::<f32>()
332            / data.len() as f32;
333
334        // Classification logic
335        if skewness.abs() > 1.0 {
336            DistributionProfile::Skewed
337        } else if kurtosis > 4.0 {
338            DistributionProfile::HeavyTailed
339        } else if (kurtosis - 3.0).abs() < 0.5 && skewness.abs() < 0.5 {
340            DistributionProfile::Normal
341        } else if kurtosis < 2.0 {
342            DistributionProfile::Uniform
343        } else {
344            DistributionProfile::Bimodal
345        }
346    }
347
348    /// Select optimal configuration based on profile
349    fn select_configuration(
350        &self,
351        profile: &TensorProfile,
352        constraints: Option<ConfigConstraints>,
353    ) -> TorshResult<QuantConfig> {
354        let mut config = match self.objective {
355            ConfigObjective::MaximumCompression => self.select_for_compression(profile),
356            ConfigObjective::MaximumAccuracy => self.select_for_accuracy(profile),
357            ConfigObjective::BalancedQuality => self.select_balanced(profile),
358            ConfigObjective::MaximumSpeed => self.select_for_speed(profile),
359            ConfigObjective::MinimumMemory => self.select_for_memory(profile),
360            ConfigObjective::EdgeOptimized => self.select_for_edge(profile),
361        }?;
362
363        // Apply constraints if provided
364        if let Some(constraints) = constraints {
365            config = self.apply_constraints(config, constraints)?;
366        }
367
368        Ok(config)
369    }
370
371    /// Select configuration optimized for compression
372    fn select_for_compression(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
373        // Use aggressive quantization
374        if profile.sparsity > 0.5 {
375            // Sparse data - use binary or ternary
376            if profile.distribution == DistributionProfile::Sparse {
377                Ok(QuantConfig::binary())
378            } else {
379                Ok(QuantConfig::ternary())
380            }
381        } else if profile.numel < 1000 {
382            // Small tensors - INT4 is good
383            Ok(QuantConfig::int4())
384        } else {
385            // Large tensors - group-wise INT4
386            let group_size = (profile.numel / 100).min(128).max(16);
387            Ok(QuantConfig::group_wise(0, group_size))
388        }
389    }
390
391    /// Select configuration optimized for accuracy
392    fn select_for_accuracy(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
393        let mut config = if profile.stats.has_outliers
394            || profile.distribution == DistributionProfile::HeavyTailed
395        {
396            // Use histogram observer for outliers
397            QuantConfig::int8().with_observer(ObserverType::Histogram)
398        } else if profile.stats.range > 1000.0 {
399            // Large range - use per-channel quantization
400            QuantConfig::per_channel(0).with_observer(ObserverType::Percentile)
401        } else {
402            // Standard case - per-tensor with percentile
403            QuantConfig::int8().with_observer(ObserverType::Percentile)
404        };
405
406        // Use reduced range for better numerical stability if needed
407        if profile.stats.range > 10000.0 {
408            config = config.with_reduce_range(crate::config::ReduceRange::Reduce);
409        }
410
411        Ok(config)
412    }
413
414    /// Select balanced configuration
415    fn select_balanced(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
416        if profile.numel > 100000 && profile.sparsity < 0.1 {
417            // Large, dense tensors - group-wise for balance
418            let group_size = if profile.stats.has_outliers { 32 } else { 64 };
419            Ok(QuantConfig::group_wise(0, group_size).with_observer(ObserverType::Histogram))
420        } else if profile.sparsity > 0.3 {
421            // Moderately sparse - INT4
422            Ok(QuantConfig::int4().with_observer(ObserverType::MinMax))
423        } else {
424            // Standard case - INT8 with histogram
425            Ok(QuantConfig::int8().with_observer(ObserverType::Histogram))
426        }
427    }
428
429    /// Select configuration optimized for speed
430    fn select_for_speed(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
431        // Prefer simpler schemes and backends
432        let mut config = if profile.numel < 10000 {
433            QuantConfig::int8()
434        } else {
435            QuantConfig::int8().with_observer(ObserverType::MinMax) // MinMax is fastest
436        };
437
438        // Use optimized backend
439        config = config.with_backend(QuantBackend::Fbgemm);
440
441        Ok(config)
442    }
443
444    /// Select configuration optimized for memory
445    fn select_for_memory(&self, profile: &TensorProfile) -> TorshResult<QuantConfig> {
446        // Similar to compression but with per-channel for better quality
447        if profile.sparsity > 0.4 {
448            Ok(QuantConfig::binary())
449        } else if profile.numel > 50000 {
450            Ok(QuantConfig::int4())
451        } else {
452            Ok(QuantConfig::int8())
453        }
454    }
455
456    /// Select configuration optimized for edge devices
457    fn select_for_edge(&self, _profile: &TensorProfile) -> TorshResult<QuantConfig> {
458        // Edge devices prefer simple, fast quantization
459        Ok(QuantConfig::int8()
460            .with_backend(QuantBackend::Qnnpack)
461            .with_observer(ObserverType::MinMax))
462    }
463
464    /// Generate candidate configurations
465    fn generate_candidates(
466        &self,
467        profile: &TensorProfile,
468        constraints: Option<ConfigConstraints>,
469    ) -> TorshResult<Vec<(QuantConfig, f32)>> {
470        let mut candidates = vec![
471            (QuantConfig::int8(), 0.0),
472            (QuantConfig::int4(), 0.0),
473            (QuantConfig::per_channel(0), 0.0),
474        ];
475
476        // Add specialized candidates based on profile
477        if profile.sparsity > 0.3 {
478            candidates.push((QuantConfig::binary(), 0.0));
479            candidates.push((QuantConfig::ternary(), 0.0));
480        }
481
482        if profile.numel > 10000 {
483            candidates.push((QuantConfig::group_wise(0, 64), 0.0));
484            candidates.push((QuantConfig::group_wise(0, 32), 0.0));
485        }
486
487        // Apply constraints
488        if let Some(constraints) = constraints {
489            candidates.retain(|(config, _)| self.satisfies_constraints(config, &constraints));
490        }
491
492        Ok(candidates)
493    }
494
495    /// Score a configuration for the current objective
496    fn score_configuration(&self, config: &QuantConfig, profile: &TensorProfile) -> f32 {
497        let mut score = 0.0;
498
499        // Base score from scheme
500        let scheme_score = self.score_scheme(config.scheme, profile);
501        score += scheme_score * self.feature_weights.distribution_weight;
502
503        // Score from observer type
504        let observer_score = self.score_observer(config.observer_type, profile);
505        score += observer_score * self.feature_weights.range_weight;
506
507        // Score from backend
508        let backend_score = self.score_backend(config.backend, profile);
509        score += backend_score * 0.5;
510
511        // Adjust based on tensor size
512        let size_score = self.score_size(config.scheme, profile.numel);
513        score += size_score * self.feature_weights.size_weight;
514
515        score
516    }
517
518    /// Score quantization scheme
519    fn score_scheme(&self, scheme: QScheme, _profile: &TensorProfile) -> f32 {
520        match (self.objective, scheme) {
521            (ConfigObjective::MaximumCompression, QScheme::Binary) => 10.0,
522            (ConfigObjective::MaximumCompression, QScheme::Ternary) => 9.0,
523            (ConfigObjective::MaximumCompression, QScheme::Int4PerTensor) => 8.0,
524            (ConfigObjective::MaximumAccuracy, QScheme::PerChannelAffine) => 10.0,
525            (ConfigObjective::MaximumAccuracy, QScheme::PerTensorAffine) => 8.5,
526            (ConfigObjective::MaximumSpeed, QScheme::PerTensorAffine) => 10.0,
527            (ConfigObjective::MaximumSpeed, QScheme::PerTensorSymmetric) => 9.5,
528            (ConfigObjective::BalancedQuality, QScheme::GroupWise) => 9.0,
529            (ConfigObjective::BalancedQuality, QScheme::PerTensorAffine) => 8.0,
530            _ => 5.0,
531        }
532    }
533
534    /// Score observer type
535    fn score_observer(&self, observer: ObserverType, profile: &TensorProfile) -> f32 {
536        match observer {
537            ObserverType::Histogram if profile.stats.has_outliers => 10.0,
538            ObserverType::Percentile
539                if profile.distribution == DistributionProfile::HeavyTailed =>
540            {
541                9.5
542            }
543            ObserverType::MinMax => 7.0, // Fast but less accurate
544            _ => 6.0,
545        }
546    }
547
548    /// Score backend
549    fn score_backend(&self, backend: QuantBackend, _profile: &TensorProfile) -> f32 {
550        match (self.objective, backend) {
551            (ConfigObjective::MaximumSpeed, QuantBackend::Fbgemm) => 10.0,
552            (ConfigObjective::EdgeOptimized, QuantBackend::Qnnpack) => 10.0,
553            _ => 5.0,
554        }
555    }
556
557    /// Score based on tensor size
558    fn score_size(&self, scheme: QScheme, numel: usize) -> f32 {
559        match scheme {
560            QScheme::GroupWise if numel > 100000 => 10.0,
561            QScheme::PerChannelAffine if numel > 10000 => 8.0,
562            QScheme::Binary if numel < 1000 => 3.0, // Binary not great for small tensors
563            _ => 5.0,
564        }
565    }
566
567    /// Apply constraints to configuration
568    fn apply_constraints(
569        &self,
570        mut config: QuantConfig,
571        constraints: ConfigConstraints,
572    ) -> TorshResult<QuantConfig> {
573        if let Some(backend) = constraints.required_backend {
574            config = config.with_backend(backend);
575        }
576
577        if let Some(min_bits) = constraints.min_bits {
578            // Ensure scheme uses at least min_bits
579            if min_bits >= 8
580                && matches!(
581                    config.scheme,
582                    QScheme::Int4PerTensor | QScheme::Binary | QScheme::Ternary
583                )
584            {
585                config = QuantConfig::int8();
586            }
587        }
588
589        Ok(config)
590    }
591
592    /// Check if configuration satisfies constraints
593    fn satisfies_constraints(&self, config: &QuantConfig, constraints: &ConfigConstraints) -> bool {
594        if let Some(backend) = constraints.required_backend {
595            if config.backend != backend {
596                return false;
597            }
598        }
599
600        if let Some(min_bits) = constraints.min_bits {
601            let scheme_bits = match config.scheme {
602                QScheme::Binary => 1,
603                QScheme::Ternary => 2,
604                QScheme::Int4PerTensor | QScheme::Int4PerChannel => 4,
605                _ => 8,
606            };
607            if scheme_bits < min_bits {
608                return false;
609            }
610        }
611
612        true
613    }
614
615    /// Update feature weights based on historical performance
616    fn update_feature_weights(&mut self) {
617        // Simple online learning: boost weights for features that correlate with good performance
618        // This is a simplified version - production would use more sophisticated ML
619
620        if self.history.len() < 10 {
621            return;
622        }
623
624        // Calculate average error for different feature combinations
625        let sparse_configs: Vec<&ConfigPerformance> = self
626            .history
627            .iter()
628            .filter(|p| p.profile.sparsity > 0.3)
629            .collect();
630
631        let dense_configs: Vec<&ConfigPerformance> = self
632            .history
633            .iter()
634            .filter(|p| p.profile.sparsity <= 0.3)
635            .collect();
636
637        // Adjust sparsity weight based on performance
638        if !sparse_configs.is_empty() {
639            let avg_sparse_error =
640                sparse_configs.iter().map(|p| p.error).sum::<f32>() / sparse_configs.len() as f32;
641            let avg_dense_error =
642                dense_configs.iter().map(|p| p.error).sum::<f32>() / dense_configs.len() as f32;
643
644            if avg_sparse_error < avg_dense_error {
645                self.feature_weights.sparsity_weight *= 1.1;
646            } else {
647                self.feature_weights.sparsity_weight *= 0.95;
648            }
649
650            // Keep weights in reasonable range
651            self.feature_weights.sparsity_weight =
652                self.feature_weights.sparsity_weight.clamp(0.5, 2.0);
653        }
654    }
655}
656
657/// Constraints for configuration selection
658#[derive(Debug, Clone, Default)]
659pub struct ConfigConstraints {
660    /// Required backend (if any)
661    pub required_backend: Option<QuantBackend>,
662    /// Minimum number of quantization bits
663    pub min_bits: Option<u32>,
664    /// Maximum memory usage (bytes)
665    pub max_memory: Option<usize>,
666    /// Target compression ratio
667    pub target_compression: Option<f32>,
668}
669
670impl ConfigConstraints {
671    /// Create new constraints
672    pub fn new() -> Self {
673        Self::default()
674    }
675
676    /// Set required backend
677    pub fn with_backend(mut self, backend: QuantBackend) -> Self {
678        self.required_backend = Some(backend);
679        self
680    }
681
682    /// Set minimum bits
683    pub fn with_min_bits(mut self, bits: u32) -> Self {
684        self.min_bits = Some(bits);
685        self
686    }
687
688    /// Set maximum memory
689    pub fn with_max_memory(mut self, bytes: usize) -> Self {
690        self.max_memory = Some(bytes);
691        self
692    }
693
694    /// Set target compression ratio
695    pub fn with_target_compression(mut self, ratio: f32) -> Self {
696        self.target_compression = Some(ratio);
697        self
698    }
699}
700
701#[cfg(test)]
702mod tests {
703    use super::*;
704    use torsh_tensor::creation::tensor_1d;
705
706    #[test]
707    fn test_auto_configurator_basic() {
708        let configurator = AutoConfigurator::new(ConfigObjective::BalancedQuality);
709        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
710        let tensor = tensor_1d(&data).unwrap();
711
712        let config = configurator.recommend(&tensor, None).unwrap();
713        assert!(config.validate().is_ok());
714    }
715
716    #[test]
717    fn test_tensor_profile_analysis() {
718        let configurator = AutoConfigurator::new(ConfigObjective::MaximumAccuracy);
719        // Create data with more values to make outlier detection more reliable
720        let data = vec![1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0, 2.0, 1.0, 100.0]; // Has outlier
721        let tensor = tensor_1d(&data).unwrap();
722
723        let profile = configurator.analyze_tensor(&tensor).unwrap();
724        assert!(
725            profile.stats.has_outliers,
726            "Expected outliers to be detected"
727        );
728        assert_eq!(profile.numel, 10);
729        assert!(profile.stats.max > 90.0, "Max value should be around 100");
730    }
731
732    #[test]
733    fn test_sparse_tensor_recommendation() {
734        let configurator = AutoConfigurator::new(ConfigObjective::MaximumCompression);
735        let data = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 2.0];
736        let tensor = tensor_1d(&data).unwrap();
737
738        let config = configurator.recommend(&tensor, None).unwrap();
739        // Should recommend binary or ternary for sparse data
740        assert!(matches!(config.scheme, QScheme::Binary | QScheme::Ternary));
741    }
742
743    #[test]
744    fn test_constraints_application() {
745        let configurator = AutoConfigurator::new(ConfigObjective::MaximumSpeed);
746        let data = vec![1.0, 2.0, 3.0, 4.0];
747        let tensor = tensor_1d(&data).unwrap();
748
749        let constraints = ConfigConstraints::new()
750            .with_backend(QuantBackend::Qnnpack)
751            .with_min_bits(8);
752
753        let config = configurator.recommend(&tensor, Some(constraints)).unwrap();
754        assert_eq!(config.backend, QuantBackend::Qnnpack);
755    }
756
757    #[test]
758    fn test_ranked_recommendations() {
759        let configurator = AutoConfigurator::new(ConfigObjective::BalancedQuality);
760        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
761        let tensor = tensor_1d(&data).unwrap();
762
763        let ranked = configurator.recommend_ranked(&tensor, 3, None).unwrap();
764        assert_eq!(ranked.len(), 3);
765
766        // Scores should be descending
767        assert!(ranked[0].1 >= ranked[1].1);
768        assert!(ranked[1].1 >= ranked[2].1);
769    }
770
771    #[test]
772    fn test_performance_update() {
773        let mut configurator = AutoConfigurator::new(ConfigObjective::MaximumAccuracy);
774        let data = vec![1.0, 2.0, 3.0, 4.0];
775        let tensor = tensor_1d(&data).unwrap();
776        let config = QuantConfig::int8();
777
778        configurator
779            .update_performance(&config, &tensor, 0.1, 4.0, Some(1.5))
780            .unwrap();
781
782        assert_eq!(configurator.history.len(), 1);
783    }
784
785    #[test]
786    fn test_distribution_classification() {
787        let configurator = AutoConfigurator::new(ConfigObjective::BalancedQuality);
788
789        // Normal distribution
790        let normal_data = vec![1.0, 2.0, 3.0, 2.0, 1.0, 2.0, 3.0, 2.0];
791        let tensor = tensor_1d(&normal_data).unwrap();
792        let _profile = configurator.analyze_tensor(&tensor).unwrap();
793        // Distribution classification depends on stats
794
795        // Sparse distribution
796        let sparse_data = vec![0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0];
797        let tensor = tensor_1d(&sparse_data).unwrap();
798        let _profile = configurator.analyze_tensor(&tensor).unwrap();
799        assert_eq!(_profile.distribution, DistributionProfile::Sparse);
800    }
801
802    #[test]
803    fn test_objective_specific_selection() {
804        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
805        let tensor = tensor_1d(&data).unwrap();
806
807        // Test each objective
808        let objectives = vec![
809            ConfigObjective::MaximumCompression,
810            ConfigObjective::MaximumAccuracy,
811            ConfigObjective::BalancedQuality,
812            ConfigObjective::MaximumSpeed,
813            ConfigObjective::MinimumMemory,
814            ConfigObjective::EdgeOptimized,
815        ];
816
817        for objective in objectives {
818            let configurator = AutoConfigurator::new(objective);
819            let config = configurator.recommend(&tensor, None).unwrap();
820            assert!(
821                config.validate().is_ok(),
822                "Failed for objective {:?}",
823                objective
824            );
825        }
826    }
827}