Skip to main content

tensorlogic_infer/causal/
data.rs

1//! Observational data container and result types for causal inference.
2//!
3//! Defines [`ObservationalData`] plus the small result/spec types
4//! [`Intervention`], [`TreatmentEffect`], and [`BackdoorAdjustment`].
5
6use super::error::CausalError;
7
8// ---------------------------------------------------------------------------
9// ObservationalData
10// ---------------------------------------------------------------------------
11
12/// Container for observational (non-interventional) data.
13///
14/// Data is stored as a matrix: `samples[i]` is the i-th observation,
15/// with one entry per variable in the same order as `variables`.
16#[derive(Debug, Clone)]
17pub struct ObservationalData {
18    variables: Vec<String>,
19    samples: Vec<Vec<f64>>,
20}
21
22impl ObservationalData {
23    /// Create an empty dataset with the given variable names.
24    pub fn new(variables: Vec<String>) -> Self {
25        Self {
26            variables,
27            samples: Vec::new(),
28        }
29    }
30
31    /// Add a single observation. Returns an error if the dimension does not match.
32    pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<(), CausalError> {
33        if sample.len() != self.variables.len() {
34            return Err(CausalError::DimensionMismatch);
35        }
36        self.samples.push(sample);
37        Ok(())
38    }
39
40    /// Number of observations.
41    pub fn n_samples(&self) -> usize {
42        self.samples.len()
43    }
44
45    /// Number of variables.
46    pub fn n_variables(&self) -> usize {
47        self.variables.len()
48    }
49
50    /// Return the column index for a variable name.
51    pub(super) fn var_index(&self, var: &str) -> Option<usize> {
52        self.variables.iter().position(|v| v == var)
53    }
54
55    /// Extract all values for a single variable.
56    pub fn column(&self, var: &str) -> Option<Vec<f64>> {
57        let idx = self.var_index(var)?;
58        Some(self.samples.iter().map(|s| s[idx]).collect())
59    }
60
61    /// Compute the marginal mean of a variable.
62    pub fn mean(&self, var: &str) -> Option<f64> {
63        let col = self.column(var)?;
64        if col.is_empty() {
65            return None;
66        }
67        Some(col.iter().sum::<f64>() / col.len() as f64)
68    }
69
70    /// Compute the mean of `outcome` conditioned on `condition_var == condition_val`.
71    ///
72    /// Equality is checked with a small tolerance (1e-9) to handle floating-point values.
73    pub fn conditional_mean(
74        &self,
75        outcome: &str,
76        condition_var: &str,
77        condition_val: f64,
78    ) -> Option<f64> {
79        let out_idx = self.var_index(outcome)?;
80        let cond_idx = self.var_index(condition_var)?;
81        let filtered: Vec<f64> = self
82            .samples
83            .iter()
84            .filter(|s| (s[cond_idx] - condition_val).abs() < 1e-9)
85            .map(|s| s[out_idx])
86            .collect();
87        if filtered.is_empty() {
88            return None;
89        }
90        Some(filtered.iter().sum::<f64>() / filtered.len() as f64)
91    }
92
93    /// Return a reference to the variable names.
94    pub fn variables(&self) -> &[String] {
95        &self.variables
96    }
97
98    /// Return all samples as a slice of rows.
99    pub fn samples(&self) -> &[Vec<f64>] {
100        &self.samples
101    }
102}
103
104// ---------------------------------------------------------------------------
105// Intervention / TreatmentEffect / BackdoorAdjustment
106// ---------------------------------------------------------------------------
107
108/// A do-calculus intervention: fix variable `variable` to `value`.
109#[derive(Debug, Clone)]
110pub struct Intervention {
111    /// The name of the intervened-upon variable.
112    pub variable: String,
113    /// The value to which the variable is set.
114    pub value: f64,
115}
116
117/// Result of an average treatment effect estimation.
118#[derive(Debug, Clone)]
119pub struct TreatmentEffect {
120    /// Average treatment effect: E[Y | do(T=1)] − E[Y | do(T=0)].
121    pub ate: f64,
122    /// Average treatment effect on the treated subgroup (ATT).
123    pub ate_treated: f64,
124    /// Average treatment effect on the control subgroup (ATC).
125    pub ate_control: f64,
126    /// Estimation method used: `"backdoor"`, `"frontdoor"`, or `"iv"`.
127    pub estimator: String,
128    /// Number of samples used.
129    pub n_samples: usize,
130    /// Bootstrap 95% confidence interval, if computed.
131    pub confidence_interval: Option<(f64, f64)>,
132}
133
134/// Outcome of a backdoor adjustment set search.
135#[derive(Debug, Clone)]
136pub struct BackdoorAdjustment {
137    /// The chosen adjustment set (variable names).
138    pub adjustment_set: Vec<String>,
139    /// Whether the set satisfies the backdoor criterion.
140    pub valid: bool,
141}