sklears_model_selection/
model_comparison.rs

1//! Statistical Model Comparison Tests
2//!
3//! This module provides statistical tests for comparing the performance of different
4//! machine learning models. It includes parametric and non-parametric tests for
5//! model comparison with proper statistical significance assessment.
6//!
7//! Key tests include:
8//! - Paired t-test for continuous performance metrics
9//! - McNemar's test for binary classification performance
10//! - Wilcoxon signed-rank test (non-parametric alternative to paired t-test)
11//! - Friedman test for comparing multiple models across multiple datasets
12//! - Nemenyi post-hoc test for pairwise comparisons after Friedman test
13//! - Cochran's Q test for comparing binary outcomes across multiple models
14
15use scirs2_core::ndarray::{Array1, Array2};
16// use scirs2_core::numeric::Float as FloatTrait;
17use sklears_core::error::{Result, SklearsError};
18use std::fmt::Debug;
19
20/// Statistical test result
21#[derive(Debug, Clone)]
22pub struct StatisticalTestResult {
23    /// Name of the test performed
24    pub test_name: String,
25    /// Test statistic value
26    pub statistic: f64,
27    /// P-value of the test
28    pub p_value: f64,
29    /// Critical value (if applicable)
30    pub critical_value: Option<f64>,
31    /// Degrees of freedom (if applicable)
32    pub degrees_of_freedom: Option<f64>,
33    /// Effect size (if applicable)
34    pub effect_size: Option<f64>,
35    /// Whether the result is statistically significant at α = 0.05
36    pub is_significant: bool,
37    /// Confidence interval for the difference (if applicable)
38    pub confidence_interval: Option<(f64, f64)>,
39    /// Interpretation message
40    pub interpretation: String,
41}
42
43impl StatisticalTestResult {
44    /// Create a new test result
45    pub fn new(test_name: String, statistic: f64, p_value: f64, alpha: f64) -> Self {
46        let is_significant = p_value < alpha;
47        let interpretation = if is_significant {
48            format!(
49                "Statistically significant difference detected (p = {:.4})",
50                p_value
51            )
52        } else {
53            format!(
54                "No statistically significant difference (p = {:.4})",
55                p_value
56            )
57        };
58
59        Self {
60            test_name,
61            statistic,
62            p_value,
63            critical_value: None,
64            degrees_of_freedom: None,
65            effect_size: None,
66            is_significant,
67            confidence_interval: None,
68            interpretation,
69        }
70    }
71
72    /// Set critical value
73    pub fn with_critical_value(mut self, critical_value: f64) -> Self {
74        self.critical_value = Some(critical_value);
75        self
76    }
77
78    /// Set degrees of freedom
79    pub fn with_degrees_of_freedom(mut self, df: f64) -> Self {
80        self.degrees_of_freedom = Some(df);
81        self
82    }
83
84    /// Set effect size
85    pub fn with_effect_size(mut self, effect_size: f64) -> Self {
86        self.effect_size = Some(effect_size);
87        self
88    }
89
90    /// Set confidence interval
91    pub fn with_confidence_interval(mut self, lower: f64, upper: f64) -> Self {
92        self.confidence_interval = Some((lower, upper));
93        self
94    }
95}
96
97/// Paired t-test for comparing two sets of continuous performance scores
98///
99/// This test assumes:
100/// - Paired observations (same test instances for both models)
101/// - Differences are normally distributed
102/// - Continuous data
103pub fn paired_t_test(
104    scores1: &Array1<f64>,
105    scores2: &Array1<f64>,
106    alpha: f64,
107) -> Result<StatisticalTestResult> {
108    if scores1.len() != scores2.len() {
109        return Err(SklearsError::InvalidInput(
110            "Score arrays must have the same length".to_string(),
111        ));
112    }
113
114    let n = scores1.len();
115    if n < 2 {
116        return Err(SklearsError::InvalidInput(
117            "Need at least 2 paired observations".to_string(),
118        ));
119    }
120
121    // Calculate differences
122    let differences: Array1<f64> = scores1 - scores2;
123
124    // Calculate mean and standard deviation of differences
125    let mean_diff = differences.mean().unwrap();
126    let variance = differences.var(1.0); // Sample variance (Bessel's correction)
127    let std_diff = variance.sqrt();
128
129    if std_diff == 0.0 {
130        // When std_diff is zero, all differences are the same (no variance)
131        // We can determine the result without a statistical test
132        let p_value = if mean_diff == 0.0 { 1.0 } else { 0.0 };
133        let significant = p_value < alpha;
134
135        return Ok(StatisticalTestResult {
136            test_name: "Paired t-test (zero variance)".to_string(),
137            statistic: 0.0,
138            p_value,
139            critical_value: None,
140            degrees_of_freedom: Some((n - 1) as f64),
141            effect_size: Some(mean_diff),
142            is_significant: significant,
143            confidence_interval: Some((mean_diff, mean_diff)),
144            interpretation: if mean_diff == 0.0 {
145                "No difference between models (identical performance)".to_string()
146            } else {
147                format!(
148                    "Models differ by {:.6} with zero variance (deterministic difference)",
149                    mean_diff
150                )
151            },
152        });
153    }
154
155    // Calculate t-statistic
156    let standard_error = std_diff / (n as f64).sqrt();
157    let t_statistic = mean_diff / standard_error;
158
159    // Degrees of freedom
160    let df = (n - 1) as f64;
161
162    // Calculate p-value (two-tailed)
163    let p_value = 2.0 * (1.0 - student_t_cdf(t_statistic.abs(), df));
164
165    // Calculate 95% confidence interval for the mean difference
166    let t_critical = inverse_student_t(1.0 - alpha / 2.0, df);
167    let margin_error = t_critical * standard_error;
168    let ci_lower = mean_diff - margin_error;
169    let ci_upper = mean_diff + margin_error;
170
171    // Calculate Cohen's d (effect size)
172    let pooled_std = ((scores1.var(1.0) + scores2.var(1.0)) / 2.0).sqrt();
173    let cohens_d = mean_diff / pooled_std;
174
175    Ok(
176        StatisticalTestResult::new("Paired t-test".to_string(), t_statistic, p_value, alpha)
177            .with_degrees_of_freedom(df)
178            .with_effect_size(cohens_d)
179            .with_confidence_interval(ci_lower, ci_upper)
180            .with_critical_value(t_critical),
181    )
182}
183
184/// McNemar's test for comparing two binary classifiers
185///
186/// This test is used when:
187/// - Comparing two binary classifiers on the same test set
188/// - Testing whether the two classifiers have significantly different error rates
189/// - Data is in the form of a 2x2 contingency table
190pub fn mcnemar_test(
191    correct_a_correct_b: usize,     // Both classifiers correct
192    correct_a_incorrect_b: usize,   // A correct, B incorrect
193    incorrect_a_correct_b: usize,   // A incorrect, B correct
194    incorrect_a_incorrect_b: usize, // Both classifiers incorrect
195    alpha: f64,
196    continuity_correction: bool,
197) -> Result<StatisticalTestResult> {
198    let b = correct_a_incorrect_b as f64;
199    let c = incorrect_a_correct_b as f64;
200    let total = (correct_a_correct_b
201        + correct_a_incorrect_b
202        + incorrect_a_correct_b
203        + incorrect_a_incorrect_b) as f64;
204
205    if total == 0.0 {
206        return Err(SklearsError::InvalidInput(
207            "No observations provided".to_string(),
208        ));
209    }
210
211    // Check if the test assumptions are met
212    if b + c < 10.0 {
213        return Err(SklearsError::InvalidInput(
214            "McNemar's test requires at least 10 discordant pairs".to_string(),
215        ));
216    }
217
218    // Calculate McNemar's statistic
219    let statistic = if continuity_correction {
220        // With continuity correction
221        ((b - c).abs() - 1.0).powi(2) / (b + c)
222    } else {
223        // Without continuity correction
224        (b - c).powi(2) / (b + c)
225    };
226
227    // Calculate p-value using chi-squared distribution with 1 df
228    let p_value = 1.0 - chi_squared_cdf(statistic, 1.0);
229
230    // Critical value for chi-squared with 1 df at given alpha
231    let critical_value = inverse_chi_squared(1.0 - alpha, 1.0);
232
233    let test_name = if continuity_correction {
234        "McNemar's test (with continuity correction)".to_string()
235    } else {
236        "McNemar's test".to_string()
237    };
238
239    Ok(
240        StatisticalTestResult::new(test_name, statistic, p_value, alpha)
241            .with_degrees_of_freedom(1.0)
242            .with_critical_value(critical_value),
243    )
244}
245
246/// Wilcoxon signed-rank test (non-parametric alternative to paired t-test)
247///
248/// This test:
249/// - Does not assume normal distribution
250/// - Tests whether the median difference is zero
251/// - Is more robust to outliers than the t-test
252pub fn wilcoxon_signed_rank_test(
253    scores1: &Array1<f64>,
254    scores2: &Array1<f64>,
255    alpha: f64,
256) -> Result<StatisticalTestResult> {
257    if scores1.len() != scores2.len() {
258        return Err(SklearsError::InvalidInput(
259            "Score arrays must have the same length".to_string(),
260        ));
261    }
262
263    // Calculate differences and filter out zeros
264    let differences: Vec<f64> = scores1
265        .iter()
266        .zip(scores2.iter())
267        .map(|(a, b)| a - b)
268        .filter(|&d| d != 0.0)
269        .collect();
270
271    let n = differences.len();
272    if n < 6 {
273        return Err(SklearsError::InvalidInput(
274            "Wilcoxon test requires at least 6 non-zero differences".to_string(),
275        ));
276    }
277
278    // Calculate absolute differences and their ranks
279    let mut abs_diffs_with_signs: Vec<(f64, f64)> =
280        differences.iter().map(|&d| (d.abs(), d.signum())).collect();
281
282    // Sort by absolute value for ranking
283    abs_diffs_with_signs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
284
285    // Calculate ranks (handle ties by averaging ranks)
286    let mut ranks = vec![0.0; n];
287    let mut i = 0;
288    while i < n {
289        let mut j = i;
290        while j < n && abs_diffs_with_signs[j].0 == abs_diffs_with_signs[i].0 {
291            j += 1;
292        }
293        let avg_rank = (i + j + 1) as f64 / 2.0;
294        for k in i..j {
295            ranks[k] = avg_rank;
296        }
297        i = j;
298    }
299
300    // Calculate sum of positive and negative ranks
301    let mut w_plus = 0.0;
302    let mut w_minus = 0.0;
303
304    for (rank, (_, sign)) in ranks.iter().zip(abs_diffs_with_signs.iter()) {
305        if *sign > 0.0 {
306            w_plus += rank;
307        } else {
308            w_minus += rank;
309        }
310    }
311
312    // Test statistic is the smaller of W+ and W-
313    let w_statistic = w_plus.min(w_minus);
314
315    // For large n (> 20), use normal approximation
316    let p_value = if n > 20 {
317        let expected_w = n as f64 * (n + 1) as f64 / 4.0;
318        let variance_w = n as f64 * (n + 1) as f64 * (2 * n + 1) as f64 / 24.0;
319        let z_score = (w_statistic - expected_w) / variance_w.sqrt();
320        2.0 * (1.0 - standard_normal_cdf(z_score.abs()))
321    } else {
322        // For small n, use exact distribution (simplified approximation)
323        let critical_value = wilcoxon_critical_value(n, alpha);
324        if w_statistic <= critical_value {
325            0.01 // Significant
326        } else {
327            0.10 // Not significant
328        }
329    };
330
331    Ok(StatisticalTestResult::new(
332        "Wilcoxon signed-rank test".to_string(),
333        w_statistic,
334        p_value,
335        alpha,
336    ))
337}
338
339/// Friedman test for comparing multiple models across multiple datasets
340///
341/// This is a non-parametric test for:
342/// - Comparing k models on n datasets
343/// - Testing whether models have significantly different performance ranks
344/// - Extension of Wilcoxon test to more than 2 models
345pub fn friedman_test(
346    performance_matrix: &Array2<f64>, // Rows: datasets, Columns: models
347    alpha: f64,
348) -> Result<StatisticalTestResult> {
349    let (n_datasets, k_models) = performance_matrix.dim();
350
351    if n_datasets < 2 || k_models < 3 {
352        return Err(SklearsError::InvalidInput(
353            "Friedman test requires at least 2 datasets and 3 models".to_string(),
354        ));
355    }
356
357    // Calculate ranks for each dataset
358    let mut rank_matrix = Array2::zeros((n_datasets, k_models));
359
360    for i in 0..n_datasets {
361        let row = performance_matrix.row(i);
362        let mut indexed_scores: Vec<(usize, f64)> = row
363            .iter()
364            .enumerate()
365            .map(|(j, &score)| (j, score))
366            .collect();
367
368        // Sort by score (descending for performance metrics)
369        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
370
371        // Assign ranks (handle ties)
372        let mut current_rank = 1.0;
373        for window in indexed_scores.windows(2) {
374            let (idx1, score1) = window[0];
375            let (idx2, score2) = window[1];
376
377            rank_matrix[[i, idx1]] = current_rank;
378
379            if (score1 - score2).abs() < 1e-10 {
380                // Tie: assign average rank
381                rank_matrix[[i, idx2]] = current_rank;
382            } else {
383                current_rank += 1.0;
384                if window.len() == 2 {
385                    rank_matrix[[i, idx2]] = current_rank;
386                }
387            }
388        }
389    }
390
391    // Calculate rank sums for each model
392    let rank_sums: Array1<f64> = rank_matrix.sum_axis(scirs2_core::ndarray::Axis(0));
393
394    // Calculate Friedman statistic
395    let sum_of_squares: f64 = rank_sums.iter().map(|&r| r * r).sum();
396    let friedman_statistic = (12.0 / (n_datasets as f64 * k_models as f64 * (k_models + 1) as f64))
397        * sum_of_squares
398        - 3.0 * n_datasets as f64 * (k_models + 1) as f64;
399
400    // Calculate p-value using chi-squared distribution
401    let df = (k_models - 1) as f64;
402    let p_value = 1.0 - chi_squared_cdf(friedman_statistic, df);
403
404    let critical_value = inverse_chi_squared(1.0 - alpha, df);
405
406    Ok(StatisticalTestResult::new(
407        "Friedman test".to_string(),
408        friedman_statistic,
409        p_value,
410        alpha,
411    )
412    .with_degrees_of_freedom(df)
413    .with_critical_value(critical_value))
414}
415
416/// Nemenyi post-hoc test for pairwise comparisons after Friedman test
417///
418/// This test is used after a significant Friedman test to determine
419/// which specific pairs of models differ significantly.
420pub fn nemenyi_post_hoc_test(
421    performance_matrix: &Array2<f64>,
422    alpha: f64,
423) -> Result<Vec<(usize, usize, StatisticalTestResult)>> {
424    let (n_datasets, k_models) = performance_matrix.dim();
425
426    if n_datasets < 2 || k_models < 3 {
427        return Err(SklearsError::InvalidInput(
428            "Nemenyi test requires at least 2 datasets and 3 models".to_string(),
429        ));
430    }
431
432    // First calculate average ranks for each model
433    let mut rank_matrix = Array2::zeros((n_datasets, k_models));
434
435    for i in 0..n_datasets {
436        let row = performance_matrix.row(i);
437        let mut indexed_scores: Vec<(usize, f64)> = row
438            .iter()
439            .enumerate()
440            .map(|(j, &score)| (j, score))
441            .collect();
442
443        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
444
445        for (rank, (idx, _)) in indexed_scores.iter().enumerate() {
446            rank_matrix[[i, *idx]] = (rank + 1) as f64;
447        }
448    }
449
450    let average_ranks: Array1<f64> = rank_matrix
451        .mean_axis(scirs2_core::ndarray::Axis(0))
452        .unwrap();
453
454    // Critical difference for Nemenyi test
455    let q_alpha = nemenyi_critical_value(k_models, alpha);
456    let critical_difference =
457        q_alpha * ((k_models * (k_models + 1)) as f64 / (6.0 * n_datasets as f64)).sqrt();
458
459    // Perform pairwise comparisons
460    let mut results = Vec::new();
461
462    for i in 0..k_models {
463        for j in (i + 1)..k_models {
464            let rank_diff = (average_ranks[i] - average_ranks[j]).abs();
465            let is_significant = rank_diff > critical_difference;
466
467            let test_result = StatisticalTestResult {
468                test_name: format!("Nemenyi post-hoc (Model {} vs Model {})", i + 1, j + 1),
469                statistic: rank_diff,
470                p_value: if is_significant { 0.01 } else { 0.10 }, // Simplified
471                critical_value: Some(critical_difference),
472                degrees_of_freedom: None,
473                effect_size: Some(rank_diff),
474                is_significant,
475                confidence_interval: None,
476                interpretation: if is_significant {
477                    format!("Significant difference in ranks: {:.3}", rank_diff)
478                } else {
479                    format!("No significant difference in ranks: {:.3}", rank_diff)
480                },
481            };
482
483            results.push((i, j, test_result));
484        }
485    }
486
487    Ok(results)
488}
489
490/// Multiple model comparison with correction for multiple testing
491pub fn multiple_model_comparison(
492    performance_matrices: &[Array2<f64>], // Multiple performance matrices
493    model_names: &[String],
494    alpha: f64,
495    correction_method: MultipleTestingCorrection,
496) -> Result<ModelComparisonResult> {
497    if performance_matrices.is_empty() {
498        return Err(SklearsError::InvalidInput(
499            "No performance data provided".to_string(),
500        ));
501    }
502
503    let n_models = model_names.len();
504    if n_models < 2 {
505        return Err(SklearsError::InvalidInput(
506            "Need at least 2 models to compare".to_string(),
507        ));
508    }
509
510    let mut pairwise_results = Vec::new();
511    let mut raw_p_values = Vec::new();
512
513    // Perform all pairwise comparisons
514    for (matrix_idx, matrix) in performance_matrices.iter().enumerate() {
515        if matrix.ncols() != n_models {
516            return Err(SklearsError::InvalidInput(format!(
517                "Performance matrix {} has {} models, expected {}",
518                matrix_idx,
519                matrix.ncols(),
520                n_models
521            )));
522        }
523
524        for i in 0..n_models {
525            for j in (i + 1)..n_models {
526                let scores1 = matrix.column(i).to_owned();
527                let scores2 = matrix.column(j).to_owned();
528
529                let test_result = paired_t_test(&scores1, &scores2, alpha)?;
530                raw_p_values.push(test_result.p_value);
531                pairwise_results.push((i, j, matrix_idx, test_result));
532            }
533        }
534    }
535
536    // Apply multiple testing correction
537    let corrected_p_values = match correction_method {
538        MultipleTestingCorrection::Bonferroni => bonferroni_correction(&raw_p_values),
539        MultipleTestingCorrection::BenjaminiHochberg => {
540            benjamini_hochberg_correction(&raw_p_values, alpha)
541        }
542        MultipleTestingCorrection::Holm => holm_correction(&raw_p_values),
543        MultipleTestingCorrection::None => raw_p_values.clone(),
544    };
545
546    // Update significance based on corrected p-values
547    for (result, &corrected_p) in pairwise_results.iter_mut().zip(corrected_p_values.iter()) {
548        result.3.p_value = corrected_p;
549        result.3.is_significant = corrected_p < alpha;
550    }
551
552    let significant_pairs_count = pairwise_results
553        .iter()
554        .filter(|(_, _, _, result)| result.is_significant)
555        .count();
556
557    Ok(ModelComparisonResult {
558        model_names: model_names.to_vec(),
559        pairwise_results,
560        correction_method,
561        alpha,
562        n_comparisons: raw_p_values.len(),
563        significant_pairs: significant_pairs_count,
564    })
565}
566
567/// Multiple testing correction methods
568#[derive(Debug, Clone)]
569pub enum MultipleTestingCorrection {
570    /// No correction applied
571    None,
572    /// Bonferroni correction (most conservative)
573    Bonferroni,
574    /// Benjamini-Hochberg correction (controls FDR)
575    BenjaminiHochberg,
576    /// Holm correction (step-down method)
577    Holm,
578}
579
580/// Result of multiple model comparison
581#[derive(Debug, Clone)]
582pub struct ModelComparisonResult {
583    pub model_names: Vec<String>,
584    pub pairwise_results: Vec<(usize, usize, usize, StatisticalTestResult)>,
585    pub correction_method: MultipleTestingCorrection,
586    pub alpha: f64,
587    pub n_comparisons: usize,
588    pub significant_pairs: usize,
589}
590
591// Helper functions for statistical distributions
592fn student_t_cdf(t: f64, df: f64) -> f64 {
593    // Simplified approximation of Student's t CDF
594    if df > 30.0 {
595        standard_normal_cdf(t)
596    } else {
597        // Use approximation for small df
598        0.5 + 0.5 * (t / (1.0 + t * t / df).sqrt()).tanh()
599    }
600}
601
602fn inverse_student_t(p: f64, df: f64) -> f64 {
603    // Simplified approximation of inverse Student's t
604    if df > 30.0 {
605        inverse_standard_normal(p)
606    } else {
607        // Simplified approximation
608        let z = inverse_standard_normal(p);
609        z * (1.0 + (z * z + 1.0) / (4.0 * df))
610    }
611}
612
613fn standard_normal_cdf(z: f64) -> f64 {
614    // Approximation of standard normal CDF
615    0.5 * (1.0 + erf(z / 2.0_f64.sqrt()))
616}
617
618fn inverse_standard_normal(p: f64) -> f64 {
619    // Approximation of inverse standard normal
620    if p <= 0.0 {
621        return f64::NEG_INFINITY;
622    }
623    if p >= 1.0 {
624        return f64::INFINITY;
625    }
626
627    // Beasley-Springer-Moro algorithm approximation
628    let a = [
629        0.0,
630        -3.969683028665376e+01,
631        2.209460984245205e+02,
632        -2.759285104469687e+02,
633        1.383_577_518_672_69e2,
634        -3.066479806614716e+01,
635    ];
636    let b = [
637        0.0,
638        -5.447609879822406e+01,
639        1.615858368580409e+02,
640        -1.556989798598866e+02,
641        6.680131188771972e+01,
642        -1.328068155288572e+01,
643    ];
644
645    let x = p - 0.5;
646    if x.abs() < 0.42 {
647        let x2 = x * x;
648        let _num = a[4] * x2 + a[3];
649        let den = b[4] * x2 + b[3];
650        x * (((a[2] * x2 + a[1]) * x2 + a[0]) / ((den * x2 + b[2]) * x2 + b[1]))
651    } else {
652        let ln_p = if p > 0.5 { (1.0 - p).ln() } else { p.ln() };
653        let t = (-2.0 * ln_p).sqrt();
654
655        let _num = a[4] * t + a[3];
656        let den = b[4] * t + b[3];
657        let result = t - ((((a[2] * t + a[1]) * t + a[0]) / ((den * t + b[2]) * t + b[1])) / t);
658
659        if p > 0.5 {
660            result
661        } else {
662            -result
663        }
664    }
665}
666
667fn chi_squared_cdf(chi2: f64, df: f64) -> f64 {
668    // Simplified approximation of chi-squared CDF
669    if df == 1.0 {
670        2.0 * standard_normal_cdf(chi2.sqrt()) - 1.0
671    } else if df == 2.0 {
672        1.0 - (-chi2 / 2.0).exp()
673    } else {
674        // Use normal approximation for large df
675        let z = ((2.0 * chi2).sqrt() - (2.0 * df - 1.0).sqrt()) / 2.0_f64.sqrt();
676        standard_normal_cdf(z)
677    }
678}
679
680fn inverse_chi_squared(p: f64, df: f64) -> f64 {
681    // Simplified approximation of inverse chi-squared
682    if df == 1.0 {
683        let z = inverse_standard_normal((p + 1.0) / 2.0);
684        z * z
685    } else if df == 2.0 {
686        -2.0 * (1.0 - p).ln()
687    } else {
688        // Wilson-Hilferty transformation
689        let h = 2.0 / (9.0 * df);
690        let z = inverse_standard_normal(p);
691        df * (1.0 - h + z * (h).sqrt()).powi(3)
692    }
693}
694
695fn erf(x: f64) -> f64 {
696    // Approximation of error function
697    let a = 0.147;
698    let x2 = x * x;
699    let ax2 = a * x2;
700    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
701
702    sign * (1.0 - (-(x2) * (4.0 / std::f64::consts::PI + ax2) / (1.0 + ax2)).exp()).sqrt()
703}
704
705fn wilcoxon_critical_value(n: usize, alpha: f64) -> f64 {
706    // Simplified critical values for Wilcoxon test
707    // In practice, you would use a lookup table
708    match (n, alpha) {
709        (6, a) if a <= 0.05 => 0.0,
710        (7, a) if a <= 0.05 => 2.0,
711        (8, a) if a <= 0.05 => 3.0,
712        (9, a) if a <= 0.05 => 5.0,
713        (10, a) if a <= 0.05 => 8.0,
714        _ => (n * (n + 1) / 4) as f64 * 0.1, // Rough approximation
715    }
716}
717
718fn nemenyi_critical_value(k: usize, alpha: f64) -> f64 {
719    // Simplified critical values for Nemenyi test
720    // In practice, you would use a lookup table
721    match (k, alpha) {
722        (3, a) if a <= 0.05 => 2.394,
723        (4, a) if a <= 0.05 => 2.569,
724        (5, a) if a <= 0.05 => 2.728,
725        (6, a) if a <= 0.05 => 2.850,
726        _ => 2.5 + (k as f64 - 3.0) * 0.1, // Rough approximation
727    }
728}
729
730// Multiple testing correction methods
731fn bonferroni_correction(p_values: &[f64]) -> Vec<f64> {
732    let n = p_values.len() as f64;
733    p_values.iter().map(|&p| (p * n).min(1.0)).collect()
734}
735
736fn benjamini_hochberg_correction(p_values: &[f64], _alpha: f64) -> Vec<f64> {
737    let n = p_values.len();
738    let mut indexed_p: Vec<(usize, f64)> =
739        p_values.iter().enumerate().map(|(i, &p)| (i, p)).collect();
740    indexed_p.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
741
742    let mut corrected = vec![0.0; n];
743    for (rank, (original_idx, p_val)) in indexed_p.iter().enumerate() {
744        let corrected_p = p_val * (n as f64) / ((rank + 1) as f64);
745        corrected[*original_idx] = corrected_p.min(1.0);
746    }
747
748    corrected
749}
750
751fn holm_correction(p_values: &[f64]) -> Vec<f64> {
752    let n = p_values.len();
753    let mut indexed_p: Vec<(usize, f64)> =
754        p_values.iter().enumerate().map(|(i, &p)| (i, p)).collect();
755    indexed_p.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
756
757    let mut corrected = vec![0.0; n];
758    for (rank, (original_idx, p_val)) in indexed_p.iter().enumerate() {
759        let corrected_p = p_val * ((n - rank) as f64);
760        corrected[*original_idx] = corrected_p.min(1.0);
761    }
762
763    corrected
764}
765
766#[allow(non_snake_case)]
767#[cfg(test)]
768mod tests {
769    use super::*;
770    use scirs2_core::ndarray::array;
771
772    #[test]
773    fn test_paired_t_test() {
774        let scores1 = array![0.9, 0.8, 0.85, 0.92, 0.88];
775        let scores2 = array![0.85, 0.75, 0.80, 0.87, 0.83];
776
777        let result = paired_t_test(&scores1, &scores2, 0.05).unwrap();
778        assert_eq!(result.test_name, "Paired t-test");
779        assert!(result.statistic > 0.0); // scores1 > scores2
780        assert!(result.degrees_of_freedom.is_some());
781        assert!(result.effect_size.is_some());
782    }
783
784    #[test]
785    fn test_mcnemar_test() {
786        let result = mcnemar_test(85, 10, 5, 0, 0.05, false).unwrap();
787        assert_eq!(result.test_name, "McNemar's test");
788        assert!(result.statistic > 0.0);
789        assert!(result.degrees_of_freedom == Some(1.0));
790    }
791
792    #[test]
793    fn test_wilcoxon_signed_rank() {
794        let scores1 = array![0.9, 0.8, 0.85, 0.92, 0.88, 0.91, 0.87];
795        let scores2 = array![0.85, 0.75, 0.80, 0.87, 0.83, 0.86, 0.82];
796
797        let result = wilcoxon_signed_rank_test(&scores1, &scores2, 0.05).unwrap();
798        assert_eq!(result.test_name, "Wilcoxon signed-rank test");
799        assert!(result.statistic >= 0.0);
800    }
801
802    #[test]
803    fn test_friedman_test() {
804        let performance = array![
805            [0.9, 0.85, 0.80],
806            [0.88, 0.83, 0.78],
807            [0.92, 0.87, 0.82],
808            [0.89, 0.84, 0.79]
809        ];
810
811        let result = friedman_test(&performance, 0.05).unwrap();
812        assert_eq!(result.test_name, "Friedman test");
813        assert!(result.degrees_of_freedom == Some(2.0));
814    }
815
816    #[test]
817    fn test_bonferroni_correction() {
818        let p_values = vec![0.01, 0.02, 0.03, 0.04];
819        let corrected = bonferroni_correction(&p_values);
820
821        assert_eq!(corrected[0], 0.04);
822        assert_eq!(corrected[1], 0.08);
823        assert_eq!(corrected[2], 0.12);
824        assert_eq!(corrected[3], 0.16);
825    }
826
827    #[test]
828    fn test_statistical_test_result() {
829        let result = StatisticalTestResult::new("Test".to_string(), 2.5, 0.03, 0.05);
830
831        assert!(result.is_significant);
832        assert_eq!(result.p_value, 0.03);
833        assert_eq!(result.statistic, 2.5);
834    }
835}