1use crate::config::*;
7use crate::node::*;
8use crate::SplitCriterion;
9use scirs2_core::ndarray::{Array1, Array2};
10use sklears_core::{
11 error::{Result, SklearsError},
12 types::Float,
13};
14use std::collections::BinaryHeap;
15
16pub fn handle_missing_values<T: Clone>(
18 x: &Array2<f64>,
19 y: &Array1<T>,
20 strategy: MissingValueStrategy,
21) -> Result<(Array2<f64>, Array1<T>)> {
22 let mut has_missing = false;
24 for value in x.iter() {
25 if value.is_nan() {
26 has_missing = true;
27 break;
28 }
29 }
30 if !has_missing {
31 return Ok((x.clone(), y.clone()));
33 }
34 match strategy {
35 MissingValueStrategy::Skip => {
36 let mut valid_indices = Vec::new();
38 for (row_idx, row) in x.outer_iter().enumerate() {
39 let mut row_valid = true;
40 for &value in row.iter() {
41 if value.is_nan() {
42 row_valid = false;
43 break;
44 }
45 }
46 if row_valid {
47 valid_indices.push(row_idx);
48 }
49 }
50 if valid_indices.is_empty() {
51 return Err(SklearsError::InvalidData {
52 reason: "All samples contain missing values".to_string(),
53 });
54 }
55 let n_valid = valid_indices.len();
57 let n_features = x.ncols();
58 let mut x_clean = Array2::zeros((n_valid, n_features));
59 let mut y_clean = Vec::with_capacity(n_valid);
60 for (new_idx, &orig_idx) in valid_indices.iter().enumerate() {
61 x_clean.row_mut(new_idx).assign(&x.row(orig_idx));
62 y_clean.push(y[orig_idx].clone());
63 }
64 Ok((x_clean, Array1::from_vec(y_clean)))
65 }
66 MissingValueStrategy::Majority => {
67 let mut x_imputed = x.clone();
69 for col_idx in 0..x.ncols() {
70 let column = x.column(col_idx);
71 let mut sum = 0.0;
73 let mut count = 0;
74 for &value in column.iter() {
75 if !value.is_nan() {
76 sum += value;
77 count += 1;
78 }
79 }
80 if count > 0 {
81 let mean = sum / count as f64;
82 for row_idx in 0..x.nrows() {
84 if x_imputed[[row_idx, col_idx]].is_nan() {
85 x_imputed[[row_idx, col_idx]] = mean;
86 }
87 }
88 } else {
89 for row_idx in 0..x.nrows() {
91 x_imputed[[row_idx, col_idx]] = 0.0;
92 }
93 }
94 }
95 Ok((x_imputed, y.clone()))
96 }
97 MissingValueStrategy::Surrogate => {
98 let mut x_imputed = x.clone();
101 for col_idx in 0..x.ncols() {
102 let mut sum = 0.0;
103 let mut count = 0;
104 for row_idx in 0..x.nrows() {
106 let value = x[[row_idx, col_idx]];
107 if !value.is_nan() {
108 sum += value;
109 count += 1;
110 }
111 }
112 if count > 0 {
113 let mean = sum / count as f64;
114 for row_idx in 0..x.nrows() {
116 if x_imputed[[row_idx, col_idx]].is_nan() {
117 x_imputed[[row_idx, col_idx]] = mean;
118 }
119 }
120 } else {
121 for row_idx in 0..x.nrows() {
123 x_imputed[[row_idx, col_idx]] = 0.0;
124 }
125 }
126 }
127 Ok((x_imputed, y.clone()))
128 }
129 }
130}
131
132#[derive(Debug)]
134pub struct BestFirstTreeBuilder {
135 pub nodes: Vec<TreeNode>,
137 pub node_queue: BinaryHeap<NodePriority>,
139 pub next_node_id: usize,
141 pub n_leaves: usize,
143}
144
145impl BestFirstTreeBuilder {
146 pub fn new(
148 x: &Array2<f64>,
149 y: &Array1<i32>,
150 config: &DecisionTreeConfig,
151 n_classes: usize,
152 ) -> Self {
153 let n_samples = x.nrows();
154 let sample_indices: Vec<usize> = (0..n_samples).collect();
155
156 let mut class_counts = vec![0; n_classes];
158 for &sample_idx in &sample_indices {
159 let class = y[sample_idx] as usize;
160 if class < n_classes {
161 class_counts[class] += 1;
162 }
163 }
164
165 let impurity = gini_impurity(&class_counts, n_samples as i32);
166 let prediction = class_counts
167 .iter()
168 .enumerate()
169 .max_by_key(|(_, &count)| count)
170 .map(|(class, _)| class as f64)
171 .unwrap_or(0.0);
172
173 let best_split = find_best_split_for_node(x, y, &sample_indices, config, n_classes);
175 let potential_decrease = best_split
176 .as_ref()
177 .map(|s| s.impurity_decrease)
178 .unwrap_or(0.0);
179
180 let root_node = TreeNode {
181 id: 0,
182 depth: 0,
183 sample_indices,
184 impurity,
185 prediction,
186 potential_decrease,
187 best_split,
188 parent_id: None,
189 is_leaf: false,
190 };
191
192 let mut node_queue = BinaryHeap::new();
193 if potential_decrease > 0.0 {
194 node_queue.push(NodePriority {
195 node_id: 0,
196 priority: -potential_decrease, });
198 }
199
200 Self {
201 nodes: vec![root_node],
202 node_queue,
203 next_node_id: 1,
204 n_leaves: 1,
205 }
206 }
207
208 pub fn build_tree(
210 &mut self,
211 x: &Array2<f64>,
212 y: &Array1<i32>,
213 config: &DecisionTreeConfig,
214 n_classes: usize,
215 ) -> Result<()> {
216 let max_leaves = match config.growing_strategy {
217 TreeGrowingStrategy::BestFirst { max_leaves } => max_leaves,
218 _ => None,
219 };
220
221 while let Some(node_priority) = self.node_queue.pop() {
222 let node_id = node_priority.node_id;
223
224 if let Some(max_leaves) = max_leaves {
226 if self.n_leaves >= max_leaves {
227 break;
228 }
229 }
230
231 if let Some(max_depth) = config.max_depth {
232 if self.nodes[node_id].depth >= max_depth {
233 continue;
234 }
235 }
236
237 if self.nodes[node_id].sample_indices.len() < config.min_samples_split {
239 continue;
240 }
241
242 if self.split_node(node_id, x, y, config, n_classes).is_err() {
244 continue;
245 }
246 }
247
248 Ok(())
249 }
250
251 fn split_node(
253 &mut self,
254 node_id: usize,
255 x: &Array2<f64>,
256 y: &Array1<i32>,
257 config: &DecisionTreeConfig,
258 n_classes: usize,
259 ) -> Result<()> {
260 let node = &self.nodes[node_id].clone();
261 let best_split = match &node.best_split {
262 Some(split) => split.clone(),
263 None => {
264 return Err(SklearsError::InvalidInput(
265 "No valid split found".to_string(),
266 ))
267 }
268 };
269
270 let (left_indices, right_indices) = split_samples_by_threshold(
272 x,
273 &node.sample_indices,
274 best_split.feature_idx,
275 best_split.threshold,
276 );
277
278 if left_indices.len() < config.min_samples_leaf
279 || right_indices.len() < config.min_samples_leaf
280 {
281 return Err(SklearsError::InvalidInput(
282 "Split would create undersized leaves".to_string(),
283 ));
284 }
285
286 let left_node_id = self.next_node_id;
288 self.next_node_id += 1;
289
290 let left_node = self.create_child_node(
291 left_node_id,
292 node.id,
293 node.depth + 1,
294 left_indices,
295 x,
296 y,
297 config,
298 n_classes,
299 );
300
301 let right_node_id = self.next_node_id;
303 self.next_node_id += 1;
304
305 let right_node = self.create_child_node(
306 right_node_id,
307 node.id,
308 node.depth + 1,
309 right_indices,
310 x,
311 y,
312 config,
313 n_classes,
314 );
315
316 if left_node.potential_decrease > config.min_impurity_decrease {
318 self.node_queue.push(NodePriority {
319 node_id: left_node_id,
320 priority: -left_node.potential_decrease,
321 });
322 }
323
324 if right_node.potential_decrease > config.min_impurity_decrease {
325 self.node_queue.push(NodePriority {
326 node_id: right_node_id,
327 priority: -right_node.potential_decrease,
328 });
329 }
330
331 self.nodes.push(left_node);
332 self.nodes.push(right_node);
333
334 self.nodes[node_id].is_leaf = false;
336 self.n_leaves += 1; Ok(())
339 }
340
341 #[allow(clippy::too_many_arguments)]
343 fn create_child_node(
344 &self,
345 node_id: usize,
346 parent_id: usize,
347 depth: usize,
348 sample_indices: Vec<usize>,
349 x: &Array2<f64>,
350 y: &Array1<i32>,
351 config: &DecisionTreeConfig,
352 n_classes: usize,
353 ) -> TreeNode {
354 let mut class_counts = vec![0; n_classes];
356 for &sample_idx in &sample_indices {
357 let class = y[sample_idx] as usize;
358 if class < n_classes {
359 class_counts[class] += 1;
360 }
361 }
362
363 let impurity = gini_impurity(&class_counts, sample_indices.len() as i32);
364 let prediction = class_counts
365 .iter()
366 .enumerate()
367 .max_by_key(|(_, &count)| count)
368 .map(|(class, _)| class as f64)
369 .unwrap_or(0.0);
370
371 let best_split = find_best_split_for_node(x, y, &sample_indices, config, n_classes);
373 let potential_decrease = best_split
374 .as_ref()
375 .map(|s| s.impurity_decrease)
376 .unwrap_or(0.0);
377
378 TreeNode {
379 id: node_id,
380 depth,
381 sample_indices,
382 impurity,
383 prediction,
384 potential_decrease,
385 best_split,
386 parent_id: Some(parent_id),
387 is_leaf: true,
388 }
389 }
390}
391
392pub fn find_best_split_for_node(
394 x: &Array2<f64>,
395 y: &Array1<i32>,
396 sample_indices: &[usize],
397 config: &DecisionTreeConfig,
398 n_classes: usize,
399) -> Option<CustomSplit> {
400 if sample_indices.len() < config.min_samples_split {
401 return None;
402 }
403
404 let n_samples = sample_indices.len();
406 let n_features = x.ncols();
407
408 let mut node_x = Array2::zeros((n_samples, n_features));
409 let mut node_y = Array1::zeros(n_samples);
410
411 for (new_idx, &orig_idx) in sample_indices.iter().enumerate() {
412 for j in 0..n_features {
413 node_x[[new_idx, j]] = x[[orig_idx, j]];
414 }
415 node_y[new_idx] = y[orig_idx];
416 }
417
418 let feature_indices: Vec<usize> = (0..n_features).collect();
420
421 match config.criterion {
422 SplitCriterion::Gini | SplitCriterion::Entropy => {
423 find_best_twoing_split(&node_x, &node_y, &feature_indices, n_classes)
424 }
425 SplitCriterion::LogLoss => {
426 find_best_logloss_split(&node_x, &node_y, &feature_indices, n_classes)
427 }
428 _ => None, }
430}
431
432pub fn split_samples_by_threshold(
434 x: &Array2<f64>,
435 sample_indices: &[usize],
436 feature_idx: usize,
437 threshold: f64,
438) -> (Vec<usize>, Vec<usize>) {
439 let mut left_indices = Vec::new();
440 let mut right_indices = Vec::new();
441
442 for &sample_idx in sample_indices {
443 if x[[sample_idx, feature_idx]] <= threshold {
444 left_indices.push(sample_idx);
445 } else {
446 right_indices.push(sample_idx);
447 }
448 }
449
450 (left_indices, right_indices)
451}
452
453pub fn find_best_mae_split(
455 x: &Array2<f64>,
456 y: &Array1<f64>,
457 feature_indices: &[usize],
458) -> Option<CustomSplit> {
459 let n_samples = x.nrows();
460 let mut best_split: Option<CustomSplit> = None;
461 let mut best_impurity_decrease = f64::NEG_INFINITY;
462
463 let y_values: Vec<f64> = y.iter().cloned().collect();
465 let initial_impurity = mae_impurity(&y_values);
466
467 for &feature_idx in feature_indices {
468 let feature_values = x.column(feature_idx);
469
470 let mut pairs: Vec<(f64, f64)> = feature_values
472 .iter()
473 .zip(y.iter())
474 .map(|(&x_val, &y_val)| (x_val, y_val))
475 .collect();
476
477 pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
478
479 for i in 1..pairs.len() {
481 if pairs[i - 1].0 >= pairs[i].0 {
482 continue; }
484
485 let threshold = (pairs[i - 1].0 + pairs[i].0) / 2.0;
486
487 let left_values: Vec<f64> = pairs[..i].iter().map(|(_, y)| *y).collect();
489 let right_values: Vec<f64> = pairs[i..].iter().map(|(_, y)| *y).collect();
490
491 if left_values.is_empty() || right_values.is_empty() {
492 continue;
493 }
494
495 let left_impurity = mae_impurity(&left_values);
497 let right_impurity = mae_impurity(&right_values);
498 let left_weight = left_values.len() as f64 / n_samples as f64;
499 let right_weight = right_values.len() as f64 / n_samples as f64;
500 let weighted_impurity = left_weight * left_impurity + right_weight * right_impurity;
501
502 let impurity_decrease = initial_impurity - weighted_impurity;
503
504 if impurity_decrease > best_impurity_decrease {
505 best_impurity_decrease = impurity_decrease;
506 best_split = Some(CustomSplit {
507 feature_idx,
508 threshold,
509 impurity_decrease,
510 left_count: left_values.len(),
511 right_count: right_values.len(),
512 });
513 }
514 }
515 }
516
517 best_split
518}
519
520pub fn find_best_twoing_split(
522 x: &Array2<f64>,
523 y: &Array1<i32>,
524 feature_indices: &[usize],
525 n_classes: usize,
526) -> Option<CustomSplit> {
527 let _n_samples = x.nrows();
528 let mut best_split: Option<CustomSplit> = None;
529 let mut best_impurity_decrease = f64::NEG_INFINITY;
530
531 for &feature_idx in feature_indices {
532 let feature_values = x.column(feature_idx);
533
534 let mut pairs: Vec<(f64, i32)> = feature_values
536 .iter()
537 .zip(y.iter())
538 .map(|(&x_val, &y_val)| (x_val, y_val))
539 .collect();
540
541 pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
542
543 for i in 1..pairs.len() {
545 if pairs[i - 1].0 >= pairs[i].0 {
546 continue; }
548
549 let threshold = (pairs[i - 1].0 + pairs[i].0) / 2.0;
550
551 let mut left_counts = vec![0; n_classes];
553 let mut right_counts = vec![0; n_classes];
554
555 for (j, (_, class)) in pairs.iter().enumerate() {
556 let class_idx = *class as usize;
557 if j < i {
558 left_counts[class_idx] += 1;
559 } else {
560 right_counts[class_idx] += 1;
561 }
562 }
563
564 let left_total: usize = left_counts.iter().sum();
565 let right_total: usize = right_counts.iter().sum();
566
567 if left_total == 0 || right_total == 0 {
568 continue;
569 }
570
571 let impurity_decrease = twoing_impurity(&left_counts, &right_counts);
572
573 if impurity_decrease > best_impurity_decrease {
574 best_impurity_decrease = impurity_decrease;
575 best_split = Some(CustomSplit {
576 feature_idx,
577 threshold,
578 impurity_decrease,
579 left_count: left_total,
580 right_count: right_total,
581 });
582 }
583 }
584 }
585
586 best_split
587}
588
589pub fn find_best_logloss_split(
591 x: &Array2<f64>,
592 y: &Array1<i32>,
593 feature_indices: &[usize],
594 n_classes: usize,
595) -> Option<CustomSplit> {
596 let n_samples = x.nrows();
597 let mut best_split: Option<CustomSplit> = None;
598 let mut best_impurity_decrease = f64::NEG_INFINITY;
599
600 let mut initial_counts = vec![0; n_classes];
602 for &class in y.iter() {
603 initial_counts[class as usize] += 1;
604 }
605 let initial_impurity = log_loss_impurity(&initial_counts);
606
607 for &feature_idx in feature_indices {
608 let feature_values = x.column(feature_idx);
609
610 let mut pairs: Vec<(f64, i32)> = feature_values
612 .iter()
613 .zip(y.iter())
614 .map(|(&x_val, &y_val)| (x_val, y_val))
615 .collect();
616
617 pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
618
619 for i in 1..pairs.len() {
621 if pairs[i - 1].0 >= pairs[i].0 {
622 continue; }
624
625 let threshold = (pairs[i - 1].0 + pairs[i].0) / 2.0;
626
627 let mut left_counts = vec![0; n_classes];
629 let mut right_counts = vec![0; n_classes];
630
631 for (j, (_, class)) in pairs.iter().enumerate() {
632 let class_idx = *class as usize;
633 if j < i {
634 left_counts[class_idx] += 1;
635 } else {
636 right_counts[class_idx] += 1;
637 }
638 }
639
640 let left_total: usize = left_counts.iter().sum();
641 let right_total: usize = right_counts.iter().sum();
642
643 if left_total == 0 || right_total == 0 {
644 continue;
645 }
646
647 let left_impurity = log_loss_impurity(&left_counts);
649 let right_impurity = log_loss_impurity(&right_counts);
650 let left_weight = left_total as f64 / n_samples as f64;
651 let right_weight = right_total as f64 / n_samples as f64;
652 let weighted_impurity = left_weight * left_impurity + right_weight * right_impurity;
653
654 let impurity_decrease = initial_impurity - weighted_impurity;
655
656 if impurity_decrease > best_impurity_decrease {
657 best_impurity_decrease = impurity_decrease;
658 best_split = Some(CustomSplit {
659 feature_idx,
660 threshold,
661 impurity_decrease,
662 left_count: left_total,
663 right_count: right_total,
664 });
665 }
666 }
667 }
668
669 best_split
670}
671
672pub fn mae_impurity(values: &[f64]) -> f64 {
674 if values.is_empty() {
675 return 0.0;
676 }
677
678 let median = {
679 let mut sorted_values = values.to_vec();
680 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
681 let len = sorted_values.len();
682 if len % 2 == 0 {
683 (sorted_values[len / 2 - 1] + sorted_values[len / 2]) / 2.0
684 } else {
685 sorted_values[len / 2]
686 }
687 };
688
689 values.iter().map(|v| (v - median).abs()).sum::<f64>() / values.len() as f64
690}
691
692pub fn twoing_impurity(left_counts: &[usize], right_counts: &[usize]) -> f64 {
694 let left_total: usize = left_counts.iter().sum();
695 let right_total: usize = right_counts.iter().sum();
696 let total = left_total + right_total;
697
698 if total == 0 || left_total == 0 || right_total == 0 {
699 return 0.0;
700 }
701
702 let mut twoing_value = 0.0;
703 for i in 0..left_counts.len() {
704 let left_prob = left_counts[i] as f64 / left_total as f64;
705 let right_prob = right_counts[i] as f64 / right_total as f64;
706 twoing_value += (left_prob - right_prob).abs();
707 }
708
709 let p_left = left_total as f64 / total as f64;
711 let p_right = right_total as f64 / total as f64;
712 0.25 * p_left * p_right * twoing_value.powi(2)
713}
714
715pub fn log_loss_impurity(class_counts: &[usize]) -> f64 {
717 let total: usize = class_counts.iter().sum();
718 if total == 0 {
719 return 0.0;
720 }
721
722 class_counts
723 .iter()
724 .filter(|&&count| count > 0)
725 .map(|&count| {
726 let prob = count as f64 / total as f64;
727 -prob * prob.ln()
728 })
729 .sum()
730}
731
732pub fn gini_impurity(class_counts: &[i32], total_samples: i32) -> f64 {
734 if total_samples == 0 {
735 return 0.0;
736 }
737
738 let mut impurity = 1.0;
739 for &count in class_counts {
740 let probability = count as f64 / total_samples as f64;
741 impurity -= probability * probability;
742 }
743 impurity
744}
745
746pub fn apply_feature_grouping(
748 grouping: &FeatureGrouping,
749 x: &Array2<Float>,
750 y: &Array1<Float>,
751) -> Result<(Array2<Float>, FeatureGroupInfo)> {
752 match grouping {
753 FeatureGrouping::None => {
754 let n_features = x.ncols();
756 let info = FeatureGroupInfo {
757 groups: (0..n_features).map(|i| vec![i]).collect(),
758 representatives: (0..n_features).collect(),
759 correlation_matrix: None,
760 group_correlations: vec![1.0; n_features],
761 };
762 Ok((x.clone(), info))
763 }
764 FeatureGrouping::AutoCorrelation {
765 threshold,
766 selection_method,
767 } => apply_auto_correlation_grouping(x, y, *threshold, *selection_method),
768 FeatureGrouping::Manual {
769 groups,
770 selection_method,
771 } => apply_manual_grouping(x, y, groups, *selection_method),
772 FeatureGrouping::Hierarchical {
773 n_clusters,
774 linkage,
775 selection_method,
776 } => apply_hierarchical_grouping(x, y, *n_clusters, *linkage, *selection_method),
777 }
778}
779
780pub fn apply_auto_correlation_grouping(
782 x: &Array2<Float>,
783 y: &Array1<Float>,
784 threshold: Float,
785 selection_method: GroupSelectionMethod,
786) -> Result<(Array2<Float>, FeatureGroupInfo)> {
787 let n_features = x.ncols();
788
789 if n_features == 0 {
790 return Err(SklearsError::InvalidInput(
791 "Cannot apply feature grouping to empty feature set".to_string(),
792 ));
793 }
794
795 let correlation_matrix = calculate_correlation_matrix(x)?;
797
798 let mut groups = Vec::new();
800 let mut assigned = vec![false; n_features];
801
802 for i in 0..n_features {
803 if assigned[i] {
804 continue;
805 }
806
807 let mut group = vec![i];
808 assigned[i] = true;
809
810 for j in (i + 1)..n_features {
812 if !assigned[j] && correlation_matrix[[i, j]].abs() >= threshold {
813 group.push(j);
814 assigned[j] = true;
815 }
816 }
817
818 groups.push(group);
819 }
820
821 let mut representatives = Vec::new();
823 let mut group_correlations = Vec::new();
824
825 for group in &groups {
826 let (representative, avg_correlation) =
827 select_group_representative(x, y, group, selection_method)?;
828 representatives.push(representative);
829 group_correlations.push(avg_correlation);
830 }
831
832 let reduced_x = create_reduced_feature_matrix(x, &representatives)?;
834
835 let info = FeatureGroupInfo {
836 groups,
837 representatives,
838 correlation_matrix: Some(correlation_matrix),
839 group_correlations,
840 };
841
842 Ok((reduced_x, info))
843}
844
845pub fn apply_manual_grouping(
847 x: &Array2<Float>,
848 y: &Array1<Float>,
849 groups: &[Vec<usize>],
850 selection_method: GroupSelectionMethod,
851) -> Result<(Array2<Float>, FeatureGroupInfo)> {
852 let n_features = x.ncols();
853
854 let mut assigned = vec![false; n_features];
856 for group in groups {
857 for &feature_idx in group {
858 if feature_idx >= n_features {
859 return Err(SklearsError::InvalidInput(format!(
860 "Feature index {} out of bounds",
861 feature_idx
862 )));
863 }
864 if assigned[feature_idx] {
865 return Err(SklearsError::InvalidInput(format!(
866 "Feature {} appears in multiple groups",
867 feature_idx
868 )));
869 }
870 assigned[feature_idx] = true;
871 }
872 }
873
874 let mut complete_groups = groups.to_vec();
876 for (i, &is_assigned) in assigned.iter().enumerate() {
877 if !is_assigned {
878 complete_groups.push(vec![i]);
879 }
880 }
881
882 let mut representatives = Vec::new();
884 let mut group_correlations = Vec::new();
885
886 for group in &complete_groups {
887 let (representative, avg_correlation) =
888 select_group_representative(x, y, group, selection_method)?;
889 representatives.push(representative);
890 group_correlations.push(avg_correlation);
891 }
892
893 let reduced_x = create_reduced_feature_matrix(x, &representatives)?;
895
896 let info = FeatureGroupInfo {
897 groups: complete_groups,
898 representatives,
899 correlation_matrix: None,
900 group_correlations,
901 };
902
903 Ok((reduced_x, info))
904}
905
906pub fn apply_hierarchical_grouping(
908 x: &Array2<Float>,
909 y: &Array1<Float>,
910 n_clusters: usize,
911 linkage: LinkageMethod,
912 selection_method: GroupSelectionMethod,
913) -> Result<(Array2<Float>, FeatureGroupInfo)> {
914 let n_features = x.ncols();
915
916 if n_clusters == 0 || n_clusters > n_features {
917 return Err(SklearsError::InvalidInput(format!(
918 "n_clusters must be between 1 and {} (number of features)",
919 n_features
920 )));
921 }
922
923 let correlation_matrix = calculate_correlation_matrix(x)?;
925 let mut distance_matrix = Array2::<Float>::zeros((n_features, n_features));
926
927 for i in 0..n_features {
928 for j in 0..n_features {
929 distance_matrix[[i, j]] = 1.0 - correlation_matrix[[i, j]].abs();
930 }
931 }
932
933 let groups = hierarchical_clustering(&distance_matrix, n_clusters, linkage)?;
935
936 let mut representatives = Vec::new();
938 let mut group_correlations = Vec::new();
939
940 for group in &groups {
941 let (representative, avg_correlation) =
942 select_group_representative(x, y, group, selection_method)?;
943 representatives.push(representative);
944 group_correlations.push(avg_correlation);
945 }
946
947 let reduced_x = create_reduced_feature_matrix(x, &representatives)?;
949
950 let info = FeatureGroupInfo {
951 groups,
952 representatives,
953 correlation_matrix: Some(correlation_matrix),
954 group_correlations,
955 };
956
957 Ok((reduced_x, info))
958}
959
960pub fn calculate_correlation_matrix(x: &Array2<Float>) -> Result<Array2<Float>> {
962 let n_features = x.ncols();
963 let n_samples = x.nrows();
964
965 if n_samples < 2 {
966 return Err(SklearsError::InvalidInput(
967 "Need at least 2 samples to calculate correlations".to_string(),
968 ));
969 }
970
971 let mut correlation_matrix = Array2::<Float>::zeros((n_features, n_features));
972
973 let means: Vec<Float> = (0..n_features)
975 .map(|j| x.column(j).mean().unwrap_or(0.0))
976 .collect();
977
978 for i in 0..n_features {
980 for j in i..n_features {
981 if i == j {
982 correlation_matrix[[i, j]] = 1.0;
983 } else {
984 let corr = calculate_pearson_correlation(
985 &x.column(i).to_owned(),
986 &x.column(j).to_owned(),
987 means[i],
988 means[j],
989 )?;
990 correlation_matrix[[i, j]] = corr;
991 correlation_matrix[[j, i]] = corr;
992 }
993 }
994 }
995
996 Ok(correlation_matrix)
997}
998
999pub fn calculate_pearson_correlation(
1001 x: &Array1<Float>,
1002 y: &Array1<Float>,
1003 mean_x: Float,
1004 mean_y: Float,
1005) -> Result<Float> {
1006 let n = x.len();
1007
1008 if n != y.len() {
1009 return Err(SklearsError::InvalidInput(format!(
1010 "Feature vectors must have same length: {} vs {}",
1011 n,
1012 y.len()
1013 )));
1014 }
1015
1016 let mut sum_xy = 0.0;
1017 let mut sum_x2 = 0.0;
1018 let mut sum_y2 = 0.0;
1019
1020 for i in 0..n {
1021 let dx = x[i] - mean_x;
1022 let dy = y[i] - mean_y;
1023 sum_xy += dx * dy;
1024 sum_x2 += dx * dx;
1025 sum_y2 += dy * dy;
1026 }
1027
1028 let denominator = (sum_x2 * sum_y2).sqrt();
1029
1030 if denominator.abs() < Float::EPSILON {
1031 Ok(0.0) } else {
1033 Ok(sum_xy / denominator)
1034 }
1035}
1036
1037pub fn select_group_representative(
1039 x: &Array2<Float>,
1040 y: &Array1<Float>,
1041 group: &[usize],
1042 method: GroupSelectionMethod,
1043) -> Result<(usize, Float)> {
1044 if group.is_empty() {
1045 return Err(SklearsError::InvalidInput(
1046 "Cannot select representative from empty group".to_string(),
1047 ));
1048 }
1049
1050 if group.len() == 1 {
1051 return Ok((group[0], 1.0));
1052 }
1053
1054 match method {
1055 GroupSelectionMethod::MaxVariance => {
1056 let mut max_variance = f64::NEG_INFINITY;
1057 let mut best_feature = group[0];
1058
1059 for &feature_idx in group {
1060 let column = x.column(feature_idx);
1061 let mean = column.mean().unwrap_or(0.0);
1062 let variance =
1063 column.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / column.len() as f64;
1064
1065 if variance > max_variance {
1066 max_variance = variance;
1067 best_feature = feature_idx;
1068 }
1069 }
1070
1071 Ok((best_feature, max_variance))
1072 }
1073 GroupSelectionMethod::MaxTargetCorrelation => {
1074 let mut max_correlation = f64::NEG_INFINITY;
1075 let mut best_feature = group[0];
1076
1077 let y_mean = y.mean().unwrap_or(0.0);
1078
1079 for &feature_idx in group {
1080 let x_col = x.column(feature_idx).to_owned();
1081 let x_mean = x_col.mean().unwrap_or(0.0);
1082 let correlation = calculate_pearson_correlation(&x_col, y, x_mean, y_mean)?;
1083
1084 if correlation.abs() > max_correlation {
1085 max_correlation = correlation.abs();
1086 best_feature = feature_idx;
1087 }
1088 }
1089
1090 Ok((best_feature, max_correlation))
1091 }
1092 GroupSelectionMethod::First => Ok((group[0], 1.0)),
1093 GroupSelectionMethod::Random => {
1094 use scirs2_core::random::thread_rng;
1095 let mut rng = thread_rng();
1096 let idx = rng.gen_range(0..group.len());
1097 Ok((group[idx], 1.0))
1098 }
1099 GroupSelectionMethod::WeightedAll => {
1100 Ok((group[0], 1.0))
1103 }
1104 }
1105}
1106
1107pub fn create_reduced_feature_matrix(
1109 x: &Array2<Float>,
1110 representatives: &[usize],
1111) -> Result<Array2<Float>> {
1112 let n_samples = x.nrows();
1113 let n_representatives = representatives.len();
1114
1115 let mut reduced_x = Array2::zeros((n_samples, n_representatives));
1116
1117 for (new_col, &orig_col) in representatives.iter().enumerate() {
1118 if orig_col >= x.ncols() {
1119 return Err(SklearsError::InvalidInput(format!(
1120 "Representative feature index {} out of bounds",
1121 orig_col
1122 )));
1123 }
1124
1125 reduced_x.column_mut(new_col).assign(&x.column(orig_col));
1126 }
1127
1128 Ok(reduced_x)
1129}
1130
1131pub fn hierarchical_clustering(
1133 distance_matrix: &Array2<Float>,
1134 n_clusters: usize,
1135 linkage: LinkageMethod,
1136) -> Result<Vec<Vec<usize>>> {
1137 let n_features = distance_matrix.nrows();
1138
1139 if n_features != distance_matrix.ncols() {
1140 return Err(SklearsError::InvalidInput(
1141 "Distance matrix must be square".to_string(),
1142 ));
1143 }
1144
1145 let mut clusters: Vec<Vec<usize>> = (0..n_features).map(|i| vec![i]).collect();
1147
1148 while clusters.len() > n_clusters {
1150 let mut min_distance = Float::INFINITY;
1152 let mut merge_i = 0;
1153 let mut merge_j = 1;
1154
1155 for i in 0..clusters.len() {
1156 for j in (i + 1)..clusters.len() {
1157 let distance =
1158 cluster_distance(&clusters[i], &clusters[j], distance_matrix, linkage);
1159 if distance < min_distance {
1160 min_distance = distance;
1161 merge_i = i;
1162 merge_j = j;
1163 }
1164 }
1165 }
1166
1167 let cluster_j = clusters.remove(merge_j);
1169 clusters[merge_i].extend(cluster_j);
1170 }
1171
1172 Ok(clusters)
1173}
1174
1175fn cluster_distance(
1177 cluster1: &[usize],
1178 cluster2: &[usize],
1179 distance_matrix: &Array2<Float>,
1180 linkage: LinkageMethod,
1181) -> Float {
1182 match linkage {
1183 LinkageMethod::Single => {
1184 let mut min_dist = Float::INFINITY;
1186 for &i in cluster1 {
1187 for &j in cluster2 {
1188 let dist = distance_matrix[[i, j]];
1189 if dist < min_dist {
1190 min_dist = dist;
1191 }
1192 }
1193 }
1194 min_dist
1195 }
1196 LinkageMethod::Complete => {
1197 let mut max_dist = Float::NEG_INFINITY;
1199 for &i in cluster1 {
1200 for &j in cluster2 {
1201 let dist = distance_matrix[[i, j]];
1202 if dist > max_dist {
1203 max_dist = dist;
1204 }
1205 }
1206 }
1207 max_dist
1208 }
1209 LinkageMethod::Average => {
1210 let mut total_dist = 0.0;
1212 let mut count = 0;
1213 for &i in cluster1 {
1214 for &j in cluster2 {
1215 total_dist += distance_matrix[[i, j]];
1216 count += 1;
1217 }
1218 }
1219 if count > 0 {
1220 total_dist / count as Float
1221 } else {
1222 0.0
1223 }
1224 }
1225 LinkageMethod::Ward => {
1226 let mut total_dist = 0.0;
1229 let mut count = 0;
1230 for &i in cluster1 {
1231 for &j in cluster2 {
1232 total_dist += distance_matrix[[i, j]];
1233 count += 1;
1234 }
1235 }
1236 if count > 0 {
1237 total_dist / count as Float
1238 } else {
1239 0.0
1240 }
1241 }
1242 }
1243}