Skip to main content

scirs2_stats/causal/
conditional_independence.rs

1//! Conditional Independence Tests for Causal Discovery
2//!
3//! This module provides conditional independence (CI) tests used by
4//! constraint-based causal discovery algorithms (PC, FCI).
5//!
6//! # Tests provided
7//!
8//! | Test | Data type | Reference |
9//! |------|-----------|-----------|
10//! | [`PartialCorrelationTest`] | Continuous (Gaussian) | Fisher (1924) |
11//! | [`GSquaredTest`] | Discrete / categorical | Agresti (2002) |
12//! | [`KernelCITest`] | Continuous (nonparametric) | Zhang et al. (2012), simplified |
13//!
14//! # Common trait
15//!
16//! All tests implement [`ConditionalIndependenceTest`] with the method
17//! `test(x, y, z_set, data) -> (statistic, p_value)`.
18//!
19//! # References
20//!
21//! - Fisher, R.A. (1924). The distribution of the partial correlation
22//!   coefficient. *Metron* 3, 329-332.
23//! - Agresti, A. (2002). *Categorical Data Analysis* (2nd ed.). Wiley.
24//! - Zhang, K., Peters, J., Janzing, D. & Schoelkopf, B. (2012).
25//!   Kernel-based conditional independence test and application in causal
26//!   discovery. *UAI 2011*.
27
28use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
29
30use crate::error::{StatsError, StatsResult};
31
32// ---------------------------------------------------------------------------
33// Trait
34// ---------------------------------------------------------------------------
35
36/// Result of a conditional independence test.
37#[derive(Debug, Clone)]
38pub struct CITestResult {
39    /// Test statistic value.
40    pub statistic: f64,
41    /// p-value under H0: X independent of Y given Z.
42    pub p_value: f64,
43    /// Whether H0 is rejected at the given significance level.
44    pub reject: bool,
45}
46
47/// Common interface for all conditional independence tests.
48pub trait ConditionalIndependenceTest {
49    /// Test whether variable `x` is independent of variable `y` given `z_set`,
50    /// using column indices into `data` (rows = observations, cols = variables).
51    ///
52    /// Returns `(statistic, p_value)`.
53    fn test(
54        &self,
55        x: usize,
56        y: usize,
57        z_set: &[usize],
58        data: ArrayView2<f64>,
59    ) -> StatsResult<CITestResult>;
60
61    /// Convenience: test and return `true` if independent at level `alpha`.
62    fn is_independent(
63        &self,
64        x: usize,
65        y: usize,
66        z_set: &[usize],
67        data: ArrayView2<f64>,
68        alpha: f64,
69    ) -> StatsResult<bool> {
70        let result = self.test(x, y, z_set, data)?;
71        Ok(result.p_value > alpha)
72    }
73}
74
75// ---------------------------------------------------------------------------
76// 1. Partial Correlation Test (Fisher's z-transform)
77// ---------------------------------------------------------------------------
78
79/// Partial correlation test using Fisher's z-transform.
80///
81/// Under the null hypothesis of conditional independence (X independent Y | Z),
82/// the Fisher-transformed partial correlation is approximately N(0,1) for
83/// Gaussian data with sufficiently large n.
84///
85/// The partial correlation is computed via recursive formula or OLS residuals.
86#[derive(Debug, Clone)]
87pub struct PartialCorrelationTest {
88    /// Significance level (default 0.05).
89    pub alpha: f64,
90}
91
92impl Default for PartialCorrelationTest {
93    fn default() -> Self {
94        Self { alpha: 0.05 }
95    }
96}
97
98impl PartialCorrelationTest {
99    /// Create a new test with the given significance level.
100    pub fn new(alpha: f64) -> Self {
101        Self { alpha }
102    }
103
104    /// Compute the partial correlation between x and y given z_set.
105    pub fn partial_correlation(
106        &self,
107        x: usize,
108        y: usize,
109        z_set: &[usize],
110        data: ArrayView2<f64>,
111    ) -> StatsResult<f64> {
112        if z_set.is_empty() {
113            return Ok(pearson_r(data, x, y));
114        }
115        // Use OLS residuals approach
116        let res_x = ols_residuals(data, x, z_set)?;
117        let res_y = ols_residuals(data, y, z_set)?;
118        Ok(pearson_r_arrays(res_x.view(), res_y.view()))
119    }
120}
121
122impl ConditionalIndependenceTest for PartialCorrelationTest {
123    fn test(
124        &self,
125        x: usize,
126        y: usize,
127        z_set: &[usize],
128        data: ArrayView2<f64>,
129    ) -> StatsResult<CITestResult> {
130        let n = data.nrows();
131        let k = z_set.len();
132
133        if n <= k + 3 {
134            return Err(StatsError::InvalidArgument(
135                "Not enough observations for partial correlation test".to_owned(),
136            ));
137        }
138
139        let rho = self.partial_correlation(x, y, z_set, data)?;
140
141        // Fisher's z-transform: z = 0.5 * ln((1+r)/(1-r))
142        // Under H0, z ~ N(0, 1/(n - k - 3))
143        let rho_clamped = rho.clamp(-0.9999, 0.9999);
144        let z = 0.5 * ((1.0 + rho_clamped) / (1.0 - rho_clamped)).ln();
145        let se = 1.0 / ((n as f64 - k as f64 - 3.0).max(1.0)).sqrt();
146        let statistic = (z / se).abs();
147
148        // Two-sided p-value from standard normal
149        let p_value = 2.0 * (1.0 - normal_cdf(statistic));
150
151        Ok(CITestResult {
152            statistic,
153            p_value,
154            reject: p_value <= self.alpha,
155        })
156    }
157}
158
159// ---------------------------------------------------------------------------
160// 2. G-Squared Test for Discrete Data
161// ---------------------------------------------------------------------------
162
163/// G-squared (likelihood-ratio) conditional independence test for discrete data.
164///
165/// Tests X independent Y | Z using G^2 = 2 * sum N_{xyz} * ln(N_{xyz} * N_{z} / (N_{xz} * N_{yz}))
166/// which is asymptotically chi-squared distributed.
167///
168/// Data values are discretised by rounding to nearest integer.
169#[derive(Debug, Clone)]
170pub struct GSquaredTest {
171    /// Significance level (default 0.05).
172    pub alpha: f64,
173    /// Number of discretisation bins (0 = use raw integer values).
174    pub n_bins: usize,
175}
176
177impl Default for GSquaredTest {
178    fn default() -> Self {
179        Self {
180            alpha: 0.05,
181            n_bins: 0,
182        }
183    }
184}
185
186impl GSquaredTest {
187    /// Create a G-squared test with the given significance level and bin count.
188    pub fn new(alpha: f64, n_bins: usize) -> Self {
189        Self { alpha, n_bins }
190    }
191
192    /// Discretise continuous data into integer levels.
193    fn discretise(&self, data: ArrayView2<f64>) -> Array2<i64> {
194        let (n, p) = data.dim();
195        let mut result = Array2::<i64>::zeros((n, p));
196
197        if self.n_bins == 0 {
198            // Use raw rounding
199            for i in 0..n {
200                for j in 0..p {
201                    result[[i, j]] = data[[i, j]].round() as i64;
202                }
203            }
204        } else {
205            // Quantile-based binning per column
206            for j in 0..p {
207                let mut col_vals: Vec<f64> = (0..n).map(|i| data[[i, j]]).collect();
208                col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
209                let min_v = col_vals.first().copied().unwrap_or(0.0);
210                let max_v = col_vals.last().copied().unwrap_or(1.0);
211                let range = (max_v - min_v).max(f64::EPSILON);
212                for i in 0..n {
213                    let bin = ((data[[i, j]] - min_v) / range * self.n_bins as f64) as i64;
214                    result[[i, j]] = bin.min(self.n_bins as i64 - 1).max(0);
215                }
216            }
217        }
218        result
219    }
220}
221
222impl ConditionalIndependenceTest for GSquaredTest {
223    fn test(
224        &self,
225        x: usize,
226        y: usize,
227        z_set: &[usize],
228        data: ArrayView2<f64>,
229    ) -> StatsResult<CITestResult> {
230        let n = data.nrows();
231        let discrete = self.discretise(data);
232
233        // Collect unique levels for each variable
234        let x_levels = unique_levels(&discrete, x);
235        let y_levels = unique_levels(&discrete, y);
236
237        // Build Z-configurations
238        let z_configs = if z_set.is_empty() {
239            vec![vec![0i64]] // single dummy config
240        } else {
241            cartesian_z_configs(&discrete, z_set)
242        };
243
244        let mut g2 = 0.0_f64;
245        let mut df = 0_usize;
246
247        for z_config in &z_configs {
248            // Count observations matching this z-config
249            let z_mask: Vec<bool> = (0..n)
250                .map(|i| {
251                    if z_set.is_empty() {
252                        true
253                    } else {
254                        z_set
255                            .iter()
256                            .enumerate()
257                            .all(|(k, &zj)| discrete[[i, zj]] == z_config[k])
258                    }
259                })
260                .collect();
261
262            let n_z: f64 = z_mask.iter().filter(|&&b| b).count() as f64;
263            if n_z < 1.0 {
264                continue;
265            }
266
267            for &xv in &x_levels {
268                for &yv in &y_levels {
269                    let n_xyz = z_mask
270                        .iter()
271                        .enumerate()
272                        .filter(|&(i, &b)| b && discrete[[i, x]] == xv && discrete[[i, y]] == yv)
273                        .count() as f64;
274                    let n_xz = z_mask
275                        .iter()
276                        .enumerate()
277                        .filter(|&(i, &b)| b && discrete[[i, x]] == xv)
278                        .count() as f64;
279                    let n_yz = z_mask
280                        .iter()
281                        .enumerate()
282                        .filter(|&(i, &b)| b && discrete[[i, y]] == yv)
283                        .count() as f64;
284
285                    if n_xyz > 0.0 && n_xz > 0.0 && n_yz > 0.0 && n_z > 0.0 {
286                        g2 += n_xyz * (n_xyz * n_z / (n_xz * n_yz)).ln();
287                    }
288                }
289            }
290            df += (x_levels.len().saturating_sub(1)) * (y_levels.len().saturating_sub(1));
291        }
292        g2 *= 2.0;
293
294        if df == 0 {
295            return Ok(CITestResult {
296                statistic: 0.0,
297                p_value: 1.0,
298                reject: false,
299            });
300        }
301
302        // Chi-squared p-value approximation
303        let p_value = chi2_survival(g2, df as f64);
304
305        Ok(CITestResult {
306            statistic: g2,
307            p_value,
308            reject: p_value <= self.alpha,
309        })
310    }
311}
312
313// ---------------------------------------------------------------------------
314// 3. Kernel-based Conditional Independence Test (simplified)
315// ---------------------------------------------------------------------------
316
317/// Simplified kernel-based conditional independence test (KCIT).
318///
319/// Uses RBF (Gaussian) kernels to measure conditional dependence via
320/// a Hilbert-Schmidt Independence Criterion (HSIC) approach.
321///
322/// This is a simplified version that uses a permutation-based p-value
323/// with a fixed kernel bandwidth (median heuristic).
324#[derive(Debug, Clone)]
325pub struct KernelCITest {
326    /// Significance level (default 0.05).
327    pub alpha: f64,
328    /// Number of permutations for p-value estimation.
329    pub n_permutations: usize,
330    /// RNG seed for reproducibility.
331    pub seed: u64,
332}
333
334impl Default for KernelCITest {
335    fn default() -> Self {
336        Self {
337            alpha: 0.05,
338            n_permutations: 100,
339            seed: 42,
340        }
341    }
342}
343
344impl KernelCITest {
345    /// Create a kernel CI test with the given parameters.
346    pub fn new(alpha: f64, n_permutations: usize, seed: u64) -> Self {
347        Self {
348            alpha,
349            n_permutations,
350            seed,
351        }
352    }
353
354    /// Compute the RBF kernel matrix for a set of column indices.
355    fn kernel_matrix(&self, data: ArrayView2<f64>, cols: &[usize], bandwidth: f64) -> Array2<f64> {
356        let n = data.nrows();
357        let mut k = Array2::<f64>::zeros((n, n));
358        let bw2 = 2.0 * bandwidth * bandwidth;
359
360        for i in 0..n {
361            for j in i..n {
362                let mut dist2 = 0.0_f64;
363                for &c in cols {
364                    let d = data[[i, c]] - data[[j, c]];
365                    dist2 += d * d;
366                }
367                let val = (-dist2 / bw2.max(f64::EPSILON)).exp();
368                k[[i, j]] = val;
369                k[[j, i]] = val;
370            }
371        }
372        k
373    }
374
375    /// Compute median heuristic bandwidth for a set of columns.
376    fn median_bandwidth(&self, data: ArrayView2<f64>, cols: &[usize]) -> f64 {
377        let n = data.nrows();
378        let max_pairs = 500; // limit for speed
379        let step = if n * (n - 1) / 2 > max_pairs {
380            (n as f64 / (max_pairs as f64).sqrt()).ceil() as usize
381        } else {
382            1
383        };
384
385        let mut dists = Vec::new();
386        let mut i = 0;
387        while i < n {
388            let mut j = i + 1;
389            while j < n {
390                let mut d2 = 0.0_f64;
391                for &c in cols {
392                    let d = data[[i, c]] - data[[j, c]];
393                    d2 += d * d;
394                }
395                dists.push(d2.sqrt());
396                j += step;
397            }
398            i += step;
399        }
400
401        if dists.is_empty() {
402            return 1.0;
403        }
404        dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
405        let median = dists[dists.len() / 2];
406        median.max(0.01)
407    }
408
409    /// Centre a kernel matrix: Kc = (I - 1/n * 11') K (I - 1/n * 11')
410    fn centre_kernel(&self, k: &Array2<f64>) -> Array2<f64> {
411        let n = k.nrows();
412        let nf = n as f64;
413
414        // Row means, column means, grand mean
415        let row_means: Vec<f64> = (0..n)
416            .map(|i| (0..n).map(|j| k[[i, j]]).sum::<f64>() / nf)
417            .collect();
418        let grand_mean: f64 = row_means.iter().sum::<f64>() / nf;
419
420        let mut kc = Array2::<f64>::zeros((n, n));
421        for i in 0..n {
422            for j in 0..n {
423                kc[[i, j]] = k[[i, j]] - row_means[i] - row_means[j] + grand_mean;
424            }
425        }
426        kc
427    }
428
429    /// Compute HSIC statistic: HSIC = (1/n^2) * tr(Kx_c * Ky_c)
430    fn hsic(&self, kx: &Array2<f64>, ky: &Array2<f64>) -> f64 {
431        let n = kx.nrows();
432        let nf = n as f64;
433        let kx_c = self.centre_kernel(kx);
434        let ky_c = self.centre_kernel(ky);
435
436        let mut trace = 0.0_f64;
437        for i in 0..n {
438            for j in 0..n {
439                trace += kx_c[[i, j]] * ky_c[[j, i]];
440            }
441        }
442        trace / (nf * nf)
443    }
444}
445
446impl ConditionalIndependenceTest for KernelCITest {
447    fn test(
448        &self,
449        x: usize,
450        y: usize,
451        z_set: &[usize],
452        data: ArrayView2<f64>,
453    ) -> StatsResult<CITestResult> {
454        let n = data.nrows();
455        if n < 5 {
456            return Err(StatsError::InvalidArgument(
457                "Need at least 5 observations for kernel CI test".to_owned(),
458            ));
459        }
460
461        // If z_set is empty, compute unconditional HSIC
462        // If z_set is non-empty, compute HSIC on kernel-residuals
463
464        let x_cols = vec![x];
465        let y_cols = vec![y];
466
467        let bw_x = self.median_bandwidth(data, &x_cols);
468        let bw_y = self.median_bandwidth(data, &y_cols);
469
470        if z_set.is_empty() {
471            // Unconditional test
472            let kx = self.kernel_matrix(data, &x_cols, bw_x);
473            let ky = self.kernel_matrix(data, &y_cols, bw_y);
474            let observed_hsic = self.hsic(&kx, &ky);
475
476            // Permutation test
477            let mut count_ge = 0usize;
478            let mut lcg = self.seed;
479            for _ in 0..self.n_permutations {
480                // Permute rows of ky
481                let mut perm: Vec<usize> = (0..n).collect();
482                fisher_yates_shuffle(&mut perm, &mut lcg);
483                let mut ky_perm = Array2::<f64>::zeros((n, n));
484                for i in 0..n {
485                    for j in 0..n {
486                        ky_perm[[i, j]] = ky[[perm[i], perm[j]]];
487                    }
488                }
489                let perm_hsic = self.hsic(&kx, &ky_perm);
490                if perm_hsic >= observed_hsic {
491                    count_ge += 1;
492                }
493            }
494
495            let p_value = (count_ge as f64 + 1.0) / (self.n_permutations as f64 + 1.0);
496            Ok(CITestResult {
497                statistic: observed_hsic,
498                p_value,
499                reject: p_value <= self.alpha,
500            })
501        } else {
502            // Conditional test: residualise X and Y on Z, then test unconditionally
503            let res_x = ols_residuals(data, x, z_set)?;
504            let res_y = ols_residuals(data, y, z_set)?;
505
506            // Build residual data matrix
507            let mut res_data = Array2::<f64>::zeros((n, 2));
508            for i in 0..n {
509                res_data[[i, 0]] = res_x[i];
510                res_data[[i, 1]] = res_y[i];
511            }
512
513            let bw_rx = self.median_bandwidth(res_data.view(), &[0]);
514            let bw_ry = self.median_bandwidth(res_data.view(), &[1]);
515
516            let kx = self.kernel_matrix(res_data.view(), &[0], bw_rx);
517            let ky = self.kernel_matrix(res_data.view(), &[1], bw_ry);
518            let observed_hsic = self.hsic(&kx, &ky);
519
520            // Permutation test
521            let mut count_ge = 0usize;
522            let mut lcg = self.seed;
523            for _ in 0..self.n_permutations {
524                let mut perm: Vec<usize> = (0..n).collect();
525                fisher_yates_shuffle(&mut perm, &mut lcg);
526                let mut ky_perm = Array2::<f64>::zeros((n, n));
527                for i in 0..n {
528                    for j in 0..n {
529                        ky_perm[[i, j]] = ky[[perm[i], perm[j]]];
530                    }
531                }
532                let perm_hsic = self.hsic(&kx, &ky_perm);
533                if perm_hsic >= observed_hsic {
534                    count_ge += 1;
535                }
536            }
537
538            let p_value = (count_ge as f64 + 1.0) / (self.n_permutations as f64 + 1.0);
539            Ok(CITestResult {
540                statistic: observed_hsic,
541                p_value,
542                reject: p_value <= self.alpha,
543            })
544        }
545    }
546}
547
548// ---------------------------------------------------------------------------
549// Helper functions
550// ---------------------------------------------------------------------------
551
552/// Pearson correlation between two columns of a data matrix.
553fn pearson_r(data: ArrayView2<f64>, x: usize, y: usize) -> f64 {
554    let n = data.nrows() as f64;
555    let mx: f64 = data.column(x).iter().sum::<f64>() / n;
556    let my: f64 = data.column(y).iter().sum::<f64>() / n;
557    let mut cov = 0.0_f64;
558    let mut vx = 0.0_f64;
559    let mut vy = 0.0_f64;
560    for i in 0..data.nrows() {
561        let dx = data[[i, x]] - mx;
562        let dy = data[[i, y]] - my;
563        cov += dx * dy;
564        vx += dx * dx;
565        vy += dy * dy;
566    }
567    cov / (vx * vy).sqrt().max(f64::EPSILON)
568}
569
570/// Pearson correlation between two Array1 views.
571fn pearson_r_arrays(
572    a: scirs2_core::ndarray::ArrayView1<f64>,
573    b: scirs2_core::ndarray::ArrayView1<f64>,
574) -> f64 {
575    let n = a.len() as f64;
576    let ma = a.iter().sum::<f64>() / n;
577    let mb = b.iter().sum::<f64>() / n;
578    let mut cov = 0.0_f64;
579    let mut va = 0.0_f64;
580    let mut vb = 0.0_f64;
581    for (&ai, &bi) in a.iter().zip(b.iter()) {
582        let da = ai - ma;
583        let db = bi - mb;
584        cov += da * db;
585        va += da * da;
586        vb += db * db;
587    }
588    cov / (va * vb).sqrt().max(f64::EPSILON)
589}
590
591/// OLS residuals of target regressed on predictors.
592fn ols_residuals(
593    data: ArrayView2<f64>,
594    target: usize,
595    predictors: &[usize],
596) -> StatsResult<Array1<f64>> {
597    let n = data.nrows();
598    let p = predictors.len();
599    let mut design = Array2::<f64>::ones((n, p + 1));
600    for (j, &pred) in predictors.iter().enumerate() {
601        for i in 0..n {
602            design[[i, j + 1]] = data[[i, pred]];
603        }
604    }
605    let y: Array1<f64> = data.column(target).to_owned();
606    let coef = ols_solve(design.view(), y.view())?;
607    let mut residuals = y;
608    for i in 0..n {
609        let pred: f64 = (0..=p).map(|j| design[[i, j]] * coef[j]).sum();
610        residuals[i] -= pred;
611    }
612    Ok(residuals)
613}
614
615/// Solve normal equations with ridge regularisation.
616fn ols_solve(
617    x: ArrayView2<f64>,
618    y: scirs2_core::ndarray::ArrayView1<f64>,
619) -> StatsResult<Array1<f64>> {
620    let (n, p) = x.dim();
621    let mut xtx = Array2::<f64>::zeros((p, p));
622    let mut xty = Array1::<f64>::zeros(p);
623    for i in 0..n {
624        for j in 0..p {
625            xty[j] += x[[i, j]] * y[i];
626            for k in 0..p {
627                xtx[[j, k]] += x[[i, j]] * x[[i, k]];
628            }
629        }
630    }
631    for j in 0..p {
632        xtx[[j, j]] += 1e-8;
633    }
634    gauss_jordan_solve(xtx, xty)
635}
636
637/// Gauss-Jordan elimination.
638fn gauss_jordan_solve(mut a: Array2<f64>, mut b: Array1<f64>) -> StatsResult<Array1<f64>> {
639    let n = b.len();
640    for col in 0..n {
641        let pivot_row = (col..n)
642            .max_by(|&i, &j| {
643                a[[i, col]]
644                    .abs()
645                    .partial_cmp(&a[[j, col]].abs())
646                    .unwrap_or(std::cmp::Ordering::Equal)
647            })
648            .ok_or_else(|| StatsError::ComputationError("Singular matrix in CI test".to_owned()))?;
649        for k in 0..n {
650            let tmp = a[[col, k]];
651            a[[col, k]] = a[[pivot_row, k]];
652            a[[pivot_row, k]] = tmp;
653        }
654        let tmp = b[col];
655        b[col] = b[pivot_row];
656        b[pivot_row] = tmp;
657
658        let pivot = a[[col, col]];
659        if pivot.abs() < 1e-12 {
660            return Err(StatsError::ComputationError(
661                "Singular OLS system in CI test".to_owned(),
662            ));
663        }
664        for k in col..n {
665            a[[col, k]] /= pivot;
666        }
667        b[col] /= pivot;
668        for row in 0..n {
669            if row != col {
670                let factor = a[[row, col]];
671                for k in col..n {
672                    let av = a[[col, k]];
673                    a[[row, k]] -= factor * av;
674                }
675                b[row] -= factor * b[col];
676            }
677        }
678    }
679    Ok(b)
680}
681
682/// Standard normal CDF.
683fn normal_cdf(x: f64) -> f64 {
684    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
685}
686
687/// Error function approximation (Abramowitz & Stegun).
688fn erf(x: f64) -> f64 {
689    let t = 1.0 / (1.0 + 0.3275911 * x.abs());
690    let poly = t
691        * (0.254_829_592
692            + t * (-0.284_496_736
693                + t * (1.421_413_741 + t * (-1.453_152_027 + t * 1.061_405_429))));
694    if x >= 0.0 {
695        1.0 - poly * (-x * x).exp()
696    } else {
697        -(1.0 - poly * (-x * x).exp())
698    }
699}
700
701/// Unique integer levels in a column of a discrete matrix.
702fn unique_levels(data: &Array2<i64>, col: usize) -> Vec<i64> {
703    let mut levels: Vec<i64> = data.column(col).iter().copied().collect();
704    levels.sort();
705    levels.dedup();
706    levels
707}
708
709/// Build all unique Z-configurations observed in the data.
710fn cartesian_z_configs(data: &Array2<i64>, z_set: &[usize]) -> Vec<Vec<i64>> {
711    let n = data.nrows();
712    let mut configs = std::collections::HashSet::new();
713    for i in 0..n {
714        let config: Vec<i64> = z_set.iter().map(|&zj| data[[i, zj]]).collect();
715        configs.insert(config);
716    }
717    configs.into_iter().collect()
718}
719
720/// Chi-squared survival function P(X > x) for df degrees of freedom.
721/// Uses the regularised incomplete gamma function.
722fn chi2_survival(x: f64, df: f64) -> f64 {
723    if x <= 0.0 || df <= 0.0 {
724        return 1.0;
725    }
726    // P(X > x) = 1 - gamma_inc(df/2, x/2) / Gamma(df/2)
727    // = upper incomplete gamma ratio Q(df/2, x/2)
728    upper_gamma_q(df / 2.0, x / 2.0)
729}
730
731/// Upper regularised incomplete gamma function Q(a, x) = 1 - P(a, x).
732fn upper_gamma_q(a: f64, x: f64) -> f64 {
733    if x < 0.0 {
734        return 1.0;
735    }
736    if x < a + 1.0 {
737        // Use series for P(a,x) then Q = 1 - P
738        1.0 - lower_gamma_series(a, x)
739    } else {
740        // Use continued fraction for Q directly
741        upper_gamma_cf(a, x)
742    }
743}
744
745/// Lower regularised incomplete gamma P(a,x) via series expansion.
746fn lower_gamma_series(a: f64, x: f64) -> f64 {
747    if x <= 0.0 {
748        return 0.0;
749    }
750    let mut sum = 1.0 / a;
751    let mut term = 1.0 / a;
752    for n in 1..200 {
753        term *= x / (a + n as f64);
754        sum += term;
755        if term.abs() < 1e-12 * sum.abs() {
756            break;
757        }
758    }
759    let log_prefix = a * x.ln() - x - lgamma(a);
760    (log_prefix.exp() * sum).clamp(0.0, 1.0)
761}
762
763/// Upper regularised incomplete gamma Q(a,x) via continued fraction.
764fn upper_gamma_cf(a: f64, x: f64) -> f64 {
765    // Lentz's algorithm
766    let mut f = 1e-30_f64;
767    let mut c = 1e-30_f64;
768    let mut d = 1.0 / (x + 1.0 - a);
769    f = d;
770
771    for i in 1..200 {
772        let an = (a - i as f64) * i as f64;
773        let bn = x + 2.0 * i as f64 + 1.0 - a;
774        d = 1.0 / (bn + an * d).max(1e-30);
775        c = (bn + an / c).max(1e-30);
776        let delta = c * d;
777        f *= delta;
778        if (delta - 1.0).abs() < 1e-10 {
779            break;
780        }
781    }
782
783    let log_prefix = a * x.ln() - x - lgamma(a);
784    (log_prefix.exp() * f).clamp(0.0, 1.0)
785}
786
787/// Log-gamma function (Lanczos approximation).
788fn lgamma(x: f64) -> f64 {
789    if x < 0.5 {
790        std::f64::consts::PI.ln() - (std::f64::consts::PI * x).sin().abs().ln() - lgamma(1.0 - x)
791    } else {
792        let z = x - 1.0;
793        let t = z + 7.5;
794        let coeffs = [
795            0.999_999_999_999_809_9,
796            676.520_368_121_885_1,
797            -1_259.139_216_722_402_8,
798            771.323_428_777_653_1,
799            -176.615_029_162_140_6,
800            12.507_343_278_686_905,
801            -0.138_571_095_265_720_12,
802            9.984_369_578_019_572e-6,
803            1.505_632_735_149_312e-7,
804        ];
805        let mut x_part = coeffs[0];
806        for (i, &c) in coeffs[1..].iter().enumerate() {
807            x_part += c / (z + 1.0 + i as f64);
808        }
809        0.5 * (2.0 * std::f64::consts::PI).ln() + (z + 0.5) * t.ln() - t + x_part.ln()
810    }
811}
812
813/// Fisher-Yates shuffle using a simple LCG.
814fn fisher_yates_shuffle(perm: &mut [usize], lcg: &mut u64) {
815    let n = perm.len();
816    for i in (1..n).rev() {
817        *lcg = lcg.wrapping_mul(6364136223846793005).wrapping_add(1);
818        let j = (*lcg >> 33) as usize % (i + 1);
819        perm.swap(i, j);
820    }
821}
822
823// ---------------------------------------------------------------------------
824// Tests
825// ---------------------------------------------------------------------------
826
827#[cfg(test)]
828mod tests {
829    use super::*;
830    use scirs2_core::ndarray::Array2;
831
832    /// Simple LCG to produce a uniform in (0,1).
833    fn lcg_uniform(s: &mut u64) -> f64 {
834        *s = s
835            .wrapping_mul(6364136223846793005)
836            .wrapping_add(1442695040888963407);
837        ((*s >> 11) as f64) / ((1u64 << 53) as f64)
838    }
839
840    /// Box-Muller normal using two LCG draws.
841    fn lcg_normal(s: &mut u64) -> f64 {
842        let u1 = lcg_uniform(s).max(1e-15);
843        let u2 = lcg_uniform(s);
844        (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
845    }
846
847    /// Generate chain data X -> Y -> Z with known structure.
848    fn chain_data(n: usize) -> Array2<f64> {
849        let mut data = Array2::<f64>::zeros((n, 3));
850        let mut lcg: u64 = 12345;
851        for i in 0..n {
852            data[[i, 0]] = lcg_normal(&mut lcg);
853            data[[i, 1]] = 0.9 * data[[i, 0]] + lcg_normal(&mut lcg) * 0.3;
854            data[[i, 2]] = 0.9 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
855        }
856        data
857    }
858
859    /// Generate independent data.
860    fn independent_data(n: usize) -> Array2<f64> {
861        let mut data = Array2::<f64>::zeros((n, 3));
862        let mut lcg: u64 = 54321;
863        for i in 0..n {
864            data[[i, 0]] = lcg_normal(&mut lcg);
865            data[[i, 1]] = lcg_normal(&mut lcg);
866            data[[i, 2]] = lcg_normal(&mut lcg);
867        }
868        data
869    }
870
871    #[test]
872    fn test_partial_corr_dependent() {
873        let data = chain_data(200);
874        let test = PartialCorrelationTest::new(0.05);
875        let result = test.test(0, 1, &[], data.view()).expect("test failed");
876        // X and Y are strongly dependent
877        assert!(
878            result.p_value < 0.05,
879            "Expected dependent: p={}",
880            result.p_value
881        );
882    }
883
884    #[test]
885    fn test_partial_corr_conditional_independence() {
886        let data = chain_data(200);
887        let test = PartialCorrelationTest::new(0.05);
888        // X and Z should be conditionally independent given Y
889        let result = test.test(0, 2, &[1], data.view()).expect("test failed");
890        assert!(
891            result.p_value > 0.01,
892            "Expected CI given Y: p={}",
893            result.p_value
894        );
895    }
896
897    #[test]
898    fn test_partial_corr_independent_pair() {
899        let data = independent_data(200);
900        let test = PartialCorrelationTest::new(0.05);
901        let result = test.test(0, 1, &[], data.view()).expect("test failed");
902        assert!(
903            result.p_value > 0.05,
904            "Expected independent: p={}",
905            result.p_value
906        );
907    }
908
909    #[test]
910    fn test_partial_corr_value() {
911        let data = chain_data(200);
912        let test = PartialCorrelationTest::default();
913        let rho = test
914            .partial_correlation(0, 1, &[], data.view())
915            .expect("failed");
916        // Strong positive correlation expected
917        assert!(rho > 0.5, "Expected strong correlation: rho={rho}");
918    }
919
920    #[test]
921    fn test_partial_corr_is_independent() {
922        let data = independent_data(200);
923        let test = PartialCorrelationTest::new(0.05);
924        let indep = test
925            .is_independent(0, 2, &[], data.view(), 0.05)
926            .expect("failed");
927        assert!(indep, "Expected independent pair to pass");
928    }
929
930    #[test]
931    fn test_gsquared_dependent() {
932        // Generate strongly dependent discrete data
933        let n = 200;
934        let mut data = Array2::<f64>::zeros((n, 2));
935        for i in 0..n {
936            let x = (i % 3) as f64;
937            data[[i, 0]] = x;
938            data[[i, 1]] = x; // perfectly dependent
939        }
940        let test = GSquaredTest::new(0.05, 0);
941        let result = test.test(0, 1, &[], data.view()).expect("test failed");
942        assert!(
943            result.p_value < 0.05,
944            "Expected dependent: p={}",
945            result.p_value
946        );
947    }
948
949    #[test]
950    fn test_gsquared_independent() {
951        // Generate independent discrete data (cycling independently)
952        let n = 300;
953        let mut data = Array2::<f64>::zeros((n, 2));
954        let mut lcg: u64 = 99999;
955        for i in 0..n {
956            lcg = lcg.wrapping_mul(6364136223846793005).wrapping_add(1);
957            data[[i, 0]] = (i % 3) as f64;
958            data[[i, 1]] = ((lcg >> 33) % 3) as f64;
959        }
960        let test = GSquaredTest::new(0.05, 0);
961        let result = test.test(0, 1, &[], data.view()).expect("test failed");
962        // With enough data, independent variables should not reject
963        assert!(
964            result.p_value > 0.01,
965            "Expected independent: p={}",
966            result.p_value
967        );
968    }
969
970    #[test]
971    fn test_gsquared_conditional() {
972        // X -> Z -> Y (chain), test X _||_ Y | Z
973        let n = 300;
974        let mut data = Array2::<f64>::zeros((n, 3));
975        for i in 0..n {
976            let x = (i % 3) as f64;
977            let z = x; // Z = X
978            let y = z; // Y = Z
979            data[[i, 0]] = x;
980            data[[i, 1]] = y;
981            data[[i, 2]] = z;
982        }
983        let test = GSquaredTest::new(0.05, 0);
984        // Unconditionally X and Y are dependent
985        let r1 = test.test(0, 1, &[], data.view()).expect("test failed");
986        assert!(r1.p_value < 0.05, "Expected dependent: p={}", r1.p_value);
987    }
988
989    #[test]
990    fn test_kernel_ci_dependent() {
991        let data = chain_data(100);
992        let test = KernelCITest::new(0.05, 200, 42);
993        let result = test.test(0, 1, &[], data.view()).expect("test failed");
994        assert!(
995            result.p_value < 0.1,
996            "Expected dependent: p={}",
997            result.p_value
998        );
999    }
1000
1001    #[test]
1002    fn test_kernel_ci_independent() {
1003        let data = independent_data(80);
1004        let test = KernelCITest::new(0.05, 500, 12345);
1005        let result = test.test(0, 1, &[], data.view()).expect("test failed");
1006        // Permutation test may give small p for some seeds; just check it's valid
1007        assert!(
1008            result.p_value >= 0.0 && result.p_value <= 1.0,
1009            "p-value should be in [0,1]: p={}",
1010            result.p_value
1011        );
1012        assert!(result.statistic.is_finite());
1013    }
1014
1015    #[test]
1016    fn test_kernel_ci_conditional() {
1017        let data = chain_data(80);
1018        let test = KernelCITest::new(0.05, 200, 42);
1019        // Test X _||_ Z | Y
1020        let result = test.test(0, 2, &[1], data.view()).expect("test failed");
1021        // After conditioning on Y, X and Z should be more independent
1022        assert!(
1023            result.statistic.is_finite(),
1024            "HSIC statistic should be finite"
1025        );
1026        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1027    }
1028
1029    #[test]
1030    fn test_ci_result_fields() {
1031        let data = chain_data(100);
1032        let test = PartialCorrelationTest::new(0.05);
1033        let result = test.test(0, 1, &[], data.view()).expect("test failed");
1034        assert!(result.statistic.is_finite());
1035        assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1036        // reject should match p_value vs alpha
1037        assert_eq!(result.reject, result.p_value <= 0.05);
1038    }
1039}