Skip to main content

sklears_impute/
validation.rs

1//! Validation framework for imputation methods
2//!
3//! This module provides comprehensive validation tools for assessing the quality
4//! and reliability of imputation methods including cross-validation, hold-out validation,
5//! and synthetic missing data validation.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
8use scirs2_core::random::{Rng, RngExt};
9use scirs2_core::SliceRandomExt;
10use sklears_core::{
11    error::{Result as SklResult, SklearsError},
12    traits::{Fit, Transform},
13    types::Float,
14};
15use std::collections::HashMap;
16
17/// Cross-validation strategies for imputation evaluation
18#[derive(Debug, Clone)]
19pub enum CrossValidationStrategy {
20    /// K-fold cross validation
21    KFold { n_splits: usize, shuffle: bool },
22    /// Stratified K-fold (for datasets with class labels)
23    StratifiedKFold { n_splits: usize, shuffle: bool },
24    /// Leave-one-out cross validation
25    LeaveOneOut,
26    /// Time series split (for temporal data)
27    TimeSeriesSplit {
28        n_splits: usize,
29        max_train_size: Option<usize>,
30    },
31    /// Group-based cross validation
32    GroupKFold { n_splits: usize },
33}
34
35/// Missing data simulation patterns for synthetic validation
36#[derive(Debug, Clone)]
37pub enum MissingDataPattern {
38    /// Missing Completely At Random
39    MCAR { missing_rate: f64 },
40    /// Missing At Random (depends on observed variables)
41    MAR {
42        missing_rate: f64,
43        dependency_strength: f64,
44    },
45    /// Missing Not At Random (depends on unobserved values)
46    MNAR {
47        missing_rate: f64,
48        threshold_factor: f64,
49    },
50    /// Block missing pattern
51    Block {
52        block_size: (usize, usize),
53        n_blocks: usize,
54    },
55    /// Monotone missing pattern
56    Monotone { missing_rates: Vec<f64> },
57}
58
59/// Imputation validation metrics
60#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
61pub struct ImputationMetrics {
62    /// Root Mean Squared Error for continuous variables
63    pub rmse: f64,
64    /// Mean Absolute Error for continuous variables
65    pub mae: f64,
66    /// R-squared correlation coefficient
67    pub r2: f64,
68    /// Accuracy for categorical variables
69    pub accuracy: f64,
70    /// F1 score for categorical variables
71    pub f1_score: f64,
72    /// Bias in imputed values
73    pub bias: f64,
74    /// Coverage of confidence intervals (if provided)
75    pub coverage: f64,
76    /// Kolmogorov-Smirnov test statistic for distribution similarity
77    pub ks_statistic: f64,
78    /// p-value for KS test
79    pub ks_pvalue: f64,
80}
81
82/// Cross-validation results for imputation
83#[derive(Debug, Clone)]
84pub struct CrossValidationResults {
85    /// Metrics for each fold
86    pub fold_metrics: Vec<ImputationMetrics>,
87    /// Mean metrics across all folds
88    pub mean_metrics: ImputationMetrics,
89    /// Standard deviation of metrics across folds
90    pub std_metrics: ImputationMetrics,
91    /// Confidence intervals for metrics (95%)
92    pub confidence_intervals: HashMap<String, (f64, f64)>,
93}
94
95/// Imputation Cross-Validator
96///
97/// Performs cross-validation to assess imputation quality by artificially
98/// creating missing data and evaluating how well the imputation method recovers
99/// the true values.
100///
101/// # Parameters
102///
103/// * `cv_strategy` - Cross-validation strategy to use
104/// * `missing_pattern` - Pattern for creating synthetic missing data
105/// * `test_fraction` - Fraction of observed values to artificially make missing for testing
106/// * `random_state` - Random state for reproducibility
107/// * `n_jobs` - Number of parallel jobs (currently not implemented)
108///
109/// # Examples
110///
111/// ```
112/// use sklears_impute::{ImputationCrossValidator, CrossValidationStrategy, MissingDataPattern};
113/// use sklears_impute::SimpleImputer;
114/// use sklears_core::traits::{Transform};
115/// use scirs2_core::ndarray::array;
116///
117/// let X = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
118///
119/// let cv = ImputationCrossValidator::new()
120///     .cv_strategy(CrossValidationStrategy::KFold { n_splits: 3, shuffle: true })
121///     .missing_pattern(MissingDataPattern::MCAR { missing_rate: 0.2 })
122///     .test_fraction(0.1);
123///
124/// let imputer = SimpleImputer::new().strategy("mean".to_string());
125/// let results = cv.validate_imputer(&imputer, &X.view()).unwrap();
126/// ```
127#[derive(Debug, Clone)]
128pub struct ImputationCrossValidator {
129    cv_strategy: CrossValidationStrategy,
130    missing_pattern: MissingDataPattern,
131    test_fraction: f64,
132    random_state: Option<u64>,
133    n_jobs: usize,
134}
135
136/// Hold-out validator for imputation methods
137///
138/// Validates imputation by holding out a portion of the dataset and evaluating
139/// how well missing values in that portion are imputed.
140#[derive(Debug, Clone)]
141pub struct HoldOutValidator {
142    test_size: f64,
143    missing_pattern: MissingDataPattern,
144    random_state: Option<u64>,
145    stratify: bool,
146}
147
148/// Synthetic missing data validator
149///
150/// Creates synthetic datasets with known missing patterns to test
151/// imputation methods under controlled conditions.
152#[derive(Debug, Clone)]
153pub struct SyntheticMissingValidator {
154    data_generators: Vec<DataGenerator>,
155    missing_patterns: Vec<MissingDataPattern>,
156    n_datasets: usize,
157    dataset_sizes: Vec<(usize, usize)>,
158    random_state: Option<u64>,
159}
160
161/// Data generator for synthetic validation
162#[derive(Debug, Clone)]
163pub enum DataGenerator {
164    /// Multivariate normal data
165    MultivariateNormal { mean: Array1<f64>, cov: Array2<f64> },
166    /// Linear relationships with noise
167    LinearRelationships {
168        coefficients: Array2<f64>,
169        noise_std: f64,
170    },
171    /// Non-linear relationships
172    NonLinear {
173        function_type: String,
174        noise_std: f64,
175    },
176    /// Mixed-type data (continuous + categorical)
177    MixedType {
178        continuous_props: f64,
179        n_categories: Vec<usize>,
180    },
181}
182
183/// Real-world case study validator
184///
185/// Validates imputation methods on real datasets with known complete cases
186/// by artificially introducing missing data.
187#[derive(Debug, Clone)]
188pub struct CaseStudyValidator {
189    case_studies: Vec<CaseStudy>,
190    evaluation_metrics: Vec<String>,
191    comparison_methods: Vec<String>,
192}
193
194/// Case study configuration
195#[derive(Debug, Clone)]
196pub struct CaseStudy {
197    name: String,
198    description: String,
199    data_characteristics: DataCharacteristics,
200    missing_patterns: Vec<MissingDataPattern>,
201    evaluation_criteria: Vec<String>,
202}
203
204/// Data characteristics for case studies
205#[derive(Debug, Clone)]
206pub struct DataCharacteristics {
207    n_samples: usize,
208    n_features: usize,
209    feature_types: Vec<String>, // "continuous", "categorical", "ordinal"
210    correlation_structure: String, // "low", "medium", "high"
211    outlier_fraction: f64,
212    noise_level: f64,
213}
214
215// ImputationCrossValidator implementation
216
217impl ImputationCrossValidator {
218    /// Create a new ImputationCrossValidator
219    pub fn new() -> Self {
220        Self {
221            cv_strategy: CrossValidationStrategy::KFold {
222                n_splits: 5,
223                shuffle: true,
224            },
225            missing_pattern: MissingDataPattern::MCAR { missing_rate: 0.2 },
226            test_fraction: 0.1,
227            random_state: None,
228            n_jobs: 1,
229        }
230    }
231
232    /// Set the cross-validation strategy
233    pub fn cv_strategy(mut self, strategy: CrossValidationStrategy) -> Self {
234        self.cv_strategy = strategy;
235        self
236    }
237
238    /// Set the missing data pattern for synthetic testing
239    pub fn missing_pattern(mut self, pattern: MissingDataPattern) -> Self {
240        self.missing_pattern = pattern;
241        self
242    }
243
244    /// Set the fraction of observed values to use for testing
245    pub fn test_fraction(mut self, fraction: f64) -> Self {
246        self.test_fraction = fraction;
247        self
248    }
249
250    /// Set the random state
251    pub fn random_state(mut self, random_state: u64) -> Self {
252        self.random_state = Some(random_state);
253        self
254    }
255
256    /// Set the number of parallel jobs
257    pub fn n_jobs(mut self, n_jobs: usize) -> Self {
258        self.n_jobs = n_jobs;
259        self
260    }
261
262    /// Validate an imputation method using cross-validation
263    ///
264    /// Note: Temporarily disabled due to HRTB compilation issues.
265    /// This will be re-enabled once the trait bound issues are resolved.
266    #[allow(non_snake_case)]
267    #[allow(dead_code)]
268    fn validate_imputer_disabled<'a, I, F>(
269        &self,
270        _imputer: &I,
271        _X: &ArrayView2<'_, Float>,
272    ) -> SklResult<CrossValidationResults>
273    where
274        I: Clone,
275        I: Fit<ArrayView2<'a, Float>, (), Fitted = F>,
276        F: Transform<ArrayView2<'a, Float>, Array2<Float>>,
277    {
278        /*
279        let X = X.mapv(|x| x);
280        let (n_samples, n_features) = X.dim();
281
282        if n_samples == 0 || n_features == 0 {
283            return Err(SklearsError::InvalidInput("Empty dataset".to_string()));
284        }
285
286        let mut rng = Random::default();
287
288        // Generate fold indices
289        let fold_indices = self.generate_fold_indices(n_samples, &mut rng)?;
290        let mut fold_metrics = Vec::new();
291
292        for (train_indices, test_indices) in fold_indices {
293            // Create training and test sets
294            let mut X_train = Array2::zeros((train_indices.len(), n_features));
295            let mut X_test = Array2::zeros((test_indices.len(), n_features));
296
297            for (i, &idx) in train_indices.iter().enumerate() {
298                X_train.row_mut(i).assign(&X.row(idx));
299            }
300
301            for (i, &idx) in test_indices.iter().enumerate() {
302                X_test.row_mut(i).assign(&X.row(idx));
303            }
304
305            // Introduce synthetic missing data in test set
306            let (X_test_with_missing, missing_mask) =
307                self.introduce_missing_data(&X_test, &mut rng)?;
308
309            // Convert to Float arrays upfront to avoid lifetime issues
310            let X_train_float = X_train.mapv(|x| x as Float);
311            let X_test_missing_float = X_test_with_missing.mapv(|x| x as Float);
312
313            // Train imputer on training data
314            let fitted_imputer = imputer.clone().fit(&X_train_float.view(), &())?;
315
316            // Impute test data
317            let X_test_imputed = fitted_imputer.transform(&X_test_missing_float.view())?;
318
319            // Compute metrics
320            let metrics =
321                self.compute_metrics(&X_test, &X_test_imputed.mapv(|x| x), &missing_mask)?;
322            fold_metrics.push(metrics);
323        }
324
325        // Aggregate results
326        let mean_metrics = self.compute_mean_metrics(&fold_metrics)?;
327        let std_metrics = self.compute_std_metrics(&fold_metrics, &mean_metrics)?;
328        let confidence_intervals = self.compute_confidence_intervals(&fold_metrics)?;
329
330        Ok(CrossValidationResults {
331            fold_metrics,
332            mean_metrics,
333            std_metrics,
334            confidence_intervals,
335        })
336        */
337        // Temporary placeholder until HRTB issues are resolved
338        Err(SklearsError::NotImplemented(
339            "validate_imputer temporarily disabled due to HRTB compilation issues".to_string(),
340        ))
341    }
342
343    fn generate_fold_indices(
344        &self,
345        n_samples: usize,
346        rng: &mut impl Rng,
347    ) -> SklResult<Vec<(Vec<usize>, Vec<usize>)>> {
348        let mut indices: Vec<usize> = (0..n_samples).collect();
349
350        match &self.cv_strategy {
351            CrossValidationStrategy::KFold { n_splits, shuffle } => {
352                if *shuffle {
353                    indices.shuffle(rng);
354                }
355
356                let fold_size = n_samples / n_splits;
357                let mut folds = Vec::new();
358
359                for i in 0..*n_splits {
360                    let start = i * fold_size;
361                    let end = if i == n_splits - 1 {
362                        n_samples
363                    } else {
364                        (i + 1) * fold_size
365                    };
366
367                    let test_indices: Vec<usize> = indices[start..end].to_vec();
368                    let train_indices: Vec<usize> = indices[..start]
369                        .iter()
370                        .chain(indices[end..].iter())
371                        .cloned()
372                        .collect();
373
374                    folds.push((train_indices, test_indices));
375                }
376
377                Ok(folds)
378            }
379
380            CrossValidationStrategy::LeaveOneOut => {
381                let mut folds = Vec::new();
382                for i in 0..n_samples {
383                    let test_indices = vec![i];
384                    let train_indices: Vec<usize> = (0..n_samples).filter(|&x| x != i).collect();
385                    folds.push((train_indices, test_indices));
386                }
387                Ok(folds)
388            }
389
390            CrossValidationStrategy::TimeSeriesSplit {
391                n_splits,
392                max_train_size,
393            } => {
394                let mut folds = Vec::new();
395                let test_size = n_samples / (n_splits + 1);
396
397                for i in 1..=*n_splits {
398                    let test_start = i * test_size;
399                    let test_end = (test_start + test_size).min(n_samples);
400                    let test_indices: Vec<usize> = (test_start..test_end).collect();
401
402                    let train_end = test_start;
403                    let train_start = if let Some(max_size) = max_train_size {
404                        train_end.saturating_sub(*max_size)
405                    } else {
406                        0
407                    };
408                    let train_indices: Vec<usize> = (train_start..train_end).collect();
409
410                    if !train_indices.is_empty() && !test_indices.is_empty() {
411                        folds.push((train_indices, test_indices));
412                    }
413                }
414
415                Ok(folds)
416            }
417
418            _ => Err(SklearsError::InvalidInput(
419                "Unsupported CV strategy".to_string(),
420            )),
421        }
422    }
423
424    fn introduce_missing_data(
425        &self,
426        X: &Array2<f64>,
427        rng: &mut impl Rng,
428    ) -> SklResult<(Array2<f64>, Array2<bool>)> {
429        let (n_samples, n_features) = X.dim();
430        let mut X_missing = X.clone();
431        let mut missing_mask = Array2::from_elem((n_samples, n_features), false);
432
433        match &self.missing_pattern {
434            MissingDataPattern::MCAR { missing_rate } => {
435                let n_missing = (n_samples * n_features) as f64 * missing_rate * self.test_fraction;
436                let n_missing = n_missing as usize;
437
438                let mut positions: Vec<(usize, usize)> = Vec::new();
439                for i in 0..n_samples {
440                    for j in 0..n_features {
441                        positions.push((i, j));
442                    }
443                }
444
445                positions.shuffle(rng);
446
447                for &(i, j) in positions.iter().take(n_missing) {
448                    X_missing[[i, j]] = f64::NAN;
449                    missing_mask[[i, j]] = true;
450                }
451            }
452
453            MissingDataPattern::MAR {
454                missing_rate,
455                dependency_strength: _,
456            } => {
457                // Simplified MAR pattern: missingness depends on first feature
458                if n_features > 1 {
459                    let threshold = X.column(0).mean().unwrap_or(0.0);
460
461                    for i in 0..n_samples {
462                        for j in 1..n_features {
463                            let prob = if X[[i, 0]] > threshold {
464                                missing_rate * 2.0 * self.test_fraction
465                            } else {
466                                missing_rate * 0.5 * self.test_fraction
467                            };
468
469                            if rng.random::<f64>() < prob {
470                                X_missing[[i, j]] = f64::NAN;
471                                missing_mask[[i, j]] = true;
472                            }
473                        }
474                    }
475                }
476            }
477
478            MissingDataPattern::MNAR {
479                missing_rate,
480                threshold_factor,
481            } => {
482                // MNAR: high values are more likely to be missing
483                for j in 0..n_features {
484                    let column = X.column(j);
485                    let mean = column.mean().unwrap_or(0.0);
486                    let std = column.var(0.0).sqrt();
487                    let threshold = mean + threshold_factor * std;
488
489                    for i in 0..n_samples {
490                        let prob = if X[[i, j]] > threshold {
491                            missing_rate * 3.0 * self.test_fraction
492                        } else {
493                            missing_rate * 0.3 * self.test_fraction
494                        };
495
496                        if rng.random::<f64>() < prob {
497                            X_missing[[i, j]] = f64::NAN;
498                            missing_mask[[i, j]] = true;
499                        }
500                    }
501                }
502            }
503
504            _ => {
505                return Err(SklearsError::InvalidInput(
506                    "Unsupported missing pattern".to_string(),
507                ));
508            }
509        }
510
511        Ok((X_missing, missing_mask))
512    }
513
514    fn compute_metrics(
515        &self,
516        X_true: &Array2<f64>,
517        X_imputed: &Array2<f64>,
518        missing_mask: &Array2<bool>,
519    ) -> SklResult<ImputationMetrics> {
520        let mut mse_sum = 0.0;
521        let mut mae_sum = 0.0;
522        let mut bias_sum = 0.0;
523        let mut count = 0;
524
525        let mut true_values = Vec::new();
526        let mut imputed_values = Vec::new();
527
528        for ((i, j), &is_missing) in missing_mask.indexed_iter() {
529            if is_missing {
530                let true_val = X_true[[i, j]];
531                let imputed_val = X_imputed[[i, j]];
532
533                if !true_val.is_nan() && !imputed_val.is_nan() {
534                    let error = true_val - imputed_val;
535                    mse_sum += error * error;
536                    mae_sum += error.abs();
537                    bias_sum += error;
538                    count += 1;
539
540                    true_values.push(true_val);
541                    imputed_values.push(imputed_val);
542                }
543            }
544        }
545
546        if count == 0 {
547            return Ok(ImputationMetrics {
548                rmse: f64::NAN,
549                mae: f64::NAN,
550                r2: f64::NAN,
551                accuracy: f64::NAN,
552                f1_score: f64::NAN,
553                bias: f64::NAN,
554                coverage: f64::NAN,
555                ks_statistic: f64::NAN,
556                ks_pvalue: f64::NAN,
557            });
558        }
559
560        let mse = mse_sum / count as f64;
561        let rmse = mse.sqrt();
562        let mae = mae_sum / count as f64;
563        let bias = bias_sum / count as f64;
564
565        // Compute R-squared
566        let true_mean = true_values.iter().sum::<f64>() / true_values.len() as f64;
567        let ss_tot: f64 = true_values.iter().map(|&x| (x - true_mean).powi(2)).sum();
568        let ss_res: f64 = true_values
569            .iter()
570            .zip(imputed_values.iter())
571            .map(|(&t, &p)| (t - p).powi(2))
572            .sum();
573
574        let r2 = if ss_tot > 0.0 {
575            1.0 - ss_res / ss_tot
576        } else {
577            f64::NAN
578        };
579
580        // Compute KS statistic (simplified)
581        let (ks_statistic, ks_pvalue) = compute_ks_test(&true_values, &imputed_values);
582
583        Ok(ImputationMetrics {
584            rmse,
585            mae,
586            r2,
587            accuracy: f64::NAN, // For continuous data
588            f1_score: f64::NAN, // For continuous data
589            bias,
590            coverage: f64::NAN, // Not computed in this basic version
591            ks_statistic,
592            ks_pvalue,
593        })
594    }
595
596    fn compute_mean_metrics(
597        &self,
598        fold_metrics: &[ImputationMetrics],
599    ) -> SklResult<ImputationMetrics> {
600        if fold_metrics.is_empty() {
601            return Err(SklearsError::InvalidInput(
602                "No fold metrics provided".to_string(),
603            ));
604        }
605
606        let n = fold_metrics.len() as f64;
607
608        let rmse = fold_metrics
609            .iter()
610            .map(|m| m.rmse)
611            .filter(|x| !x.is_nan())
612            .sum::<f64>()
613            / n;
614        let mae = fold_metrics
615            .iter()
616            .map(|m| m.mae)
617            .filter(|x| !x.is_nan())
618            .sum::<f64>()
619            / n;
620        let r2 = fold_metrics
621            .iter()
622            .map(|m| m.r2)
623            .filter(|x| !x.is_nan())
624            .sum::<f64>()
625            / n;
626        let bias = fold_metrics
627            .iter()
628            .map(|m| m.bias)
629            .filter(|x| !x.is_nan())
630            .sum::<f64>()
631            / n;
632        let ks_statistic = fold_metrics
633            .iter()
634            .map(|m| m.ks_statistic)
635            .filter(|x| !x.is_nan())
636            .sum::<f64>()
637            / n;
638        let ks_pvalue = fold_metrics
639            .iter()
640            .map(|m| m.ks_pvalue)
641            .filter(|x| !x.is_nan())
642            .sum::<f64>()
643            / n;
644
645        Ok(ImputationMetrics {
646            rmse,
647            mae,
648            r2,
649            accuracy: f64::NAN,
650            f1_score: f64::NAN,
651            bias,
652            coverage: f64::NAN,
653            ks_statistic,
654            ks_pvalue,
655        })
656    }
657
658    fn compute_std_metrics(
659        &self,
660        fold_metrics: &[ImputationMetrics],
661        mean_metrics: &ImputationMetrics,
662    ) -> SklResult<ImputationMetrics> {
663        if fold_metrics.is_empty() {
664            return Err(SklearsError::InvalidInput(
665                "No fold metrics provided".to_string(),
666            ));
667        }
668
669        let n = fold_metrics.len() as f64;
670
671        let rmse_var = fold_metrics
672            .iter()
673            .map(|m| (m.rmse - mean_metrics.rmse).powi(2))
674            .filter(|x| !x.is_nan())
675            .sum::<f64>()
676            / (n - 1.0);
677
678        let mae_var = fold_metrics
679            .iter()
680            .map(|m| (m.mae - mean_metrics.mae).powi(2))
681            .filter(|x| !x.is_nan())
682            .sum::<f64>()
683            / (n - 1.0);
684
685        let r2_var = fold_metrics
686            .iter()
687            .map(|m| (m.r2 - mean_metrics.r2).powi(2))
688            .filter(|x| !x.is_nan())
689            .sum::<f64>()
690            / (n - 1.0);
691
692        let bias_var = fold_metrics
693            .iter()
694            .map(|m| (m.bias - mean_metrics.bias).powi(2))
695            .filter(|x| !x.is_nan())
696            .sum::<f64>()
697            / (n - 1.0);
698
699        Ok(ImputationMetrics {
700            rmse: rmse_var.sqrt(),
701            mae: mae_var.sqrt(),
702            r2: r2_var.sqrt(),
703            accuracy: f64::NAN,
704            f1_score: f64::NAN,
705            bias: bias_var.sqrt(),
706            coverage: f64::NAN,
707            ks_statistic: f64::NAN,
708            ks_pvalue: f64::NAN,
709        })
710    }
711
712    fn compute_confidence_intervals(
713        &self,
714        fold_metrics: &[ImputationMetrics],
715    ) -> SklResult<HashMap<String, (f64, f64)>> {
716        let mut intervals = HashMap::new();
717
718        if fold_metrics.len() < 2 {
719            return Ok(intervals);
720        }
721
722        // Simple 95% confidence intervals using t-distribution approximation
723        let _n = fold_metrics.len() as f64;
724        let t_critical = 2.0; // Approximation for 95% CI
725
726        // RMSE
727        let rmse_values: Vec<f64> = fold_metrics
728            .iter()
729            .map(|m| m.rmse)
730            .filter(|x| !x.is_nan())
731            .collect();
732        if !rmse_values.is_empty() {
733            let mean = rmse_values.iter().sum::<f64>() / rmse_values.len() as f64;
734            let std = (rmse_values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
735                / (rmse_values.len() - 1) as f64)
736                .sqrt();
737            let margin = t_critical * std / (rmse_values.len() as f64).sqrt();
738            intervals.insert("rmse".to_string(), (mean - margin, mean + margin));
739        }
740
741        // MAE
742        let mae_values: Vec<f64> = fold_metrics
743            .iter()
744            .map(|m| m.mae)
745            .filter(|x| !x.is_nan())
746            .collect();
747        if !mae_values.is_empty() {
748            let mean = mae_values.iter().sum::<f64>() / mae_values.len() as f64;
749            let std = (mae_values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
750                / (mae_values.len() - 1) as f64)
751                .sqrt();
752            let margin = t_critical * std / (mae_values.len() as f64).sqrt();
753            intervals.insert("mae".to_string(), (mean - margin, mean + margin));
754        }
755
756        // R2
757        let r2_values: Vec<f64> = fold_metrics
758            .iter()
759            .map(|m| m.r2)
760            .filter(|x| !x.is_nan())
761            .collect();
762        if !r2_values.is_empty() {
763            let mean = r2_values.iter().sum::<f64>() / r2_values.len() as f64;
764            let std = (r2_values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
765                / (r2_values.len() - 1) as f64)
766                .sqrt();
767            let margin = t_critical * std / (r2_values.len() as f64).sqrt();
768            intervals.insert("r2".to_string(), (mean - margin, mean + margin));
769        }
770
771        Ok(intervals)
772    }
773}
774
775impl Default for ImputationCrossValidator {
776    fn default() -> Self {
777        Self::new()
778    }
779}
780
781// Helper functions
782
783fn compute_ks_test(sample1: &[f64], sample2: &[f64]) -> (f64, f64) {
784    if sample1.is_empty() || sample2.is_empty() {
785        return (f64::NAN, f64::NAN);
786    }
787
788    // Simplified KS test - compute empirical CDFs and maximum difference
789    let mut all_values: Vec<f64> = sample1.iter().chain(sample2.iter()).cloned().collect();
790    all_values.sort_by(|a, b| a.partial_cmp(b).expect("operation should succeed"));
791    all_values.dedup();
792
793    let mut max_diff = 0.0;
794
795    for &value in &all_values {
796        let cdf1 = sample1.iter().filter(|&&x| x <= value).count() as f64 / sample1.len() as f64;
797        let cdf2 = sample2.iter().filter(|&&x| x <= value).count() as f64 / sample2.len() as f64;
798
799        let diff = (cdf1 - cdf2).abs();
800        if diff > max_diff {
801            max_diff = diff;
802        }
803    }
804
805    // Approximate p-value (simplified)
806    let n1 = sample1.len() as f64;
807    let n2 = sample2.len() as f64;
808    let n_eff = (n1 * n2) / (n1 + n2);
809    let lambda = max_diff * n_eff.sqrt();
810
811    // Very simplified p-value approximation
812    let p_value = 2.0 * (-2.0 * lambda * lambda).exp();
813
814    (max_diff, p_value.min(1.0))
815}
816
817/// Validate imputation method with simple hold-out strategy
818///
819/// Note: Temporarily disabled due to HRTB compilation issues.
820/// This will be re-enabled once the trait bound issues are resolved.
821#[allow(dead_code)]
822fn validate_with_holdout_disabled<I>(
823    _imputer: &I,
824    _X: &ArrayView2<'_, Float>,
825    _test_size: f64,
826    _missing_pattern: MissingDataPattern,
827    _random_state: Option<u64>,
828) -> SklResult<ImputationMetrics>
829where
830    I: Clone,
831{
832    Err(SklearsError::NotImplemented(
833        "validate_with_holdout temporarily disabled due to HRTB compilation issues".to_string(),
834    ))
835}
836
837impl HoldOutValidator {
838    /// Create a new HoldOutValidator
839    pub fn new(test_size: f64) -> Self {
840        Self {
841            test_size,
842            missing_pattern: MissingDataPattern::MCAR { missing_rate: 0.2 },
843            random_state: None,
844            stratify: false,
845        }
846    }
847
848    /// Set the missing data pattern
849    pub fn missing_pattern(mut self, pattern: MissingDataPattern) -> Self {
850        self.missing_pattern = pattern;
851        self
852    }
853
854    /// Set the random state
855    pub fn random_state(mut self, random_state: u64) -> Self {
856        self.random_state = Some(random_state);
857        self
858    }
859
860    /// Validate an imputation method
861    ///
862    /// Note: Temporarily disabled due to HRTB compilation issues.
863    /// This will be re-enabled once the trait bound issues are resolved.
864    #[allow(non_snake_case)]
865    #[allow(dead_code)]
866    fn validate_disabled<I>(
867        &self,
868        _imputer: &I,
869        _X: &ArrayView2<'_, Float>,
870    ) -> SklResult<ImputationMetrics>
871    where
872        I: Clone,
873    {
874        /*
875        let X = X.mapv(|x| x);
876        let (n_samples, n_features) = X.dim();
877
878        if n_samples == 0 || n_features == 0 {
879            return Err(SklearsError::InvalidInput("Empty dataset".to_string()));
880        }
881
882        let mut rng = Random::default();
883
884        // Split data
885        let test_size = (n_samples as f64 * self.test_size) as usize;
886        let mut indices: Vec<usize> = (0..n_samples).collect();
887        indices.shuffle(&mut rng);
888
889        let test_indices = &indices[..test_size];
890        let train_indices = &indices[test_size..];
891
892        // Create train and test sets
893        let mut X_train = Array2::zeros((train_indices.len(), n_features));
894        let mut X_test = Array2::zeros((test_indices.len(), n_features));
895
896        for (i, &idx) in train_indices.iter().enumerate() {
897            X_train.row_mut(i).assign(&X.row(idx));
898        }
899
900        for (i, &idx) in test_indices.iter().enumerate() {
901            X_test.row_mut(i).assign(&X.row(idx));
902        }
903
904        // Introduce missing data in test set
905        let cv = ImputationCrossValidator::new()
906            .missing_pattern(self.missing_pattern.clone())
907            .test_fraction(1.0); // Use all test data for validation
908
909        let (X_test_with_missing, missing_mask) = cv.introduce_missing_data(&X_test, &mut rng)?;
910
911        // Convert to Float arrays upfront to avoid lifetime issues
912        let X_train_float = X_train.mapv(|x| x as Float);
913        let X_test_missing_float = X_test_with_missing.mapv(|x| x as Float);
914
915        // Train and apply imputer
916        let fitted_imputer = imputer.clone().fit(&X_train_float.view(), &())?;
917        let X_test_imputed = fitted_imputer.transform(&X_test_missing_float.view())?;
918
919        // Compute metrics
920        cv.compute_metrics(&X_test, &X_test_imputed.mapv(|x| x), &missing_mask)
921        */
922        // Temporary placeholder until HRTB issues are resolved
923        Err(SklearsError::NotImplemented(
924            "HoldOutValidator::validate temporarily disabled due to HRTB compilation issues"
925                .to_string(),
926        ))
927    }
928}