sklears_model_selection/
model_complexity.rs

1//! Model complexity analysis and overfitting detection
2//!
3//! This module provides tools for analyzing model complexity and detecting overfitting.
4//! It includes various complexity measures, overfitting detection strategies, and
5//! methods for optimal model selection based on complexity-performance trade-offs.
6
7use sklears_core::{
8    error::{Result, SklearsError},
9    traits::{Estimator, Fit, Predict},
10};
11use std::collections::HashMap;
12use std::fmt::{self, Display, Formatter};
13
14/// Result of model complexity analysis
15#[derive(Debug, Clone)]
16pub struct ComplexityAnalysisResult {
17    /// Training error
18    pub train_error: f64,
19    /// Validation error
20    pub validation_error: f64,
21    /// Estimated model complexity
22    pub complexity_score: f64,
23    /// Overfitting indicator (0 = no overfitting, 1 = severe overfitting)
24    pub overfitting_score: f64,
25    /// Generalization gap (validation_error - train_error)
26    pub generalization_gap: f64,
27    /// Complexity measures used
28    pub complexity_measures: HashMap<String, f64>,
29    /// Whether overfitting is detected
30    pub overfitting_detected: bool,
31    /// Recommended action
32    pub recommendation: ComplexityRecommendation,
33}
34
35/// Recommendations based on complexity analysis
36#[derive(Debug, Clone, PartialEq)]
37pub enum ComplexityRecommendation {
38    /// Model is appropriate
39    Appropriate,
40    /// Model is too simple (underfitting)
41    IncreaseComplexity,
42    /// Model is too complex (overfitting)
43    ReduceComplexity,
44    /// Use regularization
45    UseRegularization,
46    /// Collect more data
47    CollectMoreData,
48    /// Try ensemble methods
49    TryEnsembles,
50}
51
52impl Display for ComplexityRecommendation {
53    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
54        let msg = match self {
55            ComplexityRecommendation::Appropriate => "Model complexity is appropriate",
56            ComplexityRecommendation::IncreaseComplexity => "Consider increasing model complexity",
57            ComplexityRecommendation::ReduceComplexity => "Consider reducing model complexity",
58            ComplexityRecommendation::UseRegularization => "Consider using regularization",
59            ComplexityRecommendation::CollectMoreData => "Consider collecting more training data",
60            ComplexityRecommendation::TryEnsembles => "Consider using ensemble methods",
61        };
62        write!(f, "{}", msg)
63    }
64}
65
66impl Display for ComplexityAnalysisResult {
67    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
68        write!(
69            f,
70            "Model Complexity Analysis:\n\
71             Train Error: {:.6}\n\
72             Validation Error: {:.6}\n\
73             Generalization Gap: {:.6}\n\
74             Complexity Score: {:.6}\n\
75             Overfitting Score: {:.6}\n\
76             Overfitting Detected: {}\n\
77             Recommendation: {}",
78            self.train_error,
79            self.validation_error,
80            self.generalization_gap,
81            self.complexity_score,
82            self.overfitting_score,
83            self.overfitting_detected,
84            self.recommendation
85        )
86    }
87}
88
89/// Configuration for complexity analysis
90#[derive(Debug, Clone)]
91pub struct ComplexityAnalysisConfig {
92    /// Threshold for overfitting detection (generalization gap)
93    pub overfitting_threshold: f64,
94    /// Threshold for underfitting detection (high training error)
95    pub underfitting_threshold: f64,
96    /// Weight for training set size in complexity calculation
97    pub data_size_weight: f64,
98    /// Whether to include information-theoretic measures
99    pub include_information_measures: bool,
100    /// Whether to perform cross-validation for robustness
101    pub use_cross_validation: bool,
102    /// Number of CV folds if using cross-validation
103    pub cv_folds: usize,
104}
105
106impl Default for ComplexityAnalysisConfig {
107    fn default() -> Self {
108        Self {
109            overfitting_threshold: 0.1,
110            underfitting_threshold: 0.3,
111            data_size_weight: 0.1,
112            include_information_measures: true,
113            use_cross_validation: false,
114            cv_folds: 5,
115        }
116    }
117}
118
119/// Complexity measures for different types of models
120#[derive(Debug, Clone)]
121pub enum ComplexityMeasure {
122    /// Number of parameters
123    ParameterCount,
124    /// Effective degrees of freedom
125    DegreesOfFreedom,
126    /// VC dimension estimate
127    VCDimension,
128    /// Rademacher complexity
129    RademacherComplexity,
130    /// Path length (for tree models)
131    PathLength,
132    /// Number of support vectors (for SVM)
133    SupportVectorCount,
134    /// Spectral complexity (for neural networks)
135    SpectralComplexity,
136}
137
138/// Model complexity analyzer
139pub struct ModelComplexityAnalyzer {
140    config: ComplexityAnalysisConfig,
141}
142
143impl ModelComplexityAnalyzer {
144    /// Create a new complexity analyzer with default configuration
145    pub fn new() -> Self {
146        Self {
147            config: ComplexityAnalysisConfig::default(),
148        }
149    }
150
151    /// Create a new complexity analyzer with custom configuration
152    pub fn with_config(config: ComplexityAnalysisConfig) -> Self {
153        Self { config }
154    }
155
156    /// Set overfitting threshold
157    pub fn overfitting_threshold(mut self, threshold: f64) -> Self {
158        self.config.overfitting_threshold = threshold;
159        self
160    }
161
162    /// Set underfitting threshold
163    pub fn underfitting_threshold(mut self, threshold: f64) -> Self {
164        self.config.underfitting_threshold = threshold;
165        self
166    }
167
168    /// Enable or disable cross-validation
169    pub fn use_cross_validation(mut self, use_cv: bool) -> Self {
170        self.config.use_cross_validation = use_cv;
171        self
172    }
173
174    /// Set number of CV folds
175    pub fn cv_folds(mut self, folds: usize) -> Self {
176        self.config.cv_folds = folds;
177        self
178    }
179
180    /// Analyze model complexity
181    pub fn analyze<E, X, Y>(
182        &self,
183        estimator: &E,
184        x_train: &[X],
185        y_train: &[Y],
186        x_val: &[X],
187        y_val: &[Y],
188    ) -> Result<ComplexityAnalysisResult>
189    where
190        E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
191        E::Fitted: Predict<Vec<X>, Vec<f64>>,
192        X: Clone,
193        Y: Clone + Into<f64>,
194    {
195        // Train the model
196        let x_train_vec = x_train.to_vec();
197        let y_train_vec = y_train.to_vec();
198        let trained_model = estimator.clone().fit(&x_train_vec, &y_train_vec)?;
199
200        // Calculate training error
201        let train_predictions = trained_model.predict(&x_train_vec)?;
202        let train_targets: Vec<f64> = y_train.iter().map(|y| y.clone().into()).collect();
203        let train_error = self.calculate_error(&train_predictions, &train_targets);
204
205        // Calculate validation error
206        let x_val_vec = x_val.to_vec();
207        let val_predictions = trained_model.predict(&x_val_vec)?;
208        let val_targets: Vec<f64> = y_val.iter().map(|y| y.clone().into()).collect();
209        let validation_error = self.calculate_error(&val_predictions, &val_targets);
210
211        // Calculate generalization gap
212        let generalization_gap = validation_error - train_error;
213
214        // Estimate model complexity
215        let complexity_measures = self.estimate_complexity(x_train, y_train, &trained_model)?;
216        let complexity_score = self.aggregate_complexity(&complexity_measures);
217
218        // Calculate overfitting score
219        let overfitting_score = self.calculate_overfitting_score(
220            train_error,
221            validation_error,
222            complexity_score,
223            x_train.len(),
224        );
225
226        // Detect overfitting
227        let overfitting_detected = generalization_gap > self.config.overfitting_threshold;
228
229        // Generate recommendation
230        let recommendation = self.generate_recommendation(
231            train_error,
232            validation_error,
233            generalization_gap,
234            complexity_score,
235            overfitting_detected,
236        );
237
238        Ok(ComplexityAnalysisResult {
239            train_error,
240            validation_error,
241            complexity_score,
242            overfitting_score,
243            generalization_gap,
244            complexity_measures,
245            overfitting_detected,
246            recommendation,
247        })
248    }
249
250    /// Calculate prediction error (MSE for regression)
251    fn calculate_error(&self, predictions: &[f64], targets: &[f64]) -> f64 {
252        if predictions.len() != targets.len() {
253            return f64::INFINITY;
254        }
255
256        let mse = predictions
257            .iter()
258            .zip(targets.iter())
259            .map(|(&pred, &target)| (pred - target).powi(2))
260            .sum::<f64>()
261            / predictions.len() as f64;
262
263        mse
264    }
265
266    /// Estimate model complexity using various measures
267    fn estimate_complexity<X, Y>(
268        &self,
269        x_train: &[X],
270        y_train: &[Y],
271        _trained_model: &impl Predict<Vec<X>, Vec<f64>>,
272    ) -> Result<HashMap<String, f64>>
273    where
274        X: Clone,
275        Y: Clone + Into<f64>,
276    {
277        let mut measures = HashMap::new();
278
279        // Basic complexity based on training set size
280        let n_samples = x_train.len() as f64;
281        let n_features = self.estimate_feature_count(x_train);
282
283        measures.insert("training_set_size".to_string(), n_samples);
284        measures.insert("feature_count".to_string(), n_features);
285
286        // Parameter count estimation (approximate)
287        let param_count = self.estimate_parameter_count(n_features);
288        measures.insert("estimated_parameters".to_string(), param_count);
289
290        // Data-dependent complexity measures
291        let data_complexity = self.calculate_data_complexity(x_train, y_train);
292        measures.insert("data_complexity".to_string(), data_complexity);
293
294        // Effective degrees of freedom estimation
295        let eff_dof = self.estimate_effective_dof(n_samples, param_count);
296        measures.insert("effective_dof".to_string(), eff_dof);
297
298        Ok(measures)
299    }
300
301    /// Estimate number of features (simplified)
302    fn estimate_feature_count<X>(&self, _x_train: &[X]) -> f64 {
303        // This is a simplified estimation - in practice, this would depend on the actual data type
304        // For now, we'll use a default estimate
305        10.0 // Placeholder
306    }
307
308    /// Estimate parameter count based on features
309    fn estimate_parameter_count(&self, n_features: f64) -> f64 {
310        // Simple linear model assumption: one parameter per feature plus intercept
311        n_features + 1.0
312    }
313
314    /// Calculate data complexity (variance in targets)
315    fn calculate_data_complexity<X, Y>(&self, _x_train: &[X], y_train: &[Y]) -> f64
316    where
317        Y: Clone + Into<f64>,
318    {
319        let targets: Vec<f64> = y_train.iter().map(|y| y.clone().into()).collect();
320        if targets.is_empty() {
321            return 0.0;
322        }
323
324        let mean = targets.iter().sum::<f64>() / targets.len() as f64;
325        let variance =
326            targets.iter().map(|&y| (y - mean).powi(2)).sum::<f64>() / targets.len() as f64;
327
328        variance.sqrt()
329    }
330
331    /// Estimate effective degrees of freedom
332    fn estimate_effective_dof(&self, n_samples: f64, param_count: f64) -> f64 {
333        // Simple heuristic: effective DOF is limited by sample size
334        param_count.min(n_samples * 0.1)
335    }
336
337    /// Aggregate complexity measures into a single score
338    fn aggregate_complexity(&self, measures: &HashMap<String, f64>) -> f64 {
339        let mut score = 0.0;
340        let mut weight_sum = 0.0;
341
342        // Weight different complexity measures
343        if let Some(&param_count) = measures.get("estimated_parameters") {
344            score += param_count * 0.4;
345            weight_sum += 0.4;
346        }
347
348        if let Some(&eff_dof) = measures.get("effective_dof") {
349            score += eff_dof * 0.3;
350            weight_sum += 0.3;
351        }
352
353        if let Some(&data_complexity) = measures.get("data_complexity") {
354            score += data_complexity * 0.2;
355            weight_sum += 0.2;
356        }
357
358        if let Some(&n_samples) = measures.get("training_set_size") {
359            // Complexity decreases with more data
360            score += (1.0 / (n_samples + 1.0)) * 100.0 * 0.1;
361            weight_sum += 0.1;
362        }
363
364        if weight_sum > 0.0 {
365            score / weight_sum
366        } else {
367            0.0
368        }
369    }
370
371    /// Calculate overfitting score
372    fn calculate_overfitting_score(
373        &self,
374        train_error: f64,
375        validation_error: f64,
376        complexity_score: f64,
377        n_samples: usize,
378    ) -> f64 {
379        // Overfitting score combines generalization gap with complexity
380        let generalization_gap = validation_error - train_error;
381        let relative_gap = if train_error > 0.0 {
382            generalization_gap / train_error
383        } else {
384            generalization_gap
385        };
386
387        // Adjust for sample size (small datasets are more prone to overfitting)
388        let size_factor = 1.0 / (n_samples as f64).sqrt();
389
390        // Combine factors
391        let overfitting_score = relative_gap * (1.0 + complexity_score * 0.1) * (1.0 + size_factor);
392
393        // Normalize to [0, 1]
394        overfitting_score.clamp(0.0, 1.0)
395    }
396
397    /// Generate recommendation based on analysis
398    fn generate_recommendation(
399        &self,
400        train_error: f64,
401        validation_error: f64,
402        generalization_gap: f64,
403        complexity_score: f64,
404        overfitting_detected: bool,
405    ) -> ComplexityRecommendation {
406        // High training error suggests underfitting
407        if train_error > self.config.underfitting_threshold {
408            return ComplexityRecommendation::IncreaseComplexity;
409        }
410
411        // Overfitting detected
412        if overfitting_detected {
413            if complexity_score > 10.0 {
414                ComplexityRecommendation::ReduceComplexity
415            } else {
416                ComplexityRecommendation::UseRegularization
417            }
418        } else if generalization_gap > 0.05
419            && generalization_gap <= self.config.overfitting_threshold
420        {
421            // Mild overfitting
422            if validation_error > train_error * 1.5 {
423                ComplexityRecommendation::CollectMoreData
424            } else {
425                ComplexityRecommendation::UseRegularization
426            }
427        } else if train_error > 0.1 && validation_error > 0.1 {
428            // Both errors are high
429            ComplexityRecommendation::TryEnsembles
430        } else {
431            ComplexityRecommendation::Appropriate
432        }
433    }
434}
435
436impl Default for ModelComplexityAnalyzer {
437    fn default() -> Self {
438        Self::new()
439    }
440}
441
442/// Overfitting detector for time series data
443pub struct OverfittingDetector {
444    config: ComplexityAnalysisConfig,
445}
446
447impl OverfittingDetector {
448    /// Create a new overfitting detector
449    pub fn new() -> Self {
450        Self {
451            config: ComplexityAnalysisConfig::default(),
452        }
453    }
454
455    /// Create with custom configuration
456    pub fn with_config(config: ComplexityAnalysisConfig) -> Self {
457        Self { config }
458    }
459
460    /// Detect overfitting using learning curves
461    pub fn detect_from_learning_curve(
462        &self,
463        train_sizes: &[usize],
464        train_scores: &[f64],
465        val_scores: &[f64],
466    ) -> Result<bool> {
467        if train_sizes.len() != train_scores.len() || train_scores.len() != val_scores.len() {
468            return Err(SklearsError::InvalidParameter {
469                name: "arrays".to_string(),
470                reason: "array lengths must match".to_string(),
471            });
472        }
473
474        if train_sizes.is_empty() {
475            return Err(SklearsError::InvalidParameter {
476                name: "arrays".to_string(),
477                reason: "arrays cannot be empty".to_string(),
478            });
479        }
480
481        // Look for diverging learning curves
482        let mut divergence_count = 0;
483        for i in 1..train_scores.len() {
484            let train_improvement = train_scores[i - 1] - train_scores[i];
485            let val_improvement = val_scores[i - 1] - val_scores[i];
486
487            // If training improves but validation doesn't (or gets worse)
488            if train_improvement > 0.01 && val_improvement < 0.01 {
489                divergence_count += 1;
490            }
491        }
492
493        // Overfitting if curves diverge in more than half the steps
494        Ok(divergence_count > train_scores.len() / 2)
495    }
496
497    /// Detect overfitting using validation curves
498    pub fn detect_from_validation_curve(
499        &self,
500        param_values: &[f64],
501        train_scores: &[f64],
502        val_scores: &[f64],
503    ) -> Result<(bool, Option<f64>)> {
504        if param_values.len() != train_scores.len() || train_scores.len() != val_scores.len() {
505            return Err(SklearsError::InvalidParameter {
506                name: "arrays".to_string(),
507                reason: "array lengths must match".to_string(),
508            });
509        }
510
511        if param_values.is_empty() {
512            return Err(SklearsError::InvalidParameter {
513                name: "arrays".to_string(),
514                reason: "arrays cannot be empty".to_string(),
515            });
516        }
517
518        // Find optimal parameter value (minimum validation error)
519        let min_val_idx = val_scores
520            .iter()
521            .enumerate()
522            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
523            .map(|(idx, _)| idx)
524            .unwrap();
525
526        let optimal_param = param_values[min_val_idx];
527        let min_val_score = val_scores[min_val_idx];
528
529        // Check if validation score increases for higher complexity
530        let mut overfitting_detected = false;
531        for i in (min_val_idx + 1)..val_scores.len() {
532            if val_scores[i] > min_val_score + self.config.overfitting_threshold {
533                overfitting_detected = true;
534                break;
535            }
536        }
537
538        Ok((overfitting_detected, Some(optimal_param)))
539    }
540}
541
542impl Default for OverfittingDetector {
543    fn default() -> Self {
544        Self::new()
545    }
546}
547
548/// Convenience function for analyzing model complexity
549pub fn analyze_model_complexity<E, X, Y>(
550    estimator: &E,
551    x_train: &[X],
552    y_train: &[Y],
553    x_val: &[X],
554    y_val: &[Y],
555) -> Result<ComplexityAnalysisResult>
556where
557    E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
558    E::Fitted: Predict<Vec<X>, Vec<f64>>,
559    X: Clone,
560    Y: Clone + Into<f64>,
561{
562    let analyzer = ModelComplexityAnalyzer::new();
563    analyzer.analyze(estimator, x_train, y_train, x_val, y_val)
564}
565
566/// Convenience function for detecting overfitting from learning curves
567pub fn detect_overfitting_learning_curve(
568    train_sizes: &[usize],
569    train_scores: &[f64],
570    val_scores: &[f64],
571) -> Result<bool> {
572    let detector = OverfittingDetector::new();
573    detector.detect_from_learning_curve(train_sizes, train_scores, val_scores)
574}
575
576#[allow(non_snake_case)]
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    // Mock estimator for testing
582    #[derive(Clone)]
583    struct MockEstimator;
584
585    struct MockTrained;
586
587    impl Estimator for MockEstimator {
588        type Config = ();
589        type Error = SklearsError;
590        type Float = f64;
591
592        fn config(&self) -> &Self::Config {
593            &()
594        }
595    }
596
597    impl Fit<Vec<f64>, Vec<f64>> for MockEstimator {
598        type Fitted = MockTrained;
599
600        fn fit(self, _x: &Vec<f64>, _y: &Vec<f64>) -> Result<Self::Fitted> {
601            Ok(MockTrained)
602        }
603    }
604
605    impl Predict<Vec<f64>, Vec<f64>> for MockTrained {
606        fn predict(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
607            // Simple linear prediction with some noise to simulate overfitting
608            Ok(x.iter().map(|&xi| xi * 0.5 + 0.1).collect())
609        }
610    }
611
612    #[test]
613    fn test_complexity_analyzer_creation() {
614        let analyzer = ModelComplexityAnalyzer::new();
615        assert_eq!(analyzer.config.overfitting_threshold, 0.1);
616        assert_eq!(analyzer.config.underfitting_threshold, 0.3);
617    }
618
619    #[test]
620    fn test_complexity_analysis() {
621        let estimator = MockEstimator;
622        let x_train: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
623        let y_train: Vec<f64> = x_train.iter().map(|&x| x * 0.5).collect();
624        let x_val: Vec<f64> = (0..20).map(|i| i as f64 * 0.1 + 10.0).collect();
625        let y_val: Vec<f64> = x_val.iter().map(|&x| x * 0.5).collect();
626
627        let analyzer = ModelComplexityAnalyzer::new();
628        let result = analyzer.analyze(&estimator, &x_train, &y_train, &x_val, &y_val);
629
630        assert!(result.is_ok());
631        let result = result.unwrap();
632        assert!(result.train_error >= 0.0);
633        assert!(result.validation_error >= 0.0);
634        assert!(result.complexity_score >= 0.0);
635        assert!(result.overfitting_score >= 0.0 && result.overfitting_score <= 1.0);
636    }
637
638    #[test]
639    fn test_overfitting_detector() {
640        let detector = OverfittingDetector::new();
641
642        // Test learning curve overfitting detection
643        let train_sizes = vec![10, 20, 30, 40, 50];
644        let train_scores = vec![0.5, 0.3, 0.2, 0.1, 0.05]; // Improving
645        let val_scores = vec![0.6, 0.4, 0.4, 0.45, 0.5]; // Getting worse
646
647        let result = detector.detect_from_learning_curve(&train_sizes, &train_scores, &val_scores);
648        assert!(result.is_ok());
649
650        // Test validation curve overfitting detection
651        let param_values = vec![0.1, 0.5, 1.0, 2.0, 5.0];
652        let train_scores = vec![0.5, 0.3, 0.2, 0.1, 0.05];
653        let val_scores = vec![0.6, 0.4, 0.35, 0.4, 0.5];
654
655        let result =
656            detector.detect_from_validation_curve(&param_values, &train_scores, &val_scores);
657        assert!(result.is_ok());
658        let (_overfitting, optimal_param) = result.unwrap();
659        assert!(optimal_param.is_some());
660    }
661
662    #[test]
663    fn test_convenience_functions() {
664        let estimator = MockEstimator;
665        let x_train: Vec<f64> = (0..50).map(|i| i as f64 * 0.1).collect();
666        let y_train: Vec<f64> = x_train.iter().map(|&x| x * 0.3).collect();
667        let x_val: Vec<f64> = (0..10).map(|i| i as f64 * 0.1 + 5.0).collect();
668        let y_val: Vec<f64> = x_val.iter().map(|&x| x * 0.3).collect();
669
670        let result = analyze_model_complexity(&estimator, &x_train, &y_train, &x_val, &y_val);
671        assert!(result.is_ok());
672
673        let train_sizes = vec![10, 20, 30];
674        let train_scores = vec![0.5, 0.3, 0.2];
675        let val_scores = vec![0.6, 0.4, 0.45];
676
677        let result = detect_overfitting_learning_curve(&train_sizes, &train_scores, &val_scores);
678        assert!(result.is_ok());
679    }
680
681    #[test]
682    fn test_complexity_recommendations() {
683        use ComplexityRecommendation::*;
684
685        let recommendation = Appropriate;
686        assert_eq!(
687            format!("{}", recommendation),
688            "Model complexity is appropriate"
689        );
690
691        let recommendation = IncreaseComplexity;
692        assert_eq!(
693            format!("{}", recommendation),
694            "Consider increasing model complexity"
695        );
696
697        let recommendation = ReduceComplexity;
698        assert_eq!(
699            format!("{}", recommendation),
700            "Consider reducing model complexity"
701        );
702    }
703}