1use scirs2_core::ndarray::{Array1, Array2, Axis};
14use scirs2_core::random::essentials::Normal;
15use scirs2_core::random::thread_rng;
16use serde::{Deserialize, Serialize};
17use sklears_core::{
18 error::{Result, SklearsError},
19 prelude::{Fit, Transform},
20 traits::{Estimator, Trained, Untrained},
21 types::Float,
22};
23use std::collections::HashMap;
24use std::marker::PhantomData;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct CausalKernelConfig {
29 pub causal_method: CausalMethod,
31 pub treatment_bandwidth: Float,
33 pub outcome_bandwidth: Float,
35 pub n_components: usize,
37 pub regularization: Float,
39}
40
41impl Default for CausalKernelConfig {
42 fn default() -> Self {
43 Self {
44 causal_method: CausalMethod::TreatmentEffect,
45 treatment_bandwidth: 1.0,
46 outcome_bandwidth: 1.0,
47 n_components: 100,
48 regularization: 1e-5,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
55pub enum CausalMethod {
56 TreatmentEffect,
58 ConditionalTreatmentEffect,
60 InstrumentalVariable,
62 RegressionDiscontinuity,
64 DifferenceInDifferences,
66}
67
68#[derive(Debug, Clone)]
96pub struct CausalKernel<State = Untrained> {
97 config: CausalKernelConfig,
98
99 treatment_weights: Option<Array2<Float>>,
101 outcome_weights: Option<Array2<Float>>,
102 propensity_scores: Option<Array1<Float>>,
103 treatment_effects: Option<HashMap<String, Float>>,
104
105 _state: PhantomData<State>,
106}
107
108impl CausalKernel<Untrained> {
109 pub fn new(config: CausalKernelConfig) -> Self {
111 Self {
112 config,
113 treatment_weights: None,
114 outcome_weights: None,
115 propensity_scores: None,
116 treatment_effects: None,
117 _state: PhantomData,
118 }
119 }
120
121 pub fn with_components(n_components: usize) -> Self {
123 Self {
124 config: CausalKernelConfig {
125 n_components,
126 ..Default::default()
127 },
128 treatment_weights: None,
129 outcome_weights: None,
130 propensity_scores: None,
131 treatment_effects: None,
132 _state: PhantomData,
133 }
134 }
135
136 pub fn method(mut self, method: CausalMethod) -> Self {
138 self.config.causal_method = method;
139 self
140 }
141
142 pub fn treatment_bandwidth(mut self, gamma: Float) -> Self {
144 self.config.treatment_bandwidth = gamma;
145 self
146 }
147
148 fn estimate_propensity_scores(
150 &self,
151 x: &Array2<Float>,
152 treatment: &Array1<Float>,
153 ) -> Array1<Float> {
154 let n_samples = x.nrows();
155 let mut scores = Array1::zeros(n_samples);
156
157 for i in 0..n_samples {
159 let mut score = 0.0;
160 let mut weight_sum = 0.0;
161
162 for j in 0..n_samples {
163 let mut dist_sq = 0.0;
165 for k in 0..x.ncols() {
166 let diff = x[[i, k]] - x[[j, k]];
167 dist_sq += diff * diff;
168 }
169
170 let weight = (-dist_sq / (2.0 * self.config.treatment_bandwidth.powi(2))).exp();
171 score += weight * treatment[j];
172 weight_sum += weight;
173 }
174
175 scores[i] = if weight_sum > 1e-10 {
176 (score / weight_sum).max(0.01).min(0.99) } else {
178 0.5
179 };
180 }
181
182 scores
183 }
184
185 fn estimate_treatment_effect(
187 &self,
188 x: &Array2<Float>,
189 treatment: &Array1<Float>,
190 outcome: &Array1<Float>,
191 propensity_scores: &Array1<Float>,
192 ) -> HashMap<String, Float> {
193 let n_samples = x.nrows() as Float;
194
195 let mut ate_numerator_treated = 0.0;
197 let mut ate_numerator_control = 0.0;
198 let mut weight_sum_treated = 0.0;
199 let mut weight_sum_control = 0.0;
200
201 for i in 0..treatment.len() {
202 if treatment[i] > 0.5 {
203 let weight = 1.0 / propensity_scores[i];
205 ate_numerator_treated += weight * outcome[i];
206 weight_sum_treated += weight;
207 } else {
208 let weight = 1.0 / (1.0 - propensity_scores[i]);
210 ate_numerator_control += weight * outcome[i];
211 weight_sum_control += weight;
212 }
213 }
214
215 let ate = if weight_sum_treated > 0.0 && weight_sum_control > 0.0 {
216 (ate_numerator_treated / weight_sum_treated)
217 - (ate_numerator_control / weight_sum_control)
218 } else {
219 0.0
220 };
221
222 let treated_outcomes: Vec<Float> = treatment
224 .iter()
225 .zip(outcome.iter())
226 .filter_map(|(&t, &y)| if t > 0.5 { Some(y) } else { None })
227 .collect();
228
229 let control_outcomes: Vec<Float> = treatment
230 .iter()
231 .zip(outcome.iter())
232 .filter_map(|(&t, &y)| if t <= 0.5 { Some(y) } else { None })
233 .collect();
234
235 let naive_diff = if !treated_outcomes.is_empty() && !control_outcomes.is_empty() {
236 let treated_mean =
237 treated_outcomes.iter().sum::<Float>() / treated_outcomes.len() as Float;
238 let control_mean =
239 control_outcomes.iter().sum::<Float>() / control_outcomes.len() as Float;
240 treated_mean - control_mean
241 } else {
242 0.0
243 };
244
245 let mut effects = HashMap::new();
246 effects.insert("ate".to_string(), ate);
247 effects.insert("naive_difference".to_string(), naive_diff);
248 effects.insert("n_samples".to_string(), n_samples);
249 effects.insert("n_treated".to_string(), treated_outcomes.len() as Float);
250 effects.insert("n_control".to_string(), control_outcomes.len() as Float);
251
252 effects
253 }
254}
255
256impl Estimator for CausalKernel<Untrained> {
257 type Config = CausalKernelConfig;
258 type Error = SklearsError;
259 type Float = Float;
260
261 fn config(&self) -> &Self::Config {
262 &self.config
263 }
264}
265
266impl Fit<Array2<Float>, ()> for CausalKernel<Untrained> {
267 type Fitted = CausalKernel<Trained>;
268
269 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
270 if x.nrows() < 2 || x.ncols() < 3 {
271 return Err(SklearsError::InvalidInput(
272 "Input must have at least 2 samples and 3 columns (covariates, treatment, outcome)"
273 .to_string(),
274 ));
275 }
276
277 let n_covariates = x.ncols() - 2;
279 let covariates = x.slice_axis(Axis(1), (0..n_covariates).into()).to_owned();
280 let treatment = x.column(n_covariates).to_owned();
281 let outcome = x.column(n_covariates + 1).to_owned();
282
283 let propensity_scores = self.estimate_propensity_scores(&covariates, &treatment);
285
286 let treatment_effects =
288 self.estimate_treatment_effect(&covariates, &treatment, &outcome, &propensity_scores);
289
290 let mut rng = thread_rng();
292 let normal = Normal::new(0.0, 1.0).expect("operation should succeed");
293
294 let treatment_weights =
295 Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
296 rng.sample(normal) * (2.0 * self.config.treatment_bandwidth).sqrt()
297 });
298
299 let outcome_weights =
300 Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
301 rng.sample(normal) * (2.0 * self.config.outcome_bandwidth).sqrt()
302 });
303
304 Ok(CausalKernel {
305 config: self.config,
306 treatment_weights: Some(treatment_weights),
307 outcome_weights: Some(outcome_weights),
308 propensity_scores: Some(propensity_scores),
309 treatment_effects: Some(treatment_effects),
310 _state: PhantomData,
311 })
312 }
313}
314
315impl Transform<Array2<Float>, Array2<Float>> for CausalKernel<Trained> {
316 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
317 let treatment_weights = self
318 .treatment_weights
319 .as_ref()
320 .expect("operation should succeed");
321 let outcome_weights = self
322 .outcome_weights
323 .as_ref()
324 .expect("operation should succeed");
325
326 let n_covariates = treatment_weights.nrows();
328
329 if x.ncols() < n_covariates {
330 return Err(SklearsError::InvalidInput(format!(
331 "Input must have at least {} columns",
332 n_covariates
333 )));
334 }
335
336 let covariates = x.slice_axis(Axis(1), (0..n_covariates).into());
337
338 let treatment_projection = covariates.dot(treatment_weights);
340 let outcome_projection = covariates.dot(outcome_weights);
341
342 let n_samples = x.nrows();
343 let n_features = self.config.n_components * 2;
344 let mut output = Array2::zeros((n_samples, n_features));
345
346 let normalizer = (2.0 / self.config.n_components as Float).sqrt();
347
348 for i in 0..n_samples {
349 for j in 0..self.config.n_components {
350 output[[i, j]] = normalizer * treatment_projection[[i, j]].cos();
352 output[[i, j + self.config.n_components]] =
354 normalizer * outcome_projection[[i, j]].cos();
355 }
356 }
357
358 Ok(output)
359 }
360}
361
362impl CausalKernel<Trained> {
363 pub fn propensity_scores(&self) -> &Array1<Float> {
365 self.propensity_scores
366 .as_ref()
367 .expect("operation should succeed")
368 }
369
370 pub fn treatment_effects(&self) -> &HashMap<String, Float> {
372 self.treatment_effects
373 .as_ref()
374 .expect("operation should succeed")
375 }
376
377 pub fn ate(&self) -> Float {
379 self.treatment_effects
380 .as_ref()
381 .expect("operation should succeed")
382 .get("ate")
383 .copied()
384 .unwrap_or(0.0)
385 }
386}
387
388#[derive(Debug, Clone)]
413pub struct CounterfactualKernel<State = Untrained> {
414 config: CausalKernelConfig,
415
416 training_data: Option<Array2<Float>>,
418 kernel_features: Option<Array2<Float>>,
419 propensity_scores: Option<Array1<Float>>,
420
421 _state: PhantomData<State>,
422}
423
424impl CounterfactualKernel<Untrained> {
425 pub fn new(config: CausalKernelConfig) -> Self {
427 Self {
428 config,
429 training_data: None,
430 kernel_features: None,
431 propensity_scores: None,
432 _state: PhantomData,
433 }
434 }
435
436 pub fn with_components(n_components: usize) -> Self {
438 Self::new(CausalKernelConfig {
439 n_components,
440 ..Default::default()
441 })
442 }
443}
444
445impl Estimator for CounterfactualKernel<Untrained> {
446 type Config = CausalKernelConfig;
447 type Error = SklearsError;
448 type Float = Float;
449
450 fn config(&self) -> &Self::Config {
451 &self.config
452 }
453}
454
455impl Fit<Array2<Float>, ()> for CounterfactualKernel<Untrained> {
456 type Fitted = CounterfactualKernel<Trained>;
457
458 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
459 if x.nrows() < 2 || x.ncols() < 3 {
460 return Err(SklearsError::InvalidInput(
461 "Input must have at least 2 samples and 3 columns".to_string(),
462 ));
463 }
464
465 let training_data = x.clone();
466
467 let n_covariates = x.ncols() - 2;
469 let covariates = x.slice_axis(Axis(1), (0..n_covariates).into()).to_owned();
470 let treatment = x.column(n_covariates).to_owned();
471
472 let n_samples = x.nrows();
474 let mut propensity_scores = Array1::zeros(n_samples);
475
476 for i in 0..n_samples {
477 let mut score = 0.0;
478 let mut weight_sum = 0.0;
479
480 for j in 0..n_samples {
481 let mut dist_sq = 0.0;
482 for k in 0..n_covariates {
483 let diff = covariates[[i, k]] - covariates[[j, k]];
484 dist_sq += diff * diff;
485 }
486
487 let weight = (-dist_sq / (2.0 * self.config.treatment_bandwidth.powi(2))).exp();
488 score += weight * treatment[j];
489 weight_sum += weight;
490 }
491
492 propensity_scores[i] = if weight_sum > 1e-10 {
493 (score / weight_sum).max(0.01).min(0.99)
494 } else {
495 0.5
496 };
497 }
498
499 let mut rng = thread_rng();
501 let normal = Normal::new(0.0, 1.0).expect("operation should succeed");
502
503 let random_weights =
504 Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
505 rng.sample(normal) * (2.0 * self.config.treatment_bandwidth).sqrt()
506 });
507
508 let projection = covariates.dot(&random_weights);
509 let mut kernel_features = Array2::zeros((n_samples, self.config.n_components));
510
511 for i in 0..n_samples {
512 for j in 0..self.config.n_components {
513 kernel_features[[i, j]] = projection[[i, j]].cos();
514 }
515 }
516
517 Ok(CounterfactualKernel {
518 config: self.config,
519 training_data: Some(training_data),
520 kernel_features: Some(kernel_features),
521 propensity_scores: Some(propensity_scores),
522 _state: PhantomData,
523 })
524 }
525}
526
527impl Transform<Array2<Float>, Array2<Float>> for CounterfactualKernel<Trained> {
528 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
529 let training_data = self
530 .training_data
531 .as_ref()
532 .expect("operation should succeed");
533 let kernel_features = self
534 .kernel_features
535 .as_ref()
536 .expect("operation should succeed");
537
538 let n_covariates = training_data.ncols() - 2;
539
540 if x.ncols() < n_covariates {
541 return Err(SklearsError::InvalidInput(format!(
542 "Input must have at least {} columns",
543 n_covariates
544 )));
545 }
546
547 let n_samples = x.nrows();
550 let mut output = Array2::zeros((n_samples, self.config.n_components + 2));
551
552 for i in 0..n_samples {
553 let k = 5.min(kernel_features.nrows());
555 let mut distances = Vec::new();
556
557 for j in 0..kernel_features.nrows() {
558 let mut dist = 0.0;
559 for l in 0..n_covariates {
560 let diff = x[[i, l]] - training_data[[j, l]];
561 dist += diff * diff;
562 }
563 distances.push((dist, j));
564 }
565
566 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("operation should succeed"));
567
568 let mut treated_outcome = 0.0;
570 let mut control_outcome = 0.0;
571 let mut treated_weight = 0.0;
572 let mut control_weight = 0.0;
573
574 for &(dist, idx) in distances.iter().take(k) {
575 let weight = (-dist / self.config.treatment_bandwidth).exp();
576 let treatment_val = training_data[[idx, n_covariates]];
577 let outcome_val = training_data[[idx, n_covariates + 1]];
578
579 if treatment_val > 0.5 {
580 treated_outcome += weight * outcome_val;
581 treated_weight += weight;
582 } else {
583 control_outcome += weight * outcome_val;
584 control_weight += weight;
585 }
586 }
587
588 output[[i, 0]] = if treated_weight > 0.0 {
590 treated_outcome / treated_weight
591 } else {
592 0.0
593 };
594
595 output[[i, 1]] = if control_weight > 0.0 {
596 control_outcome / control_weight
597 } else {
598 0.0
599 };
600
601 for j in 0..self.config.n_components {
603 if j < kernel_features.ncols() {
604 output[[i, j + 2]] = kernel_features[[distances[0].1, j]];
605 }
606 }
607 }
608
609 Ok(output)
610 }
611}
612
613impl CounterfactualKernel<Trained> {
614 pub fn propensity_scores(&self) -> &Array1<Float> {
616 self.propensity_scores
617 .as_ref()
618 .expect("operation should succeed")
619 }
620
621 pub fn estimate_ite(&self, sample: &Array2<Float>) -> Result<Float> {
623 let counterfactuals = self.transform(sample)?;
624
625 if counterfactuals.nrows() > 0 {
626 Ok(counterfactuals[[0, 0]] - counterfactuals[[0, 1]])
628 } else {
629 Ok(0.0)
630 }
631 }
632}
633
634#[cfg(test)]
635mod tests {
636 use super::*;
637 use scirs2_core::ndarray::array;
638
639 #[test]
640 fn test_causal_kernel_basic() {
641 let config = CausalKernelConfig {
642 n_components: 20,
643 treatment_bandwidth: 1.0,
644 outcome_bandwidth: 1.0,
645 ..Default::default()
646 };
647
648 let causal = CausalKernel::new(config);
649
650 let data = array![
652 [1.0, 2.0, 0.0, 1.0],
653 [2.0, 3.0, 1.0, 5.0],
654 [1.5, 2.5, 0.0, 2.0],
655 [2.5, 3.5, 1.0, 6.0],
656 ];
657
658 let fitted = causal.fit(&data, &()).expect("operation should succeed");
659 let features = fitted.transform(&data).expect("operation should succeed");
660
661 assert_eq!(features.nrows(), 4);
662 assert_eq!(features.ncols(), 40); }
664
665 #[test]
666 fn test_propensity_score_estimation() {
667 let config = CausalKernelConfig::default();
668 let causal = CausalKernel::new(config);
669
670 let data = array![
671 [1.0, 0.0, 1.0],
672 [2.0, 1.0, 5.0],
673 [1.5, 0.0, 2.0],
674 [2.5, 1.0, 6.0],
675 ];
676
677 let fitted = causal.fit(&data, &()).expect("operation should succeed");
678 let scores = fitted.propensity_scores();
679
680 assert!(scores.iter().all(|&s| s >= 0.0 && s <= 1.0));
682 }
683
684 #[test]
685 fn test_treatment_effect_estimation() {
686 let config = CausalKernelConfig::default();
687 let causal = CausalKernel::new(config);
688
689 let data = array![
690 [1.0, 0.0, 1.0],
691 [2.0, 1.0, 5.0],
692 [1.5, 0.0, 2.0],
693 [2.5, 1.0, 6.0],
694 ];
695
696 let fitted = causal.fit(&data, &()).expect("operation should succeed");
697 let effects = fitted.treatment_effects();
698
699 assert!(effects.contains_key("ate"));
700 assert!(effects.contains_key("naive_difference"));
701 assert!(effects["ate"].is_finite());
702 }
703
704 #[test]
705 fn test_counterfactual_kernel() {
706 let config = CausalKernelConfig {
707 n_components: 10,
708 ..Default::default()
709 };
710
711 let cf = CounterfactualKernel::new(config);
712
713 let data = array![
714 [1.0, 0.0, 1.0],
715 [2.0, 1.0, 5.0],
716 [1.5, 0.0, 2.0],
717 [2.5, 1.0, 6.0],
718 ];
719
720 let fitted = cf.fit(&data, &()).expect("operation should succeed");
721 let test_data = array![[1.2], [2.3]];
722 let counterfactuals = fitted
723 .transform(&test_data)
724 .expect("operation should succeed");
725
726 assert_eq!(counterfactuals.nrows(), 2);
727 assert_eq!(counterfactuals.ncols(), 12);
729 }
730
731 #[test]
732 fn test_individual_treatment_effect() {
733 let config = CausalKernelConfig {
734 n_components: 10,
735 ..Default::default()
736 };
737
738 let cf = CounterfactualKernel::new(config);
739
740 let data = array![
741 [1.0, 0.0, 1.0],
742 [2.0, 1.0, 5.0],
743 [1.5, 0.0, 2.0],
744 [2.5, 1.0, 6.0],
745 ];
746
747 let fitted = cf.fit(&data, &()).expect("operation should succeed");
748 let test_sample = array![[1.5]];
749 let ite = fitted
750 .estimate_ite(&test_sample)
751 .expect("operation should succeed");
752
753 assert!(ite.is_finite());
754 }
755
756 #[test]
757 fn test_empty_input_error() {
758 let causal = CausalKernel::with_components(20);
759 let empty_data: Array2<Float> = Array2::zeros((0, 0));
760
761 assert!(causal.fit(&empty_data, &()).is_err());
762 }
763
764 #[test]
765 fn test_insufficient_columns_error() {
766 let causal = CausalKernel::with_components(20);
767 let data = array![[1.0, 2.0]]; assert!(causal.fit(&data, &()).is_err());
770 }
771
772 #[test]
773 fn test_different_causal_methods() {
774 let methods = vec![
775 CausalMethod::TreatmentEffect,
776 CausalMethod::ConditionalTreatmentEffect,
777 CausalMethod::InstrumentalVariable,
778 ];
779
780 let data = array![
781 [1.0, 0.0, 1.0],
782 [2.0, 1.0, 5.0],
783 [1.5, 0.0, 2.0],
784 [2.5, 1.0, 6.0],
785 ];
786
787 for method in methods {
788 let causal = CausalKernel::with_components(20).method(method);
789 let fitted = causal.fit(&data, &()).expect("operation should succeed");
790 let features = fitted.transform(&data).expect("operation should succeed");
791
792 assert_eq!(features.nrows(), 4);
793 }
794 }
795}