1use std::collections::{HashMap, HashSet, VecDeque};
36
37use scirs2_core::ndarray::ArrayView2;
38
39use super::conditional_independence::{ConditionalIndependenceTest, PartialCorrelationTest};
40use super::pc_algorithm::subsets;
41use super::{CausalGraph, EdgeMark};
42use crate::error::{StatsError, StatsResult};
43
44#[derive(Debug, Clone)]
50pub struct FciAlgorithm {
51 pub alpha: f64,
53 pub max_cond_set_size: usize,
55 pub max_pdsep_size: usize,
57}
58
59impl Default for FciAlgorithm {
60 fn default() -> Self {
61 Self {
62 alpha: 0.05,
63 max_cond_set_size: 4,
64 max_pdsep_size: 4,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct FciResult {
72 pub graph: CausalGraph,
74 pub sep_sets: HashMap<(usize, usize), Vec<usize>>,
76 pub n_tests: usize,
78 pub has_latent_confounders: bool,
80}
81
82impl FciAlgorithm {
83 pub fn new(alpha: f64) -> Self {
85 Self {
86 alpha,
87 ..Default::default()
88 }
89 }
90
91 pub fn with_params(alpha: f64, max_cond_set_size: usize, max_pdsep_size: usize) -> Self {
93 Self {
94 alpha,
95 max_cond_set_size,
96 max_pdsep_size,
97 }
98 }
99
100 pub fn fit(&self, data: ArrayView2<f64>, var_names: &[&str]) -> StatsResult<FciResult> {
102 let ci_test = PartialCorrelationTest::new(self.alpha);
103 self.fit_with_test(data, var_names, &ci_test)
104 }
105
106 pub fn fit_with_test<T: ConditionalIndependenceTest>(
108 &self,
109 data: ArrayView2<f64>,
110 var_names: &[&str],
111 ci_test: &T,
112 ) -> StatsResult<FciResult> {
113 let p = data.ncols();
114 if var_names.len() != p {
115 return Err(StatsError::DimensionMismatch(
116 "var_names length must match number of columns".to_owned(),
117 ));
118 }
119 if p == 0 {
120 return Ok(FciResult {
121 graph: CausalGraph::new(var_names),
122 sep_sets: HashMap::new(),
123 n_tests: 0,
124 has_latent_confounders: false,
125 });
126 }
127
128 let (mut adj, mut sep_sets, mut n_tests) =
130 skeleton_discovery(data, p, self.alpha, self.max_cond_set_size, ci_test)?;
131
132 let mut graph = CausalGraph::new(var_names);
134 for i in 0..p {
136 for j in (i + 1)..p {
137 if adj[i][j] {
138 graph.set_edge(i, j, EdgeMark::Circle, EdgeMark::Circle);
139 }
140 }
141 }
142 orient_unshielded_colliders(&mut graph, &adj, &sep_sets, p);
143
144 let pdsep_removals = possible_dsep_phase(
146 &graph,
147 data,
148 &adj,
149 p,
150 self.alpha,
151 self.max_pdsep_size,
152 ci_test,
153 &mut n_tests,
154 )?;
155
156 for (x, y, z_set) in pdsep_removals {
158 adj[x][y] = false;
159 adj[y][x] = false;
160 graph.remove_edge(x, y);
161 let key = (x.min(y), x.max(y));
162 sep_sets.insert(key, z_set);
163 }
164
165 for i in 0..p {
168 for j in (i + 1)..p {
169 if adj[i][j] {
170 graph.set_edge(i, j, EdgeMark::Circle, EdgeMark::Circle);
171 }
172 }
173 }
174 orient_unshielded_colliders(&mut graph, &adj, &sep_sets, p);
175
176 apply_fci_rules(&mut graph, &adj, &sep_sets, p);
178
179 let has_latent_confounders =
181 (0..p).any(|i| (0..p).any(|j| i != j && graph.is_bidirected(i, j)));
182
183 Ok(FciResult {
184 graph,
185 sep_sets,
186 n_tests,
187 has_latent_confounders,
188 })
189 }
190}
191
192fn skeleton_discovery<T: ConditionalIndependenceTest>(
198 data: ArrayView2<f64>,
199 p: usize,
200 alpha: f64,
201 max_cond_set_size: usize,
202 ci_test: &T,
203) -> StatsResult<(Vec<Vec<bool>>, HashMap<(usize, usize), Vec<usize>>, usize)> {
204 let mut adj = vec![vec![true; p]; p];
205 for i in 0..p {
206 adj[i][i] = false;
207 }
208 let mut sep_sets: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
209 let mut n_tests = 0usize;
210
211 for ord in 0..=max_cond_set_size {
212 let adj_snapshot = adj.clone();
213 let edges: Vec<(usize, usize)> = (0..p)
214 .flat_map(|i| ((i + 1)..p).map(move |j| (i, j)))
215 .filter(|&(i, j)| adj_snapshot[i][j])
216 .collect();
217
218 let mut removals = Vec::new();
219
220 for (x, y) in edges {
221 let z_x: Vec<usize> = (0..p)
222 .filter(|&k| k != x && k != y && adj_snapshot[x][k])
223 .collect();
224 let z_y: Vec<usize> = (0..p)
225 .filter(|&k| k != x && k != y && adj_snapshot[y][k])
226 .collect();
227
228 let mut found = false;
229 if z_x.len() >= ord {
230 for z_set in subsets(&z_x, ord) {
231 n_tests += 1;
232 if ci_test.is_independent(x, y, &z_set, data, alpha)? {
233 removals.push((x, y, z_set));
234 found = true;
235 break;
236 }
237 }
238 }
239 if !found && z_y.len() >= ord {
240 for z_set in subsets(&z_y, ord) {
241 n_tests += 1;
242 if ci_test.is_independent(x, y, &z_set, data, alpha)? {
243 removals.push((x, y, z_set));
244 break;
245 }
246 }
247 }
248 }
249
250 for (x, y, z_set) in removals {
251 adj[x][y] = false;
252 adj[y][x] = false;
253 let key = (x.min(y), x.max(y));
254 sep_sets.insert(key, z_set);
255 }
256 }
257
258 Ok((adj, sep_sets, n_tests))
259}
260
261fn orient_unshielded_colliders(
268 graph: &mut CausalGraph,
269 adj: &[Vec<bool>],
270 sep_sets: &HashMap<(usize, usize), Vec<usize>>,
271 p: usize,
272) {
273 for z in 0..p {
274 let neighbours: Vec<usize> = (0..p).filter(|&k| k != z && adj[z][k]).collect();
275 for i in 0..neighbours.len() {
276 for j in (i + 1)..neighbours.len() {
277 let x = neighbours[i];
278 let y = neighbours[j];
279 if adj[x][y] {
280 continue; }
282 let key = (x.min(y), x.max(y));
283 let sep = sep_sets.get(&key).cloned().unwrap_or_default();
284 if !sep.contains(&z) {
285 let mark_xz_from = graph.get_mark_from(x, z).unwrap_or(EdgeMark::Circle);
287 graph.set_edge(x, z, mark_xz_from, EdgeMark::Arrow);
288 let mark_yz_from = graph.get_mark_from(y, z).unwrap_or(EdgeMark::Circle);
289 graph.set_edge(y, z, mark_yz_from, EdgeMark::Arrow);
290 }
291 }
292 }
293 }
294}
295
296fn possible_dsep(graph: &CausalGraph, a: usize, b: usize, p: usize) -> HashSet<usize> {
306 let mut pdsep = HashSet::new();
307 let mut visited = HashSet::new();
308 let mut queue = VecDeque::new();
309
310 for k in 0..p {
312 if k != a && k != b && graph.is_adjacent(a, k) {
313 queue.push_back((k, a)); }
315 }
316
317 while let Some((cur, prev)) = queue.pop_front() {
318 if !visited.insert((cur, prev)) {
319 continue;
320 }
321 pdsep.insert(cur);
322
323 for next in 0..p {
326 if next == prev || next == a || !graph.is_adjacent(cur, next) {
327 continue;
328 }
329 let mark_at_cur_from_prev = graph.get_mark_at(prev, cur);
334 let is_possible_collider = match mark_at_cur_from_prev {
335 Some(EdgeMark::Arrow) | Some(EdgeMark::Circle) => true,
336 _ => false,
337 };
338
339 if is_possible_collider {
340 queue.push_back((next, cur));
341 }
342 }
343 }
344
345 pdsep
346}
347
348fn possible_dsep_phase<T: ConditionalIndependenceTest>(
351 graph: &CausalGraph,
352 data: ArrayView2<f64>,
353 adj: &[Vec<bool>],
354 p: usize,
355 alpha: f64,
356 max_pdsep_size: usize,
357 ci_test: &T,
358 n_tests: &mut usize,
359) -> StatsResult<Vec<(usize, usize, Vec<usize>)>> {
360 let mut removals = Vec::new();
361
362 for x in 0..p {
363 for y in (x + 1)..p {
364 if !adj[x][y] {
365 continue;
366 }
367
368 let pdsep_x = possible_dsep(graph, x, y, p);
369 let pdsep_y = possible_dsep(graph, y, x, p);
370 let combined: Vec<usize> = pdsep_x
371 .union(&pdsep_y)
372 .copied()
373 .filter(|&k| k != x && k != y)
374 .collect();
375
376 if combined.is_empty() {
377 continue;
378 }
379
380 let max_size = max_pdsep_size.min(combined.len());
381 let mut found = false;
382 for ord in 0..=max_size {
383 if found {
384 break;
385 }
386 for z_set in subsets(&combined, ord) {
387 *n_tests += 1;
388 if ci_test.is_independent(x, y, &z_set, data, alpha)? {
389 removals.push((x, y, z_set));
390 found = true;
391 break;
392 }
393 }
394 }
395 }
396 }
397
398 Ok(removals)
399}
400
401fn apply_fci_rules(
409 graph: &mut CausalGraph,
410 adj: &[Vec<bool>],
411 sep_sets: &HashMap<(usize, usize), Vec<usize>>,
412 p: usize,
413) {
414 let max_iterations = p * p * 2 + 10;
415 let mut changed = true;
416 let mut iterations = 0;
417
418 while changed && iterations < max_iterations {
419 changed = false;
420 iterations += 1;
421
422 changed |= fci_r1(graph, p);
423 changed |= fci_r2(graph, p);
424 changed |= fci_r3(graph, adj, p);
425 changed |= fci_r4(graph, adj, sep_sets, p);
426 changed |= fci_r5(graph, adj, p);
427 changed |= fci_r6(graph, p);
428 changed |= fci_r7(graph, p);
429 changed |= fci_r8(graph, p);
430 changed |= fci_r9(graph, p);
431 changed |= fci_r10(graph, p);
432 }
433}
434
435fn fci_r1(graph: &mut CausalGraph, p: usize) -> bool {
437 let mut changed = false;
438 for b in 0..p {
439 for a in 0..p {
440 if a == b {
441 continue;
442 }
443 if graph.get_mark_at(a, b) != Some(EdgeMark::Arrow) {
445 continue;
446 }
447 for c in 0..p {
448 if c == a || c == b {
449 continue;
450 }
451 if !graph.is_adjacent(b, c) {
452 continue;
453 }
454 if graph.is_adjacent(a, c) {
455 continue;
456 }
457 if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
459 continue;
460 }
461 let mark_at_c = graph.get_mark_at(b, c).unwrap_or(EdgeMark::Circle);
463 graph.set_edge(b, c, EdgeMark::Tail, mark_at_c);
464 changed = true;
465 }
466 }
467 }
468 changed
469}
470
471fn fci_r2(graph: &mut CausalGraph, p: usize) -> bool {
473 let mut changed = false;
474 for a in 0..p {
475 for c in 0..p {
476 if a == c || !graph.is_adjacent(a, c) {
477 continue;
478 }
479 if graph.get_mark_at(a, c) != Some(EdgeMark::Circle) {
481 continue;
482 }
483 for b in 0..p {
484 if b == a || b == c {
485 continue;
486 }
487 let case1 = graph.get_mark_from(a, b) == Some(EdgeMark::Tail)
489 && graph.get_mark_at(a, b) == Some(EdgeMark::Arrow)
490 && graph.get_mark_at(b, c) == Some(EdgeMark::Arrow);
491 let case2 = graph.get_mark_at(a, b) == Some(EdgeMark::Arrow)
493 && graph.get_mark_from(b, c) == Some(EdgeMark::Tail)
494 && graph.get_mark_at(b, c) == Some(EdgeMark::Arrow);
495
496 if case1 || case2 {
497 let mark_from_a = graph.get_mark_from(a, c).unwrap_or(EdgeMark::Circle);
499 graph.set_edge(a, c, mark_from_a, EdgeMark::Arrow);
500 changed = true;
501 break;
502 }
503 }
504 }
505 }
506 changed
507}
508
509fn fci_r3(graph: &mut CausalGraph, adj: &[Vec<bool>], p: usize) -> bool {
512 let mut changed = false;
513 for d in 0..p {
514 for b in 0..p {
515 if d == b || !graph.is_adjacent(d, b) {
516 continue;
517 }
518 if graph.get_mark_at(d, b) != Some(EdgeMark::Circle) {
520 continue;
521 }
522 let parents_b: Vec<usize> = (0..p)
524 .filter(|&k| {
525 k != b
526 && k != d
527 && graph.is_adjacent(k, b)
528 && graph.get_mark_at(k, b) == Some(EdgeMark::Arrow)
529 })
530 .collect();
531 let mut orient = false;
532 for i in 0..parents_b.len() {
533 for j in (i + 1)..parents_b.len() {
534 let a = parents_b[i];
535 let c = parents_b[j];
536 if adj[a][c] {
537 continue;
538 }
539 if !graph.is_adjacent(a, d) {
541 continue;
542 }
543 if graph.get_mark_at(a, d) != Some(EdgeMark::Circle) {
544 continue;
545 }
546 if !graph.is_adjacent(c, d) {
548 continue;
549 }
550 if graph.get_mark_at(c, d) != Some(EdgeMark::Circle) {
551 continue;
552 }
553 orient = true;
554 break;
555 }
556 if orient {
557 break;
558 }
559 }
560 if orient {
561 let mark_from = graph.get_mark_from(d, b).unwrap_or(EdgeMark::Circle);
562 graph.set_edge(d, b, mark_from, EdgeMark::Arrow);
563 changed = true;
564 }
565 }
566 }
567 changed
568}
569
570fn fci_r4(
576 graph: &mut CausalGraph,
577 _adj: &[Vec<bool>],
578 sep_sets: &HashMap<(usize, usize), Vec<usize>>,
579 p: usize,
580) -> bool {
581 let mut changed = false;
582 for c in 0..p {
584 for b in 0..p {
585 if b == c || !graph.is_adjacent(b, c) {
586 continue;
587 }
588 if graph.get_mark_at(b, c) != Some(EdgeMark::Arrow) {
589 continue;
590 }
591 if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
593 continue;
594 }
595
596 for a in 0..p {
601 if a == b || a == c || !graph.is_adjacent(a, c) {
602 continue;
603 }
604 }
609
610 for a in 0..p {
614 if a == b || a == c {
615 continue;
616 }
617 if graph.is_adjacent(a, c) {
618 continue; }
620 if !graph.is_adjacent(a, b) {
621 continue;
622 }
623 if graph.get_mark_at(a, b) != Some(EdgeMark::Arrow) {
625 continue;
626 }
627
628 let key = (a.min(c), a.max(c));
630 let sep = sep_sets.get(&key).cloned().unwrap_or_default();
631
632 if sep.contains(&b) {
633 let mark_from_b = graph.get_mark_from(b, c).unwrap_or(EdgeMark::Circle);
635 let _mark_at_c = EdgeMark::Arrow;
636 graph.set_edge(b, c, EdgeMark::Tail, EdgeMark::Arrow);
638 let _ = mark_from_b;
639 } else {
640 graph.set_edge(b, c, EdgeMark::Arrow, EdgeMark::Arrow);
642 }
643 changed = true;
644 break;
645 }
646 }
647 }
648 changed
649}
650
651fn fci_r5(graph: &mut CausalGraph, _adj: &[Vec<bool>], p: usize) -> bool {
655 let mut changed = false;
656 for a in 0..p {
657 for b in (a + 1)..p {
658 if !graph.is_adjacent(a, b) {
659 continue;
660 }
661 if graph.get_mark_from(a, b) != Some(EdgeMark::Circle)
663 || graph.get_mark_at(a, b) != Some(EdgeMark::Circle)
664 {
665 continue;
666 }
667 if has_uncovered_circle_path(graph, a, b, p) {
669 graph.set_edge(a, b, EdgeMark::Tail, EdgeMark::Tail);
670 changed = true;
671 }
672 }
673 }
674 changed
675}
676
677fn fci_r6(graph: &mut CausalGraph, p: usize) -> bool {
679 let mut changed = false;
680 for b in 0..p {
681 for a in 0..p {
682 if a == b || !graph.is_adjacent(a, b) {
683 continue;
684 }
685 if graph.get_mark_from(a, b) != Some(EdgeMark::Tail)
687 || graph.get_mark_at(a, b) != Some(EdgeMark::Tail)
688 {
689 continue;
690 }
691 for c in 0..p {
692 if c == a || c == b || !graph.is_adjacent(b, c) {
693 continue;
694 }
695 if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
697 continue;
698 }
699 let mark_at_c = graph.get_mark_at(b, c).unwrap_or(EdgeMark::Circle);
701 graph.set_edge(b, c, EdgeMark::Tail, mark_at_c);
702 changed = true;
703 }
704 }
705 }
706 changed
707}
708
709fn fci_r7(graph: &mut CausalGraph, p: usize) -> bool {
711 let mut changed = false;
712 for b in 0..p {
713 for a in 0..p {
714 if a == b || !graph.is_adjacent(a, b) {
715 continue;
716 }
717 if graph.get_mark_from(a, b) != Some(EdgeMark::Tail)
719 || graph.get_mark_at(a, b) != Some(EdgeMark::Circle)
720 {
721 continue;
722 }
723 for c in 0..p {
724 if c == a || c == b || !graph.is_adjacent(b, c) {
725 continue;
726 }
727 if graph.is_adjacent(a, c) {
729 continue;
730 }
731 if graph.get_mark_from(b, c) != Some(EdgeMark::Circle) {
733 continue;
734 }
735 let mark_at_c = graph.get_mark_at(b, c).unwrap_or(EdgeMark::Circle);
736 graph.set_edge(b, c, EdgeMark::Tail, mark_at_c);
737 changed = true;
738 }
739 }
740 }
741 changed
742}
743
744fn fci_r8(graph: &mut CausalGraph, p: usize) -> bool {
746 let mut changed = false;
747 for a in 0..p {
748 for c in 0..p {
749 if a == c || !graph.is_adjacent(a, c) {
750 continue;
751 }
752 if graph.get_mark_from(a, c) != Some(EdgeMark::Circle)
754 || graph.get_mark_at(a, c) != Some(EdgeMark::Arrow)
755 {
756 continue;
757 }
758 for b in 0..p {
759 if b == a || b == c {
760 continue;
761 }
762 if graph.get_mark_from(b, c) != Some(EdgeMark::Tail)
764 || graph.get_mark_at(b, c) != Some(EdgeMark::Arrow)
765 {
766 continue;
767 }
768 let mark_at_b = graph.get_mark_at(a, b);
770 let mark_from_a_to_b = graph.get_mark_from(a, b);
771 let valid = match (mark_from_a_to_b, mark_at_b) {
772 (Some(EdgeMark::Tail), Some(EdgeMark::Arrow)) => true, (Some(EdgeMark::Tail), Some(EdgeMark::Circle)) => true, _ => false,
775 };
776 if valid {
777 graph.set_edge(a, c, EdgeMark::Tail, EdgeMark::Arrow);
778 changed = true;
779 break;
780 }
781 }
782 }
783 }
784 changed
785}
786
787fn fci_r9(graph: &mut CausalGraph, p: usize) -> bool {
790 let mut changed = false;
791 for a in 0..p {
792 for c in 0..p {
793 if a == c || !graph.is_adjacent(a, c) {
794 continue;
795 }
796 if graph.get_mark_from(a, c) != Some(EdgeMark::Circle)
798 || graph.get_mark_at(a, c) != Some(EdgeMark::Arrow)
799 {
800 continue;
801 }
802 if has_directed_path_excluding_direct(graph, a, c, p) {
804 graph.set_edge(a, c, EdgeMark::Tail, EdgeMark::Arrow);
805 changed = true;
806 }
807 }
808 }
809 changed
810}
811
812fn fci_r10(graph: &mut CausalGraph, p: usize) -> bool {
816 let mut changed = false;
817 for a in 0..p {
818 for c in 0..p {
819 if a == c || !graph.is_adjacent(a, c) {
820 continue;
821 }
822 if graph.get_mark_from(a, c) != Some(EdgeMark::Circle)
824 || graph.get_mark_at(a, c) != Some(EdgeMark::Arrow)
825 {
826 continue;
827 }
828 let parents_c: Vec<usize> = (0..p)
830 .filter(|&k| {
831 k != a
832 && k != c
833 && graph.get_mark_from(k, c) == Some(EdgeMark::Tail)
834 && graph.get_mark_at(k, c) == Some(EdgeMark::Arrow)
835 })
836 .collect();
837
838 let mut orient = false;
839 for i in 0..parents_c.len() {
840 for j in (i + 1)..parents_c.len() {
841 let b = parents_c[i];
842 let d = parents_c[j];
843 let a_oo_b = graph.get_mark_from(a, b) == Some(EdgeMark::Circle)
845 && graph.get_mark_at(a, b) == Some(EdgeMark::Circle);
846 let a_oo_d = graph.get_mark_from(a, d) == Some(EdgeMark::Circle)
847 && graph.get_mark_at(a, d) == Some(EdgeMark::Circle);
848 if !a_oo_b || !a_oo_d {
849 continue;
850 }
851 if has_directed_path_general(graph, b, a, p)
853 || has_directed_path_general(graph, d, a, p)
854 {
855 orient = true;
856 break;
857 }
858 }
859 if orient {
860 break;
861 }
862 }
863 if orient {
864 graph.set_edge(a, c, EdgeMark::Tail, EdgeMark::Arrow);
865 changed = true;
866 }
867 }
868 }
869 changed
870}
871
872fn has_uncovered_circle_path(graph: &CausalGraph, src: usize, dst: usize, p: usize) -> bool {
878 let mut visited = vec![false; p];
880 visited[src] = true;
881 let mut queue = VecDeque::new();
882
883 for k in 0..p {
885 if k == dst || k == src {
886 continue;
887 }
888 if graph.is_adjacent(src, k)
889 && graph.get_mark_from(src, k) == Some(EdgeMark::Circle)
890 && graph.get_mark_at(src, k) == Some(EdgeMark::Circle)
891 {
892 queue.push_back((k, 2usize)); }
894 }
895
896 while let Some((cur, len)) = queue.pop_front() {
897 if visited[cur] {
898 continue;
899 }
900 visited[cur] = true;
901
902 if graph.is_adjacent(cur, dst)
904 && graph.get_mark_from(cur, dst) == Some(EdgeMark::Circle)
905 && graph.get_mark_at(cur, dst) == Some(EdgeMark::Circle)
906 && len + 1 >= 3
907 {
908 return true;
909 }
910
911 for next in 0..p {
913 if visited[next] || next == src || next == dst {
914 continue;
915 }
916 if graph.is_adjacent(cur, next)
917 && graph.get_mark_from(cur, next) == Some(EdgeMark::Circle)
918 && graph.get_mark_at(cur, next) == Some(EdgeMark::Circle)
919 {
920 queue.push_back((next, len + 1));
921 }
922 }
923 }
924 false
925}
926
927fn has_directed_path_excluding_direct(
930 graph: &CausalGraph,
931 src: usize,
932 dst: usize,
933 p: usize,
934) -> bool {
935 let mut visited = vec![false; p];
936 let mut stack = Vec::new();
937 for k in 0..p {
939 if k != dst
940 && graph.get_mark_from(src, k) == Some(EdgeMark::Tail)
941 && graph.get_mark_at(src, k) == Some(EdgeMark::Arrow)
942 {
943 stack.push(k);
944 }
945 }
946
947 while let Some(cur) = stack.pop() {
948 if cur == dst {
949 return true;
950 }
951 if visited[cur] {
952 continue;
953 }
954 visited[cur] = true;
955 for next in 0..p {
956 if !visited[next]
957 && graph.get_mark_from(cur, next) == Some(EdgeMark::Tail)
958 && graph.get_mark_at(cur, next) == Some(EdgeMark::Arrow)
959 {
960 stack.push(next);
961 }
962 }
963 }
964 false
965}
966
967fn has_directed_path_general(graph: &CausalGraph, src: usize, dst: usize, p: usize) -> bool {
969 let mut visited = vec![false; p];
970 let mut stack = vec![src];
971 while let Some(cur) = stack.pop() {
972 if cur == dst && cur != src {
973 return true;
974 }
975 if visited[cur] {
976 continue;
977 }
978 visited[cur] = true;
979 for next in 0..p {
980 if !visited[next]
981 && graph.get_mark_from(cur, next) == Some(EdgeMark::Tail)
982 && graph.get_mark_at(cur, next) == Some(EdgeMark::Arrow)
983 {
984 stack.push(next);
985 }
986 }
987 }
988 false
989}
990
991#[cfg(test)]
996mod tests {
997 use super::*;
998 use scirs2_core::ndarray::Array2;
999
1000 fn lcg_uniform(s: &mut u64) -> f64 {
1001 *s = s
1002 .wrapping_mul(6364136223846793005)
1003 .wrapping_add(1442695040888963407);
1004 ((*s >> 11) as f64) / ((1u64 << 53) as f64)
1005 }
1006
1007 fn lcg_normal(s: &mut u64) -> f64 {
1008 let u1 = lcg_uniform(s).max(1e-15);
1009 let u2 = lcg_uniform(s);
1010 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
1011 }
1012
1013 fn chain_data(n: usize, seed: u64) -> Array2<f64> {
1015 let mut data = Array2::<f64>::zeros((n, 3));
1016 let mut lcg = seed;
1017 for i in 0..n {
1018 data[[i, 0]] = lcg_normal(&mut lcg);
1019 data[[i, 1]] = 0.9 * data[[i, 0]] + lcg_normal(&mut lcg) * 0.3;
1020 data[[i, 2]] = 0.9 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
1021 }
1022 data
1023 }
1024
1025 fn latent_confounder_data(n: usize, seed: u64) -> Array2<f64> {
1028 let mut data = Array2::<f64>::zeros((n, 3));
1029 let mut lcg = seed;
1030 for i in 0..n {
1031 let latent = lcg_normal(&mut lcg);
1032 data[[i, 0]] = 0.8 * latent + lcg_normal(&mut lcg) * 0.3;
1033 data[[i, 1]] = 0.8 * latent + lcg_normal(&mut lcg) * 0.3;
1034 data[[i, 2]] = 0.5 * data[[i, 0]] + 0.5 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
1035 }
1036 data
1037 }
1038
1039 #[test]
1040 fn test_fci_chain() {
1041 let data = chain_data(300, 12345);
1042 let fci = FciAlgorithm::new(0.05);
1043 let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
1044 assert!(
1046 result.graph.is_adjacent(0, 1),
1047 "X-Y should be adjacent in chain"
1048 );
1049 assert!(
1050 result.graph.is_adjacent(1, 2),
1051 "Y-Z should be adjacent in chain"
1052 );
1053 assert!(
1054 !result.graph.is_adjacent(0, 2),
1055 "X-Z should not be adjacent"
1056 );
1057 }
1058
1059 #[test]
1060 fn test_fci_latent_confounder() {
1061 let data = latent_confounder_data(500, 54321);
1062 let fci = FciAlgorithm::new(0.05);
1063 let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
1064 assert!(
1067 result.graph.is_adjacent(0, 1) || result.graph.is_adjacent(0, 2),
1068 "Should find some adjacency"
1069 );
1070 assert!(result.n_tests > 0, "Should perform CI tests");
1071 }
1072
1073 #[test]
1074 fn test_fci_produces_pag() {
1075 let data = chain_data(200, 99999);
1076 let fci = FciAlgorithm::new(0.05);
1077 let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
1078 assert_eq!(result.graph.num_vars(), 3);
1080 }
1081
1082 #[test]
1083 fn test_fci_collider_detection() {
1084 let n = 300;
1086 let mut data = Array2::<f64>::zeros((n, 3));
1087 let mut lcg: u64 = 77777;
1088 for i in 0..n {
1089 data[[i, 0]] = lcg_normal(&mut lcg);
1090 data[[i, 1]] = lcg_normal(&mut lcg);
1091 data[[i, 2]] = 0.7 * data[[i, 0]] + 0.7 * data[[i, 1]] + lcg_normal(&mut lcg) * 0.3;
1092 }
1093 let fci = FciAlgorithm::new(0.05);
1094 let result = fci.fit(data.view(), &["X", "Y", "Z"]).expect("FCI failed");
1095 assert!(result.graph.is_adjacent(0, 2), "X-Z should be adjacent");
1097 assert!(result.graph.is_adjacent(1, 2), "Y-Z should be adjacent");
1098 assert!(
1099 !result.graph.is_adjacent(0, 1),
1100 "X-Y should not be adjacent"
1101 );
1102 assert!(
1104 result.graph.get_mark_at(0, 2) == Some(EdgeMark::Arrow)
1105 || result.graph.get_mark_at(1, 2) == Some(EdgeMark::Arrow),
1106 "Should detect v-structure at Z"
1107 );
1108 }
1109
1110 #[test]
1111 fn test_fci_possible_dsep() {
1112 let mut graph = CausalGraph::new(&["A", "B", "C", "D"]);
1113 graph.set_edge(0, 1, EdgeMark::Circle, EdgeMark::Arrow);
1115 graph.set_edge(1, 2, EdgeMark::Circle, EdgeMark::Circle);
1116 graph.set_edge(2, 3, EdgeMark::Circle, EdgeMark::Arrow);
1117
1118 let pdsep = possible_dsep(&graph, 0, 3, 4);
1119 assert!(
1121 pdsep.contains(&1) || pdsep.contains(&2),
1122 "Possible-D-SEP should contain intermediate nodes"
1123 );
1124 }
1125
1126 #[test]
1127 fn test_fci_r1_orientation() {
1128 let mut graph = CausalGraph::new(&["A", "B", "C"]);
1129 graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow); graph.set_edge(1, 2, EdgeMark::Circle, EdgeMark::Circle); let changed = fci_r1(&mut graph, 3);
1135 assert!(changed, "R1 should make a change");
1137 assert_eq!(
1138 graph.get_mark_from(1, 2),
1139 Some(EdgeMark::Tail),
1140 "R1: b side should be tail"
1141 );
1142 }
1143
1144 #[test]
1145 fn test_fci_edge_marks() {
1146 let mut graph = CausalGraph::new(&["A", "B", "C"]);
1147 graph.set_edge(0, 1, EdgeMark::Tail, EdgeMark::Arrow);
1148 graph.set_edge(1, 2, EdgeMark::Arrow, EdgeMark::Arrow);
1149
1150 assert!(graph.is_directed(0, 1), "A -> B");
1151 assert!(graph.is_bidirected(1, 2), "B <-> C");
1152 assert!(!graph.is_undirected(0, 1), "A -> B is not undirected");
1153 }
1154
1155 #[test]
1156 fn test_fci_empty_graph() {
1157 let data = Array2::<f64>::zeros((10, 0));
1158 let fci = FciAlgorithm::new(0.05);
1159 let result = fci.fit(data.view(), &[]).expect("FCI should handle empty");
1160 assert_eq!(result.graph.num_vars(), 0);
1161 assert_eq!(result.n_tests, 0);
1162 }
1163
1164 #[test]
1165 fn test_fci_two_vars() {
1166 let n = 200;
1167 let mut data = Array2::<f64>::zeros((n, 2));
1168 let mut lcg: u64 = 11111;
1169 for i in 0..n {
1170 data[[i, 0]] = lcg_normal(&mut lcg);
1171 data[[i, 1]] = 0.9 * data[[i, 0]] + lcg_normal(&mut lcg) * 0.3;
1172 }
1173 let fci = FciAlgorithm::new(0.05);
1174 let result = fci.fit(data.view(), &["X", "Y"]).expect("FCI with 2 vars");
1175 assert!(result.graph.is_adjacent(0, 1), "X-Y should be adjacent");
1176 }
1177}