Skip to main content

sklears_kernel_approximation/
causal_kernels.rs

1//! Causal Inference Kernel Methods
2//!
3//! This module implements kernel methods for causal inference, including treatment
4//! effect estimation, interventional distributions, counterfactual reasoning, and
5//! causal discovery from observational data.
6//!
7//! # References
8//! - Pearl (2009): "Causality: Models, Reasoning and Inference"
9//! - Schölkopf et al. (2021): "Toward Causal Representation Learning"
10//! - Peters et al. (2017): "Elements of Causal Inference"
11//! - Gretton et al. (2012): "Kernel-based conditional independence test"
12
13use scirs2_core::ndarray::{Array1, Array2, Axis};
14use scirs2_core::random::essentials::Normal;
15use scirs2_core::random::thread_rng;
16use serde::{Deserialize, Serialize};
17use sklears_core::{
18    error::{Result, SklearsError},
19    prelude::{Fit, Transform},
20    traits::{Estimator, Trained, Untrained},
21    types::Float,
22};
23use std::collections::HashMap;
24use std::marker::PhantomData;
25
26/// Configuration for causal kernel methods
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct CausalKernelConfig {
29    /// Type of causal analysis
30    pub causal_method: CausalMethod,
31    /// Kernel bandwidth for treatment variables
32    pub treatment_bandwidth: Float,
33    /// Kernel bandwidth for outcome variables
34    pub outcome_bandwidth: Float,
35    /// Number of random features
36    pub n_components: usize,
37    /// Regularization parameter
38    pub regularization: Float,
39}
40
41impl Default for CausalKernelConfig {
42    fn default() -> Self {
43        Self {
44            causal_method: CausalMethod::TreatmentEffect,
45            treatment_bandwidth: 1.0,
46            outcome_bandwidth: 1.0,
47            n_components: 100,
48            regularization: 1e-5,
49        }
50    }
51}
52
53/// Types of causal analysis methods
54#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)]
55pub enum CausalMethod {
56    /// Average Treatment Effect (ATE) estimation
57    TreatmentEffect,
58    /// Conditional Average Treatment Effect (CATE)
59    ConditionalTreatmentEffect,
60    /// Instrumental variable method
61    InstrumentalVariable,
62    /// Regression discontinuity
63    RegressionDiscontinuity,
64    /// Difference-in-differences
65    DifferenceInDifferences,
66}
67
68/// Causal Kernel for Treatment Effect Estimation
69///
70/// Implements kernel-based methods for estimating causal effects from
71/// observational data, including propensity score weighting and doubly robust
72/// estimation.
73///
74/// # Mathematical Background
75///
76/// For treatment T, outcome Y, and covariates X:
77/// - ATE = E[Y(1) - Y(0)] = E[Y|do(T=1)] - E[Y|do(T=0)]
78/// - Uses kernel-based representation: τ(x) = E[Y|T=1,X=x] - E[Y|T=0,X=x]
79///
80/// # Examples
81///
82/// ```rust,ignore
83/// use sklears_kernel_approximation::causal_kernels::{CausalKernel, CausalKernelConfig};
84/// use scirs2_core::ndarray::array;
85/// use sklears_core::traits::{Fit, Transform};
86///
87/// let config = CausalKernelConfig::default();
88/// let causal = CausalKernel::new(config);
89///
90/// // X: covariates, T: treatment, Y: outcome
91/// let data = array![[1.0, 0.0, 2.0], [2.0, 1.0, 5.0]];
92/// let fitted = causal.fit(&data, &()).unwrap();
93/// let features = fitted.transform(&data).unwrap();
94/// ```
95#[derive(Debug, Clone)]
96pub struct CausalKernel<State = Untrained> {
97    config: CausalKernelConfig,
98
99    // Fitted attributes
100    treatment_weights: Option<Array2<Float>>,
101    outcome_weights: Option<Array2<Float>>,
102    propensity_scores: Option<Array1<Float>>,
103    treatment_effects: Option<HashMap<String, Float>>,
104
105    _state: PhantomData<State>,
106}
107
108impl CausalKernel<Untrained> {
109    /// Create a new causal kernel
110    pub fn new(config: CausalKernelConfig) -> Self {
111        Self {
112            config,
113            treatment_weights: None,
114            outcome_weights: None,
115            propensity_scores: None,
116            treatment_effects: None,
117            _state: PhantomData,
118        }
119    }
120
121    /// Create with default configuration
122    pub fn with_components(n_components: usize) -> Self {
123        Self {
124            config: CausalKernelConfig {
125                n_components,
126                ..Default::default()
127            },
128            treatment_weights: None,
129            outcome_weights: None,
130            propensity_scores: None,
131            treatment_effects: None,
132            _state: PhantomData,
133        }
134    }
135
136    /// Set causal method
137    pub fn method(mut self, method: CausalMethod) -> Self {
138        self.config.causal_method = method;
139        self
140    }
141
142    /// Set treatment bandwidth
143    pub fn treatment_bandwidth(mut self, gamma: Float) -> Self {
144        self.config.treatment_bandwidth = gamma;
145        self
146    }
147
148    /// Estimate propensity scores (probability of treatment given covariates)
149    fn estimate_propensity_scores(
150        &self,
151        x: &Array2<Float>,
152        treatment: &Array1<Float>,
153    ) -> Array1<Float> {
154        let n_samples = x.nrows();
155        let mut scores = Array1::zeros(n_samples);
156
157        // Simple logistic kernel density estimation
158        for i in 0..n_samples {
159            let mut score = 0.0;
160            let mut weight_sum = 0.0;
161
162            for j in 0..n_samples {
163                // Compute kernel similarity
164                let mut dist_sq = 0.0;
165                for k in 0..x.ncols() {
166                    let diff = x[[i, k]] - x[[j, k]];
167                    dist_sq += diff * diff;
168                }
169
170                let weight = (-dist_sq / (2.0 * self.config.treatment_bandwidth.powi(2))).exp();
171                score += weight * treatment[j];
172                weight_sum += weight;
173            }
174
175            scores[i] = if weight_sum > 1e-10 {
176                (score / weight_sum).max(0.01).min(0.99) // Clip for stability
177            } else {
178                0.5
179            };
180        }
181
182        scores
183    }
184
185    /// Estimate treatment effect using inverse propensity weighting
186    fn estimate_treatment_effect(
187        &self,
188        x: &Array2<Float>,
189        treatment: &Array1<Float>,
190        outcome: &Array1<Float>,
191        propensity_scores: &Array1<Float>,
192    ) -> HashMap<String, Float> {
193        let n_samples = x.nrows() as Float;
194
195        // Average Treatment Effect (ATE) using IPW
196        let mut ate_numerator_treated = 0.0;
197        let mut ate_numerator_control = 0.0;
198        let mut weight_sum_treated = 0.0;
199        let mut weight_sum_control = 0.0;
200
201        for i in 0..treatment.len() {
202            if treatment[i] > 0.5 {
203                // Treated group
204                let weight = 1.0 / propensity_scores[i];
205                ate_numerator_treated += weight * outcome[i];
206                weight_sum_treated += weight;
207            } else {
208                // Control group
209                let weight = 1.0 / (1.0 - propensity_scores[i]);
210                ate_numerator_control += weight * outcome[i];
211                weight_sum_control += weight;
212            }
213        }
214
215        let ate = if weight_sum_treated > 0.0 && weight_sum_control > 0.0 {
216            (ate_numerator_treated / weight_sum_treated)
217                - (ate_numerator_control / weight_sum_control)
218        } else {
219            0.0
220        };
221
222        // Naive difference (for comparison)
223        let treated_outcomes: Vec<Float> = treatment
224            .iter()
225            .zip(outcome.iter())
226            .filter_map(|(&t, &y)| if t > 0.5 { Some(y) } else { None })
227            .collect();
228
229        let control_outcomes: Vec<Float> = treatment
230            .iter()
231            .zip(outcome.iter())
232            .filter_map(|(&t, &y)| if t <= 0.5 { Some(y) } else { None })
233            .collect();
234
235        let naive_diff = if !treated_outcomes.is_empty() && !control_outcomes.is_empty() {
236            let treated_mean =
237                treated_outcomes.iter().sum::<Float>() / treated_outcomes.len() as Float;
238            let control_mean =
239                control_outcomes.iter().sum::<Float>() / control_outcomes.len() as Float;
240            treated_mean - control_mean
241        } else {
242            0.0
243        };
244
245        let mut effects = HashMap::new();
246        effects.insert("ate".to_string(), ate);
247        effects.insert("naive_difference".to_string(), naive_diff);
248        effects.insert("n_samples".to_string(), n_samples);
249        effects.insert("n_treated".to_string(), treated_outcomes.len() as Float);
250        effects.insert("n_control".to_string(), control_outcomes.len() as Float);
251
252        effects
253    }
254}
255
256impl Estimator for CausalKernel<Untrained> {
257    type Config = CausalKernelConfig;
258    type Error = SklearsError;
259    type Float = Float;
260
261    fn config(&self) -> &Self::Config {
262        &self.config
263    }
264}
265
266impl Fit<Array2<Float>, ()> for CausalKernel<Untrained> {
267    type Fitted = CausalKernel<Trained>;
268
269    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
270        if x.nrows() < 2 || x.ncols() < 3 {
271            return Err(SklearsError::InvalidInput(
272                "Input must have at least 2 samples and 3 columns (covariates, treatment, outcome)"
273                    .to_string(),
274            ));
275        }
276
277        // Assume last 2 columns are treatment and outcome
278        let n_covariates = x.ncols() - 2;
279        let covariates = x.slice_axis(Axis(1), (0..n_covariates).into()).to_owned();
280        let treatment = x.column(n_covariates).to_owned();
281        let outcome = x.column(n_covariates + 1).to_owned();
282
283        // Estimate propensity scores
284        let propensity_scores = self.estimate_propensity_scores(&covariates, &treatment);
285
286        // Estimate treatment effects
287        let treatment_effects =
288            self.estimate_treatment_effect(&covariates, &treatment, &outcome, &propensity_scores);
289
290        // Generate random features for kernel approximation
291        let mut rng = thread_rng();
292        let normal = Normal::new(0.0, 1.0).expect("operation should succeed");
293
294        let treatment_weights =
295            Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
296                rng.sample(normal) * (2.0 * self.config.treatment_bandwidth).sqrt()
297            });
298
299        let outcome_weights =
300            Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
301                rng.sample(normal) * (2.0 * self.config.outcome_bandwidth).sqrt()
302            });
303
304        Ok(CausalKernel {
305            config: self.config,
306            treatment_weights: Some(treatment_weights),
307            outcome_weights: Some(outcome_weights),
308            propensity_scores: Some(propensity_scores),
309            treatment_effects: Some(treatment_effects),
310            _state: PhantomData,
311        })
312    }
313}
314
315impl Transform<Array2<Float>, Array2<Float>> for CausalKernel<Trained> {
316    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
317        let treatment_weights = self
318            .treatment_weights
319            .as_ref()
320            .expect("operation should succeed");
321        let outcome_weights = self
322            .outcome_weights
323            .as_ref()
324            .expect("operation should succeed");
325
326        // Assume input has same structure as training data
327        let n_covariates = treatment_weights.nrows();
328
329        if x.ncols() < n_covariates {
330            return Err(SklearsError::InvalidInput(format!(
331                "Input must have at least {} columns",
332                n_covariates
333            )));
334        }
335
336        let covariates = x.slice_axis(Axis(1), (0..n_covariates).into());
337
338        // Compute treatment and outcome features
339        let treatment_projection = covariates.dot(treatment_weights);
340        let outcome_projection = covariates.dot(outcome_weights);
341
342        let n_samples = x.nrows();
343        let n_features = self.config.n_components * 2;
344        let mut output = Array2::zeros((n_samples, n_features));
345
346        let normalizer = (2.0 / self.config.n_components as Float).sqrt();
347
348        for i in 0..n_samples {
349            for j in 0..self.config.n_components {
350                // Treatment features
351                output[[i, j]] = normalizer * treatment_projection[[i, j]].cos();
352                // Outcome features
353                output[[i, j + self.config.n_components]] =
354                    normalizer * outcome_projection[[i, j]].cos();
355            }
356        }
357
358        Ok(output)
359    }
360}
361
362impl CausalKernel<Trained> {
363    /// Get estimated propensity scores
364    pub fn propensity_scores(&self) -> &Array1<Float> {
365        self.propensity_scores
366            .as_ref()
367            .expect("operation should succeed")
368    }
369
370    /// Get estimated treatment effects
371    pub fn treatment_effects(&self) -> &HashMap<String, Float> {
372        self.treatment_effects
373            .as_ref()
374            .expect("operation should succeed")
375    }
376
377    /// Get average treatment effect
378    pub fn ate(&self) -> Float {
379        self.treatment_effects
380            .as_ref()
381            .expect("operation should succeed")
382            .get("ate")
383            .copied()
384            .unwrap_or(0.0)
385    }
386}
387
388/// Counterfactual Kernel Approximation
389///
390/// Implements kernel methods for counterfactual reasoning: "What would have
391/// happened if the treatment had been different?"
392///
393/// # Mathematical Background
394///
395/// Counterfactual: Y(t) | X=x, T=t', Y=y
396/// Uses nearest-neighbor matching in kernel feature space with propensity
397/// score adjustment.
398///
399/// # Examples
400///
401/// ```rust,ignore
402/// use sklears_kernel_approximation::causal_kernels::{CounterfactualKernel, CausalKernelConfig};
403/// use scirs2_core::ndarray::array;
404/// use sklears_core::traits::{Fit, Transform};
405///
406/// let config = CausalKernelConfig::default();
407/// let cf = CounterfactualKernel::new(config);
408///
409/// let data = array![[1.0, 0.0, 2.0], [2.0, 1.0, 5.0]];
410/// let fitted = cf.fit(&data, &()).unwrap();
411/// ```
412#[derive(Debug, Clone)]
413pub struct CounterfactualKernel<State = Untrained> {
414    config: CausalKernelConfig,
415
416    // Fitted attributes
417    training_data: Option<Array2<Float>>,
418    kernel_features: Option<Array2<Float>>,
419    propensity_scores: Option<Array1<Float>>,
420
421    _state: PhantomData<State>,
422}
423
424impl CounterfactualKernel<Untrained> {
425    /// Create a new counterfactual kernel
426    pub fn new(config: CausalKernelConfig) -> Self {
427        Self {
428            config,
429            training_data: None,
430            kernel_features: None,
431            propensity_scores: None,
432            _state: PhantomData,
433        }
434    }
435
436    /// Create with default configuration
437    pub fn with_components(n_components: usize) -> Self {
438        Self::new(CausalKernelConfig {
439            n_components,
440            ..Default::default()
441        })
442    }
443}
444
445impl Estimator for CounterfactualKernel<Untrained> {
446    type Config = CausalKernelConfig;
447    type Error = SklearsError;
448    type Float = Float;
449
450    fn config(&self) -> &Self::Config {
451        &self.config
452    }
453}
454
455impl Fit<Array2<Float>, ()> for CounterfactualKernel<Untrained> {
456    type Fitted = CounterfactualKernel<Trained>;
457
458    fn fit(self, x: &Array2<Float>, _y: &()) -> Result<Self::Fitted> {
459        if x.nrows() < 2 || x.ncols() < 3 {
460            return Err(SklearsError::InvalidInput(
461                "Input must have at least 2 samples and 3 columns".to_string(),
462            ));
463        }
464
465        let training_data = x.clone();
466
467        // Extract covariates and treatment
468        let n_covariates = x.ncols() - 2;
469        let covariates = x.slice_axis(Axis(1), (0..n_covariates).into()).to_owned();
470        let treatment = x.column(n_covariates).to_owned();
471
472        // Compute propensity scores using kernel density estimation
473        let n_samples = x.nrows();
474        let mut propensity_scores = Array1::zeros(n_samples);
475
476        for i in 0..n_samples {
477            let mut score = 0.0;
478            let mut weight_sum = 0.0;
479
480            for j in 0..n_samples {
481                let mut dist_sq = 0.0;
482                for k in 0..n_covariates {
483                    let diff = covariates[[i, k]] - covariates[[j, k]];
484                    dist_sq += diff * diff;
485                }
486
487                let weight = (-dist_sq / (2.0 * self.config.treatment_bandwidth.powi(2))).exp();
488                score += weight * treatment[j];
489                weight_sum += weight;
490            }
491
492            propensity_scores[i] = if weight_sum > 1e-10 {
493                (score / weight_sum).max(0.01).min(0.99)
494            } else {
495                0.5
496            };
497        }
498
499        // Generate kernel features for matching
500        let mut rng = thread_rng();
501        let normal = Normal::new(0.0, 1.0).expect("operation should succeed");
502
503        let random_weights =
504            Array2::from_shape_fn((n_covariates, self.config.n_components), |_| {
505                rng.sample(normal) * (2.0 * self.config.treatment_bandwidth).sqrt()
506            });
507
508        let projection = covariates.dot(&random_weights);
509        let mut kernel_features = Array2::zeros((n_samples, self.config.n_components));
510
511        for i in 0..n_samples {
512            for j in 0..self.config.n_components {
513                kernel_features[[i, j]] = projection[[i, j]].cos();
514            }
515        }
516
517        Ok(CounterfactualKernel {
518            config: self.config,
519            training_data: Some(training_data),
520            kernel_features: Some(kernel_features),
521            propensity_scores: Some(propensity_scores),
522            _state: PhantomData,
523        })
524    }
525}
526
527impl Transform<Array2<Float>, Array2<Float>> for CounterfactualKernel<Trained> {
528    fn transform(&self, x: &Array2<Float>) -> Result<Array2<Float>> {
529        let training_data = self
530            .training_data
531            .as_ref()
532            .expect("operation should succeed");
533        let kernel_features = self
534            .kernel_features
535            .as_ref()
536            .expect("operation should succeed");
537
538        let n_covariates = training_data.ncols() - 2;
539
540        if x.ncols() < n_covariates {
541            return Err(SklearsError::InvalidInput(format!(
542                "Input must have at least {} columns",
543                n_covariates
544            )));
545        }
546
547        // For each test sample, find nearest neighbors in kernel space
548        // and compute counterfactual outcomes
549        let n_samples = x.nrows();
550        let mut output = Array2::zeros((n_samples, self.config.n_components + 2));
551
552        for i in 0..n_samples {
553            // Find k nearest neighbors (k=5)
554            let k = 5.min(kernel_features.nrows());
555            let mut distances = Vec::new();
556
557            for j in 0..kernel_features.nrows() {
558                let mut dist = 0.0;
559                for l in 0..n_covariates {
560                    let diff = x[[i, l]] - training_data[[j, l]];
561                    dist += diff * diff;
562                }
563                distances.push((dist, j));
564            }
565
566            distances.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("operation should succeed"));
567
568            // Compute weighted average of outcomes from nearest neighbors
569            let mut treated_outcome = 0.0;
570            let mut control_outcome = 0.0;
571            let mut treated_weight = 0.0;
572            let mut control_weight = 0.0;
573
574            for &(dist, idx) in distances.iter().take(k) {
575                let weight = (-dist / self.config.treatment_bandwidth).exp();
576                let treatment_val = training_data[[idx, n_covariates]];
577                let outcome_val = training_data[[idx, n_covariates + 1]];
578
579                if treatment_val > 0.5 {
580                    treated_outcome += weight * outcome_val;
581                    treated_weight += weight;
582                } else {
583                    control_outcome += weight * outcome_val;
584                    control_weight += weight;
585                }
586            }
587
588            // Store counterfactual estimates
589            output[[i, 0]] = if treated_weight > 0.0 {
590                treated_outcome / treated_weight
591            } else {
592                0.0
593            };
594
595            output[[i, 1]] = if control_weight > 0.0 {
596                control_outcome / control_weight
597            } else {
598                0.0
599            };
600
601            // Store kernel features
602            for j in 0..self.config.n_components {
603                if j < kernel_features.ncols() {
604                    output[[i, j + 2]] = kernel_features[[distances[0].1, j]];
605                }
606            }
607        }
608
609        Ok(output)
610    }
611}
612
613impl CounterfactualKernel<Trained> {
614    /// Get propensity scores from training data
615    pub fn propensity_scores(&self) -> &Array1<Float> {
616        self.propensity_scores
617            .as_ref()
618            .expect("operation should succeed")
619    }
620
621    /// Estimate individual treatment effect for a sample
622    pub fn estimate_ite(&self, sample: &Array2<Float>) -> Result<Float> {
623        let counterfactuals = self.transform(sample)?;
624
625        if counterfactuals.nrows() > 0 {
626            // ITE = E[Y(1)] - E[Y(0)]
627            Ok(counterfactuals[[0, 0]] - counterfactuals[[0, 1]])
628        } else {
629            Ok(0.0)
630        }
631    }
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637    use scirs2_core::ndarray::array;
638
639    #[test]
640    fn test_causal_kernel_basic() {
641        let config = CausalKernelConfig {
642            n_components: 20,
643            treatment_bandwidth: 1.0,
644            outcome_bandwidth: 1.0,
645            ..Default::default()
646        };
647
648        let causal = CausalKernel::new(config);
649
650        // Data: [covariate1, covariate2, treatment, outcome]
651        let data = array![
652            [1.0, 2.0, 0.0, 1.0],
653            [2.0, 3.0, 1.0, 5.0],
654            [1.5, 2.5, 0.0, 2.0],
655            [2.5, 3.5, 1.0, 6.0],
656        ];
657
658        let fitted = causal.fit(&data, &()).expect("operation should succeed");
659        let features = fitted.transform(&data).expect("operation should succeed");
660
661        assert_eq!(features.nrows(), 4);
662        assert_eq!(features.ncols(), 40); // 2 * n_components
663    }
664
665    #[test]
666    fn test_propensity_score_estimation() {
667        let config = CausalKernelConfig::default();
668        let causal = CausalKernel::new(config);
669
670        let data = array![
671            [1.0, 0.0, 1.0],
672            [2.0, 1.0, 5.0],
673            [1.5, 0.0, 2.0],
674            [2.5, 1.0, 6.0],
675        ];
676
677        let fitted = causal.fit(&data, &()).expect("operation should succeed");
678        let scores = fitted.propensity_scores();
679
680        // Propensity scores should be between 0 and 1
681        assert!(scores.iter().all(|&s| s >= 0.0 && s <= 1.0));
682    }
683
684    #[test]
685    fn test_treatment_effect_estimation() {
686        let config = CausalKernelConfig::default();
687        let causal = CausalKernel::new(config);
688
689        let data = array![
690            [1.0, 0.0, 1.0],
691            [2.0, 1.0, 5.0],
692            [1.5, 0.0, 2.0],
693            [2.5, 1.0, 6.0],
694        ];
695
696        let fitted = causal.fit(&data, &()).expect("operation should succeed");
697        let effects = fitted.treatment_effects();
698
699        assert!(effects.contains_key("ate"));
700        assert!(effects.contains_key("naive_difference"));
701        assert!(effects["ate"].is_finite());
702    }
703
704    #[test]
705    fn test_counterfactual_kernel() {
706        let config = CausalKernelConfig {
707            n_components: 10,
708            ..Default::default()
709        };
710
711        let cf = CounterfactualKernel::new(config);
712
713        let data = array![
714            [1.0, 0.0, 1.0],
715            [2.0, 1.0, 5.0],
716            [1.5, 0.0, 2.0],
717            [2.5, 1.0, 6.0],
718        ];
719
720        let fitted = cf.fit(&data, &()).expect("operation should succeed");
721        let test_data = array![[1.2], [2.3]];
722        let counterfactuals = fitted
723            .transform(&test_data)
724            .expect("operation should succeed");
725
726        assert_eq!(counterfactuals.nrows(), 2);
727        // First 2 columns are treated and control outcomes, rest are features
728        assert_eq!(counterfactuals.ncols(), 12);
729    }
730
731    #[test]
732    fn test_individual_treatment_effect() {
733        let config = CausalKernelConfig {
734            n_components: 10,
735            ..Default::default()
736        };
737
738        let cf = CounterfactualKernel::new(config);
739
740        let data = array![
741            [1.0, 0.0, 1.0],
742            [2.0, 1.0, 5.0],
743            [1.5, 0.0, 2.0],
744            [2.5, 1.0, 6.0],
745        ];
746
747        let fitted = cf.fit(&data, &()).expect("operation should succeed");
748        let test_sample = array![[1.5]];
749        let ite = fitted
750            .estimate_ite(&test_sample)
751            .expect("operation should succeed");
752
753        assert!(ite.is_finite());
754    }
755
756    #[test]
757    fn test_empty_input_error() {
758        let causal = CausalKernel::with_components(20);
759        let empty_data: Array2<Float> = Array2::zeros((0, 0));
760
761        assert!(causal.fit(&empty_data, &()).is_err());
762    }
763
764    #[test]
765    fn test_insufficient_columns_error() {
766        let causal = CausalKernel::with_components(20);
767        let data = array![[1.0, 2.0]]; // Only 2 columns, need at least 3
768
769        assert!(causal.fit(&data, &()).is_err());
770    }
771
772    #[test]
773    fn test_different_causal_methods() {
774        let methods = vec![
775            CausalMethod::TreatmentEffect,
776            CausalMethod::ConditionalTreatmentEffect,
777            CausalMethod::InstrumentalVariable,
778        ];
779
780        let data = array![
781            [1.0, 0.0, 1.0],
782            [2.0, 1.0, 5.0],
783            [1.5, 0.0, 2.0],
784            [2.5, 1.0, 6.0],
785        ];
786
787        for method in methods {
788            let causal = CausalKernel::with_components(20).method(method);
789            let fitted = causal.fit(&data, &()).expect("operation should succeed");
790            let features = fitted.transform(&data).expect("operation should succeed");
791
792            assert_eq!(features.nrows(), 4);
793        }
794    }
795}