sklears_tree/
splits.rs

1//! Split implementations for decision trees
2//!
3//! This module contains various splitting strategies including hyperplane splits,
4//! CHAID splits, and conditional inference splits with their statistical tests.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::random::thread_rng;
8use sklears_core::error::{Result, SklearsError};
9use std::collections::HashMap;
10
11use crate::criteria::{ConditionalTestType, FeatureType};
12
13/// Hyperplane split information for oblique trees
14#[derive(Debug, Clone)]
15pub struct HyperplaneSplit {
16    /// Feature coefficients for the hyperplane (w^T x >= threshold)
17    pub coefficients: Array1<f64>,
18    /// Threshold for the hyperplane split
19    pub threshold: f64,
20    /// Bias term for the hyperplane
21    pub bias: f64,
22    /// Impurity decrease achieved by this split
23    pub impurity_decrease: f64,
24}
25
26impl HyperplaneSplit {
27    /// Evaluate the hyperplane split for a sample
28    pub fn evaluate(&self, sample: &Array1<f64>) -> bool {
29        let dot_product = self.coefficients.dot(sample) + self.bias;
30        dot_product >= self.threshold
31    }
32
33    /// Create a random hyperplane with normalized coefficients
34    pub fn random(n_features: usize, rng: &mut scirs2_core::CoreRandom) -> Self {
35        let mut coefficients = Array1::zeros(n_features);
36        for i in 0..n_features {
37            coefficients[i] = rng.gen_range(-1.0..1.0);
38        }
39
40        // Normalize coefficients
41        let dot_product: f64 = coefficients.dot(&coefficients);
42        let norm = dot_product.sqrt();
43        if norm > 1e-10_f64 {
44            coefficients /= norm;
45        }
46
47        Self {
48            coefficients,
49            threshold: rng.gen_range(-1.0..1.0),
50            bias: rng.gen_range(-0.1..0.1),
51            impurity_decrease: 0.0,
52        }
53    }
54
55    /// Find optimal hyperplane using ridge regression
56    #[cfg(feature = "oblique")]
57    pub fn from_ridge_regression(x: &Array2<f64>, y: &Array1<f64>, alpha: f64) -> Result<Self> {
58        let n_features = x.ncols();
59        if x.nrows() < 2 {
60            return Err(SklearsError::InvalidInput(
61                "Need at least 2 samples for ridge regression".to_string(),
62            ));
63        }
64
65        // Add bias column to X
66        let mut x_bias = Array2::ones((x.nrows(), n_features + 1));
67        x_bias.slice_mut(s![.., ..n_features]).assign(x);
68
69        // Ridge regression: w = (X^T X + α I)^(-1) X^T y
70        let xtx = x_bias.t().dot(&x_bias);
71        let ridge_matrix = xtx + Array2::<f64>::eye(n_features + 1) * alpha;
72        let xty = x_bias.t().dot(y);
73
74        // Simple matrix inverse using Gauss-Jordan elimination
75        match gauss_jordan_inverse(&ridge_matrix) {
76            Ok(inv_matrix) => {
77                let coefficients_full = inv_matrix.dot(&xty);
78
79                let coefficients = coefficients_full.slice(s![..n_features]).to_owned();
80                let bias = coefficients_full[n_features];
81
82                Ok(Self {
83                    coefficients,
84                    threshold: 0.0, // Will be set during split evaluation
85                    bias,
86                    impurity_decrease: 0.0,
87                })
88            }
89            Err(_) => {
90                // Fallback to random hyperplane if matrix is singular
91                let mut rng = thread_rng();
92                Ok(Self::random(n_features, &mut rng))
93            }
94        }
95    }
96}
97
98/// Simple Gauss-Jordan elimination for matrix inversion
99#[cfg(feature = "oblique")]
100fn gauss_jordan_inverse(matrix: &Array2<f64>) -> Result<Array2<f64>> {
101    let n = matrix.nrows();
102    if n != matrix.ncols() {
103        return Err(SklearsError::InvalidInput(
104            "Matrix must be square".to_string(),
105        ));
106    }
107
108    // Create augmented matrix [A | I]
109    let mut augmented = Array2::zeros((n, 2 * n));
110
111    // Copy matrix to left side
112    for i in 0..n {
113        for j in 0..n {
114            augmented[[i, j]] = matrix[[i, j]];
115        }
116        // Identity matrix on right side
117        augmented[[i, i + n]] = 1.0;
118    }
119
120    // Forward elimination
121    for i in 0..n {
122        // Find pivot
123        let mut max_row = i;
124        for k in (i + 1)..n {
125            if augmented[[k, i]].abs() > augmented[[max_row, i]].abs() {
126                max_row = k;
127            }
128        }
129
130        // Check for singular matrix
131        if augmented[[max_row, i]].abs() < 1e-12 {
132            return Err(SklearsError::InvalidInput("Matrix is singular".to_string()));
133        }
134
135        // Swap rows
136        if max_row != i {
137            for j in 0..(2 * n) {
138                let temp = augmented[[i, j]];
139                augmented[[i, j]] = augmented[[max_row, j]];
140                augmented[[max_row, j]] = temp;
141            }
142        }
143
144        // Scale pivot row
145        let pivot = augmented[[i, i]];
146        for j in 0..(2 * n) {
147            augmented[[i, j]] /= pivot;
148        }
149
150        // Eliminate column
151        for k in 0..n {
152            if k != i {
153                let factor = augmented[[k, i]];
154                for j in 0..(2 * n) {
155                    augmented[[k, j]] -= factor * augmented[[i, j]];
156                }
157            }
158        }
159    }
160
161    // Extract inverse matrix from right side
162    let mut inverse = Array2::zeros((n, n));
163    for i in 0..n {
164        for j in 0..n {
165            inverse[[i, j]] = augmented[[i, j + n]];
166        }
167    }
168
169    Ok(inverse)
170}
171
172/// CHAID (Chi-squared Automatic Interaction Detection) split information
173#[derive(Debug, Clone)]
174pub struct ChaidSplit {
175    /// Feature index
176    pub feature_idx: usize,
177    /// Category groups after merging based on chi-squared tests
178    pub category_groups: Vec<Vec<String>>,
179    /// Chi-squared statistic
180    pub chi_squared: f64,
181    /// P-value of the chi-squared test
182    pub p_value: f64,
183    /// Degrees of freedom
184    pub degrees_of_freedom: usize,
185    /// Significance level used
186    pub significance_level: f64,
187}
188
189impl ChaidSplit {
190    /// Perform CHAID splitting analysis for a categorical feature
191    pub fn analyze_categorical_split(
192        feature_values: &[String],
193        target_values: &[i32],
194        significance_level: f64,
195    ) -> Result<Option<Self>> {
196        if feature_values.len() != target_values.len() {
197            return Err(SklearsError::InvalidInput(
198                "Feature and target arrays must have the same length".to_string(),
199            ));
200        }
201
202        if feature_values.is_empty() {
203            return Ok(None);
204        }
205
206        // Build contingency table
207        let contingency_table = build_contingency_table(feature_values, target_values)?;
208
209        // Perform iterative category merging based on chi-squared tests
210        let merged_categories = merge_categories_chaid(&contingency_table, significance_level)?;
211
212        if merged_categories.len() <= 1 {
213            return Ok(None); // No meaningful split possible
214        }
215
216        // Calculate final chi-squared statistic
217        let (chi_squared, p_value, df) = calculate_chi_squared(&contingency_table)?;
218
219        Ok(Some(ChaidSplit {
220            feature_idx: 0, // Will be set by caller
221            category_groups: merged_categories,
222            chi_squared,
223            p_value,
224            degrees_of_freedom: df,
225            significance_level,
226        }))
227    }
228
229    /// Check if the split is statistically significant
230    pub fn is_significant(&self) -> bool {
231        self.p_value < self.significance_level
232    }
233}
234
235/// Build contingency table for categorical feature vs target
236fn build_contingency_table(
237    feature_values: &[String],
238    target_values: &[i32],
239) -> Result<HashMap<String, HashMap<i32, usize>>> {
240    let mut table: HashMap<String, HashMap<i32, usize>> = HashMap::new();
241
242    for (feature_val, target_val) in feature_values.iter().zip(target_values.iter()) {
243        let target_counts = table.entry(feature_val.clone()).or_default();
244        *target_counts.entry(*target_val).or_insert(0) += 1;
245    }
246
247    Ok(table)
248}
249
250/// Merge categories based on chi-squared tests (CHAID algorithm)
251fn merge_categories_chaid(
252    contingency_table: &HashMap<String, HashMap<i32, usize>>,
253    significance_level: f64,
254) -> Result<Vec<Vec<String>>> {
255    let categories: Vec<String> = contingency_table.keys().cloned().collect();
256    let mut groups: Vec<Vec<String>> = categories.iter().map(|c| vec![c.clone()]).collect();
257
258    if groups.len() <= 1 {
259        return Ok(groups);
260    }
261
262    loop {
263        let mut best_merge: Option<(usize, usize, f64)> = None;
264        let mut min_chi_squared = f64::INFINITY;
265
266        // Find the pair of adjacent categories with the smallest chi-squared statistic
267        for i in 0..groups.len() {
268            for j in (i + 1)..groups.len() {
269                // Create merged contingency table for these two groups
270                let merged_table =
271                    create_merged_contingency_table(contingency_table, &groups[i], &groups[j])?;
272
273                if let Ok((chi_squared, p_value, _)) =
274                    calculate_chi_squared_for_merged(&merged_table)
275                {
276                    // If not significant (p > significance_level), consider for merging
277                    if p_value > significance_level && chi_squared < min_chi_squared {
278                        min_chi_squared = chi_squared;
279                        best_merge = Some((i, j, chi_squared));
280                    }
281                }
282            }
283        }
284
285        // If no merge found, stop
286        if let Some((i, j, _)) = best_merge {
287            // Merge groups i and j
288            let mut merged_group = groups[i].clone();
289            merged_group.extend(groups[j].clone());
290
291            // Remove the original groups and add the merged group
292            if i < j {
293                groups.remove(j);
294                groups.remove(i);
295            } else {
296                groups.remove(i);
297                groups.remove(j);
298            }
299            groups.push(merged_group);
300        } else {
301            break;
302        }
303
304        if groups.len() <= 1 {
305            break;
306        }
307    }
308
309    Ok(groups)
310}
311
312/// Create merged contingency table for two category groups
313fn create_merged_contingency_table(
314    original_table: &HashMap<String, HashMap<i32, usize>>,
315    group1: &[String],
316    group2: &[String],
317) -> Result<HashMap<i32, usize>> {
318    let mut merged_table = HashMap::new();
319
320    // Add counts from group1
321    for category in group1 {
322        if let Some(target_counts) = original_table.get(category) {
323            for (&target, &count) in target_counts {
324                *merged_table.entry(target).or_insert(0) += count;
325            }
326        }
327    }
328
329    // Add counts from group2
330    for category in group2 {
331        if let Some(target_counts) = original_table.get(category) {
332            for (&target, &count) in target_counts {
333                *merged_table.entry(target).or_insert(0) += count;
334            }
335        }
336    }
337
338    Ok(merged_table)
339}
340
341/// Calculate chi-squared statistic for contingency table
342fn calculate_chi_squared(
343    contingency_table: &HashMap<String, HashMap<i32, usize>>,
344) -> Result<(f64, f64, usize)> {
345    use std::collections::HashSet;
346
347    // Get all unique target values
348    let mut all_targets: HashSet<i32> = HashSet::new();
349    for target_counts in contingency_table.values() {
350        all_targets.extend(target_counts.keys());
351    }
352
353    if all_targets.len() <= 1 {
354        return Ok((0.0, 1.0, 0));
355    }
356
357    let categories: Vec<&String> = contingency_table.keys().collect();
358    let targets: Vec<i32> = all_targets.into_iter().collect();
359
360    if categories.len() <= 1 {
361        return Ok((0.0, 1.0, 0));
362    }
363
364    // Calculate row and column totals
365    let mut row_totals: HashMap<&String, usize> = HashMap::new();
366    let mut col_totals: HashMap<i32, usize> = HashMap::new();
367    let mut grand_total = 0;
368
369    for category in &categories {
370        let mut row_total = 0;
371        if let Some(target_counts) = contingency_table.get(*category) {
372            for (&target, &count) in target_counts {
373                row_total += count;
374                *col_totals.entry(target).or_insert(0) += count;
375                grand_total += count;
376            }
377        }
378        row_totals.insert(category, row_total);
379    }
380
381    if grand_total == 0 {
382        return Ok((0.0, 1.0, 0));
383    }
384
385    // Calculate chi-squared statistic
386    let mut chi_squared = 0.0;
387    for category in &categories {
388        for &target in &targets {
389            let observed = contingency_table
390                .get(*category)
391                .and_then(|counts| counts.get(&target))
392                .unwrap_or(&0);
393
394            let expected = (*row_totals.get(category).unwrap_or(&0) as f64)
395                * (*col_totals.get(&target).unwrap_or(&0) as f64)
396                / (grand_total as f64);
397
398            if expected > 0.0 {
399                let diff = (*observed as f64) - expected;
400                chi_squared += (diff * diff) / expected;
401            }
402        }
403    }
404
405    let degrees_of_freedom = (categories.len() - 1) * (targets.len() - 1);
406    let p_value = chi_squared_p_value(chi_squared, degrees_of_freedom);
407
408    Ok((chi_squared, p_value, degrees_of_freedom))
409}
410
411/// Calculate chi-squared statistic for merged contingency table
412fn calculate_chi_squared_for_merged(
413    merged_table: &HashMap<i32, usize>,
414) -> Result<(f64, f64, usize)> {
415    if merged_table.len() <= 1 {
416        return Ok((0.0, 1.0, 0));
417    }
418
419    let total: usize = merged_table.values().sum();
420    if total == 0 {
421        return Ok((0.0, 1.0, 0));
422    }
423
424    // Simple chi-squared test for goodness of fit (equal expected frequencies)
425    let expected = total as f64 / merged_table.len() as f64;
426    let mut chi_squared = 0.0;
427
428    for &observed in merged_table.values() {
429        let diff = (observed as f64) - expected;
430        chi_squared += (diff * diff) / expected;
431    }
432
433    let degrees_of_freedom = merged_table.len() - 1;
434    let p_value = chi_squared_p_value(chi_squared, degrees_of_freedom);
435
436    Ok((chi_squared, p_value, degrees_of_freedom))
437}
438
439/// Calculate approximate p-value for chi-squared statistic
440fn chi_squared_p_value(chi_squared: f64, df: usize) -> f64 {
441    if df == 0 || chi_squared <= 0.0 {
442        return 1.0;
443    }
444
445    // Simple approximation using Wilson-Hilferty transformation
446    // For more accuracy, consider using a proper statistical library
447    let h = 2.0 / (9.0 * df as f64);
448    let z = ((chi_squared / df as f64).powf(1.0 / 3.0) - 1.0 + h) / h.sqrt();
449
450    // Approximate standard normal CDF
451    if z > 0.0 {
452        0.5 * (1.0 - (2.0 / std::f64::consts::PI).sqrt() * z * (-z * z / 2.0).exp())
453    } else {
454        0.5 * (1.0 + (2.0 / std::f64::consts::PI).sqrt() * (-z) * (-z * z / 2.0).exp())
455    }
456}
457
458/// Conditional inference tree split information
459#[derive(Debug, Clone)]
460pub struct ConditionalInferenceSplit {
461    /// Feature index that was selected for splitting
462    pub feature_idx: usize,
463    /// Split value for continuous features
464    pub split_value: Option<f64>,
465    /// Categories for the left branch (for categorical features)
466    pub left_categories: Option<Vec<String>>,
467    /// Test statistic value
468    pub test_statistic: f64,
469    /// P-value of the statistical test
470    pub p_value: f64,
471    /// Type of test performed
472    pub test_type: ConditionalTestType,
473    /// Significance level used
474    pub significance_level: f64,
475}
476
477impl ConditionalInferenceSplit {
478    /// Perform conditional inference splitting analysis
479    pub fn analyze_conditional_split(
480        x: &Array2<f64>,
481        y: &Array1<f64>,
482        _feature_types: &[FeatureType],
483        significance_level: f64,
484        test_type: ConditionalTestType,
485    ) -> Result<Option<Self>> {
486        if x.nrows() != y.len() {
487            return Err(SklearsError::InvalidInput(
488                "Feature and target arrays must have the same length".to_string(),
489            ));
490        }
491
492        if x.nrows() < 4 {
493            return Ok(None); // Need at least 4 samples for meaningful statistics
494        }
495
496        let n_features = x.ncols();
497        let mut best_split: Option<ConditionalInferenceSplit> = None;
498        let mut best_p_value = 1.0;
499
500        // Test each feature for association with the target
501        for feature_idx in 0..n_features {
502            let feature_values = x.column(feature_idx);
503
504            let (test_statistic, p_value) = match test_type {
505                ConditionalTestType::QuadraticForm => {
506                    compute_quadratic_form_test(&feature_values, y)?
507                }
508                ConditionalTestType::MaxType => compute_maxtype_test(&feature_values, y)?,
509                ConditionalTestType::MonteCarlo { n_permutations } => {
510                    compute_monte_carlo_test(&feature_values, y, n_permutations)?
511                }
512                ConditionalTestType::AsymptoticChiSquared => {
513                    compute_asymptotic_chi_squared_test(&feature_values, y)?
514                }
515            };
516
517            // Check if this is the most significant association
518            if p_value < significance_level && p_value < best_p_value {
519                // Find the best split point for this feature
520                let split_value = find_best_split_point(&feature_values, y)?;
521
522                best_split = Some(ConditionalInferenceSplit {
523                    feature_idx,
524                    split_value: Some(split_value),
525                    left_categories: None,
526                    test_statistic,
527                    p_value,
528                    test_type,
529                    significance_level,
530                });
531                best_p_value = p_value;
532            }
533        }
534
535        Ok(best_split)
536    }
537
538    /// Check if the split is statistically significant
539    pub fn is_significant(&self) -> bool {
540        self.p_value < self.significance_level
541    }
542}
543
544/// Compute quadratic form test statistic for continuous features
545fn compute_quadratic_form_test(
546    feature_values: &ArrayView1<f64>,
547    target_values: &Array1<f64>,
548) -> Result<(f64, f64)> {
549    let n = feature_values.len();
550    if n < 4 {
551        return Ok((0.0, 1.0));
552    }
553
554    // Compute correlation coefficient
555    let feature_mean = feature_values.mean().unwrap_or(0.0);
556    let target_mean = target_values.mean().unwrap_or(0.0);
557
558    let mut numerator = 0.0;
559    let mut feature_var = 0.0;
560    let mut target_var = 0.0;
561
562    for i in 0..n {
563        let feature_diff = feature_values[i] - feature_mean;
564        let target_diff = target_values[i] - target_mean;
565
566        numerator += feature_diff * target_diff;
567        feature_var += feature_diff * feature_diff;
568        target_var += target_diff * target_diff;
569    }
570
571    if feature_var == 0.0 || target_var == 0.0 {
572        return Ok((0.0, 1.0));
573    }
574
575    let correlation = numerator / (feature_var * target_var).sqrt();
576
577    // Transform to test statistic
578    let test_statistic =
579        correlation * correlation * (n - 2) as f64 / (1.0 - correlation * correlation);
580
581    // Approximate p-value using t-distribution approximation
582    let p_value = 2.0 * (1.0 - student_t_cdf(test_statistic.sqrt(), n - 2));
583
584    Ok((test_statistic, p_value))
585}
586
587/// Compute maxtype test statistic for categorical features
588fn compute_maxtype_test(
589    feature_values: &ArrayView1<f64>,
590    target_values: &Array1<f64>,
591) -> Result<(f64, f64)> {
592    // For simplicity, treat as continuous and use quadratic form
593    // In practice, this would be more sophisticated for true categorical data
594    compute_quadratic_form_test(feature_values, target_values)
595}
596
597/// Compute Monte Carlo permutation test
598fn compute_monte_carlo_test(
599    feature_values: &ArrayView1<f64>,
600    target_values: &Array1<f64>,
601    n_permutations: usize,
602) -> Result<(f64, f64)> {
603    // Compute original test statistic
604    let (original_statistic, _) = compute_quadratic_form_test(feature_values, target_values)?;
605
606    // Perform permutations
607    let mut rng = thread_rng();
608    let mut permuted_target = target_values.clone();
609    let mut extreme_count = 0;
610
611    for _ in 0..n_permutations {
612        // Shuffle target values using Fisher-Yates algorithm
613        let target_slice = permuted_target.as_slice_mut().unwrap();
614        for i in (1..target_slice.len()).rev() {
615            let j = rng.gen_range(0..=i);
616            target_slice.swap(i, j);
617        }
618
619        // Compute test statistic for permuted data
620        let (permuted_statistic, _) =
621            compute_quadratic_form_test(feature_values, &permuted_target)?;
622
623        if permuted_statistic >= original_statistic {
624            extreme_count += 1;
625        }
626    }
627
628    let p_value = (extreme_count + 1) as f64 / (n_permutations + 1) as f64;
629
630    Ok((original_statistic, p_value))
631}
632
633/// Compute asymptotic chi-squared test
634fn compute_asymptotic_chi_squared_test(
635    feature_values: &ArrayView1<f64>,
636    target_values: &Array1<f64>,
637) -> Result<(f64, f64)> {
638    // Use quadratic form test and chi-squared approximation
639    let (test_statistic, _) = compute_quadratic_form_test(feature_values, target_values)?;
640
641    // Degrees of freedom = 1 for single feature test
642    let df = 1;
643    let p_value = chi_squared_p_value(test_statistic, df);
644
645    Ok((test_statistic, p_value))
646}
647
648/// Find the best split point for a feature using conditional inference
649fn find_best_split_point(
650    feature_values: &ArrayView1<f64>,
651    target_values: &Array1<f64>,
652) -> Result<f64> {
653    if feature_values.is_empty() {
654        return Err(SklearsError::InvalidInput(
655            "Empty feature values".to_string(),
656        ));
657    }
658
659    // Find unique sorted values
660    let mut values: Vec<f64> = feature_values.to_vec();
661    values.sort_by(|a, b| a.partial_cmp(b).unwrap());
662    values.dedup();
663
664    if values.len() < 2 {
665        return Ok(values[0]);
666    }
667
668    let mut best_split = values[0];
669    let mut best_statistic = 0.0;
670
671    // Try each potential split point
672    for i in 0..(values.len() - 1) {
673        let split_candidate = (values[i] + values[i + 1]) / 2.0;
674
675        // Split data at this point
676        let mut left_targets = Vec::new();
677        let mut right_targets = Vec::new();
678
679        for (j, &feature_val) in feature_values.iter().enumerate() {
680            if feature_val <= split_candidate {
681                left_targets.push(target_values[j]);
682            } else {
683                right_targets.push(target_values[j]);
684            }
685        }
686
687        if left_targets.is_empty() || right_targets.is_empty() {
688            continue;
689        }
690
691        // Compute separation statistic (simplified)
692        let left_mean = left_targets.iter().sum::<f64>() / left_targets.len() as f64;
693        let right_mean = right_targets.iter().sum::<f64>() / right_targets.len() as f64;
694        let separation = (left_mean - right_mean).abs();
695
696        if separation > best_statistic {
697            best_statistic = separation;
698            best_split = split_candidate;
699        }
700    }
701
702    Ok(best_split)
703}
704
705/// Approximate Student's t-distribution CDF
706fn student_t_cdf(t: f64, df: usize) -> f64 {
707    if df == 0 {
708        return 0.5;
709    }
710
711    // Simple approximation for t-distribution CDF
712    // For production use, consider a proper statistical library
713    let x = t / (df as f64).sqrt();
714    0.5 * (1.0 + (2.0 / std::f64::consts::PI).sqrt() * x / (1.0 + x * x).sqrt())
715}