1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
8use scirs2_core::random::{Rng, RngExt};
9use scirs2_core::SliceRandomExt;
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 traits::{Fit, Transform},
13 types::Float,
14};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone)]
19pub enum CrossValidationStrategy {
20 KFold { n_splits: usize, shuffle: bool },
22 StratifiedKFold { n_splits: usize, shuffle: bool },
24 LeaveOneOut,
26 TimeSeriesSplit {
28 n_splits: usize,
29 max_train_size: Option<usize>,
30 },
31 GroupKFold { n_splits: usize },
33}
34
35#[derive(Debug, Clone)]
37pub enum MissingDataPattern {
38 MCAR { missing_rate: f64 },
40 MAR {
42 missing_rate: f64,
43 dependency_strength: f64,
44 },
45 MNAR {
47 missing_rate: f64,
48 threshold_factor: f64,
49 },
50 Block {
52 block_size: (usize, usize),
53 n_blocks: usize,
54 },
55 Monotone { missing_rates: Vec<f64> },
57}
58
59#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
61pub struct ImputationMetrics {
62 pub rmse: f64,
64 pub mae: f64,
66 pub r2: f64,
68 pub accuracy: f64,
70 pub f1_score: f64,
72 pub bias: f64,
74 pub coverage: f64,
76 pub ks_statistic: f64,
78 pub ks_pvalue: f64,
80}
81
82#[derive(Debug, Clone)]
84pub struct CrossValidationResults {
85 pub fold_metrics: Vec<ImputationMetrics>,
87 pub mean_metrics: ImputationMetrics,
89 pub std_metrics: ImputationMetrics,
91 pub confidence_intervals: HashMap<String, (f64, f64)>,
93}
94
95#[derive(Debug, Clone)]
128pub struct ImputationCrossValidator {
129 cv_strategy: CrossValidationStrategy,
130 missing_pattern: MissingDataPattern,
131 test_fraction: f64,
132 random_state: Option<u64>,
133 n_jobs: usize,
134}
135
136#[derive(Debug, Clone)]
141pub struct HoldOutValidator {
142 test_size: f64,
143 missing_pattern: MissingDataPattern,
144 random_state: Option<u64>,
145 stratify: bool,
146}
147
148#[derive(Debug, Clone)]
153pub struct SyntheticMissingValidator {
154 data_generators: Vec<DataGenerator>,
155 missing_patterns: Vec<MissingDataPattern>,
156 n_datasets: usize,
157 dataset_sizes: Vec<(usize, usize)>,
158 random_state: Option<u64>,
159}
160
161#[derive(Debug, Clone)]
163pub enum DataGenerator {
164 MultivariateNormal { mean: Array1<f64>, cov: Array2<f64> },
166 LinearRelationships {
168 coefficients: Array2<f64>,
169 noise_std: f64,
170 },
171 NonLinear {
173 function_type: String,
174 noise_std: f64,
175 },
176 MixedType {
178 continuous_props: f64,
179 n_categories: Vec<usize>,
180 },
181}
182
183#[derive(Debug, Clone)]
188pub struct CaseStudyValidator {
189 case_studies: Vec<CaseStudy>,
190 evaluation_metrics: Vec<String>,
191 comparison_methods: Vec<String>,
192}
193
194#[derive(Debug, Clone)]
196pub struct CaseStudy {
197 name: String,
198 description: String,
199 data_characteristics: DataCharacteristics,
200 missing_patterns: Vec<MissingDataPattern>,
201 evaluation_criteria: Vec<String>,
202}
203
204#[derive(Debug, Clone)]
206pub struct DataCharacteristics {
207 n_samples: usize,
208 n_features: usize,
209 feature_types: Vec<String>, correlation_structure: String, outlier_fraction: f64,
212 noise_level: f64,
213}
214
215impl ImputationCrossValidator {
218 pub fn new() -> Self {
220 Self {
221 cv_strategy: CrossValidationStrategy::KFold {
222 n_splits: 5,
223 shuffle: true,
224 },
225 missing_pattern: MissingDataPattern::MCAR { missing_rate: 0.2 },
226 test_fraction: 0.1,
227 random_state: None,
228 n_jobs: 1,
229 }
230 }
231
232 pub fn cv_strategy(mut self, strategy: CrossValidationStrategy) -> Self {
234 self.cv_strategy = strategy;
235 self
236 }
237
238 pub fn missing_pattern(mut self, pattern: MissingDataPattern) -> Self {
240 self.missing_pattern = pattern;
241 self
242 }
243
244 pub fn test_fraction(mut self, fraction: f64) -> Self {
246 self.test_fraction = fraction;
247 self
248 }
249
250 pub fn random_state(mut self, random_state: u64) -> Self {
252 self.random_state = Some(random_state);
253 self
254 }
255
256 pub fn n_jobs(mut self, n_jobs: usize) -> Self {
258 self.n_jobs = n_jobs;
259 self
260 }
261
262 #[allow(non_snake_case)]
267 #[allow(dead_code)]
268 fn validate_imputer_disabled<'a, I, F>(
269 &self,
270 _imputer: &I,
271 _X: &ArrayView2<'_, Float>,
272 ) -> SklResult<CrossValidationResults>
273 where
274 I: Clone,
275 I: Fit<ArrayView2<'a, Float>, (), Fitted = F>,
276 F: Transform<ArrayView2<'a, Float>, Array2<Float>>,
277 {
278 Err(SklearsError::NotImplemented(
339 "validate_imputer temporarily disabled due to HRTB compilation issues".to_string(),
340 ))
341 }
342
343 fn generate_fold_indices(
344 &self,
345 n_samples: usize,
346 rng: &mut impl Rng,
347 ) -> SklResult<Vec<(Vec<usize>, Vec<usize>)>> {
348 let mut indices: Vec<usize> = (0..n_samples).collect();
349
350 match &self.cv_strategy {
351 CrossValidationStrategy::KFold { n_splits, shuffle } => {
352 if *shuffle {
353 indices.shuffle(rng);
354 }
355
356 let fold_size = n_samples / n_splits;
357 let mut folds = Vec::new();
358
359 for i in 0..*n_splits {
360 let start = i * fold_size;
361 let end = if i == n_splits - 1 {
362 n_samples
363 } else {
364 (i + 1) * fold_size
365 };
366
367 let test_indices: Vec<usize> = indices[start..end].to_vec();
368 let train_indices: Vec<usize> = indices[..start]
369 .iter()
370 .chain(indices[end..].iter())
371 .cloned()
372 .collect();
373
374 folds.push((train_indices, test_indices));
375 }
376
377 Ok(folds)
378 }
379
380 CrossValidationStrategy::LeaveOneOut => {
381 let mut folds = Vec::new();
382 for i in 0..n_samples {
383 let test_indices = vec![i];
384 let train_indices: Vec<usize> = (0..n_samples).filter(|&x| x != i).collect();
385 folds.push((train_indices, test_indices));
386 }
387 Ok(folds)
388 }
389
390 CrossValidationStrategy::TimeSeriesSplit {
391 n_splits,
392 max_train_size,
393 } => {
394 let mut folds = Vec::new();
395 let test_size = n_samples / (n_splits + 1);
396
397 for i in 1..=*n_splits {
398 let test_start = i * test_size;
399 let test_end = (test_start + test_size).min(n_samples);
400 let test_indices: Vec<usize> = (test_start..test_end).collect();
401
402 let train_end = test_start;
403 let train_start = if let Some(max_size) = max_train_size {
404 train_end.saturating_sub(*max_size)
405 } else {
406 0
407 };
408 let train_indices: Vec<usize> = (train_start..train_end).collect();
409
410 if !train_indices.is_empty() && !test_indices.is_empty() {
411 folds.push((train_indices, test_indices));
412 }
413 }
414
415 Ok(folds)
416 }
417
418 _ => Err(SklearsError::InvalidInput(
419 "Unsupported CV strategy".to_string(),
420 )),
421 }
422 }
423
424 fn introduce_missing_data(
425 &self,
426 X: &Array2<f64>,
427 rng: &mut impl Rng,
428 ) -> SklResult<(Array2<f64>, Array2<bool>)> {
429 let (n_samples, n_features) = X.dim();
430 let mut X_missing = X.clone();
431 let mut missing_mask = Array2::from_elem((n_samples, n_features), false);
432
433 match &self.missing_pattern {
434 MissingDataPattern::MCAR { missing_rate } => {
435 let n_missing = (n_samples * n_features) as f64 * missing_rate * self.test_fraction;
436 let n_missing = n_missing as usize;
437
438 let mut positions: Vec<(usize, usize)> = Vec::new();
439 for i in 0..n_samples {
440 for j in 0..n_features {
441 positions.push((i, j));
442 }
443 }
444
445 positions.shuffle(rng);
446
447 for &(i, j) in positions.iter().take(n_missing) {
448 X_missing[[i, j]] = f64::NAN;
449 missing_mask[[i, j]] = true;
450 }
451 }
452
453 MissingDataPattern::MAR {
454 missing_rate,
455 dependency_strength: _,
456 } => {
457 if n_features > 1 {
459 let threshold = X.column(0).mean().unwrap_or(0.0);
460
461 for i in 0..n_samples {
462 for j in 1..n_features {
463 let prob = if X[[i, 0]] > threshold {
464 missing_rate * 2.0 * self.test_fraction
465 } else {
466 missing_rate * 0.5 * self.test_fraction
467 };
468
469 if rng.random::<f64>() < prob {
470 X_missing[[i, j]] = f64::NAN;
471 missing_mask[[i, j]] = true;
472 }
473 }
474 }
475 }
476 }
477
478 MissingDataPattern::MNAR {
479 missing_rate,
480 threshold_factor,
481 } => {
482 for j in 0..n_features {
484 let column = X.column(j);
485 let mean = column.mean().unwrap_or(0.0);
486 let std = column.var(0.0).sqrt();
487 let threshold = mean + threshold_factor * std;
488
489 for i in 0..n_samples {
490 let prob = if X[[i, j]] > threshold {
491 missing_rate * 3.0 * self.test_fraction
492 } else {
493 missing_rate * 0.3 * self.test_fraction
494 };
495
496 if rng.random::<f64>() < prob {
497 X_missing[[i, j]] = f64::NAN;
498 missing_mask[[i, j]] = true;
499 }
500 }
501 }
502 }
503
504 _ => {
505 return Err(SklearsError::InvalidInput(
506 "Unsupported missing pattern".to_string(),
507 ));
508 }
509 }
510
511 Ok((X_missing, missing_mask))
512 }
513
514 fn compute_metrics(
515 &self,
516 X_true: &Array2<f64>,
517 X_imputed: &Array2<f64>,
518 missing_mask: &Array2<bool>,
519 ) -> SklResult<ImputationMetrics> {
520 let mut mse_sum = 0.0;
521 let mut mae_sum = 0.0;
522 let mut bias_sum = 0.0;
523 let mut count = 0;
524
525 let mut true_values = Vec::new();
526 let mut imputed_values = Vec::new();
527
528 for ((i, j), &is_missing) in missing_mask.indexed_iter() {
529 if is_missing {
530 let true_val = X_true[[i, j]];
531 let imputed_val = X_imputed[[i, j]];
532
533 if !true_val.is_nan() && !imputed_val.is_nan() {
534 let error = true_val - imputed_val;
535 mse_sum += error * error;
536 mae_sum += error.abs();
537 bias_sum += error;
538 count += 1;
539
540 true_values.push(true_val);
541 imputed_values.push(imputed_val);
542 }
543 }
544 }
545
546 if count == 0 {
547 return Ok(ImputationMetrics {
548 rmse: f64::NAN,
549 mae: f64::NAN,
550 r2: f64::NAN,
551 accuracy: f64::NAN,
552 f1_score: f64::NAN,
553 bias: f64::NAN,
554 coverage: f64::NAN,
555 ks_statistic: f64::NAN,
556 ks_pvalue: f64::NAN,
557 });
558 }
559
560 let mse = mse_sum / count as f64;
561 let rmse = mse.sqrt();
562 let mae = mae_sum / count as f64;
563 let bias = bias_sum / count as f64;
564
565 let true_mean = true_values.iter().sum::<f64>() / true_values.len() as f64;
567 let ss_tot: f64 = true_values.iter().map(|&x| (x - true_mean).powi(2)).sum();
568 let ss_res: f64 = true_values
569 .iter()
570 .zip(imputed_values.iter())
571 .map(|(&t, &p)| (t - p).powi(2))
572 .sum();
573
574 let r2 = if ss_tot > 0.0 {
575 1.0 - ss_res / ss_tot
576 } else {
577 f64::NAN
578 };
579
580 let (ks_statistic, ks_pvalue) = compute_ks_test(&true_values, &imputed_values);
582
583 Ok(ImputationMetrics {
584 rmse,
585 mae,
586 r2,
587 accuracy: f64::NAN, f1_score: f64::NAN, bias,
590 coverage: f64::NAN, ks_statistic,
592 ks_pvalue,
593 })
594 }
595
596 fn compute_mean_metrics(
597 &self,
598 fold_metrics: &[ImputationMetrics],
599 ) -> SklResult<ImputationMetrics> {
600 if fold_metrics.is_empty() {
601 return Err(SklearsError::InvalidInput(
602 "No fold metrics provided".to_string(),
603 ));
604 }
605
606 let n = fold_metrics.len() as f64;
607
608 let rmse = fold_metrics
609 .iter()
610 .map(|m| m.rmse)
611 .filter(|x| !x.is_nan())
612 .sum::<f64>()
613 / n;
614 let mae = fold_metrics
615 .iter()
616 .map(|m| m.mae)
617 .filter(|x| !x.is_nan())
618 .sum::<f64>()
619 / n;
620 let r2 = fold_metrics
621 .iter()
622 .map(|m| m.r2)
623 .filter(|x| !x.is_nan())
624 .sum::<f64>()
625 / n;
626 let bias = fold_metrics
627 .iter()
628 .map(|m| m.bias)
629 .filter(|x| !x.is_nan())
630 .sum::<f64>()
631 / n;
632 let ks_statistic = fold_metrics
633 .iter()
634 .map(|m| m.ks_statistic)
635 .filter(|x| !x.is_nan())
636 .sum::<f64>()
637 / n;
638 let ks_pvalue = fold_metrics
639 .iter()
640 .map(|m| m.ks_pvalue)
641 .filter(|x| !x.is_nan())
642 .sum::<f64>()
643 / n;
644
645 Ok(ImputationMetrics {
646 rmse,
647 mae,
648 r2,
649 accuracy: f64::NAN,
650 f1_score: f64::NAN,
651 bias,
652 coverage: f64::NAN,
653 ks_statistic,
654 ks_pvalue,
655 })
656 }
657
658 fn compute_std_metrics(
659 &self,
660 fold_metrics: &[ImputationMetrics],
661 mean_metrics: &ImputationMetrics,
662 ) -> SklResult<ImputationMetrics> {
663 if fold_metrics.is_empty() {
664 return Err(SklearsError::InvalidInput(
665 "No fold metrics provided".to_string(),
666 ));
667 }
668
669 let n = fold_metrics.len() as f64;
670
671 let rmse_var = fold_metrics
672 .iter()
673 .map(|m| (m.rmse - mean_metrics.rmse).powi(2))
674 .filter(|x| !x.is_nan())
675 .sum::<f64>()
676 / (n - 1.0);
677
678 let mae_var = fold_metrics
679 .iter()
680 .map(|m| (m.mae - mean_metrics.mae).powi(2))
681 .filter(|x| !x.is_nan())
682 .sum::<f64>()
683 / (n - 1.0);
684
685 let r2_var = fold_metrics
686 .iter()
687 .map(|m| (m.r2 - mean_metrics.r2).powi(2))
688 .filter(|x| !x.is_nan())
689 .sum::<f64>()
690 / (n - 1.0);
691
692 let bias_var = fold_metrics
693 .iter()
694 .map(|m| (m.bias - mean_metrics.bias).powi(2))
695 .filter(|x| !x.is_nan())
696 .sum::<f64>()
697 / (n - 1.0);
698
699 Ok(ImputationMetrics {
700 rmse: rmse_var.sqrt(),
701 mae: mae_var.sqrt(),
702 r2: r2_var.sqrt(),
703 accuracy: f64::NAN,
704 f1_score: f64::NAN,
705 bias: bias_var.sqrt(),
706 coverage: f64::NAN,
707 ks_statistic: f64::NAN,
708 ks_pvalue: f64::NAN,
709 })
710 }
711
712 fn compute_confidence_intervals(
713 &self,
714 fold_metrics: &[ImputationMetrics],
715 ) -> SklResult<HashMap<String, (f64, f64)>> {
716 let mut intervals = HashMap::new();
717
718 if fold_metrics.len() < 2 {
719 return Ok(intervals);
720 }
721
722 let _n = fold_metrics.len() as f64;
724 let t_critical = 2.0; let rmse_values: Vec<f64> = fold_metrics
728 .iter()
729 .map(|m| m.rmse)
730 .filter(|x| !x.is_nan())
731 .collect();
732 if !rmse_values.is_empty() {
733 let mean = rmse_values.iter().sum::<f64>() / rmse_values.len() as f64;
734 let std = (rmse_values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
735 / (rmse_values.len() - 1) as f64)
736 .sqrt();
737 let margin = t_critical * std / (rmse_values.len() as f64).sqrt();
738 intervals.insert("rmse".to_string(), (mean - margin, mean + margin));
739 }
740
741 let mae_values: Vec<f64> = fold_metrics
743 .iter()
744 .map(|m| m.mae)
745 .filter(|x| !x.is_nan())
746 .collect();
747 if !mae_values.is_empty() {
748 let mean = mae_values.iter().sum::<f64>() / mae_values.len() as f64;
749 let std = (mae_values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
750 / (mae_values.len() - 1) as f64)
751 .sqrt();
752 let margin = t_critical * std / (mae_values.len() as f64).sqrt();
753 intervals.insert("mae".to_string(), (mean - margin, mean + margin));
754 }
755
756 let r2_values: Vec<f64> = fold_metrics
758 .iter()
759 .map(|m| m.r2)
760 .filter(|x| !x.is_nan())
761 .collect();
762 if !r2_values.is_empty() {
763 let mean = r2_values.iter().sum::<f64>() / r2_values.len() as f64;
764 let std = (r2_values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
765 / (r2_values.len() - 1) as f64)
766 .sqrt();
767 let margin = t_critical * std / (r2_values.len() as f64).sqrt();
768 intervals.insert("r2".to_string(), (mean - margin, mean + margin));
769 }
770
771 Ok(intervals)
772 }
773}
774
775impl Default for ImputationCrossValidator {
776 fn default() -> Self {
777 Self::new()
778 }
779}
780
781fn compute_ks_test(sample1: &[f64], sample2: &[f64]) -> (f64, f64) {
784 if sample1.is_empty() || sample2.is_empty() {
785 return (f64::NAN, f64::NAN);
786 }
787
788 let mut all_values: Vec<f64> = sample1.iter().chain(sample2.iter()).cloned().collect();
790 all_values.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
791 all_values.dedup();
792
793 let mut max_diff = 0.0;
794
795 for &value in &all_values {
796 let cdf1 = sample1.iter().filter(|&&x| x <= value).count() as f64 / sample1.len() as f64;
797 let cdf2 = sample2.iter().filter(|&&x| x <= value).count() as f64 / sample2.len() as f64;
798
799 let diff = (cdf1 - cdf2).abs();
800 if diff > max_diff {
801 max_diff = diff;
802 }
803 }
804
805 let n1 = sample1.len() as f64;
807 let n2 = sample2.len() as f64;
808 let n_eff = (n1 * n2) / (n1 + n2);
809 let lambda = max_diff * n_eff.sqrt();
810
811 let p_value = 2.0 * (-2.0 * lambda * lambda).exp();
813
814 (max_diff, p_value.min(1.0))
815}
816
817#[allow(dead_code)]
822fn validate_with_holdout_disabled<I>(
823 _imputer: &I,
824 _X: &ArrayView2<'_, Float>,
825 _test_size: f64,
826 _missing_pattern: MissingDataPattern,
827 _random_state: Option<u64>,
828) -> SklResult<ImputationMetrics>
829where
830 I: Clone,
831{
832 Err(SklearsError::NotImplemented(
833 "validate_with_holdout temporarily disabled due to HRTB compilation issues".to_string(),
834 ))
835}
836
837impl HoldOutValidator {
838 pub fn new(test_size: f64) -> Self {
840 Self {
841 test_size,
842 missing_pattern: MissingDataPattern::MCAR { missing_rate: 0.2 },
843 random_state: None,
844 stratify: false,
845 }
846 }
847
848 pub fn missing_pattern(mut self, pattern: MissingDataPattern) -> Self {
850 self.missing_pattern = pattern;
851 self
852 }
853
854 pub fn random_state(mut self, random_state: u64) -> Self {
856 self.random_state = Some(random_state);
857 self
858 }
859
860 #[allow(non_snake_case)]
865 #[allow(dead_code)]
866 fn validate_disabled<I>(
867 &self,
868 _imputer: &I,
869 _X: &ArrayView2<'_, Float>,
870 ) -> SklResult<ImputationMetrics>
871 where
872 I: Clone,
873 {
874 Err(SklearsError::NotImplemented(
924 "HoldOutValidator::validate temporarily disabled due to HRTB compilation issues"
925 .to_string(),
926 ))
927 }
928}