sklears_model_selection/
bias_variance.rs

1//! Bias-variance decomposition analysis for model performance understanding
2//!
3//! This module provides tools for performing bias-variance decomposition of model predictions,
4//! which helps understand the sources of generalization error. The decomposition separates
5//! the expected test error into three components:
6//! - Bias²: Error due to overly simplistic assumptions
7//! - Variance: Error due to sensitivity to small fluctuations in training data
8//! - Noise: Irreducible error inherent in the problem
9
10use sklears_core::{
11    error::{Result, SklearsError},
12    traits::{Estimator, Fit, Predict},
13};
14use std::fmt::{self, Display, Formatter};
15
16/// Results of bias-variance decomposition analysis
17#[derive(Debug, Clone)]
18pub struct BiasVarianceResult {
19    /// Bias component (squared)
20    pub bias_squared: f64,
21    /// Variance component
22    pub variance: f64,
23    /// Noise component (irreducible error)
24    pub noise: f64,
25    /// Total expected error
26    pub expected_error: f64,
27    /// Standard error of bias estimate
28    pub bias_std_error: f64,
29    /// Standard error of variance estimate
30    pub variance_std_error: f64,
31    /// Number of bootstrap samples used
32    pub n_bootstrap: usize,
33    /// Sample-wise bias and variance estimates
34    pub sample_wise_results: Vec<SampleBiasVariance>,
35}
36
37/// Bias-variance results for individual test samples
38#[derive(Debug, Clone)]
39pub struct SampleBiasVariance {
40    /// Sample index
41    pub sample_index: usize,
42    /// True target value
43    pub true_value: f64,
44    /// Mean prediction across bootstrap samples
45    pub mean_prediction: f64,
46    /// Variance of predictions across bootstrap samples
47    pub prediction_variance: f64,
48    /// Squared bias for this sample
49    pub squared_bias: f64,
50    /// Individual predictions from each bootstrap sample
51    pub predictions: Vec<f64>,
52}
53
54impl Display for BiasVarianceResult {
55    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
56        write!(
57            f,
58            "Bias-Variance Decomposition Results:\n\
59             Expected Error: {:.6}\n\
60             Bias²: {:.6} (SE: {:.6})\n\
61             Variance: {:.6} (SE: {:.6})\n\
62             Noise: {:.6}\n\
63             Bootstrap Samples: {}",
64            self.expected_error,
65            self.bias_squared,
66            self.bias_std_error,
67            self.variance,
68            self.variance_std_error,
69            self.noise,
70            self.n_bootstrap
71        )
72    }
73}
74
75/// Configuration for bias-variance decomposition
76#[derive(Debug, Clone)]
77pub struct BiasVarianceConfig {
78    /// Number of bootstrap samples to generate
79    pub n_bootstrap: usize,
80    /// Fraction of original dataset to sample in each bootstrap
81    pub sample_fraction: f64,
82    /// Random seed for reproducibility
83    pub random_seed: Option<u64>,
84    /// Whether to use sampling with replacement
85    pub with_replacement: bool,
86    /// Whether to compute sample-wise decomposition
87    pub compute_sample_wise: bool,
88}
89
90impl Default for BiasVarianceConfig {
91    fn default() -> Self {
92        Self {
93            n_bootstrap: 100,
94            sample_fraction: 1.0,
95            random_seed: None,
96            with_replacement: true,
97            compute_sample_wise: true,
98        }
99    }
100}
101
102/// Bias-variance decomposition analyzer
103pub struct BiasVarianceAnalyzer {
104    config: BiasVarianceConfig,
105}
106
107impl BiasVarianceAnalyzer {
108    /// Create a new bias-variance analyzer with default configuration
109    pub fn new() -> Self {
110        Self {
111            config: BiasVarianceConfig::default(),
112        }
113    }
114
115    /// Create a new bias-variance analyzer with custom configuration
116    pub fn with_config(config: BiasVarianceConfig) -> Self {
117        Self { config }
118    }
119
120    /// Set the number of bootstrap samples
121    pub fn n_bootstrap(mut self, n_bootstrap: usize) -> Self {
122        self.config.n_bootstrap = n_bootstrap;
123        self
124    }
125
126    /// Set the sample fraction for bootstrap sampling
127    pub fn sample_fraction(mut self, fraction: f64) -> Self {
128        self.config.sample_fraction = fraction;
129        self
130    }
131
132    /// Set random seed for reproducibility
133    pub fn random_seed(mut self, seed: u64) -> Self {
134        self.config.random_seed = Some(seed);
135        self
136    }
137
138    /// Enable or disable sampling with replacement
139    pub fn with_replacement(mut self, with_replacement: bool) -> Self {
140        self.config.with_replacement = with_replacement;
141        self
142    }
143
144    /// Enable or disable sample-wise computation
145    pub fn compute_sample_wise(mut self, compute: bool) -> Self {
146        self.config.compute_sample_wise = compute;
147        self
148    }
149
150    /// Perform bias-variance decomposition
151    pub fn decompose<E, X, Y>(
152        &self,
153        estimator: &E,
154        x_train: &[X],
155        y_train: &[Y],
156        x_test: &[X],
157        y_test: &[Y],
158    ) -> Result<BiasVarianceResult>
159    where
160        E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
161        E::Fitted: Predict<Vec<X>, Vec<f64>>,
162        X: Clone,
163        Y: Clone + Into<f64>,
164    {
165        if self.config.n_bootstrap == 0 {
166            return Err(SklearsError::InvalidParameter {
167                name: "n_bootstrap".to_string(),
168                reason: "must be > 0".to_string(),
169            });
170        }
171
172        if self.config.sample_fraction <= 0.0 || self.config.sample_fraction > 1.0 {
173            return Err(SklearsError::InvalidParameter {
174                name: "sample_fraction".to_string(),
175                reason: "must be in (0, 1]".to_string(),
176            });
177        }
178
179        let mut rng = self.get_rng();
180        let n_train = x_train.len();
181        let _n_test = x_test.len();
182        let sample_size = (n_train as f64 * self.config.sample_fraction) as usize;
183
184        // Convert y_test to f64 values
185        let y_test_f64: Vec<f64> = y_test.iter().map(|y| y.clone().into()).collect();
186
187        // Store predictions from each bootstrap sample
188        let mut all_predictions = Vec::with_capacity(self.config.n_bootstrap);
189
190        // Generate bootstrap samples and train models
191        for _ in 0..self.config.n_bootstrap {
192            // Create bootstrap sample
193            let (x_boot, y_boot) =
194                self.bootstrap_sample(x_train, y_train, sample_size, &mut rng)?;
195
196            // Train model on bootstrap sample
197            let trained_model = estimator.clone().fit(&x_boot, &y_boot)?;
198
199            // Make predictions on test set
200            let x_test_vec: Vec<X> = x_test.to_vec();
201            let predictions = trained_model.predict(&x_test_vec)?;
202            all_predictions.push(predictions);
203        }
204
205        // Compute bias-variance decomposition
206        self.compute_decomposition(&all_predictions, &y_test_f64)
207    }
208
209    /// Generate a bootstrap sample
210    fn bootstrap_sample<X, Y>(
211        &self,
212        x_train: &[X],
213        y_train: &[Y],
214        sample_size: usize,
215        rng: &mut impl scirs2_core::random::Rng,
216    ) -> Result<(Vec<X>, Vec<Y>)>
217    where
218        X: Clone,
219        Y: Clone,
220    {
221        let n_train = x_train.len();
222        let mut x_boot = Vec::with_capacity(sample_size);
223        let mut y_boot = Vec::with_capacity(sample_size);
224
225        if self.config.with_replacement {
226            // Sample with replacement
227            for _ in 0..sample_size {
228                let idx = rng.gen_range(0..n_train);
229                x_boot.push(x_train[idx].clone());
230                y_boot.push(y_train[idx].clone());
231            }
232        } else {
233            // Sample without replacement
234            let mut indices: Vec<usize> = (0..n_train).collect();
235            indices.shuffle(rng);
236            indices.truncate(sample_size);
237
238            for &idx in &indices {
239                x_boot.push(x_train[idx].clone());
240                y_boot.push(y_train[idx].clone());
241            }
242        }
243
244        Ok((x_boot, y_boot))
245    }
246
247    /// Compute bias-variance decomposition from predictions
248    fn compute_decomposition(
249        &self,
250        all_predictions: &[Vec<f64>],
251        y_test: &[f64],
252    ) -> Result<BiasVarianceResult> {
253        let n_test = y_test.len();
254        let n_bootstrap = all_predictions.len();
255
256        if n_bootstrap == 0 {
257            return Err(SklearsError::InvalidParameter {
258                name: "predictions".to_string(),
259                reason: "no bootstrap predictions provided".to_string(),
260            });
261        }
262
263        if all_predictions.iter().any(|p| p.len() != n_test) {
264            return Err(SklearsError::InvalidParameter {
265                name: "predictions".to_string(),
266                reason: "all prediction arrays must have same length as test set".to_string(),
267            });
268        }
269
270        let mut sample_wise_results = Vec::new();
271        let mut total_bias_squared = 0.0;
272        let mut total_variance = 0.0;
273        let mut bias_estimates = Vec::new();
274        let mut variance_estimates = Vec::new();
275
276        // Compute bias and variance for each test sample
277        for i in 0..n_test {
278            let true_value = y_test[i];
279            let predictions: Vec<f64> = all_predictions.iter().map(|p| p[i]).collect();
280
281            // Mean prediction across bootstrap samples
282            let mean_prediction = predictions.iter().sum::<f64>() / n_bootstrap as f64;
283
284            // Variance of predictions
285            let prediction_variance = predictions
286                .iter()
287                .map(|&p| (p - mean_prediction).powi(2))
288                .sum::<f64>()
289                / n_bootstrap as f64;
290
291            // Squared bias
292            let squared_bias = (mean_prediction - true_value).powi(2);
293
294            total_bias_squared += squared_bias;
295            total_variance += prediction_variance;
296
297            bias_estimates.push(squared_bias);
298            variance_estimates.push(prediction_variance);
299
300            if self.config.compute_sample_wise {
301                sample_wise_results.push(SampleBiasVariance {
302                    sample_index: i,
303                    true_value,
304                    mean_prediction,
305                    prediction_variance,
306                    squared_bias,
307                    predictions,
308                });
309            }
310        }
311
312        // Average across all test samples
313        let bias_squared = total_bias_squared / n_test as f64;
314        let variance = total_variance / n_test as f64;
315
316        // Compute standard errors
317        let bias_std_error = self.compute_standard_error(&bias_estimates);
318        let variance_std_error = self.compute_standard_error(&variance_estimates);
319
320        // Estimate noise (irreducible error) as the minimum achievable error
321        // This is approximated as the variance in y_test if available, or set to 0
322        let noise = self.estimate_noise(y_test);
323
324        let expected_error = bias_squared + variance + noise;
325
326        Ok(BiasVarianceResult {
327            bias_squared,
328            variance,
329            noise,
330            expected_error,
331            bias_std_error,
332            variance_std_error,
333            n_bootstrap,
334            sample_wise_results,
335        })
336    }
337
338    /// Compute standard error of estimates
339    fn compute_standard_error(&self, estimates: &[f64]) -> f64 {
340        let n = estimates.len() as f64;
341        let mean = estimates.iter().sum::<f64>() / n;
342        let variance = estimates.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0);
343        (variance / n).sqrt()
344    }
345
346    /// Estimate noise component (irreducible error)
347    fn estimate_noise(&self, _y_test: &[f64]) -> f64 {
348        // In practice, noise estimation requires multiple observations of the same input
349        // For simplicity, we assume noise is 0 here, but in real applications,
350        // this could be estimated from repeated measurements or domain knowledge
351        0.0
352    }
353
354    /// Get random number generator
355    fn get_rng(&self) -> impl scirs2_core::random::Rng {
356        use scirs2_core::random::rngs::StdRng;
357        use scirs2_core::random::SeedableRng;
358
359        match self.config.random_seed {
360            Some(seed) => StdRng::seed_from_u64(seed),
361            None => {
362                use scirs2_core::random::thread_rng;
363                StdRng::from_rng(&mut thread_rng())
364            }
365        }
366    }
367}
368
369impl Default for BiasVarianceAnalyzer {
370    fn default() -> Self {
371        Self::new()
372    }
373}
374
375/// Convenience function for performing bias-variance decomposition
376pub fn bias_variance_decompose<E, X, Y>(
377    estimator: &E,
378    x_train: &[X],
379    y_train: &[Y],
380    x_test: &[X],
381    y_test: &[Y],
382    n_bootstrap: Option<usize>,
383) -> Result<BiasVarianceResult>
384where
385    E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
386    E::Fitted: Predict<Vec<X>, Vec<f64>>,
387    X: Clone,
388    Y: Clone + Into<f64>,
389{
390    let mut analyzer = BiasVarianceAnalyzer::new();
391    if let Some(n) = n_bootstrap {
392        analyzer = analyzer.n_bootstrap(n);
393    }
394    analyzer.decompose(estimator, x_train, y_train, x_test, y_test)
395}
396
397#[allow(non_snake_case)]
398#[cfg(test)]
399mod tests {
400    use super::*;
401
402    // Mock estimator for testing
403    #[derive(Clone)]
404    struct MockEstimator {
405        noise_level: f64,
406    }
407
408    struct MockTrained {
409        noise_level: f64,
410    }
411
412    impl Estimator for MockEstimator {
413        type Config = ();
414        type Error = SklearsError;
415        type Float = f64;
416
417        fn config(&self) -> &Self::Config {
418            &()
419        }
420    }
421
422    impl Fit<Vec<f64>, Vec<f64>> for MockEstimator {
423        type Fitted = MockTrained;
424
425        fn fit(self, _x: &Vec<f64>, _y: &Vec<f64>) -> Result<Self::Fitted> {
426            Ok(MockTrained {
427                noise_level: self.noise_level,
428            })
429        }
430    }
431
432    impl Predict<Vec<f64>, Vec<f64>> for MockTrained {
433        fn predict(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
434            let mut rng = scirs2_core::random::thread_rng();
435            Ok(x.iter()
436                .map(|&xi| xi + rng.gen_range(-self.noise_level..self.noise_level))
437                .collect())
438        }
439    }
440
441    #[test]
442    fn test_bias_variance_analyzer_creation() {
443        let analyzer = BiasVarianceAnalyzer::new();
444        assert_eq!(analyzer.config.n_bootstrap, 100);
445        assert_eq!(analyzer.config.sample_fraction, 1.0);
446        assert!(analyzer.config.random_seed.is_none());
447        assert!(analyzer.config.with_replacement);
448        assert!(analyzer.config.compute_sample_wise);
449    }
450
451    #[test]
452    fn test_bias_variance_configuration() {
453        let analyzer = BiasVarianceAnalyzer::new()
454            .n_bootstrap(50)
455            .sample_fraction(0.8)
456            .random_seed(42)
457            .with_replacement(false)
458            .compute_sample_wise(false);
459
460        assert_eq!(analyzer.config.n_bootstrap, 50);
461        assert_eq!(analyzer.config.sample_fraction, 0.8);
462        assert_eq!(analyzer.config.random_seed, Some(42));
463        assert!(!analyzer.config.with_replacement);
464        assert!(!analyzer.config.compute_sample_wise);
465    }
466
467    #[test]
468    fn test_bias_variance_decomposition() {
469        let estimator = MockEstimator { noise_level: 0.1 };
470        let x_train: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
471        let y_train: Vec<f64> = x_train.iter().map(|&x| x * 2.0 + 1.0).collect();
472        let x_test: Vec<f64> = (0..20).map(|i| i as f64 * 0.1 + 10.0).collect();
473        let y_test: Vec<f64> = x_test.iter().map(|&x| x * 2.0 + 1.0).collect();
474
475        let analyzer = BiasVarianceAnalyzer::new().n_bootstrap(10).random_seed(42);
476
477        let result = analyzer.decompose(&estimator, &x_train, &y_train, &x_test, &y_test);
478        assert!(result.is_ok());
479
480        let result = result.unwrap();
481        assert_eq!(result.n_bootstrap, 10);
482        assert!(result.bias_squared >= 0.0);
483        assert!(result.variance >= 0.0);
484        assert_eq!(result.noise, 0.0); // Our mock noise estimation returns 0
485        assert_eq!(
486            result.expected_error,
487            result.bias_squared + result.variance + result.noise
488        );
489        assert_eq!(result.sample_wise_results.len(), x_test.len());
490    }
491
492    #[test]
493    fn test_invalid_parameters() {
494        let analyzer = BiasVarianceAnalyzer::new().n_bootstrap(0);
495        let estimator = MockEstimator { noise_level: 0.1 };
496        let x_train = vec![1.0, 2.0, 3.0];
497        let y_train = vec![1.0, 2.0, 3.0];
498        let x_test = vec![4.0, 5.0];
499        let y_test = vec![4.0, 5.0];
500
501        let result = analyzer.decompose(&estimator, &x_train, &y_train, &x_test, &y_test);
502        assert!(result.is_err());
503    }
504
505    #[test]
506    fn test_convenience_function() {
507        let estimator = MockEstimator { noise_level: 0.05 };
508        let x_train: Vec<f64> = (0..50).map(|i| i as f64 * 0.1).collect();
509        let y_train: Vec<f64> = x_train.iter().map(|&x| x + 0.5).collect();
510        let x_test: Vec<f64> = (0..10).map(|i| i as f64 * 0.1 + 5.0).collect();
511        let y_test: Vec<f64> = x_test.iter().map(|&x| x + 0.5).collect();
512
513        let result =
514            bias_variance_decompose(&estimator, &x_train, &y_train, &x_test, &y_test, Some(20));
515        assert!(result.is_ok());
516
517        let result = result.unwrap();
518        assert_eq!(result.n_bootstrap, 20);
519    }
520}
521
522// Add additional imports that might be needed
523use scirs2_core::rand_prelude::SliceRandom;
524// use scirs2_core::random::{SeedableRng};
525// use scirs2_core::SliceRandomExt;