Skip to main content

scirs2_stats/causal_graph/
structure_learning.rs

1//! Causal Structure Learning from Data
2//!
3//! # Algorithms provided
4//!
5//! | Algorithm | Type | Reference |
6//! |-----------|------|-----------|
7//! | [`PcAlgorithm`] | Constraint-based (observed variables) | Spirtes et al. (2000) |
8//! | [`FciAlgorithm`] | Constraint-based (latent variables allowed) | Richardson & Spirtes (2002) |
9//! | [`BicGreedySearch`] | Score-based (BIC / BDe scores) | Chickering (2002) |
10//! | [`LiNGAM`] | Non-Gaussian, continuous data | Shimizu et al. (2006) |
11//! | [`Notears`] | Gradient-based continuous optimisation | Zheng et al. (2018) |
12//!
13//! # References
14//!
15//! - Spirtes, P., Glymour, C. & Scheines, R. (2000). *Causation, Prediction,
16//!   and Search* (2nd ed.). MIT Press.
17//! - Zheng, X., Aragam, B., Ravikumar, P. & Xing, E.P. (2018).
18//!   DAGs with NO TEARS. *NeurIPS 2018*.
19//! - Shimizu, S. et al. (2006). A Linear Non-Gaussian Acyclic Model for
20//!   Causal Discovery. *JMLR* 7, 2003-2030.
21//! - Chickering, D.M. (2002). Optimal Structure Identification with Greedy
22//!   Search. *JMLR* 3, 507-554.
23
24use std::collections::{HashMap, HashSet};
25
26use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
27
28use crate::causal_graph::dag::CausalDAG;
29use crate::error::{StatsError, StatsResult};
30
31// ---------------------------------------------------------------------------
32// Shared output type
33// ---------------------------------------------------------------------------
34
35/// Result of a structure-learning algorithm.
36#[derive(Debug, Clone)]
37pub struct StructureLearningResult {
38    /// The learned causal DAG (or CPDAG / PAG skeleton for PC/FCI).
39    pub dag: CausalDAG,
40    /// Score of the learned graph (e.g. BIC; `NaN` when not applicable).
41    pub score: f64,
42    /// Algorithm name.
43    pub algorithm: String,
44    /// Number of conditional independence tests performed (CI-based) or
45    /// gradient steps (NOTEARS).
46    pub n_tests: usize,
47    /// Edge confidence / orientation info (for PC / FCI).
48    pub edge_info: HashMap<(usize, usize), EdgeType>,
49}
50
51/// Edge orientation type in the learned graph.
52#[derive(Debug, Clone, PartialEq, Eq)]
53pub enum EdgeType {
54    /// Directed edge (parent → child confirmed).
55    Directed,
56    /// Undirected (skeleton edge, orientation unresolved by PC/FCI).
57    Undirected,
58    /// Bidirected (latent common cause, FCI only).
59    Bidirected,
60    /// Partially directed (tail – circle, FCI partially oriented mark).
61    PartiallyDirected,
62}
63
64// ---------------------------------------------------------------------------
65// Conditional Independence Test (partial correlation)
66// ---------------------------------------------------------------------------
67
68/// Test conditional independence X ⊥ Y | Z using partial correlation.
69/// Returns the p-value under H₀: ρ_{XY·Z} = 0.
70fn partial_correlation_test(
71    data: ArrayView2<f64>,
72    x: usize,
73    y: usize,
74    z_set: &[usize],
75) -> StatsResult<f64> {
76    let n = data.nrows();
77    if z_set.is_empty() {
78        // Simple Pearson correlation test
79        let rho = pearson_r(data.column(x), data.column(y));
80        return Ok(pearson_p_value(rho, n));
81    }
82
83    // Partial correlation via OLS residuals
84    let res_x = ols_residuals(data, x, z_set)?;
85    let res_y = ols_residuals(data, y, z_set)?;
86    let rho = pearson_r(res_x.view(), res_y.view());
87    Ok(pearson_p_value(rho, n.saturating_sub(z_set.len())))
88}
89
90fn pearson_r(
91    a: scirs2_core::ndarray::ArrayView1<f64>,
92    b: scirs2_core::ndarray::ArrayView1<f64>,
93) -> f64 {
94    let n = a.len() as f64;
95    let ma = a.mean().unwrap_or(0.0);
96    let mb = b.mean().unwrap_or(0.0);
97    let cov: f64 = a
98        .iter()
99        .zip(b.iter())
100        .map(|(&ai, &bi)| (ai - ma) * (bi - mb))
101        .sum::<f64>();
102    let va: f64 = a.iter().map(|&ai| (ai - ma).powi(2)).sum::<f64>();
103    let vb: f64 = b.iter().map(|&bi| (bi - mb).powi(2)).sum::<f64>();
104    cov / (va * vb).sqrt().max(f64::EPSILON)
105}
106
107fn pearson_p_value(rho: f64, n: usize) -> f64 {
108    if n < 3 {
109        return 1.0;
110    }
111    let df = (n - 2) as f64;
112    let t = rho * (df / (1.0 - rho * rho).max(1e-12)).sqrt();
113    // Student-t p-value approximation
114    t_dist_two_sided_p(t, df)
115}
116
117/// Two-sided p-value from t-distribution using normal approximation for large df,
118/// and a Bailey (1994) two-moment approximation for small df.
119fn t_dist_two_sided_p(t: f64, df: f64) -> f64 {
120    if !t.is_finite() || !df.is_finite() || df < 1.0 {
121        return 1.0;
122    }
123    // Use normal approximation for df > 30, otherwise use a series approximation
124    if df > 30.0 {
125        return 2.0 * (1.0 - normal_cdf(t.abs()));
126    }
127    // Abramowitz & Stegun series for t-distribution CDF
128    // P(|T| > t) using the regularised incomplete beta function I_x(a, b)
129    // where x = df/(df+t^2), a = df/2, b = 1/2
130    let x = df / (df + t * t);
131    let p = inc_beta_series(df * 0.5, 0.5, x);
132    p.clamp(0.0, 1.0)
133}
134
135/// Regularised incomplete beta I_x(a,b) via a continued-fraction expansion
136/// (Lentz algorithm) which is more numerically stable than the series.
137fn inc_beta_series(a: f64, b: f64, x: f64) -> f64 {
138    if !x.is_finite() || x <= 0.0 {
139        return 0.0;
140    }
141    if x >= 1.0 {
142        return 1.0;
143    }
144    // Log of the beta function prefix
145    let log_prefix = a * x.ln() + b * (1.0 - x).ln() - log_beta(a, b);
146    if !log_prefix.is_finite() {
147        return 0.5;
148    }
149    let prefix = log_prefix.exp();
150    // Use continued fraction if x > (a+1)/(a+b+2), else series
151    if x < (a + 1.0) / (a + b + 2.0) {
152        // Series: I_x(a,b) = prefix * Σ_{k=0}^∞ x^k * Γ(a+b+k)/(Γ(a+1+k)Γ(b+k+1)… )
153        // Simple series: Σ (1-x)^k / (a + k), scaled
154        let mut s = 0.0_f64;
155        let mut t_term = 1.0_f64 / a;
156        s += t_term;
157        for k in 1..200_usize {
158            t_term *= x * (a + b + k as f64 - 1.0) / ((a + k as f64) * k as f64);
159            s += t_term;
160            if t_term.abs() < 1e-12 {
161                break;
162            }
163        }
164        (prefix * s).clamp(0.0, 1.0)
165    } else {
166        // Symmetry relation: I_x(a,b) = 1 - I_{1-x}(b,a)
167        1.0 - inc_beta_series(b, a, 1.0 - x)
168    }
169}
170
171fn log_beta(a: f64, b: f64) -> f64 {
172    lgamma(a) + lgamma(b) - lgamma(a + b)
173}
174
175fn normal_cdf(x: f64) -> f64 {
176    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
177}
178
179fn erf(x: f64) -> f64 {
180    let t = 1.0 / (1.0 + 0.3275911 * x.abs());
181    let poly = t
182        * (0.254_829_592
183            + t * (-0.284_496_736
184                + t * (1.421_413_741 + t * (-1.453_152_027 + t * 1.061_405_429))));
185    if x >= 0.0 {
186        1.0 - poly * (-x * x).exp()
187    } else {
188        -(1.0 - poly * (-x * x).exp())
189    }
190}
191
192// regularised_beta is replaced by inc_beta_series above
193
194fn lgamma(x: f64) -> f64 {
195    // Stirling approximation
196    if x < 0.5 {
197        std::f64::consts::PI.ln() - (std::f64::consts::PI * x).sin().abs().ln() - lgamma(1.0 - x)
198    } else {
199        let z = x - 1.0;
200        let t = z + 7.5;
201        let coeffs = [
202            0.999_999_999_999_809_9,
203            676.520_368_121_885_1,
204            -1_259.139_216_722_402_8,
205            771.323_428_777_653_1,
206            -176.615_029_162_140_6,
207            12.507_343_278_686_905,
208            -0.138_571_095_265_720_12,
209            9.984_369_578_019_572e-6,
210            1.505_632_735_149_312e-7,
211        ];
212        let mut x_part = coeffs[0];
213        for (i, &c) in coeffs[1..].iter().enumerate() {
214            x_part += c / (z + 1.0 + i as f64);
215        }
216        0.5 * (2.0 * std::f64::consts::PI).ln() + (z + 0.5) * t.ln() - t + x_part.ln()
217    }
218}
219
220/// Compute OLS residuals of `target ~ z_set`.
221fn ols_residuals(
222    data: ArrayView2<f64>,
223    target: usize,
224    predictors: &[usize],
225) -> StatsResult<Array1<f64>> {
226    let n = data.nrows();
227    let p = predictors.len();
228    let mut design = Array2::<f64>::ones((n, p + 1));
229    for (j, &pred) in predictors.iter().enumerate() {
230        for i in 0..n {
231            design[[i, j + 1]] = data[[i, pred]];
232        }
233    }
234    let y: Array1<f64> = data.column(target).to_owned();
235    // Normal equations
236    let coef = ols_solve(design.view(), y.view())?;
237    let mut residuals = y.clone();
238    for i in 0..n {
239        let pred: f64 = (0..=p).map(|j| design[[i, j]] * coef[j]).sum();
240        residuals[i] -= pred;
241    }
242    Ok(residuals)
243}
244
245fn ols_solve(x: ArrayView2<f64>, y: ArrayView1<f64>) -> StatsResult<Array1<f64>> {
246    let (n, p) = x.dim();
247    let mut xtx = Array2::<f64>::zeros((p, p));
248    let mut xty = Array1::<f64>::zeros(p);
249    for i in 0..n {
250        for j in 0..p {
251            xty[j] += x[[i, j]] * y[i];
252            for k in 0..p {
253                xtx[[j, k]] += x[[i, j]] * x[[i, k]];
254            }
255        }
256    }
257    // Add small ridge for stability
258    for j in 0..p {
259        xtx[[j, j]] += 1e-8;
260    }
261    gauss_jordan_solve(xtx, xty)
262}
263
264fn gauss_jordan_solve(mut a: Array2<f64>, mut b: Array1<f64>) -> StatsResult<Array1<f64>> {
265    let n = b.len();
266    for col in 0..n {
267        let pivot_row = (col..n)
268            .max_by(|&i, &j| {
269                a[[i, col]]
270                    .abs()
271                    .partial_cmp(&a[[j, col]].abs())
272                    .unwrap_or(std::cmp::Ordering::Equal)
273            })
274            .ok_or_else(|| StatsError::ComputationError("Singular matrix".to_owned()))?;
275        // Swap rows in a
276        for k in 0..n {
277            let tmp = a[[col, k]];
278            a[[col, k]] = a[[pivot_row, k]];
279            a[[pivot_row, k]] = tmp;
280        }
281        let tmp = b[col];
282        b[col] = b[pivot_row];
283        b[pivot_row] = tmp;
284
285        let pivot = a[[col, col]];
286        if pivot.abs() < 1e-12 {
287            return Err(StatsError::ComputationError(
288                "Singular OLS system".to_owned(),
289            ));
290        }
291        for k in col..n {
292            a[[col, k]] /= pivot;
293        }
294        b[col] /= pivot;
295        for row in 0..n {
296            if row != col {
297                let factor = a[[row, col]];
298                for k in col..n {
299                    let av = a[[col, k]];
300                    a[[row, k]] -= factor * av;
301                }
302                b[row] -= factor * b[col];
303            }
304        }
305    }
306    Ok(b)
307}
308
309// ---------------------------------------------------------------------------
310// 1. PC Algorithm
311// ---------------------------------------------------------------------------
312
313/// Peter-Clark (PC) algorithm for constraint-based causal discovery.
314///
315/// Proceeds in two phases:
316/// 1. **Skeleton discovery**: iteratively remove edges failing conditional
317///    independence tests up to order `max_cond_set_size`.
318/// 2. **Orientation**: orient colliders (v-structures) and apply Meek's rules.
319pub struct PcAlgorithm {
320    /// Significance level α for conditional independence tests.
321    pub alpha: f64,
322    /// Maximum conditioning set size.
323    pub max_cond_set_size: usize,
324    /// If `true`, use Fisher's z-transform (assumes Gaussian data).
325    pub gaussian: bool,
326}
327
328impl Default for PcAlgorithm {
329    fn default() -> Self {
330        Self {
331            alpha: 0.05,
332            max_cond_set_size: 3,
333            gaussian: true,
334        }
335    }
336}
337
338impl PcAlgorithm {
339    /// Run the PC algorithm on the data matrix (rows = observations, cols = variables).
340    ///
341    /// Returns a CPDAG (Completed Partially Directed Acyclic Graph).
342    pub fn fit(
343        &self,
344        data: ArrayView2<f64>,
345        var_names: &[&str],
346    ) -> StatsResult<StructureLearningResult> {
347        let p = data.ncols();
348        if var_names.len() != p {
349            return Err(StatsError::DimensionMismatch(
350                "var_names length must equal number of columns in data".to_owned(),
351            ));
352        }
353
354        // Phase 1: skeleton discovery
355        // Start with fully connected skeleton
356        let mut adj: Vec<Vec<bool>> = vec![vec![true; p]; p];
357        for i in 0..p {
358            adj[i][i] = false;
359        }
360
361        // Separation sets
362        let mut sep_sets: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
363        let mut n_tests = 0usize;
364
365        for ord in 0..=self.max_cond_set_size {
366            let edges: Vec<(usize, usize)> = (0..p)
367                .flat_map(|i| {
368                    (0..p)
369                        .filter(move |&j| i < j)
370                        .collect::<Vec<_>>()
371                        .into_iter()
372                        .map(move |j| (i, j))
373                        .collect::<Vec<_>>()
374                })
375                .filter(|&(i, j)| adj[i][j])
376                .collect();
377            for (x, y) in edges {
378                // Collect adjacent nodes of x, excluding y
379                let z_candidates: Vec<usize> =
380                    (0..p).filter(|&k| k != x && k != y && adj[x][k]).collect();
381                if z_candidates.len() < ord {
382                    continue;
383                }
384                // Test all conditioning sets of size `ord`
385                let mut found_sep = false;
386                'cond: for z_set in subsets(&z_candidates, ord) {
387                    n_tests += 1;
388                    let p_val = partial_correlation_test(data, x, y, &z_set).unwrap_or(1.0);
389                    if p_val > self.alpha {
390                        // Conditionally independent → remove edge
391                        adj[x][y] = false;
392                        adj[y][x] = false;
393                        sep_sets.insert((x.min(y), x.max(y)), z_set);
394                        found_sep = true;
395                        break 'cond;
396                    }
397                }
398                if found_sep {
399                    break;
400                }
401            }
402        }
403
404        // Phase 2: orient v-structures
405        let mut directed: HashMap<(usize, usize), EdgeType> = HashMap::new();
406        // For each X - Z - Y where X - Y is absent, if Z not in sep(X,Y), orient X → Z ← Y
407        for z in 0..p {
408            let neighbours: Vec<usize> = (0..p).filter(|&k| k != z && adj[z][k]).collect();
409            for i in 0..neighbours.len() {
410                for j in (i + 1)..neighbours.len() {
411                    let x = neighbours[i];
412                    let y = neighbours[j];
413                    if adj[x][y] {
414                        continue;
415                    } // x - y edge exists, no v-structure
416                    let key = (x.min(y), x.max(y));
417                    let sep = sep_sets.get(&key).cloned().unwrap_or_default();
418                    if !sep.contains(&z) {
419                        // Orient X → Z ← Y
420                        directed.insert((x, z), EdgeType::Directed);
421                        directed.insert((y, z), EdgeType::Directed);
422                    }
423                }
424            }
425        }
426
427        // Apply Meek's orientation rules R1-R3
428        meek_rules(p, &adj, &mut directed);
429
430        // Build DAG
431        let mut dag = CausalDAG::new();
432        for name in var_names {
433            dag.add_node(name);
434        }
435        let mut edge_info: HashMap<(usize, usize), EdgeType> = HashMap::new();
436
437        for i in 0..p {
438            for j in 0..p {
439                if i == j || !adj[i][j] {
440                    continue;
441                }
442                let et = directed.get(&(i, j)).cloned();
443                match et {
444                    Some(EdgeType::Directed) => {
445                        // Only add if not already in dag (directed i→j)
446                        let _ = dag.add_edge(var_names[i], var_names[j]);
447                        edge_info.insert((i, j), EdgeType::Directed);
448                    }
449                    _ => {
450                        // Undirected: add one direction (i < j to avoid duplicates)
451                        if i < j {
452                            let _ = dag.add_edge(var_names[i], var_names[j]);
453                            edge_info.insert((i, j), EdgeType::Undirected);
454                        }
455                    }
456                }
457            }
458        }
459
460        Ok(StructureLearningResult {
461            dag,
462            score: f64::NAN,
463            algorithm: "PC".to_owned(),
464            n_tests,
465            edge_info,
466        })
467    }
468}
469
470/// Apply Meek's orientation rules to propagate orientations.
471fn meek_rules(p: usize, adj: &[Vec<bool>], directed: &mut HashMap<(usize, usize), EdgeType>) {
472    let mut changed = true;
473    let mut iters = 0;
474    while changed && iters < 100 {
475        changed = false;
476        iters += 1;
477        // R1: if a → b - c and a - c absent, orient b → c
478        for b in 0..p {
479            for a in 0..p {
480                if !adj[a][b] {
481                    continue;
482                }
483                if directed.get(&(a, b)) != Some(&EdgeType::Directed) {
484                    continue;
485                }
486                for c in 0..p {
487                    if c == a || !adj[b][c] {
488                        continue;
489                    }
490                    if directed.contains_key(&(b, c)) {
491                        continue;
492                    }
493                    if !adj[a][c] {
494                        directed.insert((b, c), EdgeType::Directed);
495                        changed = true;
496                    }
497                }
498            }
499        }
500        // R2: if a → b → c and a - c, orient a → c
501        for a in 0..p {
502            for b in 0..p {
503                if directed.get(&(a, b)) != Some(&EdgeType::Directed) {
504                    continue;
505                }
506                for c in 0..p {
507                    if directed.get(&(b, c)) != Some(&EdgeType::Directed) {
508                        continue;
509                    }
510                    if adj[a][c] && !directed.contains_key(&(a, c)) {
511                        directed.insert((a, c), EdgeType::Directed);
512                        changed = true;
513                    }
514                }
515            }
516        }
517    }
518}
519
520// ---------------------------------------------------------------------------
521// 2. FCI Algorithm
522// ---------------------------------------------------------------------------
523
524/// Fast Causal Inference (FCI) algorithm.
525///
526/// Extends the PC algorithm to handle latent common causes by producing
527/// a Partial Ancestral Graph (PAG) with bidirected edges for latent
528/// confounding.
529pub struct FciAlgorithm {
530    /// Significance level for conditional independence tests.
531    pub alpha: f64,
532    /// Maximum conditioning set size.
533    pub max_cond_set_size: usize,
534}
535
536impl Default for FciAlgorithm {
537    fn default() -> Self {
538        Self {
539            alpha: 0.05,
540            max_cond_set_size: 3,
541        }
542    }
543}
544
545impl FciAlgorithm {
546    /// Run FCI on the data, returning a PAG-like structure.
547    pub fn fit(
548        &self,
549        data: ArrayView2<f64>,
550        var_names: &[&str],
551    ) -> StatsResult<StructureLearningResult> {
552        // Phase 1: same skeleton discovery as PC
553        let pc = PcAlgorithm {
554            alpha: self.alpha,
555            max_cond_set_size: self.max_cond_set_size,
556            gaussian: true,
557        };
558        let mut result = pc.fit(data, var_names)?;
559        result.algorithm = "FCI".to_owned();
560
561        // Phase 2: FCI-specific discriminating path orientation
562        // In a full FCI implementation, we would also run the augmented
563        // skeleton discovery (Spirtes 1993 Alg. 4.5). Here we add
564        // potential bidirected edges for ambiguous colliders.
565        let p = var_names.len();
566        let directed_clone = result.edge_info.clone();
567        for i in 0..p {
568            for j in 0..p {
569                if i == j {
570                    continue;
571                }
572                // If both i→j and j→i are NOT in directed, mark as bidirected candidate
573                let ij = directed_clone.get(&(i, j));
574                let ji = directed_clone.get(&(j, i));
575                if ij.is_none() && ji.is_none() {
576                    // Undirected edge — FCI marks as o-o (partially oriented)
577                    if i < j {
578                        result.edge_info.insert((i, j), EdgeType::PartiallyDirected);
579                    }
580                }
581            }
582        }
583
584        Ok(result)
585    }
586}
587
588// ---------------------------------------------------------------------------
589// 3. BIC Greedy Search (GES-style)
590// ---------------------------------------------------------------------------
591
592/// BIC score for a variable given its parents.
593fn bic_score(data: ArrayView2<f64>, node: usize, parents: &[usize], bic_penalty: f64) -> f64 {
594    let n = data.nrows() as f64;
595    let k = parents.len() as f64;
596
597    // Compute residuals: (y - predicted) for each observation
598    let residuals = if parents.is_empty() {
599        let mean = data.column(node).mean().unwrap_or(0.0);
600        data.column(node)
601            .iter()
602            .map(|&y| y - mean)
603            .collect::<Vec<_>>()
604    } else {
605        match ols_residuals(data, node, parents) {
606            Ok(r) => r.to_vec(),
607            Err(_) => return f64::NEG_INFINITY,
608        }
609    };
610
611    let rss: f64 = residuals.iter().map(|r| r * r).sum();
612    let sigma2 = rss / n;
613    if sigma2 < 1e-12 {
614        return 0.0;
615    }
616    // BIC = -2 log L + k log n = n log(σ²) + k log(n)
617    -(n * sigma2.ln() + bic_penalty * (k + 1.0) * n.ln())
618}
619
620/// Greedy hill-climbing score-based structure learning using the BIC score.
621pub struct BicGreedySearch {
622    /// BIC penalty multiplier (default 1.0; higher → sparser graphs).
623    pub penalty: f64,
624    /// Maximum parents per node.
625    pub max_parents: usize,
626    /// Maximum greedy iterations.
627    pub max_iter: usize,
628    /// Random restarts (naïve).
629    pub n_restarts: usize,
630}
631
632impl Default for BicGreedySearch {
633    fn default() -> Self {
634        Self {
635            penalty: 1.0,
636            max_parents: 4,
637            max_iter: 500,
638            n_restarts: 1,
639        }
640    }
641}
642
643impl BicGreedySearch {
644    /// Fit the structure by greedy BIC hill climbing.
645    pub fn fit(
646        &self,
647        data: ArrayView2<f64>,
648        var_names: &[&str],
649    ) -> StatsResult<StructureLearningResult> {
650        let p = data.ncols();
651        if var_names.len() != p {
652            return Err(StatsError::DimensionMismatch(
653                "var_names length mismatch".to_owned(),
654            ));
655        }
656
657        let mut best_dag = CausalDAG::new();
658        for name in var_names {
659            best_dag.add_node(name);
660        }
661        let mut best_score = self.compute_total_bic(data, &vec![vec![]; p]);
662        let mut best_parents = vec![vec![]; p];
663
664        let mut iters = 0usize;
665        let mut current_parents = vec![vec![]; p];
666
667        let mut improved = true;
668        while improved && iters < self.max_iter {
669            improved = false;
670            iters += 1;
671
672            // Try adding each edge not already present
673            for i in 0..p {
674                for j in 0..p {
675                    if i == j {
676                        continue;
677                    }
678                    if current_parents[j].contains(&i) {
679                        continue;
680                    }
681                    if current_parents[j].len() >= self.max_parents {
682                        continue;
683                    }
684                    // Check acyclicity: i should not be a descendant of j
685                    if self.creates_cycle(&current_parents, i, j, p) {
686                        continue;
687                    }
688
689                    let mut trial = current_parents.clone();
690                    trial[j].push(i);
691                    let score = self.compute_total_bic(data, &trial);
692                    if score > best_score {
693                        best_score = score;
694                        best_parents = trial;
695                        improved = true;
696                    }
697                }
698            }
699
700            if improved {
701                current_parents = best_parents.clone();
702            }
703
704            // Try removing each edge
705            improved = false;
706            for j in 0..p {
707                let pa = current_parents[j].clone();
708                for (k, &pi) in pa.iter().enumerate() {
709                    let mut trial = current_parents.clone();
710                    trial[j].remove(k);
711                    let score = self.compute_total_bic(data, &trial);
712                    if score > best_score {
713                        best_score = score;
714                        best_parents = trial;
715                        improved = true;
716                    }
717                    let _ = pi;
718                }
719            }
720            if improved {
721                current_parents = best_parents.clone();
722            }
723        }
724
725        // Build DAG from best parents
726        let mut dag = CausalDAG::new();
727        for name in var_names {
728            dag.add_node(name);
729        }
730        for (j, parents) in best_parents.iter().enumerate() {
731            for &i in parents {
732                let _ = dag.add_edge(var_names[i], var_names[j]);
733            }
734        }
735
736        Ok(StructureLearningResult {
737            dag,
738            score: best_score,
739            algorithm: "BIC Greedy".to_owned(),
740            n_tests: iters,
741            edge_info: HashMap::new(),
742        })
743    }
744
745    fn compute_total_bic(&self, data: ArrayView2<f64>, parents: &[Vec<usize>]) -> f64 {
746        (0..data.ncols())
747            .map(|j| bic_score(data, j, &parents[j], self.penalty))
748            .sum()
749    }
750
751    /// Simple cycle check via DFS on the parent-set representation.
752    fn creates_cycle(
753        &self,
754        parents: &[Vec<usize>],
755        new_parent: usize,
756        child: usize,
757        p: usize,
758    ) -> bool {
759        // Check if `child` is an ancestor of `new_parent` in current parents graph
760        let mut visited = HashSet::new();
761        let mut stack = vec![new_parent];
762        while let Some(cur) = stack.pop() {
763            if cur == child {
764                return true;
765            }
766            if !visited.insert(cur) {
767                continue;
768            }
769            for &pa in &parents[cur] {
770                stack.push(pa);
771            }
772        }
773        let _ = p;
774        false
775    }
776}
777
778// ---------------------------------------------------------------------------
779// 4. LiNGAM
780// ---------------------------------------------------------------------------
781
782/// Linear Non-Gaussian Acyclic Model (LiNGAM).
783///
784/// Estimates the causal ordering and connection strengths for linear
785/// structural equation models with non-Gaussian errors.
786///
787/// Uses the ICA-based approach of Shimizu et al. (2006): identifies the
788/// causal order from the independent components of the whitened data via
789/// a FastICA variant followed by row permutation.
790pub struct LiNGAM {
791    /// Maximum number of ICA iterations.
792    pub max_iter: usize,
793    /// ICA convergence tolerance.
794    pub tol: f64,
795    /// Threshold below which a coefficient is set to zero.
796    pub threshold: f64,
797}
798
799impl Default for LiNGAM {
800    fn default() -> Self {
801        Self {
802            max_iter: 500,
803            tol: 1e-6,
804            threshold: 0.1,
805        }
806    }
807}
808
809/// Result of LiNGAM.
810#[derive(Debug, Clone)]
811pub struct LiNGAMResult {
812    /// Estimated causal ordering of variables (topological sort).
813    pub causal_order: Vec<usize>,
814    /// Estimated connection strength matrix B (`B[i,j]` = effect of j on i).
815    pub b_matrix: Array2<f64>,
816    /// Learned DAG.
817    pub dag: CausalDAG,
818}
819
820impl LiNGAM {
821    /// Fit LiNGAM.
822    pub fn fit(&self, data: ArrayView2<f64>, var_names: &[&str]) -> StatsResult<LiNGAMResult> {
823        let (n, p) = data.dim();
824        if var_names.len() != p {
825            return Err(StatsError::DimensionMismatch(
826                "var_names must equal ncols".to_owned(),
827            ));
828        }
829
830        // Centre the data
831        let means: Array1<f64> = (0..p)
832            .map(|j| data.column(j).mean().unwrap_or(0.0))
833            .collect();
834        let mut xc = data.to_owned();
835        for i in 0..n {
836            for j in 0..p {
837                xc[[i, j]] -= means[j];
838            }
839        }
840
841        // Whiten: X ← W X  where W = Σ^{-1/2}
842        let (xw, whitening_matrix) = whiten(xc.view())?;
843
844        // FastICA to estimate unmixing matrix
845        let w_ica = fast_ica(xw.view(), self.max_iter, self.tol)?;
846
847        // Combined unmixing: A_hat = W^{-1} W_ICA^{-1}
848        // The mixing matrix is A = W^{-1} (W_ICA)^{-1}
849        let a_hat = pseudo_inverse_2x2_general(&w_ica, p)?;
850
851        // Scale rows so diagonal is 1 (Doolabh & Kaliath normalisation)
852        let b_matrix = normalise_lingam(a_hat, p);
853
854        // Determine causal order: prune and search for permutation
855        let causal_order = lingam_order(&b_matrix, p);
856
857        // Build DAG
858        let mut dag = CausalDAG::new();
859        for name in var_names {
860            dag.add_node(name);
861        }
862        for j in 0..p {
863            for i in 0..p {
864                if i == j {
865                    continue;
866                }
867                if b_matrix[[i, j]].abs() > self.threshold {
868                    // j causes i (b[i,j] = effect of j on i)
869                    let _ = dag.add_edge(var_names[j], var_names[i]);
870                }
871            }
872        }
873        let _ = whitening_matrix;
874
875        Ok(LiNGAMResult {
876            causal_order,
877            b_matrix,
878            dag,
879        })
880    }
881}
882
883/// Whiten data (zero mean, identity covariance).
884fn whiten(data: ArrayView2<f64>) -> StatsResult<(Array2<f64>, Array2<f64>)> {
885    let (n, p) = data.dim();
886    // Covariance matrix
887    let mut cov = Array2::<f64>::zeros((p, p));
888    for i in 0..n {
889        for j in 0..p {
890            for k in 0..p {
891                cov[[j, k]] += data[[i, j]] * data[[i, k]];
892            }
893        }
894    }
895    cov.mapv_inplace(|x| x / n as f64);
896
897    // Eigendecomposition via Jacobi iteration (simple, correct for moderate p)
898    let (eigvals, eigvecs) = jacobi_eigen(cov.view(), 100)?;
899
900    // W = D^{-1/2} V'  (whitening matrix)
901    let mut w = Array2::<f64>::zeros((p, p));
902    for i in 0..p {
903        let scale = if eigvals[i] > 1e-10 {
904            eigvals[i].sqrt().recip()
905        } else {
906            0.0
907        };
908        for j in 0..p {
909            w[[i, j]] = scale * eigvecs[[j, i]]; // eigvecs[:,i] is i-th eigenvector
910        }
911    }
912
913    // Apply whitening
914    let mut xw = Array2::<f64>::zeros((n, p));
915    for i in 0..n {
916        for j in 0..p {
917            for k in 0..p {
918                xw[[i, j]] += w[[j, k]] * data[[i, k]];
919            }
920        }
921    }
922    Ok((xw, w))
923}
924
925/// One-sided Jacobi eigendecomposition (symmetric matrix).
926fn jacobi_eigen(a: ArrayView2<f64>, max_iter: usize) -> StatsResult<(Array1<f64>, Array2<f64>)> {
927    let n = a.nrows();
928    let mut d = a.to_owned();
929    let mut v = Array2::<f64>::eye(n);
930    for _ in 0..max_iter {
931        // Find largest off-diagonal
932        let mut max_val = 0.0_f64;
933        let (mut p, mut q) = (0, 1);
934        for i in 0..n {
935            for j in (i + 1)..n {
936                if d[[i, j]].abs() > max_val {
937                    max_val = d[[i, j]].abs();
938                    p = i;
939                    q = j;
940                }
941            }
942        }
943        if max_val < 1e-12 {
944            break;
945        }
946        let theta = if (d[[p, p]] - d[[q, q]]).abs() < 1e-12 {
947            std::f64::consts::FRAC_PI_4
948        } else {
949            0.5 * ((2.0 * d[[p, q]]) / (d[[q, q]] - d[[p, p]])).atan()
950        };
951        let (s, c) = theta.sin_cos();
952        // Update D and V
953        let (dpp, dqq, dpq) = (d[[p, p]], d[[q, q]], d[[p, q]]);
954        d[[p, p]] = c * c * dpp - 2.0 * s * c * dpq + s * s * dqq;
955        d[[q, q]] = s * s * dpp + 2.0 * s * c * dpq + c * c * dqq;
956        d[[p, q]] = 0.0;
957        d[[q, p]] = 0.0;
958        for k in 0..n {
959            if k != p && k != q {
960                let dpk = d[[p, k]];
961                let dqk = d[[q, k]];
962                d[[p, k]] = c * dpk - s * dqk;
963                d[[k, p]] = d[[p, k]];
964                d[[q, k]] = s * dpk + c * dqk;
965                d[[k, q]] = d[[q, k]];
966            }
967            let vpk = v[[k, p]];
968            let vqk = v[[k, q]];
969            v[[k, p]] = c * vpk - s * vqk;
970            v[[k, q]] = s * vpk + c * vqk;
971        }
972    }
973    let eigvals: Array1<f64> = (0..n).map(|i| d[[i, i]]).collect();
974    Ok((eigvals, v))
975}
976
977/// Simplified FastICA (deflation, neg-entropy approximation).
978fn fast_ica(xw: ArrayView2<f64>, max_iter: usize, tol: f64) -> StatsResult<Array2<f64>> {
979    let (n, p) = xw.dim();
980    let mut w_mat = Array2::<f64>::eye(p);
981
982    for comp in 0..p {
983        let mut w = Array1::<f64>::from_shape_fn(p, |i| if i == comp { 1.0 } else { 0.0 });
984
985        for _ in 0..max_iter {
986            // Project
987            let wx: Vec<f64> = (0..n)
988                .map(|i| {
989                    w.iter()
990                        .zip(xw.row(i).iter())
991                        .map(|(a, b)| a * b)
992                        .sum::<f64>()
993                })
994                .collect();
995
996            // Non-linearity g(u) = tanh(u), g'(u) = 1 - tanh(u)^2
997            let g: Vec<f64> = wx.iter().map(|&u| u.tanh()).collect();
998            let gp: Vec<f64> = wx.iter().map(|&u| 1.0 - u.tanh().powi(2)).collect();
999
1000            let mut w_new = Array1::<f64>::zeros(p);
1001            for i in 0..n {
1002                for j in 0..p {
1003                    w_new[j] += g[i] * xw[[i, j]];
1004                }
1005            }
1006            w_new.mapv_inplace(|x| x / n as f64);
1007            let gp_mean = gp.iter().sum::<f64>() / n as f64;
1008            for j in 0..p {
1009                w_new[j] -= gp_mean * w[j];
1010            }
1011
1012            // Orthogonalise against previous components
1013            for prev in 0..comp {
1014                let w_prev = w_mat.row(prev);
1015                let dot: f64 = w_new.iter().zip(w_prev.iter()).map(|(a, b)| a * b).sum();
1016                for j in 0..p {
1017                    w_new[j] -= dot * w_prev[j];
1018                }
1019            }
1020
1021            // Normalise
1022            let norm: f64 = w_new
1023                .iter()
1024                .map(|x| x * x)
1025                .sum::<f64>()
1026                .sqrt()
1027                .max(f64::EPSILON);
1028            w_new.mapv_inplace(|x| x / norm);
1029
1030            let diff: f64 = w
1031                .iter()
1032                .zip(w_new.iter())
1033                .map(|(a, b)| (a - b).powi(2))
1034                .sum::<f64>()
1035                .sqrt();
1036            w = w_new;
1037            if diff < tol {
1038                break;
1039            }
1040        }
1041        for j in 0..p {
1042            w_mat[[comp, j]] = w[j];
1043        }
1044    }
1045    Ok(w_mat)
1046}
1047
1048fn pseudo_inverse_2x2_general(w: &Array2<f64>, p: usize) -> StatsResult<Array2<f64>> {
1049    // Compute pseudo-inverse via SVD (Jacobi) or direct inversion
1050    // For small p, direct Gauss-Jordan inversion works
1051    let mut aug = Array2::<f64>::zeros((p, 2 * p));
1052    for i in 0..p {
1053        for j in 0..p {
1054            aug[[i, j]] = w[[i, j]];
1055        }
1056        aug[[i, p + i]] = 1.0;
1057    }
1058    for col in 0..p {
1059        let pivot = (col..p)
1060            .max_by(|&i, &j| {
1061                aug[[i, col]]
1062                    .abs()
1063                    .partial_cmp(&aug[[j, col]].abs())
1064                    .unwrap_or(std::cmp::Ordering::Equal)
1065            })
1066            .ok_or_else(|| {
1067                StatsError::ComputationError("Singular ICA unmixing matrix".to_owned())
1068            })?;
1069        for k in 0..(2 * p) {
1070            let tmp = aug[[col, k]];
1071            aug[[col, k]] = aug[[pivot, k]];
1072            aug[[pivot, k]] = tmp;
1073        }
1074        let piv_val = aug[[col, col]];
1075        if piv_val.abs() < 1e-12 {
1076            return Err(StatsError::ComputationError("Singular".to_owned()));
1077        }
1078        for k in 0..(2 * p) {
1079            aug[[col, k]] /= piv_val;
1080        }
1081        for row in 0..p {
1082            if row != col {
1083                let factor = aug[[row, col]];
1084                for k in 0..(2 * p) {
1085                    let av = aug[[col, k]];
1086                    aug[[row, k]] -= factor * av;
1087                }
1088            }
1089        }
1090    }
1091    let mut inv = Array2::<f64>::zeros((p, p));
1092    for i in 0..p {
1093        for j in 0..p {
1094            inv[[i, j]] = aug[[i, p + j]];
1095        }
1096    }
1097    Ok(inv)
1098}
1099
1100fn normalise_lingam(mut b: Array2<f64>, p: usize) -> Array2<f64> {
1101    for i in 0..p {
1102        let diag = b[[i, i]];
1103        if diag.abs() > 1e-10 {
1104            for j in 0..p {
1105                b[[i, j]] /= diag;
1106            }
1107        }
1108    }
1109    for i in 0..p {
1110        b[[i, i]] = 0.0;
1111    }
1112    b
1113}
1114
1115fn lingam_order(b: &Array2<f64>, p: usize) -> Vec<usize> {
1116    // Simple: find a permutation where B is lower triangular
1117    // Use the row with smallest L1 norm (most zeros) as first causal variable
1118    let mut remaining: Vec<usize> = (0..p).collect();
1119    let mut order = Vec::with_capacity(p);
1120    while !remaining.is_empty() {
1121        let best = remaining
1122            .iter()
1123            .min_by(|&&i, &&j| {
1124                let li: f64 = remaining
1125                    .iter()
1126                    .filter(|&&k| k != i)
1127                    .map(|&k| b[[i, k]].abs())
1128                    .sum();
1129                let lj: f64 = remaining
1130                    .iter()
1131                    .filter(|&&k| k != j)
1132                    .map(|&k| b[[j, k]].abs())
1133                    .sum();
1134                li.partial_cmp(&lj).unwrap_or(std::cmp::Ordering::Equal)
1135            })
1136            .copied()
1137            .unwrap_or(remaining[0]);
1138        order.push(best);
1139        remaining.retain(|&x| x != best);
1140    }
1141    order
1142}
1143
1144// ---------------------------------------------------------------------------
1145// 5. NOTEARS
1146// ---------------------------------------------------------------------------
1147
1148/// NOTEARS: gradient-based DAG learning via a smooth acyclicity constraint.
1149///
1150/// Minimises  ½||X - XW||²_F / n  subject to h(W) = tr(e^{W◦W}) - p = 0,
1151/// where W is the weighted adjacency matrix of the DAG.
1152///
1153/// Reference: Zheng et al. (2018) *DAGs with NO TEARS*, NeurIPS.
1154pub struct Notears {
1155    /// Regularisation strength λ (L1 penalty on edge weights).
1156    pub lambda: f64,
1157    /// Maximum number of augmented Lagrangian outer iterations.
1158    pub max_iter: usize,
1159    /// Maximum steps per inner optimisation.
1160    pub max_inner_iter: usize,
1161    /// Acyclicity tolerance.
1162    pub h_tol: f64,
1163    /// Edge weight threshold (edges with |w| < threshold are pruned).
1164    pub w_threshold: f64,
1165}
1166
1167impl Default for Notears {
1168    fn default() -> Self {
1169        Self {
1170            lambda: 0.1,
1171            max_iter: 100,
1172            max_inner_iter: 300,
1173            h_tol: 1e-8,
1174            w_threshold: 0.3,
1175        }
1176    }
1177}
1178
1179impl Notears {
1180    /// Fit NOTEARS on the data matrix.
1181    pub fn fit(
1182        &self,
1183        data: ArrayView2<f64>,
1184        var_names: &[&str],
1185    ) -> StatsResult<StructureLearningResult> {
1186        let (n, p) = data.dim();
1187        if var_names.len() != p {
1188            return Err(StatsError::DimensionMismatch(
1189                "var_names mismatch".to_owned(),
1190            ));
1191        }
1192
1193        // Centre the data
1194        let means: Array1<f64> = (0..p)
1195            .map(|j| data.column(j).mean().unwrap_or(0.0))
1196            .collect();
1197        let mut xc = data.to_owned();
1198        for i in 0..n {
1199            for j in 0..p {
1200                xc[[i, j]] -= means[j];
1201            }
1202        }
1203
1204        // Augmented Lagrangian with penalty ρ
1205        let mut w = Array2::<f64>::zeros((p, p));
1206        let mut alpha = 0.0_f64; // Lagrange multiplier
1207        let mut rho = 1.0_f64;
1208        let rho_max = 1e16_f64;
1209        let mut h_prev = f64::INFINITY;
1210        let mut outer_iters = 0usize;
1211
1212        for _ in 0..self.max_iter {
1213            outer_iters += 1;
1214            // Inner optimisation (gradient descent on augmented Lagrangian)
1215            w = self.inner_optim(xc.view(), &w, alpha, rho, n, p)?;
1216            let h_val = notears_h(&w, p);
1217
1218            if h_val.abs() < self.h_tol {
1219                break;
1220            }
1221
1222            // Update multiplier and penalty
1223            alpha += rho * h_val;
1224            if h_val > 0.25 * h_prev {
1225                rho = (rho * 10.0).min(rho_max);
1226            }
1227            h_prev = h_val;
1228        }
1229
1230        // Threshold and build DAG
1231        let mut dag = CausalDAG::new();
1232        for name in var_names {
1233            dag.add_node(name);
1234        }
1235        let mut edge_info = HashMap::new();
1236        for i in 0..p {
1237            for j in 0..p {
1238                if i == j {
1239                    continue;
1240                }
1241                if w[[i, j]].abs() > self.w_threshold {
1242                    let _ = dag.add_edge(var_names[i], var_names[j]);
1243                    edge_info.insert((i, j), EdgeType::Directed);
1244                }
1245            }
1246        }
1247
1248        Ok(StructureLearningResult {
1249            dag,
1250            score: -notears_loss(xc.view(), &w, n, p),
1251            algorithm: "NOTEARS".to_owned(),
1252            n_tests: outer_iters,
1253            edge_info,
1254        })
1255    }
1256
1257    fn inner_optim(
1258        &self,
1259        x: ArrayView2<f64>,
1260        w_init: &Array2<f64>,
1261        alpha: f64,
1262        rho: f64,
1263        n: usize,
1264        p: usize,
1265    ) -> StatsResult<Array2<f64>> {
1266        let mut w = w_init.clone();
1267        let lr = 1e-3;
1268
1269        for _step in 0..self.max_inner_iter {
1270            let grad = self.aug_lagrangian_gradient(x, &w, alpha, rho, n, p);
1271            // Proximal gradient step for L1
1272            let mut w_new = Array2::<f64>::zeros((p, p));
1273            for i in 0..p {
1274                for j in 0..p {
1275                    if i == j {
1276                        continue;
1277                    }
1278                    let u = w[[i, j]] - lr * grad[[i, j]];
1279                    // Soft thresholding
1280                    w_new[[i, j]] = if u > lr * self.lambda {
1281                        u - lr * self.lambda
1282                    } else if u < -lr * self.lambda {
1283                        u + lr * self.lambda
1284                    } else {
1285                        0.0
1286                    };
1287                }
1288            }
1289            let diff: f64 = {
1290                let mut d = 0.0_f64;
1291                for ii in 0..p {
1292                    for jj in 0..p {
1293                        d += (w_new[[ii, jj]] - w[[ii, jj]]).powi(2);
1294                    }
1295                }
1296                d.sqrt()
1297            };
1298            w = w_new;
1299            if diff < 1e-6 {
1300                break;
1301            }
1302        }
1303        Ok(w)
1304    }
1305
1306    fn aug_lagrangian_gradient(
1307        &self,
1308        x: ArrayView2<f64>,
1309        w: &Array2<f64>,
1310        alpha: f64,
1311        rho: f64,
1312        n: usize,
1313        p: usize,
1314    ) -> Array2<f64> {
1315        // Gradient of ½||X - XW||² / n + (α + ρ h(W)/2) ∂h/∂W
1316        let mut grad = Array2::<f64>::zeros((p, p));
1317
1318        // Least squares gradient: -X' (X - XW) / n = (X'X W - X'X) / n
1319        // = X'(XW - X) / n
1320        let xw = x_times_w(x, w, n, p);
1321        for i in 0..p {
1322            for j in 0..p {
1323                if i == j {
1324                    continue;
1325                }
1326                let mut g = 0.0_f64;
1327                for k in 0..n {
1328                    g += x[[k, i]] * (xw[[k, j]] - x[[k, j]]);
1329                }
1330                grad[[i, j]] = g / n as f64;
1331            }
1332        }
1333
1334        // Acyclicity gradient: ∂h/∂W = (e^{W◦W})' ◦ 2W
1335        let exp_ww = notears_exp_ww(w, p);
1336        let h = exp_ww
1337            .iter()
1338            .enumerate()
1339            .filter(|(i, _)| i / p == i % p)
1340            .map(|(_, &v)| v)
1341            .sum::<f64>()
1342            - p as f64;
1343        let dh_dw = notears_dh_dw(&exp_ww, w, p);
1344        for i in 0..p {
1345            for j in 0..p {
1346                grad[[i, j]] += (alpha + rho * h) * dh_dw[[i, j]];
1347            }
1348        }
1349        grad
1350    }
1351}
1352
1353fn x_times_w(x: ArrayView2<f64>, w: &Array2<f64>, n: usize, p: usize) -> Array2<f64> {
1354    let mut xw = Array2::<f64>::zeros((n, p));
1355    for i in 0..n {
1356        for j in 0..p {
1357            for k in 0..p {
1358                xw[[i, j]] += x[[i, k]] * w[[k, j]];
1359            }
1360        }
1361    }
1362    xw
1363}
1364
1365fn notears_h(w: &Array2<f64>, p: usize) -> f64 {
1366    // h(W) = tr(e^{W◦W}) - p
1367    let exp_ww = notears_exp_ww(w, p);
1368    (0..p).map(|i| exp_ww[[i, i]]).sum::<f64>() - p as f64
1369}
1370
1371/// Compute matrix exponential of W◦W (element-wise squared adjacency).
1372fn notears_exp_ww(w: &Array2<f64>, p: usize) -> Array2<f64> {
1373    // W◦W
1374    let ww: Array2<f64> = w.mapv(|x| x * x);
1375    // Matrix exponential via Taylor series (10 terms)
1376    let mut result = Array2::<f64>::eye(p);
1377    let mut term = Array2::<f64>::eye(p);
1378    let mut factorial = 1.0_f64;
1379    for k in 1..=15_usize {
1380        factorial *= k as f64;
1381        // term = term * ww
1382        let mut new_term = Array2::<f64>::zeros((p, p));
1383        for i in 0..p {
1384            for j in 0..p {
1385                for l in 0..p {
1386                    new_term[[i, j]] += term[[i, l]] * ww[[l, j]];
1387                }
1388            }
1389        }
1390        term = new_term;
1391        for i in 0..p {
1392            for j in 0..p {
1393                result[[i, j]] += term[[i, j]] / factorial;
1394            }
1395        }
1396        if term.iter().map(|x| x.abs()).fold(0.0_f64, f64::max) < 1e-12 {
1397            break;
1398        }
1399    }
1400    result
1401}
1402
1403fn notears_dh_dw(exp_ww: &Array2<f64>, w: &Array2<f64>, p: usize) -> Array2<f64> {
1404    // ∂h/∂W_ij = [e^{W◦W}]_ij' × 2 W_ij  (transpose of exp_ww × 2W)
1405    let mut dh = Array2::<f64>::zeros((p, p));
1406    for i in 0..p {
1407        for j in 0..p {
1408            dh[[i, j]] = exp_ww[[j, i]] * 2.0 * w[[i, j]];
1409        }
1410    }
1411    dh
1412}
1413
1414fn notears_loss(x: ArrayView2<f64>, w: &Array2<f64>, n: usize, p: usize) -> f64 {
1415    let xw = x_times_w(x, w, n, p);
1416    let mut loss = 0.0_f64;
1417    for i in 0..n {
1418        for j in 0..p {
1419            loss += (xw[[i, j]] - x[[i, j]]).powi(2);
1420        }
1421    }
1422    loss / (2.0 * n as f64)
1423}
1424
1425// ---------------------------------------------------------------------------
1426// Helpers
1427// ---------------------------------------------------------------------------
1428
1429fn subsets<T: Copy>(items: &[T], k: usize) -> Vec<Vec<T>> {
1430    if k == 0 {
1431        return vec![Vec::new()];
1432    }
1433    if k > items.len() {
1434        return Vec::new();
1435    }
1436    let mut result = Vec::new();
1437    for i in 0..=(items.len() - k) {
1438        for mut rest in subsets(&items[i + 1..], k - 1) {
1439            rest.insert(0, items[i]);
1440            result.push(rest);
1441        }
1442    }
1443    result
1444}
1445
1446// ---------------------------------------------------------------------------
1447// Unit tests
1448// ---------------------------------------------------------------------------
1449
1450#[cfg(test)]
1451mod tests {
1452    use super::*;
1453    use scirs2_core::ndarray::Array2;
1454
1455    fn chain_data() -> Array2<f64> {
1456        // X -> Y -> Z  with independent Gaussian noise
1457        let n = 100;
1458        let mut data = Array2::<f64>::zeros((n, 3));
1459        let mut lcg: u64 = 12345;
1460        let next = |s: &mut u64| -> f64 {
1461            // Advance LCG twice to get two independent uniform samples
1462            *s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
1463            let u = (*s >> 33) as f64 / (1u64 << 31) as f64;
1464            *s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
1465            // v must be in (0, 1] for Box-Muller log
1466            let v = ((*s >> 33) as f64 / (1u64 << 31) as f64).max(1e-10);
1467            // Box-Muller transform
1468            (-2.0 * v.ln()).sqrt() * (2.0 * std::f64::consts::PI * u).cos()
1469        };
1470        for i in 0..n {
1471            data[[i, 0]] = next(&mut lcg);
1472            data[[i, 1]] = 0.8 * data[[i, 0]] + next(&mut lcg) * 0.5;
1473            data[[i, 2]] = 0.8 * data[[i, 1]] + next(&mut lcg) * 0.5;
1474        }
1475        data
1476    }
1477
1478    #[test]
1479    fn test_pc_runs() {
1480        let data = chain_data();
1481        let pc = PcAlgorithm::default();
1482        let res = pc.fit(data.view(), &["X", "Y", "Z"]).unwrap();
1483        assert_eq!(res.algorithm, "PC");
1484        assert!(res.dag.n_nodes() == 3);
1485    }
1486
1487    #[test]
1488    fn test_fci_runs() {
1489        let data = chain_data();
1490        let fci = FciAlgorithm::default();
1491        let res = fci.fit(data.view(), &["X", "Y", "Z"]).unwrap();
1492        assert_eq!(res.algorithm, "FCI");
1493    }
1494
1495    #[test]
1496    fn test_bic_greedy() {
1497        let data = chain_data();
1498        let learner = BicGreedySearch {
1499            max_iter: 50,
1500            ..Default::default()
1501        };
1502        let res = learner.fit(data.view(), &["X", "Y", "Z"]).unwrap();
1503        // n_edges() returns usize (always >= 0); just check score is valid
1504        assert!(!res.score.is_nan());
1505    }
1506
1507    #[test]
1508    fn test_lingam_runs() {
1509        let data = chain_data();
1510        let ling = LiNGAM::default();
1511        let res = ling.fit(data.view(), &["X", "Y", "Z"]).unwrap();
1512        assert_eq!(res.causal_order.len(), 3);
1513        assert_eq!(res.b_matrix.nrows(), 3);
1514    }
1515
1516    #[test]
1517    fn test_notears_runs() {
1518        let data = chain_data();
1519        let nt = Notears {
1520            max_iter: 5,
1521            max_inner_iter: 10,
1522            ..Default::default()
1523        };
1524        let res = nt.fit(data.view(), &["X", "Y", "Z"]).unwrap();
1525        assert_eq!(res.dag.n_nodes(), 3);
1526    }
1527
1528    #[test]
1529    fn test_partial_correlation_independence() {
1530        // X and Z independent given Y in chain X→Y→Z
1531        let data = chain_data();
1532        let p_val = partial_correlation_test(data.view(), 0, 2, &[1]).unwrap();
1533        // Should have large p-value (not reject independence)
1534        assert!(p_val > 0.01, "p={p_val}");
1535    }
1536}