scirs2_stats/causal/mod.rs
1//! Causal Inference Methods
2//!
3//! This module provides a comprehensive suite of causal inference estimators:
4//!
5//! ## Sub-modules
6//!
7//! | Module | Methods |
8//! |--------|---------|
9//! | [`instrumental_variables`] | 2SLS, LIML, Hausman test, weak-instrument diagnostics |
10//! | [`difference_in_differences`] | DiD (TWFE), synthetic control, event study, staggered DiD |
11//! | [`regression_discontinuity`] | Sharp RDD, fuzzy RDD, bandwidth selection, RD plots |
12//! | [`propensity_score`] | Logistic PS model, IPW, nearest-neighbour / kernel matching |
13//!
14//! ## Quick start
15//!
16//! ```rust
17//! use scirs2_stats::causal::instrumental_variables::{IVEstimator, WeakInstrumentTest};
18//! use scirs2_stats::causal::propensity_score::{
19//! PropensityScoreModel, IPW, PSMatching, MatchingMethod,
20//! };
21//! ```
22//!
23//! ## References
24//!
25//! - Angrist, J.D. & Pischke, J.-S. (2009). Mostly Harmless Econometrics.
26//! - Callaway, B. & Sant'Anna, P.H.C. (2021). Difference-in-Differences with
27//! Multiple Time Periods.
28//! - Imbens, G.W. & Kalyanaraman, K. (2012). Optimal Bandwidth Choice for
29//! the Regression Discontinuity Estimator.
30//! - Rosenbaum, P.R. & Rubin, D.B. (1983). The Central Role of the Propensity
31//! Score in Observational Studies for Causal Effects.
32
33pub mod difference_in_differences;
34pub mod instrumental_variables;
35pub mod propensity_score;
36pub mod regression_discontinuity;
37
38// ---------------------------------------------------------------------------
39// Re-exports — instrumental variables
40// ---------------------------------------------------------------------------
41
42pub use instrumental_variables::{
43 HausmanResult, HausmanTest, IVEstimator, IVResult, WeakInstrumentResult, WeakInstrumentTest,
44 LIML,
45};
46
47// ---------------------------------------------------------------------------
48// Re-exports — difference-in-differences
49// ---------------------------------------------------------------------------
50
51pub use difference_in_differences::{
52 AttGt, DiD, DiDResult, EventCoefficient, EventStudy, EventStudyResult, StaggeredDiD,
53 StaggeredDiDResult, SyntheticControl,
54};
55
56// ---------------------------------------------------------------------------
57// Re-exports — regression discontinuity
58// ---------------------------------------------------------------------------
59
60pub use regression_discontinuity::{
61 BandwidthMethod, BandwidthSelector, FuzzyRDD, RDDPlot, RDDResult, RDD,
62};
63
64// ---------------------------------------------------------------------------
65// Re-exports — propensity score
66// ---------------------------------------------------------------------------
67
68pub use propensity_score::{
69 MatchingMethod, MatchingResult, OverlapCheck, OverlapResult, PSMatching, PSResult,
70 PropensityScoreModel, TrimMethod, IPW,
71};
72
73/// Convenience function: fit a propensity score model and estimate ATE/ATT/ATC via IPW.
74pub use propensity_score::ps_estimate;
75
76// ---------------------------------------------------------------------------
77// Structural Equation Models
78// ---------------------------------------------------------------------------
79
80pub mod sem;
81
82pub use sem::{satisfies_backdoor, IdentificationResult, LinearEquation, SEMWithIntercepts, SEM};
83
84// ---------------------------------------------------------------------------
85// Linear SEM with ndarray interface
86// ---------------------------------------------------------------------------
87
88pub mod conditional_independence;
89pub mod fci_algorithm;
90pub mod hedge;
91pub mod id_algorithm;
92pub mod linear_sem;
93pub mod pc_algorithm;
94pub mod semi_markov_graph;
95pub mod symbolic_prob;
96
97pub use linear_sem::{LinearSEM, LinearSEMWithIntercepts};
98
99// ---------------------------------------------------------------------------
100// Causal graph types for constraint-based algorithms (PC, FCI, etc.)
101// ---------------------------------------------------------------------------
102
103/// Marks on the endpoint of an edge in a mixed graph.
104///
105/// Used by constraint-based causal discovery algorithms (PC, FCI) to represent
106/// different types of edges in CPDAGs and PAGs:
107///
108/// - `Tail` — definite tail (non-ancestral mark, as in `→` tails, `—` tails).
109/// - `Arrow` — definite arrowhead.
110/// - `Circle` — unknown endpoint (used by FCI for partial ancestral graphs).
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112pub enum EdgeMark {
113 /// Definite tail: non-arrowhead endpoint.
114 Tail,
115 /// Definite arrowhead.
116 Arrow,
117 /// Unknown endpoint mark (FCI/PAG only).
118 Circle,
119}
120
121/// Mixed graph for causal discovery algorithms.
122///
123/// Each edge `(i, j)` is stored as a pair `(mark_at_i, mark_at_j)`, where
124/// `mark_at_i` is the mark at node `i` (endpoint facing node `i`) and
125/// `mark_at_j` is the mark at node `j`.
126///
127/// - Directed edge `i → j`: `(Tail, Arrow)` stored at entry `(i, j)`.
128/// - Undirected edge `i — j`: `(Tail, Tail)`.
129/// - Bidirected edge `i ↔ j`: `(Arrow, Arrow)`.
130/// - Circle endpoint `i o→ j`: `(Circle, Arrow)`.
131#[derive(Debug, Clone)]
132pub struct CausalGraph {
133 /// Variable names.
134 pub var_names: Vec<String>,
135 /// Adjacency: `edges[i][j] = Some((mark_at_i_from_j, mark_at_j_from_i))`.
136 /// If `edges[i][j].is_some()` then `edges[j][i].is_some()` as well.
137 edges: Vec<Vec<Option<(EdgeMark, EdgeMark)>>>,
138 /// Separation sets: `sep[i][j]` is the set that d-separates i and j.
139 pub sep_sets: Vec<Vec<Option<Vec<usize>>>>,
140}
141
142impl CausalGraph {
143 /// Create a new graph with the given variable names, initially fully connected.
144 pub fn new(var_names: &[&str]) -> Self {
145 let p = var_names.len();
146 Self {
147 var_names: var_names.iter().map(|s| s.to_string()).collect(),
148 edges: vec![vec![None; p]; p],
149 sep_sets: vec![vec![None; p]; p],
150 }
151 }
152
153 /// Number of variables (nodes).
154 pub fn num_vars(&self) -> usize {
155 self.var_names.len()
156 }
157
158 /// Set or update an edge between `i` and `j`.
159 ///
160 /// `mark_at_i` is the endpoint mark at node `i`; `mark_at_j` is the mark at `j`.
161 /// Setting an edge is symmetric: `edges[i][j]` and `edges[j][i]` are both updated.
162 pub fn set_edge(&mut self, i: usize, j: usize, mark_at_i: EdgeMark, mark_at_j: EdgeMark) {
163 self.edges[i][j] = Some((mark_at_i, mark_at_j));
164 self.edges[j][i] = Some((mark_at_j, mark_at_i));
165 }
166
167 /// Remove an edge between `i` and `j`.
168 pub fn remove_edge(&mut self, i: usize, j: usize) {
169 self.edges[i][j] = None;
170 self.edges[j][i] = None;
171 }
172
173 /// Whether there is any edge between `i` and `j`.
174 pub fn is_adjacent(&self, i: usize, j: usize) -> bool {
175 self.edges[i][j].is_some()
176 }
177
178 /// Get the mark at node `to` on the edge from `from` to `to`.
179 ///
180 /// Returns `None` if the edge doesn't exist.
181 /// The returned mark is the one facing node `to` (i.e., the arrowhead/tail at `to`).
182 pub fn get_mark_at(&self, from: usize, to: usize) -> Option<EdgeMark> {
183 self.edges[from][to].map(|(_, mark_at_to)| mark_at_to)
184 }
185
186 /// Get the mark at node `from` on the edge between `from` and `to`.
187 ///
188 /// Returns `None` if the edge doesn't exist.
189 pub fn get_mark_from(&self, from: usize, to: usize) -> Option<EdgeMark> {
190 self.edges[from][to].map(|(mark_at_from, _)| mark_at_from)
191 }
192
193 /// Whether there is a directed edge `i → j` (tail at `i`, arrow at `j`).
194 pub fn is_directed(&self, i: usize, j: usize) -> bool {
195 matches!(self.edges[i][j], Some((EdgeMark::Tail, EdgeMark::Arrow)))
196 }
197
198 /// Whether there is an undirected edge `i — j` (tail at both ends).
199 pub fn is_undirected(&self, i: usize, j: usize) -> bool {
200 matches!(self.edges[i][j], Some((EdgeMark::Tail, EdgeMark::Tail)))
201 }
202
203 /// Whether there is a bidirected edge `i ↔ j` (arrow at both ends).
204 pub fn is_bidirected(&self, i: usize, j: usize) -> bool {
205 matches!(self.edges[i][j], Some((EdgeMark::Arrow, EdgeMark::Arrow)))
206 }
207
208 /// Return an iterator over the neighbours of node `i`.
209 pub fn neighbors(&self, i: usize) -> impl Iterator<Item = usize> + '_ {
210 (0..self.num_vars()).filter(move |&j| j != i && self.is_adjacent(i, j))
211 }
212
213 /// Return the separation set for nodes `i` and `j`, if any.
214 pub fn get_sep_set(&self, i: usize, j: usize) -> Option<&Vec<usize>> {
215 self.sep_sets[i][j].as_ref()
216 }
217
218 /// Set the separation set for nodes `i` and `j`.
219 pub fn set_sep_set(&mut self, i: usize, j: usize, sep: Vec<usize>) {
220 self.sep_sets[i][j] = Some(sep.clone());
221 self.sep_sets[j][i] = Some(sep);
222 }
223
224 /// Initialize the graph as a complete undirected graph (all edges `i — j`).
225 pub fn make_complete(&mut self) {
226 let p = self.num_vars();
227 for i in 0..p {
228 for j in 0..p {
229 if i != j {
230 self.edges[i][j] = Some((EdgeMark::Tail, EdgeMark::Tail));
231 }
232 }
233 }
234 }
235}