1use scirs2_core::ndarray::{Array1, Array2};
7use sklears_core::types::Float;
8use smartcore::linalg::basic::matrix::DenseMatrix;
9
10use crate::criteria::SplitCriterion;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum MonotonicConstraint {
16 None,
18 Increasing,
20 Decreasing,
22}
23
24#[derive(Debug, Clone)]
26pub enum InteractionConstraint {
27 None,
29 Groups(Vec<Vec<usize>>),
31 Forbidden(Vec<(usize, usize)>),
33 Allowed(Vec<(usize, usize)>),
35}
36
37#[derive(Debug, Clone)]
39pub enum FeatureGrouping {
40 None,
42 AutoCorrelation {
44 threshold: Float,
46 selection_method: GroupSelectionMethod,
48 },
49 Manual {
51 groups: Vec<Vec<usize>>,
53 selection_method: GroupSelectionMethod,
55 },
56 Hierarchical {
58 n_clusters: usize,
60 linkage: LinkageMethod,
62 selection_method: GroupSelectionMethod,
64 },
65}
66
67#[derive(Debug, Clone, Copy, PartialEq)]
69pub enum GroupSelectionMethod {
70 MaxVariance,
72 MaxTargetCorrelation,
74 First,
76 Random,
78 WeightedAll,
80}
81
82#[derive(Debug, Clone, Copy, PartialEq)]
84pub enum LinkageMethod {
85 Single,
87 Complete,
89 Average,
91 Ward,
93}
94
95#[derive(Debug, Clone)]
97pub struct FeatureGroupInfo {
98 pub groups: Vec<Vec<usize>>,
100 pub representatives: Vec<usize>,
102 pub correlation_matrix: Option<Array2<Float>>,
104 pub group_correlations: Vec<Float>,
106}
107
108#[derive(Debug, Clone)]
110pub enum MaxFeatures {
111 All,
113 Sqrt,
115 Log2,
117 Number(usize),
119 Fraction(f64),
121}
122
123#[derive(Debug, Clone, Copy)]
125pub enum PruningStrategy {
126 None,
128 CostComplexity { alpha: f64 },
130 ReducedError,
132}
133
134#[derive(Debug, Clone, Copy)]
136pub enum MissingValueStrategy {
137 Skip,
139 Majority,
141 Surrogate,
143}
144
145#[derive(Debug, Clone)]
147pub enum FeatureType {
148 Continuous,
150 Categorical(Vec<String>),
152}
153
154#[derive(Debug, Clone)]
156pub struct MultiwaySplit {
157 pub feature_idx: usize,
159 pub category_branches: Vec<Vec<String>>,
161 pub impurity_decrease: f64,
163}
164
165#[derive(Debug, Clone, Copy)]
167pub enum TreeGrowingStrategy {
168 DepthFirst,
170 BestFirst { max_leaves: Option<usize> },
172}
173
174#[derive(Debug, Clone, Copy)]
176pub enum SplitType {
177 AxisAligned,
179 Oblique {
181 n_hyperplanes: usize,
183 use_ridge: bool,
185 },
186}
187
188#[derive(Debug, Clone)]
190pub struct HyperplaneSplit {
191 pub coefficients: Array1<f64>,
193 pub threshold: f64,
195 pub bias: f64,
197 pub impurity_decrease: f64,
199}
200
201impl HyperplaneSplit {
202 pub fn evaluate(&self, sample: &Array1<f64>) -> bool {
204 let dot_product = self.coefficients.dot(sample) + self.bias;
205 dot_product >= self.threshold
206 }
207
208 pub fn random(n_features: usize, rng: &mut scirs2_core::CoreRandom) -> Self {
210 let mut coefficients = Array1::zeros(n_features);
211 for i in 0..n_features {
212 coefficients[i] = rng.gen_range(-1.0..1.0);
213 }
214
215 let dot_product: f64 = coefficients.dot(&coefficients);
217 let norm = dot_product.sqrt();
218 if norm > 1e-10_f64 {
219 coefficients /= norm;
220 }
221
222 Self {
223 coefficients,
224 threshold: rng.gen_range(-1.0..1.0),
225 bias: rng.gen_range(-0.1..0.1),
226 impurity_decrease: 0.0,
227 }
228 }
229
230 #[cfg(feature = "oblique")]
232 pub fn from_ridge_regression(x: &Array2<f64>, y: &Array1<f64>, alpha: f64) -> Result<Self> {
233 use scirs2_core::ndarray::s;
234 use sklears_core::error::SklearsError;
235
236 let n_features = x.ncols();
237 if x.nrows() < 2 {
238 return Err(SklearsError::InvalidInput(
239 "Need at least 2 samples for ridge regression".to_string(),
240 ));
241 }
242
243 let mut x_bias = Array2::ones((x.nrows(), n_features + 1));
245 x_bias.slice_mut(s![.., ..n_features]).assign(x);
246
247 let xtx = x_bias.t().dot(&x_bias);
249 let ridge_matrix = xtx + Array2::<f64>::eye(n_features + 1) * alpha;
250 let xty = x_bias.t().dot(y);
251
252 match gauss_jordan_inverse(&ridge_matrix) {
254 Ok(inv_matrix) => {
255 let coefficients_full = inv_matrix.dot(&xty);
256
257 let coefficients = coefficients_full.slice(s![..n_features]).to_owned();
258 let bias = coefficients_full[n_features];
259
260 Ok(Self {
261 coefficients,
262 threshold: 0.0, bias,
264 impurity_decrease: 0.0,
265 })
266 }
267 Err(_) => {
268 let mut rng = scirs2_core::random::thread_rng();
270 Ok(Self::random(n_features, &mut rng))
271 }
272 }
273 }
274}
275
276#[derive(Debug, Clone)]
278pub struct DecisionTreeConfig {
279 pub criterion: SplitCriterion,
281 pub max_depth: Option<usize>,
283 pub min_samples_split: usize,
285 pub min_samples_leaf: usize,
287 pub max_features: MaxFeatures,
289 pub random_state: Option<u64>,
291 pub min_weight_fraction_leaf: f64,
293 pub min_impurity_decrease: f64,
295 pub pruning: PruningStrategy,
297 pub missing_values: MissingValueStrategy,
299 pub feature_types: Option<Vec<FeatureType>>,
301 pub growing_strategy: TreeGrowingStrategy,
303 pub split_type: SplitType,
305 pub monotonic_constraints: Option<Vec<MonotonicConstraint>>,
307 pub interaction_constraints: InteractionConstraint,
309 pub feature_grouping: FeatureGrouping,
311}
312
313impl Default for DecisionTreeConfig {
314 fn default() -> Self {
315 Self {
316 criterion: SplitCriterion::Gini,
317 max_depth: None,
318 min_samples_split: 2,
319 min_samples_leaf: 1,
320 max_features: MaxFeatures::All,
321 random_state: None,
322 min_weight_fraction_leaf: 0.0,
323 min_impurity_decrease: 0.0,
324 pruning: PruningStrategy::None,
325 missing_values: MissingValueStrategy::Skip,
326 feature_types: None,
327 growing_strategy: TreeGrowingStrategy::DepthFirst,
328 split_type: SplitType::AxisAligned,
329 monotonic_constraints: None,
330 interaction_constraints: InteractionConstraint::None,
331 feature_grouping: FeatureGrouping::None,
332 }
333 }
334}
335
336pub fn ndarray_to_dense_matrix(arr: &Array2<f64>) -> DenseMatrix<f64> {
338 let _rows = arr.nrows();
339 let _cols = arr.ncols();
340 let mut data = Vec::new();
341 for row in arr.outer_iter() {
342 data.push(row.to_vec());
343 }
344 DenseMatrix::from_2d_vec(&data)
345}
346
347#[cfg(feature = "oblique")]
349fn gauss_jordan_inverse(matrix: &Array2<f64>) -> std::result::Result<Array2<f64>, &'static str> {
350 let n = matrix.nrows();
351 if n != matrix.ncols() {
352 return Err("Matrix must be square");
353 }
354
355 let mut augmented = Array2::zeros((n, 2 * n));
357 for i in 0..n {
358 for j in 0..n {
359 augmented[[i, j]] = matrix[[i, j]];
360 if i == j {
361 augmented[[i, j + n]] = 1.0;
362 }
363 }
364 }
365
366 for i in 0..n {
368 let mut max_row = i;
370 for k in i + 1..n {
371 if augmented[[k, i]].abs() > augmented[[max_row, i]].abs() {
372 max_row = k;
373 }
374 }
375
376 if max_row != i {
378 for j in 0..2 * n {
379 let temp = augmented[[i, j]];
380 augmented[[i, j]] = augmented[[max_row, j]];
381 augmented[[max_row, j]] = temp;
382 }
383 }
384
385 if augmented[[i, i]].abs() < 1e-10 {
387 return Err("Matrix is singular");
388 }
389
390 let pivot = augmented[[i, i]];
392 for j in 0..2 * n {
393 augmented[[i, j]] /= pivot;
394 }
395
396 for k in 0..n {
398 if k != i {
399 let factor = augmented[[k, i]];
400 for j in 0..2 * n {
401 augmented[[k, j]] -= factor * augmented[[i, j]];
402 }
403 }
404 }
405 }
406
407 let mut inverse = Array2::zeros((n, n));
409 for i in 0..n {
410 for j in 0..n {
411 inverse[[i, j]] = augmented[[i, j + n]];
412 }
413 }
414
415 Ok(inverse)
416}