sklears_tree/
builder.rs

1//! Tree building algorithms and utilities
2//!
3//! This module contains algorithms for building decision trees, including
4//! split finding, impurity calculations, and feature grouping utilities.
5
6use crate::config::*;
7use crate::node::*;
8use crate::SplitCriterion;
9use scirs2_core::ndarray::{Array1, Array2};
10use sklears_core::{
11    error::{Result, SklearsError},
12    types::Float,
13};
14use std::collections::BinaryHeap;
15
16/// Handle missing values in the data based on the specified strategy
17pub fn handle_missing_values<T: Clone>(
18    x: &Array2<f64>,
19    y: &Array1<T>,
20    strategy: MissingValueStrategy,
21) -> Result<(Array2<f64>, Array1<T>)> {
22    // Check for missing values (NaN)
23    let mut has_missing = false;
24    for value in x.iter() {
25        if value.is_nan() {
26            has_missing = true;
27            break;
28        }
29    }
30    if !has_missing {
31        // No missing values, return original data
32        return Ok((x.clone(), y.clone()));
33    }
34    match strategy {
35        MissingValueStrategy::Skip => {
36            // Remove rows with any missing values
37            let mut valid_indices = Vec::new();
38            for (row_idx, row) in x.outer_iter().enumerate() {
39                let mut row_valid = true;
40                for &value in row.iter() {
41                    if value.is_nan() {
42                        row_valid = false;
43                        break;
44                    }
45                }
46                if row_valid {
47                    valid_indices.push(row_idx);
48                }
49            }
50            if valid_indices.is_empty() {
51                return Err(SklearsError::InvalidData {
52                    reason: "All samples contain missing values".to_string(),
53                });
54            }
55            // Create new arrays with only valid rows
56            let n_valid = valid_indices.len();
57            let n_features = x.ncols();
58            let mut x_clean = Array2::zeros((n_valid, n_features));
59            let mut y_clean = Vec::with_capacity(n_valid);
60            for (new_idx, &orig_idx) in valid_indices.iter().enumerate() {
61                x_clean.row_mut(new_idx).assign(&x.row(orig_idx));
62                y_clean.push(y[orig_idx].clone());
63            }
64            Ok((x_clean, Array1::from_vec(y_clean)))
65        }
66        MissingValueStrategy::Majority => {
67            // Replace missing values with column means (for continuous) or mode (for discrete)
68            let mut x_imputed = x.clone();
69            for col_idx in 0..x.ncols() {
70                let column = x.column(col_idx);
71                // Calculate mean of non-missing values
72                let mut sum = 0.0;
73                let mut count = 0;
74                for &value in column.iter() {
75                    if !value.is_nan() {
76                        sum += value;
77                        count += 1;
78                    }
79                }
80                if count > 0 {
81                    let mean = sum / count as f64;
82                    // Replace missing values with mean
83                    for row_idx in 0..x.nrows() {
84                        if x_imputed[[row_idx, col_idx]].is_nan() {
85                            x_imputed[[row_idx, col_idx]] = mean;
86                        }
87                    }
88                } else {
89                    // All values in this column are missing, use 0.0
90                    for row_idx in 0..x.nrows() {
91                        x_imputed[[row_idx, col_idx]] = 0.0;
92                    }
93                }
94            }
95            Ok((x_imputed, y.clone()))
96        }
97        MissingValueStrategy::Surrogate => {
98            // TODO: Implement proper surrogate splits for missing value handling
99            // For now, fall back to mean imputation
100            let mut x_imputed = x.clone();
101            for col_idx in 0..x.ncols() {
102                let mut sum = 0.0;
103                let mut count = 0;
104                // Calculate mean of non-missing values
105                for row_idx in 0..x.nrows() {
106                    let value = x[[row_idx, col_idx]];
107                    if !value.is_nan() {
108                        sum += value;
109                        count += 1;
110                    }
111                }
112                if count > 0 {
113                    let mean = sum / count as f64;
114                    // Replace missing values with mean
115                    for row_idx in 0..x.nrows() {
116                        if x_imputed[[row_idx, col_idx]].is_nan() {
117                            x_imputed[[row_idx, col_idx]] = mean;
118                        }
119                    }
120                } else {
121                    // All values in this column are missing, use 0.0
122                    for row_idx in 0..x.nrows() {
123                        x_imputed[[row_idx, col_idx]] = 0.0;
124                    }
125                }
126            }
127            Ok((x_imputed, y.clone()))
128        }
129    }
130}
131
132/// Best-first tree builder
133#[derive(Debug)]
134pub struct BestFirstTreeBuilder {
135    /// Nodes in the tree
136    pub nodes: Vec<TreeNode>,
137    /// Priority queue of nodes to expand (ordered by potential decrease)
138    pub node_queue: BinaryHeap<NodePriority>,
139    /// Next node ID
140    pub next_node_id: usize,
141    /// Current number of leaves
142    pub n_leaves: usize,
143}
144
145impl BestFirstTreeBuilder {
146    /// Create a new best-first tree builder
147    pub fn new(
148        x: &Array2<f64>,
149        y: &Array1<i32>,
150        config: &DecisionTreeConfig,
151        n_classes: usize,
152    ) -> Self {
153        let n_samples = x.nrows();
154        let sample_indices: Vec<usize> = (0..n_samples).collect();
155
156        // Calculate root impurity and prediction
157        let mut class_counts = vec![0; n_classes];
158        for &sample_idx in &sample_indices {
159            let class = y[sample_idx] as usize;
160            if class < n_classes {
161                class_counts[class] += 1;
162            }
163        }
164
165        let impurity = gini_impurity(&class_counts, n_samples as i32);
166        let prediction = class_counts
167            .iter()
168            .enumerate()
169            .max_by_key(|(_, &count)| count)
170            .map(|(class, _)| class as f64)
171            .unwrap_or(0.0);
172
173        // Find best split for root
174        let best_split = find_best_split_for_node(x, y, &sample_indices, config, n_classes);
175        let potential_decrease = best_split
176            .as_ref()
177            .map(|s| s.impurity_decrease)
178            .unwrap_or(0.0);
179
180        let root_node = TreeNode {
181            id: 0,
182            depth: 0,
183            sample_indices,
184            impurity,
185            prediction,
186            potential_decrease,
187            best_split,
188            parent_id: None,
189            is_leaf: false,
190        };
191
192        let mut node_queue = BinaryHeap::new();
193        if potential_decrease > 0.0 {
194            node_queue.push(NodePriority {
195                node_id: 0,
196                priority: -potential_decrease, // Negative for max-heap
197            });
198        }
199
200        Self {
201            nodes: vec![root_node],
202            node_queue,
203            next_node_id: 1,
204            n_leaves: 1,
205        }
206    }
207
208    /// Build the tree using best-first strategy
209    pub fn build_tree(
210        &mut self,
211        x: &Array2<f64>,
212        y: &Array1<i32>,
213        config: &DecisionTreeConfig,
214        n_classes: usize,
215    ) -> Result<()> {
216        let max_leaves = match config.growing_strategy {
217            TreeGrowingStrategy::BestFirst { max_leaves } => max_leaves,
218            _ => None,
219        };
220
221        while let Some(node_priority) = self.node_queue.pop() {
222            let node_id = node_priority.node_id;
223
224            // Check stopping criteria
225            if let Some(max_leaves) = max_leaves {
226                if self.n_leaves >= max_leaves {
227                    break;
228                }
229            }
230
231            if let Some(max_depth) = config.max_depth {
232                if self.nodes[node_id].depth >= max_depth {
233                    continue;
234                }
235            }
236
237            // Check if we can split this node
238            if self.nodes[node_id].sample_indices.len() < config.min_samples_split {
239                continue;
240            }
241
242            // Split the node
243            if self.split_node(node_id, x, y, config, n_classes).is_err() {
244                continue;
245            }
246        }
247
248        Ok(())
249    }
250
251    /// Split a node and add children to the queue
252    fn split_node(
253        &mut self,
254        node_id: usize,
255        x: &Array2<f64>,
256        y: &Array1<i32>,
257        config: &DecisionTreeConfig,
258        n_classes: usize,
259    ) -> Result<()> {
260        let node = &self.nodes[node_id].clone();
261        let best_split = match &node.best_split {
262            Some(split) => split.clone(),
263            None => {
264                return Err(SklearsError::InvalidInput(
265                    "No valid split found".to_string(),
266                ))
267            }
268        };
269
270        // Split samples
271        let (left_indices, right_indices) = split_samples_by_threshold(
272            x,
273            &node.sample_indices,
274            best_split.feature_idx,
275            best_split.threshold,
276        );
277
278        if left_indices.len() < config.min_samples_leaf
279            || right_indices.len() < config.min_samples_leaf
280        {
281            return Err(SklearsError::InvalidInput(
282                "Split would create undersized leaves".to_string(),
283            ));
284        }
285
286        // Create left child
287        let left_node_id = self.next_node_id;
288        self.next_node_id += 1;
289
290        let left_node = self.create_child_node(
291            left_node_id,
292            node.id,
293            node.depth + 1,
294            left_indices,
295            x,
296            y,
297            config,
298            n_classes,
299        );
300
301        // Create right child
302        let right_node_id = self.next_node_id;
303        self.next_node_id += 1;
304
305        let right_node = self.create_child_node(
306            right_node_id,
307            node.id,
308            node.depth + 1,
309            right_indices,
310            x,
311            y,
312            config,
313            n_classes,
314        );
315
316        // Add children to queue if they can be split
317        if left_node.potential_decrease > config.min_impurity_decrease {
318            self.node_queue.push(NodePriority {
319                node_id: left_node_id,
320                priority: -left_node.potential_decrease,
321            });
322        }
323
324        if right_node.potential_decrease > config.min_impurity_decrease {
325            self.node_queue.push(NodePriority {
326                node_id: right_node_id,
327                priority: -right_node.potential_decrease,
328            });
329        }
330
331        self.nodes.push(left_node);
332        self.nodes.push(right_node);
333
334        // Mark parent as non-leaf and increment leaf count
335        self.nodes[node_id].is_leaf = false;
336        self.n_leaves += 1; // +2 children -1 parent = +1 net leaves
337
338        Ok(())
339    }
340
341    /// Create a child node
342    #[allow(clippy::too_many_arguments)]
343    fn create_child_node(
344        &self,
345        node_id: usize,
346        parent_id: usize,
347        depth: usize,
348        sample_indices: Vec<usize>,
349        x: &Array2<f64>,
350        y: &Array1<i32>,
351        config: &DecisionTreeConfig,
352        n_classes: usize,
353    ) -> TreeNode {
354        // Calculate impurity and prediction
355        let mut class_counts = vec![0; n_classes];
356        for &sample_idx in &sample_indices {
357            let class = y[sample_idx] as usize;
358            if class < n_classes {
359                class_counts[class] += 1;
360            }
361        }
362
363        let impurity = gini_impurity(&class_counts, sample_indices.len() as i32);
364        let prediction = class_counts
365            .iter()
366            .enumerate()
367            .max_by_key(|(_, &count)| count)
368            .map(|(class, _)| class as f64)
369            .unwrap_or(0.0);
370
371        // Find best split for this node
372        let best_split = find_best_split_for_node(x, y, &sample_indices, config, n_classes);
373        let potential_decrease = best_split
374            .as_ref()
375            .map(|s| s.impurity_decrease)
376            .unwrap_or(0.0);
377
378        TreeNode {
379            id: node_id,
380            depth,
381            sample_indices,
382            impurity,
383            prediction,
384            potential_decrease,
385            best_split,
386            parent_id: Some(parent_id),
387            is_leaf: true,
388        }
389    }
390}
391
392/// Find best split for a node given sample indices
393pub fn find_best_split_for_node(
394    x: &Array2<f64>,
395    y: &Array1<i32>,
396    sample_indices: &[usize],
397    config: &DecisionTreeConfig,
398    n_classes: usize,
399) -> Option<CustomSplit> {
400    if sample_indices.len() < config.min_samples_split {
401        return None;
402    }
403
404    // Create subset of data for this node
405    let n_samples = sample_indices.len();
406    let n_features = x.ncols();
407
408    let mut node_x = Array2::zeros((n_samples, n_features));
409    let mut node_y = Array1::zeros(n_samples);
410
411    for (new_idx, &orig_idx) in sample_indices.iter().enumerate() {
412        for j in 0..n_features {
413            node_x[[new_idx, j]] = x[[orig_idx, j]];
414        }
415        node_y[new_idx] = y[orig_idx];
416    }
417
418    // Find best split using existing logic
419    let feature_indices: Vec<usize> = (0..n_features).collect();
420
421    match config.criterion {
422        SplitCriterion::Gini | SplitCriterion::Entropy => {
423            find_best_twoing_split(&node_x, &node_y, &feature_indices, n_classes)
424        }
425        SplitCriterion::LogLoss => {
426            find_best_logloss_split(&node_x, &node_y, &feature_indices, n_classes)
427        }
428        _ => None, // For regression criteria, would need separate implementation
429    }
430}
431
432/// Split samples by threshold
433pub fn split_samples_by_threshold(
434    x: &Array2<f64>,
435    sample_indices: &[usize],
436    feature_idx: usize,
437    threshold: f64,
438) -> (Vec<usize>, Vec<usize>) {
439    let mut left_indices = Vec::new();
440    let mut right_indices = Vec::new();
441
442    for &sample_idx in sample_indices {
443        if x[[sample_idx, feature_idx]] <= threshold {
444            left_indices.push(sample_idx);
445        } else {
446            right_indices.push(sample_idx);
447        }
448    }
449
450    (left_indices, right_indices)
451}
452
453/// Find best split using MAE criterion for regression
454pub fn find_best_mae_split(
455    x: &Array2<f64>,
456    y: &Array1<f64>,
457    feature_indices: &[usize],
458) -> Option<CustomSplit> {
459    let n_samples = x.nrows();
460    let mut best_split: Option<CustomSplit> = None;
461    let mut best_impurity_decrease = f64::NEG_INFINITY;
462
463    // Calculate initial impurity
464    let y_values: Vec<f64> = y.iter().cloned().collect();
465    let initial_impurity = mae_impurity(&y_values);
466
467    for &feature_idx in feature_indices {
468        let feature_values = x.column(feature_idx);
469
470        // Create (value, target) pairs and sort by feature value
471        let mut pairs: Vec<(f64, f64)> = feature_values
472            .iter()
473            .zip(y.iter())
474            .map(|(&x_val, &y_val)| (x_val, y_val))
475            .collect();
476
477        pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
478
479        // Try each potential split point
480        for i in 1..pairs.len() {
481            if pairs[i - 1].0 >= pairs[i].0 {
482                continue; // Skip identical values
483            }
484
485            let threshold = (pairs[i - 1].0 + pairs[i].0) / 2.0;
486
487            // Split data
488            let left_values: Vec<f64> = pairs[..i].iter().map(|(_, y)| *y).collect();
489            let right_values: Vec<f64> = pairs[i..].iter().map(|(_, y)| *y).collect();
490
491            if left_values.is_empty() || right_values.is_empty() {
492                continue;
493            }
494
495            // Calculate weighted impurity
496            let left_impurity = mae_impurity(&left_values);
497            let right_impurity = mae_impurity(&right_values);
498            let left_weight = left_values.len() as f64 / n_samples as f64;
499            let right_weight = right_values.len() as f64 / n_samples as f64;
500            let weighted_impurity = left_weight * left_impurity + right_weight * right_impurity;
501
502            let impurity_decrease = initial_impurity - weighted_impurity;
503
504            if impurity_decrease > best_impurity_decrease {
505                best_impurity_decrease = impurity_decrease;
506                best_split = Some(CustomSplit {
507                    feature_idx,
508                    threshold,
509                    impurity_decrease,
510                    left_count: left_values.len(),
511                    right_count: right_values.len(),
512                });
513            }
514        }
515    }
516
517    best_split
518}
519
520/// Find best split using Twoing criterion for classification
521pub fn find_best_twoing_split(
522    x: &Array2<f64>,
523    y: &Array1<i32>,
524    feature_indices: &[usize],
525    n_classes: usize,
526) -> Option<CustomSplit> {
527    let _n_samples = x.nrows();
528    let mut best_split: Option<CustomSplit> = None;
529    let mut best_impurity_decrease = f64::NEG_INFINITY;
530
531    for &feature_idx in feature_indices {
532        let feature_values = x.column(feature_idx);
533
534        // Create (value, class) pairs and sort by feature value
535        let mut pairs: Vec<(f64, i32)> = feature_values
536            .iter()
537            .zip(y.iter())
538            .map(|(&x_val, &y_val)| (x_val, y_val))
539            .collect();
540
541        pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
542
543        // Try each potential split point
544        for i in 1..pairs.len() {
545            if pairs[i - 1].0 >= pairs[i].0 {
546                continue; // Skip identical values
547            }
548
549            let threshold = (pairs[i - 1].0 + pairs[i].0) / 2.0;
550
551            // Count classes in left and right splits
552            let mut left_counts = vec![0; n_classes];
553            let mut right_counts = vec![0; n_classes];
554
555            for (j, (_, class)) in pairs.iter().enumerate() {
556                let class_idx = *class as usize;
557                if j < i {
558                    left_counts[class_idx] += 1;
559                } else {
560                    right_counts[class_idx] += 1;
561                }
562            }
563
564            let left_total: usize = left_counts.iter().sum();
565            let right_total: usize = right_counts.iter().sum();
566
567            if left_total == 0 || right_total == 0 {
568                continue;
569            }
570
571            let impurity_decrease = twoing_impurity(&left_counts, &right_counts);
572
573            if impurity_decrease > best_impurity_decrease {
574                best_impurity_decrease = impurity_decrease;
575                best_split = Some(CustomSplit {
576                    feature_idx,
577                    threshold,
578                    impurity_decrease,
579                    left_count: left_total,
580                    right_count: right_total,
581                });
582            }
583        }
584    }
585
586    best_split
587}
588
589/// Find best split using Log-loss criterion for classification
590pub fn find_best_logloss_split(
591    x: &Array2<f64>,
592    y: &Array1<i32>,
593    feature_indices: &[usize],
594    n_classes: usize,
595) -> Option<CustomSplit> {
596    let n_samples = x.nrows();
597    let mut best_split: Option<CustomSplit> = None;
598    let mut best_impurity_decrease = f64::NEG_INFINITY;
599
600    // Calculate initial impurity
601    let mut initial_counts = vec![0; n_classes];
602    for &class in y.iter() {
603        initial_counts[class as usize] += 1;
604    }
605    let initial_impurity = log_loss_impurity(&initial_counts);
606
607    for &feature_idx in feature_indices {
608        let feature_values = x.column(feature_idx);
609
610        // Create (value, class) pairs and sort by feature value
611        let mut pairs: Vec<(f64, i32)> = feature_values
612            .iter()
613            .zip(y.iter())
614            .map(|(&x_val, &y_val)| (x_val, y_val))
615            .collect();
616
617        pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
618
619        // Try each potential split point
620        for i in 1..pairs.len() {
621            if pairs[i - 1].0 >= pairs[i].0 {
622                continue; // Skip identical values
623            }
624
625            let threshold = (pairs[i - 1].0 + pairs[i].0) / 2.0;
626
627            // Count classes in left and right splits
628            let mut left_counts = vec![0; n_classes];
629            let mut right_counts = vec![0; n_classes];
630
631            for (j, (_, class)) in pairs.iter().enumerate() {
632                let class_idx = *class as usize;
633                if j < i {
634                    left_counts[class_idx] += 1;
635                } else {
636                    right_counts[class_idx] += 1;
637                }
638            }
639
640            let left_total: usize = left_counts.iter().sum();
641            let right_total: usize = right_counts.iter().sum();
642
643            if left_total == 0 || right_total == 0 {
644                continue;
645            }
646
647            // Calculate weighted impurity
648            let left_impurity = log_loss_impurity(&left_counts);
649            let right_impurity = log_loss_impurity(&right_counts);
650            let left_weight = left_total as f64 / n_samples as f64;
651            let right_weight = right_total as f64 / n_samples as f64;
652            let weighted_impurity = left_weight * left_impurity + right_weight * right_impurity;
653
654            let impurity_decrease = initial_impurity - weighted_impurity;
655
656            if impurity_decrease > best_impurity_decrease {
657                best_impurity_decrease = impurity_decrease;
658                best_split = Some(CustomSplit {
659                    feature_idx,
660                    threshold,
661                    impurity_decrease,
662                    left_count: left_total,
663                    right_count: right_total,
664                });
665            }
666        }
667    }
668
669    best_split
670}
671
672/// Calculate Mean Absolute Error (MAE) impurity for regression
673pub fn mae_impurity(values: &[f64]) -> f64 {
674    if values.is_empty() {
675        return 0.0;
676    }
677
678    let median = {
679        let mut sorted_values = values.to_vec();
680        sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
681        let len = sorted_values.len();
682        if len % 2 == 0 {
683            (sorted_values[len / 2 - 1] + sorted_values[len / 2]) / 2.0
684        } else {
685            sorted_values[len / 2]
686        }
687    };
688
689    values.iter().map(|v| (v - median).abs()).sum::<f64>() / values.len() as f64
690}
691
692/// Calculate Twoing criterion impurity for binary classification
693pub fn twoing_impurity(left_counts: &[usize], right_counts: &[usize]) -> f64 {
694    let left_total: usize = left_counts.iter().sum();
695    let right_total: usize = right_counts.iter().sum();
696    let total = left_total + right_total;
697
698    if total == 0 || left_total == 0 || right_total == 0 {
699        return 0.0;
700    }
701
702    let mut twoing_value = 0.0;
703    for i in 0..left_counts.len() {
704        let left_prob = left_counts[i] as f64 / left_total as f64;
705        let right_prob = right_counts[i] as f64 / right_total as f64;
706        twoing_value += (left_prob - right_prob).abs();
707    }
708
709    // Twoing criterion: 0.25 * p_left * p_right * (sum|p_left_i - p_right_i|)^2
710    let p_left = left_total as f64 / total as f64;
711    let p_right = right_total as f64 / total as f64;
712    0.25 * p_left * p_right * twoing_value.powi(2)
713}
714
715/// Calculate Log-loss impurity for probability-based classification
716pub fn log_loss_impurity(class_counts: &[usize]) -> f64 {
717    let total: usize = class_counts.iter().sum();
718    if total == 0 {
719        return 0.0;
720    }
721
722    class_counts
723        .iter()
724        .filter(|&&count| count > 0)
725        .map(|&count| {
726            let prob = count as f64 / total as f64;
727            -prob * prob.ln()
728        })
729        .sum()
730}
731
732/// Calculate gini impurity for multiway splits
733pub fn gini_impurity(class_counts: &[i32], total_samples: i32) -> f64 {
734    if total_samples == 0 {
735        return 0.0;
736    }
737
738    let mut impurity = 1.0;
739    for &count in class_counts {
740        let probability = count as f64 / total_samples as f64;
741        impurity -= probability * probability;
742    }
743    impurity
744}
745
746/// Apply feature grouping to training data
747pub fn apply_feature_grouping(
748    grouping: &FeatureGrouping,
749    x: &Array2<Float>,
750    y: &Array1<Float>,
751) -> Result<(Array2<Float>, FeatureGroupInfo)> {
752    match grouping {
753        FeatureGrouping::None => {
754            // No grouping - return original data
755            let n_features = x.ncols();
756            let info = FeatureGroupInfo {
757                groups: (0..n_features).map(|i| vec![i]).collect(),
758                representatives: (0..n_features).collect(),
759                correlation_matrix: None,
760                group_correlations: vec![1.0; n_features],
761            };
762            Ok((x.clone(), info))
763        }
764        FeatureGrouping::AutoCorrelation {
765            threshold,
766            selection_method,
767        } => apply_auto_correlation_grouping(x, y, *threshold, *selection_method),
768        FeatureGrouping::Manual {
769            groups,
770            selection_method,
771        } => apply_manual_grouping(x, y, groups, *selection_method),
772        FeatureGrouping::Hierarchical {
773            n_clusters,
774            linkage,
775            selection_method,
776        } => apply_hierarchical_grouping(x, y, *n_clusters, *linkage, *selection_method),
777    }
778}
779
780/// Apply automatic correlation-based feature grouping
781pub fn apply_auto_correlation_grouping(
782    x: &Array2<Float>,
783    y: &Array1<Float>,
784    threshold: Float,
785    selection_method: GroupSelectionMethod,
786) -> Result<(Array2<Float>, FeatureGroupInfo)> {
787    let n_features = x.ncols();
788
789    if n_features == 0 {
790        return Err(SklearsError::InvalidInput(
791            "Cannot apply feature grouping to empty feature set".to_string(),
792        ));
793    }
794
795    // Calculate feature correlation matrix
796    let correlation_matrix = calculate_correlation_matrix(x)?;
797
798    // Find groups of correlated features
799    let mut groups = Vec::new();
800    let mut assigned = vec![false; n_features];
801
802    for i in 0..n_features {
803        if assigned[i] {
804            continue;
805        }
806
807        let mut group = vec![i];
808        assigned[i] = true;
809
810        // Find features correlated with feature i above threshold
811        for j in (i + 1)..n_features {
812            if !assigned[j] && correlation_matrix[[i, j]].abs() >= threshold {
813                group.push(j);
814                assigned[j] = true;
815            }
816        }
817
818        groups.push(group);
819    }
820
821    // Select representative features from each group
822    let mut representatives = Vec::new();
823    let mut group_correlations = Vec::new();
824
825    for group in &groups {
826        let (representative, avg_correlation) =
827            select_group_representative(x, y, group, selection_method)?;
828        representatives.push(representative);
829        group_correlations.push(avg_correlation);
830    }
831
832    // Create reduced feature matrix with only representatives
833    let reduced_x = create_reduced_feature_matrix(x, &representatives)?;
834
835    let info = FeatureGroupInfo {
836        groups,
837        representatives,
838        correlation_matrix: Some(correlation_matrix),
839        group_correlations,
840    };
841
842    Ok((reduced_x, info))
843}
844
845/// Apply manual feature grouping specified by user
846pub fn apply_manual_grouping(
847    x: &Array2<Float>,
848    y: &Array1<Float>,
849    groups: &[Vec<usize>],
850    selection_method: GroupSelectionMethod,
851) -> Result<(Array2<Float>, FeatureGroupInfo)> {
852    let n_features = x.ncols();
853
854    // Validate that groups don't overlap and cover all features
855    let mut assigned = vec![false; n_features];
856    for group in groups {
857        for &feature_idx in group {
858            if feature_idx >= n_features {
859                return Err(SklearsError::InvalidInput(format!(
860                    "Feature index {} out of bounds",
861                    feature_idx
862                )));
863            }
864            if assigned[feature_idx] {
865                return Err(SklearsError::InvalidInput(format!(
866                    "Feature {} appears in multiple groups",
867                    feature_idx
868                )));
869            }
870            assigned[feature_idx] = true;
871        }
872    }
873
874    // Add ungrouped features as singleton groups
875    let mut complete_groups = groups.to_vec();
876    for (i, &is_assigned) in assigned.iter().enumerate() {
877        if !is_assigned {
878            complete_groups.push(vec![i]);
879        }
880    }
881
882    // Select representatives for each group
883    let mut representatives = Vec::new();
884    let mut group_correlations = Vec::new();
885
886    for group in &complete_groups {
887        let (representative, avg_correlation) =
888            select_group_representative(x, y, group, selection_method)?;
889        representatives.push(representative);
890        group_correlations.push(avg_correlation);
891    }
892
893    // Create reduced feature matrix
894    let reduced_x = create_reduced_feature_matrix(x, &representatives)?;
895
896    let info = FeatureGroupInfo {
897        groups: complete_groups,
898        representatives,
899        correlation_matrix: None,
900        group_correlations,
901    };
902
903    Ok((reduced_x, info))
904}
905
906/// Apply hierarchical clustering-based feature grouping
907pub fn apply_hierarchical_grouping(
908    x: &Array2<Float>,
909    y: &Array1<Float>,
910    n_clusters: usize,
911    linkage: LinkageMethod,
912    selection_method: GroupSelectionMethod,
913) -> Result<(Array2<Float>, FeatureGroupInfo)> {
914    let n_features = x.ncols();
915
916    if n_clusters == 0 || n_clusters > n_features {
917        return Err(SklearsError::InvalidInput(format!(
918            "n_clusters must be between 1 and {} (number of features)",
919            n_features
920        )));
921    }
922
923    // Calculate distance matrix (1 - |correlation|)
924    let correlation_matrix = calculate_correlation_matrix(x)?;
925    let mut distance_matrix = Array2::<Float>::zeros((n_features, n_features));
926
927    for i in 0..n_features {
928        for j in 0..n_features {
929            distance_matrix[[i, j]] = 1.0 - correlation_matrix[[i, j]].abs();
930        }
931    }
932
933    // Perform hierarchical clustering (simplified implementation)
934    let groups = hierarchical_clustering(&distance_matrix, n_clusters, linkage)?;
935
936    // Select representatives for each group
937    let mut representatives = Vec::new();
938    let mut group_correlations = Vec::new();
939
940    for group in &groups {
941        let (representative, avg_correlation) =
942            select_group_representative(x, y, group, selection_method)?;
943        representatives.push(representative);
944        group_correlations.push(avg_correlation);
945    }
946
947    // Create reduced feature matrix
948    let reduced_x = create_reduced_feature_matrix(x, &representatives)?;
949
950    let info = FeatureGroupInfo {
951        groups,
952        representatives,
953        correlation_matrix: Some(correlation_matrix),
954        group_correlations,
955    };
956
957    Ok((reduced_x, info))
958}
959
960/// Calculate correlation matrix for features
961pub fn calculate_correlation_matrix(x: &Array2<Float>) -> Result<Array2<Float>> {
962    let n_features = x.ncols();
963    let n_samples = x.nrows();
964
965    if n_samples < 2 {
966        return Err(SklearsError::InvalidInput(
967            "Need at least 2 samples to calculate correlations".to_string(),
968        ));
969    }
970
971    let mut correlation_matrix = Array2::<Float>::zeros((n_features, n_features));
972
973    // Calculate means
974    let means: Vec<Float> = (0..n_features)
975        .map(|j| x.column(j).mean().unwrap_or(0.0))
976        .collect();
977
978    // Calculate correlation for each pair of features
979    for i in 0..n_features {
980        for j in i..n_features {
981            if i == j {
982                correlation_matrix[[i, j]] = 1.0;
983            } else {
984                let corr = calculate_pearson_correlation(
985                    &x.column(i).to_owned(),
986                    &x.column(j).to_owned(),
987                    means[i],
988                    means[j],
989                )?;
990                correlation_matrix[[i, j]] = corr;
991                correlation_matrix[[j, i]] = corr;
992            }
993        }
994    }
995
996    Ok(correlation_matrix)
997}
998
999/// Calculate Pearson correlation between two feature vectors
1000pub fn calculate_pearson_correlation(
1001    x: &Array1<Float>,
1002    y: &Array1<Float>,
1003    mean_x: Float,
1004    mean_y: Float,
1005) -> Result<Float> {
1006    let n = x.len();
1007
1008    if n != y.len() {
1009        return Err(SklearsError::InvalidInput(format!(
1010            "Feature vectors must have same length: {} vs {}",
1011            n,
1012            y.len()
1013        )));
1014    }
1015
1016    let mut sum_xy = 0.0;
1017    let mut sum_x2 = 0.0;
1018    let mut sum_y2 = 0.0;
1019
1020    for i in 0..n {
1021        let dx = x[i] - mean_x;
1022        let dy = y[i] - mean_y;
1023        sum_xy += dx * dy;
1024        sum_x2 += dx * dx;
1025        sum_y2 += dy * dy;
1026    }
1027
1028    let denominator = (sum_x2 * sum_y2).sqrt();
1029
1030    if denominator.abs() < Float::EPSILON {
1031        Ok(0.0) // No correlation if one or both variables have no variance
1032    } else {
1033        Ok(sum_xy / denominator)
1034    }
1035}
1036
1037/// Select representative feature from a group
1038pub fn select_group_representative(
1039    x: &Array2<Float>,
1040    y: &Array1<Float>,
1041    group: &[usize],
1042    method: GroupSelectionMethod,
1043) -> Result<(usize, Float)> {
1044    if group.is_empty() {
1045        return Err(SklearsError::InvalidInput(
1046            "Cannot select representative from empty group".to_string(),
1047        ));
1048    }
1049
1050    if group.len() == 1 {
1051        return Ok((group[0], 1.0));
1052    }
1053
1054    match method {
1055        GroupSelectionMethod::MaxVariance => {
1056            let mut max_variance = f64::NEG_INFINITY;
1057            let mut best_feature = group[0];
1058
1059            for &feature_idx in group {
1060                let column = x.column(feature_idx);
1061                let mean = column.mean().unwrap_or(0.0);
1062                let variance =
1063                    column.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / column.len() as f64;
1064
1065                if variance > max_variance {
1066                    max_variance = variance;
1067                    best_feature = feature_idx;
1068                }
1069            }
1070
1071            Ok((best_feature, max_variance))
1072        }
1073        GroupSelectionMethod::MaxTargetCorrelation => {
1074            let mut max_correlation = f64::NEG_INFINITY;
1075            let mut best_feature = group[0];
1076
1077            let y_mean = y.mean().unwrap_or(0.0);
1078
1079            for &feature_idx in group {
1080                let x_col = x.column(feature_idx).to_owned();
1081                let x_mean = x_col.mean().unwrap_or(0.0);
1082                let correlation = calculate_pearson_correlation(&x_col, y, x_mean, y_mean)?;
1083
1084                if correlation.abs() > max_correlation {
1085                    max_correlation = correlation.abs();
1086                    best_feature = feature_idx;
1087                }
1088            }
1089
1090            Ok((best_feature, max_correlation))
1091        }
1092        GroupSelectionMethod::First => Ok((group[0], 1.0)),
1093        GroupSelectionMethod::Random => {
1094            use scirs2_core::random::thread_rng;
1095            let mut rng = thread_rng();
1096            let idx = rng.gen_range(0..group.len());
1097            Ok((group[idx], 1.0))
1098        }
1099        GroupSelectionMethod::WeightedAll => {
1100            // For now, just return the first feature
1101            // In a full implementation, this would modify the training to use all features
1102            Ok((group[0], 1.0))
1103        }
1104    }
1105}
1106
1107/// Create reduced feature matrix with only representative features
1108pub fn create_reduced_feature_matrix(
1109    x: &Array2<Float>,
1110    representatives: &[usize],
1111) -> Result<Array2<Float>> {
1112    let n_samples = x.nrows();
1113    let n_representatives = representatives.len();
1114
1115    let mut reduced_x = Array2::zeros((n_samples, n_representatives));
1116
1117    for (new_col, &orig_col) in representatives.iter().enumerate() {
1118        if orig_col >= x.ncols() {
1119            return Err(SklearsError::InvalidInput(format!(
1120                "Representative feature index {} out of bounds",
1121                orig_col
1122            )));
1123        }
1124
1125        reduced_x.column_mut(new_col).assign(&x.column(orig_col));
1126    }
1127
1128    Ok(reduced_x)
1129}
1130
1131/// Simple hierarchical clustering implementation
1132pub fn hierarchical_clustering(
1133    distance_matrix: &Array2<Float>,
1134    n_clusters: usize,
1135    linkage: LinkageMethod,
1136) -> Result<Vec<Vec<usize>>> {
1137    let n_features = distance_matrix.nrows();
1138
1139    if n_features != distance_matrix.ncols() {
1140        return Err(SklearsError::InvalidInput(
1141            "Distance matrix must be square".to_string(),
1142        ));
1143    }
1144
1145    // Start with each feature in its own cluster
1146    let mut clusters: Vec<Vec<usize>> = (0..n_features).map(|i| vec![i]).collect();
1147
1148    // Merge clusters until we have the desired number
1149    while clusters.len() > n_clusters {
1150        // Find the two closest clusters
1151        let mut min_distance = Float::INFINITY;
1152        let mut merge_i = 0;
1153        let mut merge_j = 1;
1154
1155        for i in 0..clusters.len() {
1156            for j in (i + 1)..clusters.len() {
1157                let distance =
1158                    cluster_distance(&clusters[i], &clusters[j], distance_matrix, linkage);
1159                if distance < min_distance {
1160                    min_distance = distance;
1161                    merge_i = i;
1162                    merge_j = j;
1163                }
1164            }
1165        }
1166
1167        // Merge the closest clusters
1168        let cluster_j = clusters.remove(merge_j);
1169        clusters[merge_i].extend(cluster_j);
1170    }
1171
1172    Ok(clusters)
1173}
1174
1175/// Calculate distance between two clusters
1176fn cluster_distance(
1177    cluster1: &[usize],
1178    cluster2: &[usize],
1179    distance_matrix: &Array2<Float>,
1180    linkage: LinkageMethod,
1181) -> Float {
1182    match linkage {
1183        LinkageMethod::Single => {
1184            // Minimum distance between any two points
1185            let mut min_dist = Float::INFINITY;
1186            for &i in cluster1 {
1187                for &j in cluster2 {
1188                    let dist = distance_matrix[[i, j]];
1189                    if dist < min_dist {
1190                        min_dist = dist;
1191                    }
1192                }
1193            }
1194            min_dist
1195        }
1196        LinkageMethod::Complete => {
1197            // Maximum distance between any two points
1198            let mut max_dist = Float::NEG_INFINITY;
1199            for &i in cluster1 {
1200                for &j in cluster2 {
1201                    let dist = distance_matrix[[i, j]];
1202                    if dist > max_dist {
1203                        max_dist = dist;
1204                    }
1205                }
1206            }
1207            max_dist
1208        }
1209        LinkageMethod::Average => {
1210            // Average distance between all pairs of points
1211            let mut total_dist = 0.0;
1212            let mut count = 0;
1213            for &i in cluster1 {
1214                for &j in cluster2 {
1215                    total_dist += distance_matrix[[i, j]];
1216                    count += 1;
1217                }
1218            }
1219            if count > 0 {
1220                total_dist / count as Float
1221            } else {
1222                0.0
1223            }
1224        }
1225        LinkageMethod::Ward => {
1226            // For simplicity, use average linkage
1227            // A full Ward implementation would require centroid calculations
1228            let mut total_dist = 0.0;
1229            let mut count = 0;
1230            for &i in cluster1 {
1231                for &j in cluster2 {
1232                    total_dist += distance_matrix[[i, j]];
1233                    count += 1;
1234                }
1235            }
1236            if count > 0 {
1237                total_dist / count as Float
1238            } else {
1239                0.0
1240            }
1241        }
1242    }
1243}