sklears_tree/
config.rs

1//! Configuration types and enums for decision trees
2//!
3//! This module contains all the configuration enums, structs, and parameters
4//! used by decision tree classifiers and regressors.
5
6use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::types::Float;
8use smartcore::linalg::basic::matrix::DenseMatrix;
9
10// Import types from criteria module
11use crate::criteria::SplitCriterion;
12
13/// Monotonic constraint for a feature
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum MonotonicConstraint {
16    /// No constraint on the relationship
17    None,
18    /// Feature must have increasing relationship with target (positive monotonicity)
19    Increasing,
20    /// Feature must have decreasing relationship with target (negative monotonicity)
21    Decreasing,
22}
23
24/// Interaction constraint between features
25#[derive(Debug, Clone)]
26pub enum InteractionConstraint {
27    /// No constraints on feature interactions
28    None,
29    /// Allow interactions only within specified groups
30    Groups(Vec<Vec<usize>>),
31    /// Forbid specific feature pairs from interacting
32    Forbidden(Vec<(usize, usize)>),
33    /// Allow only specific feature pairs to interact
34    Allowed(Vec<(usize, usize)>),
35}
36
37/// Feature grouping strategy for handling correlated features
38#[derive(Debug, Clone)]
39pub enum FeatureGrouping {
40    /// No feature grouping (default)
41    None,
42    /// Automatic grouping based on correlation threshold
43    AutoCorrelation {
44        /// Correlation threshold above which features are grouped together
45        threshold: Float,
46        /// Method to select representative feature from each group
47        selection_method: GroupSelectionMethod,
48    },
49    /// Manual feature groups specified by user
50    Manual {
51        /// List of feature groups, each group contains feature indices
52        groups: Vec<Vec<usize>>,
53        /// Method to select representative feature from each group
54        selection_method: GroupSelectionMethod,
55    },
56    /// Hierarchical clustering-based grouping
57    Hierarchical {
58        /// Number of clusters to create
59        n_clusters: usize,
60        /// Linkage method for hierarchical clustering
61        linkage: LinkageMethod,
62        /// Method to select representative feature from each group
63        selection_method: GroupSelectionMethod,
64    },
65}
66
67/// Method for selecting representative feature from a group
68#[derive(Debug, Clone, Copy, PartialEq)]
69pub enum GroupSelectionMethod {
70    /// Select feature with highest variance within the group
71    MaxVariance,
72    /// Select feature with highest correlation to target
73    MaxTargetCorrelation,
74    /// Select first feature in the group (index order)
75    First,
76    /// Select random feature from the group
77    Random,
78    /// Use all features from the group but with reduced weight
79    WeightedAll,
80}
81
82/// Linkage method for hierarchical clustering
83#[derive(Debug, Clone, Copy, PartialEq)]
84pub enum LinkageMethod {
85    /// Single linkage (minimum distance)
86    Single,
87    /// Complete linkage (maximum distance)
88    Complete,
89    /// Average linkage
90    Average,
91    /// Ward linkage (minimize within-cluster variance)
92    Ward,
93}
94
95/// Information about feature groups discovered or specified
96#[derive(Debug, Clone)]
97pub struct FeatureGroupInfo {
98    /// Groups of correlated features
99    pub groups: Vec<Vec<usize>>,
100    /// Representative feature index for each group
101    pub representatives: Vec<usize>,
102    /// Correlation matrix used for grouping (if applicable)
103    pub correlation_matrix: Option<Array2<Float>>,
104    /// Within-group correlations for each group
105    pub group_correlations: Vec<Float>,
106}
107
108/// Strategy for selecting max features
109#[derive(Debug, Clone)]
110pub enum MaxFeatures {
111    /// Use all features
112    All,
113    /// Use sqrt(n_features)
114    Sqrt,
115    /// Use log2(n_features)
116    Log2,
117    /// Use a specific number of features
118    Number(usize),
119    /// Use a fraction of features
120    Fraction(f64),
121}
122
123/// Pruning strategy for decision trees
124#[derive(Debug, Clone, Copy)]
125pub enum PruningStrategy {
126    /// No pruning
127    None,
128    /// Cost-complexity pruning (post-pruning)
129    CostComplexity { alpha: f64 },
130    /// Reduced error pruning
131    ReducedError,
132}
133
134/// Missing value handling strategy
135#[derive(Debug, Clone, Copy)]
136pub enum MissingValueStrategy {
137    /// Skip samples with missing values
138    Skip,
139    /// Use majority class/mean for splits
140    Majority,
141    /// Use surrogate splits
142    Surrogate,
143}
144
145/// Feature type specification for multiway splits
146#[derive(Debug, Clone)]
147pub enum FeatureType {
148    /// Continuous numerical feature (binary splits)
149    Continuous,
150    /// Categorical feature with specified categories (multiway splits)
151    Categorical(Vec<String>),
152}
153
154/// Information about a multiway split
155#[derive(Debug, Clone)]
156pub struct MultiwaySplit {
157    /// Feature index
158    pub feature_idx: usize,
159    /// Category assignments for each branch
160    pub category_branches: Vec<Vec<String>>,
161    /// Impurity decrease achieved by this split
162    pub impurity_decrease: f64,
163}
164
165/// Tree growing strategy
166#[derive(Debug, Clone, Copy)]
167pub enum TreeGrowingStrategy {
168    /// Depth-first growing (traditional CART)
169    DepthFirst,
170    /// Best-first growing (expand node with highest impurity decrease)
171    BestFirst { max_leaves: Option<usize> },
172}
173
174/// Split type for decision trees
175#[derive(Debug, Clone, Copy)]
176pub enum SplitType {
177    /// Traditional axis-aligned splits (threshold on single feature)
178    AxisAligned,
179    /// Linear hyperplane splits (linear combination of features)
180    Oblique {
181        /// Number of random hyperplanes to evaluate per split
182        n_hyperplanes: usize,
183        /// Use ridge regression to find optimal hyperplane
184        use_ridge: bool,
185    },
186}
187
188/// Hyperplane split information for oblique trees
189#[derive(Debug, Clone)]
190pub struct HyperplaneSplit {
191    /// Feature coefficients for the hyperplane (w^T x >= threshold)
192    pub coefficients: Array1<f64>,
193    /// Threshold for the hyperplane split
194    pub threshold: f64,
195    /// Bias term for the hyperplane
196    pub bias: f64,
197    /// Impurity decrease achieved by this split
198    pub impurity_decrease: f64,
199}
200
201impl HyperplaneSplit {
202    /// Evaluate the hyperplane split for a sample
203    pub fn evaluate(&self, sample: &Array1<f64>) -> bool {
204        let dot_product = self.coefficients.dot(sample) + self.bias;
205        dot_product >= self.threshold
206    }
207
208    /// Create a random hyperplane with normalized coefficients
209    pub fn random(n_features: usize, rng: &mut scirs2_core::CoreRandom) -> Self {
210        let mut coefficients = Array1::zeros(n_features);
211        for i in 0..n_features {
212            coefficients[i] = rng.gen_range(-1.0..1.0);
213        }
214
215        // Normalize coefficients
216        let dot_product: f64 = coefficients.dot(&coefficients);
217        let norm = dot_product.sqrt();
218        if norm > 1e-10_f64 {
219            coefficients /= norm;
220        }
221
222        Self {
223            coefficients,
224            threshold: rng.gen_range(-1.0..1.0),
225            bias: rng.gen_range(-0.1..0.1),
226            impurity_decrease: 0.0,
227        }
228    }
229
230    /// Find optimal hyperplane using ridge regression
231    #[cfg(feature = "oblique")]
232    pub fn from_ridge_regression(x: &Array2<f64>, y: &Array1<f64>, alpha: f64) -> Result<Self> {
233        use scirs2_core::ndarray::s;
234        use sklears_core::error::SklearsError;
235
236        let n_features = x.ncols();
237        if x.nrows() < 2 {
238            return Err(SklearsError::InvalidInput(
239                "Need at least 2 samples for ridge regression".to_string(),
240            ));
241        }
242
243        // Add bias column to X
244        let mut x_bias = Array2::ones((x.nrows(), n_features + 1));
245        x_bias.slice_mut(s![.., ..n_features]).assign(x);
246
247        // Ridge regression: w = (X^T X + α I)^(-1) X^T y
248        let xtx = x_bias.t().dot(&x_bias);
249        let ridge_matrix = xtx + Array2::<f64>::eye(n_features + 1) * alpha;
250        let xty = x_bias.t().dot(y);
251
252        // Simple matrix inverse using Gauss-Jordan elimination
253        match gauss_jordan_inverse(&ridge_matrix) {
254            Ok(inv_matrix) => {
255                let coefficients_full = inv_matrix.dot(&xty);
256
257                let coefficients = coefficients_full.slice(s![..n_features]).to_owned();
258                let bias = coefficients_full[n_features];
259
260                Ok(Self {
261                    coefficients,
262                    threshold: 0.0, // Will be set during split evaluation
263                    bias,
264                    impurity_decrease: 0.0,
265                })
266            }
267            Err(_) => {
268                // Fallback to random hyperplane if matrix is singular
269                let mut rng = scirs2_core::random::thread_rng();
270                Ok(Self::random(n_features, &mut rng))
271            }
272        }
273    }
274}
275
276/// Configuration for Decision Trees
277#[derive(Debug, Clone)]
278pub struct DecisionTreeConfig {
279    /// Split criterion
280    pub criterion: SplitCriterion,
281    /// Maximum depth of the tree
282    pub max_depth: Option<usize>,
283    /// Minimum samples required to split an internal node
284    pub min_samples_split: usize,
285    /// Minimum samples required to be at a leaf node
286    pub min_samples_leaf: usize,
287    /// Maximum number of features to consider for splits
288    pub max_features: MaxFeatures,
289    /// Random seed for reproducibility
290    pub random_state: Option<u64>,
291    /// Minimum weighted fraction of samples required to be at a leaf
292    pub min_weight_fraction_leaf: f64,
293    /// Minimum impurity decrease required for a split
294    pub min_impurity_decrease: f64,
295    /// Pruning strategy to apply
296    pub pruning: PruningStrategy,
297    /// Strategy for handling missing values
298    pub missing_values: MissingValueStrategy,
299    /// Feature types for each feature (enables multiway splits for categorical features)
300    pub feature_types: Option<Vec<FeatureType>>,
301    /// Tree growing strategy
302    pub growing_strategy: TreeGrowingStrategy,
303    /// Split type (axis-aligned or oblique)
304    pub split_type: SplitType,
305    /// Monotonic constraints for each feature
306    pub monotonic_constraints: Option<Vec<MonotonicConstraint>>,
307    /// Interaction constraints between features
308    pub interaction_constraints: InteractionConstraint,
309    /// Feature grouping strategy for handling correlated features
310    pub feature_grouping: FeatureGrouping,
311}
312
313impl Default for DecisionTreeConfig {
314    fn default() -> Self {
315        Self {
316            criterion: SplitCriterion::Gini,
317            max_depth: None,
318            min_samples_split: 2,
319            min_samples_leaf: 1,
320            max_features: MaxFeatures::All,
321            random_state: None,
322            min_weight_fraction_leaf: 0.0,
323            min_impurity_decrease: 0.0,
324            pruning: PruningStrategy::None,
325            missing_values: MissingValueStrategy::Skip,
326            feature_types: None,
327            growing_strategy: TreeGrowingStrategy::DepthFirst,
328            split_type: SplitType::AxisAligned,
329            monotonic_constraints: None,
330            interaction_constraints: InteractionConstraint::None,
331            feature_grouping: FeatureGrouping::None,
332        }
333    }
334}
335
336/// Helper function to convert ndarray to DenseMatrix
337pub fn ndarray_to_dense_matrix(arr: &Array2<f64>) -> DenseMatrix<f64> {
338    let _rows = arr.nrows();
339    let _cols = arr.ncols();
340    let mut data = Vec::new();
341    for row in arr.outer_iter() {
342        data.push(row.to_vec());
343    }
344    DenseMatrix::from_2d_vec(&data)
345}
346
347/// Simple Gauss-Jordan elimination for matrix inversion
348#[cfg(feature = "oblique")]
349fn gauss_jordan_inverse(matrix: &Array2<f64>) -> std::result::Result<Array2<f64>, &'static str> {
350    let n = matrix.nrows();
351    if n != matrix.ncols() {
352        return Err("Matrix must be square");
353    }
354
355    // Create augmented matrix [A | I]
356    let mut augmented = Array2::zeros((n, 2 * n));
357    for i in 0..n {
358        for j in 0..n {
359            augmented[[i, j]] = matrix[[i, j]];
360            if i == j {
361                augmented[[i, j + n]] = 1.0;
362            }
363        }
364    }
365
366    // Forward elimination
367    for i in 0..n {
368        // Find pivot
369        let mut max_row = i;
370        for k in i + 1..n {
371            if augmented[[k, i]].abs() > augmented[[max_row, i]].abs() {
372                max_row = k;
373            }
374        }
375
376        // Swap rows if needed
377        if max_row != i {
378            for j in 0..2 * n {
379                let temp = augmented[[i, j]];
380                augmented[[i, j]] = augmented[[max_row, j]];
381                augmented[[max_row, j]] = temp;
382            }
383        }
384
385        // Check for singular matrix
386        if augmented[[i, i]].abs() < 1e-10 {
387            return Err("Matrix is singular");
388        }
389
390        // Make diagonal element 1
391        let pivot = augmented[[i, i]];
392        for j in 0..2 * n {
393            augmented[[i, j]] /= pivot;
394        }
395
396        // Eliminate column
397        for k in 0..n {
398            if k != i {
399                let factor = augmented[[k, i]];
400                for j in 0..2 * n {
401                    augmented[[k, j]] -= factor * augmented[[i, j]];
402                }
403            }
404        }
405    }
406
407    // Extract inverse matrix
408    let mut inverse = Array2::zeros((n, n));
409    for i in 0..n {
410        for j in 0..n {
411            inverse[[i, j]] = augmented[[i, j + n]];
412        }
413    }
414
415    Ok(inverse)
416}