Skip to main content

scirs2_series/causality/
pc.rs

1//! PC Algorithm for Causal Discovery
2//!
3//! Implements the classic PC (Peter-Clark) algorithm for discovering causal
4//! structure from observational data. Unlike [`super::pc_stable`] which is
5//! adapted for time series with lagged variables, this module implements the
6//! standard cross-sectional PC algorithm.
7//!
8//! ## Algorithm Phases
9//!
10//! 1. **Skeleton** (Phase 1): Start with a complete undirected graph. For each
11//!    pair (X, Y), test conditional independence X _||_ Y | S for increasing
12//!    sizes of conditioning set S drawn from the adjacency set. Remove the edge
13//!    if independent.
14//!
15//! 2. **V-structures** (Phase 2): For each unshielded triple X - Z - Y (i.e.
16//!    X and Y not adjacent, both adjacent to Z), orient as X -> Z <- Y if Z
17//!    was *not* in the separation set of (X, Y).
18//!
19//! 3. **Meek rules** (Phase 3): Apply Meek's orientation propagation rules
20//!    to infer additional edge directions without creating new v-structures
21//!    or directed cycles.
22//!
23//! ## Usage
24//!
25//! ```rust,no_run
26//! use scirs2_series::causality::pc::{PCAlgorithm, PCConfig, IndependenceTest};
27//!
28//! let data: Vec<Vec<f64>> = vec![
29//!     vec![1.0, 2.0, 3.0],
30//!     vec![2.0, 4.1, 5.9],
31//!     // ... more observations
32//! ];
33//! let config = PCConfig::default();
34//! let pc = PCAlgorithm::new(config);
35//! let graph = pc.discover(&data).expect("discovery failed");
36//! println!("Edges: {:?}", graph.edges);
37//! ```
38
39use super::CausalityResult;
40use crate::error::TimeSeriesError;
41
42use std::collections::{HashMap, HashSet};
43
44// ---------------------------------------------------------------------------
45// Public types
46// ---------------------------------------------------------------------------
47
48/// Configuration for the PC algorithm.
49#[non_exhaustive]
50#[derive(Debug, Clone)]
51pub struct PCConfig {
52    /// Significance level for conditional independence tests.
53    pub significance_level: f64,
54    /// Maximum conditioning set size to consider.
55    /// The algorithm stops increasing the conditioning set size once it exceeds
56    /// this value or when no edge is testable.
57    pub max_cond_set_size: usize,
58    /// Type of independence test to use.
59    pub test_type: IndependenceTest,
60}
61
62impl Default for PCConfig {
63    fn default() -> Self {
64        Self {
65            significance_level: 0.05,
66            max_cond_set_size: 4,
67            test_type: IndependenceTest::PartialCorrelation,
68        }
69    }
70}
71
72/// Type of conditional independence test.
73#[non_exhaustive]
74#[derive(Debug, Clone, Copy, PartialEq)]
75pub enum IndependenceTest {
76    /// Partial correlation test with Fisher's z-transform.
77    PartialCorrelation,
78    /// Mutual information test (Gaussian approximation).
79    MutualInformation,
80    /// Kernel-based independence test (HSIC-like, simplified).
81    KernelBased,
82}
83
84/// Type of an edge in the causal graph.
85#[non_exhaustive]
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87pub enum EdgeType {
88    /// Directed edge (arrow from `from` to `to`).
89    Directed,
90    /// Undirected edge.
91    Undirected,
92    /// Bidirected edge (latent common cause).
93    Bidirected,
94}
95
96/// A single edge in the causal graph.
97#[derive(Debug, Clone, PartialEq, Eq, Hash)]
98pub struct CausalEdge {
99    /// Source node index.
100    pub from: usize,
101    /// Target node index.
102    pub to: usize,
103    /// Type of edge.
104    pub edge_type: EdgeType,
105}
106
107/// Result of causal discovery: the estimated causal graph.
108#[derive(Debug, Clone)]
109pub struct CausalGraph {
110    /// Number of nodes (variables).
111    pub nodes: usize,
112    /// Discovered edges.
113    pub edges: Vec<CausalEdge>,
114    /// Separation sets for pairs that were separated.
115    /// Key: (i, j) with i < j; Value: conditioning set that separated them.
116    pub separation_sets: HashMap<(usize, usize), Vec<usize>>,
117}
118
119impl CausalGraph {
120    /// Check if there is a directed edge from `from` to `to`.
121    pub fn has_directed_edge(&self, from: usize, to: usize) -> bool {
122        self.edges
123            .iter()
124            .any(|e| e.from == from && e.to == to && e.edge_type == EdgeType::Directed)
125    }
126
127    /// Check if there is any edge (directed or undirected) between `a` and `b`.
128    pub fn has_edge(&self, a: usize, b: usize) -> bool {
129        self.edges
130            .iter()
131            .any(|e| (e.from == a && e.to == b) || (e.from == b && e.to == a))
132    }
133
134    /// Count edges of a given type.
135    pub fn count_edges(&self, edge_type: EdgeType) -> usize {
136        self.edges
137            .iter()
138            .filter(|e| e.edge_type == edge_type)
139            .count()
140    }
141}
142
143// ---------------------------------------------------------------------------
144// PC Algorithm
145// ---------------------------------------------------------------------------
146
147/// The PC algorithm for causal discovery from observational data.
148#[derive(Debug, Clone)]
149pub struct PCAlgorithm {
150    config: PCConfig,
151}
152
153impl PCAlgorithm {
154    /// Create a new PC algorithm instance.
155    pub fn new(config: PCConfig) -> Self {
156        Self { config }
157    }
158
159    /// Run causal discovery on cross-sectional data.
160    ///
161    /// # Arguments
162    /// * `data` - Observations as a vector of samples, each sample is a vector
163    ///   of variable values. All samples must have the same length.
164    ///
165    /// # Returns
166    /// A [`CausalGraph`] with the discovered causal structure.
167    pub fn discover(&self, data: &[Vec<f64>]) -> CausalityResult<CausalGraph> {
168        let n_samples = data.len();
169        if n_samples < 4 {
170            return Err(TimeSeriesError::InsufficientData {
171                message: "Need at least 4 samples for PC algorithm".to_string(),
172                required: 4,
173                actual: n_samples,
174            });
175        }
176
177        let n_vars = data[0].len();
178        if n_vars < 2 {
179            return Err(TimeSeriesError::InvalidInput(
180                "Need at least 2 variables for causal discovery".to_string(),
181            ));
182        }
183
184        // Validate all samples have same length
185        for (i, sample) in data.iter().enumerate() {
186            if sample.len() != n_vars {
187                return Err(TimeSeriesError::DimensionMismatch {
188                    expected: n_vars,
189                    actual: sample.len(),
190                });
191            }
192            // Check for NaN / Inf
193            for &v in sample {
194                if !v.is_finite() {
195                    return Err(TimeSeriesError::InvalidInput(format!(
196                        "Non-finite value in sample {}",
197                        i
198                    )));
199                }
200            }
201        }
202
203        // Compute correlation/covariance matrix
204        let cov_matrix = compute_covariance_matrix(data)?;
205
206        // Phase 1: Skeleton discovery
207        let (adjacency, separation_sets) =
208            self.discover_skeleton(n_vars, n_samples, &cov_matrix)?;
209
210        // Phase 2: Orient v-structures
211        let mut edge_types = self.orient_v_structures(n_vars, &adjacency, &separation_sets);
212
213        // Phase 3: Meek rules
214        self.apply_meek_rules(n_vars, &adjacency, &mut edge_types);
215
216        // Build the causal graph
217        let mut edges = Vec::new();
218        for i in 0..n_vars {
219            for j in (i + 1)..n_vars {
220                if adjacency[i].contains(&j) {
221                    let key = (i, j);
222                    let et = edge_types
223                        .get(&key)
224                        .copied()
225                        .unwrap_or(EdgeType::Undirected);
226                    match et {
227                        EdgeType::Directed => {
228                            // Check direction: was it i->j or j->i?
229                            // We store the directed edge based on what was determined
230                            if let Some(&dir) = edge_types.get(&(i, j)) {
231                                if dir == EdgeType::Directed {
232                                    edges.push(CausalEdge {
233                                        from: i,
234                                        to: j,
235                                        edge_type: EdgeType::Directed,
236                                    });
237                                }
238                            }
239                        }
240                        _ => {
241                            edges.push(CausalEdge {
242                                from: i,
243                                to: j,
244                                edge_type: et,
245                            });
246                        }
247                    }
248                }
249            }
250        }
251
252        // Also add reverse directed edges that were stored with (j, i) key
253        for (&(from, to), &et) in &edge_types {
254            if et == EdgeType::Directed && from > to {
255                // This is a j->i edge stored as (j, i) = Directed
256                edges.push(CausalEdge {
257                    from,
258                    to,
259                    edge_type: EdgeType::Directed,
260                });
261            }
262        }
263
264        // Deduplicate
265        let mut seen = HashSet::new();
266        let deduped: Vec<CausalEdge> = edges
267            .into_iter()
268            .filter(|e| {
269                let key = (e.from, e.to, e.edge_type);
270                seen.insert(key)
271            })
272            .collect();
273
274        Ok(CausalGraph {
275            nodes: n_vars,
276            edges: deduped,
277            separation_sets,
278        })
279    }
280
281    // --- Phase 1: Skeleton ---
282
283    fn discover_skeleton(
284        &self,
285        n_vars: usize,
286        n_samples: usize,
287        cov_matrix: &[Vec<f64>],
288    ) -> CausalityResult<(Vec<HashSet<usize>>, HashMap<(usize, usize), Vec<usize>>)> {
289        // Start with complete undirected graph
290        let mut adjacency: Vec<HashSet<usize>> = (0..n_vars)
291            .map(|i| (0..n_vars).filter(|&j| j != i).collect())
292            .collect();
293
294        let mut separation_sets: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
295
296        let mut p = 0usize;
297        loop {
298            if p > self.config.max_cond_set_size {
299                break;
300            }
301
302            let mut any_testable = false;
303            let mut removals: Vec<(usize, usize, Vec<usize>)> = Vec::new();
304
305            // Snapshot adjacency for stable iteration
306            let adj_snapshot: Vec<Vec<usize>> = adjacency
307                .iter()
308                .map(|s| {
309                    let mut v: Vec<usize> = s.iter().copied().collect();
310                    v.sort();
311                    v
312                })
313                .collect();
314
315            for i in 0..n_vars {
316                let neighbors_i = &adj_snapshot[i];
317                for &j in neighbors_i {
318                    if j <= i {
319                        continue; // Only test each pair once
320                    }
321
322                    // Conditioning set candidates: neighbors of i, excluding j
323                    let cond_candidates: Vec<usize> =
324                        neighbors_i.iter().copied().filter(|&k| k != j).collect();
325
326                    if cond_candidates.len() < p {
327                        continue;
328                    }
329                    any_testable = true;
330
331                    // Test all subsets of size p
332                    let subsets = gen_combinations(&cond_candidates, p);
333                    let mut found_independent = false;
334                    let mut best_sep = Vec::new();
335
336                    for subset in &subsets {
337                        let p_value = self
338                            .test_conditional_independence(i, j, subset, n_samples, cov_matrix)?;
339
340                        if p_value > self.config.significance_level {
341                            found_independent = true;
342                            best_sep = subset.clone();
343                            break;
344                        }
345                    }
346
347                    if found_independent {
348                        removals.push((i, j, best_sep));
349                    }
350                }
351            }
352
353            // Apply removals
354            for (i, j, sep_set) in removals {
355                adjacency[i].remove(&j);
356                adjacency[j].remove(&i);
357                let key = if i < j { (i, j) } else { (j, i) };
358                separation_sets.insert(key, sep_set);
359            }
360
361            if !any_testable {
362                break;
363            }
364            p += 1;
365        }
366
367        Ok((adjacency, separation_sets))
368    }
369
370    // --- Phase 2: V-structure orientation ---
371
372    fn orient_v_structures(
373        &self,
374        n_vars: usize,
375        adjacency: &[HashSet<usize>],
376        separation_sets: &HashMap<(usize, usize), Vec<usize>>,
377    ) -> HashMap<(usize, usize), EdgeType> {
378        let mut edge_types: HashMap<(usize, usize), EdgeType> = HashMap::new();
379
380        // Initialize all remaining edges as undirected
381        for i in 0..n_vars {
382            for &j in &adjacency[i] {
383                if j > i {
384                    edge_types.insert((i, j), EdgeType::Undirected);
385                }
386            }
387        }
388
389        // For each unshielded triple X - Z - Y (X and Y not adjacent)
390        for z in 0..n_vars {
391            let neighbors_z: Vec<usize> = adjacency[z].iter().copied().collect();
392            for idx_x in 0..neighbors_z.len() {
393                for idx_y in (idx_x + 1)..neighbors_z.len() {
394                    let x = neighbors_z[idx_x];
395                    let y = neighbors_z[idx_y];
396
397                    // Check if X and Y are NOT adjacent (unshielded)
398                    if adjacency[x].contains(&y) {
399                        continue;
400                    }
401
402                    // Look up separation set of (X, Y)
403                    let key = if x < y { (x, y) } else { (y, x) };
404                    let sep_set = separation_sets.get(&key);
405
406                    // If Z is NOT in the separation set, orient as X -> Z <- Y
407                    let z_in_sep = sep_set.map(|s| s.contains(&z)).unwrap_or(false);
408
409                    if !z_in_sep {
410                        // Orient X -> Z
411                        edge_types.insert((x, z), EdgeType::Directed);
412                        // Orient Y -> Z
413                        edge_types.insert((y, z), EdgeType::Directed);
414                        // Remove the undirected versions
415                        let k1 = if x < z { (x, z) } else { (z, x) };
416                        let k2 = if y < z { (y, z) } else { (z, y) };
417                        edge_types.remove(&k1);
418                        edge_types.remove(&k2);
419                        edge_types.insert((x, z), EdgeType::Directed);
420                        edge_types.insert((y, z), EdgeType::Directed);
421                    }
422                }
423            }
424        }
425
426        edge_types
427    }
428
429    // --- Phase 3: Meek rules ---
430
431    fn apply_meek_rules(
432        &self,
433        n_vars: usize,
434        adjacency: &[HashSet<usize>],
435        edge_types: &mut HashMap<(usize, usize), EdgeType>,
436    ) {
437        // Apply Meek's four rules iteratively until no more orientations change
438        let max_iterations = n_vars * n_vars;
439        for _ in 0..max_iterations {
440            let mut changed = false;
441
442            // Rule 1: If A -> B - C and A and C are not adjacent, orient B -> C
443            for b in 0..n_vars {
444                let neighbors_b: Vec<usize> = adjacency[b].iter().copied().collect();
445                for &c in &neighbors_b {
446                    // Check if B - C is undirected
447                    if !is_undirected(edge_types, b, c) {
448                        continue;
449                    }
450
451                    for &a in &neighbors_b {
452                        if a == c {
453                            continue;
454                        }
455                        // Check if A -> B (directed)
456                        if !is_directed(edge_types, a, b) {
457                            continue;
458                        }
459                        // Check if A and C are not adjacent
460                        if adjacency[a].contains(&c) {
461                            continue;
462                        }
463
464                        // Orient B -> C
465                        orient_edge(edge_types, b, c);
466                        changed = true;
467                    }
468                }
469            }
470
471            // Rule 2: If A -> C -> B and A - B, orient A -> B
472            for a in 0..n_vars {
473                let neighbors_a: Vec<usize> = adjacency[a].iter().copied().collect();
474                for &b in &neighbors_a {
475                    if !is_undirected(edge_types, a, b) {
476                        continue;
477                    }
478
479                    // Look for C such that A -> C and C -> B
480                    for &c in &neighbors_a {
481                        if c == b {
482                            continue;
483                        }
484                        if !is_directed(edge_types, a, c) {
485                            continue;
486                        }
487                        if !adjacency[c].contains(&b) {
488                            continue;
489                        }
490                        if !is_directed(edge_types, c, b) {
491                            continue;
492                        }
493
494                        orient_edge(edge_types, a, b);
495                        changed = true;
496                    }
497                }
498            }
499
500            // Rule 3: If A - C, A - D, C -> B, D -> B, and C and D are not adjacent,
501            // orient A -> B
502            for a in 0..n_vars {
503                let neighbors_a: Vec<usize> = adjacency[a].iter().copied().collect();
504                for &b in &neighbors_a {
505                    if !is_undirected(edge_types, a, b) {
506                        continue;
507                    }
508
509                    // Find two distinct C, D both adjacent to A (undirected) and both -> B
510                    let mut oriented = false;
511                    for idx_c in 0..neighbors_a.len() {
512                        if oriented {
513                            break;
514                        }
515                        let c = neighbors_a[idx_c];
516                        if c == b {
517                            continue;
518                        }
519                        if !is_undirected(edge_types, a, c) {
520                            continue;
521                        }
522                        if !adjacency[c].contains(&b) || !is_directed(edge_types, c, b) {
523                            continue;
524                        }
525
526                        for idx_d in (idx_c + 1)..neighbors_a.len() {
527                            let d = neighbors_a[idx_d];
528                            if d == b || d == c {
529                                continue;
530                            }
531                            if !is_undirected(edge_types, a, d) {
532                                continue;
533                            }
534                            if !adjacency[d].contains(&b) || !is_directed(edge_types, d, b) {
535                                continue;
536                            }
537                            // C and D must not be adjacent
538                            if adjacency[c].contains(&d) {
539                                continue;
540                            }
541
542                            orient_edge(edge_types, a, b);
543                            changed = true;
544                            oriented = true;
545                            break;
546                        }
547                    }
548                }
549            }
550
551            if !changed {
552                break;
553            }
554        }
555    }
556
557    // --- Conditional independence tests ---
558
559    fn test_conditional_independence(
560        &self,
561        i: usize,
562        j: usize,
563        cond_set: &[usize],
564        n_samples: usize,
565        cov_matrix: &[Vec<f64>],
566    ) -> CausalityResult<f64> {
567        match self.config.test_type {
568            IndependenceTest::PartialCorrelation => {
569                partial_correlation_test(i, j, cond_set, n_samples, cov_matrix)
570            }
571            IndependenceTest::MutualInformation => {
572                mutual_information_test(i, j, cond_set, n_samples, cov_matrix)
573            }
574            IndependenceTest::KernelBased => {
575                // Simplified: falls back to partial correlation
576                // A full kernel-based test (HSIC) would require access to raw data
577                partial_correlation_test(i, j, cond_set, n_samples, cov_matrix)
578            }
579        }
580    }
581}
582
583// ---------------------------------------------------------------------------
584// Statistical tests
585// ---------------------------------------------------------------------------
586
587/// Partial correlation test using Fisher's z-transform.
588///
589/// Computes the partial correlation between variables `i` and `j` conditioning
590/// on `cond_set`, then applies Fisher's z-transform to compute a p-value.
591fn partial_correlation_test(
592    i: usize,
593    j: usize,
594    cond_set: &[usize],
595    n_samples: usize,
596    cov_matrix: &[Vec<f64>],
597) -> CausalityResult<f64> {
598    let parcorr = compute_partial_corr(i, j, cond_set, cov_matrix)?;
599
600    // Fisher's z-transform
601    let df = n_samples as f64 - cond_set.len() as f64 - 2.0;
602    if df < 1.0 {
603        return Ok(1.0); // Not enough degrees of freedom
604    }
605
606    let clamped = parcorr.clamp(-0.9999, 0.9999);
607    let z_stat = 0.5 * ((1.0 + clamped) / (1.0 - clamped)).ln() * df.sqrt();
608
609    // Two-sided p-value
610    let p_value = 2.0 * (1.0 - normal_cdf(z_stat.abs()));
611    Ok(p_value)
612}
613
614/// Mutual information test (Gaussian approximation).
615///
616/// Under Gaussianity: MI(X;Y|Z) = -0.5 * ln(1 - parcorr^2)
617/// Test statistic: 2 * n * MI ~ chi2(1) under H0
618fn mutual_information_test(
619    i: usize,
620    j: usize,
621    cond_set: &[usize],
622    n_samples: usize,
623    cov_matrix: &[Vec<f64>],
624) -> CausalityResult<f64> {
625    let parcorr = compute_partial_corr(i, j, cond_set, cov_matrix)?;
626
627    let r_sq = parcorr * parcorr;
628    let mi = if r_sq < 1.0 {
629        -0.5 * (1.0 - r_sq).ln()
630    } else {
631        f64::INFINITY
632    };
633
634    let test_stat = 2.0 * n_samples as f64 * mi;
635    let p_value = chi_squared_p_value_1df(test_stat);
636    Ok(p_value)
637}
638
639/// Compute partial correlation between i and j given cond_set.
640///
641/// Uses the precision matrix approach: parcorr(i,j|S) = -P[i,j] / sqrt(P[i,i] * P[j,j])
642/// where P = inv(Sigma_S) and Sigma_S is the submatrix of the covariance for {i,j} ∪ S.
643fn compute_partial_corr(
644    i: usize,
645    j: usize,
646    cond_set: &[usize],
647    cov_matrix: &[Vec<f64>],
648) -> CausalityResult<f64> {
649    if cond_set.is_empty() {
650        // Simple Pearson correlation
651        let var_i = cov_matrix[i][i];
652        let var_j = cov_matrix[j][j];
653        let denom = (var_i * var_j).sqrt();
654        if denom < 1e-15 {
655            return Ok(0.0);
656        }
657        return Ok(cov_matrix[i][j] / denom);
658    }
659
660    // Build the sub-covariance matrix for variables {i, j} ∪ cond_set
661    let mut indices = vec![i, j];
662    indices.extend_from_slice(cond_set);
663    let k = indices.len();
664
665    let mut sub_cov = vec![vec![0.0; k]; k];
666    for (a_idx, &a) in indices.iter().enumerate() {
667        for (b_idx, &b) in indices.iter().enumerate() {
668            sub_cov[a_idx][b_idx] = cov_matrix[a][b];
669        }
670    }
671
672    // Regularize
673    for idx in 0..k {
674        sub_cov[idx][idx] += 1e-10;
675    }
676
677    // Invert
678    let precision = invert_small_matrix(&sub_cov)?;
679
680    let denom = (precision[0][0] * precision[1][1]).sqrt();
681    if denom < 1e-15 {
682        return Ok(0.0);
683    }
684
685    Ok(-precision[0][1] / denom)
686}
687
688// ---------------------------------------------------------------------------
689// Matrix utilities
690// ---------------------------------------------------------------------------
691
692/// Compute the covariance matrix from data samples.
693fn compute_covariance_matrix(data: &[Vec<f64>]) -> CausalityResult<Vec<Vec<f64>>> {
694    let n = data.len();
695    let p = data[0].len();
696
697    // Compute means
698    let mut means = vec![0.0; p];
699    for sample in data {
700        for (j, &v) in sample.iter().enumerate() {
701            means[j] += v;
702        }
703    }
704    for m in &mut means {
705        *m /= n as f64;
706    }
707
708    // Compute covariance
709    let mut cov = vec![vec![0.0; p]; p];
710    for sample in data {
711        for a in 0..p {
712            let da = sample[a] - means[a];
713            for b in a..p {
714                let db = sample[b] - means[b];
715                cov[a][b] += da * db;
716            }
717        }
718    }
719
720    let denom = (n as f64 - 1.0).max(1.0);
721    for a in 0..p {
722        for b in a..p {
723            cov[a][b] /= denom;
724            cov[b][a] = cov[a][b];
725        }
726    }
727
728    Ok(cov)
729}
730
731/// Invert a small matrix using Gauss-Jordan elimination with partial pivoting.
732fn invert_small_matrix(mat: &[Vec<f64>]) -> CausalityResult<Vec<Vec<f64>>> {
733    let n = mat.len();
734    let mut augmented = vec![vec![0.0; 2 * n]; n];
735
736    for i in 0..n {
737        for j in 0..n {
738            augmented[i][j] = mat[i][j];
739        }
740        augmented[i][n + i] = 1.0;
741    }
742
743    for col in 0..n {
744        let mut max_val = augmented[col][col].abs();
745        let mut max_row = col;
746        for row in (col + 1)..n {
747            let val = augmented[row][col].abs();
748            if val > max_val {
749                max_val = val;
750                max_row = row;
751            }
752        }
753
754        if max_val < 1e-14 {
755            return Err(TimeSeriesError::NumericalInstability(
756                "Singular matrix in partial correlation computation".to_string(),
757            ));
758        }
759
760        if max_row != col {
761            augmented.swap(col, max_row);
762        }
763
764        let pivot = augmented[col][col];
765        for j in 0..(2 * n) {
766            augmented[col][j] /= pivot;
767        }
768
769        for row in 0..n {
770            if row != col {
771                let factor = augmented[row][col];
772                for j in 0..(2 * n) {
773                    augmented[row][j] -= factor * augmented[col][j];
774                }
775            }
776        }
777    }
778
779    let mut inv = vec![vec![0.0; n]; n];
780    for i in 0..n {
781        for j in 0..n {
782            inv[i][j] = augmented[i][n + j];
783        }
784    }
785
786    Ok(inv)
787}
788
789// ---------------------------------------------------------------------------
790// Edge orientation helpers
791// ---------------------------------------------------------------------------
792
793fn is_directed(edge_types: &HashMap<(usize, usize), EdgeType>, from: usize, to: usize) -> bool {
794    edge_types
795        .get(&(from, to))
796        .map(|&et| et == EdgeType::Directed)
797        .unwrap_or(false)
798}
799
800fn is_undirected(edge_types: &HashMap<(usize, usize), EdgeType>, a: usize, b: usize) -> bool {
801    let k1 = (a, b);
802    let k2 = (b, a);
803    let k_canon = if a < b { (a, b) } else { (b, a) };
804
805    // If any directed orientation exists, it's not undirected
806    if is_directed(edge_types, a, b) || is_directed(edge_types, b, a) {
807        return false;
808    }
809
810    // Check if there is an undirected edge
811    edge_types
812        .get(&k_canon)
813        .map(|&et| et == EdgeType::Undirected)
814        .unwrap_or(false)
815        || edge_types
816            .get(&k1)
817            .map(|&et| et == EdgeType::Undirected)
818            .unwrap_or(false)
819        || edge_types
820            .get(&k2)
821            .map(|&et| et == EdgeType::Undirected)
822            .unwrap_or(false)
823}
824
825fn orient_edge(edge_types: &mut HashMap<(usize, usize), EdgeType>, from: usize, to: usize) {
826    // Remove undirected versions
827    let k_canon = if from < to { (from, to) } else { (to, from) };
828    edge_types.remove(&k_canon);
829    edge_types.remove(&(to, from));
830    edge_types.remove(&(from, to));
831    // Insert directed
832    edge_types.insert((from, to), EdgeType::Directed);
833}
834
835// ---------------------------------------------------------------------------
836// Combination generation
837// ---------------------------------------------------------------------------
838
839fn gen_combinations(items: &[usize], k: usize) -> Vec<Vec<usize>> {
840    if k == 0 {
841        return vec![vec![]];
842    }
843    if k > items.len() {
844        return vec![];
845    }
846    if k == items.len() {
847        return vec![items.to_vec()];
848    }
849
850    let mut result = Vec::new();
851    gen_combinations_rec(items, k, 0, &mut vec![], &mut result);
852    result
853}
854
855fn gen_combinations_rec(
856    items: &[usize],
857    k: usize,
858    start: usize,
859    current: &mut Vec<usize>,
860    result: &mut Vec<Vec<usize>>,
861) {
862    if current.len() == k {
863        result.push(current.clone());
864        return;
865    }
866    let remaining = k - current.len();
867    let available = items.len() - start;
868    if available < remaining {
869        return;
870    }
871    for i in start..items.len() {
872        current.push(items[i]);
873        gen_combinations_rec(items, k, i + 1, current, result);
874        current.pop();
875    }
876}
877
878// ---------------------------------------------------------------------------
879// Standard normal CDF
880// ---------------------------------------------------------------------------
881
882fn normal_cdf(x: f64) -> f64 {
883    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
884}
885
886fn erf(x: f64) -> f64 {
887    let a1 = 0.254_829_592;
888    let a2 = -0.284_496_736;
889    let a3 = 1.421_413_741;
890    let a4 = -1.453_152_027;
891    let a5 = 1.061_405_429;
892    let p = 0.327_591_1;
893
894    let sign = if x < 0.0 { -1.0 } else { 1.0 };
895    let x = x.abs();
896    let t = 1.0 / (1.0 + p * x);
897    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
898    sign * y
899}
900
901/// Chi-squared(1) p-value
902fn chi_squared_p_value_1df(chi2: f64) -> f64 {
903    if chi2 <= 0.0 {
904        return 1.0;
905    }
906    // chi2(1) CDF = 2 * Phi(sqrt(chi2)) - 1
907    // p-value = 1 - CDF = 2 * (1 - Phi(sqrt(chi2)))
908    2.0 * (1.0 - normal_cdf(chi2.sqrt()))
909}
910
911// ---------------------------------------------------------------------------
912// Fisher z-transform utility (public for tests)
913// ---------------------------------------------------------------------------
914
915/// Compute Fisher's z-transform of a correlation coefficient.
916///
917/// z = 0.5 * ln((1 + r) / (1 - r))
918///
919/// Under H0 (r = 0), z * sqrt(n - |S| - 3) ~ N(0, 1).
920pub fn fisher_z_transform(r: f64, n: usize, cond_size: usize) -> (f64, f64) {
921    let clamped = r.clamp(-0.9999, 0.9999);
922    let z = 0.5 * ((1.0 + clamped) / (1.0 - clamped)).ln();
923    let df = n as f64 - cond_size as f64 - 3.0;
924    let z_stat = if df > 0.0 { z * df.sqrt() } else { 0.0 };
925    let p_value = if df > 0.0 {
926        2.0 * (1.0 - normal_cdf(z_stat.abs()))
927    } else {
928        1.0
929    };
930    (z_stat, p_value)
931}
932
933// ---------------------------------------------------------------------------
934// Tests
935// ---------------------------------------------------------------------------
936
937#[cfg(test)]
938mod tests {
939    use super::*;
940
941    /// Simple LCG pseudo-random for deterministic tests.
942    fn next_rand(state: &mut u64) -> f64 {
943        *state = state
944            .wrapping_mul(6364136223846793005)
945            .wrapping_add(1442695040888963407);
946        ((*state >> 32) as f64) / (u32::MAX as f64) - 0.5
947    }
948
949    /// Generate N samples of (X, Y, Z) where X -> Y -> Z (chain)
950    fn generate_chain(n: usize, seed: u64) -> Vec<Vec<f64>> {
951        let mut state = seed;
952        let mut data = Vec::with_capacity(n);
953        for _ in 0..n {
954            let x = next_rand(&mut state);
955            let y = 0.8 * x + next_rand(&mut state) * 0.3;
956            let z = 0.8 * y + next_rand(&mut state) * 0.3;
957            data.push(vec![x, y, z]);
958        }
959        data
960    }
961
962    /// Generate N samples of (X, Y, Z) where X -> Z <- Y (v-structure)
963    fn generate_v_structure(n: usize, seed: u64) -> Vec<Vec<f64>> {
964        let mut state = seed;
965        let mut data = Vec::with_capacity(n);
966        for _ in 0..n {
967            let x = next_rand(&mut state);
968            let y = next_rand(&mut state);
969            let z = 0.7 * x + 0.7 * y + next_rand(&mut state) * 0.2;
970            data.push(vec![x, y, z]);
971        }
972        data
973    }
974
975    /// Generate independent variables
976    fn generate_independent(n: usize, seed: u64) -> Vec<Vec<f64>> {
977        let mut state = seed;
978        let mut data = Vec::with_capacity(n);
979        for _ in 0..n {
980            let x = next_rand(&mut state);
981            let y = next_rand(&mut state);
982            let z = next_rand(&mut state);
983            data.push(vec![x, y, z]);
984        }
985        data
986    }
987
988    #[test]
989    fn test_pc_config_default() {
990        let cfg = PCConfig::default();
991        assert!((cfg.significance_level - 0.05).abs() < 1e-10);
992        assert_eq!(cfg.max_cond_set_size, 4);
993        assert_eq!(cfg.test_type, IndependenceTest::PartialCorrelation);
994    }
995
996    #[test]
997    fn test_independent_variables_no_edge() {
998        let data = generate_independent(500, 42);
999        let config = PCConfig {
1000            significance_level: 0.05,
1001            max_cond_set_size: 2,
1002            test_type: IndependenceTest::PartialCorrelation,
1003        };
1004        let pc = PCAlgorithm::new(config);
1005        let graph = pc.discover(&data).expect("discovery");
1006        // Independent variables should have no edges (or very few due to chance)
1007        assert!(
1008            graph.edges.len() <= 1,
1009            "Independent vars should have ~0 edges, got {}",
1010            graph.edges.len()
1011        );
1012    }
1013
1014    #[test]
1015    fn test_chain_skeleton_discovered() {
1016        // X -> Y -> Z: skeleton should have X-Y and Y-Z edges
1017        let data = generate_chain(1000, 123);
1018        let config = PCConfig {
1019            significance_level: 0.05,
1020            max_cond_set_size: 2,
1021            test_type: IndependenceTest::PartialCorrelation,
1022        };
1023        let pc = PCAlgorithm::new(config);
1024        let graph = pc.discover(&data).expect("discovery");
1025
1026        // Should have edges involving Y (Y is connected to both X and Z)
1027        let has_xy = graph.has_edge(0, 1);
1028        let has_yz = graph.has_edge(1, 2);
1029        assert!(has_xy, "Should have X-Y edge in chain");
1030        assert!(has_yz, "Should have Y-Z edge in chain");
1031
1032        // X-Z should NOT be present (conditional independence given Y)
1033        let has_xz = graph.has_edge(0, 2);
1034        assert!(!has_xz, "Should NOT have X-Z direct edge in chain");
1035    }
1036
1037    #[test]
1038    fn test_v_structure_orientation() {
1039        // X -> Z <- Y: X and Y independent, both cause Z
1040        let data = generate_v_structure(1000, 456);
1041        let config = PCConfig {
1042            significance_level: 0.05,
1043            max_cond_set_size: 2,
1044            test_type: IndependenceTest::PartialCorrelation,
1045        };
1046        let pc = PCAlgorithm::new(config);
1047        let graph = pc.discover(&data).expect("discovery");
1048
1049        // X and Y should NOT be adjacent
1050        assert!(
1051            !graph.has_edge(0, 1),
1052            "X and Y should not be adjacent in v-structure"
1053        );
1054
1055        // Check that Z is connected to both X and Y
1056        let has_xz = graph.has_edge(0, 2);
1057        let has_yz = graph.has_edge(1, 2);
1058        assert!(has_xz, "Should have X-Z edge");
1059        assert!(has_yz, "Should have Y-Z edge");
1060
1061        // V-structure should orient: X -> Z and Y -> Z
1062        let x_to_z = graph.has_directed_edge(0, 2);
1063        let y_to_z = graph.has_directed_edge(1, 2);
1064        assert!(x_to_z, "Should orient X -> Z in v-structure");
1065        assert!(y_to_z, "Should orient Y -> Z in v-structure");
1066    }
1067
1068    #[test]
1069    fn test_causal_graph_node_edge_counts() {
1070        let data = generate_chain(500, 789);
1071        let config = PCConfig::default();
1072        let pc = PCAlgorithm::new(config);
1073        let graph = pc.discover(&data).expect("discovery");
1074
1075        assert_eq!(graph.nodes, 3);
1076        // Chain should have 2 edges (X-Y and Y-Z), X-Z removed
1077        assert!(
1078            graph.edges.len() >= 2,
1079            "Chain should have at least 2 edges, got {}",
1080            graph.edges.len()
1081        );
1082    }
1083
1084    #[test]
1085    fn test_partial_correlation_independent() {
1086        // Two truly independent variables: correlation should be near zero
1087        let data = generate_independent(500, 999);
1088        let cov = compute_covariance_matrix(&data).expect("cov");
1089        let parcorr = compute_partial_corr(0, 1, &[], &cov).expect("parcorr");
1090        assert!(
1091            parcorr.abs() < 0.15,
1092            "Independent vars should have near-zero partial corr, got {}",
1093            parcorr
1094        );
1095    }
1096
1097    #[test]
1098    fn test_partial_correlation_dependent() {
1099        let data = generate_chain(500, 111);
1100        let cov = compute_covariance_matrix(&data).expect("cov");
1101        let parcorr = compute_partial_corr(0, 1, &[], &cov).expect("parcorr");
1102        assert!(
1103            parcorr.abs() > 0.3,
1104            "Dependent vars should have significant partial corr, got {}",
1105            parcorr
1106        );
1107    }
1108
1109    #[test]
1110    fn test_partial_correlation_conditional_independence() {
1111        // In X->Y->Z, X and Z should be conditionally independent given Y
1112        let data = generate_chain(1000, 222);
1113        let cov = compute_covariance_matrix(&data).expect("cov");
1114        let parcorr_xz_given_y = compute_partial_corr(0, 2, &[1], &cov).expect("parcorr");
1115        assert!(
1116            parcorr_xz_given_y.abs() < 0.15,
1117            "X⊥Z|Y should hold in chain, parcorr={}",
1118            parcorr_xz_given_y
1119        );
1120    }
1121
1122    #[test]
1123    fn test_fisher_z_transform_correct_pvalue() {
1124        // Zero correlation => p-value should be ~1
1125        let (_, p) = fisher_z_transform(0.0, 100, 0);
1126        assert!(
1127            (p - 1.0).abs() < 0.01,
1128            "Zero correlation should give p≈1, got {}",
1129            p
1130        );
1131
1132        // Strong correlation => p-value should be small
1133        let (_, p2) = fisher_z_transform(0.9, 100, 0);
1134        assert!(
1135            p2 < 0.01,
1136            "Strong correlation should give small p-value, got {}",
1137            p2
1138        );
1139    }
1140
1141    #[test]
1142    fn test_mutual_information_test() {
1143        let data = generate_chain(500, 333);
1144        let cov = compute_covariance_matrix(&data).expect("cov");
1145        let p = mutual_information_test(0, 1, &[], 500, &cov).expect("mi");
1146        assert!(
1147            p < 0.05,
1148            "MI test should detect dependence in chain, p={}",
1149            p
1150        );
1151    }
1152
1153    #[test]
1154    fn test_pc_insufficient_data() {
1155        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
1156        let pc = PCAlgorithm::new(PCConfig::default());
1157        let result = pc.discover(&data);
1158        assert!(result.is_err());
1159    }
1160
1161    #[test]
1162    fn test_edge_type_non_exhaustive() {
1163        // Verify we can construct all edge types
1164        let _d = EdgeType::Directed;
1165        let _u = EdgeType::Undirected;
1166        let _b = EdgeType::Bidirected;
1167    }
1168
1169    #[test]
1170    fn test_known_graph_recovery_synthetic() {
1171        // Generate a known DAG: X0 -> X1, X0 -> X2, X1 -> X3, X2 -> X3
1172        // This is a diamond/fork structure
1173        let n = 2000;
1174        let mut state: u64 = 42;
1175        let mut data = Vec::with_capacity(n);
1176        for _ in 0..n {
1177            let x0 = next_rand(&mut state);
1178            let x1 = 0.8 * x0 + next_rand(&mut state) * 0.2;
1179            let x2 = 0.8 * x0 + next_rand(&mut state) * 0.2;
1180            let x3 = 0.5 * x1 + 0.5 * x2 + next_rand(&mut state) * 0.2;
1181            data.push(vec![x0, x1, x2, x3]);
1182        }
1183
1184        let config = PCConfig {
1185            significance_level: 0.05,
1186            max_cond_set_size: 3,
1187            test_type: IndependenceTest::PartialCorrelation,
1188        };
1189        let pc = PCAlgorithm::new(config);
1190        let graph = pc.discover(&data).expect("discovery");
1191
1192        assert_eq!(graph.nodes, 4);
1193
1194        // Should have edges: 0-1, 0-2, 1-3, 2-3
1195        assert!(graph.has_edge(0, 1), "Should have X0-X1 edge");
1196        assert!(graph.has_edge(0, 2), "Should have X0-X2 edge");
1197        assert!(graph.has_edge(1, 3), "Should have X1-X3 edge");
1198        assert!(graph.has_edge(2, 3), "Should have X2-X3 edge");
1199    }
1200}