1use crate::error::{StatsError, StatsResult};
22use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
23
24#[derive(Debug, Clone)]
30pub struct PSResult {
31 pub ate: f64,
33 pub ate_se: f64,
35 pub att: f64,
37 pub att_se: f64,
39 pub atc: f64,
41 pub atc_se: f64,
43 pub ate_p: f64,
45 pub att_p: f64,
47 pub atc_p: f64,
49 pub propensity_scores: Array1<f64>,
51 pub estimator: String,
53}
54
55#[derive(Debug, Clone)]
57pub struct OverlapResult {
58 pub ps: Array1<f64>,
60 pub common_support_idx: Vec<usize>,
62 pub ps_lower: f64,
64 pub ps_upper: f64,
66 pub frac_treated_in_support: f64,
68 pub frac_control_in_support: f64,
70 pub overlap_coefficient: f64,
72}
73
74#[derive(Debug, Clone)]
76pub struct MatchingResult {
77 pub att: f64,
79 pub att_se: f64,
81 pub p_value: f64,
83 pub conf_interval: [f64; 2],
85 pub n_matched_treated: usize,
87 pub method: String,
89}
90
91fn normal_p_value(z: f64) -> f64 {
96 2.0 * (1.0 - normal_cdf(z.abs()))
97}
98
99fn normal_cdf(x: f64) -> f64 {
100 0.5 * (1.0 + erf_approx(x / std::f64::consts::SQRT_2))
101}
102
103fn erf_approx(x: f64) -> f64 {
104 let t = 1.0 / (1.0 + 0.3275911 * x.abs());
105 let y = 1.0
106 - (0.254829592
107 + (-0.284496736 + (1.421413741 + (-1.453152027 + 1.061405429 * t) * t) * t) * t)
108 * t
109 * (-x * x).exp();
110 if x >= 0.0 {
111 y
112 } else {
113 -y
114 }
115}
116
117pub struct PropensityScoreModel {
125 pub max_iter: usize,
127 pub tol: f64,
129 pub lambda: f64,
131}
132
133impl PropensityScoreModel {
134 pub fn new() -> Self {
136 Self {
137 max_iter: 200,
138 tol: 1e-8,
139 lambda: 1e-4,
140 }
141 }
142
143 pub fn fit(&self, x: &ArrayView2<f64>, w: &ArrayView1<f64>) -> StatsResult<Array1<f64>> {
152 let n = x.nrows();
153 let k = x.ncols();
154 if w.len() != n {
155 return Err(StatsError::DimensionMismatch(
156 "x rows must equal w length".into(),
157 ));
158 }
159 let mut xmat = Array2::<f64>::zeros((n, k + 1));
161 for i in 0..n {
162 xmat[[i, 0]] = 1.0;
163 for j in 0..k {
164 xmat[[i, j + 1]] = x[[i, j]];
165 }
166 }
167 let k1 = k + 1;
168 let mut beta = Array1::<f64>::zeros(k1);
169
170 for _iter in 0..self.max_iter {
171 let eta: Array1<f64> = xmat.dot(&beta);
173 let mu: Array1<f64> = eta.mapv(sigmoid);
174 let v: Array1<f64> = mu.mapv(|m| (m * (1.0 - m)).max(1e-8));
176 let grad_data = xmat.t().dot(&(w.to_owned() - &mu));
178 let mut grad = grad_data;
179 for j in 1..k1 {
180 grad[j] -= self.lambda * beta[j];
181 }
182 let sqrt_v: Array1<f64> = v.mapv(|vi| vi.sqrt());
185 let mut wxmat = Array2::<f64>::zeros((n, k1));
186 for i in 0..n {
187 for j in 0..k1 {
188 wxmat[[i, j]] = sqrt_v[i] * xmat[[i, j]];
189 }
190 }
191 let mut hess = wxmat.t().dot(&wxmat);
192 for j in 1..k1 {
193 hess[[j, j]] += self.lambda;
194 }
195 let h_inv = cholesky_invert_ps(&hess.view())?;
196 let delta = h_inv.dot(&grad);
197 let step_norm: f64 = delta.iter().map(|&d| d * d).sum::<f64>().sqrt();
198 beta = &beta + δ
199 if step_norm < self.tol {
200 break;
201 }
202 }
203 Ok(beta)
204 }
205
206 pub fn predict(&self, x: &ArrayView2<f64>, beta: &ArrayView1<f64>) -> StatsResult<Array1<f64>> {
212 let n = x.nrows();
213 let k = x.ncols();
214 if beta.len() != k + 1 {
215 return Err(StatsError::DimensionMismatch(format!(
216 "beta length {} != k+1 = {}",
217 beta.len(),
218 k + 1
219 )));
220 }
221 let mut eta = Array1::<f64>::zeros(n);
222 for i in 0..n {
223 eta[i] = beta[0];
224 for j in 0..k {
225 eta[i] += beta[j + 1] * x[[i, j]];
226 }
227 }
228 Ok(eta.mapv(sigmoid))
229 }
230}
231
232impl Default for PropensityScoreModel {
233 fn default() -> Self {
234 Self::new()
235 }
236}
237
238fn sigmoid(x: f64) -> f64 {
239 if x > 500.0 {
240 return 1.0;
241 }
242 if x < -500.0 {
243 return 0.0;
244 }
245 1.0 / (1.0 + (-x).exp())
246}
247
248fn cholesky_invert_ps(a: &scirs2_core::ndarray::ArrayView2<f64>) -> StatsResult<Array2<f64>> {
249 let n = a.nrows();
250 let mut l = Array2::<f64>::zeros((n, n));
251 for i in 0..n {
252 for j in 0..=i {
253 let mut s = a[[i, j]];
254 for p in 0..j {
255 s -= l[[i, p]] * l[[j, p]];
256 }
257 if i == j {
258 if s <= 0.0 {
259 return Err(StatsError::ComputationError(
260 "Hessian not positive definite (PS logistic)".into(),
261 ));
262 }
263 l[[i, j]] = s.sqrt();
264 } else {
265 l[[i, j]] = s / l[[j, j]];
266 }
267 }
268 }
269 let mut linv = Array2::<f64>::zeros((n, n));
270 for j in 0..n {
271 linv[[j, j]] = 1.0 / l[[j, j]];
272 for i in (j + 1)..n {
273 let mut s = 0.0_f64;
274 for p in j..i {
275 s += l[[i, p]] * linv[[p, j]];
276 }
277 linv[[i, j]] = -s / l[[i, i]];
278 }
279 }
280 Ok(linv.t().dot(&linv))
281}
282
283pub struct IPW;
295
296impl IPW {
297 pub fn estimate(
305 y: &ArrayView1<f64>,
306 w: &ArrayView1<f64>,
307 ps: &ArrayView1<f64>,
308 trim_eps: f64,
309 ) -> StatsResult<PSResult> {
310 let n = y.len();
311 if w.len() != n || ps.len() != n {
312 return Err(StatsError::DimensionMismatch(
313 "y, w, ps must all have the same length".into(),
314 ));
315 }
316 let eps = trim_eps.max(0.0).min(0.49);
317
318 let ps_trim: Array1<f64> = ps.mapv(|p| p.clamp(eps, 1.0 - eps));
320
321 let ate_terms: Array1<f64> = (0..n)
323 .map(|i| {
324 let wi = w[i];
325 let yi = y[i];
326 let pi = ps_trim[i];
327 wi * yi / pi - (1.0 - wi) * yi / (1.0 - pi)
328 })
329 .collect();
330 let ate = ate_terms.iter().sum::<f64>() / n as f64;
331
332 let n_treated: usize = w.iter().filter(|&&wi| wi > 0.5).count();
334 let att_num: f64 = (0..n).filter(|&i| w[i] > 0.5).map(|i| y[i]).sum::<f64>();
335 let att_denom_ctrl_num: f64 = (0..n)
336 .filter(|&i| w[i] <= 0.5)
337 .map(|i| y[i] * ps_trim[i] / (1.0 - ps_trim[i]))
338 .sum::<f64>();
339 let att_denom_ctrl_den: f64 = (0..n)
340 .filter(|&i| w[i] <= 0.5)
341 .map(|i| ps_trim[i] / (1.0 - ps_trim[i]))
342 .sum::<f64>();
343 let att = if n_treated > 0 && att_denom_ctrl_den > 1e-10 {
344 att_num / n_treated as f64 - att_denom_ctrl_num / att_denom_ctrl_den
345 } else {
346 0.0
347 };
348
349 let n_control = n - n_treated;
351 let atc_ctrl_mean = if n_control > 0 {
352 (0..n).filter(|&i| w[i] <= 0.5).map(|i| y[i]).sum::<f64>() / n_control as f64
353 } else {
354 0.0
355 };
356 let atc_trt_num: f64 = (0..n)
357 .filter(|&i| w[i] > 0.5)
358 .map(|i| y[i] * (1.0 - ps_trim[i]) / ps_trim[i])
359 .sum::<f64>();
360 let atc_trt_den: f64 = (0..n)
361 .filter(|&i| w[i] > 0.5)
362 .map(|i| (1.0 - ps_trim[i]) / ps_trim[i])
363 .sum::<f64>();
364 let atc = if atc_trt_den > 1e-10 {
365 atc_trt_num / atc_trt_den - atc_ctrl_mean
366 } else {
367 0.0
368 };
369
370 let ate_se = bootstrap_se_ipw_ate(y, w, &ps_trim.view(), ate, n)?;
372 let att_se = bootstrap_se_ipw_att(y, w, &ps_trim.view(), att, n)?;
373 let atc_se = ate_se; let ate_p = normal_p_value(if ate_se > 0.0 { ate / ate_se } else { 0.0 });
376 let att_p = normal_p_value(if att_se > 0.0 { att / att_se } else { 0.0 });
377 let atc_p = normal_p_value(if atc_se > 0.0 { atc / atc_se } else { 0.0 });
378
379 Ok(PSResult {
380 ate,
381 ate_se,
382 att,
383 att_se,
384 atc,
385 atc_se,
386 ate_p,
387 att_p,
388 atc_p,
389 propensity_scores: ps_trim,
390 estimator: "IPW".into(),
391 })
392 }
393}
394
395fn bootstrap_se_ipw_ate(
397 y: &ArrayView1<f64>,
398 w: &ArrayView1<f64>,
399 ps: &ArrayView1<f64>,
400 ate: f64,
401 n: usize,
402) -> StatsResult<f64> {
403 let psi: Array1<f64> = (0..n)
404 .map(|i| {
405 let wi = w[i];
406 let yi = y[i];
407 let pi = ps[i];
408 wi * yi / pi - (1.0 - wi) * yi / (1.0 - pi) - ate
409 })
410 .collect();
411 let var_psi: f64 = psi.iter().map(|&p| p * p).sum::<f64>() / (n * (n - 1).max(1)) as f64;
412 Ok(var_psi.sqrt())
413}
414
415fn bootstrap_se_ipw_att(
417 y: &ArrayView1<f64>,
418 w: &ArrayView1<f64>,
419 ps: &ArrayView1<f64>,
420 att: f64,
421 n: usize,
422) -> StatsResult<f64> {
423 let n_treated: f64 = w.iter().filter(|&&wi| wi > 0.5).count() as f64;
424 if n_treated < 1.0 {
425 return Ok(0.0);
426 }
427 let psi: Array1<f64> = (0..n)
428 .map(|i| {
429 let wi = w[i];
430 let yi = y[i];
431 let pi = ps[i];
432 (wi * yi - (1.0 - wi) * pi * yi / (1.0 - pi)) / (n_treated / n as f64) - att
434 })
435 .collect();
436 let var_psi: f64 = psi.iter().map(|&p| p * p).sum::<f64>() / (n * (n - 1).max(1)) as f64;
437 Ok(var_psi.sqrt())
438}
439
440#[derive(Debug, Clone, Copy, PartialEq, Eq)]
446pub enum MatchingMethod {
447 NearestNeighbour,
449 Caliper,
451 Kernel,
453}
454
455pub struct PSMatching {
457 pub method: MatchingMethod,
459 pub caliper: Option<f64>,
461 pub n_neighbours: usize,
463 pub kernel_bandwidth: Option<f64>,
465}
466
467impl PSMatching {
468 pub fn new(method: MatchingMethod) -> Self {
470 Self {
471 method,
472 caliper: None,
473 n_neighbours: 1,
474 kernel_bandwidth: None,
475 }
476 }
477
478 pub fn estimate_att(
485 &self,
486 y: &ArrayView1<f64>,
487 w: &ArrayView1<f64>,
488 ps: &ArrayView1<f64>,
489 ) -> StatsResult<MatchingResult> {
490 let n = y.len();
491 if w.len() != n || ps.len() != n {
492 return Err(StatsError::DimensionMismatch(
493 "y, w, ps must have equal length".into(),
494 ));
495 }
496
497 let treated_idx: Vec<usize> = (0..n).filter(|&i| w[i] > 0.5).collect();
498 let control_idx: Vec<usize> = (0..n).filter(|&i| w[i] <= 0.5).collect();
499
500 if treated_idx.is_empty() {
501 return Err(StatsError::InsufficientData("No treated units".into()));
502 }
503 if control_idx.is_empty() {
504 return Err(StatsError::InsufficientData("No control units".into()));
505 }
506
507 let logit_ps: Array1<f64> = ps.mapv(|p| logit(p.clamp(1e-8, 1.0 - 1e-8)));
509 let logit_sd = std_dev_vec(&logit_ps.to_vec());
510 let caliper_val = self.caliper.unwrap_or(0.2 * logit_sd);
511 let bw = self.kernel_bandwidth.unwrap_or(0.1 * logit_sd);
512
513 match self.method {
514 MatchingMethod::NearestNeighbour | MatchingMethod::Caliper => {
515 self.nn_matching_att(y, &treated_idx, &control_idx, &logit_ps.view(), caliper_val)
516 }
517 MatchingMethod::Kernel => {
518 self.kernel_matching_att(y, &treated_idx, &control_idx, ps, bw)
519 }
520 }
521 }
522
523 fn nn_matching_att(
524 &self,
525 y: &ArrayView1<f64>,
526 treated_idx: &[usize],
527 control_idx: &[usize],
528 logit_ps: &ArrayView1<f64>,
529 caliper: f64,
530 ) -> StatsResult<MatchingResult> {
531 let mut matched_diffs: Vec<f64> = Vec::new();
532 let use_caliper = self.method == MatchingMethod::Caliper;
533
534 for &ti in treated_idx {
535 let lps_t = logit_ps[ti];
536 let best = control_idx
537 .iter()
538 .map(|&ci| (ci, (logit_ps[ci] - lps_t).abs()))
539 .filter(|(_, dist)| !use_caliper || *dist <= caliper)
540 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
541 if let Some((best_ci, _)) = best {
542 matched_diffs.push(y[ti] - y[best_ci]);
543 }
544 }
545
546 if matched_diffs.is_empty() {
547 return Err(StatsError::InsufficientData(
548 "No matches found; try increasing the caliper".into(),
549 ));
550 }
551
552 let n_m = matched_diffs.len();
553 let att = matched_diffs.iter().sum::<f64>() / n_m as f64;
554 let se = if n_m > 1 {
555 let var = matched_diffs
556 .iter()
557 .map(|&d| (d - att).powi(2))
558 .sum::<f64>()
559 / (n_m * (n_m - 1)) as f64;
560 var.sqrt()
561 } else {
562 0.0
563 };
564 let t = if se > 0.0 { att / se } else { 0.0 };
565 let p = normal_p_value(t);
566 let ci = [att - 1.96 * se, att + 1.96 * se];
567
568 let method_name = if self.method == MatchingMethod::Caliper {
569 "Caliper-matching"
570 } else {
571 "NN-matching"
572 };
573
574 Ok(MatchingResult {
575 att,
576 att_se: se,
577 p_value: p,
578 conf_interval: ci,
579 n_matched_treated: n_m,
580 method: method_name.into(),
581 })
582 }
583
584 fn kernel_matching_att(
585 &self,
586 y: &ArrayView1<f64>,
587 treated_idx: &[usize],
588 control_idx: &[usize],
589 ps: &ArrayView1<f64>,
590 bw: f64,
591 ) -> StatsResult<MatchingResult> {
592 let mut diffs: Vec<f64> = Vec::with_capacity(treated_idx.len());
593 for &ti in treated_idx {
594 let psi = ps[ti];
595 let weights: Vec<f64> = control_idx
597 .iter()
598 .map(|&ci| {
599 let u = (ps[ci] - psi) / bw;
600 if u.abs() < 1.0 {
601 0.75 * (1.0 - u * u)
602 } else {
603 0.0
604 }
605 })
606 .collect();
607 let total_w: f64 = weights.iter().sum();
608 if total_w < 1e-10 {
609 continue;
610 }
611 let y_ctrl_wt: f64 = control_idx
612 .iter()
613 .zip(weights.iter())
614 .map(|(&ci, &wi)| y[ci] * wi)
615 .sum::<f64>()
616 / total_w;
617 diffs.push(y[ti] - y_ctrl_wt);
618 }
619 if diffs.is_empty() {
620 return Err(StatsError::InsufficientData(
621 "No matches with positive kernel weight; reduce bandwidth".into(),
622 ));
623 }
624 let n_m = diffs.len();
625 let att = diffs.iter().sum::<f64>() / n_m as f64;
626 let se = if n_m > 1 {
627 let var =
628 diffs.iter().map(|&d| (d - att).powi(2)).sum::<f64>() / (n_m * (n_m - 1)) as f64;
629 var.sqrt()
630 } else {
631 0.0
632 };
633 let t = if se > 0.0 { att / se } else { 0.0 };
634 let p = normal_p_value(t);
635 let ci = [att - 1.96 * se, att + 1.96 * se];
636 Ok(MatchingResult {
637 att,
638 att_se: se,
639 p_value: p,
640 conf_interval: ci,
641 n_matched_treated: n_m,
642 method: "Kernel-matching".into(),
643 })
644 }
645}
646
647fn logit(p: f64) -> f64 {
648 (p / (1.0 - p)).ln()
649}
650
651fn std_dev_vec(v: &[f64]) -> f64 {
652 let n = v.len();
653 if n < 2 {
654 return 1.0;
655 }
656 let mean = v.iter().sum::<f64>() / n as f64;
657 let var = v.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1) as f64;
658 var.sqrt().max(1e-15)
659}
660
661pub struct OverlapCheck {
667 pub trim_method: TrimMethod,
669}
670
671#[derive(Debug, Clone, Copy, PartialEq, Eq)]
673pub enum TrimMethod {
674 Crump,
676 MinMax,
678 Percentile,
680}
681
682impl OverlapCheck {
683 pub fn new(trim_method: TrimMethod) -> Self {
685 Self { trim_method }
686 }
687
688 pub fn check(&self, ps: &ArrayView1<f64>, w: &ArrayView1<f64>) -> StatsResult<OverlapResult> {
694 let n = ps.len();
695 if w.len() != n {
696 return Err(StatsError::DimensionMismatch(
697 "ps and w must have equal length".into(),
698 ));
699 }
700
701 let treated_ps: Vec<f64> = (0..n).filter(|&i| w[i] > 0.5).map(|i| ps[i]).collect();
702 let control_ps: Vec<f64> = (0..n).filter(|&i| w[i] <= 0.5).map(|i| ps[i]).collect();
703
704 if treated_ps.is_empty() || control_ps.is_empty() {
705 return Err(StatsError::InsufficientData(
706 "Need both treated and control units".into(),
707 ));
708 }
709
710 let (ps_lower, ps_upper) = match self.trim_method {
711 TrimMethod::Crump => {
712 (0.1_f64, 0.9_f64)
714 }
715 TrimMethod::MinMax => {
716 let min_t = treated_ps.iter().cloned().fold(f64::INFINITY, f64::min);
717 let max_t = treated_ps.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
718 let min_c = control_ps.iter().cloned().fold(f64::INFINITY, f64::min);
719 let max_c = control_ps.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
720 (min_t.max(min_c), max_t.min(max_c))
721 }
722 TrimMethod::Percentile => {
723 let alpha = 0.05_f64;
725 let all_ps: Vec<f64> = ps.to_vec();
726 let lower = quantile_val(&all_ps, alpha);
727 let upper = quantile_val(&all_ps, 1.0 - alpha);
728 (lower, upper)
729 }
730 };
731
732 let common_support_idx: Vec<usize> = (0..n)
733 .filter(|&i| ps[i] >= ps_lower && ps[i] <= ps_upper)
734 .collect();
735
736 let n_t = treated_ps.len() as f64;
737 let n_c = control_ps.len() as f64;
738 let frac_t = treated_ps
739 .iter()
740 .filter(|&&p| p >= ps_lower && p <= ps_upper)
741 .count() as f64
742 / n_t.max(1.0);
743 let frac_c = control_ps
744 .iter()
745 .filter(|&&p| p >= ps_lower && p <= ps_upper)
746 .count() as f64
747 / n_c.max(1.0);
748
749 let overlap_coefficient = overlap_coef(&treated_ps, &control_ps);
751
752 Ok(OverlapResult {
753 ps: ps.to_owned(),
754 common_support_idx,
755 ps_lower,
756 ps_upper,
757 frac_treated_in_support: frac_t,
758 frac_control_in_support: frac_c,
759 overlap_coefficient,
760 })
761 }
762}
763
764fn quantile_val(v: &[f64], q: f64) -> f64 {
765 if v.is_empty() {
766 return 0.5;
767 }
768 let mut sorted = v.to_vec();
769 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
770 let idx = ((q * (sorted.len() - 1) as f64).round() as usize).min(sorted.len() - 1);
771 sorted[idx]
772}
773
774fn overlap_coef(ps_t: &[f64], ps_c: &[f64]) -> f64 {
776 if ps_t.is_empty() || ps_c.is_empty() {
777 return 0.0;
778 }
779 let all_min = ps_t
781 .iter()
782 .chain(ps_c.iter())
783 .cloned()
784 .fold(f64::INFINITY, f64::min);
785 let all_max = ps_t
786 .iter()
787 .chain(ps_c.iter())
788 .cloned()
789 .fold(f64::NEG_INFINITY, f64::max);
790 if (all_max - all_min).abs() < 1e-10 {
791 return 1.0;
792 }
793 let n_bins = 100_usize;
794 let step = (all_max - all_min) / n_bins as f64;
795 let mut oc = 0.0_f64;
796 for b in 0..n_bins {
797 let lo = all_min + b as f64 * step;
798 let hi = lo + step;
799 let ft = ps_t.iter().filter(|&&p| p >= lo && p < hi).count() as f64 / ps_t.len() as f64;
800 let fc = ps_c.iter().filter(|&&p| p >= lo && p < hi).count() as f64 / ps_c.len() as f64;
801 oc += ft.min(fc);
802 }
803 oc
804}
805
806pub fn ps_estimate(
820 y: &ArrayView1<f64>,
821 w: &ArrayView1<f64>,
822 x: &ArrayView2<f64>,
823 trim_eps: f64,
824) -> StatsResult<PSResult> {
825 let ps_model = PropensityScoreModel::new();
826 let beta = ps_model.fit(x, w)?;
827 let ps = ps_model.predict(x, &beta.view())?;
828 IPW::estimate(y, w, &ps.view(), trim_eps)
829}
830
831#[cfg(test)]
836mod tests {
837 use super::*;
838 use scirs2_core::ndarray::{array, Array1, Array2};
839
840 #[test]
841 fn test_logistic_regression_ps() {
842 let x = array![
844 [0.0],
845 [1.0],
846 [2.0],
847 [3.0],
848 [4.0],
849 [5.0],
850 [6.0],
851 [7.0],
852 [8.0],
853 [9.0]
854 ];
855 let w = array![0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0];
856 let model = PropensityScoreModel::new();
857 let beta = model
858 .fit(&x.view(), &w.view())
859 .expect("Logistic fit should succeed");
860 assert_eq!(beta.len(), 2);
861 assert!(
863 beta[1] > 0.0,
864 "Coefficient should be positive, got {}",
865 beta[1]
866 );
867 let ps = model
869 .predict(&x.view(), &beta.view())
870 .expect("Predict should succeed");
871 assert!(ps[9] > 0.5, "ps for x=9 should be > 0.5, got {}", ps[9]);
872 assert!(ps[0] < 0.5, "ps for x=0 should be < 0.5, got {}", ps[0]);
873 }
874
875 #[test]
876 fn test_ipw_zero_effect() {
877 let n = 100_usize;
879 let ps: Array1<f64> = (0..n).map(|i| 0.3 + 0.4 * (i as f64 / n as f64)).collect();
880 let w: Array1<f64> = ps.mapv(|p| if p > 0.5 { 1.0 } else { 0.0 });
881 let y: Array1<f64> = Array1::ones(n);
883 let res =
884 IPW::estimate(&y.view(), &w.view(), &ps.view(), 0.05).expect("IPW should succeed");
885 assert!(
886 res.ate.abs() < 0.1,
887 "ATE should be ~0 when no effect, got {}",
888 res.ate
889 );
890 }
891
892 #[test]
893 fn test_ps_matching_nn() {
894 let n = 40_usize;
895 let ps: Array1<f64> = (0..n).map(|i| 0.1 + 0.8 * i as f64 / n as f64).collect();
896 let w: Array1<f64> = ps.mapv(|p| if p > 0.5 { 1.0 } else { 0.0 });
897 let y: Array1<f64> = (0..n).map(|i| if w[i] > 0.5 { 5.0 } else { 3.0 }).collect();
899 let matcher = PSMatching::new(MatchingMethod::NearestNeighbour);
900 let res = matcher
901 .estimate_att(&y.view(), &w.view(), &ps.view())
902 .expect("NN matching should succeed");
903 assert!(
904 (res.att - 2.0).abs() < 0.5,
905 "ATT should be ~2.0, got {}",
906 res.att
907 );
908 }
909
910 #[test]
911 fn test_overlap_check_minmax() {
912 let ps = array![0.1, 0.3, 0.4, 0.5, 0.5, 0.2, 0.6, 0.7, 0.8, 0.9];
917 let w = array![0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0];
918 let checker = OverlapCheck::new(TrimMethod::MinMax);
919 let res = checker
920 .check(&ps.view(), &w.view())
921 .expect("Overlap check should succeed");
922 assert!(
923 res.ps_lower < res.ps_upper,
924 "lower={} >= upper={}",
925 res.ps_lower,
926 res.ps_upper
927 );
928 assert!(!res.common_support_idx.is_empty());
929 }
930
931 #[test]
932 fn test_ps_estimate_pipeline() {
933 let n = 60_usize;
934 let mut x_data = Array2::<f64>::zeros((n, 1));
935 let mut w_data = Array1::<f64>::zeros(n);
936 let mut y_data = Array1::<f64>::zeros(n);
937 for i in 0..n {
938 let xi = i as f64 / n as f64;
939 x_data[[i, 0]] = xi;
940 w_data[i] = if xi > 0.5 { 1.0 } else { 0.0 };
941 y_data[i] = if w_data[i] > 0.5 { 3.0 + xi } else { 1.0 + xi };
942 }
943 let res = ps_estimate(&y_data.view(), &w_data.view(), &x_data.view(), 0.05)
944 .expect("PS estimate pipeline should succeed");
945 assert!(res.ate.abs() > 0.0, "ATE should be non-zero");
947 }
948}