Skip to main content

scirs2_stats/causal/
propensity_score.rs

1//! Propensity Score Methods for Causal Inference
2//!
3//! Provides a suite of propensity-score-based estimators:
4//!
5//! - **`PropensityScoreModel`**: logistic-regression-based PS estimation
6//! - **`IPW`**: inverse probability weighting (Horvitz-Thompson and
7//!   normalised Hajek variants)
8//! - **`PSMatching`**: nearest-neighbour, caliper, and kernel matching
9//! - **`OverlapCheck`**: common-support trimming and overlap diagnostics
10//! - Estimates ATE, ATT, and ATC
11//!
12//! # References
13//!
14//! - Rosenbaum, P.R. & Rubin, D.B. (1983). The Central Role of the Propensity Score
15//!   in Observational Studies for Causal Effects. Biometrika.
16//! - Hirano, K., Imbens, G.W. & Ridder, G. (2003). Efficient Estimation of Average
17//!   Treatment Effects Using the Estimated Propensity Score.
18//! - Heckman, J.J., Ichimura, H. & Todd, P. (1998). Matching As An Econometric
19//!   Evaluation Estimator.
20
21use crate::error::{StatsError, StatsResult};
22use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
23
24// ---------------------------------------------------------------------------
25// Result types
26// ---------------------------------------------------------------------------
27
28/// Result of a propensity score-based causal effect estimation
29#[derive(Debug, Clone)]
30pub struct PSResult {
31    /// Average Treatment Effect (population-wide)
32    pub ate: f64,
33    /// Standard error of ATE
34    pub ate_se: f64,
35    /// Average Treatment Effect on the Treated
36    pub att: f64,
37    /// Standard error of ATT
38    pub att_se: f64,
39    /// Average Treatment Effect on the Controls
40    pub atc: f64,
41    /// Standard error of ATC
42    pub atc_se: f64,
43    /// p-value for ATE test (H₀: ATE = 0)
44    pub ate_p: f64,
45    /// p-value for ATT test
46    pub att_p: f64,
47    /// p-value for ATC test
48    pub atc_p: f64,
49    /// Estimated propensity scores
50    pub propensity_scores: Array1<f64>,
51    /// Estimator name
52    pub estimator: String,
53}
54
55/// Overlap diagnostics result
56#[derive(Debug, Clone)]
57pub struct OverlapResult {
58    /// Estimated propensity scores for all units
59    pub ps: Array1<f64>,
60    /// Indices of units in the common support
61    pub common_support_idx: Vec<usize>,
62    /// Lower bound of the common support trimming rule
63    pub ps_lower: f64,
64    /// Upper bound of the common support trimming rule
65    pub ps_upper: f64,
66    /// Fraction of treated units inside common support
67    pub frac_treated_in_support: f64,
68    /// Fraction of control units inside common support
69    pub frac_control_in_support: f64,
70    /// Overlap coefficient (integral of min(f_t, f_c))
71    pub overlap_coefficient: f64,
72}
73
74/// Matching result
75#[derive(Debug, Clone)]
76pub struct MatchingResult {
77    /// ATT estimate
78    pub att: f64,
79    /// Standard error of ATT
80    pub att_se: f64,
81    /// Two-sided p-value
82    pub p_value: f64,
83    /// 95 % confidence interval
84    pub conf_interval: [f64; 2],
85    /// Number of treated units matched
86    pub n_matched_treated: usize,
87    /// Matching method used
88    pub method: String,
89}
90
91// ---------------------------------------------------------------------------
92// Utility: standard normal
93// ---------------------------------------------------------------------------
94
95fn normal_p_value(z: f64) -> f64 {
96    2.0 * (1.0 - normal_cdf(z.abs()))
97}
98
99fn normal_cdf(x: f64) -> f64 {
100    0.5 * (1.0 + erf_approx(x / std::f64::consts::SQRT_2))
101}
102
103fn erf_approx(x: f64) -> f64 {
104    let t = 1.0 / (1.0 + 0.3275911 * x.abs());
105    let y = 1.0
106        - (0.254829592
107            + (-0.284496736 + (1.421413741 + (-1.453152027 + 1.061405429 * t) * t) * t) * t)
108            * t
109            * (-x * x).exp();
110    if x >= 0.0 {
111        y
112    } else {
113        -y
114    }
115}
116
117// ---------------------------------------------------------------------------
118// Propensity Score Model (logistic regression)
119// ---------------------------------------------------------------------------
120
121/// Logistic regression estimator for the propensity score.
122///
123/// Estimates P(W=1 | X) via logistic regression using Newton-Raphson (IRLS).
124pub struct PropensityScoreModel {
125    /// Maximum number of Newton-Raphson iterations
126    pub max_iter: usize,
127    /// Convergence tolerance
128    pub tol: f64,
129    /// L2 regularisation parameter (ridge penalty)
130    pub lambda: f64,
131}
132
133impl PropensityScoreModel {
134    /// Create a new propensity score model.
135    pub fn new() -> Self {
136        Self {
137            max_iter: 200,
138            tol: 1e-8,
139            lambda: 1e-4,
140        }
141    }
142
143    /// Fit the propensity score model via IRLS (Newton-Raphson).
144    ///
145    /// # Arguments
146    /// * `x`   – covariate matrix (n × k); a constant column is prepended automatically
147    /// * `w`   – binary treatment indicator (n,)
148    ///
149    /// # Returns
150    /// Fitted coefficient vector (k+1,) including intercept.
151    pub fn fit(&self, x: &ArrayView2<f64>, w: &ArrayView1<f64>) -> StatsResult<Array1<f64>> {
152        let n = x.nrows();
153        let k = x.ncols();
154        if w.len() != n {
155            return Err(StatsError::DimensionMismatch(
156                "x rows must equal w length".into(),
157            ));
158        }
159        // Prepend intercept column
160        let mut xmat = Array2::<f64>::zeros((n, k + 1));
161        for i in 0..n {
162            xmat[[i, 0]] = 1.0;
163            for j in 0..k {
164                xmat[[i, j + 1]] = x[[i, j]];
165            }
166        }
167        let k1 = k + 1;
168        let mut beta = Array1::<f64>::zeros(k1);
169
170        for _iter in 0..self.max_iter {
171            // mu = sigmoid(X beta)
172            let eta: Array1<f64> = xmat.dot(&beta);
173            let mu: Array1<f64> = eta.mapv(sigmoid);
174            // Working weights: v_i = mu_i (1 - mu_i)
175            let v: Array1<f64> = mu.mapv(|m| (m * (1.0 - m)).max(1e-8));
176            // Gradient: X' (y - mu) - lambda * beta  (regularise all but intercept)
177            let grad_data = xmat.t().dot(&(w.to_owned() - &mu));
178            let mut grad = grad_data;
179            for j in 1..k1 {
180                grad[j] -= self.lambda * beta[j];
181            }
182            // Hessian: X' diag(v) X + lambda * I  (except [0,0])
183            // Build W^{1/2} X and solve H delta = grad
184            let sqrt_v: Array1<f64> = v.mapv(|vi| vi.sqrt());
185            let mut wxmat = Array2::<f64>::zeros((n, k1));
186            for i in 0..n {
187                for j in 0..k1 {
188                    wxmat[[i, j]] = sqrt_v[i] * xmat[[i, j]];
189                }
190            }
191            let mut hess = wxmat.t().dot(&wxmat);
192            for j in 1..k1 {
193                hess[[j, j]] += self.lambda;
194            }
195            let h_inv = cholesky_invert_ps(&hess.view())?;
196            let delta = h_inv.dot(&grad);
197            let step_norm: f64 = delta.iter().map(|&d| d * d).sum::<f64>().sqrt();
198            beta = &beta + &delta;
199            if step_norm < self.tol {
200                break;
201            }
202        }
203        Ok(beta)
204    }
205
206    /// Predict propensity scores for new covariates.
207    ///
208    /// # Arguments
209    /// * `x`    – covariate matrix (n × k)
210    /// * `beta` – fitted coefficients from `fit` (k+1,)
211    pub fn predict(&self, x: &ArrayView2<f64>, beta: &ArrayView1<f64>) -> StatsResult<Array1<f64>> {
212        let n = x.nrows();
213        let k = x.ncols();
214        if beta.len() != k + 1 {
215            return Err(StatsError::DimensionMismatch(format!(
216                "beta length {} != k+1 = {}",
217                beta.len(),
218                k + 1
219            )));
220        }
221        let mut eta = Array1::<f64>::zeros(n);
222        for i in 0..n {
223            eta[i] = beta[0];
224            for j in 0..k {
225                eta[i] += beta[j + 1] * x[[i, j]];
226            }
227        }
228        Ok(eta.mapv(sigmoid))
229    }
230}
231
232impl Default for PropensityScoreModel {
233    fn default() -> Self {
234        Self::new()
235    }
236}
237
238fn sigmoid(x: f64) -> f64 {
239    if x > 500.0 {
240        return 1.0;
241    }
242    if x < -500.0 {
243        return 0.0;
244    }
245    1.0 / (1.0 + (-x).exp())
246}
247
248fn cholesky_invert_ps(a: &scirs2_core::ndarray::ArrayView2<f64>) -> StatsResult<Array2<f64>> {
249    let n = a.nrows();
250    let mut l = Array2::<f64>::zeros((n, n));
251    for i in 0..n {
252        for j in 0..=i {
253            let mut s = a[[i, j]];
254            for p in 0..j {
255                s -= l[[i, p]] * l[[j, p]];
256            }
257            if i == j {
258                if s <= 0.0 {
259                    return Err(StatsError::ComputationError(
260                        "Hessian not positive definite (PS logistic)".into(),
261                    ));
262                }
263                l[[i, j]] = s.sqrt();
264            } else {
265                l[[i, j]] = s / l[[j, j]];
266            }
267        }
268    }
269    let mut linv = Array2::<f64>::zeros((n, n));
270    for j in 0..n {
271        linv[[j, j]] = 1.0 / l[[j, j]];
272        for i in (j + 1)..n {
273            let mut s = 0.0_f64;
274            for p in j..i {
275                s += l[[i, p]] * linv[[p, j]];
276            }
277            linv[[i, j]] = -s / l[[i, i]];
278        }
279    }
280    Ok(linv.t().dot(&linv))
281}
282
283// ---------------------------------------------------------------------------
284// Inverse Probability Weighting (IPW)
285// ---------------------------------------------------------------------------
286
287/// Inverse Probability Weighting estimator.
288///
289/// ATE (Horvitz-Thompson):
290///   τ̂_ATE = (1/n) Σ [ W_i Y_i / e_i - (1-W_i) Y_i / (1-e_i) ]
291///
292/// ATT (normalised):
293///   τ̂_ATT = Σ_{W=1} Y_i / n_t  -  Σ_{W=0} Y_i e_i/(1-e_i) / Σ_{W=0} e_i/(1-e_i)
294pub struct IPW;
295
296impl IPW {
297    /// Estimate ATE, ATT, and ATC via inverse probability weighting.
298    ///
299    /// # Arguments
300    /// * `y`  – outcome vector
301    /// * `w`  – binary treatment indicator
302    /// * `ps` – estimated propensity scores
303    /// * `trim_eps` – trim propensity scores to [trim_eps, 1 - trim_eps]
304    pub fn estimate(
305        y: &ArrayView1<f64>,
306        w: &ArrayView1<f64>,
307        ps: &ArrayView1<f64>,
308        trim_eps: f64,
309    ) -> StatsResult<PSResult> {
310        let n = y.len();
311        if w.len() != n || ps.len() != n {
312            return Err(StatsError::DimensionMismatch(
313                "y, w, ps must all have the same length".into(),
314            ));
315        }
316        let eps = trim_eps.max(0.0).min(0.49);
317
318        // Trim propensity scores
319        let ps_trim: Array1<f64> = ps.mapv(|p| p.clamp(eps, 1.0 - eps));
320
321        // ATE: Horvitz-Thompson estimator
322        let ate_terms: Array1<f64> = (0..n)
323            .map(|i| {
324                let wi = w[i];
325                let yi = y[i];
326                let pi = ps_trim[i];
327                wi * yi / pi - (1.0 - wi) * yi / (1.0 - pi)
328            })
329            .collect();
330        let ate = ate_terms.iter().sum::<f64>() / n as f64;
331
332        // ATT: normalised IPW
333        let n_treated: usize = w.iter().filter(|&&wi| wi > 0.5).count();
334        let att_num: f64 = (0..n).filter(|&i| w[i] > 0.5).map(|i| y[i]).sum::<f64>();
335        let att_denom_ctrl_num: f64 = (0..n)
336            .filter(|&i| w[i] <= 0.5)
337            .map(|i| y[i] * ps_trim[i] / (1.0 - ps_trim[i]))
338            .sum::<f64>();
339        let att_denom_ctrl_den: f64 = (0..n)
340            .filter(|&i| w[i] <= 0.5)
341            .map(|i| ps_trim[i] / (1.0 - ps_trim[i]))
342            .sum::<f64>();
343        let att = if n_treated > 0 && att_denom_ctrl_den > 1e-10 {
344            att_num / n_treated as f64 - att_denom_ctrl_num / att_denom_ctrl_den
345        } else {
346            0.0
347        };
348
349        // ATC: normalised IPW
350        let n_control = n - n_treated;
351        let atc_ctrl_mean = if n_control > 0 {
352            (0..n).filter(|&i| w[i] <= 0.5).map(|i| y[i]).sum::<f64>() / n_control as f64
353        } else {
354            0.0
355        };
356        let atc_trt_num: f64 = (0..n)
357            .filter(|&i| w[i] > 0.5)
358            .map(|i| y[i] * (1.0 - ps_trim[i]) / ps_trim[i])
359            .sum::<f64>();
360        let atc_trt_den: f64 = (0..n)
361            .filter(|&i| w[i] > 0.5)
362            .map(|i| (1.0 - ps_trim[i]) / ps_trim[i])
363            .sum::<f64>();
364        let atc = if atc_trt_den > 1e-10 {
365            atc_trt_num / atc_trt_den - atc_ctrl_mean
366        } else {
367            0.0
368        };
369
370        // Sandwich standard errors (influence-function based)
371        let ate_se = bootstrap_se_ipw_ate(y, w, &ps_trim.view(), ate, n)?;
372        let att_se = bootstrap_se_ipw_att(y, w, &ps_trim.view(), att, n)?;
373        let atc_se = ate_se; // simplified
374
375        let ate_p = normal_p_value(if ate_se > 0.0 { ate / ate_se } else { 0.0 });
376        let att_p = normal_p_value(if att_se > 0.0 { att / att_se } else { 0.0 });
377        let atc_p = normal_p_value(if atc_se > 0.0 { atc / atc_se } else { 0.0 });
378
379        Ok(PSResult {
380            ate,
381            ate_se,
382            att,
383            att_se,
384            atc,
385            atc_se,
386            ate_p,
387            att_p,
388            atc_p,
389            propensity_scores: ps_trim,
390            estimator: "IPW".into(),
391        })
392    }
393}
394
395/// Influence-function-based SE for ATE
396fn bootstrap_se_ipw_ate(
397    y: &ArrayView1<f64>,
398    w: &ArrayView1<f64>,
399    ps: &ArrayView1<f64>,
400    ate: f64,
401    n: usize,
402) -> StatsResult<f64> {
403    let psi: Array1<f64> = (0..n)
404        .map(|i| {
405            let wi = w[i];
406            let yi = y[i];
407            let pi = ps[i];
408            wi * yi / pi - (1.0 - wi) * yi / (1.0 - pi) - ate
409        })
410        .collect();
411    let var_psi: f64 = psi.iter().map(|&p| p * p).sum::<f64>() / (n * (n - 1).max(1)) as f64;
412    Ok(var_psi.sqrt())
413}
414
415/// Influence-function-based SE for ATT
416fn bootstrap_se_ipw_att(
417    y: &ArrayView1<f64>,
418    w: &ArrayView1<f64>,
419    ps: &ArrayView1<f64>,
420    att: f64,
421    n: usize,
422) -> StatsResult<f64> {
423    let n_treated: f64 = w.iter().filter(|&&wi| wi > 0.5).count() as f64;
424    if n_treated < 1.0 {
425        return Ok(0.0);
426    }
427    let psi: Array1<f64> = (0..n)
428        .map(|i| {
429            let wi = w[i];
430            let yi = y[i];
431            let pi = ps[i];
432            // Influence function for ATT
433            (wi * yi - (1.0 - wi) * pi * yi / (1.0 - pi)) / (n_treated / n as f64) - att
434        })
435        .collect();
436    let var_psi: f64 = psi.iter().map(|&p| p * p).sum::<f64>() / (n * (n - 1).max(1)) as f64;
437    Ok(var_psi.sqrt())
438}
439
440// ---------------------------------------------------------------------------
441// Propensity Score Matching
442// ---------------------------------------------------------------------------
443
444/// Matching method options
445#[derive(Debug, Clone, Copy, PartialEq, Eq)]
446pub enum MatchingMethod {
447    /// 1-nearest-neighbour matching without replacement
448    NearestNeighbour,
449    /// Caliper matching (NN within caliper)
450    Caliper,
451    /// Kernel matching (weighted average of all controls)
452    Kernel,
453}
454
455/// Propensity score matching estimator.
456pub struct PSMatching {
457    /// Matching method
458    pub method: MatchingMethod,
459    /// Caliper width (for NN and Caliper methods); if `None`, defaults to 0.2 * sd(logit(ps))
460    pub caliper: Option<f64>,
461    /// Number of nearest neighbours (k:1 matching)
462    pub n_neighbours: usize,
463    /// Kernel bandwidth for kernel matching
464    pub kernel_bandwidth: Option<f64>,
465}
466
467impl PSMatching {
468    /// Create a new PSMatching estimator.
469    pub fn new(method: MatchingMethod) -> Self {
470        Self {
471            method,
472            caliper: None,
473            n_neighbours: 1,
474            kernel_bandwidth: None,
475        }
476    }
477
478    /// Estimate ATT via propensity score matching.
479    ///
480    /// # Arguments
481    /// * `y`  – outcome
482    /// * `w`  – binary treatment indicator
483    /// * `ps` – estimated propensity scores
484    pub fn estimate_att(
485        &self,
486        y: &ArrayView1<f64>,
487        w: &ArrayView1<f64>,
488        ps: &ArrayView1<f64>,
489    ) -> StatsResult<MatchingResult> {
490        let n = y.len();
491        if w.len() != n || ps.len() != n {
492            return Err(StatsError::DimensionMismatch(
493                "y, w, ps must have equal length".into(),
494            ));
495        }
496
497        let treated_idx: Vec<usize> = (0..n).filter(|&i| w[i] > 0.5).collect();
498        let control_idx: Vec<usize> = (0..n).filter(|&i| w[i] <= 0.5).collect();
499
500        if treated_idx.is_empty() {
501            return Err(StatsError::InsufficientData("No treated units".into()));
502        }
503        if control_idx.is_empty() {
504            return Err(StatsError::InsufficientData("No control units".into()));
505        }
506
507        // Caliper in logit(ps) scale
508        let logit_ps: Array1<f64> = ps.mapv(|p| logit(p.clamp(1e-8, 1.0 - 1e-8)));
509        let logit_sd = std_dev_vec(&logit_ps.to_vec());
510        let caliper_val = self.caliper.unwrap_or(0.2 * logit_sd);
511        let bw = self.kernel_bandwidth.unwrap_or(0.1 * logit_sd);
512
513        match self.method {
514            MatchingMethod::NearestNeighbour | MatchingMethod::Caliper => {
515                self.nn_matching_att(y, &treated_idx, &control_idx, &logit_ps.view(), caliper_val)
516            }
517            MatchingMethod::Kernel => {
518                self.kernel_matching_att(y, &treated_idx, &control_idx, ps, bw)
519            }
520        }
521    }
522
523    fn nn_matching_att(
524        &self,
525        y: &ArrayView1<f64>,
526        treated_idx: &[usize],
527        control_idx: &[usize],
528        logit_ps: &ArrayView1<f64>,
529        caliper: f64,
530    ) -> StatsResult<MatchingResult> {
531        let mut matched_diffs: Vec<f64> = Vec::new();
532        let use_caliper = self.method == MatchingMethod::Caliper;
533
534        for &ti in treated_idx {
535            let lps_t = logit_ps[ti];
536            let best = control_idx
537                .iter()
538                .map(|&ci| (ci, (logit_ps[ci] - lps_t).abs()))
539                .filter(|(_, dist)| !use_caliper || *dist <= caliper)
540                .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
541            if let Some((best_ci, _)) = best {
542                matched_diffs.push(y[ti] - y[best_ci]);
543            }
544        }
545
546        if matched_diffs.is_empty() {
547            return Err(StatsError::InsufficientData(
548                "No matches found; try increasing the caliper".into(),
549            ));
550        }
551
552        let n_m = matched_diffs.len();
553        let att = matched_diffs.iter().sum::<f64>() / n_m as f64;
554        let se = if n_m > 1 {
555            let var = matched_diffs
556                .iter()
557                .map(|&d| (d - att).powi(2))
558                .sum::<f64>()
559                / (n_m * (n_m - 1)) as f64;
560            var.sqrt()
561        } else {
562            0.0
563        };
564        let t = if se > 0.0 { att / se } else { 0.0 };
565        let p = normal_p_value(t);
566        let ci = [att - 1.96 * se, att + 1.96 * se];
567
568        let method_name = if self.method == MatchingMethod::Caliper {
569            "Caliper-matching"
570        } else {
571            "NN-matching"
572        };
573
574        Ok(MatchingResult {
575            att,
576            att_se: se,
577            p_value: p,
578            conf_interval: ci,
579            n_matched_treated: n_m,
580            method: method_name.into(),
581        })
582    }
583
584    fn kernel_matching_att(
585        &self,
586        y: &ArrayView1<f64>,
587        treated_idx: &[usize],
588        control_idx: &[usize],
589        ps: &ArrayView1<f64>,
590        bw: f64,
591    ) -> StatsResult<MatchingResult> {
592        let mut diffs: Vec<f64> = Vec::with_capacity(treated_idx.len());
593        for &ti in treated_idx {
594            let psi = ps[ti];
595            // Epanechnikov kernel weights
596            let weights: Vec<f64> = control_idx
597                .iter()
598                .map(|&ci| {
599                    let u = (ps[ci] - psi) / bw;
600                    if u.abs() < 1.0 {
601                        0.75 * (1.0 - u * u)
602                    } else {
603                        0.0
604                    }
605                })
606                .collect();
607            let total_w: f64 = weights.iter().sum();
608            if total_w < 1e-10 {
609                continue;
610            }
611            let y_ctrl_wt: f64 = control_idx
612                .iter()
613                .zip(weights.iter())
614                .map(|(&ci, &wi)| y[ci] * wi)
615                .sum::<f64>()
616                / total_w;
617            diffs.push(y[ti] - y_ctrl_wt);
618        }
619        if diffs.is_empty() {
620            return Err(StatsError::InsufficientData(
621                "No matches with positive kernel weight; reduce bandwidth".into(),
622            ));
623        }
624        let n_m = diffs.len();
625        let att = diffs.iter().sum::<f64>() / n_m as f64;
626        let se = if n_m > 1 {
627            let var =
628                diffs.iter().map(|&d| (d - att).powi(2)).sum::<f64>() / (n_m * (n_m - 1)) as f64;
629            var.sqrt()
630        } else {
631            0.0
632        };
633        let t = if se > 0.0 { att / se } else { 0.0 };
634        let p = normal_p_value(t);
635        let ci = [att - 1.96 * se, att + 1.96 * se];
636        Ok(MatchingResult {
637            att,
638            att_se: se,
639            p_value: p,
640            conf_interval: ci,
641            n_matched_treated: n_m,
642            method: "Kernel-matching".into(),
643        })
644    }
645}
646
647fn logit(p: f64) -> f64 {
648    (p / (1.0 - p)).ln()
649}
650
651fn std_dev_vec(v: &[f64]) -> f64 {
652    let n = v.len();
653    if n < 2 {
654        return 1.0;
655    }
656    let mean = v.iter().sum::<f64>() / n as f64;
657    let var = v.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
658    var.sqrt().max(1e-15)
659}
660
661// ---------------------------------------------------------------------------
662// Overlap Check / Common Support
663// ---------------------------------------------------------------------------
664
665/// Common-support and overlap diagnostics for propensity score analysis.
666pub struct OverlapCheck {
667    /// Trimming rule: exclude units with PS outside [min_treated + epsilon, max_control - epsilon]
668    pub trim_method: TrimMethod,
669}
670
671/// Trimming method for common-support enforcement
672#[derive(Debug, Clone, Copy, PartialEq, Eq)]
673pub enum TrimMethod {
674    /// Crump et al. (2009) optimal trimming: exclude if ps < α or ps > 1-α
675    Crump,
676    /// Min-max trimming: keep only the overlap range
677    MinMax,
678    /// Percentile trimming: trim the most extreme τ% on each side
679    Percentile,
680}
681
682impl OverlapCheck {
683    /// Create a new OverlapCheck analyser.
684    pub fn new(trim_method: TrimMethod) -> Self {
685        Self { trim_method }
686    }
687
688    /// Compute overlap diagnostics.
689    ///
690    /// # Arguments
691    /// * `ps` – propensity scores for all units
692    /// * `w`  – binary treatment indicator
693    pub fn check(&self, ps: &ArrayView1<f64>, w: &ArrayView1<f64>) -> StatsResult<OverlapResult> {
694        let n = ps.len();
695        if w.len() != n {
696            return Err(StatsError::DimensionMismatch(
697                "ps and w must have equal length".into(),
698            ));
699        }
700
701        let treated_ps: Vec<f64> = (0..n).filter(|&i| w[i] > 0.5).map(|i| ps[i]).collect();
702        let control_ps: Vec<f64> = (0..n).filter(|&i| w[i] <= 0.5).map(|i| ps[i]).collect();
703
704        if treated_ps.is_empty() || control_ps.is_empty() {
705            return Err(StatsError::InsufficientData(
706                "Need both treated and control units".into(),
707            ));
708        }
709
710        let (ps_lower, ps_upper) = match self.trim_method {
711            TrimMethod::Crump => {
712                // Crump optimal α ≈ 0.1 (simple approximation)
713                (0.1_f64, 0.9_f64)
714            }
715            TrimMethod::MinMax => {
716                let min_t = treated_ps.iter().cloned().fold(f64::INFINITY, f64::min);
717                let max_t = treated_ps.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
718                let min_c = control_ps.iter().cloned().fold(f64::INFINITY, f64::min);
719                let max_c = control_ps.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
720                (min_t.max(min_c), max_t.min(max_c))
721            }
722            TrimMethod::Percentile => {
723                // Trim 5% on each side
724                let alpha = 0.05_f64;
725                let all_ps: Vec<f64> = ps.to_vec();
726                let lower = quantile_val(&all_ps, alpha);
727                let upper = quantile_val(&all_ps, 1.0 - alpha);
728                (lower, upper)
729            }
730        };
731
732        let common_support_idx: Vec<usize> = (0..n)
733            .filter(|&i| ps[i] >= ps_lower && ps[i] <= ps_upper)
734            .collect();
735
736        let n_t = treated_ps.len() as f64;
737        let n_c = control_ps.len() as f64;
738        let frac_t = treated_ps
739            .iter()
740            .filter(|&&p| p >= ps_lower && p <= ps_upper)
741            .count() as f64
742            / n_t.max(1.0);
743        let frac_c = control_ps
744            .iter()
745            .filter(|&&p| p >= ps_lower && p <= ps_upper)
746            .count() as f64
747            / n_c.max(1.0);
748
749        // Overlap coefficient: approximate as fraction of total PS range covered by both groups
750        let overlap_coefficient = overlap_coef(&treated_ps, &control_ps);
751
752        Ok(OverlapResult {
753            ps: ps.to_owned(),
754            common_support_idx,
755            ps_lower,
756            ps_upper,
757            frac_treated_in_support: frac_t,
758            frac_control_in_support: frac_c,
759            overlap_coefficient,
760        })
761    }
762}
763
764fn quantile_val(v: &[f64], q: f64) -> f64 {
765    if v.is_empty() {
766        return 0.5;
767    }
768    let mut sorted = v.to_vec();
769    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
770    let idx = ((q * (sorted.len() - 1) as f64).round() as usize).min(sorted.len() - 1);
771    sorted[idx]
772}
773
774/// Overlap coefficient: 1 - total variation distance
775fn overlap_coef(ps_t: &[f64], ps_c: &[f64]) -> f64 {
776    if ps_t.is_empty() || ps_c.is_empty() {
777        return 0.0;
778    }
779    // Grid-based approximation
780    let all_min = ps_t
781        .iter()
782        .chain(ps_c.iter())
783        .cloned()
784        .fold(f64::INFINITY, f64::min);
785    let all_max = ps_t
786        .iter()
787        .chain(ps_c.iter())
788        .cloned()
789        .fold(f64::NEG_INFINITY, f64::max);
790    if (all_max - all_min).abs() < 1e-10 {
791        return 1.0;
792    }
793    let n_bins = 100_usize;
794    let step = (all_max - all_min) / n_bins as f64;
795    let mut oc = 0.0_f64;
796    for b in 0..n_bins {
797        let lo = all_min + b as f64 * step;
798        let hi = lo + step;
799        let ft = ps_t.iter().filter(|&&p| p >= lo && p < hi).count() as f64 / ps_t.len() as f64;
800        let fc = ps_c.iter().filter(|&&p| p >= lo && p < hi).count() as f64 / ps_c.len() as f64;
801        oc += ft.min(fc);
802    }
803    oc
804}
805
806// ---------------------------------------------------------------------------
807// Full estimation pipeline
808// ---------------------------------------------------------------------------
809
810/// Convenience function: estimate ATE/ATT/ATC using propensity score methods.
811///
812/// Fits a logistic propensity score model and applies IPW.
813///
814/// # Arguments
815/// * `y`  – outcome vector (n,)
816/// * `w`  – binary treatment indicator (n,)
817/// * `x`  – covariate matrix (n × k)
818/// * `trim_eps` – propensity score trimming threshold
819pub fn ps_estimate(
820    y: &ArrayView1<f64>,
821    w: &ArrayView1<f64>,
822    x: &ArrayView2<f64>,
823    trim_eps: f64,
824) -> StatsResult<PSResult> {
825    let ps_model = PropensityScoreModel::new();
826    let beta = ps_model.fit(x, w)?;
827    let ps = ps_model.predict(x, &beta.view())?;
828    IPW::estimate(y, w, &ps.view(), trim_eps)
829}
830
831// ---------------------------------------------------------------------------
832// Tests
833// ---------------------------------------------------------------------------
834
835#[cfg(test)]
836mod tests {
837    use super::*;
838    use scirs2_core::ndarray::{array, Array1, Array2};
839
840    #[test]
841    fn test_logistic_regression_ps() {
842        // Binary outcome with one covariate; should converge
843        let x = array![
844            [0.0],
845            [1.0],
846            [2.0],
847            [3.0],
848            [4.0],
849            [5.0],
850            [6.0],
851            [7.0],
852            [8.0],
853            [9.0]
854        ];
855        let w = array![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
856        let model = PropensityScoreModel::new();
857        let beta = model
858            .fit(&x.view(), &w.view())
859            .expect("Logistic fit should succeed");
860        assert_eq!(beta.len(), 2);
861        // Coefficient on x should be positive (larger x → more likely treated)
862        assert!(
863            beta[1] > 0.0,
864            "Coefficient should be positive, got {}",
865            beta[1]
866        );
867        // Predict: units with x>5 should have ps > 0.5
868        let ps = model
869            .predict(&x.view(), &beta.view())
870            .expect("Predict should succeed");
871        assert!(ps[9] > 0.5, "ps for x=9 should be > 0.5, got {}", ps[9]);
872        assert!(ps[0] < 0.5, "ps for x=0 should be < 0.5, got {}", ps[0]);
873    }
874
875    #[test]
876    fn test_ipw_zero_effect() {
877        // No treatment effect: both groups have same outcome distribution
878        let n = 100_usize;
879        let ps: Array1<f64> = (0..n).map(|i| 0.3 + 0.4 * (i as f64 / n as f64)).collect();
880        let w: Array1<f64> = ps.mapv(|p| if p > 0.5 { 1.0 } else { 0.0 });
881        // Outcomes equal to a constant (no effect)
882        let y: Array1<f64> = Array1::ones(n);
883        let res =
884            IPW::estimate(&y.view(), &w.view(), &ps.view(), 0.05).expect("IPW should succeed");
885        assert!(
886            res.ate.abs() < 0.1,
887            "ATE should be ~0 when no effect, got {}",
888            res.ate
889        );
890    }
891
892    #[test]
893    fn test_ps_matching_nn() {
894        let n = 40_usize;
895        let ps: Array1<f64> = (0..n).map(|i| 0.1 + 0.8 * i as f64 / n as f64).collect();
896        let w: Array1<f64> = ps.mapv(|p| if p > 0.5 { 1.0 } else { 0.0 });
897        // Treatment effect = 2
898        let y: Array1<f64> = (0..n).map(|i| if w[i] > 0.5 { 5.0 } else { 3.0 }).collect();
899        let matcher = PSMatching::new(MatchingMethod::NearestNeighbour);
900        let res = matcher
901            .estimate_att(&y.view(), &w.view(), &ps.view())
902            .expect("NN matching should succeed");
903        assert!(
904            (res.att - 2.0).abs() < 0.5,
905            "ATT should be ~2.0, got {}",
906            res.att
907        );
908    }
909
910    #[test]
911    fn test_overlap_check_minmax() {
912        // Ensure treated and control propensity scores overlap:
913        // treated PS: 0.3, 0.5, 0.6, 0.7, 0.8  (min=0.3, max=0.8)
914        // control PS: 0.1, 0.2, 0.4, 0.5, 0.9  (min=0.1, max=0.9)
915        // MinMax common support: [max(0.3,0.1), min(0.8,0.9)] = [0.3, 0.8]
916        let ps = array![0.1, 0.3, 0.4, 0.5, 0.5, 0.2, 0.6, 0.7, 0.8, 0.9];
917        let w = array![0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0];
918        let checker = OverlapCheck::new(TrimMethod::MinMax);
919        let res = checker
920            .check(&ps.view(), &w.view())
921            .expect("Overlap check should succeed");
922        assert!(
923            res.ps_lower < res.ps_upper,
924            "lower={} >= upper={}",
925            res.ps_lower,
926            res.ps_upper
927        );
928        assert!(!res.common_support_idx.is_empty());
929    }
930
931    #[test]
932    fn test_ps_estimate_pipeline() {
933        let n = 60_usize;
934        let mut x_data = Array2::<f64>::zeros((n, 1));
935        let mut w_data = Array1::<f64>::zeros(n);
936        let mut y_data = Array1::<f64>::zeros(n);
937        for i in 0..n {
938            let xi = i as f64 / n as f64;
939            x_data[[i, 0]] = xi;
940            w_data[i] = if xi > 0.5 { 1.0 } else { 0.0 };
941            y_data[i] = if w_data[i] > 0.5 { 3.0 + xi } else { 1.0 + xi };
942        }
943        let res = ps_estimate(&y_data.view(), &w_data.view(), &x_data.view(), 0.05)
944            .expect("PS estimate pipeline should succeed");
945        // Treatment effect ≈ 2
946        assert!(res.ate.abs() > 0.0, "ATE should be non-zero");
947    }
948}