1use crate::error::{StatsError, StatsResult};
19use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
20use std::collections::HashMap;
21
22#[derive(Debug, Clone)]
28pub struct DiDResult {
29 pub att: f64,
31
32 pub std_error: f64,
34
35 pub t_stat: f64,
37
38 pub p_value: f64,
40
41 pub conf_interval: [f64; 2],
43
44 pub parallel_trends_p: Option<f64>,
46
47 pub n_treated: usize,
49
50 pub n_control: usize,
52
53 pub estimator: String,
55}
56
57#[derive(Debug, Clone)]
59pub struct EventCoefficient {
60 pub relative_time: i64,
62 pub estimate: f64,
64 pub std_error: f64,
66 pub t_stat: f64,
68 pub p_value: f64,
70 pub conf_interval: [f64; 2],
72}
73
74#[derive(Debug, Clone)]
76pub struct EventStudyResult {
77 pub coefficients: Vec<EventCoefficient>,
79 pub pre_trend_f: f64,
81 pub pre_trend_p: f64,
83 pub pre_trend_df: usize,
85}
86
87#[derive(Debug, Clone)]
89pub struct StaggeredDiDResult {
90 pub att_gt: Vec<AttGt>,
92 pub aggregate_att: f64,
94 pub aggregate_se: f64,
96 pub aggregate_p: f64,
98}
99
100#[derive(Debug, Clone)]
102pub struct AttGt {
103 pub cohort: i64,
105 pub period: i64,
107 pub att: f64,
109 pub std_error: f64,
111 pub p_value: f64,
113}
114
115fn normal_cdf(x: f64) -> f64 {
120 0.5 * (1.0 + libm_erf(x / std::f64::consts::SQRT_2))
121}
122
123fn libm_erf(x: f64) -> f64 {
124 let t = 1.0 / (1.0 + 0.3275911 * x.abs());
126 let y = 1.0
127 - (0.254829592
128 + (-0.284496736 + (1.421413741 + (-1.453152027 + 1.061405429 * t) * t) * t) * t)
129 * t
130 * (-x * x).exp();
131 if x >= 0.0 {
132 y
133 } else {
134 -y
135 }
136}
137
138fn normal_p_value(z: f64) -> f64 {
139 2.0 * (1.0 - normal_cdf(z.abs()))
141}
142
143fn t_dist_p_value_did(t: f64, df: f64) -> f64 {
144 if df <= 0.0 {
145 return 1.0;
146 }
147 if df > 200.0 {
149 return normal_p_value(t);
150 }
151 let x = df / (df + t * t);
153 regularized_incomplete_beta(x, df / 2.0, 0.5)
154 .min(1.0)
155 .max(0.0)
156}
157
158fn regularized_incomplete_beta(x: f64, a: f64, b: f64) -> f64 {
159 if x <= 0.0 {
160 return 0.0;
161 }
162 if x >= 1.0 {
163 return 1.0;
164 }
165 if x > (a + 1.0) / (a + b + 2.0) {
166 return 1.0 - regularized_incomplete_beta(1.0 - x, b, a);
167 }
168 let log_cf =
169 (a * x.ln() + b * (1.0 - x).ln() - ln_gamma(a) - ln_gamma(b) + ln_gamma(a + b)).exp() / a;
170 log_cf * beta_cf(x, a, b)
171}
172
173fn beta_cf(x: f64, a: f64, b: f64) -> f64 {
174 let fpmin = 1e-300_f64;
175 let qab = a + b;
176 let qap = a + 1.0;
177 let qam = a - 1.0;
178 let mut c = 1.0_f64;
179 let mut d = 1.0 - qab * x / qap;
180 if d.abs() < fpmin {
181 d = fpmin;
182 }
183 d = 1.0 / d;
184 let mut h = d;
185 for m in 1..=200_i32 {
186 let mf = m as f64;
187 let aa = mf * (b - mf) * x / ((qam + 2.0 * mf) * (a + 2.0 * mf));
188 d = 1.0 + aa * d;
189 if d.abs() < fpmin {
190 d = fpmin;
191 }
192 c = 1.0 + aa / c;
193 if c.abs() < fpmin {
194 c = fpmin;
195 }
196 d = 1.0 / d;
197 h *= d * c;
198 let aa2 = -(a + mf) * (qab + mf) * x / ((a + 2.0 * mf) * (qap + 2.0 * mf));
199 d = 1.0 + aa2 * d;
200 if d.abs() < fpmin {
201 d = fpmin;
202 }
203 c = 1.0 + aa2 / c;
204 if c.abs() < fpmin {
205 c = fpmin;
206 }
207 d = 1.0 / d;
208 let del = d * c;
209 h *= del;
210 if (del - 1.0).abs() < 3e-15 {
211 break;
212 }
213 }
214 h
215}
216
217fn ln_gamma(x: f64) -> f64 {
218 const G: f64 = 7.0;
219 const C: [f64; 9] = [
220 0.99999999999980993,
221 676.5203681218851,
222 -1259.1392167224028,
223 771.323_428_777_653_1,
224 -176.615_029_162_140_6,
225 12.507_343_278_686_905,
226 -0.13857_109_526_572_012,
227 9.984_369_578_019_572e-6,
228 1.5056_327_351_493_116e-7,
229 ];
230 if x < 0.5 {
231 std::f64::consts::PI.ln() - (std::f64::consts::PI * x).sin().ln() - ln_gamma(1.0 - x)
232 } else {
233 let z = x - 1.0;
234 let mut s = C[0];
235 for (i, &ci) in C[1..].iter().enumerate() {
236 s += ci / (z + (i as f64) + 1.0);
237 }
238 let t = z + G + 0.5;
239 0.5 * (2.0 * std::f64::consts::PI).ln() + (z + 0.5) * t.ln() - t + s.ln()
240 }
241}
242
243fn ols_fit_did(
248 x: &ArrayView2<f64>,
249 y: &ArrayView1<f64>,
250) -> StatsResult<(Array1<f64>, Array1<f64>, Array2<f64>)> {
251 let n = x.nrows();
252 let k = x.ncols();
253 if n < k {
254 return Err(StatsError::InsufficientData(format!(
255 "Need at least {k} observations, got {n}"
256 )));
257 }
258 let xtx = x.t().dot(x);
259 let xty = x.t().dot(y);
260 let xtx_inv = cholesky_invert_did(&xtx.view())?;
261 let beta = xtx_inv.dot(&xty);
262 let fitted = x.dot(&beta);
263 let residuals = y.to_owned() - fitted;
264 Ok((beta, residuals, xtx_inv))
265}
266
267fn cholesky_invert_did(a: &ArrayView2<f64>) -> StatsResult<Array2<f64>> {
268 let n = a.nrows();
269 let mut l = Array2::<f64>::zeros((n, n));
270 for i in 0..n {
271 for j in 0..=i {
272 let mut s = a[[i, j]];
273 for p in 0..j {
274 s -= l[[i, p]] * l[[j, p]];
275 }
276 if i == j {
277 if s <= 0.0 {
278 return Err(StatsError::ComputationError(
279 "Matrix not positive definite (DiD)".into(),
280 ));
281 }
282 l[[i, j]] = s.sqrt();
283 } else {
284 l[[i, j]] = s / l[[j, j]];
285 }
286 }
287 }
288 let mut linv = Array2::<f64>::zeros((n, n));
289 for j in 0..n {
290 linv[[j, j]] = 1.0 / l[[j, j]];
291 for i in (j + 1)..n {
292 let mut s = 0.0_f64;
293 for p in j..i {
294 s += l[[i, p]] * linv[[p, j]];
295 }
296 linv[[i, j]] = -s / l[[i, i]];
297 }
298 }
299 Ok(linv.t().dot(&linv))
300}
301
302fn t_critical_did(alpha: f64, df: usize) -> f64 {
303 let df_f = df as f64;
305 let mut t = 2.0_f64;
306 for _ in 0..50 {
307 let p = t_dist_p_value_did(t, df_f);
308 let target = 2.0 * alpha;
309 let err = p - target;
310 let delta = 1e-6;
311 let dp = (t_dist_p_value_did(t + delta, df_f) - p) / delta;
312 if dp.abs() < 1e-15 {
313 break;
314 }
315 t -= err / dp;
316 if err.abs() < 1e-10 {
317 break;
318 }
319 }
320 t.max(0.0)
321}
322
323pub struct DiD;
335
336impl DiD {
337 pub fn estimate(
349 y: &ArrayView1<f64>,
350 treated: &ArrayView1<f64>,
351 n_units: usize,
352 n_periods: usize,
353 treat_period: usize,
354 ) -> StatsResult<DiDResult> {
355 let n = n_units * n_periods;
356 if y.len() != n {
357 return Err(StatsError::DimensionMismatch(format!(
358 "y length {} != n_units * n_periods = {}",
359 y.len(),
360 n
361 )));
362 }
363 if treated.len() != n_units {
364 return Err(StatsError::DimensionMismatch(
365 "treated length must equal n_units".into(),
366 ));
367 }
368 if treat_period >= n_periods {
369 return Err(StatsError::InvalidArgument(
370 "treat_period must be < n_periods".into(),
371 ));
372 }
373
374 let n_treated: usize = treated.iter().filter(|&&v| v > 0.5).count();
375 let n_control = n_units - n_treated;
376
377 let k = 1 + (n_units - 1) + (n_periods - 1) + 1;
384 let mut xmat = Array2::<f64>::zeros((n, k));
385 let mut y_vec = Array1::<f64>::zeros(n);
386
387 for i in 0..n_units {
388 for t in 0..n_periods {
389 let row = i * n_periods + t;
390 y_vec[row] = y[row];
391 xmat[[row, 0]] = 1.0;
393 if i > 0 {
395 xmat[[row, i]] = 1.0;
396 }
397 if t > 0 {
399 xmat[[row, n_units + t - 1]] = 1.0;
400 }
401 let post = if t >= treat_period { 1.0 } else { 0.0 };
403 let treat = treated[i];
404 xmat[[row, k - 1]] = post * treat;
405 }
406 }
407
408 let (beta, resid, xtx_inv) = ols_fit_did(&xmat.view(), &y_vec.view())?;
409 let att = beta[k - 1];
410 let df = (n - k) as f64;
411 let s2 = resid.iter().map(|&r| r * r).sum::<f64>() / df.max(1.0);
412 let var_att = xtx_inv[[k - 1, k - 1]] * s2;
413 let se = var_att.max(0.0).sqrt();
414 let t_stat = if se > 0.0 { att / se } else { 0.0 };
415 let p_val = t_dist_p_value_did(t_stat, df);
416 let t_crit = t_critical_did(0.025, df as usize);
417 let ci = [att - t_crit * se, att + t_crit * se];
418
419 let parallel_p = if treat_period > 1 {
422 Some(Self::parallel_trends_test(
423 y,
424 treated,
425 n_units,
426 n_periods,
427 treat_period,
428 )?)
429 } else {
430 None
431 };
432
433 Ok(DiDResult {
434 att,
435 std_error: se,
436 t_stat,
437 p_value: p_val,
438 conf_interval: ci,
439 parallel_trends_p: parallel_p,
440 n_treated,
441 n_control,
442 estimator: "DiD-TWFE".into(),
443 })
444 }
445
446 fn parallel_trends_test(
451 y: &ArrayView1<f64>,
452 treated: &ArrayView1<f64>,
453 n_units: usize,
454 n_periods: usize,
455 treat_period: usize,
456 ) -> StatsResult<f64> {
457 let n_pre = n_units * treat_period;
459 if n_pre < 4 {
460 return Ok(1.0); }
462 let k_pre = 3; let mut x_pre = Array2::<f64>::zeros((n_pre, k_pre));
464 let mut y_pre = Array1::<f64>::zeros(n_pre);
465 let mut row = 0;
466 for i in 0..n_units {
467 for t in 0..treat_period {
468 y_pre[row] = y[i * n_periods + t];
469 x_pre[[row, 0]] = 1.0; x_pre[[row, 1]] = t as f64; x_pre[[row, 2]] = treated[i] * (t as f64); row += 1;
473 }
474 }
475 let (beta_pre, resid_pre, xtx_inv_pre) = ols_fit_did(&x_pre.view(), &y_pre.view())?;
476 let df_pre = (n_pre - k_pre) as f64;
477 let s2_pre = resid_pre.iter().map(|&r| r * r).sum::<f64>() / df_pre.max(1.0);
478 let var_coef = xtx_inv_pre[[k_pre - 1, k_pre - 1]] * s2_pre;
479 let se = var_coef.max(0.0).sqrt();
480 let t = if se > 0.0 {
481 beta_pre[k_pre - 1] / se
482 } else {
483 0.0
484 };
485 Ok(t_dist_p_value_did(t, df_pre))
486 }
487}
488
489pub struct SyntheticControl {
499 pub max_iter: usize,
501 pub tol: f64,
503}
504
505impl SyntheticControl {
506 pub fn new() -> Self {
508 Self {
509 max_iter: 2000,
510 tol: 1e-8,
511 }
512 }
513
514 pub fn fit_weights(
523 &self,
524 y_treated: &ArrayView1<f64>,
525 y_donors: &ArrayView2<f64>,
526 ) -> StatsResult<Array1<f64>> {
527 let t_pre = y_treated.len();
528 let n_donors = y_donors.ncols();
529 if y_donors.nrows() != t_pre {
530 return Err(StatsError::DimensionMismatch(
531 "y_donors must have same number of rows as y_treated".into(),
532 ));
533 }
534 if n_donors == 0 {
535 return Err(StatsError::InvalidArgument(
536 "Need at least one donor unit".into(),
537 ));
538 }
539
540 let mut w: Array1<f64> = Array1::from_elem(n_donors, 1.0 / n_donors as f64);
543 let yd_t = y_donors.t(); let ytd_y: Array2<f64> = yd_t.dot(y_donors); let ytd_yt: Array1<f64> = yd_t.dot(y_treated); let step_denom: f64 = ytd_y
551 .rows()
552 .into_iter()
553 .map(|row| row.iter().map(|&v| v.abs()).sum::<f64>())
554 .fold(f64::NEG_INFINITY, f64::max);
555 let lr = if step_denom > 0.0 {
556 0.5 / step_denom
557 } else {
558 1e-3
559 };
560
561 for _ in 0..self.max_iter {
562 let grad = ytd_y.dot(&w) - &ytd_yt;
564 let w_new_raw = &w - &grad.mapv(|g| g * lr);
565 let w_new = project_simplex(&w_new_raw.view());
566 let diff: f64 = (&w_new - &w).iter().map(|&d| d * d).sum::<f64>().sqrt();
567 w = w_new;
568 if diff < self.tol {
569 break;
570 }
571 }
572
573 Ok(w)
574 }
575
576 pub fn treatment_effects(
583 &self,
584 y_treated_post: &ArrayView1<f64>,
585 y_donors_post: &ArrayView2<f64>,
586 weights: &ArrayView1<f64>,
587 ) -> StatsResult<Array1<f64>> {
588 if y_donors_post.nrows() != y_treated_post.len() {
589 return Err(StatsError::DimensionMismatch(
590 "y_donors_post rows must equal y_treated_post length".into(),
591 ));
592 }
593 let synthetic = y_donors_post.dot(weights);
594 Ok(y_treated_post.to_owned() - synthetic)
595 }
596}
597
598fn project_simplex(v: &ArrayView1<f64>) -> Array1<f64> {
600 let n = v.len();
601 let mut u: Vec<f64> = v.to_vec();
602 u.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
603 let mut rho = 0_usize;
604 let mut cum = 0.0_f64;
605 for (j, &uj) in u.iter().enumerate() {
606 cum += uj;
607 if uj - (cum - 1.0) / (j as f64 + 1.0) > 0.0 {
608 rho = j;
609 }
610 }
611 let cum_rho: f64 = u[..=rho].iter().sum();
612 let theta = (cum_rho - 1.0) / (rho as f64 + 1.0);
613 v.mapv(|vi| (vi - theta).max(0.0))
614}
615
616impl Default for SyntheticControl {
617 fn default() -> Self {
618 Self::new()
619 }
620}
621
622pub struct EventStudy {
633 pub n_pre: usize,
635 pub n_post: usize,
637}
638
639impl EventStudy {
640 pub fn new(n_pre: usize, n_post: usize) -> Self {
642 Self { n_pre, n_post }
643 }
644
645 pub fn estimate(
654 &self,
655 y: &ArrayView1<f64>,
656 treated: &ArrayView1<f64>,
657 n_units: usize,
658 n_periods: usize,
659 treat_period: usize,
660 ) -> StatsResult<EventStudyResult> {
661 let n = n_units * n_periods;
662 if y.len() != n {
663 return Err(StatsError::DimensionMismatch(
664 "y length != n_units * n_periods".into(),
665 ));
666 }
667
668 let n_event_dummies = self.n_pre + self.n_post - 1; let k = (n_units - 1) + (n_periods - 1) + n_event_dummies;
677 let mut xmat = Array2::<f64>::zeros((n, k));
678 let mut y_vec = Array1::<f64>::zeros(n);
679
680 let event_times: Vec<i64> = {
681 let mut v: Vec<i64> = (-(self.n_pre as i64)..=(self.n_post as i64 - 1)).collect();
682 v.retain(|&l| l != -1); v
684 };
685
686 for i in 0..n_units {
687 for t in 0..n_periods {
688 let row = i * n_periods + t;
689 y_vec[row] = y[row];
690 if i > 0 {
692 xmat[[row, i - 1]] = 1.0;
693 }
694 if t > 0 {
696 xmat[[row, n_units - 1 + t - 1]] = 1.0;
697 }
698 if treated[i] > 0.5 {
700 let rel_time = (t as i64) - (treat_period as i64);
701 for (d_idx, &et) in event_times.iter().enumerate() {
702 if rel_time == et {
703 xmat[[row, (n_units - 1) + (n_periods - 1) + d_idx]] = 1.0;
704 }
705 }
706 }
707 }
708 }
709
710 let (beta, resid, xtx_inv) = ols_fit_did(&xmat.view(), &y_vec.view())?;
711 let df = (n - k) as f64;
712 let s2 = resid.iter().map(|&r| r * r).sum::<f64>() / df.max(1.0);
713 let t_crit = t_critical_did(0.025, df as usize);
714 let fe_offset = (n_units - 1) + (n_periods - 1);
715
716 let mut coefficients = Vec::with_capacity(n_event_dummies);
717 for (d_idx, &et) in event_times.iter().enumerate() {
718 let coef_idx = fe_offset + d_idx;
719 let est = beta[coef_idx];
720 let se = (xtx_inv[[coef_idx, coef_idx]] * s2).max(0.0).sqrt();
721 let t = if se > 0.0 { est / se } else { 0.0 };
722 let p = t_dist_p_value_did(t, df);
723 coefficients.push(EventCoefficient {
724 relative_time: et,
725 estimate: est,
726 std_error: se,
727 t_stat: t,
728 p_value: p,
729 conf_interval: [est - t_crit * se, est + t_crit * se],
730 });
731 }
732
733 let n_pre_coefs = self.n_pre.saturating_sub(1); let (pre_f, pre_p) = if n_pre_coefs > 0 {
736 let pre_coef_idxs: Vec<usize> = (0..n_pre_coefs).map(|j| fe_offset + j).collect();
738 let rss_ur = resid.iter().map(|&r| r * r).sum::<f64>();
739 let mut x_r = xmat.clone();
741 for &idx in &pre_coef_idxs {
742 for i in 0..n {
743 x_r[[i, idx]] = 0.0;
744 }
745 }
746 let cols_r: Vec<usize> = (0..k).filter(|c| !pre_coef_idxs.contains(c)).collect();
748 let mut xr = Array2::<f64>::zeros((n, cols_r.len()));
749 for (new_j, &old_j) in cols_r.iter().enumerate() {
750 for i in 0..n {
751 xr[[i, new_j]] = xmat[[i, old_j]];
752 }
753 }
754 let (_br, resid_r, _) = ols_fit_did(&xr.view(), &y_vec.view())?;
755 let rss_r = resid_r.iter().map(|&r| r * r).sum::<f64>();
756 let f = ((rss_r - rss_ur) / n_pre_coefs as f64) / (rss_ur / df).max(1e-15);
757 let chi2 = f * n_pre_coefs as f64;
759 let p_f = 1.0 - regularized_gamma_lower_did(n_pre_coefs as f64 / 2.0, chi2 / 2.0);
760 (f, p_f)
761 } else {
762 (0.0, 1.0)
763 };
764
765 Ok(EventStudyResult {
766 coefficients,
767 pre_trend_f: pre_f,
768 pre_trend_p: pre_p,
769 pre_trend_df: n_pre_coefs,
770 })
771 }
772}
773
774fn regularized_gamma_lower_did(a: f64, x: f64) -> f64 {
775 if x < 0.0 {
776 return 0.0;
777 }
778 if x == 0.0 {
779 return 0.0;
780 }
781 if x < a + 1.0 {
782 let mut ap = a;
783 let mut del = 1.0 / a;
784 let mut sum = del;
785 for _ in 0..200 {
786 ap += 1.0;
787 del *= x / ap;
788 sum += del;
789 if del.abs() < sum.abs() * 3e-15 {
790 break;
791 }
792 }
793 sum * (-x + a * x.ln() - ln_gamma(a)).exp()
794 } else {
795 1.0 - regularized_gamma_upper_did(a, x)
796 }
797}
798
799fn regularized_gamma_upper_did(a: f64, x: f64) -> f64 {
800 let fpmin = 1e-300_f64;
801 let mut b = x + 1.0 - a;
802 let mut c = 1.0 / fpmin;
803 let mut d = 1.0 / b;
804 let mut h = d;
805 for i in 1..=200_i64 {
806 let an = -(i as f64) * ((i as f64) - a);
807 b += 2.0;
808 d = an * d + b;
809 if d.abs() < fpmin {
810 d = fpmin;
811 }
812 c = b + an / c;
813 if c.abs() < fpmin {
814 c = fpmin;
815 }
816 d = 1.0 / d;
817 let del = d * c;
818 h *= del;
819 if (del - 1.0).abs() < 3e-15 {
820 break;
821 }
822 }
823 (-x + a * x.ln() - ln_gamma(a)).exp() * h
824}
825
826pub struct StaggeredDiD {
838 pub n_bootstrap: usize,
840 pub seed: u64,
842}
843
844impl StaggeredDiD {
845 pub fn new(n_bootstrap: usize, seed: u64) -> Self {
847 Self { n_bootstrap, seed }
848 }
849
850 pub fn estimate(
859 &self,
860 y: &ArrayView2<f64>,
861 g: &[i64],
862 n_units: usize,
863 n_periods: usize,
864 ) -> StatsResult<StaggeredDiDResult> {
865 if y.nrows() != n_units || y.ncols() != n_periods {
866 return Err(StatsError::DimensionMismatch(
867 "y must be (n_units × n_periods)".into(),
868 ));
869 }
870 if g.len() != n_units {
871 return Err(StatsError::DimensionMismatch(
872 "g must have length n_units".into(),
873 ));
874 }
875
876 let mut cohorts: Vec<i64> = g
878 .iter()
879 .filter(|&&gi| gi < i64::MAX && gi >= 0)
880 .cloned()
881 .collect::<std::collections::HashSet<i64>>()
882 .into_iter()
883 .collect();
884 cohorts.sort();
885
886 let mut att_gt_vec: Vec<AttGt> = Vec::new();
887
888 for &cohort in &cohorts {
889 let treated_ids: Vec<usize> = (0..n_units).filter(|&i| g[i] == cohort).collect();
891 for t in 0..n_periods {
894 let t_i64 = t as i64;
895 let control_ids: Vec<usize> = (0..n_units)
897 .filter(|&i| g[i] == i64::MAX || g[i] > t_i64)
898 .collect();
899
900 if treated_ids.is_empty() || control_ids.is_empty() {
901 continue;
902 }
903
904 let t_ref = (cohort - 1) as usize;
906 if t_ref >= n_periods {
907 continue;
908 }
909
910 let (att, se) = self.compute_att_gt(y, &treated_ids, &control_ids, t, t_ref)?;
914
915 let p = normal_p_value(if se > 0.0 { att / se } else { 0.0 });
916 att_gt_vec.push(AttGt {
917 cohort,
918 period: t_i64,
919 att,
920 std_error: se,
921 p_value: p,
922 });
923 }
924 }
925
926 if att_gt_vec.is_empty() {
927 return Err(StatsError::InsufficientData(
928 "No valid (cohort, period) pairs found".into(),
929 ));
930 }
931
932 let post_atts: Vec<&AttGt> = att_gt_vec
934 .iter()
935 .filter(|ag| ag.period >= ag.cohort)
936 .collect();
937 let aggregate_att = if post_atts.is_empty() {
938 0.0
939 } else {
940 post_atts.iter().map(|ag| ag.att).sum::<f64>() / post_atts.len() as f64
941 };
942 let aggregate_se = if post_atts.is_empty() {
944 0.0
945 } else {
946 let var_sum: f64 = post_atts.iter().map(|ag| ag.std_error * ag.std_error).sum();
947 (var_sum / (post_atts.len() * post_atts.len()) as f64).sqrt()
948 };
949 let aggregate_p = normal_p_value(if aggregate_se > 0.0 {
950 aggregate_att / aggregate_se
951 } else {
952 0.0
953 });
954
955 Ok(StaggeredDiDResult {
956 att_gt: att_gt_vec,
957 aggregate_att,
958 aggregate_se,
959 aggregate_p,
960 })
961 }
962
963 fn compute_att_gt(
965 &self,
966 y: &ArrayView2<f64>,
967 treated_ids: &[usize],
968 control_ids: &[usize],
969 t: usize,
970 t_ref: usize,
971 ) -> StatsResult<(f64, f64)> {
972 let n_t = treated_ids.len();
973 let n_c = control_ids.len();
974
975 let delta_treated: Vec<f64> = treated_ids
977 .iter()
978 .map(|&i| y[[i, t]] - y[[i, t_ref]])
979 .collect();
980 let delta_control: Vec<f64> = control_ids
982 .iter()
983 .map(|&i| y[[i, t]] - y[[i, t_ref]])
984 .collect();
985
986 let mean_t = delta_treated.iter().sum::<f64>() / n_t as f64;
987 let mean_c = delta_control.iter().sum::<f64>() / n_c as f64;
988 let att = mean_t - mean_c;
989
990 let var_t = if n_t > 1 {
992 delta_treated
993 .iter()
994 .map(|&v| (v - mean_t).powi(2))
995 .sum::<f64>()
996 / (n_t * (n_t - 1)) as f64
997 } else {
998 0.0
999 };
1000 let var_c = if n_c > 1 {
1001 delta_control
1002 .iter()
1003 .map(|&v| (v - mean_c).powi(2))
1004 .sum::<f64>()
1005 / (n_c * (n_c - 1)) as f64
1006 } else {
1007 0.0
1008 };
1009 let se = (var_t + var_c).sqrt();
1010
1011 Ok((att, se))
1012 }
1013}
1014
1015#[cfg(test)]
1020mod tests {
1021 use super::*;
1022 use scirs2_core::ndarray::{array, Array1, Array2};
1023
1024 #[test]
1025 fn test_did_no_effect() {
1026 let n_units = 4_usize;
1028 let n_periods = 4_usize;
1029 let treat_period = 2_usize;
1030 let treated = array![1.0, 1.0, 0.0, 0.0];
1032 let unit_fe = [1.0, 2.0, 1.5, 2.5];
1034 let time_fe = [0.0, 1.0, 2.0, 3.0];
1035 let mut y_vec = Array1::<f64>::zeros(n_units * n_periods);
1036 for i in 0..n_units {
1037 for t in 0..n_periods {
1038 y_vec[i * n_periods + t] = unit_fe[i] + time_fe[t];
1039 }
1040 }
1041 let res = DiD::estimate(
1042 &y_vec.view(),
1043 &treated.view(),
1044 n_units,
1045 n_periods,
1046 treat_period,
1047 )
1048 .expect("DiD estimate should succeed");
1049 assert!(
1050 res.att.abs() < 0.1,
1051 "ATT should be ~0 when there is no effect, got {}",
1052 res.att
1053 );
1054 assert_eq!(res.n_treated, 2);
1055 assert_eq!(res.n_control, 2);
1056 }
1057
1058 #[test]
1059 fn test_did_known_effect() {
1060 let n_units = 4_usize;
1061 let n_periods = 4_usize;
1062 let treat_period = 2_usize;
1063 let treated = array![1.0, 1.0, 0.0, 0.0];
1064 let unit_fe = [0.0, 0.0, 0.0, 0.0];
1065 let time_fe = [0.0, 0.0, 0.0, 0.0];
1066 let treatment_effect = 5.0_f64;
1067 let mut y_vec = Array1::<f64>::zeros(n_units * n_periods);
1068 for i in 0..n_units {
1069 for t in 0..n_periods {
1070 let te = if treated[i] > 0.5 && t >= treat_period {
1071 treatment_effect
1072 } else {
1073 0.0
1074 };
1075 y_vec[i * n_periods + t] = unit_fe[i] + time_fe[t] + te;
1076 }
1077 }
1078 let res = DiD::estimate(
1079 &y_vec.view(),
1080 &treated.view(),
1081 n_units,
1082 n_periods,
1083 treat_period,
1084 )
1085 .expect("DiD estimate should succeed");
1086 assert!(
1087 (res.att - treatment_effect).abs() < 0.5,
1088 "ATT should be ~5.0, got {}",
1089 res.att
1090 );
1091 }
1092
1093 #[test]
1094 fn test_synthetic_control_simplex_weights() {
1095 let n_donors = 4_usize;
1096 let t_pre = 10_usize;
1097 let treated: Array1<f64> = (0..t_pre).map(|t| t as f64).collect();
1098 let mut donors = Array2::<f64>::zeros((t_pre, n_donors));
1100 for t in 0..t_pre {
1101 donors[[t, 0]] = t as f64; donors[[t, 1]] = t as f64 * 2.0;
1103 donors[[t, 2]] = (t as f64).powi(2);
1104 donors[[t, 3]] = 0.0;
1105 }
1106 let sc = SyntheticControl::new();
1107 let weights = sc
1108 .fit_weights(&treated.view(), &donors.view())
1109 .expect("SyntheticControl fit should succeed");
1110 let sum: f64 = weights.iter().sum();
1112 assert!(
1113 (sum - 1.0).abs() < 1e-6,
1114 "Weights should sum to 1, got {}",
1115 sum
1116 );
1117 assert!(weights.iter().all(|&w| w >= -1e-10));
1119 }
1120
1121 #[test]
1122 fn test_event_study_no_pre_trends() {
1123 let n_units = 6_usize;
1124 let n_periods = 6_usize;
1125 let treat_period = 3_usize;
1126 let treated = array![1.0, 1.0, 1.0, 0.0, 0.0, 0.0];
1127 let treatment_effect = 3.0_f64;
1129 let mut y_vec = Array1::<f64>::zeros(n_units * n_periods);
1130 for i in 0..n_units {
1131 for t in 0..n_periods {
1132 let te = if treated[i] > 0.5 && t >= treat_period {
1133 treatment_effect
1134 } else {
1135 0.0
1136 };
1137 y_vec[i * n_periods + t] = te;
1138 }
1139 }
1140 let es = EventStudy::new(2, 3);
1141 let res = es
1142 .estimate(
1143 &y_vec.view(),
1144 &treated.view(),
1145 n_units,
1146 n_periods,
1147 treat_period,
1148 )
1149 .expect("EventStudy should succeed");
1150 let post_coefs: Vec<&EventCoefficient> = res
1152 .coefficients
1153 .iter()
1154 .filter(|c| c.relative_time >= 0)
1155 .collect();
1156 assert!(
1157 !post_coefs.is_empty(),
1158 "Should have post-treatment coefficients"
1159 );
1160 }
1161}