sklears_kernel_approximation/
causal_kernels.rs

1//! Causal Inference Kernel Methods
2//!
3//! This module implements kernel methods for causal inference, including treatment
4//! effect estimation, interventional distributions, counterfactual reasoning, and
5//! causal discovery from observational data.
6//!
7//! # References
8//! - Pearl (2009): "Causality: Models, Reasoning and Inference"
9//! - Schölkopf et al. (2021): "Toward Causal Representation Learning"
10//! - Peters et al. (2017): "Elements of Causal Inference"
11//! - Gretton et al. (2012): "Kernel-based conditional independence test"
12
13use scirs2_core::ndarray::{Array1, Array2, Axis};
14use scirs2_core::random::essentials::Normal;
15use scirs2_core::random::thread_rng;
16use serde::{Deserialize, Serialize};
17use sklears_core::{
18    error::{Result, SklearsError},
19    prelude::{Fit, Transform},
20    traits::{Estimator, Trained, Untrained},
21    types::Float,
22};
23use std::collections::HashMap;
24use std::marker::PhantomData;
25
26/// Configuration for causal kernel methods
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct CausalKernelConfig {
29    /// Type of causal analysis
30    pub causal_method: CausalMethod,
31    /// Kernel bandwidth for treatment variables
32    pub treatment_bandwidth: Float,
33    /// Kernel bandwidth for outcome variables
34    pub outcome_bandwidth: Float,
35    /// Number of random features
36    pub n_components: usize,
37    /// Regularization parameter
38    pub regularization: Float,
39}
40
41impl Default for CausalKernelConfig {
42    fn default() -> Self {
43        Self {
44            causal_method: CausalMethod::TreatmentEffect,
45            treatment_bandwidth: 1.0,
46            outcome_bandwidth: 1.0,
47            n_components: 100,
48            regularization: 1e-5,
49        }
50    }
51}
52
53/// Types of causal analysis methods
54#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
55pub enum CausalMethod {
56    /// Average Treatment Effect (ATE) estimation
57    TreatmentEffect,
58    /// Conditional Average Treatment Effect (CATE)
59    ConditionalTreatmentEffect,
60    /// Instrumental variable method
61    InstrumentalVariable,
62    /// Regression discontinuity
63    RegressionDiscontinuity,
64    /// Difference-in-differences
65    DifferenceInDifferences,
66}
67
68/// Causal Kernel for Treatment Effect Estimation
69///
70/// Implements kernel-based methods for estimating causal effects from
71/// observational data, including propensity score weighting and doubly robust
72/// estimation.
73///
74/// # Mathematical Background
75///
76/// For treatment T, outcome Y, and covariates X:
77/// - ATE = E[Y(1) - Y(0)] = E[Y|do(T=1)] - E[Y|do(T=0)]
78/// - Uses kernel-based representation: τ(x) = E[Y|T=1,X=x] - E[Y|T=0,X=x]
79///
80/// # Examples
81///
82/// ```rust,ignore
83/// use sklears_kernel_approximation::causal_kernels::{CausalKernel, CausalKernelConfig};
84/// use scirs2_core::ndarray::array;
85/// use sklears_core::traits::{Fit, Transform};
86///
87/// let config = CausalKernelConfig::default();
88/// let causal = CausalKernel::new(config);
89///
90/// // X: covariates, T: treatment, Y: outcome
91/// let data = array![[1.0, 0.0, 2.0], [2.0, 1.0, 5.0]];
92/// let fitted = causal.fit(&data, &()).unwrap();
93/// let features = fitted.transform(&data).unwrap();
94/// ```
95#[derive(Debug, Clone)]
96pub struct CausalKernel<State = Untrained> {
97    config: CausalKernelConfig,
98
99    // Fitted attributes
100    treatment_weights: Option<Array2<Float>>,
101    outcome_weights: Option<Array2<Float>>,
102    propensity_scores: Option<Array1<Float>>,
103    treatment_effects: Option<HashMap<String, Float>>,
104
105    _state: PhantomData<State>,
106}
107
108impl CausalKernel<Untrained> {
109    /// Create a new causal kernel
110    pub fn new(config: CausalKernelConfig) -> Self {
111        Self {
112            config,
113            treatment_weights: None,
114            outcome_weights: None,
115            propensity_scores: None,
116            treatment_effects: None,
117            _state: PhantomData,
118        }
119    }
120
121    /// Create with default configuration
122    pub fn with_components(n_components: usize) -> Self {
123        Self {
124            config: CausalKernelConfig {
125                n_components,
126                ..Default::default()
127            },
128            treatment_weights: None,
129            outcome_weights: None,
130            propensity_scores: None,
131            treatment_effects: None,
132            _state: PhantomData,
133        }
134    }
135
136    /// Set causal method
137    pub fn method(mut self, method: CausalMethod) -> Self {
138        self.config.causal_method = method;
139        self
140    }
141
142    /// Set treatment bandwidth
143    pub fn treatment_bandwidth(mut self, gamma: Float) -> Self {
144        self.config.treatment_bandwidth = gamma;
145        self
146    }
147
148    /// Estimate propensity scores (probability of treatment given covariates)
149    fn estimate_propensity_scores(
150        &self,
151        x: &Array2<Float>,
152        treatment: &Array1<Float>,
153    ) -> Array1<Float> {
154        let n_samples = x.nrows();
155        let mut scores = Array1::zeros(n_samples);
156
157        // Simple logistic kernel density estimation
158        for i in 0..n_samples {
159            let mut score = 0.0;
160            let mut weight_sum = 0.0;
161
162            for j in 0..n_samples {
163                // Compute kernel similarity
164                let mut dist_sq = 0.0;
165                for k in 0..x.ncols() {
166                    let diff = x[[i, k]] - x[[j, k]];
167                    dist_sq += diff * diff;
168                }
169
170                let weight = (-dist_sq / (2.0 * self.config.treatment_bandwidth.powi(2))).exp();
171                score += weight * treatment[j];
172                weight_sum += weight;
173            }
174
175            scores[i] = if weight_sum > 1e-10 {
176                (score / weight_sum).max(0.01).min(0.99) // Clip for stability
177            } else {
178                0.5
179            };
180        }
181
182        scores
183    }
184
185    /// Estimate treatment effect using inverse propensity weighting
186    fn estimate_treatment_effect(
187        &self,
188        x: &Array2<Float>,
189        treatment: &Array1<Float>,
190        outcome: &Array1<Float>,
191        propensity_scores: &Array1<Float>,
192    ) -> HashMap<String, Float> {
193        let n_samples = x.nrows() as Float;
194
195        // Average Treatment Effect (ATE) using IPW
196        let mut ate_numerator_treated = 0.0;
197        let mut ate_numerator_control = 0.0;
198        let mut weight_sum_treated = 0.0;
199        let mut weight_sum_control = 0.0;
200
201        for i in 0..treatment.len() {
202            if treatment[i] > 0.5 {
203                // Treated group
204                let weight = 1.0 / propensity_scores[i];
205                ate_numerator_treated += weight * outcome[i];
206                weight_sum_treated += weight;
207            } else {
208                // Control group
209                let weight = 1.0 / (1.0 - propensity_scores[i]);
210                ate_numerator_control += weight * outcome[i];
211                weight_sum_control += weight;
212            }
213        }
214
215        let ate = if weight_sum_treated > 0.0 && weight_sum_control > 0.0 {
216            (ate_numerator_treated / weight_sum_treated)
217                - (ate_numerator_control / weight_sum_control)
218        } else {
219            0.0
220        };
221
222        // Naive difference (for comparison)
223        let treated_outcomes: Vec<Float> = treatment
224            .iter()
225            .zip(outcome.iter())
226            .filter_map(|(&t, &y)| if t > 0.5 { Some(y) } else { None })
227            .collect();
228
229        let control_outcomes: Vec<Float> = treatment
230            .iter()
231            .zip(outcome.iter())
232            .filter_map(|(&t, &y)| if t <= 0.5 { Some(y) } else { None })
233            .collect();
234
235        let naive_diff = if !treated_outcomes.is_empty() && !control_outcomes.is_empty() {
236            let treated_mean =
237                treated_outcomes.iter().sum::<Float>() / treated_outcomes.len() as Float;
238            let control_mean =
239                control_outcomes.iter().sum::<Float>() / control_outcomes.len() as Float;
240            treated_mean - control_mean
241        } else {
242            0.0
243        };
244
245        let mut effects = HashMap::new();
246        effects.insert("ate".to_string(), ate);
247        effects.insert("naive_difference".to_string(), naive_diff);
248        effects.insert("n_samples".to_string(), n_samples);
249        effects.insert("n_treated".to_string(), treated_outcomes.len() as Float);
250        effects.insert("n_control".to_string(), control_outcomes.len() as Float);
251
252        effects
253    }
254}
255
256impl Estimator for CausalKernel<Untrained> {
257    type Config = CausalKernelConfig;
258    type Error = SklearsError;
259    type Float = Float;
260
261    fn config(&self) -> &Self::Config {
262        &self.config
263    }
264}
265
266impl Fit<Array2<Float>, ()> for CausalKernel<Untrained> {
267    type Fitted = CausalKernel<Trained>;
268
269    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
270        if x.nrows() < 2 || x.ncols() < 3 {
271            return Err(SklearsError::InvalidInput(
272                "Input must have at least 2 samples and 3 columns (covariates, treatment, outcome)"
273                    .to_string(),
274            ));
275        }
276
277        // Assume last 2 columns are treatment and outcome
278        let n_covariates = x.ncols() - 2;
279        let covariates = x.slice_axis(Axis(1), (0..n_covariates).into()).to_owned();
280        let treatment = x.column(n_covariates).to_owned();
281        let outcome = x.column(n_covariates + 1).to_owned();
282
283        // Estimate propensity scores
284        let propensity_scores = self.estimate_propensity_scores(&covariates, &treatment);
285
286        // Estimate treatment effects
287        let treatment_effects =
288            self.estimate_treatment_effect(&covariates, &treatment, &outcome, &propensity_scores);
289
290        // Generate random features for kernel approximation
291        let mut rng = thread_rng();
292        let normal = Normal::new(0.0, 1.0).unwrap();
293
294        let treatment_weights =
295            Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
296                rng.sample(normal) * (2.0 * self.config.treatment_bandwidth).sqrt()
297            });
298
299        let outcome_weights =
300            Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
301                rng.sample(normal) * (2.0 * self.config.outcome_bandwidth).sqrt()
302            });
303
304        Ok(CausalKernel {
305            config: self.config,
306            treatment_weights: Some(treatment_weights),
307            outcome_weights: Some(outcome_weights),
308            propensity_scores: Some(propensity_scores),
309            treatment_effects: Some(treatment_effects),
310            _state: PhantomData,
311        })
312    }
313}
314
315impl Transform<Array2<Float>, Array2<Float>> for CausalKernel<Trained> {
316    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
317        let treatment_weights = self.treatment_weights.as_ref().unwrap();
318        let outcome_weights = self.outcome_weights.as_ref().unwrap();
319
320        // Assume input has same structure as training data
321        let n_covariates = treatment_weights.nrows();
322
323        if x.ncols() < n_covariates {
324            return Err(SklearsError::InvalidInput(format!(
325                "Input must have at least {} columns",
326                n_covariates
327            )));
328        }
329
330        let covariates = x.slice_axis(Axis(1), (0..n_covariates).into());
331
332        // Compute treatment and outcome features
333        let treatment_projection = covariates.dot(treatment_weights);
334        let outcome_projection = covariates.dot(outcome_weights);
335
336        let n_samples = x.nrows();
337        let n_features = self.config.n_components * 2;
338        let mut output = Array2::zeros((n_samples, n_features));
339
340        let normalizer = (2.0 / self.config.n_components as Float).sqrt();
341
342        for i in 0..n_samples {
343            for j in 0..self.config.n_components {
344                // Treatment features
345                output[[i, j]] = normalizer * treatment_projection[[i, j]].cos();
346                // Outcome features
347                output[[i, j + self.config.n_components]] =
348                    normalizer * outcome_projection[[i, j]].cos();
349            }
350        }
351
352        Ok(output)
353    }
354}
355
356impl CausalKernel<Trained> {
357    /// Get estimated propensity scores
358    pub fn propensity_scores(&self) -> &Array1<Float> {
359        self.propensity_scores.as_ref().unwrap()
360    }
361
362    /// Get estimated treatment effects
363    pub fn treatment_effects(&self) -> &HashMap<String, Float> {
364        self.treatment_effects.as_ref().unwrap()
365    }
366
367    /// Get average treatment effect
368    pub fn ate(&self) -> Float {
369        self.treatment_effects
370            .as_ref()
371            .unwrap()
372            .get("ate")
373            .copied()
374            .unwrap_or(0.0)
375    }
376}
377
378/// Counterfactual Kernel Approximation
379///
380/// Implements kernel methods for counterfactual reasoning: "What would have
381/// happened if the treatment had been different?"
382///
383/// # Mathematical Background
384///
385/// Counterfactual: Y(t) | X=x, T=t', Y=y
386/// Uses nearest-neighbor matching in kernel feature space with propensity
387/// score adjustment.
388///
389/// # Examples
390///
391/// ```rust,ignore
392/// use sklears_kernel_approximation::causal_kernels::{CounterfactualKernel, CausalKernelConfig};
393/// use scirs2_core::ndarray::array;
394/// use sklears_core::traits::{Fit, Transform};
395///
396/// let config = CausalKernelConfig::default();
397/// let cf = CounterfactualKernel::new(config);
398///
399/// let data = array![[1.0, 0.0, 2.0], [2.0, 1.0, 5.0]];
400/// let fitted = cf.fit(&data, &()).unwrap();
401/// ```
402#[derive(Debug, Clone)]
403pub struct CounterfactualKernel<State = Untrained> {
404    config: CausalKernelConfig,
405
406    // Fitted attributes
407    training_data: Option<Array2<Float>>,
408    kernel_features: Option<Array2<Float>>,
409    propensity_scores: Option<Array1<Float>>,
410
411    _state: PhantomData<State>,
412}
413
414impl CounterfactualKernel<Untrained> {
415    /// Create a new counterfactual kernel
416    pub fn new(config: CausalKernelConfig) -> Self {
417        Self {
418            config,
419            training_data: None,
420            kernel_features: None,
421            propensity_scores: None,
422            _state: PhantomData,
423        }
424    }
425
426    /// Create with default configuration
427    pub fn with_components(n_components: usize) -> Self {
428        Self::new(CausalKernelConfig {
429            n_components,
430            ..Default::default()
431        })
432    }
433}
434
435impl Estimator for CounterfactualKernel<Untrained> {
436    type Config = CausalKernelConfig;
437    type Error = SklearsError;
438    type Float = Float;
439
440    fn config(&self) -> &Self::Config {
441        &self.config
442    }
443}
444
445impl Fit<Array2<Float>, ()> for CounterfactualKernel<Untrained> {
446    type Fitted = CounterfactualKernel<Trained>;
447
448    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
449        if x.nrows() < 2 || x.ncols() < 3 {
450            return Err(SklearsError::InvalidInput(
451                "Input must have at least 2 samples and 3 columns".to_string(),
452            ));
453        }
454
455        let training_data = x.clone();
456
457        // Extract covariates and treatment
458        let n_covariates = x.ncols() - 2;
459        let covariates = x.slice_axis(Axis(1), (0..n_covariates).into()).to_owned();
460        let treatment = x.column(n_covariates).to_owned();
461
462        // Compute propensity scores using kernel density estimation
463        let n_samples = x.nrows();
464        let mut propensity_scores = Array1::zeros(n_samples);
465
466        for i in 0..n_samples {
467            let mut score = 0.0;
468            let mut weight_sum = 0.0;
469
470            for j in 0..n_samples {
471                let mut dist_sq = 0.0;
472                for k in 0..n_covariates {
473                    let diff = covariates[[i, k]] - covariates[[j, k]];
474                    dist_sq += diff * diff;
475                }
476
477                let weight = (-dist_sq / (2.0 * self.config.treatment_bandwidth.powi(2))).exp();
478                score += weight * treatment[j];
479                weight_sum += weight;
480            }
481
482            propensity_scores[i] = if weight_sum > 1e-10 {
483                (score / weight_sum).max(0.01).min(0.99)
484            } else {
485                0.5
486            };
487        }
488
489        // Generate kernel features for matching
490        let mut rng = thread_rng();
491        let normal = Normal::new(0.0, 1.0).unwrap();
492
493        let random_weights =
494            Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
495                rng.sample(normal) * (2.0 * self.config.treatment_bandwidth).sqrt()
496            });
497
498        let projection = covariates.dot(&random_weights);
499        let mut kernel_features = Array2::zeros((n_samples, self.config.n_components));
500
501        for i in 0..n_samples {
502            for j in 0..self.config.n_components {
503                kernel_features[[i, j]] = projection[[i, j]].cos();
504            }
505        }
506
507        Ok(CounterfactualKernel {
508            config: self.config,
509            training_data: Some(training_data),
510            kernel_features: Some(kernel_features),
511            propensity_scores: Some(propensity_scores),
512            _state: PhantomData,
513        })
514    }
515}
516
517impl Transform<Array2<Float>, Array2<Float>> for CounterfactualKernel<Trained> {
518    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
519        let training_data = self.training_data.as_ref().unwrap();
520        let kernel_features = self.kernel_features.as_ref().unwrap();
521
522        let n_covariates = training_data.ncols() - 2;
523
524        if x.ncols() < n_covariates {
525            return Err(SklearsError::InvalidInput(format!(
526                "Input must have at least {} columns",
527                n_covariates
528            )));
529        }
530
531        // For each test sample, find nearest neighbors in kernel space
532        // and compute counterfactual outcomes
533        let n_samples = x.nrows();
534        let mut output = Array2::zeros((n_samples, self.config.n_components + 2));
535
536        for i in 0..n_samples {
537            // Find k nearest neighbors (k=5)
538            let k = 5.min(kernel_features.nrows());
539            let mut distances = Vec::new();
540
541            for j in 0..kernel_features.nrows() {
542                let mut dist = 0.0;
543                for l in 0..n_covariates {
544                    let diff = x[[i, l]] - training_data[[j, l]];
545                    dist += diff * diff;
546                }
547                distances.push((dist, j));
548            }
549
550            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
551
552            // Compute weighted average of outcomes from nearest neighbors
553            let mut treated_outcome = 0.0;
554            let mut control_outcome = 0.0;
555            let mut treated_weight = 0.0;
556            let mut control_weight = 0.0;
557
558            for &(dist, idx) in distances.iter().take(k) {
559                let weight = (-dist / self.config.treatment_bandwidth).exp();
560                let treatment_val = training_data[[idx, n_covariates]];
561                let outcome_val = training_data[[idx, n_covariates + 1]];
562
563                if treatment_val > 0.5 {
564                    treated_outcome += weight * outcome_val;
565                    treated_weight += weight;
566                } else {
567                    control_outcome += weight * outcome_val;
568                    control_weight += weight;
569                }
570            }
571
572            // Store counterfactual estimates
573            output[[i, 0]] = if treated_weight > 0.0 {
574                treated_outcome / treated_weight
575            } else {
576                0.0
577            };
578
579            output[[i, 1]] = if control_weight > 0.0 {
580                control_outcome / control_weight
581            } else {
582                0.0
583            };
584
585            // Store kernel features
586            for j in 0..self.config.n_components {
587                if j < kernel_features.ncols() {
588                    output[[i, j + 2]] = kernel_features[[distances[0].1, j]];
589                }
590            }
591        }
592
593        Ok(output)
594    }
595}
596
597impl CounterfactualKernel<Trained> {
598    /// Get propensity scores from training data
599    pub fn propensity_scores(&self) -> &Array1<Float> {
600        self.propensity_scores.as_ref().unwrap()
601    }
602
603    /// Estimate individual treatment effect for a sample
604    pub fn estimate_ite(&self, sample: &Array2<Float>) -> Result<Float> {
605        let counterfactuals = self.transform(sample)?;
606
607        if counterfactuals.nrows() > 0 {
608            // ITE = E[Y(1)] - E[Y(0)]
609            Ok(counterfactuals[[0, 0]] - counterfactuals[[0, 1]])
610        } else {
611            Ok(0.0)
612        }
613    }
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619    use scirs2_core::ndarray::array;
620
621    #[test]
622    fn test_causal_kernel_basic() {
623        let config = CausalKernelConfig {
624            n_components: 20,
625            treatment_bandwidth: 1.0,
626            outcome_bandwidth: 1.0,
627            ..Default::default()
628        };
629
630        let causal = CausalKernel::new(config);
631
632        // Data: [covariate1, covariate2, treatment, outcome]
633        let data = array![
634            [1.0, 2.0, 0.0, 1.0],
635            [2.0, 3.0, 1.0, 5.0],
636            [1.5, 2.5, 0.0, 2.0],
637            [2.5, 3.5, 1.0, 6.0],
638        ];
639
640        let fitted = causal.fit(&data, &()).unwrap();
641        let features = fitted.transform(&data).unwrap();
642
643        assert_eq!(features.nrows(), 4);
644        assert_eq!(features.ncols(), 40); // 2 * n_components
645    }
646
647    #[test]
648    fn test_propensity_score_estimation() {
649        let config = CausalKernelConfig::default();
650        let causal = CausalKernel::new(config);
651
652        let data = array![
653            [1.0, 0.0, 1.0],
654            [2.0, 1.0, 5.0],
655            [1.5, 0.0, 2.0],
656            [2.5, 1.0, 6.0],
657        ];
658
659        let fitted = causal.fit(&data, &()).unwrap();
660        let scores = fitted.propensity_scores();
661
662        // Propensity scores should be between 0 and 1
663        assert!(scores.iter().all(|&s| s >= 0.0 && s <= 1.0));
664    }
665
666    #[test]
667    fn test_treatment_effect_estimation() {
668        let config = CausalKernelConfig::default();
669        let causal = CausalKernel::new(config);
670
671        let data = array![
672            [1.0, 0.0, 1.0],
673            [2.0, 1.0, 5.0],
674            [1.5, 0.0, 2.0],
675            [2.5, 1.0, 6.0],
676        ];
677
678        let fitted = causal.fit(&data, &()).unwrap();
679        let effects = fitted.treatment_effects();
680
681        assert!(effects.contains_key("ate"));
682        assert!(effects.contains_key("naive_difference"));
683        assert!(effects["ate"].is_finite());
684    }
685
686    #[test]
687    fn test_counterfactual_kernel() {
688        let config = CausalKernelConfig {
689            n_components: 10,
690            ..Default::default()
691        };
692
693        let cf = CounterfactualKernel::new(config);
694
695        let data = array![
696            [1.0, 0.0, 1.0],
697            [2.0, 1.0, 5.0],
698            [1.5, 0.0, 2.0],
699            [2.5, 1.0, 6.0],
700        ];
701
702        let fitted = cf.fit(&data, &()).unwrap();
703        let test_data = array![[1.2], [2.3]];
704        let counterfactuals = fitted.transform(&test_data).unwrap();
705
706        assert_eq!(counterfactuals.nrows(), 2);
707        // First 2 columns are treated and control outcomes, rest are features
708        assert_eq!(counterfactuals.ncols(), 12);
709    }
710
711    #[test]
712    fn test_individual_treatment_effect() {
713        let config = CausalKernelConfig {
714            n_components: 10,
715            ..Default::default()
716        };
717
718        let cf = CounterfactualKernel::new(config);
719
720        let data = array![
721            [1.0, 0.0, 1.0],
722            [2.0, 1.0, 5.0],
723            [1.5, 0.0, 2.0],
724            [2.5, 1.0, 6.0],
725        ];
726
727        let fitted = cf.fit(&data, &()).unwrap();
728        let test_sample = array![[1.5]];
729        let ite = fitted.estimate_ite(&test_sample).unwrap();
730
731        assert!(ite.is_finite());
732    }
733
734    #[test]
735    fn test_empty_input_error() {
736        let causal = CausalKernel::with_components(20);
737        let empty_data: Array2<Float> = Array2::zeros((0, 0));
738
739        assert!(causal.fit(&empty_data, &()).is_err());
740    }
741
742    #[test]
743    fn test_insufficient_columns_error() {
744        let causal = CausalKernel::with_components(20);
745        let data = array![[1.0, 2.0]]; // Only 2 columns, need at least 3
746
747        assert!(causal.fit(&data, &()).is_err());
748    }
749
750    #[test]
751    fn test_different_causal_methods() {
752        let methods = vec![
753            CausalMethod::TreatmentEffect,
754            CausalMethod::ConditionalTreatmentEffect,
755            CausalMethod::InstrumentalVariable,
756        ];
757
758        let data = array![
759            [1.0, 0.0, 1.0],
760            [2.0, 1.0, 5.0],
761            [1.5, 0.0, 2.0],
762            [2.5, 1.0, 6.0],
763        ];
764
765        for method in methods {
766            let causal = CausalKernel::with_components(20).method(method);
767            let fitted = causal.fit(&data, &()).unwrap();
768            let features = fitted.transform(&data).unwrap();
769
770            assert_eq!(features.nrows(), 4);
771        }
772    }
773}