1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::random::thread_rng;
8use sklears_core::error::{Result, SklearsError};
9use std::collections::HashMap;
10
11use crate::criteria::{ConditionalTestType, FeatureType};
12
13#[derive(Debug, Clone)]
15pub struct HyperplaneSplit {
16 pub coefficients: Array1<f64>,
18 pub threshold: f64,
20 pub bias: f64,
22 pub impurity_decrease: f64,
24}
25
26impl HyperplaneSplit {
27 pub fn evaluate(&self, sample: &Array1<f64>) -> bool {
29 let dot_product = self.coefficients.dot(sample) + self.bias;
30 dot_product >= self.threshold
31 }
32
33 pub fn random(n_features: usize, rng: &mut scirs2_core::CoreRandom) -> Self {
35 let mut coefficients = Array1::zeros(n_features);
36 for i in 0..n_features {
37 coefficients[i] = rng.gen_range(-1.0..1.0);
38 }
39
40 let dot_product: f64 = coefficients.dot(&coefficients);
42 let norm = dot_product.sqrt();
43 if norm > 1e-10_f64 {
44 coefficients /= norm;
45 }
46
47 Self {
48 coefficients,
49 threshold: rng.gen_range(-1.0..1.0),
50 bias: rng.gen_range(-0.1..0.1),
51 impurity_decrease: 0.0,
52 }
53 }
54
55 #[cfg(feature = "oblique")]
57 pub fn from_ridge_regression(x: &Array2<f64>, y: &Array1<f64>, alpha: f64) -> Result<Self> {
58 let n_features = x.ncols();
59 if x.nrows() < 2 {
60 return Err(SklearsError::InvalidInput(
61 "Need at least 2 samples for ridge regression".to_string(),
62 ));
63 }
64
65 let mut x_bias = Array2::ones((x.nrows(), n_features + 1));
67 x_bias.slice_mut(s![.., ..n_features]).assign(x);
68
69 let xtx = x_bias.t().dot(&x_bias);
71 let ridge_matrix = xtx + Array2::<f64>::eye(n_features + 1) * alpha;
72 let xty = x_bias.t().dot(y);
73
74 match gauss_jordan_inverse(&ridge_matrix) {
76 Ok(inv_matrix) => {
77 let coefficients_full = inv_matrix.dot(&xty);
78
79 let coefficients = coefficients_full.slice(s![..n_features]).to_owned();
80 let bias = coefficients_full[n_features];
81
82 Ok(Self {
83 coefficients,
84 threshold: 0.0, bias,
86 impurity_decrease: 0.0,
87 })
88 }
89 Err(_) => {
90 let mut rng = thread_rng();
92 Ok(Self::random(n_features, &mut rng))
93 }
94 }
95 }
96}
97
98#[cfg(feature = "oblique")]
100fn gauss_jordan_inverse(matrix: &Array2<f64>) -> Result<Array2<f64>> {
101 let n = matrix.nrows();
102 if n != matrix.ncols() {
103 return Err(SklearsError::InvalidInput(
104 "Matrix must be square".to_string(),
105 ));
106 }
107
108 let mut augmented = Array2::zeros((n, 2 * n));
110
111 for i in 0..n {
113 for j in 0..n {
114 augmented[[i, j]] = matrix[[i, j]];
115 }
116 augmented[[i, i + n]] = 1.0;
118 }
119
120 for i in 0..n {
122 let mut max_row = i;
124 for k in (i + 1)..n {
125 if augmented[[k, i]].abs() > augmented[[max_row, i]].abs() {
126 max_row = k;
127 }
128 }
129
130 if augmented[[max_row, i]].abs() < 1e-12 {
132 return Err(SklearsError::InvalidInput("Matrix is singular".to_string()));
133 }
134
135 if max_row != i {
137 for j in 0..(2 * n) {
138 let temp = augmented[[i, j]];
139 augmented[[i, j]] = augmented[[max_row, j]];
140 augmented[[max_row, j]] = temp;
141 }
142 }
143
144 let pivot = augmented[[i, i]];
146 for j in 0..(2 * n) {
147 augmented[[i, j]] /= pivot;
148 }
149
150 for k in 0..n {
152 if k != i {
153 let factor = augmented[[k, i]];
154 for j in 0..(2 * n) {
155 augmented[[k, j]] -= factor * augmented[[i, j]];
156 }
157 }
158 }
159 }
160
161 let mut inverse = Array2::zeros((n, n));
163 for i in 0..n {
164 for j in 0..n {
165 inverse[[i, j]] = augmented[[i, j + n]];
166 }
167 }
168
169 Ok(inverse)
170}
171
172#[derive(Debug, Clone)]
174pub struct ChaidSplit {
175 pub feature_idx: usize,
177 pub category_groups: Vec<Vec<String>>,
179 pub chi_squared: f64,
181 pub p_value: f64,
183 pub degrees_of_freedom: usize,
185 pub significance_level: f64,
187}
188
189impl ChaidSplit {
190 pub fn analyze_categorical_split(
192 feature_values: &[String],
193 target_values: &[i32],
194 significance_level: f64,
195 ) -> Result<Option<Self>> {
196 if feature_values.len() != target_values.len() {
197 return Err(SklearsError::InvalidInput(
198 "Feature and target arrays must have the same length".to_string(),
199 ));
200 }
201
202 if feature_values.is_empty() {
203 return Ok(None);
204 }
205
206 let contingency_table = build_contingency_table(feature_values, target_values)?;
208
209 let merged_categories = merge_categories_chaid(&contingency_table, significance_level)?;
211
212 if merged_categories.len() <= 1 {
213 return Ok(None); }
215
216 let (chi_squared, p_value, df) = calculate_chi_squared(&contingency_table)?;
218
219 Ok(Some(ChaidSplit {
220 feature_idx: 0, category_groups: merged_categories,
222 chi_squared,
223 p_value,
224 degrees_of_freedom: df,
225 significance_level,
226 }))
227 }
228
229 pub fn is_significant(&self) -> bool {
231 self.p_value < self.significance_level
232 }
233}
234
235fn build_contingency_table(
237 feature_values: &[String],
238 target_values: &[i32],
239) -> Result<HashMap<String, HashMap<i32, usize>>> {
240 let mut table: HashMap<String, HashMap<i32, usize>> = HashMap::new();
241
242 for (feature_val, target_val) in feature_values.iter().zip(target_values.iter()) {
243 let target_counts = table.entry(feature_val.clone()).or_default();
244 *target_counts.entry(*target_val).or_insert(0) += 1;
245 }
246
247 Ok(table)
248}
249
250fn merge_categories_chaid(
252 contingency_table: &HashMap<String, HashMap<i32, usize>>,
253 significance_level: f64,
254) -> Result<Vec<Vec<String>>> {
255 let categories: Vec<String> = contingency_table.keys().cloned().collect();
256 let mut groups: Vec<Vec<String>> = categories.iter().map(|c| vec![c.clone()]).collect();
257
258 if groups.len() <= 1 {
259 return Ok(groups);
260 }
261
262 loop {
263 let mut best_merge: Option<(usize, usize, f64)> = None;
264 let mut min_chi_squared = f64::INFINITY;
265
266 for i in 0..groups.len() {
268 for j in (i + 1)..groups.len() {
269 let merged_table =
271 create_merged_contingency_table(contingency_table, &groups[i], &groups[j])?;
272
273 if let Ok((chi_squared, p_value, _)) =
274 calculate_chi_squared_for_merged(&merged_table)
275 {
276 if p_value > significance_level && chi_squared < min_chi_squared {
278 min_chi_squared = chi_squared;
279 best_merge = Some((i, j, chi_squared));
280 }
281 }
282 }
283 }
284
285 if let Some((i, j, _)) = best_merge {
287 let mut merged_group = groups[i].clone();
289 merged_group.extend(groups[j].clone());
290
291 if i < j {
293 groups.remove(j);
294 groups.remove(i);
295 } else {
296 groups.remove(i);
297 groups.remove(j);
298 }
299 groups.push(merged_group);
300 } else {
301 break;
302 }
303
304 if groups.len() <= 1 {
305 break;
306 }
307 }
308
309 Ok(groups)
310}
311
312fn create_merged_contingency_table(
314 original_table: &HashMap<String, HashMap<i32, usize>>,
315 group1: &[String],
316 group2: &[String],
317) -> Result<HashMap<i32, usize>> {
318 let mut merged_table = HashMap::new();
319
320 for category in group1 {
322 if let Some(target_counts) = original_table.get(category) {
323 for (&target, &count) in target_counts {
324 *merged_table.entry(target).or_insert(0) += count;
325 }
326 }
327 }
328
329 for category in group2 {
331 if let Some(target_counts) = original_table.get(category) {
332 for (&target, &count) in target_counts {
333 *merged_table.entry(target).or_insert(0) += count;
334 }
335 }
336 }
337
338 Ok(merged_table)
339}
340
341fn calculate_chi_squared(
343 contingency_table: &HashMap<String, HashMap<i32, usize>>,
344) -> Result<(f64, f64, usize)> {
345 use std::collections::HashSet;
346
347 let mut all_targets: HashSet<i32> = HashSet::new();
349 for target_counts in contingency_table.values() {
350 all_targets.extend(target_counts.keys());
351 }
352
353 if all_targets.len() <= 1 {
354 return Ok((0.0, 1.0, 0));
355 }
356
357 let categories: Vec<&String> = contingency_table.keys().collect();
358 let targets: Vec<i32> = all_targets.into_iter().collect();
359
360 if categories.len() <= 1 {
361 return Ok((0.0, 1.0, 0));
362 }
363
364 let mut row_totals: HashMap<&String, usize> = HashMap::new();
366 let mut col_totals: HashMap<i32, usize> = HashMap::new();
367 let mut grand_total = 0;
368
369 for category in &categories {
370 let mut row_total = 0;
371 if let Some(target_counts) = contingency_table.get(*category) {
372 for (&target, &count) in target_counts {
373 row_total += count;
374 *col_totals.entry(target).or_insert(0) += count;
375 grand_total += count;
376 }
377 }
378 row_totals.insert(category, row_total);
379 }
380
381 if grand_total == 0 {
382 return Ok((0.0, 1.0, 0));
383 }
384
385 let mut chi_squared = 0.0;
387 for category in &categories {
388 for &target in &targets {
389 let observed = contingency_table
390 .get(*category)
391 .and_then(|counts| counts.get(&target))
392 .unwrap_or(&0);
393
394 let expected = (*row_totals.get(category).unwrap_or(&0) as f64)
395 * (*col_totals.get(&target).unwrap_or(&0) as f64)
396 / (grand_total as f64);
397
398 if expected > 0.0 {
399 let diff = (*observed as f64) - expected;
400 chi_squared += (diff * diff) / expected;
401 }
402 }
403 }
404
405 let degrees_of_freedom = (categories.len() - 1) * (targets.len() - 1);
406 let p_value = chi_squared_p_value(chi_squared, degrees_of_freedom);
407
408 Ok((chi_squared, p_value, degrees_of_freedom))
409}
410
411fn calculate_chi_squared_for_merged(
413 merged_table: &HashMap<i32, usize>,
414) -> Result<(f64, f64, usize)> {
415 if merged_table.len() <= 1 {
416 return Ok((0.0, 1.0, 0));
417 }
418
419 let total: usize = merged_table.values().sum();
420 if total == 0 {
421 return Ok((0.0, 1.0, 0));
422 }
423
424 let expected = total as f64 / merged_table.len() as f64;
426 let mut chi_squared = 0.0;
427
428 for &observed in merged_table.values() {
429 let diff = (observed as f64) - expected;
430 chi_squared += (diff * diff) / expected;
431 }
432
433 let degrees_of_freedom = merged_table.len() - 1;
434 let p_value = chi_squared_p_value(chi_squared, degrees_of_freedom);
435
436 Ok((chi_squared, p_value, degrees_of_freedom))
437}
438
439fn chi_squared_p_value(chi_squared: f64, df: usize) -> f64 {
441 if df == 0 || chi_squared <= 0.0 {
442 return 1.0;
443 }
444
445 let h = 2.0 / (9.0 * df as f64);
448 let z = ((chi_squared / df as f64).powf(1.0 / 3.0) - 1.0 + h) / h.sqrt();
449
450 if z > 0.0 {
452 0.5 * (1.0 - (2.0 / std::f64::consts::PI).sqrt() * z * (-z * z / 2.0).exp())
453 } else {
454 0.5 * (1.0 + (2.0 / std::f64::consts::PI).sqrt() * (-z) * (-z * z / 2.0).exp())
455 }
456}
457
458#[derive(Debug, Clone)]
460pub struct ConditionalInferenceSplit {
461 pub feature_idx: usize,
463 pub split_value: Option<f64>,
465 pub left_categories: Option<Vec<String>>,
467 pub test_statistic: f64,
469 pub p_value: f64,
471 pub test_type: ConditionalTestType,
473 pub significance_level: f64,
475}
476
477impl ConditionalInferenceSplit {
478 pub fn analyze_conditional_split(
480 x: &Array2<f64>,
481 y: &Array1<f64>,
482 _feature_types: &[FeatureType],
483 significance_level: f64,
484 test_type: ConditionalTestType,
485 ) -> Result<Option<Self>> {
486 if x.nrows() != y.len() {
487 return Err(SklearsError::InvalidInput(
488 "Feature and target arrays must have the same length".to_string(),
489 ));
490 }
491
492 if x.nrows() < 4 {
493 return Ok(None); }
495
496 let n_features = x.ncols();
497 let mut best_split: Option<ConditionalInferenceSplit> = None;
498 let mut best_p_value = 1.0;
499
500 for feature_idx in 0..n_features {
502 let feature_values = x.column(feature_idx);
503
504 let (test_statistic, p_value) = match test_type {
505 ConditionalTestType::QuadraticForm => {
506 compute_quadratic_form_test(&feature_values, y)?
507 }
508 ConditionalTestType::MaxType => compute_maxtype_test(&feature_values, y)?,
509 ConditionalTestType::MonteCarlo { n_permutations } => {
510 compute_monte_carlo_test(&feature_values, y, n_permutations)?
511 }
512 ConditionalTestType::AsymptoticChiSquared => {
513 compute_asymptotic_chi_squared_test(&feature_values, y)?
514 }
515 };
516
517 if p_value < significance_level && p_value < best_p_value {
519 let split_value = find_best_split_point(&feature_values, y)?;
521
522 best_split = Some(ConditionalInferenceSplit {
523 feature_idx,
524 split_value: Some(split_value),
525 left_categories: None,
526 test_statistic,
527 p_value,
528 test_type,
529 significance_level,
530 });
531 best_p_value = p_value;
532 }
533 }
534
535 Ok(best_split)
536 }
537
538 pub fn is_significant(&self) -> bool {
540 self.p_value < self.significance_level
541 }
542}
543
544fn compute_quadratic_form_test(
546 feature_values: &ArrayView1<f64>,
547 target_values: &Array1<f64>,
548) -> Result<(f64, f64)> {
549 let n = feature_values.len();
550 if n < 4 {
551 return Ok((0.0, 1.0));
552 }
553
554 let feature_mean = feature_values.mean().unwrap_or(0.0);
556 let target_mean = target_values.mean().unwrap_or(0.0);
557
558 let mut numerator = 0.0;
559 let mut feature_var = 0.0;
560 let mut target_var = 0.0;
561
562 for i in 0..n {
563 let feature_diff = feature_values[i] - feature_mean;
564 let target_diff = target_values[i] - target_mean;
565
566 numerator += feature_diff * target_diff;
567 feature_var += feature_diff * feature_diff;
568 target_var += target_diff * target_diff;
569 }
570
571 if feature_var == 0.0 || target_var == 0.0 {
572 return Ok((0.0, 1.0));
573 }
574
575 let correlation = numerator / (feature_var * target_var).sqrt();
576
577 let test_statistic =
579 correlation * correlation * (n - 2) as f64 / (1.0 - correlation * correlation);
580
581 let p_value = 2.0 * (1.0 - student_t_cdf(test_statistic.sqrt(), n - 2));
583
584 Ok((test_statistic, p_value))
585}
586
587fn compute_maxtype_test(
589 feature_values: &ArrayView1<f64>,
590 target_values: &Array1<f64>,
591) -> Result<(f64, f64)> {
592 compute_quadratic_form_test(feature_values, target_values)
595}
596
597fn compute_monte_carlo_test(
599 feature_values: &ArrayView1<f64>,
600 target_values: &Array1<f64>,
601 n_permutations: usize,
602) -> Result<(f64, f64)> {
603 let (original_statistic, _) = compute_quadratic_form_test(feature_values, target_values)?;
605
606 let mut rng = thread_rng();
608 let mut permuted_target = target_values.clone();
609 let mut extreme_count = 0;
610
611 for _ in 0..n_permutations {
612 let target_slice = permuted_target.as_slice_mut().unwrap();
614 for i in (1..target_slice.len()).rev() {
615 let j = rng.gen_range(0..=i);
616 target_slice.swap(i, j);
617 }
618
619 let (permuted_statistic, _) =
621 compute_quadratic_form_test(feature_values, &permuted_target)?;
622
623 if permuted_statistic >= original_statistic {
624 extreme_count += 1;
625 }
626 }
627
628 let p_value = (extreme_count + 1) as f64 / (n_permutations + 1) as f64;
629
630 Ok((original_statistic, p_value))
631}
632
633fn compute_asymptotic_chi_squared_test(
635 feature_values: &ArrayView1<f64>,
636 target_values: &Array1<f64>,
637) -> Result<(f64, f64)> {
638 let (test_statistic, _) = compute_quadratic_form_test(feature_values, target_values)?;
640
641 let df = 1;
643 let p_value = chi_squared_p_value(test_statistic, df);
644
645 Ok((test_statistic, p_value))
646}
647
648fn find_best_split_point(
650 feature_values: &ArrayView1<f64>,
651 target_values: &Array1<f64>,
652) -> Result<f64> {
653 if feature_values.is_empty() {
654 return Err(SklearsError::InvalidInput(
655 "Empty feature values".to_string(),
656 ));
657 }
658
659 let mut values: Vec<f64> = feature_values.to_vec();
661 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
662 values.dedup();
663
664 if values.len() < 2 {
665 return Ok(values[0]);
666 }
667
668 let mut best_split = values[0];
669 let mut best_statistic = 0.0;
670
671 for i in 0..(values.len() - 1) {
673 let split_candidate = (values[i] + values[i + 1]) / 2.0;
674
675 let mut left_targets = Vec::new();
677 let mut right_targets = Vec::new();
678
679 for (j, &feature_val) in feature_values.iter().enumerate() {
680 if feature_val <= split_candidate {
681 left_targets.push(target_values[j]);
682 } else {
683 right_targets.push(target_values[j]);
684 }
685 }
686
687 if left_targets.is_empty() || right_targets.is_empty() {
688 continue;
689 }
690
691 let left_mean = left_targets.iter().sum::<f64>() / left_targets.len() as f64;
693 let right_mean = right_targets.iter().sum::<f64>() / right_targets.len() as f64;
694 let separation = (left_mean - right_mean).abs();
695
696 if separation > best_statistic {
697 best_statistic = separation;
698 best_split = split_candidate;
699 }
700 }
701
702 Ok(best_split)
703}
704
705fn student_t_cdf(t: f64, df: usize) -> f64 {
707 if df == 0 {
708 return 0.5;
709 }
710
711 let x = t / (df as f64).sqrt();
714 0.5 * (1.0 + (2.0 / std::f64::consts::PI).sqrt() * x / (1.0 + x * x).sqrt())
715}