tensorlogic_infer/causal/data.rs
1//! Observational data container and result types for causal inference.
2//!
3//! Defines [`ObservationalData`] plus the small result/spec types
4//! [`Intervention`], [`TreatmentEffect`], and [`BackdoorAdjustment`].
5
6use super::error::CausalError;
7
8// ---------------------------------------------------------------------------
9// ObservationalData
10// ---------------------------------------------------------------------------
11
12/// Container for observational (non-interventional) data.
13///
14/// Data is stored as a matrix: `samples[i]` is the i-th observation,
15/// with one entry per variable in the same order as `variables`.
16#[derive(Debug, Clone)]
17pub struct ObservationalData {
18 variables: Vec<String>,
19 samples: Vec<Vec<f64>>,
20}
21
22impl ObservationalData {
23 /// Create an empty dataset with the given variable names.
24 pub fn new(variables: Vec<String>) -> Self {
25 Self {
26 variables,
27 samples: Vec::new(),
28 }
29 }
30
31 /// Add a single observation. Returns an error if the dimension does not match.
32 pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<(), CausalError> {
33 if sample.len() != self.variables.len() {
34 return Err(CausalError::DimensionMismatch);
35 }
36 self.samples.push(sample);
37 Ok(())
38 }
39
40 /// Number of observations.
41 pub fn n_samples(&self) -> usize {
42 self.samples.len()
43 }
44
45 /// Number of variables.
46 pub fn n_variables(&self) -> usize {
47 self.variables.len()
48 }
49
50 /// Return the column index for a variable name.
51 pub(super) fn var_index(&self, var: &str) -> Option<usize> {
52 self.variables.iter().position(|v| v == var)
53 }
54
55 /// Extract all values for a single variable.
56 pub fn column(&self, var: &str) -> Option<Vec<f64>> {
57 let idx = self.var_index(var)?;
58 Some(self.samples.iter().map(|s| s[idx]).collect())
59 }
60
61 /// Compute the marginal mean of a variable.
62 pub fn mean(&self, var: &str) -> Option<f64> {
63 let col = self.column(var)?;
64 if col.is_empty() {
65 return None;
66 }
67 Some(col.iter().sum::<f64>() / col.len() as f64)
68 }
69
70 /// Compute the mean of `outcome` conditioned on `condition_var == condition_val`.
71 ///
72 /// Equality is checked with a small tolerance (1e-9) to handle floating-point values.
73 pub fn conditional_mean(
74 &self,
75 outcome: &str,
76 condition_var: &str,
77 condition_val: f64,
78 ) -> Option<f64> {
79 let out_idx = self.var_index(outcome)?;
80 let cond_idx = self.var_index(condition_var)?;
81 let filtered: Vec<f64> = self
82 .samples
83 .iter()
84 .filter(|s| (s[cond_idx] - condition_val).abs() < 1e-9)
85 .map(|s| s[out_idx])
86 .collect();
87 if filtered.is_empty() {
88 return None;
89 }
90 Some(filtered.iter().sum::<f64>() / filtered.len() as f64)
91 }
92
93 /// Return a reference to the variable names.
94 pub fn variables(&self) -> &[String] {
95 &self.variables
96 }
97
98 /// Return all samples as a slice of rows.
99 pub fn samples(&self) -> &[Vec<f64>] {
100 &self.samples
101 }
102}
103
104// ---------------------------------------------------------------------------
105// Intervention / TreatmentEffect / BackdoorAdjustment
106// ---------------------------------------------------------------------------
107
108/// A do-calculus intervention: fix variable `variable` to `value`.
109#[derive(Debug, Clone)]
110pub struct Intervention {
111 /// The name of the intervened-upon variable.
112 pub variable: String,
113 /// The value to which the variable is set.
114 pub value: f64,
115}
116
117/// Result of an average treatment effect estimation.
118#[derive(Debug, Clone)]
119pub struct TreatmentEffect {
120 /// Average treatment effect: E[Y | do(T=1)] − E[Y | do(T=0)].
121 pub ate: f64,
122 /// Average treatment effect on the treated subgroup (ATT).
123 pub ate_treated: f64,
124 /// Average treatment effect on the control subgroup (ATC).
125 pub ate_control: f64,
126 /// Estimation method used: `"backdoor"`, `"frontdoor"`, or `"iv"`.
127 pub estimator: String,
128 /// Number of samples used.
129 pub n_samples: usize,
130 /// Bootstrap 95% confidence interval, if computed.
131 pub confidence_interval: Option<(f64, f64)>,
132}
133
134/// Outcome of a backdoor adjustment set search.
135#[derive(Debug, Clone)]
136pub struct BackdoorAdjustment {
137 /// The chosen adjustment set (variable names).
138 pub adjustment_set: Vec<String>,
139 /// Whether the set satisfies the backdoor criterion.
140 pub valid: bool,
141}