1use super::{cpd::TabularCPD, dag::DAG};
9use crate::StatsError;
10use std::collections::{HashMap, HashSet, VecDeque};
11
12pub fn count_cardinalities(data: &[Vec<f64>]) -> Vec<usize> {
18 if data.is_empty() {
19 return Vec::new();
20 }
21 let n_vars = data[0].len();
22 let mut cards = vec![0usize; n_vars];
23 for row in data {
24 for (j, &val) in row.iter().enumerate().take(n_vars) {
25 let v = val.round() as usize;
26 if v + 1 > cards[j] {
27 cards[j] = v + 1;
28 }
29 }
30 }
31 cards.iter().map(|&c| c.max(2)).collect()
33}
34
35fn sample_corr(data: &[Vec<f64>], x: usize, y: usize) -> f64 {
37 let n = data.len() as f64;
38 let mean_x = data.iter().map(|r| r[x]).sum::<f64>() / n;
39 let mean_y = data.iter().map(|r| r[y]).sum::<f64>() / n;
40 let cov: f64 = data
41 .iter()
42 .map(|r| (r[x] - mean_x) * (r[y] - mean_y))
43 .sum::<f64>()
44 / n;
45 let var_x: f64 = data.iter().map(|r| (r[x] - mean_x).powi(2)).sum::<f64>() / n;
46 let var_y: f64 = data.iter().map(|r| (r[y] - mean_y).powi(2)).sum::<f64>() / n;
47 if var_x < 1e-15 || var_y < 1e-15 {
48 return 0.0;
49 }
50 (cov / (var_x.sqrt() * var_y.sqrt())).clamp(-1.0, 1.0)
51}
52
53pub fn partial_correlation(data: &[Vec<f64>], x: usize, y: usize, z: &[usize]) -> f64 {
57 if z.is_empty() {
58 return sample_corr(data, x, y);
59 }
60 let mut vars = vec![x, y];
62 vars.extend_from_slice(z);
63 vars.sort_unstable();
64 vars.dedup();
65 let idx_x = vars.iter().position(|&v| v == x).unwrap_or(0);
66 let idx_y = vars.iter().position(|&v| v == y).unwrap_or(0);
67 let m = vars.len();
68 let mut corr = vec![vec![0.0f64; m]; m];
70 for i in 0..m {
71 corr[i][i] = 1.0;
72 for j in (i + 1)..m {
73 let c = sample_corr(data, vars[i], vars[j]);
74 corr[i][j] = c;
75 corr[j][i] = c;
76 }
77 }
78 let inv = invert_matrix(&corr).unwrap_or_else(|| vec![vec![0.0; m]; m]);
80 let px = inv[idx_x][idx_x];
81 let py = inv[idx_y][idx_y];
82 let pxy = inv[idx_x][idx_y];
83 if px < 1e-15 || py < 1e-15 {
84 return 0.0;
85 }
86 (-pxy / (px * py).sqrt()).clamp(-1.0, 1.0)
87}
88
89fn invert_matrix(mat: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
91 let n = mat.len();
92 let mut a: Vec<Vec<f64>> = mat.to_vec();
93 let mut inv: Vec<Vec<f64>> = (0..n)
94 .map(|i| {
95 let mut row = vec![0.0; n];
96 row[i] = 1.0;
97 row
98 })
99 .collect();
100 for col in 0..n {
101 let pivot_row = (col..n).max_by(|&i, &j| {
103 a[i][col]
104 .abs()
105 .partial_cmp(&a[j][col].abs())
106 .unwrap_or(std::cmp::Ordering::Equal)
107 })?;
108 a.swap(col, pivot_row);
109 inv.swap(col, pivot_row);
110 let pivot = a[col][col];
111 if pivot.abs() < 1e-15 {
112 return None;
113 }
114 for j in 0..n {
115 a[col][j] /= pivot;
116 inv[col][j] /= pivot;
117 }
118 for row in 0..n {
119 if row == col {
120 continue;
121 }
122 let factor = a[row][col];
123 for j in 0..n {
124 let av = a[col][j];
125 let iv = inv[col][j];
126 a[row][j] -= factor * av;
127 inv[row][j] -= factor * iv;
128 }
129 }
130 }
131 Some(inv)
132}
133
134pub fn fisherz_test(data: &[Vec<f64>], x: usize, y: usize, z: &[usize]) -> f64 {
138 let n = data.len() as f64;
139 let r = partial_correlation(data, x, y, z);
140 let r_clamped = r.clamp(-1.0 + 1e-10, 1.0 - 1e-10);
141 let fisher_z = 0.5 * ((1.0 + r_clamped) / (1.0 - r_clamped)).ln();
142 let dof = (n - z.len() as f64 - 3.0).max(1.0);
143 let stat = fisher_z.abs() * dof.sqrt();
144 2.0 * normal_sf(stat)
147}
148
149fn normal_sf(x: f64) -> f64 {
151 0.5 * erfc_approx(x / std::f64::consts::SQRT_2)
152}
153
154fn erfc_approx(x: f64) -> f64 {
156 let t = 1.0 / (1.0 + 0.3275911 * x.abs());
158 let poly = t
159 * (0.254829592
160 + t * (-0.284496736 + t * (1.421413741 + t * (-1.453152027 + t * 1.061405429))));
161 let result = poly * (-x * x).exp();
162 if x >= 0.0 {
163 result
164 } else {
165 2.0 - result
166 }
167}
168
169#[derive(Debug, Clone)]
180pub struct PCAlgorithm {
181 pub alpha: f64,
183 pub max_cond_set: usize,
185}
186
187impl Default for PCAlgorithm {
188 fn default() -> Self {
189 Self {
190 alpha: 0.05,
191 max_cond_set: 3,
192 }
193 }
194}
195
196impl PCAlgorithm {
197 pub fn new(alpha: f64, max_cond_set: usize) -> Self {
199 Self {
200 alpha,
201 max_cond_set,
202 }
203 }
204
205 pub fn fit(&self, data: &[Vec<f64>]) -> Result<DAG, StatsError> {
207 if data.is_empty() {
208 return Err(StatsError::InvalidInput("Empty data".to_string()));
209 }
210 let n = data[0].len();
211 if n < 2 {
212 return Err(StatsError::InvalidInput(
213 "Need at least 2 variables".to_string(),
214 ));
215 }
216
217 let mut adj: Vec<HashSet<usize>> = (0..n)
220 .map(|i| (0..n).filter(|&j| j != i).collect())
221 .collect();
222
223 let mut sep: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
225
226 let mut cond_size = 0usize;
227 loop {
228 let mut removed = false;
229 let edges: Vec<(usize, usize)> = (0..n)
230 .flat_map(|i| adj[i].iter().map(move |&j| (i, j)))
231 .filter(|&(i, j)| i < j)
232 .collect();
233
234 for (x, y) in edges {
235 if !adj[x].contains(&y) {
236 continue;
237 }
238 let adj_x: Vec<usize> = adj[x].iter().copied().filter(|&v| v != y).collect();
240 if adj_x.len() < cond_size {
241 continue;
242 }
243 for cond_set in subsets(&adj_x, cond_size) {
245 let p = fisherz_test(data, x, y, &cond_set);
246 if p > self.alpha {
247 adj[x].remove(&y);
249 adj[y].remove(&x);
250 sep.insert((x, y), cond_set.clone());
251 sep.insert((y, x), cond_set);
252 removed = true;
253 break;
254 }
255 }
256 }
257
258 cond_size += 1;
259 if !removed || cond_size > self.max_cond_set {
260 break;
261 }
262 }
263
264 let mut dag = DAG::new(n);
266 let mut oriented: HashSet<(usize, usize)> = HashSet::new();
270
271 for b in 0..n {
272 let neighbors_b: Vec<usize> = adj[b].iter().copied().collect();
273 for (i, &a) in neighbors_b.iter().enumerate() {
274 for &c in &neighbors_b[(i + 1)..] {
275 if adj[a].contains(&c) {
277 continue;
278 }
279 let is_collider = sep.get(&(a, c)).map(|s| !s.contains(&b)).unwrap_or(true);
281 if is_collider {
282 oriented.insert((a, b));
284 oriented.insert((c, b));
285 }
286 }
287 }
288 }
289
290 for &(from, to) in &oriented {
294 let _ = dag.add_edge(from, to); }
297
298 for x in 0..n {
301 for y in adj[x].iter().copied().collect::<Vec<_>>() {
302 if y <= x {
303 continue;
304 }
305 if oriented.contains(&(x, y)) || oriented.contains(&(y, x)) {
306 continue;
307 }
308 if dag.add_edge(x, y).is_ok() {
310 } else if dag.add_edge(y, x).is_ok() {
312 }
314 }
315 }
316
317 Ok(dag)
318 }
319
320 pub fn conditional_independence_test(
322 &self,
323 data: &[Vec<f64>],
324 x: usize,
325 y: usize,
326 z: &[usize],
327 ) -> bool {
328 fisherz_test(data, x, y, z) > self.alpha
329 }
330}
331
332#[derive(Debug, Clone)]
341pub struct HillClimbing {
342 pub max_iter: usize,
344 pub tabu_length: usize,
346}
347
348impl Default for HillClimbing {
349 fn default() -> Self {
350 Self {
351 max_iter: 100,
352 tabu_length: 10,
353 }
354 }
355}
356
357#[derive(Debug, Clone, PartialEq, Eq, Hash)]
359pub enum Operator {
360 AddEdge(usize, usize),
361 RemoveEdge(usize, usize),
362 ReverseEdge(usize, usize),
363}
364
365impl HillClimbing {
366 pub fn new(max_iter: usize, tabu_length: usize) -> Self {
368 Self {
369 max_iter,
370 tabu_length,
371 }
372 }
373
374 pub fn fit(&self, data: &[Vec<f64>], cards: &[usize]) -> Result<DAG, StatsError> {
376 if data.is_empty() {
377 return Err(StatsError::InvalidInput("Empty data".to_string()));
378 }
379 let n = data[0].len();
380 if cards.len() != n {
381 return Err(StatsError::InvalidInput(format!(
382 "cards length {} != n_vars {n}",
383 cards.len()
384 )));
385 }
386
387 let mut dag = DAG::new(n);
388 let mut current_score = BIC::score(data, &dag, cards);
389 let mut tabu: VecDeque<Operator> = VecDeque::new();
390
391 for _iter in 0..self.max_iter {
392 let mut best_op: Option<Operator> = None;
393 let mut best_delta = 0.0f64;
394
395 let ops = self.enumerate_operators(&dag, n);
397 for op in ops {
398 if tabu.contains(&op) {
399 continue;
400 }
401 let new_dag = self.apply_op(&dag, &op);
402 if new_dag.is_none() {
403 continue;
404 }
405 let new_dag = new_dag.expect("apply_op returned Some after is_none() check");
406 if !new_dag.is_dag() {
407 continue;
408 }
409 let new_score = BIC::score(data, &new_dag, cards);
410 let delta = new_score - current_score;
411 if delta > best_delta {
412 best_delta = delta;
413 best_op = Some(op);
414 }
415 }
416
417 if let Some(op) = best_op {
418 let new_dag = self.apply_op(&dag, &op).expect(
419 "apply_op with best_op guaranteed to succeed since it passed earlier checks",
420 );
421 current_score += best_delta;
422 dag = new_dag;
423 tabu.push_back(op);
424 if tabu.len() > self.tabu_length {
425 tabu.pop_front();
426 }
427 } else {
428 break; }
430 }
431
432 Ok(dag)
433 }
434
435 fn enumerate_operators(&self, dag: &DAG, n: usize) -> Vec<Operator> {
436 let mut ops = Vec::new();
437 for i in 0..n {
438 for j in 0..n {
439 if i == j {
440 continue;
441 }
442 if dag.has_edge(i, j) {
443 ops.push(Operator::RemoveEdge(i, j));
444 ops.push(Operator::ReverseEdge(i, j));
446 } else if !dag.has_edge(j, i) {
447 ops.push(Operator::AddEdge(i, j));
448 }
449 }
450 }
451 ops
452 }
453
454 fn apply_op(&self, dag: &DAG, op: &Operator) -> Option<DAG> {
455 let mut new_dag = dag.clone();
456 match op {
457 Operator::AddEdge(i, j) => {
458 new_dag.add_edge(*i, *j).ok()?;
459 }
460 Operator::RemoveEdge(i, j) => {
461 new_dag.remove_edge(*i, *j);
462 }
463 Operator::ReverseEdge(i, j) => {
464 new_dag.remove_edge(*i, *j);
465 new_dag.add_edge(*j, *i).ok()?;
466 }
467 }
468 Some(new_dag)
469 }
470}
471
472pub struct BIC;
481
482impl BIC {
483 pub fn score(data: &[Vec<f64>], dag: &DAG, cards: &[usize]) -> f64 {
485 let n_samples = data.len() as f64;
486 if n_samples < 1.0 {
487 return f64::NEG_INFINITY;
488 }
489 let n = dag.n_nodes;
490 let mut bic = 0.0f64;
491 for node in 0..n {
492 bic += Self::node_score(data, dag, node, cards, n_samples);
493 }
494 bic
495 }
496
497 fn node_score(
498 data: &[Vec<f64>],
499 dag: &DAG,
500 node: usize,
501 cards: &[usize],
502 n_samples: f64,
503 ) -> f64 {
504 let card_node = cards[node];
505 let parents = &dag.parents[node];
506 let parent_cards: Vec<usize> = parents.iter().map(|&p| cards[p]).collect();
507 let n_parent_configs: usize = if parent_cards.is_empty() {
508 1
509 } else {
510 parent_cards.iter().product()
511 };
512 let mut counts = vec![vec![0u64; card_node]; n_parent_configs];
514 let mut pa_counts = vec![0u64; n_parent_configs];
515
516 for row in data {
517 let node_val = (row[node].round() as usize).min(card_node - 1);
518 let pa_config = if parents.is_empty() {
519 0
520 } else {
521 Self::config_index(row, parents, &parent_cards)
522 };
523 if pa_config < n_parent_configs && node_val < card_node {
524 counts[pa_config][node_val] += 1;
525 pa_counts[pa_config] += 1;
526 }
527 }
528
529 let mut ll = 0.0f64;
531 for pa in 0..n_parent_configs {
532 let pa_count = pa_counts[pa] as f64;
533 if pa_count < 1.0 {
534 continue;
535 }
536 for val in 0..card_node {
537 let c = counts[pa][val] as f64;
538 if c > 0.0 {
539 ll += c * (c / pa_count).ln();
540 }
541 }
542 }
543
544 let k = (card_node - 1) * n_parent_configs;
546 ll - 0.5 * k as f64 * n_samples.ln()
547 }
548
549 fn config_index(row: &[f64], parents: &[usize], parent_cards: &[usize]) -> usize {
550 let mut idx = 0usize;
551 let mut stride = 1usize;
552 for (i, &p) in parents.iter().enumerate().rev() {
553 let val = (row[p].round() as usize).min(parent_cards[i] - 1);
554 idx += val * stride;
555 stride *= parent_cards[i];
556 }
557 idx
558 }
559
560 pub fn mle_cpd(
562 data: &[Vec<f64>],
563 node: usize,
564 parents: &[usize],
565 cards: &[usize],
566 ) -> Result<TabularCPD, StatsError> {
567 let card_node = cards[node];
568 let parent_indices = parents.to_vec();
569 let parent_cards: Vec<usize> = parents.iter().map(|&p| cards[p]).collect();
570 let n_rows: usize = if parent_cards.is_empty() {
571 1
572 } else {
573 parent_cards.iter().product()
574 };
575
576 let mut counts = vec![vec![0u64; card_node]; n_rows];
577
578 for row in data {
579 let node_val = (row[node].round() as usize).min(card_node - 1);
580 let pa_config = if parents.is_empty() {
581 0
582 } else {
583 let parent_cards_local = parent_cards.clone();
584 let mut idx = 0usize;
585 let mut stride = 1usize;
586 for (i, &p) in parents.iter().enumerate().rev() {
587 let val = (row[p].round() as usize).min(parent_cards_local[i] - 1);
588 idx += val * stride;
589 stride *= parent_cards_local[i];
590 }
591 idx
592 };
593 if pa_config < n_rows && node_val < card_node {
594 counts[pa_config][node_val] += 1;
595 }
596 }
597
598 let alpha = 1.0f64; let table: Vec<Vec<f64>> = counts
601 .iter()
602 .map(|row_counts| {
603 let total = row_counts.iter().sum::<u64>() as f64 + alpha * card_node as f64;
604 row_counts
605 .iter()
606 .map(|&c| (c as f64 + alpha) / total)
607 .collect()
608 })
609 .collect();
610
611 TabularCPD::new(node, card_node, parent_indices, parent_cards, table)
612 }
613}
614
615fn subsets<T: Copy>(items: &[T], k: usize) -> Vec<Vec<T>> {
620 if k == 0 {
621 return vec![Vec::new()];
622 }
623 if k > items.len() {
624 return Vec::new();
625 }
626 let mut result = Vec::new();
627 for i in 0..=(items.len() - k) {
628 for mut rest in subsets(&items[i + 1..], k - 1) {
629 rest.insert(0, items[i]);
630 result.push(rest);
631 }
632 }
633 result
634}
635
636#[cfg(test)]
641mod tests {
642 use super::*;
643
644 fn continuous_chain_data(n: usize) -> Vec<Vec<f64>> {
645 let mut data = Vec::with_capacity(n);
647 let mut lcg: u64 = 54321;
648 let mut normal = || -> f64 {
649 lcg = lcg
650 .wrapping_mul(6364136223846793005)
651 .wrapping_add(1442695040888963407);
652 let u = (lcg >> 12) as f64 / (1u64 << 52) as f64;
653 lcg = lcg
654 .wrapping_mul(6364136223846793005)
655 .wrapping_add(1442695040888963407);
656 let v = ((lcg >> 12) as f64 / (1u64 << 52) as f64).max(1e-15);
657 (-2.0 * v.ln()).sqrt() * (2.0 * std::f64::consts::PI * u).cos()
658 };
659 for _ in 0..n {
660 let x0 = normal();
661 let x1 = 0.8 * x0 + 0.5 * normal();
662 let x2 = 0.8 * x1 + 0.5 * normal();
663 data.push(vec![x0, x1, x2]);
664 }
665 data
666 }
667
668 fn discrete_data(n: usize) -> Vec<Vec<f64>> {
669 let mut data = Vec::with_capacity(n);
671 let mut lcg: u64 = 99887;
672 let mut uniform = || -> f64 {
673 lcg = lcg
674 .wrapping_mul(6364136223846793005)
675 .wrapping_add(1442695040888963407);
676 (lcg >> 11) as f64 / (1u64 << 53) as f64
677 };
678 for _ in 0..n {
679 let x0 = if uniform() < 0.5 { 0.0 } else { 1.0 };
680 let x1 = if x0 == 0.0 {
681 if uniform() < 0.8 {
682 0.0
683 } else {
684 1.0
685 }
686 } else {
687 if uniform() < 0.2 {
688 0.0
689 } else {
690 1.0
691 }
692 };
693 data.push(vec![x0, x1]);
694 }
695 data
696 }
697
698 #[test]
699 fn test_pc_algorithm_chain() {
700 let data = continuous_chain_data(200);
701 let pc = PCAlgorithm {
702 alpha: 0.05,
703 max_cond_set: 2,
704 };
705 let dag = pc.fit(&data).unwrap();
706 assert_eq!(dag.n_nodes, 3);
707 assert!(dag.n_edges() > 0, "PC should learn at least one edge");
709 }
710
711 #[test]
712 fn test_pc_independence_test() {
713 let data = continuous_chain_data(500);
714 let pc = PCAlgorithm::default();
715 let indep = pc.conditional_independence_test(&data, 0, 2, &[1]);
717 assert!(
718 indep,
719 "X0 and X2 should be conditionally independent given X1"
720 );
721 let dep = pc.conditional_independence_test(&data, 0, 1, &[]);
723 assert!(!dep, "X0 and X1 should be dependent marginally");
724 }
725
726 #[test]
727 fn test_hill_climbing_discrete() {
728 let data = discrete_data(200);
729 let cards = count_cardinalities(&data);
730 let hc = HillClimbing::default();
731 let dag = hc.fit(&data, &cards).unwrap();
732 assert_eq!(dag.n_nodes, 2);
733 }
734
735 #[test]
736 fn test_bic_score() {
737 let data = discrete_data(100);
738 let cards = count_cardinalities(&data);
739 let mut dag_empty = DAG::new(2);
740 let mut dag_edge = DAG::new(2);
741 dag_edge.add_edge(0, 1).unwrap();
742 let score_empty = BIC::score(&data, &dag_empty, &cards);
743 let score_edge = BIC::score(&data, &dag_edge, &cards);
744 assert!(
746 score_edge > score_empty || score_edge.is_finite(),
747 "BIC edge={score_edge}, BIC empty={score_empty}"
748 );
749 let _ = dag_empty.n_nodes; }
751
752 #[test]
753 fn test_mle_cpd() {
754 let data = discrete_data(200);
755 let cards = count_cardinalities(&data);
756 let cpd = BIC::mle_cpd(&data, 0, &[], &cards).unwrap();
757 let sum: f64 = cpd.table[0].iter().sum();
758 assert!((sum - 1.0).abs() < 1e-9);
759 }
760
761 #[test]
762 fn test_partial_correlation() {
763 let data = continuous_chain_data(500);
764 let pc = partial_correlation(&data, 0, 2, &[1]);
766 assert!(pc.abs() < 0.2, "Partial corr(X0,X2|X1) ≈ 0, got {pc}");
767 }
768}