1use sklears_core::{
11 error::{Result, SklearsError},
12 traits::{Estimator, Fit, Predict},
13};
14use std::fmt::{self, Display, Formatter};
15
16#[derive(Debug, Clone)]
18pub struct BiasVarianceResult {
19 pub bias_squared: f64,
21 pub variance: f64,
23 pub noise: f64,
25 pub expected_error: f64,
27 pub bias_std_error: f64,
29 pub variance_std_error: f64,
31 pub n_bootstrap: usize,
33 pub sample_wise_results: Vec<SampleBiasVariance>,
35}
36
37#[derive(Debug, Clone)]
39pub struct SampleBiasVariance {
40 pub sample_index: usize,
42 pub true_value: f64,
44 pub mean_prediction: f64,
46 pub prediction_variance: f64,
48 pub squared_bias: f64,
50 pub predictions: Vec<f64>,
52}
53
54impl Display for BiasVarianceResult {
55 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
56 write!(
57 f,
58 "Bias-Variance Decomposition Results:\n\
59 Expected Error: {:.6}\n\
60 Bias²: {:.6} (SE: {:.6})\n\
61 Variance: {:.6} (SE: {:.6})\n\
62 Noise: {:.6}\n\
63 Bootstrap Samples: {}",
64 self.expected_error,
65 self.bias_squared,
66 self.bias_std_error,
67 self.variance,
68 self.variance_std_error,
69 self.noise,
70 self.n_bootstrap
71 )
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct BiasVarianceConfig {
78 pub n_bootstrap: usize,
80 pub sample_fraction: f64,
82 pub random_seed: Option<u64>,
84 pub with_replacement: bool,
86 pub compute_sample_wise: bool,
88}
89
90impl Default for BiasVarianceConfig {
91 fn default() -> Self {
92 Self {
93 n_bootstrap: 100,
94 sample_fraction: 1.0,
95 random_seed: None,
96 with_replacement: true,
97 compute_sample_wise: true,
98 }
99 }
100}
101
102pub struct BiasVarianceAnalyzer {
104 config: BiasVarianceConfig,
105}
106
107impl BiasVarianceAnalyzer {
108 pub fn new() -> Self {
110 Self {
111 config: BiasVarianceConfig::default(),
112 }
113 }
114
115 pub fn with_config(config: BiasVarianceConfig) -> Self {
117 Self { config }
118 }
119
120 pub fn n_bootstrap(mut self, n_bootstrap: usize) -> Self {
122 self.config.n_bootstrap = n_bootstrap;
123 self
124 }
125
126 pub fn sample_fraction(mut self, fraction: f64) -> Self {
128 self.config.sample_fraction = fraction;
129 self
130 }
131
132 pub fn random_seed(mut self, seed: u64) -> Self {
134 self.config.random_seed = Some(seed);
135 self
136 }
137
138 pub fn with_replacement(mut self, with_replacement: bool) -> Self {
140 self.config.with_replacement = with_replacement;
141 self
142 }
143
144 pub fn compute_sample_wise(mut self, compute: bool) -> Self {
146 self.config.compute_sample_wise = compute;
147 self
148 }
149
150 pub fn decompose<E, X, Y>(
152 &self,
153 estimator: &E,
154 x_train: &[X],
155 y_train: &[Y],
156 x_test: &[X],
157 y_test: &[Y],
158 ) -> Result<BiasVarianceResult>
159 where
160 E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
161 E::Fitted: Predict<Vec<X>, Vec<f64>>,
162 X: Clone,
163 Y: Clone + Into<f64>,
164 {
165 if self.config.n_bootstrap == 0 {
166 return Err(SklearsError::InvalidParameter {
167 name: "n_bootstrap".to_string(),
168 reason: "must be > 0".to_string(),
169 });
170 }
171
172 if self.config.sample_fraction <= 0.0 || self.config.sample_fraction > 1.0 {
173 return Err(SklearsError::InvalidParameter {
174 name: "sample_fraction".to_string(),
175 reason: "must be in (0, 1]".to_string(),
176 });
177 }
178
179 let mut rng = self.get_rng();
180 let n_train = x_train.len();
181 let _n_test = x_test.len();
182 let sample_size = (n_train as f64 * self.config.sample_fraction) as usize;
183
184 let y_test_f64: Vec<f64> = y_test.iter().map(|y| y.clone().into()).collect();
186
187 let mut all_predictions = Vec::with_capacity(self.config.n_bootstrap);
189
190 for _ in 0..self.config.n_bootstrap {
192 let (x_boot, y_boot) =
194 self.bootstrap_sample(x_train, y_train, sample_size, &mut rng)?;
195
196 let trained_model = estimator.clone().fit(&x_boot, &y_boot)?;
198
199 let x_test_vec: Vec<X> = x_test.to_vec();
201 let predictions = trained_model.predict(&x_test_vec)?;
202 all_predictions.push(predictions);
203 }
204
205 self.compute_decomposition(&all_predictions, &y_test_f64)
207 }
208
209 fn bootstrap_sample<X, Y>(
211 &self,
212 x_train: &[X],
213 y_train: &[Y],
214 sample_size: usize,
215 rng: &mut impl scirs2_core::random::Rng,
216 ) -> Result<(Vec<X>, Vec<Y>)>
217 where
218 X: Clone,
219 Y: Clone,
220 {
221 let n_train = x_train.len();
222 let mut x_boot = Vec::with_capacity(sample_size);
223 let mut y_boot = Vec::with_capacity(sample_size);
224
225 if self.config.with_replacement {
226 for _ in 0..sample_size {
228 let idx = rng.gen_range(0..n_train);
229 x_boot.push(x_train[idx].clone());
230 y_boot.push(y_train[idx].clone());
231 }
232 } else {
233 let mut indices: Vec<usize> = (0..n_train).collect();
235 indices.shuffle(rng);
236 indices.truncate(sample_size);
237
238 for &idx in &indices {
239 x_boot.push(x_train[idx].clone());
240 y_boot.push(y_train[idx].clone());
241 }
242 }
243
244 Ok((x_boot, y_boot))
245 }
246
247 fn compute_decomposition(
249 &self,
250 all_predictions: &[Vec<f64>],
251 y_test: &[f64],
252 ) -> Result<BiasVarianceResult> {
253 let n_test = y_test.len();
254 let n_bootstrap = all_predictions.len();
255
256 if n_bootstrap == 0 {
257 return Err(SklearsError::InvalidParameter {
258 name: "predictions".to_string(),
259 reason: "no bootstrap predictions provided".to_string(),
260 });
261 }
262
263 if all_predictions.iter().any(|p| p.len() != n_test) {
264 return Err(SklearsError::InvalidParameter {
265 name: "predictions".to_string(),
266 reason: "all prediction arrays must have same length as test set".to_string(),
267 });
268 }
269
270 let mut sample_wise_results = Vec::new();
271 let mut total_bias_squared = 0.0;
272 let mut total_variance = 0.0;
273 let mut bias_estimates = Vec::new();
274 let mut variance_estimates = Vec::new();
275
276 for i in 0..n_test {
278 let true_value = y_test[i];
279 let predictions: Vec<f64> = all_predictions.iter().map(|p| p[i]).collect();
280
281 let mean_prediction = predictions.iter().sum::<f64>() / n_bootstrap as f64;
283
284 let prediction_variance = predictions
286 .iter()
287 .map(|&p| (p - mean_prediction).powi(2))
288 .sum::<f64>()
289 / n_bootstrap as f64;
290
291 let squared_bias = (mean_prediction - true_value).powi(2);
293
294 total_bias_squared += squared_bias;
295 total_variance += prediction_variance;
296
297 bias_estimates.push(squared_bias);
298 variance_estimates.push(prediction_variance);
299
300 if self.config.compute_sample_wise {
301 sample_wise_results.push(SampleBiasVariance {
302 sample_index: i,
303 true_value,
304 mean_prediction,
305 prediction_variance,
306 squared_bias,
307 predictions,
308 });
309 }
310 }
311
312 let bias_squared = total_bias_squared / n_test as f64;
314 let variance = total_variance / n_test as f64;
315
316 let bias_std_error = self.compute_standard_error(&bias_estimates);
318 let variance_std_error = self.compute_standard_error(&variance_estimates);
319
320 let noise = self.estimate_noise(y_test);
323
324 let expected_error = bias_squared + variance + noise;
325
326 Ok(BiasVarianceResult {
327 bias_squared,
328 variance,
329 noise,
330 expected_error,
331 bias_std_error,
332 variance_std_error,
333 n_bootstrap,
334 sample_wise_results,
335 })
336 }
337
338 fn compute_standard_error(&self, estimates: &[f64]) -> f64 {
340 let n = estimates.len() as f64;
341 let mean = estimates.iter().sum::<f64>() / n;
342 let variance = estimates.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0);
343 (variance / n).sqrt()
344 }
345
346 fn estimate_noise(&self, _y_test: &[f64]) -> f64 {
348 0.0
352 }
353
354 fn get_rng(&self) -> impl scirs2_core::random::Rng {
356 use scirs2_core::random::rngs::StdRng;
357 use scirs2_core::random::SeedableRng;
358
359 match self.config.random_seed {
360 Some(seed) => StdRng::seed_from_u64(seed),
361 None => {
362 use scirs2_core::random::thread_rng;
363 StdRng::from_rng(&mut thread_rng())
364 }
365 }
366 }
367}
368
369impl Default for BiasVarianceAnalyzer {
370 fn default() -> Self {
371 Self::new()
372 }
373}
374
375pub fn bias_variance_decompose<E, X, Y>(
377 estimator: &E,
378 x_train: &[X],
379 y_train: &[Y],
380 x_test: &[X],
381 y_test: &[Y],
382 n_bootstrap: Option<usize>,
383) -> Result<BiasVarianceResult>
384where
385 E: Estimator + Fit<Vec<X>, Vec<Y>> + Clone,
386 E::Fitted: Predict<Vec<X>, Vec<f64>>,
387 X: Clone,
388 Y: Clone + Into<f64>,
389{
390 let mut analyzer = BiasVarianceAnalyzer::new();
391 if let Some(n) = n_bootstrap {
392 analyzer = analyzer.n_bootstrap(n);
393 }
394 analyzer.decompose(estimator, x_train, y_train, x_test, y_test)
395}
396
397#[allow(non_snake_case)]
398#[cfg(test)]
399mod tests {
400 use super::*;
401
402 #[derive(Clone)]
404 struct MockEstimator {
405 noise_level: f64,
406 }
407
408 struct MockTrained {
409 noise_level: f64,
410 }
411
412 impl Estimator for MockEstimator {
413 type Config = ();
414 type Error = SklearsError;
415 type Float = f64;
416
417 fn config(&self) -> &Self::Config {
418 &()
419 }
420 }
421
422 impl Fit<Vec<f64>, Vec<f64>> for MockEstimator {
423 type Fitted = MockTrained;
424
425 fn fit(self, _x: &Vec<f64>, _y: &Vec<f64>) -> Result<Self::Fitted> {
426 Ok(MockTrained {
427 noise_level: self.noise_level,
428 })
429 }
430 }
431
432 impl Predict<Vec<f64>, Vec<f64>> for MockTrained {
433 fn predict(&self, x: &Vec<f64>) -> Result<Vec<f64>> {
434 let mut rng = scirs2_core::random::thread_rng();
435 Ok(x.iter()
436 .map(|&xi| xi + rng.gen_range(-self.noise_level..self.noise_level))
437 .collect())
438 }
439 }
440
441 #[test]
442 fn test_bias_variance_analyzer_creation() {
443 let analyzer = BiasVarianceAnalyzer::new();
444 assert_eq!(analyzer.config.n_bootstrap, 100);
445 assert_eq!(analyzer.config.sample_fraction, 1.0);
446 assert!(analyzer.config.random_seed.is_none());
447 assert!(analyzer.config.with_replacement);
448 assert!(analyzer.config.compute_sample_wise);
449 }
450
451 #[test]
452 fn test_bias_variance_configuration() {
453 let analyzer = BiasVarianceAnalyzer::new()
454 .n_bootstrap(50)
455 .sample_fraction(0.8)
456 .random_seed(42)
457 .with_replacement(false)
458 .compute_sample_wise(false);
459
460 assert_eq!(analyzer.config.n_bootstrap, 50);
461 assert_eq!(analyzer.config.sample_fraction, 0.8);
462 assert_eq!(analyzer.config.random_seed, Some(42));
463 assert!(!analyzer.config.with_replacement);
464 assert!(!analyzer.config.compute_sample_wise);
465 }
466
467 #[test]
468 fn test_bias_variance_decomposition() {
469 let estimator = MockEstimator { noise_level: 0.1 };
470 let x_train: Vec<f64> = (0..100).map(|i| i as f64 * 0.1).collect();
471 let y_train: Vec<f64> = x_train.iter().map(|&x| x * 2.0 + 1.0).collect();
472 let x_test: Vec<f64> = (0..20).map(|i| i as f64 * 0.1 + 10.0).collect();
473 let y_test: Vec<f64> = x_test.iter().map(|&x| x * 2.0 + 1.0).collect();
474
475 let analyzer = BiasVarianceAnalyzer::new().n_bootstrap(10).random_seed(42);
476
477 let result = analyzer.decompose(&estimator, &x_train, &y_train, &x_test, &y_test);
478 assert!(result.is_ok());
479
480 let result = result.unwrap();
481 assert_eq!(result.n_bootstrap, 10);
482 assert!(result.bias_squared >= 0.0);
483 assert!(result.variance >= 0.0);
484 assert_eq!(result.noise, 0.0); assert_eq!(
486 result.expected_error,
487 result.bias_squared + result.variance + result.noise
488 );
489 assert_eq!(result.sample_wise_results.len(), x_test.len());
490 }
491
492 #[test]
493 fn test_invalid_parameters() {
494 let analyzer = BiasVarianceAnalyzer::new().n_bootstrap(0);
495 let estimator = MockEstimator { noise_level: 0.1 };
496 let x_train = vec![1.0, 2.0, 3.0];
497 let y_train = vec![1.0, 2.0, 3.0];
498 let x_test = vec![4.0, 5.0];
499 let y_test = vec![4.0, 5.0];
500
501 let result = analyzer.decompose(&estimator, &x_train, &y_train, &x_test, &y_test);
502 assert!(result.is_err());
503 }
504
505 #[test]
506 fn test_convenience_function() {
507 let estimator = MockEstimator { noise_level: 0.05 };
508 let x_train: Vec<f64> = (0..50).map(|i| i as f64 * 0.1).collect();
509 let y_train: Vec<f64> = x_train.iter().map(|&x| x + 0.5).collect();
510 let x_test: Vec<f64> = (0..10).map(|i| i as f64 * 0.1 + 5.0).collect();
511 let y_test: Vec<f64> = x_test.iter().map(|&x| x + 0.5).collect();
512
513 let result =
514 bias_variance_decompose(&estimator, &x_train, &y_train, &x_test, &y_test, Some(20));
515 assert!(result.is_ok());
516
517 let result = result.unwrap();
518 assert_eq!(result.n_bootstrap, 20);
519 }
520}
521
522use scirs2_core::rand_prelude::SliceRandom;
524