Skip to main content

scirs2_stats/bayesian_network/
structure_learning.rs

1//! Bayesian Network Structure Learning.
2//!
3//! Provides:
4//! - [`PCAlgorithm`] — constraint-based learning (Spirtes-Glymour-Scheines)
5//! - [`HillClimbing`] — score-based greedy search with tabu list
6//! - [`BIC`] — BIC score for discrete Bayesian Networks
7
8use super::{cpd::TabularCPD, dag::DAG};
9use crate::StatsError;
10use std::collections::{HashMap, HashSet, VecDeque};
11
12// ---------------------------------------------------------------------------
13// Utilities: discrete data statistics
14// ---------------------------------------------------------------------------
15
16/// Count unique values for each variable in the data.
17pub fn count_cardinalities(data: &[Vec<f64>]) -> Vec<usize> {
18    if data.is_empty() {
19        return Vec::new();
20    }
21    let n_vars = data[0].len();
22    let mut cards = vec![0usize; n_vars];
23    for row in data {
24        for (j, &val) in row.iter().enumerate().take(n_vars) {
25            let v = val.round() as usize;
26            if v + 1 > cards[j] {
27                cards[j] = v + 1;
28            }
29        }
30    }
31    // Ensure minimum cardinality of 2
32    cards.iter().map(|&c| c.max(2)).collect()
33}
34
35/// Compute sample correlation between two variables.
36fn sample_corr(data: &[Vec<f64>], x: usize, y: usize) -> f64 {
37    let n = data.len() as f64;
38    let mean_x = data.iter().map(|r| r[x]).sum::<f64>() / n;
39    let mean_y = data.iter().map(|r| r[y]).sum::<f64>() / n;
40    let cov: f64 = data
41        .iter()
42        .map(|r| (r[x] - mean_x) * (r[y] - mean_y))
43        .sum::<f64>()
44        / n;
45    let var_x: f64 = data.iter().map(|r| (r[x] - mean_x).powi(2)).sum::<f64>() / n;
46    let var_y: f64 = data.iter().map(|r| (r[y] - mean_y).powi(2)).sum::<f64>() / n;
47    if var_x < 1e-15 || var_y < 1e-15 {
48        return 0.0;
49    }
50    (cov / (var_x.sqrt() * var_y.sqrt())).clamp(-1.0, 1.0)
51}
52
53/// Compute the partial correlation of X and Y given the set Z.
54///
55/// Uses recursive formula via the Gram matrix.
56pub fn partial_correlation(data: &[Vec<f64>], x: usize, y: usize, z: &[usize]) -> f64 {
57    if z.is_empty() {
58        return sample_corr(data, x, y);
59    }
60    // Build correlation matrix for {x, y} ∪ z
61    let mut vars = vec![x, y];
62    vars.extend_from_slice(z);
63    vars.sort_unstable();
64    vars.dedup();
65    let idx_x = vars.iter().position(|&v| v == x).unwrap_or(0);
66    let idx_y = vars.iter().position(|&v| v == y).unwrap_or(0);
67    let m = vars.len();
68    // Build correlation matrix
69    let mut corr = vec![vec![0.0f64; m]; m];
70    for i in 0..m {
71        corr[i][i] = 1.0;
72        for j in (i + 1)..m {
73            let c = sample_corr(data, vars[i], vars[j]);
74            corr[i][j] = c;
75            corr[j][i] = c;
76        }
77    }
78    // Partial correlation via matrix inversion (Gaussian elimination)
79    let inv = invert_matrix(&corr).unwrap_or_else(|| vec![vec![0.0; m]; m]);
80    let px = inv[idx_x][idx_x];
81    let py = inv[idx_y][idx_y];
82    let pxy = inv[idx_x][idx_y];
83    if px < 1e-15 || py < 1e-15 {
84        return 0.0;
85    }
86    (-pxy / (px * py).sqrt()).clamp(-1.0, 1.0)
87}
88
89/// Invert a small symmetric matrix via Gaussian elimination with partial pivoting.
90fn invert_matrix(mat: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
91    let n = mat.len();
92    let mut a: Vec<Vec<f64>> = mat.to_vec();
93    let mut inv: Vec<Vec<f64>> = (0..n)
94        .map(|i| {
95            let mut row = vec![0.0; n];
96            row[i] = 1.0;
97            row
98        })
99        .collect();
100    for col in 0..n {
101        // Pivot
102        let pivot_row = (col..n).max_by(|&i, &j| {
103            a[i][col]
104                .abs()
105                .partial_cmp(&a[j][col].abs())
106                .unwrap_or(std::cmp::Ordering::Equal)
107        })?;
108        a.swap(col, pivot_row);
109        inv.swap(col, pivot_row);
110        let pivot = a[col][col];
111        if pivot.abs() < 1e-15 {
112            return None;
113        }
114        for j in 0..n {
115            a[col][j] /= pivot;
116            inv[col][j] /= pivot;
117        }
118        for row in 0..n {
119            if row == col {
120                continue;
121            }
122            let factor = a[row][col];
123            for j in 0..n {
124                let av = a[col][j];
125                let iv = inv[col][j];
126                a[row][j] -= factor * av;
127                inv[row][j] -= factor * iv;
128            }
129        }
130    }
131    Some(inv)
132}
133
134/// Fisher's z-transformation partial correlation independence test.
135///
136/// Returns the p-value. A p-value > alpha indicates conditional independence.
137pub fn fisherz_test(data: &[Vec<f64>], x: usize, y: usize, z: &[usize]) -> f64 {
138    let n = data.len() as f64;
139    let r = partial_correlation(data, x, y, z);
140    let r_clamped = r.clamp(-1.0 + 1e-10, 1.0 - 1e-10);
141    let fisher_z = 0.5 * ((1.0 + r_clamped) / (1.0 - r_clamped)).ln();
142    let dof = (n - z.len() as f64 - 3.0).max(1.0);
143    let stat = fisher_z.abs() * dof.sqrt();
144    // Two-sided p-value approximation: 2 * Φ(-|z|) ≈ 2 * erfc(|z| / sqrt(2)) / 2
145    // Using normal approximation
146    2.0 * normal_sf(stat)
147}
148
149/// Approximate survival function of the standard normal: P(Z > x).
150fn normal_sf(x: f64) -> f64 {
151    0.5 * erfc_approx(x / std::f64::consts::SQRT_2)
152}
153
154/// Approximation to erfc(x) using Horner's method.
155fn erfc_approx(x: f64) -> f64 {
156    // Abramowitz & Stegun 7.1.26 approximation
157    let t = 1.0 / (1.0 + 0.3275911 * x.abs());
158    let poly = t
159        * (0.254829592
160            + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
161    let result = poly * (-x * x).exp();
162    if x >= 0.0 {
163        result
164    } else {
165        2.0 - result
166    }
167}
168
169// ---------------------------------------------------------------------------
170// PCAlgorithm
171// ---------------------------------------------------------------------------
172
173/// PC (Peter-Clark) algorithm for constraint-based structure learning.
174///
175/// Phase 1: Start with complete undirected graph; remove edges by testing
176///          conditional independence with increasing separator set sizes.
177/// Phase 2: Orient v-structures (colliders).
178/// Phase 3: Apply Meek orientation rules to avoid new cycles / v-structures.
179#[derive(Debug, Clone)]
180pub struct PCAlgorithm {
181    /// Significance level for conditional independence tests.
182    pub alpha: f64,
183    /// Maximum conditioning set size.
184    pub max_cond_set: usize,
185}
186
187impl Default for PCAlgorithm {
188    fn default() -> Self {
189        Self {
190            alpha: 0.05,
191            max_cond_set: 3,
192        }
193    }
194}
195
196impl PCAlgorithm {
197    /// Create a new PCAlgorithm.
198    pub fn new(alpha: f64, max_cond_set: usize) -> Self {
199        Self {
200            alpha,
201            max_cond_set,
202        }
203    }
204
205    /// Learn the DAG from continuous data using Fisher's z test.
206    pub fn fit(&self, data: &[Vec<f64>]) -> Result<DAG, StatsError> {
207        if data.is_empty() {
208            return Err(StatsError::InvalidInput("Empty data".to_string()));
209        }
210        let n = data[0].len();
211        if n < 2 {
212            return Err(StatsError::InvalidInput(
213                "Need at least 2 variables".to_string(),
214            ));
215        }
216
217        // Phase 1: Skeleton learning
218        // Start with complete undirected graph
219        let mut adj: Vec<HashSet<usize>> = (0..n)
220            .map(|i| (0..n).filter(|&j| j != i).collect())
221            .collect();
222
223        // Separator sets: sep[i][j] = conditioning set that made i-j independent
224        let mut sep: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
225
226        let mut cond_size = 0usize;
227        loop {
228            let mut removed = false;
229            let edges: Vec<(usize, usize)> = (0..n)
230                .flat_map(|i| adj[i].iter().map(move |&j| (i, j)))
231                .filter(|&(i, j)| i < j)
232                .collect();
233
234            for (x, y) in edges {
235                if !adj[x].contains(&y) {
236                    continue;
237                }
238                // Get adjacents of x (excluding y)
239                let adj_x: Vec<usize> = adj[x].iter().copied().filter(|&v| v != y).collect();
240                if adj_x.len() < cond_size {
241                    continue;
242                }
243                // Enumerate conditioning sets of size `cond_size` from adj_x
244                for cond_set in subsets(&adj_x, cond_size) {
245                    let p = fisherz_test(data, x, y, &cond_set);
246                    if p > self.alpha {
247                        // Remove edge x-y
248                        adj[x].remove(&y);
249                        adj[y].remove(&x);
250                        sep.insert((x, y), cond_set.clone());
251                        sep.insert((y, x), cond_set);
252                        removed = true;
253                        break;
254                    }
255                }
256            }
257
258            cond_size += 1;
259            if !removed || cond_size > self.max_cond_set {
260                break;
261            }
262        }
263
264        // Phase 2: Orient v-structures
265        let mut dag = DAG::new(n);
266        // Edge types: None=undirected, Some(true)=oriented
267        // We use a different representation: directed edges stored in DAG
268        // First, detect and orient v-structures
269        let mut oriented: HashSet<(usize, usize)> = HashSet::new();
270
271        for b in 0..n {
272            let neighbors_b: Vec<usize> = adj[b].iter().copied().collect();
273            for (i, &a) in neighbors_b.iter().enumerate() {
274                for &c in &neighbors_b[(i + 1)..] {
275                    // Check if a and c are NOT adjacent
276                    if adj[a].contains(&c) {
277                        continue;
278                    }
279                    // Check if b is NOT in sep[a][c]
280                    let is_collider = sep.get(&(a, c)).map(|s| !s.contains(&b)).unwrap_or(true);
281                    if is_collider {
282                        // Orient a→b←c
283                        oriented.insert((a, b));
284                        oriented.insert((c, b));
285                    }
286                }
287            }
288        }
289
290        // Phase 3: Build DAG from oriented + remaining undirected edges
291        // Use topological ordering heuristic for remaining undirected edges
292        // Add oriented edges first
293        for &(from, to) in &oriented {
294            // Remove undirected representation: adj no longer needed for rest
295            let _ = dag.add_edge(from, to); // ignore cycle errors from conflicting orientations
296        }
297
298        // Orient remaining undirected edges respecting existing orientation
299        // Heuristic: orient consistently (avoid new v-structures, avoid cycles)
300        for x in 0..n {
301            for y in adj[x].iter().copied().collect::<Vec<_>>() {
302                if y <= x {
303                    continue;
304                }
305                if oriented.contains(&(x, y)) || oriented.contains(&(y, x)) {
306                    continue;
307                }
308                // Neither direction is oriented; try both
309                if dag.add_edge(x, y).is_ok() {
310                    // success
311                } else if dag.add_edge(y, x).is_ok() {
312                    // reversed
313                }
314            }
315        }
316
317        Ok(dag)
318    }
319
320    /// Test conditional independence between x and y given z in the data.
321    pub fn conditional_independence_test(
322        &self,
323        data: &[Vec<f64>],
324        x: usize,
325        y: usize,
326        z: &[usize],
327    ) -> bool {
328        fisherz_test(data, x, y, z) > self.alpha
329    }
330}
331
332// ---------------------------------------------------------------------------
333// HillClimbing
334// ---------------------------------------------------------------------------
335
336/// Score-based greedy hill-climbing for Bayesian Network structure learning.
337///
338/// Uses BIC score. Operators: add edge, remove edge, reverse edge.
339/// A tabu list prevents revisiting recent states.
340#[derive(Debug, Clone)]
341pub struct HillClimbing {
342    /// Maximum number of iterations.
343    pub max_iter: usize,
344    /// Tabu list length.
345    pub tabu_length: usize,
346}
347
348impl Default for HillClimbing {
349    fn default() -> Self {
350        Self {
351            max_iter: 100,
352            tabu_length: 10,
353        }
354    }
355}
356
357/// An operator applied to the DAG during hill climbing.
358#[derive(Debug, Clone, PartialEq, Eq, Hash)]
359pub enum Operator {
360    AddEdge(usize, usize),
361    RemoveEdge(usize, usize),
362    ReverseEdge(usize, usize),
363}
364
365impl HillClimbing {
366    /// Create a new HillClimbing learner.
367    pub fn new(max_iter: usize, tabu_length: usize) -> Self {
368        Self {
369            max_iter,
370            tabu_length,
371        }
372    }
373
374    /// Learn the DAG structure from discrete data.
375    pub fn fit(&self, data: &[Vec<f64>], cards: &[usize]) -> Result<DAG, StatsError> {
376        if data.is_empty() {
377            return Err(StatsError::InvalidInput("Empty data".to_string()));
378        }
379        let n = data[0].len();
380        if cards.len() != n {
381            return Err(StatsError::InvalidInput(format!(
382                "cards length {} != n_vars {n}",
383                cards.len()
384            )));
385        }
386
387        let mut dag = DAG::new(n);
388        let mut current_score = BIC::score(data, &dag, cards);
389        let mut tabu: VecDeque<Operator> = VecDeque::new();
390
391        for _iter in 0..self.max_iter {
392            let mut best_op: Option<Operator> = None;
393            let mut best_delta = 0.0f64;
394
395            // Enumerate all operators
396            let ops = self.enumerate_operators(&dag, n);
397            for op in ops {
398                if tabu.contains(&op) {
399                    continue;
400                }
401                let new_dag = self.apply_op(&dag, &op);
402                if new_dag.is_none() {
403                    continue;
404                }
405                let new_dag = new_dag.expect("apply_op returned Some after is_none() check");
406                if !new_dag.is_dag() {
407                    continue;
408                }
409                let new_score = BIC::score(data, &new_dag, cards);
410                let delta = new_score - current_score;
411                if delta > best_delta {
412                    best_delta = delta;
413                    best_op = Some(op);
414                }
415            }
416
417            if let Some(op) = best_op {
418                let new_dag = self.apply_op(&dag, &op).expect(
419                    "apply_op with best_op guaranteed to succeed since it passed earlier checks",
420                );
421                current_score += best_delta;
422                dag = new_dag;
423                tabu.push_back(op);
424                if tabu.len() > self.tabu_length {
425                    tabu.pop_front();
426                }
427            } else {
428                break; // No improvement
429            }
430        }
431
432        Ok(dag)
433    }
434
435    fn enumerate_operators(&self, dag: &DAG, n: usize) -> Vec<Operator> {
436        let mut ops = Vec::new();
437        for i in 0..n {
438            for j in 0..n {
439                if i == j {
440                    continue;
441                }
442                if dag.has_edge(i, j) {
443                    ops.push(Operator::RemoveEdge(i, j));
444                    // Reverse: i→j becomes j→i
445                    ops.push(Operator::ReverseEdge(i, j));
446                } else if !dag.has_edge(j, i) {
447                    ops.push(Operator::AddEdge(i, j));
448                }
449            }
450        }
451        ops
452    }
453
454    fn apply_op(&self, dag: &DAG, op: &Operator) -> Option<DAG> {
455        let mut new_dag = dag.clone();
456        match op {
457            Operator::AddEdge(i, j) => {
458                new_dag.add_edge(*i, *j).ok()?;
459            }
460            Operator::RemoveEdge(i, j) => {
461                new_dag.remove_edge(*i, *j);
462            }
463            Operator::ReverseEdge(i, j) => {
464                new_dag.remove_edge(*i, *j);
465                new_dag.add_edge(*j, *i).ok()?;
466            }
467        }
468        Some(new_dag)
469    }
470}
471
472// ---------------------------------------------------------------------------
473// BIC Score
474// ---------------------------------------------------------------------------
475
476/// Bayesian Information Criterion for discrete Bayesian Networks.
477///
478/// BIC = log-likelihood - (k/2) * log(n)
479/// where k = number of free parameters and n = sample size.
480pub struct BIC;
481
482impl BIC {
483    /// Compute the BIC score of a DAG given data and cardinalities.
484    pub fn score(data: &[Vec<f64>], dag: &DAG, cards: &[usize]) -> f64 {
485        let n_samples = data.len() as f64;
486        if n_samples < 1.0 {
487            return f64::NEG_INFINITY;
488        }
489        let n = dag.n_nodes;
490        let mut bic = 0.0f64;
491        for node in 0..n {
492            bic += Self::node_score(data, dag, node, cards, n_samples);
493        }
494        bic
495    }
496
497    fn node_score(
498        data: &[Vec<f64>],
499        dag: &DAG,
500        node: usize,
501        cards: &[usize],
502        n_samples: f64,
503    ) -> f64 {
504        let card_node = cards[node];
505        let parents = &dag.parents[node];
506        let parent_cards: Vec<usize> = parents.iter().map(|&p| cards[p]).collect();
507        let n_parent_configs: usize = if parent_cards.is_empty() {
508            1
509        } else {
510            parent_cards.iter().product()
511        };
512        // Count occurrences N[pa_config][node_val]
513        let mut counts = vec![vec![0u64; card_node]; n_parent_configs];
514        let mut pa_counts = vec![0u64; n_parent_configs];
515
516        for row in data {
517            let node_val = (row[node].round() as usize).min(card_node - 1);
518            let pa_config = if parents.is_empty() {
519                0
520            } else {
521                Self::config_index(row, parents, &parent_cards)
522            };
523            if pa_config < n_parent_configs && node_val < card_node {
524                counts[pa_config][node_val] += 1;
525                pa_counts[pa_config] += 1;
526            }
527        }
528
529        // Log-likelihood contribution
530        let mut ll = 0.0f64;
531        for pa in 0..n_parent_configs {
532            let pa_count = pa_counts[pa] as f64;
533            if pa_count < 1.0 {
534                continue;
535            }
536            for val in 0..card_node {
537                let c = counts[pa][val] as f64;
538                if c > 0.0 {
539                    ll += c * (c / pa_count).ln();
540                }
541            }
542        }
543
544        // Penalty: k = (card_node - 1) * n_parent_configs
545        let k = (card_node - 1) * n_parent_configs;
546        ll - 0.5 * k as f64 * n_samples.ln()
547    }
548
549    fn config_index(row: &[f64], parents: &[usize], parent_cards: &[usize]) -> usize {
550        let mut idx = 0usize;
551        let mut stride = 1usize;
552        for (i, &p) in parents.iter().enumerate().rev() {
553            let val = (row[p].round() as usize).min(parent_cards[i] - 1);
554            idx += val * stride;
555            stride *= parent_cards[i];
556        }
557        idx
558    }
559
560    /// Build a TabularCPD by MLE from data for a node given its parents.
561    pub fn mle_cpd(
562        data: &[Vec<f64>],
563        node: usize,
564        parents: &[usize],
565        cards: &[usize],
566    ) -> Result<TabularCPD, StatsError> {
567        let card_node = cards[node];
568        let parent_indices = parents.to_vec();
569        let parent_cards: Vec<usize> = parents.iter().map(|&p| cards[p]).collect();
570        let n_rows: usize = if parent_cards.is_empty() {
571            1
572        } else {
573            parent_cards.iter().product()
574        };
575
576        let mut counts = vec![vec![0u64; card_node]; n_rows];
577
578        for row in data {
579            let node_val = (row[node].round() as usize).min(card_node - 1);
580            let pa_config = if parents.is_empty() {
581                0
582            } else {
583                let parent_cards_local = parent_cards.clone();
584                let mut idx = 0usize;
585                let mut stride = 1usize;
586                for (i, &p) in parents.iter().enumerate().rev() {
587                    let val = (row[p].round() as usize).min(parent_cards_local[i] - 1);
588                    idx += val * stride;
589                    stride *= parent_cards_local[i];
590                }
591                idx
592            };
593            if pa_config < n_rows && node_val < card_node {
594                counts[pa_config][node_val] += 1;
595            }
596        }
597
598        // Normalize (with Laplace smoothing to avoid zeros)
599        let alpha = 1.0f64; // pseudocount
600        let table: Vec<Vec<f64>> = counts
601            .iter()
602            .map(|row_counts| {
603                let total = row_counts.iter().sum::<u64>() as f64 + alpha * card_node as f64;
604                row_counts
605                    .iter()
606                    .map(|&c| (c as f64 + alpha) / total)
607                    .collect()
608            })
609            .collect();
610
611        TabularCPD::new(node, card_node, parent_indices, parent_cards, table)
612    }
613}
614
615// ---------------------------------------------------------------------------
616// Helper: enumerate subsets of size k
617// ---------------------------------------------------------------------------
618
619fn subsets<T: Copy>(items: &[T], k: usize) -> Vec<Vec<T>> {
620    if k == 0 {
621        return vec![Vec::new()];
622    }
623    if k > items.len() {
624        return Vec::new();
625    }
626    let mut result = Vec::new();
627    for i in 0..=(items.len() - k) {
628        for mut rest in subsets(&items[i + 1..], k - 1) {
629            rest.insert(0, items[i]);
630            result.push(rest);
631        }
632    }
633    result
634}
635
636// ---------------------------------------------------------------------------
637// Unit tests
638// ---------------------------------------------------------------------------
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643
644    fn continuous_chain_data(n: usize) -> Vec<Vec<f64>> {
645        // X0 → X1 → X2 with Gaussian noise
646        let mut data = Vec::with_capacity(n);
647        let mut lcg: u64 = 54321;
648        let mut normal = || -> f64 {
649            lcg = lcg
650                .wrapping_mul(6364136223846793005)
651                .wrapping_add(1442695040888963407);
652            let u = (lcg >> 12) as f64 / (1u64 << 52) as f64;
653            lcg = lcg
654                .wrapping_mul(6364136223846793005)
655                .wrapping_add(1442695040888963407);
656            let v = ((lcg >> 12) as f64 / (1u64 << 52) as f64).max(1e-15);
657            (-2.0 * v.ln()).sqrt() * (2.0 * std::f64::consts::PI * u).cos()
658        };
659        for _ in 0..n {
660            let x0 = normal();
661            let x1 = 0.8 * x0 + 0.5 * normal();
662            let x2 = 0.8 * x1 + 0.5 * normal();
663            data.push(vec![x0, x1, x2]);
664        }
665        data
666    }
667
668    fn discrete_data(n: usize) -> Vec<Vec<f64>> {
669        // Binary data: X0 independent, X1 depends on X0
670        let mut data = Vec::with_capacity(n);
671        let mut lcg: u64 = 99887;
672        let mut uniform = || -> f64 {
673            lcg = lcg
674                .wrapping_mul(6364136223846793005)
675                .wrapping_add(1442695040888963407);
676            (lcg >> 11) as f64 / (1u64 << 53) as f64
677        };
678        for _ in 0..n {
679            let x0 = if uniform() < 0.5 { 0.0 } else { 1.0 };
680            let x1 = if x0 == 0.0 {
681                if uniform() < 0.8 {
682                    0.0
683                } else {
684                    1.0
685                }
686            } else {
687                if uniform() < 0.2 {
688                    0.0
689                } else {
690                    1.0
691                }
692            };
693            data.push(vec![x0, x1]);
694        }
695        data
696    }
697
698    #[test]
699    fn test_pc_algorithm_chain() {
700        let data = continuous_chain_data(200);
701        let pc = PCAlgorithm {
702            alpha: 0.05,
703            max_cond_set: 2,
704        };
705        let dag = pc.fit(&data).unwrap();
706        assert_eq!(dag.n_nodes, 3);
707        // At minimum some edges should be learned
708        assert!(dag.n_edges() > 0, "PC should learn at least one edge");
709    }
710
711    #[test]
712    fn test_pc_independence_test() {
713        let data = continuous_chain_data(500);
714        let pc = PCAlgorithm::default();
715        // X0 ⊥ X2 | X1 in a chain
716        let indep = pc.conditional_independence_test(&data, 0, 2, &[1]);
717        assert!(
718            indep,
719            "X0 and X2 should be conditionally independent given X1"
720        );
721        // X0 is NOT independent of X1 marginally
722        let dep = pc.conditional_independence_test(&data, 0, 1, &[]);
723        assert!(!dep, "X0 and X1 should be dependent marginally");
724    }
725
726    #[test]
727    fn test_hill_climbing_discrete() {
728        let data = discrete_data(200);
729        let cards = count_cardinalities(&data);
730        let hc = HillClimbing::default();
731        let dag = hc.fit(&data, &cards).unwrap();
732        assert_eq!(dag.n_nodes, 2);
733    }
734
735    #[test]
736    fn test_bic_score() {
737        let data = discrete_data(100);
738        let cards = count_cardinalities(&data);
739        let mut dag_empty = DAG::new(2);
740        let mut dag_edge = DAG::new(2);
741        dag_edge.add_edge(0, 1).unwrap();
742        let score_empty = BIC::score(&data, &dag_empty, &cards);
743        let score_edge = BIC::score(&data, &dag_edge, &cards);
744        // BIC with edge should be higher for correlated data
745        assert!(
746            score_edge > score_empty || score_edge.is_finite(),
747            "BIC edge={score_edge}, BIC empty={score_empty}"
748        );
749        let _ = dag_empty.n_nodes; // suppress unused warning
750    }
751
752    #[test]
753    fn test_mle_cpd() {
754        let data = discrete_data(200);
755        let cards = count_cardinalities(&data);
756        let cpd = BIC::mle_cpd(&data, 0, &[], &cards).unwrap();
757        let sum: f64 = cpd.table[0].iter().sum();
758        assert!((sum - 1.0).abs() < 1e-9);
759    }
760
761    #[test]
762    fn test_partial_correlation() {
763        let data = continuous_chain_data(500);
764        // Partial corr of 0 and 2 given 1 should be near 0
765        let pc = partial_correlation(&data, 0, 2, &[1]);
766        assert!(pc.abs() < 0.2, "Partial corr(X0,X2|X1) ≈ 0, got {pc}");
767    }
768}