sklears_tree/criteria.rs
1//! Split criteria and constraints for decision trees
2//!
3//! This module contains enums and structs that define how decision trees
4//! make splitting decisions and handle constraints.
5
6use scirs2_core::ndarray::Array2;
7use sklears_core::types::Float;
8
9// Import types from config module
10
11/// Split criterion for decision trees
12#[derive(Debug, Clone, Copy)]
13pub enum SplitCriterion {
14 /// Gini impurity (for classification)
15 Gini,
16 /// Information gain / entropy (for classification)
17 Entropy,
18 /// Mean squared error (for regression)
19 MSE,
20 /// Mean absolute error (for regression)
21 MAE,
22 /// Twoing criterion for binary splits (classification)
23 Twoing,
24 /// Log-loss criterion for probability-based splitting (classification)
25 LogLoss,
26 /// Chi-squared automatic interaction detection (CHAID)
27 CHAID { significance_level: f64 },
28 /// Conditional inference trees with statistical testing
29 ConditionalInference {
30 significance_level: f64,
31 /// Type of statistical test to use
32 test_type: ConditionalTestType,
33 },
34}
35
36/// Statistical test types for conditional inference trees
37#[derive(Debug, Clone, Copy)]
38pub enum ConditionalTestType {
39 QuadraticForm,
40 MaxType,
41 MonteCarlo {
42 n_permutations: usize,
43 },
44 /// Asymptotic chi-squared test
45 AsymptoticChiSquared,
46}
47
48/// Monotonic constraint for a feature
49#[derive(Debug, Clone, Copy, PartialEq)]
50pub enum MonotonicConstraint {
51 /// No constraint on the relationship
52 None,
53 /// Feature must have increasing relationship with target (positive monotonicity)
54 Increasing,
55 /// Feature must have decreasing relationship with target (negative monotonicity)
56 Decreasing,
57}
58
59/// Interaction constraint between features
60#[derive(Debug, Clone)]
61pub enum InteractionConstraint {
62 /// No constraints on feature interactions
63 None,
64 /// Allow interactions only within specified groups
65 Groups(Vec<Vec<usize>>),
66 /// Forbid specific feature pairs from interacting
67 Forbidden(Vec<(usize, usize)>),
68 /// Allow only specific feature pairs to interact
69 Allowed(Vec<(usize, usize)>),
70}
71
72/// Feature grouping strategy for handling correlated features
73#[derive(Debug, Clone)]
74pub enum FeatureGrouping {
75 /// No feature grouping (default)
76 None,
77 /// Automatic grouping based on correlation threshold
78 AutoCorrelation {
79 /// Correlation threshold above which features are grouped together
80 threshold: Float,
81 /// Method to select representative feature from each group
82 selection_method: GroupSelectionMethod,
83 },
84 /// Manual feature groups specified by user
85 Manual {
86 /// List of feature groups, each group contains feature indices
87 groups: Vec<Vec<usize>>,
88 /// Method to select representative feature from each group
89 selection_method: GroupSelectionMethod,
90 },
91 /// Hierarchical clustering-based grouping
92 Hierarchical {
93 /// Number of clusters to create
94 n_clusters: usize,
95 /// Linkage method for hierarchical clustering
96 linkage: LinkageMethod,
97 /// Method to select representative feature from each group
98 selection_method: GroupSelectionMethod,
99 },
100}
101
102/// Method for selecting representative feature from a group
103#[derive(Debug, Clone, Copy, PartialEq)]
104pub enum GroupSelectionMethod {
105 /// Select feature with highest variance within the group
106 MaxVariance,
107 /// Select feature with highest correlation to target
108 MaxTargetCorrelation,
109 /// Select first feature in the group (index order)
110 First,
111 /// Select random feature from the group
112 Random,
113 /// Use all features from the group but with reduced weight
114 WeightedAll,
115}
116
117/// Linkage method for hierarchical clustering
118#[derive(Debug, Clone, Copy, PartialEq)]
119pub enum LinkageMethod {
120 /// Single linkage (minimum distance)
121 Single,
122 /// Complete linkage (maximum distance)
123 Complete,
124 /// Average linkage
125 Average,
126 /// Ward linkage (minimize within-cluster variance)
127 Ward,
128}
129
130/// Information about feature groups discovered or specified
131#[derive(Debug, Clone)]
132pub struct FeatureGroupInfo {
133 /// Groups of correlated features
134 pub groups: Vec<Vec<usize>>,
135 /// Representative feature index for each group
136 pub representatives: Vec<usize>,
137 /// Correlation matrix used for grouping (if applicable)
138 pub correlation_matrix: Option<Array2<Float>>,
139 /// Within-group correlations for each group
140 pub group_correlations: Vec<Float>,
141}
142
143/// Pruning strategy for decision trees
144#[derive(Debug, Clone, Copy)]
145pub enum PruningStrategy {
146 /// No pruning
147 None,
148 /// Cost-complexity pruning (post-pruning)
149 CostComplexity { alpha: f64 },
150 /// Reduced error pruning
151 ReducedError,
152}
153
154/// Missing value handling strategy
155#[derive(Debug, Clone, Copy)]
156pub enum MissingValueStrategy {
157 /// Skip samples with missing values
158 Skip,
159 /// Use majority class/mean for splits
160 Majority,
161 /// Use surrogate splits
162 Surrogate,
163}
164
165/// Feature type specification for multiway splits
166#[derive(Debug, Clone)]
167pub enum FeatureType {
168 /// Continuous numerical feature (binary splits)
169 Continuous,
170 /// Categorical feature with specified categories (multiway splits)
171 Categorical(Vec<String>),
172}
173
174/// Information about a multiway split
175#[derive(Debug, Clone)]
176pub struct MultiwaySplit {
177 /// Feature index
178 pub feature_idx: usize,
179 /// Category assignments for each branch
180 pub category_branches: Vec<Vec<String>>,
181 /// Impurity decrease achieved by this split
182 pub impurity_decrease: f64,
183}
184
185/// Tree growing strategy
186#[derive(Debug, Clone, Copy)]
187pub enum TreeGrowingStrategy {
188 /// Depth-first growing (traditional CART)
189 DepthFirst,
190 /// Best-first growing (expand node with highest impurity decrease)
191 BestFirst { max_leaves: Option<usize> },
192}
193
194/// Split type for decision trees
195#[derive(Debug, Clone, Copy)]
196pub enum SplitType {
197 /// Traditional axis-aligned splits (threshold on single feature)
198 AxisAligned,
199 /// Linear hyperplane splits (linear combination of features)
200 Oblique {
201 /// Number of random hyperplanes to evaluate per split
202 n_hyperplanes: usize,
203 /// Use ridge regression to find optimal hyperplane
204 use_ridge: bool,
205 },
206}