1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
29
30use crate::error::{StatsError, StatsResult};
31
32#[derive(Debug, Clone)]
38pub struct CITestResult {
39 pub statistic: f64,
41 pub p_value: f64,
43 pub reject: bool,
45}
46
47pub trait ConditionalIndependenceTest {
49 fn test(
54 &self,
55 x: usize,
56 y: usize,
57 z_set: &[usize],
58 data: ArrayView2<f64>,
59 ) -> StatsResult<CITestResult>;
60
61 fn is_independent(
63 &self,
64 x: usize,
65 y: usize,
66 z_set: &[usize],
67 data: ArrayView2<f64>,
68 alpha: f64,
69 ) -> StatsResult<bool> {
70 let result = self.test(x, y, z_set, data)?;
71 Ok(result.p_value > alpha)
72 }
73}
74
75#[derive(Debug, Clone)]
87pub struct PartialCorrelationTest {
88 pub alpha: f64,
90}
91
92impl Default for PartialCorrelationTest {
93 fn default() -> Self {
94 Self { alpha: 0.05 }
95 }
96}
97
98impl PartialCorrelationTest {
99 pub fn new(alpha: f64) -> Self {
101 Self { alpha }
102 }
103
104 pub fn partial_correlation(
106 &self,
107 x: usize,
108 y: usize,
109 z_set: &[usize],
110 data: ArrayView2<f64>,
111 ) -> StatsResult<f64> {
112 if z_set.is_empty() {
113 return Ok(pearson_r(data, x, y));
114 }
115 let res_x = ols_residuals(data, x, z_set)?;
117 let res_y = ols_residuals(data, y, z_set)?;
118 Ok(pearson_r_arrays(res_x.view(), res_y.view()))
119 }
120}
121
122impl ConditionalIndependenceTest for PartialCorrelationTest {
123 fn test(
124 &self,
125 x: usize,
126 y: usize,
127 z_set: &[usize],
128 data: ArrayView2<f64>,
129 ) -> StatsResult<CITestResult> {
130 let n = data.nrows();
131 let k = z_set.len();
132
133 if n <= k + 3 {
134 return Err(StatsError::InvalidArgument(
135 "Not enough observations for partial correlation test".to_owned(),
136 ));
137 }
138
139 let rho = self.partial_correlation(x, y, z_set, data)?;
140
141 let rho_clamped = rho.clamp(-0.9999, 0.9999);
144 let z = 0.5 * ((1.0 + rho_clamped) / (1.0 - rho_clamped)).ln();
145 let se = 1.0 / ((n as f64 - k as f64 - 3.0).max(1.0)).sqrt();
146 let statistic = (z / se).abs();
147
148 let p_value = 2.0 * (1.0 - normal_cdf(statistic));
150
151 Ok(CITestResult {
152 statistic,
153 p_value,
154 reject: p_value <= self.alpha,
155 })
156 }
157}
158
159#[derive(Debug, Clone)]
170pub struct GSquaredTest {
171 pub alpha: f64,
173 pub n_bins: usize,
175}
176
177impl Default for GSquaredTest {
178 fn default() -> Self {
179 Self {
180 alpha: 0.05,
181 n_bins: 0,
182 }
183 }
184}
185
186impl GSquaredTest {
187 pub fn new(alpha: f64, n_bins: usize) -> Self {
189 Self { alpha, n_bins }
190 }
191
192 fn discretise(&self, data: ArrayView2<f64>) -> Array2<i64> {
194 let (n, p) = data.dim();
195 let mut result = Array2::<i64>::zeros((n, p));
196
197 if self.n_bins == 0 {
198 for i in 0..n {
200 for j in 0..p {
201 result[[i, j]] = data[[i, j]].round() as i64;
202 }
203 }
204 } else {
205 for j in 0..p {
207 let mut col_vals: Vec<f64> = (0..n).map(|i| data[[i, j]]).collect();
208 col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
209 let min_v = col_vals.first().copied().unwrap_or(0.0);
210 let max_v = col_vals.last().copied().unwrap_or(1.0);
211 let range = (max_v - min_v).max(f64::EPSILON);
212 for i in 0..n {
213 let bin = ((data[[i, j]] - min_v) / range * self.n_bins as f64) as i64;
214 result[[i, j]] = bin.min(self.n_bins as i64 - 1).max(0);
215 }
216 }
217 }
218 result
219 }
220}
221
222impl ConditionalIndependenceTest for GSquaredTest {
223 fn test(
224 &self,
225 x: usize,
226 y: usize,
227 z_set: &[usize],
228 data: ArrayView2<f64>,
229 ) -> StatsResult<CITestResult> {
230 let n = data.nrows();
231 let discrete = self.discretise(data);
232
233 let x_levels = unique_levels(&discrete, x);
235 let y_levels = unique_levels(&discrete, y);
236
237 let z_configs = if z_set.is_empty() {
239 vec![vec![0i64]] } else {
241 cartesian_z_configs(&discrete, z_set)
242 };
243
244 let mut g2 = 0.0_f64;
245 let mut df = 0_usize;
246
247 for z_config in &z_configs {
248 let z_mask: Vec<bool> = (0..n)
250 .map(|i| {
251 if z_set.is_empty() {
252 true
253 } else {
254 z_set
255 .iter()
256 .enumerate()
257 .all(|(k, &zj)| discrete[[i, zj]] == z_config[k])
258 }
259 })
260 .collect();
261
262 let n_z: f64 = z_mask.iter().filter(|&&b| b).count() as f64;
263 if n_z < 1.0 {
264 continue;
265 }
266
267 for &xv in &x_levels {
268 for &yv in &y_levels {
269 let n_xyz = z_mask
270 .iter()
271 .enumerate()
272 .filter(|&(i, &b)| b && discrete[[i, x]] == xv && discrete[[i, y]] == yv)
273 .count() as f64;
274 let n_xz = z_mask
275 .iter()
276 .enumerate()
277 .filter(|&(i, &b)| b && discrete[[i, x]] == xv)
278 .count() as f64;
279 let n_yz = z_mask
280 .iter()
281 .enumerate()
282 .filter(|&(i, &b)| b && discrete[[i, y]] == yv)
283 .count() as f64;
284
285 if n_xyz > 0.0 && n_xz > 0.0 && n_yz > 0.0 && n_z > 0.0 {
286 g2 += n_xyz * (n_xyz * n_z / (n_xz * n_yz)).ln();
287 }
288 }
289 }
290 df += (x_levels.len().saturating_sub(1)) * (y_levels.len().saturating_sub(1));
291 }
292 g2 *= 2.0;
293
294 if df == 0 {
295 return Ok(CITestResult {
296 statistic: 0.0,
297 p_value: 1.0,
298 reject: false,
299 });
300 }
301
302 let p_value = chi2_survival(g2, df as f64);
304
305 Ok(CITestResult {
306 statistic: g2,
307 p_value,
308 reject: p_value <= self.alpha,
309 })
310 }
311}
312
313#[derive(Debug, Clone)]
325pub struct KernelCITest {
326 pub alpha: f64,
328 pub n_permutations: usize,
330 pub seed: u64,
332}
333
334impl Default for KernelCITest {
335 fn default() -> Self {
336 Self {
337 alpha: 0.05,
338 n_permutations: 100,
339 seed: 42,
340 }
341 }
342}
343
344impl KernelCITest {
345 pub fn new(alpha: f64, n_permutations: usize, seed: u64) -> Self {
347 Self {
348 alpha,
349 n_permutations,
350 seed,
351 }
352 }
353
354 fn kernel_matrix(&self, data: ArrayView2<f64>, cols: &[usize], bandwidth: f64) -> Array2<f64> {
356 let n = data.nrows();
357 let mut k = Array2::<f64>::zeros((n, n));
358 let bw2 = 2.0 * bandwidth * bandwidth;
359
360 for i in 0..n {
361 for j in i..n {
362 let mut dist2 = 0.0_f64;
363 for &c in cols {
364 let d = data[[i, c]] - data[[j, c]];
365 dist2 += d * d;
366 }
367 let val = (-dist2 / bw2.max(f64::EPSILON)).exp();
368 k[[i, j]] = val;
369 k[[j, i]] = val;
370 }
371 }
372 k
373 }
374
375 fn median_bandwidth(&self, data: ArrayView2<f64>, cols: &[usize]) -> f64 {
377 let n = data.nrows();
378 let max_pairs = 500; let step = if n * (n - 1) / 2 > max_pairs {
380 (n as f64 / (max_pairs as f64).sqrt()).ceil() as usize
381 } else {
382 1
383 };
384
385 let mut dists = Vec::new();
386 let mut i = 0;
387 while i < n {
388 let mut j = i + 1;
389 while j < n {
390 let mut d2 = 0.0_f64;
391 for &c in cols {
392 let d = data[[i, c]] - data[[j, c]];
393 d2 += d * d;
394 }
395 dists.push(d2.sqrt());
396 j += step;
397 }
398 i += step;
399 }
400
401 if dists.is_empty() {
402 return 1.0;
403 }
404 dists.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
405 let median = dists[dists.len() / 2];
406 median.max(0.01)
407 }
408
409 fn centre_kernel(&self, k: &Array2<f64>) -> Array2<f64> {
411 let n = k.nrows();
412 let nf = n as f64;
413
414 let row_means: Vec<f64> = (0..n)
416 .map(|i| (0..n).map(|j| k[[i, j]]).sum::<f64>() / nf)
417 .collect();
418 let grand_mean: f64 = row_means.iter().sum::<f64>() / nf;
419
420 let mut kc = Array2::<f64>::zeros((n, n));
421 for i in 0..n {
422 for j in 0..n {
423 kc[[i, j]] = k[[i, j]] - row_means[i] - row_means[j] + grand_mean;
424 }
425 }
426 kc
427 }
428
429 fn hsic(&self, kx: &Array2<f64>, ky: &Array2<f64>) -> f64 {
431 let n = kx.nrows();
432 let nf = n as f64;
433 let kx_c = self.centre_kernel(kx);
434 let ky_c = self.centre_kernel(ky);
435
436 let mut trace = 0.0_f64;
437 for i in 0..n {
438 for j in 0..n {
439 trace += kx_c[[i, j]] * ky_c[[j, i]];
440 }
441 }
442 trace / (nf * nf)
443 }
444}
445
446impl ConditionalIndependenceTest for KernelCITest {
447 fn test(
448 &self,
449 x: usize,
450 y: usize,
451 z_set: &[usize],
452 data: ArrayView2<f64>,
453 ) -> StatsResult<CITestResult> {
454 let n = data.nrows();
455 if n < 5 {
456 return Err(StatsError::InvalidArgument(
457 "Need at least 5 observations for kernel CI test".to_owned(),
458 ));
459 }
460
461 let x_cols = vec![x];
465 let y_cols = vec![y];
466
467 let bw_x = self.median_bandwidth(data, &x_cols);
468 let bw_y = self.median_bandwidth(data, &y_cols);
469
470 if z_set.is_empty() {
471 let kx = self.kernel_matrix(data, &x_cols, bw_x);
473 let ky = self.kernel_matrix(data, &y_cols, bw_y);
474 let observed_hsic = self.hsic(&kx, &ky);
475
476 let mut count_ge = 0usize;
478 let mut lcg = self.seed;
479 for _ in 0..self.n_permutations {
480 let mut perm: Vec<usize> = (0..n).collect();
482 fisher_yates_shuffle(&mut perm, &mut lcg);
483 let mut ky_perm = Array2::<f64>::zeros((n, n));
484 for i in 0..n {
485 for j in 0..n {
486 ky_perm[[i, j]] = ky[[perm[i], perm[j]]];
487 }
488 }
489 let perm_hsic = self.hsic(&kx, &ky_perm);
490 if perm_hsic >= observed_hsic {
491 count_ge += 1;
492 }
493 }
494
495 let p_value = (count_ge as f64 + 1.0) / (self.n_permutations as f64 + 1.0);
496 Ok(CITestResult {
497 statistic: observed_hsic,
498 p_value,
499 reject: p_value <= self.alpha,
500 })
501 } else {
502 let res_x = ols_residuals(data, x, z_set)?;
504 let res_y = ols_residuals(data, y, z_set)?;
505
506 let mut res_data = Array2::<f64>::zeros((n, 2));
508 for i in 0..n {
509 res_data[[i, 0]] = res_x[i];
510 res_data[[i, 1]] = res_y[i];
511 }
512
513 let bw_rx = self.median_bandwidth(res_data.view(), &[0]);
514 let bw_ry = self.median_bandwidth(res_data.view(), &[1]);
515
516 let kx = self.kernel_matrix(res_data.view(), &[0], bw_rx);
517 let ky = self.kernel_matrix(res_data.view(), &[1], bw_ry);
518 let observed_hsic = self.hsic(&kx, &ky);
519
520 let mut count_ge = 0usize;
522 let mut lcg = self.seed;
523 for _ in 0..self.n_permutations {
524 let mut perm: Vec<usize> = (0..n).collect();
525 fisher_yates_shuffle(&mut perm, &mut lcg);
526 let mut ky_perm = Array2::<f64>::zeros((n, n));
527 for i in 0..n {
528 for j in 0..n {
529 ky_perm[[i, j]] = ky[[perm[i], perm[j]]];
530 }
531 }
532 let perm_hsic = self.hsic(&kx, &ky_perm);
533 if perm_hsic >= observed_hsic {
534 count_ge += 1;
535 }
536 }
537
538 let p_value = (count_ge as f64 + 1.0) / (self.n_permutations as f64 + 1.0);
539 Ok(CITestResult {
540 statistic: observed_hsic,
541 p_value,
542 reject: p_value <= self.alpha,
543 })
544 }
545 }
546}
547
548fn pearson_r(data: ArrayView2<f64>, x: usize, y: usize) -> f64 {
554 let n = data.nrows() as f64;
555 let mx: f64 = data.column(x).iter().sum::<f64>() / n;
556 let my: f64 = data.column(y).iter().sum::<f64>() / n;
557 let mut cov = 0.0_f64;
558 let mut vx = 0.0_f64;
559 let mut vy = 0.0_f64;
560 for i in 0..data.nrows() {
561 let dx = data[[i, x]] - mx;
562 let dy = data[[i, y]] - my;
563 cov += dx * dy;
564 vx += dx * dx;
565 vy += dy * dy;
566 }
567 cov / (vx * vy).sqrt().max(f64::EPSILON)
568}
569
570fn pearson_r_arrays(
572 a: scirs2_core::ndarray::ArrayView1<f64>,
573 b: scirs2_core::ndarray::ArrayView1<f64>,
574) -> f64 {
575 let n = a.len() as f64;
576 let ma = a.iter().sum::<f64>() / n;
577 let mb = b.iter().sum::<f64>() / n;
578 let mut cov = 0.0_f64;
579 let mut va = 0.0_f64;
580 let mut vb = 0.0_f64;
581 for (&ai, &bi) in a.iter().zip(b.iter()) {
582 let da = ai - ma;
583 let db = bi - mb;
584 cov += da * db;
585 va += da * da;
586 vb += db * db;
587 }
588 cov / (va * vb).sqrt().max(f64::EPSILON)
589}
590
591fn ols_residuals(
593 data: ArrayView2<f64>,
594 target: usize,
595 predictors: &[usize],
596) -> StatsResult<Array1<f64>> {
597 let n = data.nrows();
598 let p = predictors.len();
599 let mut design = Array2::<f64>::ones((n, p + 1));
600 for (j, &pred) in predictors.iter().enumerate() {
601 for i in 0..n {
602 design[[i, j + 1]] = data[[i, pred]];
603 }
604 }
605 let y: Array1<f64> = data.column(target).to_owned();
606 let coef = ols_solve(design.view(), y.view())?;
607 let mut residuals = y;
608 for i in 0..n {
609 let pred: f64 = (0..=p).map(|j| design[[i, j]] * coef[j]).sum();
610 residuals[i] -= pred;
611 }
612 Ok(residuals)
613}
614
615fn ols_solve(
617 x: ArrayView2<f64>,
618 y: scirs2_core::ndarray::ArrayView1<f64>,
619) -> StatsResult<Array1<f64>> {
620 let (n, p) = x.dim();
621 let mut xtx = Array2::<f64>::zeros((p, p));
622 let mut xty = Array1::<f64>::zeros(p);
623 for i in 0..n {
624 for j in 0..p {
625 xty[j] += x[[i, j]] * y[i];
626 for k in 0..p {
627 xtx[[j, k]] += x[[i, j]] * x[[i, k]];
628 }
629 }
630 }
631 for j in 0..p {
632 xtx[[j, j]] += 1e-8;
633 }
634 gauss_jordan_solve(xtx, xty)
635}
636
637fn gauss_jordan_solve(mut a: Array2<f64>, mut b: Array1<f64>) -> StatsResult<Array1<f64>> {
639 let n = b.len();
640 for col in 0..n {
641 let pivot_row = (col..n)
642 .max_by(|&i, &j| {
643 a[[i, col]]
644 .abs()
645 .partial_cmp(&a[[j, col]].abs())
646 .unwrap_or(std::cmp::Ordering::Equal)
647 })
648 .ok_or_else(|| StatsError::ComputationError("Singular matrix in CI test".to_owned()))?;
649 for k in 0..n {
650 let tmp = a[[col, k]];
651 a[[col, k]] = a[[pivot_row, k]];
652 a[[pivot_row, k]] = tmp;
653 }
654 let tmp = b[col];
655 b[col] = b[pivot_row];
656 b[pivot_row] = tmp;
657
658 let pivot = a[[col, col]];
659 if pivot.abs() < 1e-12 {
660 return Err(StatsError::ComputationError(
661 "Singular OLS system in CI test".to_owned(),
662 ));
663 }
664 for k in col..n {
665 a[[col, k]] /= pivot;
666 }
667 b[col] /= pivot;
668 for row in 0..n {
669 if row != col {
670 let factor = a[[row, col]];
671 for k in col..n {
672 let av = a[[col, k]];
673 a[[row, k]] -= factor * av;
674 }
675 b[row] -= factor * b[col];
676 }
677 }
678 }
679 Ok(b)
680}
681
682fn normal_cdf(x: f64) -> f64 {
684 0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
685}
686
687fn erf(x: f64) -> f64 {
689 let t = 1.0 / (1.0 + 0.3275911 * x.abs());
690 let poly = t
691 * (0.254_829_592
692 + t * (-0.284_496_736
693 + t * (1.421_413_741 + t * (-1.453_152_027 + t * 1.061_405_429))));
694 if x >= 0.0 {
695 1.0 - poly * (-x * x).exp()
696 } else {
697 -(1.0 - poly * (-x * x).exp())
698 }
699}
700
701fn unique_levels(data: &Array2<i64>, col: usize) -> Vec<i64> {
703 let mut levels: Vec<i64> = data.column(col).iter().copied().collect();
704 levels.sort();
705 levels.dedup();
706 levels
707}
708
709fn cartesian_z_configs(data: &Array2<i64>, z_set: &[usize]) -> Vec<Vec<i64>> {
711 let n = data.nrows();
712 let mut configs = std::collections::HashSet::new();
713 for i in 0..n {
714 let config: Vec<i64> = z_set.iter().map(|&zj| data[[i, zj]]).collect();
715 configs.insert(config);
716 }
717 configs.into_iter().collect()
718}
719
720fn chi2_survival(x: f64, df: f64) -> f64 {
723 if x <= 0.0 || df <= 0.0 {
724 return 1.0;
725 }
726 upper_gamma_q(df / 2.0, x / 2.0)
729}
730
731fn upper_gamma_q(a: f64, x: f64) -> f64 {
733 if x < 0.0 {
734 return 1.0;
735 }
736 if x < a + 1.0 {
737 1.0 - lower_gamma_series(a, x)
739 } else {
740 upper_gamma_cf(a, x)
742 }
743}
744
745fn lower_gamma_series(a: f64, x: f64) -> f64 {
747 if x <= 0.0 {
748 return 0.0;
749 }
750 let mut sum = 1.0 / a;
751 let mut term = 1.0 / a;
752 for n in 1..200 {
753 term *= x / (a + n as f64);
754 sum += term;
755 if term.abs() < 1e-12 * sum.abs() {
756 break;
757 }
758 }
759 let log_prefix = a * x.ln() - x - lgamma(a);
760 (log_prefix.exp() * sum).clamp(0.0, 1.0)
761}
762
763fn upper_gamma_cf(a: f64, x: f64) -> f64 {
765 let mut f = 1e-30_f64;
767 let mut c = 1e-30_f64;
768 let mut d = 1.0 / (x + 1.0 - a);
769 f = d;
770
771 for i in 1..200 {
772 let an = (a - i as f64) * i as f64;
773 let bn = x + 2.0 * i as f64 + 1.0 - a;
774 d = 1.0 / (bn + an * d).max(1e-30);
775 c = (bn + an / c).max(1e-30);
776 let delta = c * d;
777 f *= delta;
778 if (delta - 1.0).abs() < 1e-10 {
779 break;
780 }
781 }
782
783 let log_prefix = a * x.ln() - x - lgamma(a);
784 (log_prefix.exp() * f).clamp(0.0, 1.0)
785}
786
787fn lgamma(x: f64) -> f64 {
789 if x < 0.5 {
790 std::f64::consts::PI.ln() - (std::f64::consts::PI * x).sin().abs().ln() - lgamma(1.0 - x)
791 } else {
792 let z = x - 1.0;
793 let t = z + 7.5;
794 let coeffs = [
795 0.999_999_999_999_809_9,
796 676.520_368_121_885_1,
797 -1_259.139_216_722_402_8,
798 771.323_428_777_653_1,
799 -176.615_029_162_140_6,
800 12.507_343_278_686_905,
801 -0.138_571_095_265_720_12,
802 9.984_369_578_019_572e-6,
803 1.505_632_735_149_312e-7,
804 ];
805 let mut x_part = coeffs[0];
806 for (i, &c) in coeffs[1..].iter().enumerate() {
807 x_part += c / (z + 1.0 + i as f64);
808 }
809 0.5 * (2.0 * std::f64::consts::PI).ln() + (z + 0.5) * t.ln() - t + x_part.ln()
810 }
811}
812
813fn fisher_yates_shuffle(perm: &mut [usize], lcg: &mut u64) {
815 let n = perm.len();
816 for i in (1..n).rev() {
817 *lcg = lcg.wrapping_mul(6364136223846793005).wrapping_add(1);
818 let j = (*lcg >> 33) as usize % (i + 1);
819 perm.swap(i, j);
820 }
821}
822
823#[cfg(test)]
828mod tests {
829 use super::*;
830 use scirs2_core::ndarray::Array2;
831
832 fn lcg_uniform(s: &mut u64) -> f64 {
834 *s = s
835 .wrapping_mul(6364136223846793005)
836 .wrapping_add(1442695040888963407);
837 ((*s >> 11) as f64) / ((1u64 << 53) as f64)
838 }
839
840 fn lcg_normal(s: &mut u64) -> f64 {
842 let u1 = lcg_uniform(s).max(1e-15);
843 let u2 = lcg_uniform(s);
844 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
845 }
846
847 fn chain_data(n: usize) -> Array2<f64> {
849 let mut data = Array2::<f64>::zeros((n, 3));
850 let mut lcg: u64 = 12345;
851 for i in 0..n {
852 data[[i, 0]] = lcg_normal(&mut lcg);
853 data[[i, 1]] = 0.9 * data[[i, 0]] + lcg_normal(&mut lcg) * 0.3;
854 data[[i, 2]] = 0.9 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
855 }
856 data
857 }
858
859 fn independent_data(n: usize) -> Array2<f64> {
861 let mut data = Array2::<f64>::zeros((n, 3));
862 let mut lcg: u64 = 54321;
863 for i in 0..n {
864 data[[i, 0]] = lcg_normal(&mut lcg);
865 data[[i, 1]] = lcg_normal(&mut lcg);
866 data[[i, 2]] = lcg_normal(&mut lcg);
867 }
868 data
869 }
870
871 #[test]
872 fn test_partial_corr_dependent() {
873 let data = chain_data(200);
874 let test = PartialCorrelationTest::new(0.05);
875 let result = test.test(0, 1, &[], data.view()).expect("test failed");
876 assert!(
878 result.p_value < 0.05,
879 "Expected dependent: p={}",
880 result.p_value
881 );
882 }
883
884 #[test]
885 fn test_partial_corr_conditional_independence() {
886 let data = chain_data(200);
887 let test = PartialCorrelationTest::new(0.05);
888 let result = test.test(0, 2, &[1], data.view()).expect("test failed");
890 assert!(
891 result.p_value > 0.01,
892 "Expected CI given Y: p={}",
893 result.p_value
894 );
895 }
896
897 #[test]
898 fn test_partial_corr_independent_pair() {
899 let data = independent_data(200);
900 let test = PartialCorrelationTest::new(0.05);
901 let result = test.test(0, 1, &[], data.view()).expect("test failed");
902 assert!(
903 result.p_value > 0.05,
904 "Expected independent: p={}",
905 result.p_value
906 );
907 }
908
909 #[test]
910 fn test_partial_corr_value() {
911 let data = chain_data(200);
912 let test = PartialCorrelationTest::default();
913 let rho = test
914 .partial_correlation(0, 1, &[], data.view())
915 .expect("failed");
916 assert!(rho > 0.5, "Expected strong correlation: rho={rho}");
918 }
919
920 #[test]
921 fn test_partial_corr_is_independent() {
922 let data = independent_data(200);
923 let test = PartialCorrelationTest::new(0.05);
924 let indep = test
925 .is_independent(0, 2, &[], data.view(), 0.05)
926 .expect("failed");
927 assert!(indep, "Expected independent pair to pass");
928 }
929
930 #[test]
931 fn test_gsquared_dependent() {
932 let n = 200;
934 let mut data = Array2::<f64>::zeros((n, 2));
935 for i in 0..n {
936 let x = (i % 3) as f64;
937 data[[i, 0]] = x;
938 data[[i, 1]] = x; }
940 let test = GSquaredTest::new(0.05, 0);
941 let result = test.test(0, 1, &[], data.view()).expect("test failed");
942 assert!(
943 result.p_value < 0.05,
944 "Expected dependent: p={}",
945 result.p_value
946 );
947 }
948
949 #[test]
950 fn test_gsquared_independent() {
951 let n = 300;
953 let mut data = Array2::<f64>::zeros((n, 2));
954 let mut lcg: u64 = 99999;
955 for i in 0..n {
956 lcg = lcg.wrapping_mul(6364136223846793005).wrapping_add(1);
957 data[[i, 0]] = (i % 3) as f64;
958 data[[i, 1]] = ((lcg >> 33) % 3) as f64;
959 }
960 let test = GSquaredTest::new(0.05, 0);
961 let result = test.test(0, 1, &[], data.view()).expect("test failed");
962 assert!(
964 result.p_value > 0.01,
965 "Expected independent: p={}",
966 result.p_value
967 );
968 }
969
970 #[test]
971 fn test_gsquared_conditional() {
972 let n = 300;
974 let mut data = Array2::<f64>::zeros((n, 3));
975 for i in 0..n {
976 let x = (i % 3) as f64;
977 let z = x; let y = z; data[[i, 0]] = x;
980 data[[i, 1]] = y;
981 data[[i, 2]] = z;
982 }
983 let test = GSquaredTest::new(0.05, 0);
984 let r1 = test.test(0, 1, &[], data.view()).expect("test failed");
986 assert!(r1.p_value < 0.05, "Expected dependent: p={}", r1.p_value);
987 }
988
989 #[test]
990 fn test_kernel_ci_dependent() {
991 let data = chain_data(100);
992 let test = KernelCITest::new(0.05, 200, 42);
993 let result = test.test(0, 1, &[], data.view()).expect("test failed");
994 assert!(
995 result.p_value < 0.1,
996 "Expected dependent: p={}",
997 result.p_value
998 );
999 }
1000
1001 #[test]
1002 fn test_kernel_ci_independent() {
1003 let data = independent_data(80);
1004 let test = KernelCITest::new(0.05, 500, 12345);
1005 let result = test.test(0, 1, &[], data.view()).expect("test failed");
1006 assert!(
1008 result.p_value >= 0.0 && result.p_value <= 1.0,
1009 "p-value should be in [0,1]: p={}",
1010 result.p_value
1011 );
1012 assert!(result.statistic.is_finite());
1013 }
1014
1015 #[test]
1016 fn test_kernel_ci_conditional() {
1017 let data = chain_data(80);
1018 let test = KernelCITest::new(0.05, 200, 42);
1019 let result = test.test(0, 2, &[1], data.view()).expect("test failed");
1021 assert!(
1023 result.statistic.is_finite(),
1024 "HSIC statistic should be finite"
1025 );
1026 assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1027 }
1028
1029 #[test]
1030 fn test_ci_result_fields() {
1031 let data = chain_data(100);
1032 let test = PartialCorrelationTest::new(0.05);
1033 let result = test.test(0, 1, &[], data.view()).expect("test failed");
1034 assert!(result.statistic.is_finite());
1035 assert!(result.p_value >= 0.0 && result.p_value <= 1.0);
1036 assert_eq!(result.reject, result.p_value <= 0.05);
1038 }
1039}