1use scirs2_core::ndarray::{Array1, Array2, Axis};
14use scirs2_core::random::essentials::Normal;
15use scirs2_core::random::thread_rng;
16use serde::{Deserialize, Serialize};
17use sklears_core::{
18 error::{Result, SklearsError},
19 prelude::{Fit, Transform},
20 traits::{Estimator, Trained, Untrained},
21 types::Float,
22};
23use std::collections::HashMap;
24use std::marker::PhantomData;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct CausalKernelConfig {
29 pub causal_method: CausalMethod,
31 pub treatment_bandwidth: Float,
33 pub outcome_bandwidth: Float,
35 pub n_components: usize,
37 pub regularization: Float,
39}
40
41impl Default for CausalKernelConfig {
42 fn default() -> Self {
43 Self {
44 causal_method: CausalMethod::TreatmentEffect,
45 treatment_bandwidth: 1.0,
46 outcome_bandwidth: 1.0,
47 n_components: 100,
48 regularization: 1e-5,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
55pub enum CausalMethod {
56 TreatmentEffect,
58 ConditionalTreatmentEffect,
60 InstrumentalVariable,
62 RegressionDiscontinuity,
64 DifferenceInDifferences,
66}
67
68#[derive(Debug, Clone)]
96pub struct CausalKernel<State = Untrained> {
97 config: CausalKernelConfig,
98
99 treatment_weights: Option<Array2<Float>>,
101 outcome_weights: Option<Array2<Float>>,
102 propensity_scores: Option<Array1<Float>>,
103 treatment_effects: Option<HashMap<String, Float>>,
104
105 _state: PhantomData<State>,
106}
107
108impl CausalKernel<Untrained> {
109 pub fn new(config: CausalKernelConfig) -> Self {
111 Self {
112 config,
113 treatment_weights: None,
114 outcome_weights: None,
115 propensity_scores: None,
116 treatment_effects: None,
117 _state: PhantomData,
118 }
119 }
120
121 pub fn with_components(n_components: usize) -> Self {
123 Self {
124 config: CausalKernelConfig {
125 n_components,
126 ..Default::default()
127 },
128 treatment_weights: None,
129 outcome_weights: None,
130 propensity_scores: None,
131 treatment_effects: None,
132 _state: PhantomData,
133 }
134 }
135
136 pub fn method(mut self, method: CausalMethod) -> Self {
138 self.config.causal_method = method;
139 self
140 }
141
142 pub fn treatment_bandwidth(mut self, gamma: Float) -> Self {
144 self.config.treatment_bandwidth = gamma;
145 self
146 }
147
148 fn estimate_propensity_scores(
150 &self,
151 x: &Array2<Float>,
152 treatment: &Array1<Float>,
153 ) -> Array1<Float> {
154 let n_samples = x.nrows();
155 let mut scores = Array1::zeros(n_samples);
156
157 for i in 0..n_samples {
159 let mut score = 0.0;
160 let mut weight_sum = 0.0;
161
162 for j in 0..n_samples {
163 let mut dist_sq = 0.0;
165 for k in 0..x.ncols() {
166 let diff = x[[i, k]] - x[[j, k]];
167 dist_sq += diff * diff;
168 }
169
170 let weight = (-dist_sq / (2.0 * self.config.treatment_bandwidth.powi(2))).exp();
171 score += weight * treatment[j];
172 weight_sum += weight;
173 }
174
175 scores[i] = if weight_sum > 1e-10 {
176 (score / weight_sum).max(0.01).min(0.99) } else {
178 0.5
179 };
180 }
181
182 scores
183 }
184
185 fn estimate_treatment_effect(
187 &self,
188 x: &Array2<Float>,
189 treatment: &Array1<Float>,
190 outcome: &Array1<Float>,
191 propensity_scores: &Array1<Float>,
192 ) -> HashMap<String, Float> {
193 let n_samples = x.nrows() as Float;
194
195 let mut ate_numerator_treated = 0.0;
197 let mut ate_numerator_control = 0.0;
198 let mut weight_sum_treated = 0.0;
199 let mut weight_sum_control = 0.0;
200
201 for i in 0..treatment.len() {
202 if treatment[i] > 0.5 {
203 let weight = 1.0 / propensity_scores[i];
205 ate_numerator_treated += weight * outcome[i];
206 weight_sum_treated += weight;
207 } else {
208 let weight = 1.0 / (1.0 - propensity_scores[i]);
210 ate_numerator_control += weight * outcome[i];
211 weight_sum_control += weight;
212 }
213 }
214
215 let ate = if weight_sum_treated > 0.0 && weight_sum_control > 0.0 {
216 (ate_numerator_treated / weight_sum_treated)
217 - (ate_numerator_control / weight_sum_control)
218 } else {
219 0.0
220 };
221
222 let treated_outcomes: Vec<Float> = treatment
224 .iter()
225 .zip(outcome.iter())
226 .filter_map(|(&t, &y)| if t > 0.5 { Some(y) } else { None })
227 .collect();
228
229 let control_outcomes: Vec<Float> = treatment
230 .iter()
231 .zip(outcome.iter())
232 .filter_map(|(&t, &y)| if t <= 0.5 { Some(y) } else { None })
233 .collect();
234
235 let naive_diff = if !treated_outcomes.is_empty() && !control_outcomes.is_empty() {
236 let treated_mean =
237 treated_outcomes.iter().sum::<Float>() / treated_outcomes.len() as Float;
238 let control_mean =
239 control_outcomes.iter().sum::<Float>() / control_outcomes.len() as Float;
240 treated_mean - control_mean
241 } else {
242 0.0
243 };
244
245 let mut effects = HashMap::new();
246 effects.insert("ate".to_string(), ate);
247 effects.insert("naive_difference".to_string(), naive_diff);
248 effects.insert("n_samples".to_string(), n_samples);
249 effects.insert("n_treated".to_string(), treated_outcomes.len() as Float);
250 effects.insert("n_control".to_string(), control_outcomes.len() as Float);
251
252 effects
253 }
254}
255
256impl Estimator for CausalKernel<Untrained> {
257 type Config = CausalKernelConfig;
258 type Error = SklearsError;
259 type Float = Float;
260
261 fn config(&self) -> &Self::Config {
262 &self.config
263 }
264}
265
266impl Fit<Array2<Float>, ()> for CausalKernel<Untrained> {
267 type Fitted = CausalKernel<Trained>;
268
269 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
270 if x.nrows() < 2 || x.ncols() < 3 {
271 return Err(SklearsError::InvalidInput(
272 "Input must have at least 2 samples and 3 columns (covariates, treatment, outcome)"
273 .to_string(),
274 ));
275 }
276
277 let n_covariates = x.ncols() - 2;
279 let covariates = x.slice_axis(Axis(1), (0..n_covariates).into()).to_owned();
280 let treatment = x.column(n_covariates).to_owned();
281 let outcome = x.column(n_covariates + 1).to_owned();
282
283 let propensity_scores = self.estimate_propensity_scores(&covariates, &treatment);
285
286 let treatment_effects =
288 self.estimate_treatment_effect(&covariates, &treatment, &outcome, &propensity_scores);
289
290 let mut rng = thread_rng();
292 let normal = Normal::new(0.0, 1.0).unwrap();
293
294 let treatment_weights =
295 Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
296 rng.sample(normal) * (2.0 * self.config.treatment_bandwidth).sqrt()
297 });
298
299 let outcome_weights =
300 Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
301 rng.sample(normal) * (2.0 * self.config.outcome_bandwidth).sqrt()
302 });
303
304 Ok(CausalKernel {
305 config: self.config,
306 treatment_weights: Some(treatment_weights),
307 outcome_weights: Some(outcome_weights),
308 propensity_scores: Some(propensity_scores),
309 treatment_effects: Some(treatment_effects),
310 _state: PhantomData,
311 })
312 }
313}
314
315impl Transform<Array2<Float>, Array2<Float>> for CausalKernel<Trained> {
316 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
317 let treatment_weights = self.treatment_weights.as_ref().unwrap();
318 let outcome_weights = self.outcome_weights.as_ref().unwrap();
319
320 let n_covariates = treatment_weights.nrows();
322
323 if x.ncols() < n_covariates {
324 return Err(SklearsError::InvalidInput(format!(
325 "Input must have at least {} columns",
326 n_covariates
327 )));
328 }
329
330 let covariates = x.slice_axis(Axis(1), (0..n_covariates).into());
331
332 let treatment_projection = covariates.dot(treatment_weights);
334 let outcome_projection = covariates.dot(outcome_weights);
335
336 let n_samples = x.nrows();
337 let n_features = self.config.n_components * 2;
338 let mut output = Array2::zeros((n_samples, n_features));
339
340 let normalizer = (2.0 / self.config.n_components as Float).sqrt();
341
342 for i in 0..n_samples {
343 for j in 0..self.config.n_components {
344 output[[i, j]] = normalizer * treatment_projection[[i, j]].cos();
346 output[[i, j + self.config.n_components]] =
348 normalizer * outcome_projection[[i, j]].cos();
349 }
350 }
351
352 Ok(output)
353 }
354}
355
356impl CausalKernel<Trained> {
357 pub fn propensity_scores(&self) -> &Array1<Float> {
359 self.propensity_scores.as_ref().unwrap()
360 }
361
362 pub fn treatment_effects(&self) -> &HashMap<String, Float> {
364 self.treatment_effects.as_ref().unwrap()
365 }
366
367 pub fn ate(&self) -> Float {
369 self.treatment_effects
370 .as_ref()
371 .unwrap()
372 .get("ate")
373 .copied()
374 .unwrap_or(0.0)
375 }
376}
377
378#[derive(Debug, Clone)]
403pub struct CounterfactualKernel<State = Untrained> {
404 config: CausalKernelConfig,
405
406 training_data: Option<Array2<Float>>,
408 kernel_features: Option<Array2<Float>>,
409 propensity_scores: Option<Array1<Float>>,
410
411 _state: PhantomData<State>,
412}
413
414impl CounterfactualKernel<Untrained> {
415 pub fn new(config: CausalKernelConfig) -> Self {
417 Self {
418 config,
419 training_data: None,
420 kernel_features: None,
421 propensity_scores: None,
422 _state: PhantomData,
423 }
424 }
425
426 pub fn with_components(n_components: usize) -> Self {
428 Self::new(CausalKernelConfig {
429 n_components,
430 ..Default::default()
431 })
432 }
433}
434
435impl Estimator for CounterfactualKernel<Untrained> {
436 type Config = CausalKernelConfig;
437 type Error = SklearsError;
438 type Float = Float;
439
440 fn config(&self) -> &Self::Config {
441 &self.config
442 }
443}
444
445impl Fit<Array2<Float>, ()> for CounterfactualKernel<Untrained> {
446 type Fitted = CounterfactualKernel<Trained>;
447
448 fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
449 if x.nrows() < 2 || x.ncols() < 3 {
450 return Err(SklearsError::InvalidInput(
451 "Input must have at least 2 samples and 3 columns".to_string(),
452 ));
453 }
454
455 let training_data = x.clone();
456
457 let n_covariates = x.ncols() - 2;
459 let covariates = x.slice_axis(Axis(1), (0..n_covariates).into()).to_owned();
460 let treatment = x.column(n_covariates).to_owned();
461
462 let n_samples = x.nrows();
464 let mut propensity_scores = Array1::zeros(n_samples);
465
466 for i in 0..n_samples {
467 let mut score = 0.0;
468 let mut weight_sum = 0.0;
469
470 for j in 0..n_samples {
471 let mut dist_sq = 0.0;
472 for k in 0..n_covariates {
473 let diff = covariates[[i, k]] - covariates[[j, k]];
474 dist_sq += diff * diff;
475 }
476
477 let weight = (-dist_sq / (2.0 * self.config.treatment_bandwidth.powi(2))).exp();
478 score += weight * treatment[j];
479 weight_sum += weight;
480 }
481
482 propensity_scores[i] = if weight_sum > 1e-10 {
483 (score / weight_sum).max(0.01).min(0.99)
484 } else {
485 0.5
486 };
487 }
488
489 let mut rng = thread_rng();
491 let normal = Normal::new(0.0, 1.0).unwrap();
492
493 let random_weights =
494 Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
495 rng.sample(normal) * (2.0 * self.config.treatment_bandwidth).sqrt()
496 });
497
498 let projection = covariates.dot(&random_weights);
499 let mut kernel_features = Array2::zeros((n_samples, self.config.n_components));
500
501 for i in 0..n_samples {
502 for j in 0..self.config.n_components {
503 kernel_features[[i, j]] = projection[[i, j]].cos();
504 }
505 }
506
507 Ok(CounterfactualKernel {
508 config: self.config,
509 training_data: Some(training_data),
510 kernel_features: Some(kernel_features),
511 propensity_scores: Some(propensity_scores),
512 _state: PhantomData,
513 })
514 }
515}
516
517impl Transform<Array2<Float>, Array2<Float>> for CounterfactualKernel<Trained> {
518 fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
519 let training_data = self.training_data.as_ref().unwrap();
520 let kernel_features = self.kernel_features.as_ref().unwrap();
521
522 let n_covariates = training_data.ncols() - 2;
523
524 if x.ncols() < n_covariates {
525 return Err(SklearsError::InvalidInput(format!(
526 "Input must have at least {} columns",
527 n_covariates
528 )));
529 }
530
531 let n_samples = x.nrows();
534 let mut output = Array2::zeros((n_samples, self.config.n_components + 2));
535
536 for i in 0..n_samples {
537 let k = 5.min(kernel_features.nrows());
539 let mut distances = Vec::new();
540
541 for j in 0..kernel_features.nrows() {
542 let mut dist = 0.0;
543 for l in 0..n_covariates {
544 let diff = x[[i, l]] - training_data[[j, l]];
545 dist += diff * diff;
546 }
547 distances.push((dist, j));
548 }
549
550 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
551
552 let mut treated_outcome = 0.0;
554 let mut control_outcome = 0.0;
555 let mut treated_weight = 0.0;
556 let mut control_weight = 0.0;
557
558 for &(dist, idx) in distances.iter().take(k) {
559 let weight = (-dist / self.config.treatment_bandwidth).exp();
560 let treatment_val = training_data[[idx, n_covariates]];
561 let outcome_val = training_data[[idx, n_covariates + 1]];
562
563 if treatment_val > 0.5 {
564 treated_outcome += weight * outcome_val;
565 treated_weight += weight;
566 } else {
567 control_outcome += weight * outcome_val;
568 control_weight += weight;
569 }
570 }
571
572 output[[i, 0]] = if treated_weight > 0.0 {
574 treated_outcome / treated_weight
575 } else {
576 0.0
577 };
578
579 output[[i, 1]] = if control_weight > 0.0 {
580 control_outcome / control_weight
581 } else {
582 0.0
583 };
584
585 for j in 0..self.config.n_components {
587 if j < kernel_features.ncols() {
588 output[[i, j + 2]] = kernel_features[[distances[0].1, j]];
589 }
590 }
591 }
592
593 Ok(output)
594 }
595}
596
597impl CounterfactualKernel<Trained> {
598 pub fn propensity_scores(&self) -> &Array1<Float> {
600 self.propensity_scores.as_ref().unwrap()
601 }
602
603 pub fn estimate_ite(&self, sample: &Array2<Float>) -> Result<Float> {
605 let counterfactuals = self.transform(sample)?;
606
607 if counterfactuals.nrows() > 0 {
608 Ok(counterfactuals[[0, 0]] - counterfactuals[[0, 1]])
610 } else {
611 Ok(0.0)
612 }
613 }
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619 use scirs2_core::ndarray::array;
620
621 #[test]
622 fn test_causal_kernel_basic() {
623 let config = CausalKernelConfig {
624 n_components: 20,
625 treatment_bandwidth: 1.0,
626 outcome_bandwidth: 1.0,
627 ..Default::default()
628 };
629
630 let causal = CausalKernel::new(config);
631
632 let data = array![
634 [1.0, 2.0, 0.0, 1.0],
635 [2.0, 3.0, 1.0, 5.0],
636 [1.5, 2.5, 0.0, 2.0],
637 [2.5, 3.5, 1.0, 6.0],
638 ];
639
640 let fitted = causal.fit(&data, &()).unwrap();
641 let features = fitted.transform(&data).unwrap();
642
643 assert_eq!(features.nrows(), 4);
644 assert_eq!(features.ncols(), 40); }
646
647 #[test]
648 fn test_propensity_score_estimation() {
649 let config = CausalKernelConfig::default();
650 let causal = CausalKernel::new(config);
651
652 let data = array![
653 [1.0, 0.0, 1.0],
654 [2.0, 1.0, 5.0],
655 [1.5, 0.0, 2.0],
656 [2.5, 1.0, 6.0],
657 ];
658
659 let fitted = causal.fit(&data, &()).unwrap();
660 let scores = fitted.propensity_scores();
661
662 assert!(scores.iter().all(|&s| s >= 0.0 && s <= 1.0));
664 }
665
666 #[test]
667 fn test_treatment_effect_estimation() {
668 let config = CausalKernelConfig::default();
669 let causal = CausalKernel::new(config);
670
671 let data = array![
672 [1.0, 0.0, 1.0],
673 [2.0, 1.0, 5.0],
674 [1.5, 0.0, 2.0],
675 [2.5, 1.0, 6.0],
676 ];
677
678 let fitted = causal.fit(&data, &()).unwrap();
679 let effects = fitted.treatment_effects();
680
681 assert!(effects.contains_key("ate"));
682 assert!(effects.contains_key("naive_difference"));
683 assert!(effects["ate"].is_finite());
684 }
685
686 #[test]
687 fn test_counterfactual_kernel() {
688 let config = CausalKernelConfig {
689 n_components: 10,
690 ..Default::default()
691 };
692
693 let cf = CounterfactualKernel::new(config);
694
695 let data = array![
696 [1.0, 0.0, 1.0],
697 [2.0, 1.0, 5.0],
698 [1.5, 0.0, 2.0],
699 [2.5, 1.0, 6.0],
700 ];
701
702 let fitted = cf.fit(&data, &()).unwrap();
703 let test_data = array![[1.2], [2.3]];
704 let counterfactuals = fitted.transform(&test_data).unwrap();
705
706 assert_eq!(counterfactuals.nrows(), 2);
707 assert_eq!(counterfactuals.ncols(), 12);
709 }
710
711 #[test]
712 fn test_individual_treatment_effect() {
713 let config = CausalKernelConfig {
714 n_components: 10,
715 ..Default::default()
716 };
717
718 let cf = CounterfactualKernel::new(config);
719
720 let data = array![
721 [1.0, 0.0, 1.0],
722 [2.0, 1.0, 5.0],
723 [1.5, 0.0, 2.0],
724 [2.5, 1.0, 6.0],
725 ];
726
727 let fitted = cf.fit(&data, &()).unwrap();
728 let test_sample = array![[1.5]];
729 let ite = fitted.estimate_ite(&test_sample).unwrap();
730
731 assert!(ite.is_finite());
732 }
733
734 #[test]
735 fn test_empty_input_error() {
736 let causal = CausalKernel::with_components(20);
737 let empty_data: Array2<Float> = Array2::zeros((0, 0));
738
739 assert!(causal.fit(&empty_data, &()).is_err());
740 }
741
742 #[test]
743 fn test_insufficient_columns_error() {
744 let causal = CausalKernel::with_components(20);
745 let data = array![[1.0, 2.0]]; assert!(causal.fit(&data, &()).is_err());
748 }
749
750 #[test]
751 fn test_different_causal_methods() {
752 let methods = vec![
753 CausalMethod::TreatmentEffect,
754 CausalMethod::ConditionalTreatmentEffect,
755 CausalMethod::InstrumentalVariable,
756 ];
757
758 let data = array![
759 [1.0, 0.0, 1.0],
760 [2.0, 1.0, 5.0],
761 [1.5, 0.0, 2.0],
762 [2.5, 1.0, 6.0],
763 ];
764
765 for method in methods {
766 let causal = CausalKernel::with_components(20).method(method);
767 let fitted = causal.fit(&data, &()).unwrap();
768 let features = fitted.transform(&data).unwrap();
769
770 assert_eq!(features.nrows(), 4);
771 }
772 }
773}