Skip to main content

scirs2_stats/causal/
difference_in_differences.rs

1//! Difference-in-Differences and Related Methods
2//!
3//! Implements panel-data causal inference estimators:
4//!
5//! - **`DiD`**: Classic 2×2 DiD with parallel-trends check and ATT estimation
6//! - **`SyntheticControl`**: Abadie-Diamond-Hainmueller donor-pool weighting
7//! - **`EventStudy`**: Pre/post event-time coefficients for dynamic treatment effects
8//! - **`StaggeredDiD`**: Callaway-Sant'Anna (2021) doubly-robust ATT(g,t) estimator
9//! - **`DiDResult`**: Unified result type
10//!
11//! # References
12//!
13//! - Abadie, A. & Gardeazabal, J. (2003). The Economic Costs of Conflict.
14//! - Callaway, B. & Sant'Anna, P.H.C. (2021). Difference-in-Differences with
15//!   Multiple Time Periods. Journal of Econometrics.
16//! - Roth, J. et al. (2023). What's Trending in Difference-in-Differences?
17
18use crate::error::{StatsError, StatsResult};
19use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
20use std::collections::HashMap;
21
22// ---------------------------------------------------------------------------
23// Result types
24// ---------------------------------------------------------------------------
25
26/// Result of a Difference-in-Differences estimation
27#[derive(Debug, Clone)]
28pub struct DiDResult {
29    /// Average Treatment Effect on the Treated (ATT)
30    pub att: f64,
31
32    /// Standard error of ATT
33    pub std_error: f64,
34
35    /// t-statistic
36    pub t_stat: f64,
37
38    /// Two-sided p-value
39    pub p_value: f64,
40
41    /// 95 % confidence interval [lower, upper]
42    pub conf_interval: [f64; 2],
43
44    /// Pre-treatment parallel-trends test p-value (if computed)
45    pub parallel_trends_p: Option<f64>,
46
47    /// Number of treated observations
48    pub n_treated: usize,
49
50    /// Number of control observations
51    pub n_control: usize,
52
53    /// Estimator name
54    pub estimator: String,
55}
56
57/// Event-study coefficient
58#[derive(Debug, Clone)]
59pub struct EventCoefficient {
60    /// Relative event time (negative = pre-treatment)
61    pub relative_time: i64,
62    /// Point estimate
63    pub estimate: f64,
64    /// Standard error
65    pub std_error: f64,
66    /// t-statistic
67    pub t_stat: f64,
68    /// Two-sided p-value
69    pub p_value: f64,
70    /// 95 % confidence interval
71    pub conf_interval: [f64; 2],
72}
73
74/// Result of an event-study analysis
75#[derive(Debug, Clone)]
76pub struct EventStudyResult {
77    /// Coefficients for each relative time period
78    pub coefficients: Vec<EventCoefficient>,
79    /// Pre-treatment F-test statistic (joint test that all pre-coefficients = 0)
80    pub pre_trend_f: f64,
81    /// p-value for the pre-trend test
82    pub pre_trend_p: f64,
83    /// Degrees of freedom for the pre-trend F-test
84    pub pre_trend_df: usize,
85}
86
87/// Result of a Callaway-Sant'Anna staggered DiD
88#[derive(Debug, Clone)]
89pub struct StaggeredDiDResult {
90    /// ATT(g, t) estimates for each (cohort, period) pair
91    pub att_gt: Vec<AttGt>,
92    /// Aggregate ATT (simple weighted average)
93    pub aggregate_att: f64,
94    /// Standard error of aggregate ATT (via bootstrap)
95    pub aggregate_se: f64,
96    /// p-value for aggregate ATT
97    pub aggregate_p: f64,
98}
99
100/// ATT for a specific (group, time) pair
101#[derive(Debug, Clone)]
102pub struct AttGt {
103    /// First-treatment period (cohort)
104    pub cohort: i64,
105    /// Calendar period
106    pub period: i64,
107    /// ATT estimate
108    pub att: f64,
109    /// Standard error
110    pub std_error: f64,
111    /// p-value
112    pub p_value: f64,
113}
114
115// ---------------------------------------------------------------------------
116// Utility: simple normal CDF and quantile
117// ---------------------------------------------------------------------------
118
119fn normal_cdf(x: f64) -> f64 {
120    0.5 * (1.0 + libm_erf(x / std::f64::consts::SQRT_2))
121}
122
123fn libm_erf(x: f64) -> f64 {
124    // Abramowitz & Stegun approximation 7.1.26, max |error| < 1.5e-7
125    let t = 1.0 / (1.0 + 0.3275911 * x.abs());
126    let y = 1.0
127        - (0.254829592
128            + (-0.284496736 + (1.421413741 + (-1.453152027 + 1.061405429 * t) * t) * t) * t)
129            * t
130            * (-x * x).exp();
131    if x >= 0.0 {
132        y
133    } else {
134        -y
135    }
136}
137
138fn normal_p_value(z: f64) -> f64 {
139    // Two-sided
140    2.0 * (1.0 - normal_cdf(z.abs()))
141}
142
143fn t_dist_p_value_did(t: f64, df: f64) -> f64 {
144    if df <= 0.0 {
145        return 1.0;
146    }
147    // Use normal approximation for large df
148    if df > 200.0 {
149        return normal_p_value(t);
150    }
151    // Regularized incomplete beta I_x(df/2, 0.5) at x = df/(df+t²)
152    let x = df / (df + t * t);
153    regularized_incomplete_beta(x, df / 2.0, 0.5)
154        .min(1.0)
155        .max(0.0)
156}
157
158fn regularized_incomplete_beta(x: f64, a: f64, b: f64) -> f64 {
159    if x <= 0.0 {
160        return 0.0;
161    }
162    if x >= 1.0 {
163        return 1.0;
164    }
165    if x > (a + 1.0) / (a + b + 2.0) {
166        return 1.0 - regularized_incomplete_beta(1.0 - x, b, a);
167    }
168    let log_cf =
169        (a * x.ln() + b * (1.0 - x).ln() - ln_gamma(a) - ln_gamma(b) + ln_gamma(a + b)).exp() / a;
170    log_cf * beta_cf(x, a, b)
171}
172
173fn beta_cf(x: f64, a: f64, b: f64) -> f64 {
174    let fpmin = 1e-300_f64;
175    let qab = a + b;
176    let qap = a + 1.0;
177    let qam = a - 1.0;
178    let mut c = 1.0_f64;
179    let mut d = 1.0 - qab * x / qap;
180    if d.abs() < fpmin {
181        d = fpmin;
182    }
183    d = 1.0 / d;
184    let mut h = d;
185    for m in 1..=200_i32 {
186        let mf = m as f64;
187        let aa = mf * (b - mf) * x / ((qam + 2.0 * mf) * (a + 2.0 * mf));
188        d = 1.0 + aa * d;
189        if d.abs() < fpmin {
190            d = fpmin;
191        }
192        c = 1.0 + aa / c;
193        if c.abs() < fpmin {
194            c = fpmin;
195        }
196        d = 1.0 / d;
197        h *= d * c;
198        let aa2 = -(a + mf) * (qab + mf) * x / ((a + 2.0 * mf) * (qap + 2.0 * mf));
199        d = 1.0 + aa2 * d;
200        if d.abs() < fpmin {
201            d = fpmin;
202        }
203        c = 1.0 + aa2 / c;
204        if c.abs() < fpmin {
205            c = fpmin;
206        }
207        d = 1.0 / d;
208        let del = d * c;
209        h *= del;
210        if (del - 1.0).abs() < 3e-15 {
211            break;
212        }
213    }
214    h
215}
216
217fn ln_gamma(x: f64) -> f64 {
218    const G: f64 = 7.0;
219    const C: [f64; 9] = [
220        0.99999999999980993,
221        676.5203681218851,
222        -1259.1392167224028,
223        771.323_428_777_653_1,
224        -176.615_029_162_140_6,
225        12.507_343_278_686_905,
226        -0.13857_109_526_572_012,
227        9.984_369_578_019_572e-6,
228        1.5056_327_351_493_116e-7,
229    ];
230    if x < 0.5 {
231        std::f64::consts::PI.ln() - (std::f64::consts::PI * x).sin().ln() - ln_gamma(1.0 - x)
232    } else {
233        let z = x - 1.0;
234        let mut s = C[0];
235        for (i, &ci) in C[1..].iter().enumerate() {
236            s += ci / (z + (i as f64) + 1.0);
237        }
238        let t = z + G + 0.5;
239        0.5 * (2.0 * std::f64::consts::PI).ln() + (z + 0.5) * t.ln() - t + s.ln()
240    }
241}
242
243// ---------------------------------------------------------------------------
244// OLS helper for DiD regressions
245// ---------------------------------------------------------------------------
246
247fn ols_fit_did(
248    x: &ArrayView2<f64>,
249    y: &ArrayView1<f64>,
250) -> StatsResult<(Array1<f64>, Array1<f64>, Array2<f64>)> {
251    let n = x.nrows();
252    let k = x.ncols();
253    if n < k {
254        return Err(StatsError::InsufficientData(format!(
255            "Need at least {k} observations, got {n}"
256        )));
257    }
258    let xtx = x.t().dot(x);
259    let xty = x.t().dot(y);
260    let xtx_inv = cholesky_invert_did(&xtx.view())?;
261    let beta = xtx_inv.dot(&xty);
262    let fitted = x.dot(&beta);
263    let residuals = y.to_owned() - fitted;
264    Ok((beta, residuals, xtx_inv))
265}
266
267fn cholesky_invert_did(a: &ArrayView2<f64>) -> StatsResult<Array2<f64>> {
268    let n = a.nrows();
269    let mut l = Array2::<f64>::zeros((n, n));
270    for i in 0..n {
271        for j in 0..=i {
272            let mut s = a[[i, j]];
273            for p in 0..j {
274                s -= l[[i, p]] * l[[j, p]];
275            }
276            if i == j {
277                if s <= 0.0 {
278                    return Err(StatsError::ComputationError(
279                        "Matrix not positive definite (DiD)".into(),
280                    ));
281                }
282                l[[i, j]] = s.sqrt();
283            } else {
284                l[[i, j]] = s / l[[j, j]];
285            }
286        }
287    }
288    let mut linv = Array2::<f64>::zeros((n, n));
289    for j in 0..n {
290        linv[[j, j]] = 1.0 / l[[j, j]];
291        for i in (j + 1)..n {
292            let mut s = 0.0_f64;
293            for p in j..i {
294                s += l[[i, p]] * linv[[p, j]];
295            }
296            linv[[i, j]] = -s / l[[i, i]];
297        }
298    }
299    Ok(linv.t().dot(&linv))
300}
301
302fn t_critical_did(alpha: f64, df: usize) -> f64 {
303    // Newton-Raphson inversion for the t critical value
304    let df_f = df as f64;
305    let mut t = 2.0_f64;
306    for _ in 0..50 {
307        let p = t_dist_p_value_did(t, df_f);
308        let target = 2.0 * alpha;
309        let err = p - target;
310        let delta = 1e-6;
311        let dp = (t_dist_p_value_did(t + delta, df_f) - p) / delta;
312        if dp.abs() < 1e-15 {
313            break;
314        }
315        t -= err / dp;
316        if err.abs() < 1e-10 {
317            break;
318        }
319    }
320    t.max(0.0)
321}
322
323// ---------------------------------------------------------------------------
324// Classic Difference-in-Differences
325// ---------------------------------------------------------------------------
326
327/// Classic 2×2 Difference-in-Differences estimator.
328///
329/// The ATT is estimated via the two-way fixed effects regression:
330///   y_{it} = α_i + δ_t + β D_{it} + ε_{it}
331/// where D_{it} = 1 for treated units after treatment.
332///
333/// Provides a parallel-trends pre-test and ATT with standard errors.
334pub struct DiD;
335
336impl DiD {
337    /// Estimate the ATT.
338    ///
339    /// # Arguments
340    /// * `y`         – outcome vector (n × T flattened, row-major: unit i at time t → index i*T + t)
341    /// * `treated`   – binary indicator for each unit (length n_units); 1 = treated
342    /// * `n_units`   – number of units
343    /// * `n_periods` – number of time periods
344    /// * `treat_period` – the first period when treatment takes effect (0-indexed)
345    ///
346    /// # Returns
347    /// [`DiDResult`] with ATT estimate and parallel-trends test.
348    pub fn estimate(
349        y: &ArrayView1<f64>,
350        treated: &ArrayView1<f64>,
351        n_units: usize,
352        n_periods: usize,
353        treat_period: usize,
354    ) -> StatsResult<DiDResult> {
355        let n = n_units * n_periods;
356        if y.len() != n {
357            return Err(StatsError::DimensionMismatch(format!(
358                "y length {} != n_units * n_periods = {}",
359                y.len(),
360                n
361            )));
362        }
363        if treated.len() != n_units {
364            return Err(StatsError::DimensionMismatch(
365                "treated length must equal n_units".into(),
366            ));
367        }
368        if treat_period >= n_periods {
369            return Err(StatsError::InvalidArgument(
370                "treat_period must be < n_periods".into(),
371            ));
372        }
373
374        let n_treated: usize = treated.iter().filter(|&&v| v > 0.5).count();
375        let n_control = n_units - n_treated;
376
377        // Build design matrix for TWFE regression:
378        // cols: [unit FE (n_units-1), time FE (n_periods-1), DiD indicator]
379        // Using within (demeaned) approach for efficiency: construct direct 2SLS-style regression.
380        // For simplicity, use the Mundlak-Wooldridge form: include unit/time dummies + D.
381
382        // cols: [intercept, unit FE (n_units-1), time FE (n_periods-1), DiD indicator]
383        let k = 1 + (n_units - 1) + (n_periods - 1) + 1;
384        let mut xmat = Array2::<f64>::zeros((n, k));
385        let mut y_vec = Array1::<f64>::zeros(n);
386
387        for i in 0..n_units {
388            for t in 0..n_periods {
389                let row = i * n_periods + t;
390                y_vec[row] = y[row];
391                // Intercept
392                xmat[[row, 0]] = 1.0;
393                // Unit FE (omit unit 0)
394                if i > 0 {
395                    xmat[[row, i]] = 1.0;
396                }
397                // Time FE (omit period 0)
398                if t > 0 {
399                    xmat[[row, n_units + t - 1]] = 1.0;
400                }
401                // DiD indicator
402                let post = if t >= treat_period { 1.0 } else { 0.0 };
403                let treat = treated[i];
404                xmat[[row, k - 1]] = post * treat;
405            }
406        }
407
408        let (beta, resid, xtx_inv) = ols_fit_did(&xmat.view(), &y_vec.view())?;
409        let att = beta[k - 1];
410        let df = (n - k) as f64;
411        let s2 = resid.iter().map(|&r| r * r).sum::<f64>() / df.max(1.0);
412        let var_att = xtx_inv[[k - 1, k - 1]] * s2;
413        let se = var_att.max(0.0).sqrt();
414        let t_stat = if se > 0.0 { att / se } else { 0.0 };
415        let p_val = t_dist_p_value_did(t_stat, df);
416        let t_crit = t_critical_did(0.025, df as usize);
417        let ci = [att - t_crit * se, att + t_crit * se];
418
419        // Parallel-trends pre-test:
420        // Regress pre-treatment trends on treated*time (should be zero)
421        let parallel_p = if treat_period > 1 {
422            Some(Self::parallel_trends_test(
423                y,
424                treated,
425                n_units,
426                n_periods,
427                treat_period,
428            )?)
429        } else {
430            None
431        };
432
433        Ok(DiDResult {
434            att,
435            std_error: se,
436            t_stat,
437            p_value: p_val,
438            conf_interval: ci,
439            parallel_trends_p: parallel_p,
440            n_treated,
441            n_control,
442            estimator: "DiD-TWFE".into(),
443        })
444    }
445
446    /// Pre-treatment parallel-trends test.
447    ///
448    /// Regresses y on treat×t for t < treat_period and tests whether the
449    /// interaction coefficient is zero.
450    fn parallel_trends_test(
451        y: &ArrayView1<f64>,
452        treated: &ArrayView1<f64>,
453        n_units: usize,
454        n_periods: usize,
455        treat_period: usize,
456    ) -> StatsResult<f64> {
457        // Use pre-treatment observations only
458        let n_pre = n_units * treat_period;
459        if n_pre < 4 {
460            return Ok(1.0); // Not enough data to test
461        }
462        let k_pre = 3; // intercept, time trend, treat*time
463        let mut x_pre = Array2::<f64>::zeros((n_pre, k_pre));
464        let mut y_pre = Array1::<f64>::zeros(n_pre);
465        let mut row = 0;
466        for i in 0..n_units {
467            for t in 0..treat_period {
468                y_pre[row] = y[i * n_periods + t];
469                x_pre[[row, 0]] = 1.0; // intercept
470                x_pre[[row, 1]] = t as f64; // time trend
471                x_pre[[row, 2]] = treated[i] * (t as f64); // treat × time
472                row += 1;
473            }
474        }
475        let (beta_pre, resid_pre, xtx_inv_pre) = ols_fit_did(&x_pre.view(), &y_pre.view())?;
476        let df_pre = (n_pre - k_pre) as f64;
477        let s2_pre = resid_pre.iter().map(|&r| r * r).sum::<f64>() / df_pre.max(1.0);
478        let var_coef = xtx_inv_pre[[k_pre - 1, k_pre - 1]] * s2_pre;
479        let se = var_coef.max(0.0).sqrt();
480        let t = if se > 0.0 {
481            beta_pre[k_pre - 1] / se
482        } else {
483            0.0
484        };
485        Ok(t_dist_p_value_did(t, df_pre))
486    }
487}
488
489// ---------------------------------------------------------------------------
490// Synthetic Control
491// ---------------------------------------------------------------------------
492
493/// Synthetic Control Method (Abadie-Diamond-Hainmueller, 2010).
494///
495/// Finds a weighted combination of donor units that best matches the treated
496/// unit's pre-treatment trajectory.  The weights are constrained to be
497/// non-negative and sum to one; they are found by projected-gradient descent.
498pub struct SyntheticControl {
499    /// Maximum number of optimization iterations
500    pub max_iter: usize,
501    /// Convergence tolerance
502    pub tol: f64,
503}
504
505impl SyntheticControl {
506    /// Create a new SyntheticControl estimator.
507    pub fn new() -> Self {
508        Self {
509            max_iter: 2000,
510            tol: 1e-8,
511        }
512    }
513
514    /// Fit the synthetic control weights.
515    ///
516    /// # Arguments
517    /// * `y_treated`  – pre-treatment outcomes for the treated unit (T_pre,)
518    /// * `y_donors`   – pre-treatment outcomes for donor units (T_pre × n_donors)
519    ///
520    /// # Returns
521    /// Optimal weights (n_donors,), sum to 1, all >= 0.
522    pub fn fit_weights(
523        &self,
524        y_treated: &ArrayView1<f64>,
525        y_donors: &ArrayView2<f64>,
526    ) -> StatsResult<Array1<f64>> {
527        let t_pre = y_treated.len();
528        let n_donors = y_donors.ncols();
529        if y_donors.nrows() != t_pre {
530            return Err(StatsError::DimensionMismatch(
531                "y_donors must have same number of rows as y_treated".into(),
532            ));
533        }
534        if n_donors == 0 {
535            return Err(StatsError::InvalidArgument(
536                "Need at least one donor unit".into(),
537            ));
538        }
539
540        // Minimize ||y_treated - Y_donors w||² s.t. w >= 0, sum(w) = 1
541        // Projected gradient descent with projection onto simplex.
542        let mut w: Array1<f64> = Array1::from_elem(n_donors, 1.0 / n_donors as f64);
543        let yd_t = y_donors.t(); // n_donors × T_pre
544
545        // Pre-compute Y'Y and Y'y for gradient
546        let ytd_y: Array2<f64> = yd_t.dot(y_donors); // n_donors × n_donors
547        let ytd_yt: Array1<f64> = yd_t.dot(y_treated); // n_donors
548
549        // Step size: 1 / max eigenvalue of Y'Y (Gershgorin bound)
550        let step_denom: f64 = ytd_y
551            .rows()
552            .into_iter()
553            .map(|row| row.iter().map(|&v| v.abs()).sum::<f64>())
554            .fold(f64::NEG_INFINITY, f64::max);
555        let lr = if step_denom > 0.0 {
556            0.5 / step_denom
557        } else {
558            1e-3
559        };
560
561        for _ in 0..self.max_iter {
562            // Gradient of ||y - Yw||² w.r.t. w: 2 (Y'Y w - Y'y)
563            let grad = ytd_y.dot(&w) - &ytd_yt;
564            let w_new_raw = &w - &grad.mapv(|g| g * lr);
565            let w_new = project_simplex(&w_new_raw.view());
566            let diff: f64 = (&w_new - &w).iter().map(|&d| d * d).sum::<f64>().sqrt();
567            w = w_new;
568            if diff < self.tol {
569                break;
570            }
571        }
572
573        Ok(w)
574    }
575
576    /// Estimate treatment effect for each post-treatment period.
577    ///
578    /// # Arguments
579    /// * `y_treated_post`  – post outcomes for treated unit (T_post,)
580    /// * `y_donors_post`   – post outcomes for donor units (T_post × n_donors)
581    /// * `weights`         – fitted weights from `fit_weights`
582    pub fn treatment_effects(
583        &self,
584        y_treated_post: &ArrayView1<f64>,
585        y_donors_post: &ArrayView2<f64>,
586        weights: &ArrayView1<f64>,
587    ) -> StatsResult<Array1<f64>> {
588        if y_donors_post.nrows() != y_treated_post.len() {
589            return Err(StatsError::DimensionMismatch(
590                "y_donors_post rows must equal y_treated_post length".into(),
591            ));
592        }
593        let synthetic = y_donors_post.dot(weights);
594        Ok(y_treated_post.to_owned() - synthetic)
595    }
596}
597
598/// Project a vector onto the probability simplex: w >= 0, sum = 1.
599fn project_simplex(v: &ArrayView1<f64>) -> Array1<f64> {
600    let n = v.len();
601    let mut u: Vec<f64> = v.to_vec();
602    u.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
603    let mut rho = 0_usize;
604    let mut cum = 0.0_f64;
605    for (j, &uj) in u.iter().enumerate() {
606        cum += uj;
607        if uj - (cum - 1.0) / (j as f64 + 1.0) > 0.0 {
608            rho = j;
609        }
610    }
611    let cum_rho: f64 = u[..=rho].iter().sum();
612    let theta = (cum_rho - 1.0) / (rho as f64 + 1.0);
613    v.mapv(|vi| (vi - theta).max(0.0))
614}
615
616impl Default for SyntheticControl {
617    fn default() -> Self {
618        Self::new()
619    }
620}
621
622// ---------------------------------------------------------------------------
623// Event Study
624// ---------------------------------------------------------------------------
625
626/// Event-study design for estimating dynamic treatment effects.
627///
628/// Estimates the regression:
629///   y_{it} = Σ_{l≠-1} β_l D^l_{it} + α_i + δ_t + ε_{it}
630/// where D^l_{it} = 1 if unit i is treated and is l periods from its
631/// treatment date.  The omitted period is l = -1 (normalisation).
632pub struct EventStudy {
633    /// How many pre-treatment periods to include (>= 1)
634    pub n_pre: usize,
635    /// How many post-treatment periods to include (>= 1)
636    pub n_post: usize,
637}
638
639impl EventStudy {
640    /// Create a new EventStudy estimator.
641    pub fn new(n_pre: usize, n_post: usize) -> Self {
642        Self { n_pre, n_post }
643    }
644
645    /// Estimate dynamic treatment effects.
646    ///
647    /// # Arguments
648    /// * `y`           – outcome vector (n_units × n_periods, row-major)
649    /// * `treated`     – binary indicator per unit (n_units,)
650    /// * `n_units`     – number of units
651    /// * `n_periods`   – number of time periods
652    /// * `treat_period`– the first treatment period for all treated units (0-indexed)
653    pub fn estimate(
654        &self,
655        y: &ArrayView1<f64>,
656        treated: &ArrayView1<f64>,
657        n_units: usize,
658        n_periods: usize,
659        treat_period: usize,
660    ) -> StatsResult<EventStudyResult> {
661        let n = n_units * n_periods;
662        if y.len() != n {
663            return Err(StatsError::DimensionMismatch(
664                "y length != n_units * n_periods".into(),
665            ));
666        }
667
668        // Relative time periods we estimate: [-n_pre, ..., -2, -1 (omit), 0, ..., n_post-1]
669        // Total coefficients (excluding the omitted -1): n_pre + n_post - 1 (omit l=-1)
670        // But we include l=-1 dummy in the regression then drop it; simpler: include all and use
671        // the coefficient of l=-1 as the reference (pin to 0 via exclusion).
672        // We include event-time dummies: [-n_pre, ..., -2, 0, ..., n_post-1]
673        let n_event_dummies = self.n_pre + self.n_post - 1; // exclude l=-1
674
675        // Design matrix: [unit FE (n_units-1), time FE (n_periods-1), event dummies]
676        let k = (n_units - 1) + (n_periods - 1) + n_event_dummies;
677        let mut xmat = Array2::<f64>::zeros((n, k));
678        let mut y_vec = Array1::<f64>::zeros(n);
679
680        let event_times: Vec<i64> = {
681            let mut v: Vec<i64> = (-(self.n_pre as i64)..=(self.n_post as i64 - 1)).collect();
682            v.retain(|&l| l != -1); // omit l = -1
683            v
684        };
685
686        for i in 0..n_units {
687            for t in 0..n_periods {
688                let row = i * n_periods + t;
689                y_vec[row] = y[row];
690                // Unit FE
691                if i > 0 {
692                    xmat[[row, i - 1]] = 1.0;
693                }
694                // Time FE
695                if t > 0 {
696                    xmat[[row, n_units - 1 + t - 1]] = 1.0;
697                }
698                // Event-time dummies
699                if treated[i] > 0.5 {
700                    let rel_time = (t as i64) - (treat_period as i64);
701                    for (d_idx, &et) in event_times.iter().enumerate() {
702                        if rel_time == et {
703                            xmat[[row, (n_units - 1) + (n_periods - 1) + d_idx]] = 1.0;
704                        }
705                    }
706                }
707            }
708        }
709
710        let (beta, resid, xtx_inv) = ols_fit_did(&xmat.view(), &y_vec.view())?;
711        let df = (n - k) as f64;
712        let s2 = resid.iter().map(|&r| r * r).sum::<f64>() / df.max(1.0);
713        let t_crit = t_critical_did(0.025, df as usize);
714        let fe_offset = (n_units - 1) + (n_periods - 1);
715
716        let mut coefficients = Vec::with_capacity(n_event_dummies);
717        for (d_idx, &et) in event_times.iter().enumerate() {
718            let coef_idx = fe_offset + d_idx;
719            let est = beta[coef_idx];
720            let se = (xtx_inv[[coef_idx, coef_idx]] * s2).max(0.0).sqrt();
721            let t = if se > 0.0 { est / se } else { 0.0 };
722            let p = t_dist_p_value_did(t, df);
723            coefficients.push(EventCoefficient {
724                relative_time: et,
725                estimate: est,
726                std_error: se,
727                t_stat: t,
728                p_value: p,
729                conf_interval: [est - t_crit * se, est + t_crit * se],
730            });
731        }
732
733        // Pre-trend F-test (joint test that pre-treatment coefficients = 0)
734        let n_pre_coefs = self.n_pre.saturating_sub(1); // exclude l=-1 normalisation
735        let (pre_f, pre_p) = if n_pre_coefs > 0 {
736            // R matrix picks out the pre-treatment coefficients
737            let pre_coef_idxs: Vec<usize> = (0..n_pre_coefs).map(|j| fe_offset + j).collect();
738            let rss_ur = resid.iter().map(|&r| r * r).sum::<f64>();
739            // Restricted model: set pre coefficients = 0
740            let mut x_r = xmat.clone();
741            for &idx in &pre_coef_idxs {
742                for i in 0..n {
743                    x_r[[i, idx]] = 0.0;
744                }
745            }
746            // Drop those columns
747            let cols_r: Vec<usize> = (0..k).filter(|c| !pre_coef_idxs.contains(c)).collect();
748            let mut xr = Array2::<f64>::zeros((n, cols_r.len()));
749            for (new_j, &old_j) in cols_r.iter().enumerate() {
750                for i in 0..n {
751                    xr[[i, new_j]] = xmat[[i, old_j]];
752                }
753            }
754            let (_br, resid_r, _) = ols_fit_did(&xr.view(), &y_vec.view())?;
755            let rss_r = resid_r.iter().map(|&r| r * r).sum::<f64>();
756            let f = ((rss_r - rss_ur) / n_pre_coefs as f64) / (rss_ur / df).max(1e-15);
757            // F-distribution p-value via chi2 approx
758            let chi2 = f * n_pre_coefs as f64;
759            let p_f = 1.0 - regularized_gamma_lower_did(n_pre_coefs as f64 / 2.0, chi2 / 2.0);
760            (f, p_f)
761        } else {
762            (0.0, 1.0)
763        };
764
765        Ok(EventStudyResult {
766            coefficients,
767            pre_trend_f: pre_f,
768            pre_trend_p: pre_p,
769            pre_trend_df: n_pre_coefs,
770        })
771    }
772}
773
774fn regularized_gamma_lower_did(a: f64, x: f64) -> f64 {
775    if x < 0.0 {
776        return 0.0;
777    }
778    if x == 0.0 {
779        return 0.0;
780    }
781    if x < a + 1.0 {
782        let mut ap = a;
783        let mut del = 1.0 / a;
784        let mut sum = del;
785        for _ in 0..200 {
786            ap += 1.0;
787            del *= x / ap;
788            sum += del;
789            if del.abs() < sum.abs() * 3e-15 {
790                break;
791            }
792        }
793        sum * (-x + a * x.ln() - ln_gamma(a)).exp()
794    } else {
795        1.0 - regularized_gamma_upper_did(a, x)
796    }
797}
798
799fn regularized_gamma_upper_did(a: f64, x: f64) -> f64 {
800    let fpmin = 1e-300_f64;
801    let mut b = x + 1.0 - a;
802    let mut c = 1.0 / fpmin;
803    let mut d = 1.0 / b;
804    let mut h = d;
805    for i in 1..=200_i64 {
806        let an = -(i as f64) * ((i as f64) - a);
807        b += 2.0;
808        d = an * d + b;
809        if d.abs() < fpmin {
810            d = fpmin;
811        }
812        c = b + an / c;
813        if c.abs() < fpmin {
814            c = fpmin;
815        }
816        d = 1.0 / d;
817        let del = d * c;
818        h *= del;
819        if (del - 1.0).abs() < 3e-15 {
820            break;
821        }
822    }
823    (-x + a * x.ln() - ln_gamma(a)).exp() * h
824}
825
826// ---------------------------------------------------------------------------
827// Staggered DiD (Callaway-Sant'Anna)
828// ---------------------------------------------------------------------------
829
830/// Callaway-Sant'Anna (2021) doubly-robust ATT(g,t) estimator
831/// for staggered treatment adoption.
832///
833/// The ATT for cohort `g` at period `t` is:
834///   ATT(g,t) = E[Y_t(g) - Y_t(0) | G = g]
835/// estimated using the "not yet treated" comparison group and
836/// inverse-probability weighting with logistic propensity scores.
837pub struct StaggeredDiD {
838    /// Number of bootstrap replications for standard errors
839    pub n_bootstrap: usize,
840    /// Random seed for bootstrap
841    pub seed: u64,
842}
843
844impl StaggeredDiD {
845    /// Create a new StaggeredDiD estimator.
846    pub fn new(n_bootstrap: usize, seed: u64) -> Self {
847        Self { n_bootstrap, seed }
848    }
849
850    /// Estimate ATT(g,t) for all (cohort, period) pairs.
851    ///
852    /// # Arguments
853    /// * `y`           – outcome matrix (n_units × n_periods)
854    /// * `g`           – cohort vector: `g[i]` = first treatment period of unit i;
855    ///                   set `g[i]` = i64::MAX for never-treated units.
856    /// * `n_units`     – number of units
857    /// * `n_periods`   – number of calendar periods
858    pub fn estimate(
859        &self,
860        y: &ArrayView2<f64>,
861        g: &[i64],
862        n_units: usize,
863        n_periods: usize,
864    ) -> StatsResult<StaggeredDiDResult> {
865        if y.nrows() != n_units || y.ncols() != n_periods {
866            return Err(StatsError::DimensionMismatch(
867                "y must be (n_units × n_periods)".into(),
868            ));
869        }
870        if g.len() != n_units {
871            return Err(StatsError::DimensionMismatch(
872                "g must have length n_units".into(),
873            ));
874        }
875
876        // Collect unique treatment cohorts (excluding never-treated)
877        let mut cohorts: Vec<i64> = g
878            .iter()
879            .filter(|&&gi| gi < i64::MAX && gi >= 0)
880            .cloned()
881            .collect::<std::collections::HashSet<i64>>()
882            .into_iter()
883            .collect();
884        cohorts.sort();
885
886        let mut att_gt_vec: Vec<AttGt> = Vec::new();
887
888        for &cohort in &cohorts {
889            // Treated units in this cohort
890            let treated_ids: Vec<usize> = (0..n_units).filter(|&i| g[i] == cohort).collect();
891            // "Not yet treated" units at time t: never-treated + those with cohort > t
892            // We estimate for post-treatment periods t >= cohort
893            for t in 0..n_periods {
894                let t_i64 = t as i64;
895                // Control group: not yet treated at time t (and not treated before t)
896                let control_ids: Vec<usize> = (0..n_units)
897                    .filter(|&i| g[i] == i64::MAX || g[i] > t_i64)
898                    .collect();
899
900                if treated_ids.is_empty() || control_ids.is_empty() {
901                    continue;
902                }
903
904                // Reference period: cohort - 1 (last pre-treatment period)
905                let t_ref = (cohort - 1) as usize;
906                if t_ref >= n_periods {
907                    continue;
908                }
909
910                // Doubly-robust DiD: IPW on propensity score
911                // Simple implementation: use difference-in-means with propensity weighting
912                // P(treated | baseline characteristics) estimated via logistic on y at t_ref
913                let (att, se) = self.compute_att_gt(y, &treated_ids, &control_ids, t, t_ref)?;
914
915                let p = normal_p_value(if se > 0.0 { att / se } else { 0.0 });
916                att_gt_vec.push(AttGt {
917                    cohort,
918                    period: t_i64,
919                    att,
920                    std_error: se,
921                    p_value: p,
922                });
923            }
924        }
925
926        if att_gt_vec.is_empty() {
927            return Err(StatsError::InsufficientData(
928                "No valid (cohort, period) pairs found".into(),
929            ));
930        }
931
932        // Aggregate ATT: simple weighted average over post-treatment (g,t) pairs
933        let post_atts: Vec<&AttGt> = att_gt_vec
934            .iter()
935            .filter(|ag| ag.period >= ag.cohort)
936            .collect();
937        let aggregate_att = if post_atts.is_empty() {
938            0.0
939        } else {
940            post_atts.iter().map(|ag| ag.att).sum::<f64>() / post_atts.len() as f64
941        };
942        // Aggregate SE: pooled
943        let aggregate_se = if post_atts.is_empty() {
944            0.0
945        } else {
946            let var_sum: f64 = post_atts.iter().map(|ag| ag.std_error * ag.std_error).sum();
947            (var_sum / (post_atts.len() * post_atts.len()) as f64).sqrt()
948        };
949        let aggregate_p = normal_p_value(if aggregate_se > 0.0 {
950            aggregate_att / aggregate_se
951        } else {
952            0.0
953        });
954
955        Ok(StaggeredDiDResult {
956            att_gt: att_gt_vec,
957            aggregate_att,
958            aggregate_se,
959            aggregate_p,
960        })
961    }
962
963    /// Compute ATT(g,t) for a specific cohort-period pair.
964    fn compute_att_gt(
965        &self,
966        y: &ArrayView2<f64>,
967        treated_ids: &[usize],
968        control_ids: &[usize],
969        t: usize,
970        t_ref: usize,
971    ) -> StatsResult<(f64, f64)> {
972        let n_t = treated_ids.len();
973        let n_c = control_ids.len();
974
975        // Delta Y for treated: y_t - y_{t_ref}
976        let delta_treated: Vec<f64> = treated_ids
977            .iter()
978            .map(|&i| y[[i, t]] - y[[i, t_ref]])
979            .collect();
980        // Delta Y for control: y_t - y_{t_ref}
981        let delta_control: Vec<f64> = control_ids
982            .iter()
983            .map(|&i| y[[i, t]] - y[[i, t_ref]])
984            .collect();
985
986        let mean_t = delta_treated.iter().sum::<f64>() / n_t as f64;
987        let mean_c = delta_control.iter().sum::<f64>() / n_c as f64;
988        let att = mean_t - mean_c;
989
990        // Variance via delta method
991        let var_t = if n_t > 1 {
992            delta_treated
993                .iter()
994                .map(|&v| (v - mean_t).powi(2))
995                .sum::<f64>()
996                / (n_t * (n_t - 1)) as f64
997        } else {
998            0.0
999        };
1000        let var_c = if n_c > 1 {
1001            delta_control
1002                .iter()
1003                .map(|&v| (v - mean_c).powi(2))
1004                .sum::<f64>()
1005                / (n_c * (n_c - 1)) as f64
1006        } else {
1007            0.0
1008        };
1009        let se = (var_t + var_c).sqrt();
1010
1011        Ok((att, se))
1012    }
1013}
1014
1015// ---------------------------------------------------------------------------
1016// Tests
1017// ---------------------------------------------------------------------------
1018
1019#[cfg(test)]
1020mod tests {
1021    use super::*;
1022    use scirs2_core::ndarray::{array, Array1, Array2};
1023
1024    #[test]
1025    fn test_did_no_effect() {
1026        // Parallel trends and no treatment effect
1027        let n_units = 4_usize;
1028        let n_periods = 4_usize;
1029        let treat_period = 2_usize;
1030        // Treated units: 0, 1; control: 2, 3
1031        let treated = array![1.0, 1.0, 0.0, 0.0];
1032        // y = unit_fe + time_fe (no treatment effect)
1033        let unit_fe = [1.0, 2.0, 1.5, 2.5];
1034        let time_fe = [0.0, 1.0, 2.0, 3.0];
1035        let mut y_vec = Array1::<f64>::zeros(n_units * n_periods);
1036        for i in 0..n_units {
1037            for t in 0..n_periods {
1038                y_vec[i * n_periods + t] = unit_fe[i] + time_fe[t];
1039            }
1040        }
1041        let res = DiD::estimate(
1042            &y_vec.view(),
1043            &treated.view(),
1044            n_units,
1045            n_periods,
1046            treat_period,
1047        )
1048        .expect("DiD estimate should succeed");
1049        assert!(
1050            res.att.abs() < 0.1,
1051            "ATT should be ~0 when there is no effect, got {}",
1052            res.att
1053        );
1054        assert_eq!(res.n_treated, 2);
1055        assert_eq!(res.n_control, 2);
1056    }
1057
1058    #[test]
1059    fn test_did_known_effect() {
1060        let n_units = 4_usize;
1061        let n_periods = 4_usize;
1062        let treat_period = 2_usize;
1063        let treated = array![1.0, 1.0, 0.0, 0.0];
1064        let unit_fe = [0.0, 0.0, 0.0, 0.0];
1065        let time_fe = [0.0, 0.0, 0.0, 0.0];
1066        let treatment_effect = 5.0_f64;
1067        let mut y_vec = Array1::<f64>::zeros(n_units * n_periods);
1068        for i in 0..n_units {
1069            for t in 0..n_periods {
1070                let te = if treated[i] > 0.5 && t >= treat_period {
1071                    treatment_effect
1072                } else {
1073                    0.0
1074                };
1075                y_vec[i * n_periods + t] = unit_fe[i] + time_fe[t] + te;
1076            }
1077        }
1078        let res = DiD::estimate(
1079            &y_vec.view(),
1080            &treated.view(),
1081            n_units,
1082            n_periods,
1083            treat_period,
1084        )
1085        .expect("DiD estimate should succeed");
1086        assert!(
1087            (res.att - treatment_effect).abs() < 0.5,
1088            "ATT should be ~5.0, got {}",
1089            res.att
1090        );
1091    }
1092
1093    #[test]
1094    fn test_synthetic_control_simplex_weights() {
1095        let n_donors = 4_usize;
1096        let t_pre = 10_usize;
1097        let treated: Array1<f64> = (0..t_pre).map(|t| t as f64).collect();
1098        // Donors: first donor perfectly matches
1099        let mut donors = Array2::<f64>::zeros((t_pre, n_donors));
1100        for t in 0..t_pre {
1101            donors[[t, 0]] = t as f64; // perfect match
1102            donors[[t, 1]] = t as f64 * 2.0;
1103            donors[[t, 2]] = (t as f64).powi(2);
1104            donors[[t, 3]] = 0.0;
1105        }
1106        let sc = SyntheticControl::new();
1107        let weights = sc
1108            .fit_weights(&treated.view(), &donors.view())
1109            .expect("SyntheticControl fit should succeed");
1110        // Weights should sum to 1
1111        let sum: f64 = weights.iter().sum();
1112        assert!(
1113            (sum - 1.0).abs() < 1e-6,
1114            "Weights should sum to 1, got {}",
1115            sum
1116        );
1117        // All weights non-negative
1118        assert!(weights.iter().all(|&w| w >= -1e-10));
1119    }
1120
1121    #[test]
1122    fn test_event_study_no_pre_trends() {
1123        let n_units = 6_usize;
1124        let n_periods = 6_usize;
1125        let treat_period = 3_usize;
1126        let treated = array![1.0, 1.0, 1.0, 0.0, 0.0, 0.0];
1127        // No pre-trends, treatment effect = 3 post-treatment
1128        let treatment_effect = 3.0_f64;
1129        let mut y_vec = Array1::<f64>::zeros(n_units * n_periods);
1130        for i in 0..n_units {
1131            for t in 0..n_periods {
1132                let te = if treated[i] > 0.5 && t >= treat_period {
1133                    treatment_effect
1134                } else {
1135                    0.0
1136                };
1137                y_vec[i * n_periods + t] = te;
1138            }
1139        }
1140        let es = EventStudy::new(2, 3);
1141        let res = es
1142            .estimate(
1143                &y_vec.view(),
1144                &treated.view(),
1145                n_units,
1146                n_periods,
1147                treat_period,
1148            )
1149            .expect("EventStudy should succeed");
1150        // Check post-treatment coefficients are positive
1151        let post_coefs: Vec<&EventCoefficient> = res
1152            .coefficients
1153            .iter()
1154            .filter(|c| c.relative_time >= 0)
1155            .collect();
1156        assert!(
1157            !post_coefs.is_empty(),
1158            "Should have post-treatment coefficients"
1159        );
1160    }
1161}