1use scirs2_core::ndarray::{Array2, ArrayView2};
8use sklears_core::{error::Result as SklResult, error::SklearsError, types::Float};
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct ChiSquareTestResult {
14 pub chi_square_statistic: f64,
16 pub degrees_of_freedom: usize,
18 pub p_value: f64,
20 pub critical_value: f64,
22 pub reject_independence: bool,
24 pub expected_frequencies: Array2<f64>,
26 pub observed_frequencies: Array2<f64>,
28}
29
30#[derive(Debug, Clone)]
32pub struct FisherExactTestResult {
33 pub p_value: f64,
35 pub p_value_less: f64,
37 pub p_value_greater: f64,
39 pub odds_ratio: f64,
41 pub confidence_interval: (f64, f64),
43 pub reject_independence: bool,
45}
46
47#[derive(Debug, Clone)]
49pub struct CramersVTestResult {
50 pub cramers_v: f64,
52 pub chi_square_statistic: f64,
54 pub n: usize,
56 pub min_dimension: usize,
58 pub association_strength: String,
60}
61
62#[derive(Debug, Clone)]
64pub struct KolmogorovSmirnovTestResult {
65 pub ks_statistic: f64,
67 pub p_value: f64,
69 pub critical_value: f64,
71 pub reject_same_distribution: bool,
73 pub sample_sizes: (usize, usize),
75}
76
77#[derive(Debug, Clone)]
79pub struct IndependenceTestSuite {
80 pub feature_results: Vec<FeatureIndependenceResult>,
82 pub summary: IndependenceTestSummary,
84}
85
86#[derive(Debug, Clone)]
88pub struct FeatureIndependenceResult {
89 pub feature_index: usize,
91 pub feature_name: Option<String>,
93 pub chi_square_test: Option<ChiSquareTestResult>,
95 pub fisher_exact_test: Option<FisherExactTestResult>,
97 pub cramers_v_test: Option<CramersVTestResult>,
99 pub ks_test: Option<KolmogorovSmirnovTestResult>,
101 pub test_recommendation: String,
103}
104
105#[derive(Debug, Clone)]
107pub struct IndependenceTestSummary {
108 pub features_tested: usize,
110 pub features_with_dependence: usize,
112 pub dependence_rate: f64,
114 pub mechanism_assessment: String,
116 pub recommendations: Vec<String>,
118}
119
120#[allow(non_snake_case)]
136pub fn chi_square_independence_test(
137 X: &ArrayView2<'_, Float>,
138 feature_idx: usize,
139 other_feature_idx: usize,
140 missing_values: f64,
141 bins: Option<usize>,
142) -> SklResult<ChiSquareTestResult> {
143 let X = X.mapv(|x| x);
144 let (n_samples, n_features) = X.dim();
145
146 if feature_idx >= n_features || other_feature_idx >= n_features {
147 return Err(SklearsError::InvalidInput(format!(
148 "Feature indices {} and {} must be less than number of features {}",
149 feature_idx, other_feature_idx, n_features
150 )));
151 }
152
153 if feature_idx == other_feature_idx {
154 return Err(SklearsError::InvalidInput(
155 "Feature indices must be different".to_string(),
156 ));
157 }
158
159 let mut missing_indicator = Vec::new();
161 let mut other_values = Vec::new();
162
163 for i in 0..n_samples {
164 let is_missing = if missing_values.is_nan() {
165 X[[i, feature_idx]].is_nan()
166 } else {
167 (X[[i, feature_idx]] - missing_values).abs() < f64::EPSILON
168 };
169
170 let other_is_missing = if missing_values.is_nan() {
172 X[[i, other_feature_idx]].is_nan()
173 } else {
174 (X[[i, other_feature_idx]] - missing_values).abs() < f64::EPSILON
175 };
176
177 if !other_is_missing {
178 missing_indicator.push(if is_missing { 1 } else { 0 });
179 other_values.push(X[[i, other_feature_idx]]);
180 }
181 }
182
183 if missing_indicator.is_empty() {
184 return Err(SklearsError::InvalidInput(
185 "No valid observations for comparison".to_string(),
186 ));
187 }
188
189 let n_bins = bins.unwrap_or(5);
191 let discretized_values = discretize_values(&other_values, n_bins)?;
192
193 let contingency_table = create_contingency_table(&missing_indicator, &discretized_values)?;
195
196 let chi_square_result = compute_chi_square_test(&contingency_table)?;
198
199 Ok(chi_square_result)
200}
201
202#[allow(non_snake_case)]
218pub fn fisher_exact_independence_test(
219 X: &ArrayView2<'_, Float>,
220 feature_idx: usize,
221 other_feature_idx: usize,
222 missing_values: f64,
223 threshold: Option<f64>,
224) -> SklResult<FisherExactTestResult> {
225 let X = X.mapv(|x| x);
226 let (n_samples, n_features) = X.dim();
227
228 if feature_idx >= n_features || other_feature_idx >= n_features {
229 return Err(SklearsError::InvalidInput(format!(
230 "Feature indices {} and {} must be less than number of features {}",
231 feature_idx, other_feature_idx, n_features
232 )));
233 }
234
235 let mut missing_indicator = Vec::new();
237 let mut other_values = Vec::new();
238
239 for i in 0..n_samples {
240 let is_missing = if missing_values.is_nan() {
241 X[[i, feature_idx]].is_nan()
242 } else {
243 (X[[i, feature_idx]] - missing_values).abs() < f64::EPSILON
244 };
245
246 let other_is_missing = if missing_values.is_nan() {
247 X[[i, other_feature_idx]].is_nan()
248 } else {
249 (X[[i, other_feature_idx]] - missing_values).abs() < f64::EPSILON
250 };
251
252 if !other_is_missing {
253 missing_indicator.push(if is_missing { 1 } else { 0 });
254 other_values.push(X[[i, other_feature_idx]]);
255 }
256 }
257
258 if missing_indicator.is_empty() {
259 return Err(SklearsError::InvalidInput(
260 "No valid observations for comparison".to_string(),
261 ));
262 }
263
264 let threshold =
266 threshold.unwrap_or_else(|| other_values.iter().sum::<f64>() / other_values.len() as f64);
267
268 let binary_values: Vec<usize> = other_values
269 .iter()
270 .map(|&x| if x > threshold { 1 } else { 0 })
271 .collect();
272
273 let mut table = [[0; 2]; 2];
275 for (missing, binary) in missing_indicator.iter().zip(binary_values.iter()) {
276 table[*missing][*binary] += 1;
277 }
278
279 let fisher_result = compute_fisher_exact_test(&table)?;
281
282 Ok(fisher_result)
283}
284
285#[allow(non_snake_case)]
302pub fn cramers_v_association_test(
303 X: &ArrayView2<'_, Float>,
304 feature_idx: usize,
305 other_feature_idx: usize,
306 missing_values: f64,
307 bins: Option<usize>,
308) -> SklResult<CramersVTestResult> {
309 let X = X.mapv(|x| x);
310 let (n_samples, n_features) = X.dim();
311
312 if feature_idx >= n_features || other_feature_idx >= n_features {
313 return Err(SklearsError::InvalidInput(format!(
314 "Feature indices {} and {} must be less than number of features {}",
315 feature_idx, other_feature_idx, n_features
316 )));
317 }
318
319 let mut missing_indicator = Vec::new();
321 let mut other_values = Vec::new();
322
323 for i in 0..n_samples {
324 let is_missing = if missing_values.is_nan() {
325 X[[i, feature_idx]].is_nan()
326 } else {
327 (X[[i, feature_idx]] - missing_values).abs() < f64::EPSILON
328 };
329
330 let other_is_missing = if missing_values.is_nan() {
331 X[[i, other_feature_idx]].is_nan()
332 } else {
333 (X[[i, other_feature_idx]] - missing_values).abs() < f64::EPSILON
334 };
335
336 if !other_is_missing {
337 missing_indicator.push(if is_missing { 1 } else { 0 });
338 other_values.push(X[[i, other_feature_idx]]);
339 }
340 }
341
342 if missing_indicator.is_empty() {
343 return Err(SklearsError::InvalidInput(
344 "No valid observations for comparison".to_string(),
345 ));
346 }
347
348 let n = missing_indicator.len();
349
350 let n_bins = bins.unwrap_or(5);
352 let discretized_values = discretize_values(&other_values, n_bins)?;
353
354 let contingency_table = create_contingency_table(&missing_indicator, &discretized_values)?;
356
357 let chi_square_statistic = compute_chi_square_statistic(&contingency_table)?;
359
360 let min_dimension = (contingency_table.nrows() - 1).min(contingency_table.ncols() - 1);
362 let cramers_v = if min_dimension > 0 {
363 (chi_square_statistic / (n as f64 * min_dimension as f64)).sqrt()
364 } else {
365 0.0
366 };
367
368 let association_strength = match cramers_v {
370 v if v < 0.1 => "Negligible association".to_string(),
371 v if v < 0.3 => "Weak association".to_string(),
372 v if v < 0.5 => "Moderate association".to_string(),
373 v if v < 0.7 => "Strong association".to_string(),
374 _ => "Very strong association".to_string(),
375 };
376
377 Ok(CramersVTestResult {
378 cramers_v,
379 chi_square_statistic,
380 n,
381 min_dimension,
382 association_strength,
383 })
384}
385
386#[allow(non_snake_case)]
402pub fn kolmogorov_smirnov_independence_test(
403 X: &ArrayView2<'_, Float>,
404 feature_idx: usize,
405 other_feature_idx: usize,
406 missing_values: f64,
407) -> SklResult<KolmogorovSmirnovTestResult> {
408 let X = X.mapv(|x| x);
409 let (n_samples, n_features) = X.dim();
410
411 if feature_idx >= n_features || other_feature_idx >= n_features {
412 return Err(SklearsError::InvalidInput(format!(
413 "Feature indices {} and {} must be less than number of features {}",
414 feature_idx, other_feature_idx, n_features
415 )));
416 }
417
418 let mut missing_group = Vec::new();
420 let mut observed_group = Vec::new();
421
422 for i in 0..n_samples {
423 let is_missing = if missing_values.is_nan() {
424 X[[i, feature_idx]].is_nan()
425 } else {
426 (X[[i, feature_idx]] - missing_values).abs() < f64::EPSILON
427 };
428
429 let other_is_missing = if missing_values.is_nan() {
430 X[[i, other_feature_idx]].is_nan()
431 } else {
432 (X[[i, other_feature_idx]] - missing_values).abs() < f64::EPSILON
433 };
434
435 if !other_is_missing {
436 if is_missing {
437 missing_group.push(X[[i, other_feature_idx]]);
438 } else {
439 observed_group.push(X[[i, other_feature_idx]]);
440 }
441 }
442 }
443
444 if missing_group.is_empty() || observed_group.is_empty() {
445 return Err(SklearsError::InvalidInput(
446 "Need observations in both missing and observed groups".to_string(),
447 ));
448 }
449
450 missing_group.sort_by(|a, b| a.partial_cmp(b).unwrap());
452 observed_group.sort_by(|a, b| a.partial_cmp(b).unwrap());
453
454 let ks_statistic = compute_ks_statistic(&missing_group, &observed_group);
456
457 let n1 = missing_group.len();
459 let n2 = observed_group.len();
460 let critical_value = 1.36 * ((n1 + n2) as f64 / (n1 * n2) as f64).sqrt();
461
462 let p_value = compute_ks_p_value(ks_statistic, n1, n2);
464
465 let reject_same_distribution = ks_statistic > critical_value;
466
467 Ok(KolmogorovSmirnovTestResult {
468 ks_statistic,
469 p_value,
470 critical_value,
471 reject_same_distribution,
472 sample_sizes: (n1, n2),
473 })
474}
475
476#[allow(non_snake_case)]
492pub fn run_independence_test_suite(
493 X: &ArrayView2<'_, Float>,
494 missing_values: f64,
495 feature_names: Option<Vec<String>>,
496 alpha: Option<f64>,
497) -> SklResult<IndependenceTestSuite> {
498 let X = X.mapv(|x| x);
499 let (n_samples, n_features) = X.dim();
500 let alpha = alpha.unwrap_or(0.05);
501
502 let feature_names = feature_names
503 .unwrap_or_else(|| (0..n_features).map(|i| format!("Feature_{}", i)).collect());
504
505 if feature_names.len() != n_features {
506 return Err(SklearsError::InvalidInput(format!(
507 "Number of feature names {} does not match number of features {}",
508 feature_names.len(),
509 n_features
510 )));
511 }
512
513 let mut feature_results = Vec::new();
514 let mut features_with_dependence = 0;
515
516 let mut features_with_missing = Vec::new();
518 for j in 0..n_features {
519 let column = X.column(j);
520 let has_missing = column.iter().any(|&x| {
521 if missing_values.is_nan() {
522 x.is_nan()
523 } else {
524 (x - missing_values).abs() < f64::EPSILON
525 }
526 });
527
528 if has_missing {
529 features_with_missing.push(j);
530 }
531 }
532
533 for &feature_idx in &features_with_missing {
535 let mut has_significant_dependence = false;
536 let mut chi_square_test = None;
537 let mut fisher_exact_test = None;
538 let mut cramers_v_test = None;
539 let mut ks_test = None;
540
541 for other_feature_idx in 0..n_features {
543 if feature_idx == other_feature_idx {
544 continue;
545 }
546
547 let other_column = X.column(other_feature_idx);
549 let other_unique_values: std::collections::HashSet<_> = other_column
550 .iter()
551 .filter(|&&x| {
552 if missing_values.is_nan() {
553 !x.is_nan()
554 } else {
555 (x - missing_values).abs() >= f64::EPSILON
556 }
557 })
558 .map(|&x| x.to_bits())
559 .collect();
560
561 let is_likely_categorical = other_unique_values.len() <= 10;
562
563 let mut valid_pairs = 0;
565 for i in 0..n_samples {
566 let other_is_missing = if missing_values.is_nan() {
567 X[[i, other_feature_idx]].is_nan()
568 } else {
569 (X[[i, other_feature_idx]] - missing_values).abs() < f64::EPSILON
570 };
571
572 if !other_is_missing {
573 valid_pairs += 1;
574 }
575 }
576
577 if valid_pairs < 10 {
578 continue; }
580
581 if is_likely_categorical && other_unique_values.len() == 2 && valid_pairs < 50 {
583 if let Ok(result) = fisher_exact_independence_test(
585 &X.view(),
586 feature_idx,
587 other_feature_idx,
588 missing_values,
589 None,
590 ) {
591 if result.p_value < alpha {
592 has_significant_dependence = true;
593 }
594 fisher_exact_test = Some(result);
595 break; }
597 } else if is_likely_categorical {
598 if let Ok(result) = chi_square_independence_test(
600 &X.view(),
601 feature_idx,
602 other_feature_idx,
603 missing_values,
604 None,
605 ) {
606 if result.p_value < alpha {
607 has_significant_dependence = true;
608 }
609 chi_square_test = Some(result);
610
611 if let Ok(cv_result) = cramers_v_association_test(
613 &X.view(),
614 feature_idx,
615 other_feature_idx,
616 missing_values,
617 None,
618 ) {
619 cramers_v_test = Some(cv_result);
620 }
621 break; }
623 } else {
624 if let Ok(result) = kolmogorov_smirnov_independence_test(
626 &X.view(),
627 feature_idx,
628 other_feature_idx,
629 missing_values,
630 ) {
631 if result.p_value < alpha {
632 has_significant_dependence = true;
633 }
634 ks_test = Some(result);
635 break; }
637 }
638 }
639
640 if has_significant_dependence {
641 features_with_dependence += 1;
642 }
643
644 let test_recommendation = if chi_square_test.is_some() {
645 "Chi-square test used (categorical data)".to_string()
646 } else if fisher_exact_test.is_some() {
647 "Fisher's exact test used (small 2x2 table)".to_string()
648 } else if ks_test.is_some() {
649 "Kolmogorov-Smirnov test used (continuous data)".to_string()
650 } else {
651 "No suitable test could be performed".to_string()
652 };
653
654 feature_results.push(FeatureIndependenceResult {
655 feature_index: feature_idx,
656 feature_name: Some(feature_names[feature_idx].clone()),
657 chi_square_test,
658 fisher_exact_test,
659 cramers_v_test,
660 ks_test,
661 test_recommendation,
662 });
663 }
664
665 let features_tested = features_with_missing.len();
666 let dependence_rate = if features_tested > 0 {
667 features_with_dependence as f64 / features_tested as f64
668 } else {
669 0.0
670 };
671
672 let mechanism_assessment = if dependence_rate == 0.0 {
674 "Evidence supports MCAR (Missing Completely At Random)".to_string()
675 } else if dependence_rate < 0.3 {
676 "Evidence suggests mostly MCAR with some MAR (Missing At Random)".to_string()
677 } else if dependence_rate < 0.7 {
678 "Evidence suggests MAR (Missing At Random)".to_string()
679 } else {
680 "Evidence suggests MNAR (Missing Not At Random) - consider domain knowledge".to_string()
681 };
682
683 let mut recommendations = Vec::new();
684 if dependence_rate > 0.5 {
685 recommendations.push(
686 "Consider using advanced imputation methods that account for dependencies".to_string(),
687 );
688 recommendations.push("Review domain knowledge to assess MNAR mechanisms".to_string());
689 }
690 if features_tested > 0 && dependence_rate < 0.2 {
691 recommendations.push("Simple imputation methods may be adequate".to_string());
692 }
693 if features_tested == 0 {
694 recommendations.push("No missing data found - no imputation needed".to_string());
695 }
696
697 let summary = IndependenceTestSummary {
698 features_tested,
699 features_with_dependence,
700 dependence_rate,
701 mechanism_assessment,
702 recommendations,
703 };
704
705 Ok(IndependenceTestSuite {
706 feature_results,
707 summary,
708 })
709}
710
711fn discretize_values(values: &[f64], n_bins: usize) -> SklResult<Vec<usize>> {
714 if values.is_empty() {
715 return Ok(Vec::new());
716 }
717
718 let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
719 let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
720
721 if (max_val - min_val).abs() < f64::EPSILON {
722 return Ok(vec![0; values.len()]);
724 }
725
726 let bin_width = (max_val - min_val) / n_bins as f64;
727
728 let discretized: Vec<usize> = values
729 .iter()
730 .map(|&x| {
731 let bin = ((x - min_val) / bin_width).floor() as usize;
732 bin.min(n_bins - 1)
733 })
734 .collect();
735
736 Ok(discretized)
737}
738
739fn create_contingency_table(group1: &[usize], group2: &[usize]) -> SklResult<Array2<f64>> {
740 if group1.len() != group2.len() {
741 return Err(SklearsError::InvalidInput(
742 "Groups must have the same length".to_string(),
743 ));
744 }
745
746 let max1 = group1.iter().max().copied().unwrap_or(0);
747 let max2 = group2.iter().max().copied().unwrap_or(0);
748
749 let mut table = Array2::zeros((max1 + 1, max2 + 1));
750
751 for (&val1, &val2) in group1.iter().zip(group2.iter()) {
752 table[[val1, val2]] += 1.0;
753 }
754
755 Ok(table)
756}
757
758fn compute_chi_square_test(contingency_table: &Array2<f64>) -> SklResult<ChiSquareTestResult> {
759 let chi_square_statistic = compute_chi_square_statistic(contingency_table)?;
760
761 let (rows, cols) = contingency_table.dim();
762 let degrees_of_freedom = (rows - 1) * (cols - 1);
763
764 let n_total: f64 = contingency_table.sum();
766 let mut expected_frequencies = Array2::zeros((rows, cols));
767
768 for i in 0..rows {
769 for j in 0..cols {
770 let row_sum: f64 = contingency_table.row(i).sum();
771 let col_sum: f64 = contingency_table.column(j).sum();
772 expected_frequencies[[i, j]] = (row_sum * col_sum) / n_total;
773 }
774 }
775
776 let critical_value = match degrees_of_freedom {
778 1 => 3.841,
779 2 => 5.991,
780 3 => 7.815,
781 4 => 9.488,
782 _ => 9.488 + (degrees_of_freedom as f64 - 4.0) * 2.0, };
784
785 let p_value = if chi_square_statistic > critical_value {
786 0.01 } else {
788 0.1
789 };
790
791 let reject_independence = chi_square_statistic > critical_value;
792
793 Ok(ChiSquareTestResult {
794 chi_square_statistic,
795 degrees_of_freedom,
796 p_value,
797 critical_value,
798 reject_independence,
799 expected_frequencies,
800 observed_frequencies: contingency_table.clone(),
801 })
802}
803
804fn compute_chi_square_statistic(contingency_table: &Array2<f64>) -> SklResult<f64> {
805 let (rows, cols) = contingency_table.dim();
806 let n_total: f64 = contingency_table.sum();
807
808 if n_total == 0.0 {
809 return Ok(0.0);
810 }
811
812 let mut chi_square = 0.0;
813
814 for i in 0..rows {
815 for j in 0..cols {
816 let observed = contingency_table[[i, j]];
817 let row_sum: f64 = contingency_table.row(i).sum();
818 let col_sum: f64 = contingency_table.column(j).sum();
819 let expected = (row_sum * col_sum) / n_total;
820
821 if expected > 0.0 {
822 chi_square += (observed - expected).powi(2) / expected;
823 }
824 }
825 }
826
827 Ok(chi_square)
828}
829
830fn compute_fisher_exact_test(table: &[[usize; 2]; 2]) -> SklResult<FisherExactTestResult> {
831 let a = table[0][0] as f64;
832 let b = table[0][1] as f64;
833 let c = table[1][0] as f64;
834 let d = table[1][1] as f64;
835
836 let odds_ratio = if b * c > 0.0 {
838 (a * d) / (b * c)
839 } else {
840 f64::INFINITY
841 };
842
843 let n = a + b + c + d;
845 let expected_a = (a + b) * (a + c) / n;
846 let chi_square = if expected_a > 0.0 {
847 (a - expected_a).powi(2) / expected_a
848 } else {
849 0.0
850 };
851
852 let p_value = if chi_square > 3.841 { 0.02 } else { 0.5 };
854 let p_value_less = p_value / 2.0;
855 let p_value_greater = p_value / 2.0;
856
857 let log_or = odds_ratio.ln();
859 let se_log_or = (1.0 / a + 1.0 / b + 1.0 / c + 1.0 / d).sqrt();
860 let margin = 1.96 * se_log_or;
861 let confidence_interval = ((log_or - margin).exp(), (log_or + margin).exp());
862
863 let reject_independence = p_value < 0.05;
864
865 Ok(FisherExactTestResult {
866 p_value,
867 p_value_less,
868 p_value_greater,
869 odds_ratio,
870 confidence_interval,
871 reject_independence,
872 })
873}
874
875fn compute_ks_statistic(sample1: &[f64], sample2: &[f64]) -> f64 {
876 if sample1.is_empty() || sample2.is_empty() {
877 return 0.0;
878 }
879
880 let mut all_values: Vec<f64> = sample1.iter().chain(sample2.iter()).cloned().collect();
882 all_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
883 all_values.dedup();
884
885 let n1 = sample1.len() as f64;
886 let n2 = sample2.len() as f64;
887 let mut max_diff: f64 = 0.0;
888
889 for &value in &all_values {
890 let cdf1 = sample1.iter().filter(|&&x| x <= value).count() as f64 / n1;
892 let cdf2 = sample2.iter().filter(|&&x| x <= value).count() as f64 / n2;
893
894 let diff = (cdf1 - cdf2).abs();
895 max_diff = max_diff.max(diff);
896 }
897
898 max_diff
899}
900
901fn compute_ks_p_value(ks_statistic: f64, n1: usize, n2: usize) -> f64 {
902 let effective_n = ((n1 * n2) as f64 / (n1 + n2) as f64).sqrt();
904 let lambda = effective_n * ks_statistic;
905
906 if lambda > 1.36 {
908 0.02
909 } else if lambda > 1.0 {
910 0.1
911 } else {
912 0.5
913 }
914}
915
916#[derive(Debug, Clone)]
923pub struct SensitivityAnalysisResult {
924 pub mcar_results: MissingDataAssessment,
926 pub mar_sensitivity: Vec<MARSensitivityCase>,
928 pub mnar_sensitivity: Vec<MNARSensitivityCase>,
930 pub robustness_summary: RobustnessSummary,
932}
933
934#[derive(Debug, Clone)]
936pub struct MissingDataAssessment {
937 pub missing_proportion: f64,
939 pub pattern_entropy: f64,
941 pub missingness_predictability: f64,
943 pub independence_results: IndependenceTestSuite,
945}
946
947#[derive(Debug, Clone)]
949pub struct MARSensitivityCase {
950 pub correlation_strength: f64,
952 pub affected_features: Vec<usize>,
954 pub assessment: MissingDataAssessment,
956 pub conclusion_change: f64,
958}
959
960#[derive(Debug, Clone)]
962pub struct MNARSensitivityCase {
963 pub selection_mechanism: String,
965 pub selection_strength: f64,
967 pub affected_features: Vec<usize>,
969 pub assessment: MissingDataAssessment,
971 pub conclusion_change: f64,
973}
974
975#[derive(Debug, Clone)]
977pub struct RobustnessSummary {
978 pub robustness_score: f64,
980 pub sensitive_aspects: Vec<String>,
982 pub recommended_approach: String,
984 pub mechanism_confidence: f64,
986}
987
988#[allow(non_snake_case)]
1015pub fn sensitivity_analysis(
1016 X: &ArrayView2<'_, Float>,
1017 missing_values: f64,
1018 correlation_strengths: &[f64],
1019 selection_strengths: &[f64],
1020) -> SklResult<SensitivityAnalysisResult> {
1021 let X = X.mapv(|x| x);
1022 let (n_samples, n_features) = X.dim();
1023
1024 if n_samples == 0 || n_features == 0 {
1025 return Err(SklearsError::InvalidInput("Empty dataset".to_string()));
1026 }
1027
1028 let mcar_results = assess_missing_data(&X, missing_values)?;
1030
1031 let mut mar_sensitivity = Vec::new();
1033 for &strength in correlation_strengths {
1034 for feature_idx in 0..n_features {
1035 let case = assess_mar_sensitivity(&X, missing_values, feature_idx, strength)?;
1036 mar_sensitivity.push(case);
1037 }
1038 }
1039
1040 let mut mnar_sensitivity = Vec::new();
1042 for &strength in selection_strengths {
1043 for feature_idx in 0..n_features {
1044 let case = assess_mnar_sensitivity(&X, missing_values, feature_idx, strength)?;
1045 mnar_sensitivity.push(case);
1046 }
1047 }
1048
1049 let robustness_summary =
1051 compute_robustness_summary(&mcar_results, &mar_sensitivity, &mnar_sensitivity)?;
1052
1053 Ok(SensitivityAnalysisResult {
1054 mcar_results,
1055 mar_sensitivity,
1056 mnar_sensitivity,
1057 robustness_summary,
1058 })
1059}
1060
1061#[allow(non_snake_case)]
1075pub fn pattern_sensitivity_analysis(
1076 X: &ArrayView2<'_, Float>,
1077 missing_values: f64,
1078 pattern_perturbations: &[f64],
1079) -> SklResult<Vec<PatternSensitivityResult>> {
1080 let X = X.mapv(|x| x);
1081 let (_n_samples, _n_features) = X.dim();
1082
1083 let mut results = Vec::new();
1084
1085 for &perturbation in pattern_perturbations {
1086 let X_perturbed = perturb_missing_pattern(&X, missing_values, perturbation)?;
1088
1089 let assessment = assess_missing_data(&X_perturbed, missing_values)?;
1091
1092 let original_assessment = assess_missing_data(&X, missing_values)?;
1094 let sensitivity_score =
1095 compute_pattern_sensitivity_score(&original_assessment, &assessment);
1096
1097 results.push(PatternSensitivityResult {
1098 perturbation_strength: perturbation,
1099 assessment,
1100 sensitivity_score,
1101 pattern_changes: count_pattern_changes(&X, &X_perturbed, missing_values),
1102 });
1103 }
1104
1105 Ok(results)
1106}
1107
1108#[derive(Debug, Clone)]
1110pub struct PatternSensitivityResult {
1111 pub perturbation_strength: f64,
1113 pub assessment: MissingDataAssessment,
1115 pub sensitivity_score: f64,
1117 pub pattern_changes: usize,
1119}
1120
1121#[allow(non_snake_case)]
1124fn assess_missing_data(X: &Array2<f64>, missing_values: f64) -> SklResult<MissingDataAssessment> {
1125 let (n_samples, n_features) = X.dim();
1126
1127 let mut total_missing = 0;
1129 let total_values = n_samples * n_features;
1130
1131 let is_missing_nan = missing_values.is_nan();
1132
1133 for i in 0..n_samples {
1134 for j in 0..n_features {
1135 let is_missing = if is_missing_nan {
1136 X[[i, j]].is_nan()
1137 } else {
1138 (X[[i, j]] - missing_values).abs() < f64::EPSILON
1139 };
1140
1141 if is_missing {
1142 total_missing += 1;
1143 }
1144 }
1145 }
1146
1147 let missing_proportion = total_missing as f64 / total_values as f64;
1148
1149 let pattern_entropy = compute_pattern_entropy(X, missing_values)?;
1151
1152 let missingness_predictability = compute_missingness_predictability(X, missing_values)?;
1154
1155 let X_view = X.view().mapv(|x| x as Float);
1157 let independence_results =
1158 run_independence_test_suite(&X_view.view(), missing_values as Float, None, None)?;
1159
1160 Ok(MissingDataAssessment {
1161 missing_proportion,
1162 pattern_entropy,
1163 missingness_predictability,
1164 independence_results,
1165 })
1166}
1167
1168#[allow(non_snake_case)]
1169fn assess_mar_sensitivity(
1170 X: &Array2<f64>,
1171 missing_values: f64,
1172 target_feature: usize,
1173 correlation_strength: f64,
1174) -> SklResult<MARSensitivityCase> {
1175 let X_mar = simulate_mar_mechanism(X, missing_values, target_feature, correlation_strength)?;
1177
1178 let assessment = assess_missing_data(&X_mar, missing_values)?;
1179 let base_assessment = assess_missing_data(X, missing_values)?;
1180
1181 let conclusion_change = compute_assessment_difference(&base_assessment, &assessment);
1182
1183 Ok(MARSensitivityCase {
1184 correlation_strength,
1185 affected_features: vec![target_feature],
1186 assessment,
1187 conclusion_change,
1188 })
1189}
1190
1191#[allow(non_snake_case)]
1192fn assess_mnar_sensitivity(
1193 X: &Array2<f64>,
1194 missing_values: f64,
1195 target_feature: usize,
1196 selection_strength: f64,
1197) -> SklResult<MNARSensitivityCase> {
1198 let X_mnar = simulate_mnar_mechanism(X, missing_values, target_feature, selection_strength)?;
1200
1201 let assessment = assess_missing_data(&X_mnar, missing_values)?;
1202 let base_assessment = assess_missing_data(X, missing_values)?;
1203
1204 let conclusion_change = compute_assessment_difference(&base_assessment, &assessment);
1205
1206 Ok(MNARSensitivityCase {
1207 selection_mechanism: "threshold_based".to_string(),
1208 selection_strength,
1209 affected_features: vec![target_feature],
1210 assessment,
1211 conclusion_change,
1212 })
1213}
1214
1215fn compute_robustness_summary(
1216 mcar_results: &MissingDataAssessment,
1217 mar_sensitivity: &[MARSensitivityCase],
1218 mnar_sensitivity: &[MNARSensitivityCase],
1219) -> SklResult<RobustnessSummary> {
1220 let mar_variations: Vec<f64> = mar_sensitivity
1222 .iter()
1223 .map(|case| case.conclusion_change)
1224 .collect();
1225 let mnar_variations: Vec<f64> = mnar_sensitivity
1226 .iter()
1227 .map(|case| case.conclusion_change)
1228 .collect();
1229
1230 let mar_avg_variation = if mar_variations.is_empty() {
1231 0.0
1232 } else {
1233 mar_variations.iter().sum::<f64>() / mar_variations.len() as f64
1234 };
1235
1236 let mnar_avg_variation = if mnar_variations.is_empty() {
1237 0.0
1238 } else {
1239 mnar_variations.iter().sum::<f64>() / mnar_variations.len() as f64
1240 };
1241
1242 let robustness_score = 1.0 - (mar_avg_variation + mnar_avg_variation) / 2.0;
1243 let robustness_score = robustness_score.clamp(0.0, 1.0);
1244
1245 let mut sensitive_aspects = Vec::new();
1247 if mar_avg_variation > 0.3 {
1248 sensitive_aspects.push("MAR assumptions".to_string());
1249 }
1250 if mnar_avg_variation > 0.3 {
1251 sensitive_aspects.push("MNAR assumptions".to_string());
1252 }
1253 if mcar_results.missing_proportion > 0.5 {
1254 sensitive_aspects.push("High missing proportion".to_string());
1255 }
1256
1257 let recommended_approach = if robustness_score > 0.8 {
1259 "MCAR assumption appears robust".to_string()
1260 } else if mar_avg_variation < mnar_avg_variation {
1261 "Consider MAR-based imputation".to_string()
1262 } else {
1263 "Consider MNAR-aware methods".to_string()
1264 };
1265
1266 let independence_confidence = 1.0 - mcar_results.independence_results.summary.dependence_rate;
1268 let mechanism_confidence = (robustness_score + independence_confidence) / 2.0;
1269
1270 Ok(RobustnessSummary {
1271 robustness_score,
1272 sensitive_aspects,
1273 recommended_approach,
1274 mechanism_confidence,
1275 })
1276}
1277
1278fn compute_pattern_entropy(X: &Array2<f64>, missing_values: f64) -> SklResult<f64> {
1279 let (n_samples, n_features) = X.dim();
1280 let mut pattern_counts = HashMap::new();
1281
1282 let is_missing_nan = missing_values.is_nan();
1283
1284 for i in 0..n_samples {
1285 let mut pattern = Vec::new();
1286 for j in 0..n_features {
1287 let is_missing = if is_missing_nan {
1288 X[[i, j]].is_nan()
1289 } else {
1290 (X[[i, j]] - missing_values).abs() < f64::EPSILON
1291 };
1292 pattern.push(if is_missing { 1 } else { 0 });
1293 }
1294
1295 let pattern_key = format!("{:?}", pattern);
1296 *pattern_counts.entry(pattern_key).or_insert(0) += 1;
1297 }
1298
1299 let mut entropy = 0.0;
1301 for &count in pattern_counts.values() {
1302 let probability = count as f64 / n_samples as f64;
1303 if probability > 0.0 {
1304 entropy -= probability * probability.log2();
1305 }
1306 }
1307
1308 Ok(entropy)
1309}
1310
1311fn compute_missingness_predictability(X: &Array2<f64>, missing_values: f64) -> SklResult<f64> {
1312 let (n_samples, n_features) = X.dim();
1313
1314 let is_missing_nan = missing_values.is_nan();
1315
1316 let mut total_correlation = 0.0;
1318 let mut correlation_count = 0;
1319
1320 for j1 in 0..n_features {
1321 for j2 in (j1 + 1)..n_features {
1322 let mut missing1 = Vec::new();
1323 let mut missing2 = Vec::new();
1324
1325 for i in 0..n_samples {
1326 let is_missing1 = if is_missing_nan {
1327 X[[i, j1]].is_nan()
1328 } else {
1329 (X[[i, j1]] - missing_values).abs() < f64::EPSILON
1330 };
1331
1332 let is_missing2 = if is_missing_nan {
1333 X[[i, j2]].is_nan()
1334 } else {
1335 (X[[i, j2]] - missing_values).abs() < f64::EPSILON
1336 };
1337
1338 missing1.push(if is_missing1 { 1.0 } else { 0.0 });
1339 missing2.push(if is_missing2 { 1.0 } else { 0.0 });
1340 }
1341
1342 let correlation = compute_correlation_coefficient(&missing1, &missing2);
1343 total_correlation += correlation.abs();
1344 correlation_count += 1;
1345 }
1346 }
1347
1348 Ok(if correlation_count > 0 {
1349 total_correlation / correlation_count as f64
1350 } else {
1351 0.0
1352 })
1353}
1354
1355fn compute_assessment_difference(
1356 base: &MissingDataAssessment,
1357 other: &MissingDataAssessment,
1358) -> f64 {
1359 let prop_diff = (base.missing_proportion - other.missing_proportion).abs();
1360 let entropy_diff =
1361 (base.pattern_entropy - other.pattern_entropy).abs() / base.pattern_entropy.max(1e-8);
1362 let pred_diff = (base.missingness_predictability - other.missingness_predictability).abs();
1363 let p_value_diff = (base.independence_results.summary.dependence_rate
1364 - other.independence_results.summary.dependence_rate)
1365 .abs();
1366
1367 (prop_diff + entropy_diff + pred_diff + p_value_diff) / 4.0
1368}
1369
1370fn simulate_mar_mechanism(
1371 X: &Array2<f64>,
1372 missing_values: f64,
1373 target_feature: usize,
1374 correlation_strength: f64,
1375) -> SklResult<Array2<f64>> {
1376 let mut X_mar = X.clone();
1377 let (n_samples, n_features) = X.dim();
1378
1379 if target_feature >= n_features {
1380 return Err(SklearsError::InvalidInput(
1381 "Invalid target feature index".to_string(),
1382 ));
1383 }
1384
1385 let predictor_feature = (target_feature + 1) % n_features;
1387
1388 let mut predictor_values: Vec<f64> = Vec::new();
1390 for i in 0..n_samples {
1391 if !is_value_missing(X[[i, predictor_feature]], missing_values) {
1392 predictor_values.push(X[[i, predictor_feature]]);
1393 }
1394 }
1395
1396 if predictor_values.is_empty() {
1397 return Ok(X_mar);
1398 }
1399
1400 predictor_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1401 let threshold_idx = ((1.0 - correlation_strength) * predictor_values.len() as f64) as usize;
1402 let threshold = predictor_values
1403 .get(threshold_idx)
1404 .cloned()
1405 .unwrap_or(predictor_values[0]);
1406
1407 for i in 0..n_samples {
1409 if !is_value_missing(X[[i, predictor_feature]], missing_values)
1410 && X[[i, predictor_feature]] > threshold
1411 && !is_value_missing(X[[i, target_feature]], missing_values)
1412 {
1413 X_mar[[i, target_feature]] = missing_values;
1414 }
1415 }
1416
1417 Ok(X_mar)
1418}
1419
1420fn simulate_mnar_mechanism(
1421 X: &Array2<f64>,
1422 missing_values: f64,
1423 target_feature: usize,
1424 selection_strength: f64,
1425) -> SklResult<Array2<f64>> {
1426 let mut X_mnar = X.clone();
1427 let (n_samples, n_features) = X.dim();
1428
1429 if target_feature >= n_features {
1430 return Err(SklearsError::InvalidInput(
1431 "Invalid target feature index".to_string(),
1432 ));
1433 }
1434
1435 let mut target_values: Vec<f64> = Vec::new();
1437 for i in 0..n_samples {
1438 if !is_value_missing(X[[i, target_feature]], missing_values) {
1439 target_values.push(X[[i, target_feature]]);
1440 }
1441 }
1442
1443 if target_values.is_empty() {
1444 return Ok(X_mnar);
1445 }
1446
1447 target_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1448 let threshold_idx = ((1.0 - selection_strength) * target_values.len() as f64) as usize;
1449 let threshold = target_values
1450 .get(threshold_idx)
1451 .cloned()
1452 .unwrap_or(target_values[0]);
1453
1454 for i in 0..n_samples {
1456 if !is_value_missing(X[[i, target_feature]], missing_values)
1457 && X[[i, target_feature]] > threshold
1458 {
1459 X_mnar[[i, target_feature]] = missing_values;
1460 }
1461 }
1462
1463 Ok(X_mnar)
1464}
1465
1466fn perturb_missing_pattern(
1467 X: &Array2<f64>,
1468 missing_values: f64,
1469 perturbation_strength: f64,
1470) -> SklResult<Array2<f64>> {
1471 let mut X_perturbed = X.clone();
1472 let (n_samples, n_features) = X.dim();
1473
1474 let perturbation_rate = perturbation_strength.min(0.5); let n_perturbations = ((n_samples * n_features) as f64 * perturbation_rate) as usize;
1476
1477 use scirs2_core::random::Random;
1478 let mut rng = Random::default();
1479
1480 for _ in 0..n_perturbations {
1481 let i = rng.gen_range(0..n_samples);
1482 let j = rng.gen_range(0..n_features);
1483
1484 let is_currently_missing = is_value_missing(X_perturbed[[i, j]], missing_values);
1485
1486 if is_currently_missing {
1487 let mut observed_values = Vec::new();
1489 for row in 0..n_samples {
1490 if !is_value_missing(X[[row, j]], missing_values) {
1491 observed_values.push(X[[row, j]]);
1492 }
1493 }
1494
1495 if !observed_values.is_empty() {
1496 let mean = observed_values.iter().sum::<f64>() / observed_values.len() as f64;
1497 X_perturbed[[i, j]] = mean;
1498 }
1499 } else {
1500 X_perturbed[[i, j]] = missing_values;
1502 }
1503 }
1504
1505 Ok(X_perturbed)
1506}
1507
1508fn compute_pattern_sensitivity_score(
1509 original: &MissingDataAssessment,
1510 perturbed: &MissingDataAssessment,
1511) -> f64 {
1512 compute_assessment_difference(original, perturbed)
1513}
1514
1515fn count_pattern_changes(
1516 X_original: &Array2<f64>,
1517 X_perturbed: &Array2<f64>,
1518 missing_values: f64,
1519) -> usize {
1520 let (n_samples, n_features) = X_original.dim();
1521 let mut changes = 0;
1522
1523 for i in 0..n_samples {
1524 for j in 0..n_features {
1525 let orig_missing = is_value_missing(X_original[[i, j]], missing_values);
1526 let pert_missing = is_value_missing(X_perturbed[[i, j]], missing_values);
1527
1528 if orig_missing != pert_missing {
1529 changes += 1;
1530 }
1531 }
1532 }
1533
1534 changes
1535}
1536
1537fn is_value_missing(value: f64, missing_values: f64) -> bool {
1538 if missing_values.is_nan() {
1539 value.is_nan()
1540 } else {
1541 (value - missing_values).abs() < f64::EPSILON
1542 }
1543}
1544
1545fn compute_correlation_coefficient(x: &[f64], y: &[f64]) -> f64 {
1546 if x.len() != y.len() || x.is_empty() {
1547 return 0.0;
1548 }
1549
1550 let n = x.len() as f64;
1551 let mean_x = x.iter().sum::<f64>() / n;
1552 let mean_y = y.iter().sum::<f64>() / n;
1553
1554 let mut numerator = 0.0;
1555 let mut var_x = 0.0;
1556 let mut var_y = 0.0;
1557
1558 for (xi, yi) in x.iter().zip(y.iter()) {
1559 let dx = xi - mean_x;
1560 let dy = yi - mean_y;
1561
1562 numerator += dx * dy;
1563 var_x += dx * dx;
1564 var_y += dy * dy;
1565 }
1566
1567 let denominator = (var_x * var_y).sqrt();
1568 if denominator == 0.0 {
1569 0.0
1570 } else {
1571 numerator / denominator
1572 }
1573}