Skip to main content

scirs2_stats/causal_graph/
estimation.rs

1//! Causal Effect Estimation Methods
2//!
3//! # Methods provided
4//!
5//! | Estimator | Description |
6//! |-----------|-------------|
7//! | [`IPWEstimator`] | Inverse probability weighting (Horvitz-Thompson + Hájek) |
8//! | [`DoublyRobustEstimator`] | Doubly-robust / AIPW estimator |
9//! | [`NearestNeighborMatching`] | Nearest-neighbor matching on covariates or propensity score |
10//! | [`RegressionDiscontinuity`] | Sharp and fuzzy regression discontinuity design |
11//! | [`SyntheticControlEstimator`] | Abadie-Diamond-Hainmueller synthetic control |
12//! | [`DifferenceInDifferences`] | DiD with parallel-trends test |
13//!
14//! All estimators return an [`EstimationResult`] with point estimate, standard
15//! errors, confidence intervals, and p-values.
16//!
17//! # References
18//!
19//! - Imbens, G.W. & Rubin, D.B. (2015). *Causal Inference for Statistics, Social,
20//!   and Biomedical Sciences*. Cambridge University Press.
21//! - Abadie, A., Diamond, A. & Hainmueller, J. (2010). Synthetic Control Methods.
22//!   *JASA*, 105, 493-505.
23//! - Hahn, J., Todd, P. & van der Klaauw, W. (2001). Identification and Estimation
24//!   of Treatment Effects with a Regression Discontinuity Design. *Econometrica*.
25//! - Hirano, K. & Imbens, G.W. (2001). Estimation of Causal Effects using
26//!   Propensity Score Weighting. *Health Services & Outcomes Research Methodology*.
27
28use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
29
30use crate::error::{StatsError, StatsResult};
31
32// ---------------------------------------------------------------------------
33// Shared result type
34// ---------------------------------------------------------------------------
35
36/// Result of a causal effect estimation procedure.
37#[derive(Debug, Clone)]
38pub struct EstimationResult {
39    /// Point estimate of the treatment effect (ATE / ATT / LATE / etc.).
40    pub estimate: f64,
41    /// Heteroscedasticity-consistent standard error.
42    pub std_error: f64,
43    /// 95 % confidence interval `[lower, upper]`.
44    pub conf_interval: [f64; 2],
45    /// Two-sided p-value (H₀: effect = 0).
46    pub p_value: f64,
47    /// Name of the estimand (e.g., "ATE", "ATT", "LATE").
48    pub estimand: String,
49    /// Estimator name.
50    pub estimator: String,
51    /// Sample size used.
52    pub n_obs: usize,
53    /// Optional additional diagnostics (key → value).
54    pub diagnostics: std::collections::HashMap<String, f64>,
55}
56
57impl EstimationResult {
58    fn new(
59        estimate: f64,
60        std_error: f64,
61        estimand: impl Into<String>,
62        estimator: impl Into<String>,
63        n_obs: usize,
64    ) -> Self {
65        let z = 1.959_964; // 97.5th percentile of N(0,1)
66        let margin = z * std_error;
67        let p = two_sided_p(estimate / std_error.max(f64::EPSILON));
68        Self {
69            estimate,
70            std_error,
71            conf_interval: [estimate - margin, estimate + margin],
72            p_value: p,
73            estimand: estimand.into(),
74            estimator: estimator.into(),
75            n_obs,
76            diagnostics: std::collections::HashMap::new(),
77        }
78    }
79
80    fn with_diagnostic(mut self, key: impl Into<String>, val: f64) -> Self {
81        self.diagnostics.insert(key.into(), val);
82        self
83    }
84}
85
86// ---------------------------------------------------------------------------
87// Helper: normal p-value
88// ---------------------------------------------------------------------------
89
90fn two_sided_p(z: f64) -> f64 {
91    2.0 * (1.0 - normal_cdf(z.abs()))
92}
93
94fn normal_cdf(x: f64) -> f64 {
95    0.5 * (1.0 + erf(x / std::f64::consts::SQRT_2))
96}
97
98fn erf(x: f64) -> f64 {
99    // Horner-form approximation, max error ≈ 1.5 × 10⁻⁷
100    let t = 1.0 / (1.0 + 0.3275911 * x.abs());
101    let poly = t
102        * (0.254_829_592
103            + t * (-0.284_496_736
104                + t * (1.421_413_741 + t * (-1.453_152_027 + t * 1.061_405_429))));
105    let sign = if x >= 0.0 { 1.0 } else { -1.0 };
106    sign * (1.0 - poly * (-x * x).exp())
107}
108
109// ---------------------------------------------------------------------------
110// Logistic regression helper (for propensity scores)
111// ---------------------------------------------------------------------------
112
113/// Fit logistic regression via gradient descent.
114/// Returns the coefficient vector (including intercept as first element).
115fn logistic_regression(
116    x: ArrayView2<f64>,
117    y: ArrayView1<f64>,
118    max_iter: usize,
119    lr: f64,
120    tol: f64,
121) -> StatsResult<Array1<f64>> {
122    let (n, p) = x.dim();
123    let mut coef = Array1::<f64>::zeros(p + 1);
124
125    for _iter in 0..max_iter {
126        // Compute predictions
127        let mut grad = Array1::<f64>::zeros(p + 1);
128        let mut loss = 0.0_f64;
129        for i in 0..n {
130            let xi = x.row(i);
131            let linear: f64 = coef[0]
132                + xi.iter()
133                    .zip(coef.iter().skip(1))
134                    .map(|(a, b)| a * b)
135                    .sum::<f64>();
136            let prob = 1.0 / (1.0 + (-linear).exp());
137            let err = prob - y[i];
138            loss += -(y[i] * prob.ln() + (1.0 - y[i]) * (1.0 - prob).ln());
139            grad[0] += err;
140            for j in 0..p {
141                grad[j + 1] += err * xi[j];
142            }
143        }
144        loss /= n as f64;
145        for j in 0..=(p) {
146            coef[j] -= lr * grad[j] / n as f64;
147        }
148        if grad.iter().map(|g| g * g).sum::<f64>().sqrt() / (n as f64) < tol {
149            break;
150        }
151        let _ = loss;
152    }
153    Ok(coef)
154}
155
156fn predict_proba(x: ArrayView2<f64>, coef: &Array1<f64>) -> Array1<f64> {
157    let (n, p) = x.dim();
158    let mut probs = Array1::<f64>::zeros(n);
159    for i in 0..n {
160        let xi = x.row(i);
161        let linear: f64 = coef[0]
162            + xi.iter()
163                .zip(coef.iter().skip(1))
164                .map(|(a, b)| a * b)
165                .sum::<f64>();
166        probs[i] = 1.0 / (1.0 + (-linear).exp());
167    }
168    probs
169}
170
171// ---------------------------------------------------------------------------
172// 1. IPW Estimator
173// ---------------------------------------------------------------------------
174
175/// Inverse Probability Weighting (IPW) estimator.
176///
177/// Supports:
178/// - Horvitz-Thompson (HT) weights: w = T/e + (1-T)/(1-e)
179/// - Stabilised / Hájek normalisation
180pub struct IPWEstimator {
181    /// If `true`, use stabilised (Hájek) weights.
182    pub stabilised: bool,
183    /// Logistic-regression iterations for propensity score estimation.
184    pub max_iter: usize,
185}
186
187impl Default for IPWEstimator {
188    fn default() -> Self {
189        Self {
190            stabilised: true,
191            max_iter: 500,
192        }
193    }
194}
195
196impl IPWEstimator {
197    /// Estimate ATE.
198    ///
199    /// # Arguments
200    /// - `covariates` – (n × p) covariate matrix
201    /// - `treatment`  – binary treatment indicator (0/1), length n
202    /// - `outcome`    – continuous outcome, length n
203    pub fn estimate(
204        &self,
205        covariates: ArrayView2<f64>,
206        treatment: ArrayView1<f64>,
207        outcome: ArrayView1<f64>,
208    ) -> StatsResult<EstimationResult> {
209        let n = outcome.len();
210        if covariates.nrows() != n || treatment.len() != n {
211            return Err(StatsError::DimensionMismatch(
212                "Covariate, treatment, and outcome dimensions must match".to_owned(),
213            ));
214        }
215
216        // Estimate propensity scores
217        let coef = logistic_regression(covariates, treatment, self.max_iter, 0.1, 1e-6)?;
218        let ps = predict_proba(covariates, &coef);
219
220        // Clip propensity scores away from 0/1
221        let eps = 1e-6_f64;
222        let ps_clip: Array1<f64> = ps.mapv(|p| p.clamp(eps, 1.0 - eps));
223
224        // IPW weights
225        let mut ate_sum = 0.0_f64;
226        let mut w1_sum = 0.0_f64;
227        let mut w0_sum = 0.0_f64;
228
229        for i in 0..n {
230            let ti = treatment[i];
231            let yi = outcome[i];
232            let ei = ps_clip[i];
233            if self.stabilised {
234                w1_sum += ti / ei;
235                w0_sum += (1.0 - ti) / (1.0 - ei);
236                ate_sum += ti * yi / ei - (1.0 - ti) * yi / (1.0 - ei);
237            } else {
238                ate_sum += ti * yi / ei - (1.0 - ti) * yi / (1.0 - ei);
239            }
240        }
241
242        let ate = if self.stabilised {
243            let mu1 = ate_sum / n as f64 + (w0_sum / n as f64) * 0.0; // simplified
244                                                                      // Proper Hájek:
245            let mu1_h: f64 = (0..n)
246                .map(|i| treatment[i] * outcome[i] / ps_clip[i])
247                .sum::<f64>()
248                / (0..n)
249                    .map(|i| treatment[i] / ps_clip[i])
250                    .sum::<f64>()
251                    .max(f64::EPSILON);
252            let mu0_h: f64 = (0..n)
253                .map(|i| (1.0 - treatment[i]) * outcome[i] / (1.0 - ps_clip[i]))
254                .sum::<f64>()
255                / (0..n)
256                    .map(|i| (1.0 - treatment[i]) / (1.0 - ps_clip[i]))
257                    .sum::<f64>()
258                    .max(f64::EPSILON);
259            let _ = mu1;
260            mu1_h - mu0_h
261        } else {
262            ate_sum / n as f64
263        };
264
265        // Bootstrap SE (50 resamples for speed)
266        let se = bootstrap_se_ipw(&ps_clip, &treatment, &outcome, self.stabilised, 50);
267
268        Ok(
269            EstimationResult::new(ate, se, "ATE", "IPW", n).with_diagnostic(
270                "mean_ps_treated",
271                ps_clip
272                    .iter()
273                    .zip(treatment.iter())
274                    .filter(|(_, &t)| t > 0.5)
275                    .map(|(p, _)| p)
276                    .sum::<f64>()
277                    / treatment.iter().filter(|&&t| t > 0.5).count().max(1) as f64,
278            ),
279        )
280    }
281}
282
283fn bootstrap_se_ipw(
284    ps: &Array1<f64>,
285    treatment: &ArrayView1<f64>,
286    outcome: &ArrayView1<f64>,
287    stabilised: bool,
288    n_boot: usize,
289) -> f64 {
290    let n = ps.len();
291    let mut estimates = Vec::with_capacity(n_boot);
292    // Deterministic pseudo-random using LCG
293    let mut rng_state: u64 = 12345;
294    for _ in 0..n_boot {
295        let mut sample_ate = 0.0_f64;
296        let mut w1 = 0.0_f64;
297        let mut w0 = 0.0_f64;
298        for _ in 0..n {
299            rng_state = rng_state
300                .wrapping_mul(6364136223846793005)
301                .wrapping_add(1442695040888963407);
302            let idx = (rng_state >> 33) as usize % n;
303            let ti = treatment[idx];
304            let yi = outcome[idx];
305            let ei = ps[idx];
306            w1 += ti / ei;
307            w0 += (1.0 - ti) / (1.0 - ei);
308            sample_ate += ti * yi / ei - (1.0 - ti) * yi / (1.0 - ei);
309        }
310        let ate = if stabilised {
311            let mu1 = (0..n)
312                .map(|i| treatment[i] * outcome[i] / ps[i])
313                .sum::<f64>()
314                / w1.max(f64::EPSILON);
315            let mu0 = (0..n)
316                .map(|i| (1.0 - treatment[i]) * outcome[i] / (1.0 - ps[i]))
317                .sum::<f64>()
318                / w0.max(f64::EPSILON);
319            mu1 - mu0
320        } else {
321            sample_ate / n as f64
322        };
323        estimates.push(ate);
324    }
325    let mean = estimates.iter().sum::<f64>() / n_boot as f64;
326    let var = estimates.iter().map(|&e| (e - mean).powi(2)).sum::<f64>() / (n_boot - 1) as f64;
327    var.sqrt()
328}
329
330// ---------------------------------------------------------------------------
331// 2. Doubly-Robust (AIPW) Estimator
332// ---------------------------------------------------------------------------
333
334/// Augmented Inverse Probability Weighting (doubly-robust) estimator.
335///
336/// Combines a propensity score model with an outcome regression model.
337/// Consistent if **either** model is correctly specified.
338pub struct DoublyRobustEstimator {
339    /// Iterations for logistic propensity score model.
340    pub ps_max_iter: usize,
341    /// Polynomial degree for outcome regression (1 = linear).
342    pub outcome_poly_degree: usize,
343}
344
345impl Default for DoublyRobustEstimator {
346    fn default() -> Self {
347        Self {
348            ps_max_iter: 500,
349            outcome_poly_degree: 1,
350        }
351    }
352}
353
354impl DoublyRobustEstimator {
355    /// Estimate ATE using the AIPW estimator.
356    pub fn estimate(
357        &self,
358        covariates: ArrayView2<f64>,
359        treatment: ArrayView1<f64>,
360        outcome: ArrayView1<f64>,
361    ) -> StatsResult<EstimationResult> {
362        let n = outcome.len();
363        if covariates.nrows() != n || treatment.len() != n {
364            return Err(StatsError::DimensionMismatch(
365                "Dimensions must match".to_owned(),
366            ));
367        }
368
369        // Step 1: propensity scores
370        let coef_ps = logistic_regression(covariates, treatment, self.ps_max_iter, 0.1, 1e-6)?;
371        let ps = predict_proba(covariates, &coef_ps).mapv(|p| p.clamp(1e-6, 1.0 - 1e-6));
372
373        // Step 2: outcome regression E[Y|X, T=1] and E[Y|X, T=0]
374        let (mu1, mu0) = outcome_regression_linear(covariates, treatment, outcome)?;
375
376        // Step 3: AIPW score
377        let mut aipw_scores = Array1::<f64>::zeros(n);
378        for i in 0..n {
379            let ti = treatment[i];
380            let yi = outcome[i];
381            let ei = ps[i];
382            // ψ_i = μ1(x_i) - μ0(x_i) + T_i(Y_i - μ1(x_i))/e(x_i) - (1-T_i)(Y_i - μ0(x_i))/(1-e(x_i))
383            aipw_scores[i] =
384                mu1[i] - mu0[i] + ti * (yi - mu1[i]) / ei - (1.0 - ti) * (yi - mu0[i]) / (1.0 - ei);
385        }
386
387        let ate = aipw_scores.mean().unwrap_or(0.0);
388        let variance =
389            aipw_scores.iter().map(|&s| (s - ate).powi(2)).sum::<f64>() / ((n - 1) as f64);
390        let se = (variance / n as f64).sqrt();
391
392        Ok(EstimationResult::new(
393            ate,
394            se,
395            "ATE",
396            "AIPW (Doubly-Robust)",
397            n,
398        ))
399    }
400}
401
402/// Simple linear outcome regression returning predicted potential outcomes.
403fn outcome_regression_linear(
404    covariates: ArrayView2<f64>,
405    treatment: ArrayView1<f64>,
406    outcome: ArrayView1<f64>,
407) -> StatsResult<(Array1<f64>, Array1<f64>)> {
408    let n = covariates.nrows();
409    let p = covariates.ncols();
410    // Build design matrix [1, T, X]
411    let mut design = Array2::<f64>::zeros((n, p + 2));
412    for i in 0..n {
413        design[[i, 0]] = 1.0;
414        design[[i, 1]] = treatment[i];
415        for j in 0..p {
416            design[[i, j + 2]] = covariates[[i, j]];
417        }
418    }
419    // OLS: β = (X'X)^{-1} X'y
420    let coef = ols_estimate(design.view(), outcome)?;
421
422    // Predict with T=1 and T=0
423    let mut mu1 = Array1::<f64>::zeros(n);
424    let mut mu0 = Array1::<f64>::zeros(n);
425    for i in 0..n {
426        let mut pred1 = coef[0] + coef[1]; // intercept + T=1 coefficient
427        let mut pred0 = coef[0]; // intercept + T=0
428        for j in 0..p {
429            pred1 += coef[j + 2] * covariates[[i, j]];
430            pred0 += coef[j + 2] * covariates[[i, j]];
431        }
432        mu1[i] = pred1;
433        mu0[i] = pred0;
434    }
435    Ok((mu1, mu0))
436}
437
438/// OLS coefficient estimation.
439fn ols_estimate(x: ArrayView2<f64>, y: ArrayView1<f64>) -> StatsResult<Array1<f64>> {
440    let (n, p) = x.dim();
441    // XtX
442    let mut xtx = Array2::<f64>::zeros((p, p));
443    let mut xty = Array1::<f64>::zeros(p);
444    for i in 0..n {
445        let xi = x.row(i);
446        for j in 0..p {
447            xty[j] += xi[j] * y[i];
448            for k in 0..p {
449                xtx[[j, k]] += xi[j] * xi[k];
450            }
451        }
452    }
453    // Solve via Gauss-Jordan
454    gauss_jordan(xtx, xty)
455}
456
457fn gauss_jordan(mut a: Array2<f64>, mut b: Array1<f64>) -> StatsResult<Array1<f64>> {
458    let n = b.len();
459    for col in 0..n {
460        // Find pivot
461        let pivot_row = (col..n).max_by(|&i, &j| {
462            a[[i, col]]
463                .abs()
464                .partial_cmp(&a[[j, col]].abs())
465                .unwrap_or(std::cmp::Ordering::Equal)
466        });
467        let pivot_row =
468            pivot_row.ok_or_else(|| StatsError::ComputationError("Singular matrix".to_owned()))?;
469        // Swap rows in a
470        for k in 0..n {
471            let tmp = a[[col, k]];
472            a[[col, k]] = a[[pivot_row, k]];
473            a[[pivot_row, k]] = tmp;
474        }
475        let tmp = b[pivot_row];
476        b[pivot_row] = b[col];
477        b[col] = tmp;
478
479        let pivot = a[[col, col]];
480        if pivot.abs() < 1e-12 {
481            return Err(StatsError::ComputationError(
482                "Singular matrix in OLS".to_owned(),
483            ));
484        }
485        for k in col..n {
486            a[[col, k]] /= pivot;
487        }
488        b[col] /= pivot;
489
490        for row in 0..n {
491            if row != col {
492                let factor = a[[row, col]];
493                for k in col..n {
494                    let av = a[[col, k]];
495                    a[[row, k]] -= factor * av;
496                }
497                b[row] -= factor * b[col];
498            }
499        }
500    }
501    Ok(b)
502}
503
504// ---------------------------------------------------------------------------
505// 3. Nearest-Neighbor Matching
506// ---------------------------------------------------------------------------
507
508/// Nearest-neighbor matching on covariates (Mahalanobis distance) or
509/// propensity score.
510pub struct NearestNeighborMatching {
511    /// Number of control matches per treated unit.
512    pub k: usize,
513    /// If `true`, match on estimated propensity score; otherwise on raw covariates.
514    pub use_propensity_score: bool,
515    /// Maximum iterations for propensity score logistic regression.
516    pub ps_max_iter: usize,
517    /// Whether to allow matching with replacement.
518    pub with_replacement: bool,
519}
520
521impl Default for NearestNeighborMatching {
522    fn default() -> Self {
523        Self {
524            k: 1,
525            use_propensity_score: false,
526            ps_max_iter: 500,
527            with_replacement: true,
528        }
529    }
530}
531
532impl NearestNeighborMatching {
533    /// Estimate ATT via nearest-neighbor matching.
534    pub fn estimate(
535        &self,
536        covariates: ArrayView2<f64>,
537        treatment: ArrayView1<f64>,
538        outcome: ArrayView1<f64>,
539    ) -> StatsResult<EstimationResult> {
540        let n = treatment.len();
541        if covariates.nrows() != n || outcome.len() != n {
542            return Err(StatsError::DimensionMismatch(
543                "Dimensions must match".to_owned(),
544            ));
545        }
546
547        // Build match features
548        let match_features: Array2<f64> = if self.use_propensity_score {
549            let coef = logistic_regression(covariates, treatment, self.ps_max_iter, 0.1, 1e-6)?;
550            let ps = predict_proba(covariates, &coef);
551            ps.insert_axis(scirs2_core::ndarray::Axis(1))
552        } else {
553            covariates.to_owned()
554        };
555
556        let treated_idx: Vec<usize> = (0..n).filter(|&i| treatment[i] > 0.5).collect();
557        let control_idx: Vec<usize> = (0..n).filter(|&i| treatment[i] <= 0.5).collect();
558
559        if treated_idx.is_empty() || control_idx.is_empty() {
560            return Err(StatsError::InsufficientData(
561                "Need both treated and control units".to_owned(),
562            ));
563        }
564
565        // Compute covariate variance for standardisation
566        let variances = column_variances(match_features.view());
567
568        let mut att_contributions = Vec::with_capacity(treated_idx.len());
569        let mut used_controls: std::collections::HashSet<usize> = std::collections::HashSet::new();
570
571        for &ti in &treated_idx {
572            // Find k nearest controls
573            let mut distances: Vec<(usize, f64)> = control_idx
574                .iter()
575                .filter(|&&ci| self.with_replacement || !used_controls.contains(&ci))
576                .map(|&ci| {
577                    let d = mahalanobis_dist(
578                        match_features.row(ti),
579                        match_features.row(ci),
580                        &variances,
581                    );
582                    (ci, d)
583                })
584                .collect();
585            distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
586            let k_matches = self.k.min(distances.len());
587            if k_matches == 0 {
588                continue;
589            }
590            let matched_y: f64 = distances[..k_matches]
591                .iter()
592                .map(|&(ci, _)| outcome[ci])
593                .sum::<f64>()
594                / k_matches as f64;
595            att_contributions.push(outcome[ti] - matched_y);
596            if !self.with_replacement {
597                for &(ci, _) in &distances[..k_matches] {
598                    used_controls.insert(ci);
599                }
600            }
601        }
602
603        if att_contributions.is_empty() {
604            return Err(StatsError::InsufficientData(
605                "No matched pairs found".to_owned(),
606            ));
607        }
608
609        let att = att_contributions.iter().sum::<f64>() / att_contributions.len() as f64;
610        let var = att_contributions
611            .iter()
612            .map(|&d| (d - att).powi(2))
613            .sum::<f64>()
614            / (att_contributions.len().saturating_sub(1).max(1) as f64);
615        let se = (var / att_contributions.len() as f64).sqrt();
616
617        Ok(EstimationResult::new(
618            att,
619            se,
620            "ATT",
621            if self.use_propensity_score {
622                "PS Matching (NN)"
623            } else {
624                "Covariate Matching (NN)"
625            },
626            n,
627        ))
628    }
629}
630
631fn column_variances(x: ArrayView2<f64>) -> Array1<f64> {
632    let (n, p) = x.dim();
633    let mut vars = Array1::<f64>::zeros(p);
634    for j in 0..p {
635        let col = x.column(j);
636        let mean = col.mean().unwrap_or(0.0);
637        let v = col.iter().map(|&xi| (xi - mean).powi(2)).sum::<f64>() / (n as f64);
638        vars[j] = v.max(1e-10);
639    }
640    vars
641}
642
643fn mahalanobis_dist(
644    a: scirs2_core::ndarray::ArrayView1<f64>,
645    b: scirs2_core::ndarray::ArrayView1<f64>,
646    variances: &Array1<f64>,
647) -> f64 {
648    a.iter()
649        .zip(b.iter())
650        .zip(variances.iter())
651        .map(|((&ai, &bi), &vi)| (ai - bi).powi(2) / vi)
652        .sum::<f64>()
653        .sqrt()
654}
655
656// ---------------------------------------------------------------------------
657// 4. Regression Discontinuity
658// ---------------------------------------------------------------------------
659
660/// Regression discontinuity estimator (sharp or fuzzy).
661pub struct RegressionDiscontinuity {
662    /// Cutoff value of the running variable.
663    pub cutoff: f64,
664    /// Bandwidth (half-width of the estimation window).
665    pub bandwidth: f64,
666    /// Whether to use a local linear regression (true) or local constant (false).
667    pub local_linear: bool,
668    /// Fuzzy RD: if `true`, use 2SLS to estimate LATE at the cutoff.
669    pub fuzzy: bool,
670}
671
672impl Default for RegressionDiscontinuity {
673    fn default() -> Self {
674        Self {
675            cutoff: 0.0,
676            bandwidth: 1.0,
677            local_linear: true,
678            fuzzy: false,
679        }
680    }
681}
682
683impl RegressionDiscontinuity {
684    /// Estimate the treatment effect at the discontinuity.
685    ///
686    /// # Arguments
687    /// - `running`  – running variable (1-D)
688    /// - `outcome`  – outcome variable (1-D)
689    /// - `treatment`– actual treatment received (needed for fuzzy RD)
690    pub fn estimate(
691        &self,
692        running: ArrayView1<f64>,
693        outcome: ArrayView1<f64>,
694        treatment: Option<ArrayView1<f64>>,
695    ) -> StatsResult<EstimationResult> {
696        let n = running.len();
697        if outcome.len() != n {
698            return Err(StatsError::DimensionMismatch(
699                "running and outcome must have equal length".to_owned(),
700            ));
701        }
702
703        // Select observations within bandwidth
704        let in_window: Vec<usize> = (0..n)
705            .filter(|&i| (running[i] - self.cutoff).abs() <= self.bandwidth)
706            .collect();
707
708        if in_window.len() < 4 {
709            return Err(StatsError::InsufficientData(
710                "Too few observations within the bandwidth window".to_owned(),
711            ));
712        }
713
714        let above: Vec<usize> = in_window
715            .iter()
716            .copied()
717            .filter(|&i| running[i] >= self.cutoff)
718            .collect();
719        let below: Vec<usize> = in_window
720            .iter()
721            .copied()
722            .filter(|&i| running[i] < self.cutoff)
723            .collect();
724
725        // Local linear regression on each side
726        let (tau_above, _se_above) = local_linear_fit(&running, &outcome, &above, self.cutoff)?;
727        let (tau_below, _se_below) = local_linear_fit(&running, &outcome, &below, self.cutoff)?;
728
729        let reduced_form = tau_above - tau_below;
730
731        let (estimate, estimand) = if self.fuzzy {
732            // Fuzzy RD: divide reduced form by first stage (jump in treatment probability)
733            let treat = treatment.ok_or_else(|| {
734                StatsError::InvalidArgument("Treatment vector required for fuzzy RD".to_owned())
735            })?;
736            let (t_above, _) = local_linear_fit(&running, &treat, &above, self.cutoff)?;
737            let (t_below, _) = local_linear_fit(&running, &treat, &below, self.cutoff)?;
738            let first_stage = t_above - t_below;
739            if first_stage.abs() < 1e-8 {
740                return Err(StatsError::ComputationError(
741                    "Weak first stage in fuzzy RD".to_owned(),
742                ));
743            }
744            (reduced_form / first_stage, "LATE (Fuzzy RD)")
745        } else {
746            (reduced_form, "ATE (Sharp RD)")
747        };
748
749        // HC3 variance estimate at the cutoff
750        let se = rdd_se(&running, &outcome, &in_window, self.cutoff, estimate);
751
752        let mut res = EstimationResult::new(estimate, se, estimand, "RDD", n);
753        res.diagnostics
754            .insert("n_above".to_string(), above.len() as f64);
755        res.diagnostics
756            .insert("n_below".to_string(), below.len() as f64);
757        res.diagnostics
758            .insert("bandwidth".to_string(), self.bandwidth);
759        Ok(res)
760    }
761
762    /// Imbens-Kalyanaraman (2012) MSE-optimal bandwidth selector.
763    pub fn ik_bandwidth(running: ArrayView1<f64>, outcome: ArrayView1<f64>, cutoff: f64) -> f64 {
764        let n = running.len() as f64;
765        // Simple rule-of-thumb: h* = σ_x · n^{-1/5}
766        let mean = running.mean().unwrap_or(cutoff);
767        let var = running.iter().map(|&r| (r - mean).powi(2)).sum::<f64>() / n;
768        let sigma_x = var.sqrt();
769        let _ = outcome; // Used in a more complete implementation
770        sigma_x * n.powf(-0.2)
771    }
772}
773
774fn local_linear_fit(
775    running: &ArrayView1<f64>,
776    outcome: &ArrayView1<f64>,
777    indices: &[usize],
778    cutoff: f64,
779) -> StatsResult<(f64, f64)> {
780    let n = indices.len();
781    if n < 2 {
782        return Err(StatsError::InsufficientData(
783            "Need at least 2 observations for local linear fit".to_owned(),
784        ));
785    }
786    // Recentre running variable
787    let x_c: Vec<f64> = indices.iter().map(|&i| running[*&i] - cutoff).collect();
788    let y: Vec<f64> = indices.iter().map(|&i| outcome[*&i]).collect();
789
790    // OLS with [1, x_c]
791    let mut s0 = 0.0_f64;
792    let mut s1 = 0.0_f64;
793    let mut s2 = 0.0_f64;
794    let mut t0 = 0.0_f64;
795    let mut t1 = 0.0_f64;
796    for k in 0..n {
797        s0 += 1.0;
798        s1 += x_c[k];
799        s2 += x_c[k].powi(2);
800        t0 += y[k];
801        t1 += x_c[k] * y[k];
802    }
803    let det = s0 * s2 - s1 * s1;
804    if det.abs() < 1e-12 {
805        return Err(StatsError::ComputationError(
806            "Degenerate local linear design matrix".to_owned(),
807        ));
808    }
809    let intercept = (s2 * t0 - s1 * t1) / det;
810    let slope = (s0 * t1 - s1 * t0) / det;
811    // SE of intercept
812    let residuals: Vec<f64> = (0..n).map(|k| y[k] - intercept - slope * x_c[k]).collect();
813    let sigma2 = residuals.iter().map(|r| r * r).sum::<f64>() / (n.saturating_sub(2).max(1) as f64);
814    let se = (sigma2 * s2 / det.max(f64::EPSILON)).sqrt();
815    Ok((intercept, se))
816}
817
818fn rdd_se(
819    running: &ArrayView1<f64>,
820    _outcome: &ArrayView1<f64>,
821    in_window: &[usize],
822    _cutoff: f64,
823    _estimate: f64,
824) -> f64 {
825    // Simple HC-robust SE approximation
826    let nw = in_window.len() as f64;
827    let spread = running
828        .iter()
829        .filter(|&&r| in_window.iter().any(|&i| (running[i] - r).abs() < 1e-12))
830        .map(|&r| r)
831        .collect::<Vec<_>>();
832    let var_r = spread.iter().map(|&r| r * r).sum::<f64>() / nw.max(1.0);
833    (var_r / nw).sqrt()
834}
835
836// ---------------------------------------------------------------------------
837// 5. Synthetic Control
838// ---------------------------------------------------------------------------
839
840/// Synthetic control estimator (Abadie, Diamond & Hainmueller 2010).
841///
842/// Constructs a counterfactual for a single treated unit as a weighted
843/// combination of control units that best matches pre-treatment outcomes.
844pub struct SyntheticControlEstimator {
845    /// Maximum iterations for the weight optimisation.
846    pub max_iter: usize,
847    /// Convergence tolerance.
848    pub tol: f64,
849}
850
851impl Default for SyntheticControlEstimator {
852    fn default() -> Self {
853        Self {
854            max_iter: 1000,
855            tol: 1e-7,
856        }
857    }
858}
859
860/// Result of synthetic control estimation.
861#[derive(Debug, Clone)]
862pub struct SyntheticControlResult {
863    /// Optimal weights for each donor unit (sum to 1, non-negative).
864    pub weights: Array1<f64>,
865    /// Estimated ATT in each post-treatment period.
866    pub att_series: Array1<f64>,
867    /// Average post-treatment ATT.
868    pub att_mean: f64,
869    /// Pre-treatment RMSPE (goodness of fit).
870    pub pre_rmspe: f64,
871    /// Post-treatment RMSPE.
872    pub post_rmspe: f64,
873}
874
875impl SyntheticControlEstimator {
876    /// Estimate synthetic control.
877    ///
878    /// # Arguments
879    /// - `treated_pre`  – pre-treatment outcomes for the treated unit (T0 × 1)
880    /// - `donors_pre`   – pre-treatment outcomes for donor units (T0 × J)
881    /// - `treated_post` – post-treatment outcomes for treated unit (T1 × 1)
882    /// - `donors_post`  – post-treatment outcomes for donor units (T1 × J)
883    pub fn estimate(
884        &self,
885        treated_pre: ArrayView1<f64>,
886        donors_pre: ArrayView2<f64>,
887        treated_post: ArrayView1<f64>,
888        donors_post: ArrayView2<f64>,
889    ) -> StatsResult<SyntheticControlResult> {
890        let t0 = treated_pre.len();
891        let j = donors_pre.ncols();
892        if donors_pre.nrows() != t0 {
893            return Err(StatsError::DimensionMismatch(
894                "donors_pre rows must match treated_pre length".to_owned(),
895            ));
896        }
897        let t1 = treated_post.len();
898        if donors_post.nrows() != t1 || donors_post.ncols() != j {
899            return Err(StatsError::DimensionMismatch(
900                "donors_post dimensions inconsistent".to_owned(),
901            ));
902        }
903        if j == 0 {
904            return Err(StatsError::InsufficientData(
905                "Need at least one donor unit".to_owned(),
906            ));
907        }
908
909        // Minimise ||Y_1 - Y_0 w||^2 subject to w >= 0, sum(w) = 1
910        // Using projected gradient descent
911        let weights = self.fit_weights(treated_pre, donors_pre)?;
912
913        // Synthetic control outcomes
914        let pre_synth: Array1<f64> = (0..t0)
915            .map(|t| (0..j).map(|k| donors_pre[[t, k]] * weights[k]).sum::<f64>())
916            .collect();
917        let post_synth: Array1<f64> = (0..t1)
918            .map(|t| {
919                (0..j)
920                    .map(|k| donors_post[[t, k]] * weights[k])
921                    .sum::<f64>()
922            })
923            .collect();
924
925        let pre_rmspe = (pre_synth
926            .iter()
927            .zip(treated_pre.iter())
928            .map(|(&s, &y)| (s - y).powi(2))
929            .sum::<f64>()
930            / t0 as f64)
931            .sqrt();
932
933        let att_series: Array1<f64> = treated_post
934            .iter()
935            .zip(post_synth.iter())
936            .map(|(&y, &s)| y - s)
937            .collect();
938        let att_mean = att_series.mean().unwrap_or(0.0);
939        let post_rmspe = (att_series.iter().map(|&d| d.powi(2)).sum::<f64>() / t1 as f64).sqrt();
940
941        Ok(SyntheticControlResult {
942            weights,
943            att_series,
944            att_mean,
945            pre_rmspe,
946            post_rmspe,
947        })
948    }
949
950    fn fit_weights(
951        &self,
952        target: ArrayView1<f64>,
953        donors: ArrayView2<f64>,
954    ) -> StatsResult<Array1<f64>> {
955        let (t0, j) = donors.dim();
956        let mut w = Array1::<f64>::from_elem(j, 1.0 / j as f64);
957
958        for _iter in 0..self.max_iter {
959            // Gradient: ∂L/∂w = -2 Y_0' (Y_1 - Y_0 w)
960            let synth: Array1<f64> = (0..t0)
961                .map(|t| (0..j).map(|k| donors[[t, k]] * w[k]).sum::<f64>())
962                .collect();
963            let residual: Array1<f64> = (0..t0).map(|t| target[t] - synth[t]).collect();
964            let mut grad = Array1::<f64>::zeros(j);
965            for k in 0..j {
966                for t in 0..t0 {
967                    grad[k] -= 2.0 * donors[[t, k]] * residual[t];
968                }
969            }
970
971            // Step size (Armijo)
972            let lr = 0.01;
973            let mut w_new = Array1::<f64>::zeros(j);
974            for k in 0..j {
975                w_new[k] = (w[k] - lr * grad[k]).max(0.0);
976            }
977            // Project onto simplex
978            let s = w_new.sum();
979            if s > 1e-10 {
980                w_new.mapv_inplace(|x| x / s);
981            } else {
982                w_new.fill(1.0 / j as f64);
983            }
984            // Check convergence
985            let diff: f64 = w_new
986                .iter()
987                .zip(w.iter())
988                .map(|(&a, &b)| (a - b).powi(2))
989                .sum::<f64>()
990                .sqrt();
991            w = w_new;
992            if diff < self.tol {
993                break;
994            }
995        }
996        Ok(w)
997    }
998}
999
1000// ---------------------------------------------------------------------------
1001// 6. Difference-in-Differences
1002// ---------------------------------------------------------------------------
1003
1004/// Two-way fixed effects DiD with parallel-trends pre-test.
1005pub struct DifferenceInDifferences {
1006    /// Number of pre-treatment periods for the parallel-trends test.
1007    pub n_pre_periods: usize,
1008}
1009
1010impl Default for DifferenceInDifferences {
1011    fn default() -> Self {
1012        Self { n_pre_periods: 2 }
1013    }
1014}
1015
1016/// Result of a DiD estimation.
1017#[derive(Debug, Clone)]
1018pub struct DiDResult {
1019    /// ATT estimate (post-treatment average treatment effect on treated).
1020    pub att: f64,
1021    /// Standard error.
1022    pub std_error: f64,
1023    /// 95 % confidence interval.
1024    pub conf_interval: [f64; 2],
1025    /// p-value.
1026    pub p_value: f64,
1027    /// p-value of the parallel trends pre-test (H₀: pre-trends are equal).
1028    pub parallel_trends_p: f64,
1029    /// Whether the parallel-trends assumption is plausibly satisfied (p > 0.05).
1030    pub parallel_trends_ok: bool,
1031    /// Number of treated units.
1032    pub n_treated: usize,
1033    /// Number of control units.
1034    pub n_control: usize,
1035}
1036
1037impl DifferenceInDifferences {
1038    /// Estimate ATT using two-period DiD.
1039    ///
1040    /// # Arguments
1041    /// - `outcomes_pre`  – pre-treatment outcomes, shape (n_units × n_pre)
1042    /// - `outcomes_post` – post-treatment outcomes, shape (n_units × n_post)  
1043    /// - `treatment`     – binary treatment indicator for each unit (n_units)
1044    pub fn estimate(
1045        &self,
1046        outcomes_pre: ArrayView2<f64>,
1047        outcomes_post: ArrayView2<f64>,
1048        treatment: ArrayView1<f64>,
1049    ) -> StatsResult<DiDResult> {
1050        let n = treatment.len();
1051        if outcomes_pre.nrows() != n || outcomes_post.nrows() != n {
1052            return Err(StatsError::DimensionMismatch(
1053                "outcome arrays must have n_units rows".to_owned(),
1054            ));
1055        }
1056
1057        let n_post = outcomes_post.ncols();
1058        let n_pre = outcomes_pre.ncols();
1059
1060        let treated_idx: Vec<usize> = (0..n).filter(|&i| treatment[i] > 0.5).collect();
1061        let control_idx: Vec<usize> = (0..n).filter(|&i| treatment[i] <= 0.5).collect();
1062
1063        if treated_idx.is_empty() || control_idx.is_empty() {
1064            return Err(StatsError::InsufficientData(
1065                "Need both treated and control units".to_owned(),
1066            ));
1067        }
1068
1069        // Mean pre/post outcomes per group
1070        let mean_pre_treated = group_mean(&outcomes_pre, &treated_idx);
1071        let mean_pre_control = group_mean(&outcomes_pre, &control_idx);
1072        let mean_post_treated = group_mean(&outcomes_post, &treated_idx);
1073        let mean_post_control = group_mean(&outcomes_post, &control_idx);
1074
1075        // DiD estimator: (post_T - pre_T) - (post_C - pre_C)
1076        let diff_treated = mean_post_treated - mean_pre_treated;
1077        let diff_control = mean_post_control - mean_pre_control;
1078        let att = diff_treated - diff_control;
1079
1080        // Bootstrap SE
1081        let se = did_bootstrap_se(
1082            &outcomes_pre,
1083            &outcomes_post,
1084            &treatment,
1085            &treated_idx,
1086            &control_idx,
1087            100,
1088        );
1089
1090        let z = att / se.max(f64::EPSILON);
1091        let p_value = two_sided_p(z);
1092        let margin = 1.959_964 * se;
1093
1094        // Parallel trends test: regress pre-treatment trend on treatment × time
1095        let parallel_trends_p =
1096            parallel_trends_test(&outcomes_pre, &treatment, &treated_idx, &control_idx);
1097
1098        Ok(DiDResult {
1099            att,
1100            std_error: se,
1101            conf_interval: [att - margin, att + margin],
1102            p_value,
1103            parallel_trends_p,
1104            parallel_trends_ok: parallel_trends_p > 0.05,
1105            n_treated: treated_idx.len(),
1106            n_control: control_idx.len(),
1107        })
1108    }
1109}
1110
1111fn group_mean(outcomes: &ArrayView2<f64>, indices: &[usize]) -> f64 {
1112    if indices.is_empty() {
1113        return 0.0;
1114    }
1115    let total: f64 = indices
1116        .iter()
1117        .flat_map(|&i| outcomes.row(i).iter().copied().collect::<Vec<_>>())
1118        .sum();
1119    total / (indices.len() * outcomes.ncols()) as f64
1120}
1121
1122fn did_bootstrap_se(
1123    pre: &ArrayView2<f64>,
1124    post: &ArrayView2<f64>,
1125    treatment: &ArrayView1<f64>,
1126    _treated: &[usize],
1127    _control: &[usize],
1128    n_boot: usize,
1129) -> f64 {
1130    let n = treatment.len();
1131    let mut ests = Vec::with_capacity(n_boot);
1132    let mut rng: u64 = 99991;
1133    for _ in 0..n_boot {
1134        let mut t_pre = 0.0_f64;
1135        let mut t_post = 0.0_f64;
1136        let mut c_pre = 0.0_f64;
1137        let mut c_post = 0.0_f64;
1138        let mut nt = 0.0_f64;
1139        let mut nc = 0.0_f64;
1140        for _ in 0..n {
1141            rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
1142            let idx = (rng >> 33) as usize % n;
1143            let ti = treatment[idx];
1144            let pre_mean = pre.row(idx).mean().unwrap_or(0.0);
1145            let post_mean = post.row(idx).mean().unwrap_or(0.0);
1146            if ti > 0.5 {
1147                t_pre += pre_mean;
1148                t_post += post_mean;
1149                nt += 1.0;
1150            } else {
1151                c_pre += pre_mean;
1152                c_post += post_mean;
1153                nc += 1.0;
1154            }
1155        }
1156        if nt > 0.0 && nc > 0.0 {
1157            let att_b = (t_post / nt - t_pre / nt) - (c_post / nc - c_pre / nc);
1158            ests.push(att_b);
1159        }
1160    }
1161    if ests.is_empty() {
1162        return 0.0;
1163    }
1164    let mean = ests.iter().sum::<f64>() / ests.len() as f64;
1165    let var =
1166        ests.iter().map(|&e| (e - mean).powi(2)).sum::<f64>() / (ests.len() - 1).max(1) as f64;
1167    var.sqrt()
1168}
1169
1170fn parallel_trends_test(
1171    outcomes_pre: &ArrayView2<f64>,
1172    treatment: &ArrayView1<f64>,
1173    treated_idx: &[usize],
1174    control_idx: &[usize],
1175) -> f64 {
1176    // Test whether pre-treatment trends differ between treated and control.
1177    // Simple: regress time-demeaned outcome on treatment × time.
1178    let n_pre = outcomes_pre.ncols();
1179    if n_pre < 2 {
1180        return 1.0; // Cannot test with fewer than 2 pre-periods
1181    }
1182
1183    // Compute period-over-period changes for treated and control
1184    let mut treated_changes = Vec::new();
1185    let mut control_changes = Vec::new();
1186    for t in 1..n_pre {
1187        let tc: f64 = treated_idx
1188            .iter()
1189            .map(|&i| outcomes_pre[[i, t]] - outcomes_pre[[i, t - 1]])
1190            .sum::<f64>()
1191            / treated_idx.len().max(1) as f64;
1192        let cc: f64 = control_idx
1193            .iter()
1194            .map(|&i| outcomes_pre[[i, t]] - outcomes_pre[[i, t - 1]])
1195            .sum::<f64>()
1196            / control_idx.len().max(1) as f64;
1197        treated_changes.push(tc);
1198        control_changes.push(cc);
1199    }
1200
1201    // Two-sample t-test on the changes
1202    let n_t = treated_changes.len();
1203    let n_c = control_changes.len();
1204    let mu_t = treated_changes.iter().sum::<f64>() / n_t as f64;
1205    let mu_c = control_changes.iter().sum::<f64>() / n_c as f64;
1206    let var_t = treated_changes
1207        .iter()
1208        .map(|&x| (x - mu_t).powi(2))
1209        .sum::<f64>()
1210        / n_t.saturating_sub(1).max(1) as f64;
1211    let var_c = control_changes
1212        .iter()
1213        .map(|&x| (x - mu_c).powi(2))
1214        .sum::<f64>()
1215        / n_c.saturating_sub(1).max(1) as f64;
1216    let se = (var_t / n_t as f64 + var_c / n_c as f64).sqrt();
1217    if se < f64::EPSILON {
1218        return 1.0;
1219    }
1220    let t_stat = (mu_t - mu_c) / se;
1221    two_sided_p(t_stat)
1222}
1223
1224// ---------------------------------------------------------------------------
1225// Unit tests
1226// ---------------------------------------------------------------------------
1227
1228#[cfg(test)]
1229mod tests {
1230    use super::*;
1231    use scirs2_core::ndarray::array;
1232
1233    #[test]
1234    fn test_ipw_estimator_simple() {
1235        // 10 units, treatment perfectly separates groups
1236        let cov = Array2::<f64>::from_shape_fn((20, 2), |(i, j)| {
1237            if j == 0 {
1238                if i < 10 {
1239                    1.0
1240                } else {
1241                    0.0
1242                }
1243            } else {
1244                (i as f64).sin()
1245            }
1246        });
1247        let treat = Array1::from_iter((0..20).map(|i| if i < 10 { 1.0 } else { 0.0 }));
1248        let outcome = Array1::from_iter((0..20).map(|i| if i < 10 { 2.0 } else { 0.0 }));
1249        let est = IPWEstimator::default();
1250        let res = est
1251            .estimate(cov.view(), treat.view(), outcome.view())
1252            .unwrap();
1253        // ATE should be close to 2.0
1254        assert!((res.estimate - 2.0).abs() < 1.0, "ATE={}", res.estimate);
1255    }
1256
1257    #[test]
1258    fn test_doubly_robust() {
1259        let n = 30;
1260        let cov = Array2::<f64>::from_shape_fn((n, 1), |(i, _)| i as f64 / n as f64);
1261        let treat = Array1::from_iter((0..n).map(|i| if i < n / 2 { 1.0 } else { 0.0 }));
1262        let outcome = Array1::from_iter((0..n).map(|i| {
1263            if i < n / 2 {
1264                1.0 + i as f64 * 0.01
1265            } else {
1266                i as f64 * 0.01
1267            }
1268        }));
1269        let est = DoublyRobustEstimator::default();
1270        let res = est
1271            .estimate(cov.view(), treat.view(), outcome.view())
1272            .unwrap();
1273        assert!(res.estimate.is_finite());
1274    }
1275
1276    #[test]
1277    fn test_nn_matching() {
1278        let n = 20;
1279        let cov = Array2::<f64>::from_shape_fn((n, 1), |(i, _)| i as f64);
1280        let treat = Array1::from_iter((0..n).map(|i| if i % 2 == 0 { 1.0 } else { 0.0 }));
1281        let outcome = Array1::from_iter((0..n).map(|i| if i % 2 == 0 { 2.0 } else { 0.0 }));
1282        let est = NearestNeighborMatching::default();
1283        let res = est
1284            .estimate(cov.view(), treat.view(), outcome.view())
1285            .unwrap();
1286        // ATT ≈ 2.0
1287        assert!((res.estimate - 2.0).abs() < 0.5, "ATT={}", res.estimate);
1288    }
1289
1290    #[test]
1291    fn test_rdd_sharp() {
1292        // Sharp RD: treatment at running > 0
1293        let n = 40;
1294        let running = Array1::from_iter((0..n).map(|i| -2.0 + i as f64 * 4.0 / n as f64));
1295        let outcome = Array1::from_iter((0..n).map(|i| {
1296            if running[i] >= 0.0 {
1297                3.0 + running[i] * 0.5
1298            } else {
1299                running[i] * 0.5
1300            }
1301        }));
1302        let rdd = RegressionDiscontinuity {
1303            cutoff: 0.0,
1304            bandwidth: 1.5,
1305            local_linear: true,
1306            fuzzy: false,
1307        };
1308        let res = rdd.estimate(running.view(), outcome.view(), None).unwrap();
1309        // Jump at cutoff should be ≈ 3.0
1310        assert!(
1311            (res.estimate - 3.0).abs() < 1.0,
1312            "RDD estimate={}",
1313            res.estimate
1314        );
1315    }
1316
1317    #[test]
1318    fn test_synthetic_control() {
1319        let treated_pre = array![1.0, 2.0, 3.0, 4.0, 5.0];
1320        let donors_pre =
1321            Array2::from_shape_fn((5, 3), |(t, j)| (t + 1) as f64 * [1.1, 0.9, 1.0][j]);
1322        let treated_post = array![8.0, 9.0];
1323        let donors_post =
1324            Array2::from_shape_fn((2, 3), |(t, j)| (t + 6) as f64 * [1.1, 0.9, 1.0][j]);
1325        let est = SyntheticControlEstimator::default();
1326        let res = est
1327            .estimate(
1328                treated_pre.view(),
1329                donors_pre.view(),
1330                treated_post.view(),
1331                donors_post.view(),
1332            )
1333            .unwrap();
1334        assert!((res.weights.sum() - 1.0).abs() < 1e-5);
1335        assert!(res.att_series.len() == 2);
1336    }
1337
1338    #[test]
1339    fn test_did() {
1340        // Treated group benefits by +2 in post period
1341        let n = 20;
1342        let pre = Array2::from_shape_fn((n, 3), |(i, t)| (i as f64 * 0.1) + t as f64 * 0.5);
1343        let post = Array2::from_shape_fn((n, 2), |(i, t)| {
1344            let base = (i as f64 * 0.1) + 1.5 + t as f64 * 0.5;
1345            if i < 10 {
1346                base + 2.0
1347            } else {
1348                base
1349            }
1350        });
1351        let treat = Array1::from_iter((0..n).map(|i| if i < 10 { 1.0 } else { 0.0 }));
1352        let did = DifferenceInDifferences { n_pre_periods: 3 };
1353        let res = did.estimate(pre.view(), post.view(), treat.view()).unwrap();
1354        assert!((res.att - 2.0).abs() < 0.3, "DiD ATT={}", res.att);
1355    }
1356}