sklears_tree/
criteria.rs

1//! Split criteria and constraints for decision trees
2//!
3//! This module contains enums and structs that define how decision trees
4//! make splitting decisions and handle constraints.
5
6use scirs2_core::ndarray::Array2;
7use sklears_core::types::Float;
8
9// Import types from config module
10
11/// Split criterion for decision trees
12#[derive(Debug, Clone, Copy)]
13pub enum SplitCriterion {
14    /// Gini impurity (for classification)
15    Gini,
16    /// Information gain / entropy (for classification)
17    Entropy,
18    /// Mean squared error (for regression)
19    MSE,
20    /// Mean absolute error (for regression)
21    MAE,
22    /// Twoing criterion for binary splits (classification)
23    Twoing,
24    /// Log-loss criterion for probability-based splitting (classification)
25    LogLoss,
26    /// Chi-squared automatic interaction detection (CHAID)
27    CHAID { significance_level: f64 },
28    /// Conditional inference trees with statistical testing
29    ConditionalInference {
30        significance_level: f64,
31        /// Type of statistical test to use
32        test_type: ConditionalTestType,
33    },
34}
35
36/// Statistical test types for conditional inference trees
37#[derive(Debug, Clone, Copy)]
38pub enum ConditionalTestType {
39    QuadraticForm,
40    MaxType,
41    MonteCarlo {
42        n_permutations: usize,
43    },
44    /// Asymptotic chi-squared test
45    AsymptoticChiSquared,
46}
47
48/// Monotonic constraint for a feature
49#[derive(Debug, Clone, Copy, PartialEq)]
50pub enum MonotonicConstraint {
51    /// No constraint on the relationship
52    None,
53    /// Feature must have increasing relationship with target (positive monotonicity)
54    Increasing,
55    /// Feature must have decreasing relationship with target (negative monotonicity)
56    Decreasing,
57}
58
59/// Interaction constraint between features
60#[derive(Debug, Clone)]
61pub enum InteractionConstraint {
62    /// No constraints on feature interactions
63    None,
64    /// Allow interactions only within specified groups
65    Groups(Vec<Vec<usize>>),
66    /// Forbid specific feature pairs from interacting
67    Forbidden(Vec<(usize, usize)>),
68    /// Allow only specific feature pairs to interact
69    Allowed(Vec<(usize, usize)>),
70}
71
72/// Feature grouping strategy for handling correlated features
73#[derive(Debug, Clone)]
74pub enum FeatureGrouping {
75    /// No feature grouping (default)
76    None,
77    /// Automatic grouping based on correlation threshold
78    AutoCorrelation {
79        /// Correlation threshold above which features are grouped together
80        threshold: Float,
81        /// Method to select representative feature from each group
82        selection_method: GroupSelectionMethod,
83    },
84    /// Manual feature groups specified by user
85    Manual {
86        /// List of feature groups, each group contains feature indices
87        groups: Vec<Vec<usize>>,
88        /// Method to select representative feature from each group
89        selection_method: GroupSelectionMethod,
90    },
91    /// Hierarchical clustering-based grouping
92    Hierarchical {
93        /// Number of clusters to create
94        n_clusters: usize,
95        /// Linkage method for hierarchical clustering
96        linkage: LinkageMethod,
97        /// Method to select representative feature from each group
98        selection_method: GroupSelectionMethod,
99    },
100}
101
102/// Method for selecting representative feature from a group
103#[derive(Debug, Clone, Copy, PartialEq)]
104pub enum GroupSelectionMethod {
105    /// Select feature with highest variance within the group
106    MaxVariance,
107    /// Select feature with highest correlation to target
108    MaxTargetCorrelation,
109    /// Select first feature in the group (index order)
110    First,
111    /// Select random feature from the group
112    Random,
113    /// Use all features from the group but with reduced weight
114    WeightedAll,
115}
116
117/// Linkage method for hierarchical clustering
118#[derive(Debug, Clone, Copy, PartialEq)]
119pub enum LinkageMethod {
120    /// Single linkage (minimum distance)
121    Single,
122    /// Complete linkage (maximum distance)
123    Complete,
124    /// Average linkage
125    Average,
126    /// Ward linkage (minimize within-cluster variance)
127    Ward,
128}
129
130/// Information about feature groups discovered or specified
131#[derive(Debug, Clone)]
132pub struct FeatureGroupInfo {
133    /// Groups of correlated features
134    pub groups: Vec<Vec<usize>>,
135    /// Representative feature index for each group
136    pub representatives: Vec<usize>,
137    /// Correlation matrix used for grouping (if applicable)
138    pub correlation_matrix: Option<Array2<Float>>,
139    /// Within-group correlations for each group
140    pub group_correlations: Vec<Float>,
141}
142
143/// Pruning strategy for decision trees
144#[derive(Debug, Clone, Copy)]
145pub enum PruningStrategy {
146    /// No pruning
147    None,
148    /// Cost-complexity pruning (post-pruning)
149    CostComplexity { alpha: f64 },
150    /// Reduced error pruning
151    ReducedError,
152}
153
154/// Missing value handling strategy
155#[derive(Debug, Clone, Copy)]
156pub enum MissingValueStrategy {
157    /// Skip samples with missing values
158    Skip,
159    /// Use majority class/mean for splits
160    Majority,
161    /// Use surrogate splits
162    Surrogate,
163}
164
165/// Feature type specification for multiway splits
166#[derive(Debug, Clone)]
167pub enum FeatureType {
168    /// Continuous numerical feature (binary splits)
169    Continuous,
170    /// Categorical feature with specified categories (multiway splits)
171    Categorical(Vec<String>),
172}
173
174/// Information about a multiway split
175#[derive(Debug, Clone)]
176pub struct MultiwaySplit {
177    /// Feature index
178    pub feature_idx: usize,
179    /// Category assignments for each branch
180    pub category_branches: Vec<Vec<String>>,
181    /// Impurity decrease achieved by this split
182    pub impurity_decrease: f64,
183}
184
185/// Tree growing strategy
186#[derive(Debug, Clone, Copy)]
187pub enum TreeGrowingStrategy {
188    /// Depth-first growing (traditional CART)
189    DepthFirst,
190    /// Best-first growing (expand node with highest impurity decrease)
191    BestFirst { max_leaves: Option<usize> },
192}
193
194/// Split type for decision trees
195#[derive(Debug, Clone, Copy)]
196pub enum SplitType {
197    /// Traditional axis-aligned splits (threshold on single feature)
198    AxisAligned,
199    /// Linear hyperplane splits (linear combination of features)
200    Oblique {
201        /// Number of random hyperplanes to evaluate per split
202        n_hyperplanes: usize,
203        /// Use ridge regression to find optimal hyperplane
204        use_ridge: bool,
205    },
206}