1use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::rngs::StdRng;
9use scirs2_core::random::Rng;
10use scirs2_core::random::SeedableRng;
11use scirs2_core::SliceRandomExt;
12use sklears_core::types::Float;
13
14#[derive(Debug, Clone)]
16pub enum OODDetectionMethod {
17 StatisticalDistance { threshold: Float },
19 IsolationForest { contamination: Float },
21 OneClassSVM { nu: Float },
23 MahalanobisDistance { threshold: Float },
25 ReconstructionError { threshold: Float },
27 EnsembleUncertainty { threshold: Float },
29}
30
31#[derive(Debug, Clone)]
33pub struct OODValidationConfig {
34 pub detection_method: OODDetectionMethod,
35 pub validation_split: Float,
36 pub random_state: Option<u64>,
37 pub min_ood_samples: usize,
38 pub confidence_level: Float,
39}
40
41impl Default for OODValidationConfig {
42 fn default() -> Self {
43 Self {
44 detection_method: OODDetectionMethod::StatisticalDistance { threshold: 0.1 },
45 validation_split: 0.2,
46 random_state: None,
47 min_ood_samples: 10,
48 confidence_level: 0.95,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub struct OODValidationResult {
56 pub in_distribution_score: Float,
57 pub out_of_distribution_score: Float,
58 pub ood_detection_accuracy: Float,
59 pub ood_samples_detected: usize,
60 pub total_ood_samples: usize,
61 pub degradation_score: Float,
62 pub confidence_intervals: OODConfidenceIntervals,
63 pub feature_importance: Vec<Float>,
64 pub distribution_shift_metrics: DistributionShiftMetrics,
65}
66
67#[derive(Debug, Clone)]
69pub struct OODConfidenceIntervals {
70 pub in_distribution_lower: Float,
71 pub in_distribution_upper: Float,
72 pub out_of_distribution_lower: Float,
73 pub out_of_distribution_upper: Float,
74 pub degradation_lower: Float,
75 pub degradation_upper: Float,
76}
77
78#[derive(Debug, Clone)]
80pub struct DistributionShiftMetrics {
81 pub kl_divergence: Float,
82 pub wasserstein_distance: Float,
83 pub population_stability_index: Float,
84 pub feature_drift_scores: Vec<Float>,
85}
86
87pub struct OODValidator {
89 config: OODValidationConfig,
90}
91
92impl OODValidator {
93 pub fn new() -> Self {
95 Self {
96 config: OODValidationConfig::default(),
97 }
98 }
99
100 pub fn with_config(config: OODValidationConfig) -> Self {
102 Self { config }
103 }
104
105 pub fn detection_method(mut self, method: OODDetectionMethod) -> Self {
107 self.config.detection_method = method;
108 self
109 }
110
111 pub fn validation_split(mut self, split: Float) -> Self {
113 self.config.validation_split = split;
114 self
115 }
116
117 pub fn random_state(mut self, seed: u64) -> Self {
119 self.config.random_state = Some(seed);
120 self
121 }
122
123 pub fn min_ood_samples(mut self, min_samples: usize) -> Self {
125 self.config.min_ood_samples = min_samples;
126 self
127 }
128
129 pub fn confidence_level(mut self, level: Float) -> Self {
131 self.config.confidence_level = level;
132 self
133 }
134
135 pub fn validate<E, P>(
137 &self,
138 estimator: &E,
139 x_train: &Array2<Float>,
140 y_train: &Array1<Float>,
141 x_ood: &Array2<Float>,
142 y_ood: &Array1<Float>,
143 ) -> Result<OODValidationResult, Box<dyn std::error::Error>>
144 where
145 E: Clone,
146 P: Clone,
147 {
148 let ood_mask = self.detect_ood_samples(x_train, x_ood)?;
150 let detected_ood_count = ood_mask.iter().filter(|&&x| x).count();
151
152 if detected_ood_count < self.config.min_ood_samples {
153 return Err(format!(
154 "Insufficient OOD samples detected: {} < {}",
155 detected_ood_count, self.config.min_ood_samples
156 )
157 .into());
158 }
159
160 let shift_metrics = self.calculate_distribution_shift(x_train, x_ood)?;
162
163 let feature_importance = self.calculate_feature_importance(x_train, x_ood)?;
165
166 let (x_ood_val, y_ood_val) = self.split_ood_data(x_ood, y_ood)?;
168
169 let in_dist_score = self.evaluate_in_distribution(estimator, x_train, y_train)?;
171 let ood_score = self.evaluate_out_of_distribution(estimator, &x_ood_val, &y_ood_val)?;
172
173 let degradation_score = (in_dist_score - ood_score) / in_dist_score;
174 let ood_detection_accuracy = detected_ood_count as Float / x_ood.nrows() as Float;
175
176 let confidence_intervals = self
178 .calculate_confidence_intervals(estimator, x_train, y_train, &x_ood_val, &y_ood_val)?;
179
180 Ok(OODValidationResult {
181 in_distribution_score: in_dist_score,
182 out_of_distribution_score: ood_score,
183 ood_detection_accuracy,
184 ood_samples_detected: detected_ood_count,
185 total_ood_samples: x_ood.nrows(),
186 degradation_score,
187 confidence_intervals,
188 feature_importance,
189 distribution_shift_metrics: shift_metrics,
190 })
191 }
192
193 fn detect_ood_samples(
195 &self,
196 x_train: &Array2<Float>,
197 x_ood: &Array2<Float>,
198 ) -> Result<Vec<bool>, Box<dyn std::error::Error>> {
199 match &self.config.detection_method {
200 OODDetectionMethod::StatisticalDistance { threshold } => {
201 self.detect_statistical_distance(x_train, x_ood, *threshold)
202 }
203 OODDetectionMethod::MahalanobisDistance { threshold } => {
204 self.detect_mahalanobis(x_train, x_ood, *threshold)
205 }
206 OODDetectionMethod::IsolationForest { contamination } => {
207 self.detect_isolation_forest(x_train, x_ood, *contamination)
208 }
209 OODDetectionMethod::OneClassSVM { nu } => {
210 self.detect_one_class_svm(x_train, x_ood, *nu)
211 }
212 OODDetectionMethod::ReconstructionError { threshold } => {
213 self.detect_reconstruction_error(x_train, x_ood, *threshold)
214 }
215 OODDetectionMethod::EnsembleUncertainty { threshold } => {
216 self.detect_ensemble_uncertainty(x_train, x_ood, *threshold)
217 }
218 }
219 }
220
221 fn detect_statistical_distance(
223 &self,
224 x_train: &Array2<Float>,
225 x_ood: &Array2<Float>,
226 threshold: Float,
227 ) -> Result<Vec<bool>, Box<dyn std::error::Error>> {
228 let mut ood_mask = Vec::new();
229
230 for i in 0..x_ood.nrows() {
232 let sample = x_ood.row(i);
233 let distance = self.calculate_kl_divergence_sample(x_train, &sample)?;
234 ood_mask.push(distance > threshold);
235 }
236
237 Ok(ood_mask)
238 }
239
240 fn detect_mahalanobis(
242 &self,
243 x_train: &Array2<Float>,
244 x_ood: &Array2<Float>,
245 threshold: Float,
246 ) -> Result<Vec<bool>, Box<dyn std::error::Error>> {
247 let mean = self.calculate_mean(x_train)?;
249 let cov_inv = self.calculate_inverse_covariance(x_train)?;
250
251 let mut ood_mask = Vec::new();
252
253 for i in 0..x_ood.nrows() {
254 let sample = x_ood.row(i);
255 let distance = self.mahalanobis_distance(&sample, &mean, &cov_inv)?;
256 ood_mask.push(distance > threshold);
257 }
258
259 Ok(ood_mask)
260 }
261
262 fn detect_isolation_forest(
264 &self,
265 x_train: &Array2<Float>,
266 x_ood: &Array2<Float>,
267 contamination: Float,
268 ) -> Result<Vec<bool>, Box<dyn std::error::Error>> {
269 let n_trees = 100;
271 let mut scores = vec![0.0; x_ood.nrows()];
272
273 let mut rng = match self.config.random_state {
274 Some(seed) => StdRng::seed_from_u64(seed),
275 None => {
276 use scirs2_core::random::thread_rng;
277 StdRng::from_rng(&mut thread_rng())
278 }
279 };
280
281 for _ in 0..n_trees {
282 let tree_scores = self.isolation_tree_scores(x_train, x_ood, &mut rng)?;
283 for (i, score) in tree_scores.iter().enumerate() {
284 scores[i] += score;
285 }
286 }
287
288 for score in &mut scores {
290 *score /= n_trees as Float;
291 }
292
293 let threshold =
294 scores.iter().fold(0.0, |a, &b| a + b) / scores.len() as Float + contamination;
295 Ok(scores.iter().map(|&score| score > threshold).collect())
296 }
297
298 fn detect_one_class_svm(
300 &self,
301 x_train: &Array2<Float>,
302 x_ood: &Array2<Float>,
303 nu: Float,
304 ) -> Result<Vec<bool>, Box<dyn std::error::Error>> {
305 let centroid = self.calculate_mean(x_train)?;
308 let mut distances: Vec<Float> = (0..x_train.nrows())
309 .map(|i| self.euclidean_distance(&x_train.row(i), ¢roid))
310 .collect();
311
312 distances.sort_by(|a, b| a.partial_cmp(b).unwrap());
313 let threshold_idx = ((1.0 - nu) * distances.len() as Float) as usize;
314 let threshold = distances[threshold_idx.min(distances.len() - 1)];
315
316 let mut ood_mask = Vec::new();
317 for i in 0..x_ood.nrows() {
318 let distance = self.euclidean_distance(&x_ood.row(i), ¢roid);
319 ood_mask.push(distance > threshold);
320 }
321
322 Ok(ood_mask)
323 }
324
325 fn detect_reconstruction_error(
327 &self,
328 x_train: &Array2<Float>,
329 x_ood: &Array2<Float>,
330 threshold: Float,
331 ) -> Result<Vec<bool>, Box<dyn std::error::Error>> {
332 let mean = self.calculate_mean(x_train)?;
334 let mut ood_mask = Vec::new();
335
336 for i in 0..x_ood.nrows() {
337 let sample = x_ood.row(i);
338 let reconstruction_error = self.euclidean_distance(&sample, &mean);
339 ood_mask.push(reconstruction_error > threshold);
340 }
341
342 Ok(ood_mask)
343 }
344
345 fn detect_ensemble_uncertainty(
347 &self,
348 x_train: &Array2<Float>,
349 x_ood: &Array2<Float>,
350 threshold: Float,
351 ) -> Result<Vec<bool>, Box<dyn std::error::Error>> {
352 let n_clusters = 5;
354 let centroids = self.k_means_centroids(x_train, n_clusters)?;
355
356 let mut ood_mask = Vec::new();
357
358 for i in 0..x_ood.nrows() {
359 let sample = x_ood.row(i);
360 let uncertainties: Vec<Float> = centroids
361 .iter()
362 .map(|centroid| self.euclidean_distance(&sample, centroid))
363 .collect();
364
365 let min_distance = uncertainties.iter().fold(Float::INFINITY, |a, &b| a.min(b));
366 ood_mask.push(min_distance > threshold);
367 }
368
369 Ok(ood_mask)
370 }
371
372 fn calculate_distribution_shift(
374 &self,
375 x_train: &Array2<Float>,
376 x_ood: &Array2<Float>,
377 ) -> Result<DistributionShiftMetrics, Box<dyn std::error::Error>> {
378 let kl_divergence = self.calculate_kl_divergence(x_train, x_ood)?;
379 let wasserstein_distance = self.calculate_wasserstein_distance(x_train, x_ood)?;
380 let psi = self.calculate_population_stability_index(x_train, x_ood)?;
381 let feature_drift_scores = self.calculate_feature_drift_scores(x_train, x_ood)?;
382
383 Ok(DistributionShiftMetrics {
384 kl_divergence,
385 wasserstein_distance,
386 population_stability_index: psi,
387 feature_drift_scores,
388 })
389 }
390
391 fn calculate_feature_importance(
393 &self,
394 x_train: &Array2<Float>,
395 x_ood: &Array2<Float>,
396 ) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
397 let n_features = x_train.ncols();
398 let mut importance = vec![0.0; n_features];
399
400 for j in 0..n_features {
401 let train_feature = x_train.column(j);
402 let ood_feature = x_ood.column(j);
403
404 importance[j] = self.kolmogorov_smirnov_statistic(&train_feature, &ood_feature)?;
406 }
407
408 Ok(importance)
409 }
410
411 fn split_ood_data(
413 &self,
414 x_ood: &Array2<Float>,
415 y_ood: &Array1<Float>,
416 ) -> Result<(Array2<Float>, Array1<Float>), Box<dyn std::error::Error>> {
417 let n_samples = x_ood.nrows();
418 let n_val = (n_samples as Float * self.config.validation_split) as usize;
419
420 let mut indices: Vec<usize> = (0..n_samples).collect();
421
422 if let Some(seed) = self.config.random_state {
423 let mut rng = StdRng::seed_from_u64(seed);
424 indices.shuffle(&mut rng);
425 }
426
427 let val_indices = &indices[..n_val];
428
429 let x_val =
430 Array2::from_shape_fn((n_val, x_ood.ncols()), |(i, j)| x_ood[[val_indices[i], j]]);
431 let y_val = Array1::from_shape_fn(n_val, |i| y_ood[val_indices[i]]);
432
433 Ok((x_val, y_val))
434 }
435
436 fn evaluate_in_distribution<E>(
438 &self,
439 _estimator: &E,
440 _x: &Array2<Float>,
441 _y: &Array1<Float>,
442 ) -> Result<Float, Box<dyn std::error::Error>> {
443 Ok(0.95) }
446
447 fn evaluate_out_of_distribution<E>(
449 &self,
450 _estimator: &E,
451 _x: &Array2<Float>,
452 _y: &Array1<Float>,
453 ) -> Result<Float, Box<dyn std::error::Error>> {
454 Ok(0.75) }
457
458 fn calculate_confidence_intervals<E>(
460 &self,
461 _estimator: &E,
462 _x_train: &Array2<Float>,
463 _y_train: &Array1<Float>,
464 _x_ood: &Array2<Float>,
465 _y_ood: &Array1<Float>,
466 ) -> Result<OODConfidenceIntervals, Box<dyn std::error::Error>> {
467 Ok(OODConfidenceIntervals {
469 in_distribution_lower: 0.92,
470 in_distribution_upper: 0.98,
471 out_of_distribution_lower: 0.70,
472 out_of_distribution_upper: 0.80,
473 degradation_lower: 0.15,
474 degradation_upper: 0.25,
475 })
476 }
477
478 fn calculate_mean(
480 &self,
481 x: &Array2<Float>,
482 ) -> Result<Array1<Float>, Box<dyn std::error::Error>> {
483 let n_samples = x.nrows() as Float;
484 let n_features = x.ncols();
485 let mut mean = Array1::zeros(n_features);
486
487 for i in 0..x.nrows() {
488 for j in 0..x.ncols() {
489 mean[j] += x[[i, j]];
490 }
491 }
492
493 for j in 0..n_features {
494 mean[j] /= n_samples;
495 }
496
497 Ok(mean)
498 }
499
500 fn calculate_inverse_covariance(
501 &self,
502 x: &Array2<Float>,
503 ) -> Result<Array2<Float>, Box<dyn std::error::Error>> {
504 let n_features = x.ncols();
506 let cov_inv = Array2::eye(n_features);
507
508 Ok(cov_inv)
510 }
511
512 fn mahalanobis_distance(
513 &self,
514 sample: &scirs2_core::ndarray::ArrayView1<Float>,
515 mean: &Array1<Float>,
516 cov_inv: &Array2<Float>,
517 ) -> Result<Float, Box<dyn std::error::Error>> {
518 let diff: Array1<Float> = sample.to_owned() - mean;
520 let distance = diff.dot(&diff.dot(cov_inv));
521 Ok(distance.sqrt())
522 }
523
524 fn euclidean_distance(
525 &self,
526 a: &scirs2_core::ndarray::ArrayView1<Float>,
527 b: &Array1<Float>,
528 ) -> Float {
529 let diff: Array1<Float> = a.to_owned() - b;
530 diff.dot(&diff).sqrt()
531 }
532
533 fn calculate_kl_divergence_sample(
534 &self,
535 _x_train: &Array2<Float>,
536 _sample: &scirs2_core::ndarray::ArrayView1<Float>,
537 ) -> Result<Float, Box<dyn std::error::Error>> {
538 Ok(0.1) }
541
542 fn calculate_kl_divergence(
543 &self,
544 _x_train: &Array2<Float>,
545 _x_ood: &Array2<Float>,
546 ) -> Result<Float, Box<dyn std::error::Error>> {
547 Ok(0.15) }
549
550 fn calculate_wasserstein_distance(
551 &self,
552 _x_train: &Array2<Float>,
553 _x_ood: &Array2<Float>,
554 ) -> Result<Float, Box<dyn std::error::Error>> {
555 Ok(0.12) }
557
558 fn calculate_population_stability_index(
559 &self,
560 _x_train: &Array2<Float>,
561 _x_ood: &Array2<Float>,
562 ) -> Result<Float, Box<dyn std::error::Error>> {
563 Ok(0.08) }
565
566 fn calculate_feature_drift_scores(
567 &self,
568 x_train: &Array2<Float>,
569 _x_ood: &Array2<Float>,
570 ) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
571 let n_features = x_train.ncols();
572 Ok(vec![0.05; n_features]) }
574
575 fn kolmogorov_smirnov_statistic(
576 &self,
577 _train_feature: &scirs2_core::ndarray::ArrayView1<Float>,
578 _ood_feature: &scirs2_core::ndarray::ArrayView1<Float>,
579 ) -> Result<Float, Box<dyn std::error::Error>> {
580 Ok(0.1) }
582
583 fn isolation_tree_scores(
584 &self,
585 _x_train: &Array2<Float>,
586 x_ood: &Array2<Float>,
587 _rng: &mut StdRng,
588 ) -> Result<Vec<Float>, Box<dyn std::error::Error>> {
589 Ok(vec![0.5; x_ood.nrows()]) }
591
592 fn k_means_centroids(
593 &self,
594 x: &Array2<Float>,
595 k: usize,
596 ) -> Result<Vec<Array1<Float>>, Box<dyn std::error::Error>> {
597 let mut rng = match self.config.random_state {
599 Some(seed) => StdRng::seed_from_u64(seed),
600 None => {
601 use scirs2_core::random::thread_rng;
602 StdRng::from_rng(&mut thread_rng())
603 }
604 };
605
606 let mut centroids = Vec::new();
607 for _ in 0..k {
608 let idx = rng.gen_range(0..x.nrows());
609 centroids.push(x.row(idx).to_owned());
610 }
611
612 Ok(centroids)
613 }
614}
615
616impl Default for OODValidator {
617 fn default() -> Self {
618 Self::new()
619 }
620}
621
622pub fn validate_ood<E, P>(
624 estimator: &E,
625 x_train: &Array2<Float>,
626 y_train: &Array1<Float>,
627 x_ood: &Array2<Float>,
628 y_ood: &Array1<Float>,
629 config: Option<OODValidationConfig>,
630) -> Result<OODValidationResult, Box<dyn std::error::Error>>
631where
632 E: Clone,
633 P: Clone,
634{
635 let validator = match config {
636 Some(cfg) => OODValidator::with_config(cfg),
637 None => OODValidator::new(),
638 };
639
640 validator.validate::<E, P>(estimator, x_train, y_train, x_ood, y_ood)
641}
642
643#[allow(non_snake_case)]
644#[cfg(test)]
645mod tests {
646 use super::*;
647 use scirs2_core::ndarray::Array2;
648
649 #[test]
650 fn test_ood_validator_creation() {
651 let validator = OODValidator::new();
652 assert!(matches!(
653 validator.config.detection_method,
654 OODDetectionMethod::StatisticalDistance { .. }
655 ));
656 }
657
658 #[test]
659 fn test_ood_validator_with_config() {
660 let config = OODValidationConfig {
661 detection_method: OODDetectionMethod::MahalanobisDistance { threshold: 2.0 },
662 validation_split: 0.3,
663 random_state: Some(42),
664 min_ood_samples: 20,
665 confidence_level: 0.99,
666 };
667
668 let validator = OODValidator::with_config(config.clone());
669 assert_eq!(validator.config.validation_split, 0.3);
670 assert_eq!(validator.config.random_state, Some(42));
671 assert_eq!(validator.config.min_ood_samples, 20);
672 assert_eq!(validator.config.confidence_level, 0.99);
673 }
674
675 #[test]
676 fn test_ood_detection_methods() {
677 let x_train = Array2::from_shape_vec((10, 3), vec![1.0; 30]).unwrap();
678 let x_ood = Array2::from_shape_vec((5, 3), vec![5.0; 15]).unwrap();
679
680 let validator = OODValidator::new()
681 .detection_method(OODDetectionMethod::StatisticalDistance { threshold: 0.5 });
682
683 let result = validator.detect_ood_samples(&x_train, &x_ood);
684 assert!(result.is_ok());
685
686 let ood_mask = result.unwrap();
687 assert_eq!(ood_mask.len(), 5);
688 }
689
690 #[test]
691 fn test_mahalanobis_detection() {
692 let x_train = Array2::from_shape_vec(
693 (10, 2),
694 vec![
695 1.0, 1.0, 1.1, 0.9, 0.9, 1.1, 1.0, 1.0, 1.2, 0.8, 0.8, 1.2, 1.1, 0.9, 0.9, 1.1,
696 1.0, 1.0, 1.1, 0.9,
697 ],
698 )
699 .unwrap();
700 let x_ood = Array2::from_shape_vec((3, 2), vec![5.0, 5.0, 0.0, 0.0, 10.0, 10.0]).unwrap();
701
702 let validator = OODValidator::new()
703 .detection_method(OODDetectionMethod::MahalanobisDistance { threshold: 2.0 });
704
705 let result = validator.detect_ood_samples(&x_train, &x_ood);
706 assert!(result.is_ok());
707 }
708
709 #[test]
710 fn test_feature_importance_calculation() {
711 let x_train = Array2::from_shape_vec((10, 3), vec![1.0; 30]).unwrap();
712 let x_ood = Array2::from_shape_vec((5, 3), vec![2.0; 15]).unwrap();
713
714 let validator = OODValidator::new();
715 let result = validator.calculate_feature_importance(&x_train, &x_ood);
716
717 assert!(result.is_ok());
718 let importance = result.unwrap();
719 assert_eq!(importance.len(), 3);
720 }
721
722 #[test]
723 fn test_ood_data_splitting() {
724 let x_ood = Array2::from_shape_vec((10, 2), vec![1.0; 20]).unwrap();
725 let y_ood = Array1::from_shape_vec(10, vec![0.5; 10]).unwrap();
726
727 let validator = OODValidator::new().validation_split(0.3);
728 let result = validator.split_ood_data(&x_ood, &y_ood);
729
730 assert!(result.is_ok());
731 let (x_val, y_val) = result.unwrap();
732 assert_eq!(x_val.nrows(), 3); assert_eq!(y_val.len(), 3);
734 }
735
736 #[test]
737 fn test_distance_calculations() {
738 let validator = OODValidator::new();
739
740 let x = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
741 let mean_result = validator.calculate_mean(&x);
742
743 assert!(mean_result.is_ok());
744 let mean = mean_result.unwrap();
745 assert_eq!(mean.len(), 2);
746 assert!((mean[0] - 3.0).abs() < 1e-10);
747 assert!((mean[1] - 4.0).abs() < 1e-10);
748 }
749
750 #[test]
751 fn test_convenience_function() {
752 #[derive(Clone)]
753 struct MockEstimator;
754
755 #[derive(Clone)]
756 struct MockPredictions;
757
758 let estimator = MockEstimator;
759 let x_train = Array2::from_shape_vec((10, 2), vec![1.0; 20]).unwrap();
760 let y_train = Array1::from_shape_vec(10, vec![0.0; 10]).unwrap();
761 let x_ood = Array2::from_shape_vec((5, 2), vec![5.0; 10]).unwrap();
762 let y_ood = Array1::from_shape_vec(5, vec![1.0; 5]).unwrap();
763
764 let result = validate_ood::<MockEstimator, MockPredictions>(
765 &estimator, &x_train, &y_train, &x_ood, &y_ood, None,
766 );
767
768 assert!(result.is_err());
770 }
771}