1use sklears_core::{
8 error::{Result, SklearsError},
9 traits::{Estimator, Fit, Predict},
10};
11use std::collections::HashMap;
12use std::fmt::{self, Display, Formatter};
13
14#[derive(Debug, Clone)]
16pub struct ComplexityAnalysisResult {
17 pub train_error: f64,
19 pub validation_error: f64,
21 pub complexity_score: f64,
23 pub overfitting_score: f64,
25 pub generalization_gap: f64,
27 pub complexity_measures: HashMap<String, f64>,
29 pub overfitting_detected: bool,
31 pub recommendation: ComplexityRecommendation,
33}
34
35#[derive(Debug, Clone, PartialEq)]
37pub enum ComplexityRecommendation {
38 Appropriate,
40 IncreaseComplexity,
42 ReduceComplexity,
44 UseRegularization,
46 CollectMoreData,
48 TryEnsembles,
50}
51
52impl Display for ComplexityRecommendation {
53 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
54 let msg = match self {
55 ComplexityRecommendation::Appropriate => "Model complexity is appropriate",
56 ComplexityRecommendation::IncreaseComplexity => "Consider increasing model complexity",
57 ComplexityRecommendation::ReduceComplexity => "Consider reducing model complexity",
58 ComplexityRecommendation::UseRegularization => "Consider using regularization",
59 ComplexityRecommendation::CollectMoreData => "Consider collecting more training data",
60 ComplexityRecommendation::TryEnsembles => "Consider using ensemble methods",
61 };
62 write!(f, "{}", msg)
63 }
64}
65
66impl Display for ComplexityAnalysisResult {
67 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
68 write!(
69 f,
70 "Model Complexity Analysis:\n\
71 Train Error: {:.6}\n\
72 Validation Error: {:.6}\n\
73 Generalization Gap: {:.6}\n\
74 Complexity Score: {:.6}\n\
75 Overfitting Score: {:.6}\n\
76 Overfitting Detected: {}\n\
77 Recommendation: {}",
78 self.train_error,
79 self.validation_error,
80 self.generalization_gap,
81 self.complexity_score,
82 self.overfitting_score,
83 self.overfitting_detected,
84 self.recommendation
85 )
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct ComplexityAnalysisConfig {
92 pub overfitting_threshold: f64,
94 pub underfitting_threshold: f64,
96 pub data_size_weight: f64,
98 pub include_information_measures: bool,
100 pub use_cross_validation: bool,
102 pub cv_folds: usize,
104}
105
106impl Default for ComplexityAnalysisConfig {
107 fn default() -> Self {
108 Self {
109 overfitting_threshold: 0.1,
110 underfitting_threshold: 0.3,
111 data_size_weight: 0.1,
112 include_information_measures: true,
113 use_cross_validation: false,
114 cv_folds: 5,
115 }
116 }
117}
118
119#[derive(Debug, Clone)]
121pub enum ComplexityMeasure {
122 ParameterCount,
124 DegreesOfFreedom,
126 VCDimension,
128 RademacherComplexity,
130 PathLength,
132 SupportVectorCount,
134 SpectralComplexity,
136}
137
138pub struct ModelComplexityAnalyzer {
140 config: ComplexityAnalysisConfig,
141}
142
143impl ModelComplexityAnalyzer {
144 pub fn new() -> Self {
146 Self {
147 config: ComplexityAnalysisConfig::default(),
148 }
149 }
150
151 pub fn with_config(config: ComplexityAnalysisConfig) -> Self {
153 Self { config }
154 }
155
156 pub fn overfitting_threshold(mut self, threshold: f64) -> Self {
158 self.config.overfitting_threshold = threshold;
159 self
160 }
161
162 pub fn underfitting_threshold(mut self, threshold: f64) -> Self {
164 self.config.underfitting_threshold = threshold;
165 self
166 }
167
168 pub fn use_cross_validation(mut self, use_cv: bool) -> Self {
170 self.config.use_cross_validation = use_cv;
171 self
172 }
173
174 pub fn cv_folds(mut self, folds: usize) -> Self {
176 self.config.cv_folds = folds;
177 self
178 }
179
180 pub fn analyze<E, X, Y>(
182 &self,
183 estimator: &E,
184 x_train: &[X],
185 y_train: &[Y],
186 x_val: &[X],
187 y_val: &[Y],
188 ) -> Result<ComplexityAnalysisResult>
189 where
190 E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
191 E::Fitted: Predict<Vec<X>, Vec<f64>>,
192 X: Clone,
193 Y: Clone + Into<f64>,
194 {
195 let x_train_vec = x_train.to_vec();
197 let y_train_vec = y_train.to_vec();
198 let trained_model = estimator.clone().fit(&x_train_vec, &y_train_vec)?;
199
200 let train_predictions = trained_model.predict(&x_train_vec)?;
202 let train_targets: Vec<f64> = y_train.iter().map(|y| y.clone().into()).collect();
203 let train_error = self.calculate_error(&train_predictions, &train_targets);
204
205 let x_val_vec = x_val.to_vec();
207 let val_predictions = trained_model.predict(&x_val_vec)?;
208 let val_targets: Vec<f64> = y_val.iter().map(|y| y.clone().into()).collect();
209 let validation_error = self.calculate_error(&val_predictions, &val_targets);
210
211 let generalization_gap = validation_error - train_error;
213
214 let complexity_measures = self.estimate_complexity(x_train, y_train, &trained_model)?;
216 let complexity_score = self.aggregate_complexity(&complexity_measures);
217
218 let overfitting_score = self.calculate_overfitting_score(
220 train_error,
221 validation_error,
222 complexity_score,
223 x_train.len(),
224 );
225
226 let overfitting_detected = generalization_gap > self.config.overfitting_threshold;
228
229 let recommendation = self.generate_recommendation(
231 train_error,
232 validation_error,
233 generalization_gap,
234 complexity_score,
235 overfitting_detected,
236 );
237
238 Ok(ComplexityAnalysisResult {
239 train_error,
240 validation_error,
241 complexity_score,
242 overfitting_score,
243 generalization_gap,
244 complexity_measures,
245 overfitting_detected,
246 recommendation,
247 })
248 }
249
250 fn calculate_error(&self, predictions: &[f64], targets: &[f64]) -> f64 {
252 if predictions.len() != targets.len() {
253 return f64::INFINITY;
254 }
255
256 let mse = predictions
257 .iter()
258 .zip(targets.iter())
259 .map(|(&pred, &target)| (pred - target).powi(2))
260 .sum::<f64>()
261 / predictions.len() as f64;
262
263 mse
264 }
265
266 fn estimate_complexity<X, Y>(
268 &self,
269 x_train: &[X],
270 y_train: &[Y],
271 _trained_model: &impl Predict<Vec<X>, Vec<f64>>,
272 ) -> Result<HashMap<String, f64>>
273 where
274 X: Clone,
275 Y: Clone + Into<f64>,
276 {
277 let mut measures = HashMap::new();
278
279 let n_samples = x_train.len() as f64;
281 let n_features = self.estimate_feature_count(x_train);
282
283 measures.insert("training_set_size".to_string(), n_samples);
284 measures.insert("feature_count".to_string(), n_features);
285
286 let param_count = self.estimate_parameter_count(n_features);
288 measures.insert("estimated_parameters".to_string(), param_count);
289
290 let data_complexity = self.calculate_data_complexity(x_train, y_train);
292 measures.insert("data_complexity".to_string(), data_complexity);
293
294 let eff_dof = self.estimate_effective_dof(n_samples, param_count);
296 measures.insert("effective_dof".to_string(), eff_dof);
297
298 Ok(measures)
299 }
300
301 fn estimate_feature_count<X>(&self, _x_train: &[X]) -> f64 {
303 10.0 }
307
308 fn estimate_parameter_count(&self, n_features: f64) -> f64 {
310 n_features + 1.0
312 }
313
314 fn calculate_data_complexity<X, Y>(&self, _x_train: &[X], y_train: &[Y]) -> f64
316 where
317 Y: Clone + Into<f64>,
318 {
319 let targets: Vec<f64> = y_train.iter().map(|y| y.clone().into()).collect();
320 if targets.is_empty() {
321 return 0.0;
322 }
323
324 let mean = targets.iter().sum::<f64>() / targets.len() as f64;
325 let variance =
326 targets.iter().map(|&y| (y - mean).powi(2)).sum::<f64>() / targets.len() as f64;
327
328 variance.sqrt()
329 }
330
331 fn estimate_effective_dof(&self, n_samples: f64, param_count: f64) -> f64 {
333 param_count.min(n_samples * 0.1)
335 }
336
337 fn aggregate_complexity(&self, measures: &HashMap<String, f64>) -> f64 {
339 let mut score = 0.0;
340 let mut weight_sum = 0.0;
341
342 if let Some(¶m_count) = measures.get("estimated_parameters") {
344 score += param_count * 0.4;
345 weight_sum += 0.4;
346 }
347
348 if let Some(&eff_dof) = measures.get("effective_dof") {
349 score += eff_dof * 0.3;
350 weight_sum += 0.3;
351 }
352
353 if let Some(&data_complexity) = measures.get("data_complexity") {
354 score += data_complexity * 0.2;
355 weight_sum += 0.2;
356 }
357
358 if let Some(&n_samples) = measures.get("training_set_size") {
359 score += (1.0 / (n_samples + 1.0)) * 100.0 * 0.1;
361 weight_sum += 0.1;
362 }
363
364 if weight_sum > 0.0 {
365 score / weight_sum
366 } else {
367 0.0
368 }
369 }
370
371 fn calculate_overfitting_score(
373 &self,
374 train_error: f64,
375 validation_error: f64,
376 complexity_score: f64,
377 n_samples: usize,
378 ) -> f64 {
379 let generalization_gap = validation_error - train_error;
381 let relative_gap = if train_error > 0.0 {
382 generalization_gap / train_error
383 } else {
384 generalization_gap
385 };
386
387 let size_factor = 1.0 / (n_samples as f64).sqrt();
389
390 let overfitting_score = relative_gap * (1.0 + complexity_score * 0.1) * (1.0 + size_factor);
392
393 overfitting_score.clamp(0.0, 1.0)
395 }
396
397 fn generate_recommendation(
399 &self,
400 train_error: f64,
401 validation_error: f64,
402 generalization_gap: f64,
403 complexity_score: f64,
404 overfitting_detected: bool,
405 ) -> ComplexityRecommendation {
406 if train_error > self.config.underfitting_threshold {
408 return ComplexityRecommendation::IncreaseComplexity;
409 }
410
411 if overfitting_detected {
413 if complexity_score > 10.0 {
414 ComplexityRecommendation::ReduceComplexity
415 } else {
416 ComplexityRecommendation::UseRegularization
417 }
418 } else if generalization_gap > 0.05
419 && generalization_gap <= self.config.overfitting_threshold
420 {
421 if validation_error > train_error * 1.5 {
423 ComplexityRecommendation::CollectMoreData
424 } else {
425 ComplexityRecommendation::UseRegularization
426 }
427 } else if train_error > 0.1 && validation_error > 0.1 {
428 ComplexityRecommendation::TryEnsembles
430 } else {
431 ComplexityRecommendation::Appropriate
432 }
433 }
434}
435
436impl Default for ModelComplexityAnalyzer {
437 fn default() -> Self {
438 Self::new()
439 }
440}
441
442pub struct OverfittingDetector {
444 config: ComplexityAnalysisConfig,
445}
446
447impl OverfittingDetector {
448 pub fn new() -> Self {
450 Self {
451 config: ComplexityAnalysisConfig::default(),
452 }
453 }
454
455 pub fn with_config(config: ComplexityAnalysisConfig) -> Self {
457 Self { config }
458 }
459
460 pub fn detect_from_learning_curve(
462 &self,
463 train_sizes: &[usize],
464 train_scores: &[f64],
465 val_scores: &[f64],
466 ) -> Result<bool> {
467 if train_sizes.len() != train_scores.len() || train_scores.len() != val_scores.len() {
468 return Err(SklearsError::InvalidParameter {
469 name: "arrays".to_string(),
470 reason: "array lengths must match".to_string(),
471 });
472 }
473
474 if train_sizes.is_empty() {
475 return Err(SklearsError::InvalidParameter {
476 name: "arrays".to_string(),
477 reason: "arrays cannot be empty".to_string(),
478 });
479 }
480
481 let mut divergence_count = 0;
483 for i in 1..train_scores.len() {
484 let train_improvement = train_scores[i - 1] - train_scores[i];
485 let val_improvement = val_scores[i - 1] - val_scores[i];
486
487 if train_improvement > 0.01 && val_improvement < 0.01 {
489 divergence_count += 1;
490 }
491 }
492
493 Ok(divergence_count > train_scores.len() / 2)
495 }
496
497 pub fn detect_from_validation_curve(
499 &self,
500 param_values: &[f64],
501 train_scores: &[f64],
502 val_scores: &[f64],
503 ) -> Result<(bool, Option<f64>)> {
504 if param_values.len() != train_scores.len() || train_scores.len() != val_scores.len() {
505 return Err(SklearsError::InvalidParameter {
506 name: "arrays".to_string(),
507 reason: "array lengths must match".to_string(),
508 });
509 }
510
511 if param_values.is_empty() {
512 return Err(SklearsError::InvalidParameter {
513 name: "arrays".to_string(),
514 reason: "arrays cannot be empty".to_string(),
515 });
516 }
517
518 let min_val_idx = val_scores
520 .iter()
521 .enumerate()
522 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
523 .map(|(idx, _)| idx)
524 .unwrap();
525
526 let optimal_param = param_values[min_val_idx];
527 let min_val_score = val_scores[min_val_idx];
528
529 let mut overfitting_detected = false;
531 for i in (min_val_idx + 1)..val_scores.len() {
532 if val_scores[i] > min_val_score + self.config.overfitting_threshold {
533 overfitting_detected = true;
534 break;
535 }
536 }
537
538 Ok((overfitting_detected, Some(optimal_param)))
539 }
540}
541
542impl Default for OverfittingDetector {
543 fn default() -> Self {
544 Self::new()
545 }
546}
547
548pub fn analyze_model_complexity<E, X, Y>(
550 estimator: &E,
551 x_train: &[X],
552 y_train: &[Y],
553 x_val: &[X],
554 y_val: &[Y],
555) -> Result<ComplexityAnalysisResult>
556where
557 E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
558 E::Fitted: Predict<Vec<X>, Vec<f64>>,
559 X: Clone,
560 Y: Clone + Into<f64>,
561{
562 let analyzer = ModelComplexityAnalyzer::new();
563 analyzer.analyze(estimator, x_train, y_train, x_val, y_val)
564}
565
566pub fn detect_overfitting_learning_curve(
568 train_sizes: &[usize],
569 train_scores: &[f64],
570 val_scores: &[f64],
571) -> Result<bool> {
572 let detector = OverfittingDetector::new();
573 detector.detect_from_learning_curve(train_sizes, train_scores, val_scores)
574}
575
576#[allow(non_snake_case)]
577#[cfg(test)]
578mod tests {
579 use super::*;
580
581 #[derive(Clone)]
583 struct MockEstimator;
584
585 struct MockTrained;
586
587 impl Estimator for MockEstimator {
588 type Config = ();
589 type Error = SklearsError;
590 type Float = f64;
591
592 fn config(&self) -> &Self::Config {
593 &()
594 }
595 }
596
597 impl Fit<Vec<f64>, Vec<f64>> for MockEstimator {
598 type Fitted = MockTrained;
599
600 fn fit(self, _x: &Vec<f64>, _y: &Vec<f64>) -> Result<Self::Fitted> {
601 Ok(MockTrained)
602 }
603 }
604
605 impl Predict<Vec<f64>, Vec<f64>> for MockTrained {
606 fn predict(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
607 Ok(x.iter().map(|&xi| xi * 0.5 + 0.1).collect())
609 }
610 }
611
612 #[test]
613 fn test_complexity_analyzer_creation() {
614 let analyzer = ModelComplexityAnalyzer::new();
615 assert_eq!(analyzer.config.overfitting_threshold, 0.1);
616 assert_eq!(analyzer.config.underfitting_threshold, 0.3);
617 }
618
619 #[test]
620 fn test_complexity_analysis() {
621 let estimator = MockEstimator;
622 let x_train: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
623 let y_train: Vec<f64> = x_train.iter().map(|&x| x * 0.5).collect();
624 let x_val: Vec<f64> = (0..20).map(|i| i as f64 * 0.1 + 10.0).collect();
625 let y_val: Vec<f64> = x_val.iter().map(|&x| x * 0.5).collect();
626
627 let analyzer = ModelComplexityAnalyzer::new();
628 let result = analyzer.analyze(&estimator, &x_train, &y_train, &x_val, &y_val);
629
630 assert!(result.is_ok());
631 let result = result.unwrap();
632 assert!(result.train_error >= 0.0);
633 assert!(result.validation_error >= 0.0);
634 assert!(result.complexity_score >= 0.0);
635 assert!(result.overfitting_score >= 0.0 && result.overfitting_score <= 1.0);
636 }
637
638 #[test]
639 fn test_overfitting_detector() {
640 let detector = OverfittingDetector::new();
641
642 let train_sizes = vec![10, 20, 30, 40, 50];
644 let train_scores = vec![0.5, 0.3, 0.2, 0.1, 0.05]; let val_scores = vec![0.6, 0.4, 0.4, 0.45, 0.5]; let result = detector.detect_from_learning_curve(&train_sizes, &train_scores, &val_scores);
648 assert!(result.is_ok());
649
650 let param_values = vec![0.1, 0.5, 1.0, 2.0, 5.0];
652 let train_scores = vec![0.5, 0.3, 0.2, 0.1, 0.05];
653 let val_scores = vec![0.6, 0.4, 0.35, 0.4, 0.5];
654
655 let result =
656 detector.detect_from_validation_curve(¶m_values, &train_scores, &val_scores);
657 assert!(result.is_ok());
658 let (_overfitting, optimal_param) = result.unwrap();
659 assert!(optimal_param.is_some());
660 }
661
662 #[test]
663 fn test_convenience_functions() {
664 let estimator = MockEstimator;
665 let x_train: Vec<f64> = (0..50).map(|i| i as f64 * 0.1).collect();
666 let y_train: Vec<f64> = x_train.iter().map(|&x| x * 0.3).collect();
667 let x_val: Vec<f64> = (0..10).map(|i| i as f64 * 0.1 + 5.0).collect();
668 let y_val: Vec<f64> = x_val.iter().map(|&x| x * 0.3).collect();
669
670 let result = analyze_model_complexity(&estimator, &x_train, &y_train, &x_val, &y_val);
671 assert!(result.is_ok());
672
673 let train_sizes = vec![10, 20, 30];
674 let train_scores = vec![0.5, 0.3, 0.2];
675 let val_scores = vec![0.6, 0.4, 0.45];
676
677 let result = detect_overfitting_learning_curve(&train_sizes, &train_scores, &val_scores);
678 assert!(result.is_ok());
679 }
680
681 #[test]
682 fn test_complexity_recommendations() {
683 use ComplexityRecommendation::*;
684
685 let recommendation = Appropriate;
686 assert_eq!(
687 format!("{}", recommendation),
688 "Model complexity is appropriate"
689 );
690
691 let recommendation = IncreaseComplexity;
692 assert_eq!(
693 format!("{}", recommendation),
694 "Consider increasing model complexity"
695 );
696
697 let recommendation = ReduceComplexity;
698 assert_eq!(
699 format!("{}", recommendation),
700 "Consider reducing model complexity"
701 );
702 }
703}