sklears_ensemble/adaboost/
decision_tree.rs1use super::types::*;
4use scirs2_core::ndarray::{Array1, Array2};
5use sklears_core::{
6 error::Result,
7 traits::{Fit, Trained, Untrained},
8 types::{Float, Int},
9};
10use std::marker::PhantomData;
11
12impl DecisionTreeClassifier<Untrained> {
13 pub fn new() -> Self {
14 Self {
15 criterion: SplitCriterion::Gini,
16 max_depth: Some(1),
17 min_samples_split: 2,
18 min_samples_leaf: 1,
19 random_state: None,
20 state: PhantomData,
21 }
22 }
23
24 pub fn criterion(mut self, criterion: SplitCriterion) -> Self {
25 self.criterion = criterion;
26 self
27 }
28
29 pub fn max_depth(mut self, max_depth: usize) -> Self {
30 self.max_depth = Some(max_depth);
31 self
32 }
33
34 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
35 self.min_samples_split = min_samples_split;
36 self
37 }
38
39 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
40 self.min_samples_leaf = min_samples_leaf;
41 self
42 }
43
44 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
45 self.random_state = random_state;
46 self
47 }
48}
49
50impl Fit<Array2<Float>, Array1<Int>> for DecisionTreeClassifier<Untrained> {
51 type Fitted = DecisionTreeClassifier<Trained>;
52
53 fn fit(self, _x: &Array2<Float>, _y: &Array1<Int>) -> Result<Self::Fitted> {
54 Ok(DecisionTreeClassifier {
56 criterion: self.criterion,
57 max_depth: self.max_depth,
58 min_samples_split: self.min_samples_split,
59 min_samples_leaf: self.min_samples_leaf,
60 random_state: self.random_state,
61 state: PhantomData,
62 })
63 }
64}
65
66impl DecisionTreeClassifier<Trained> {
67 pub fn predict(&self, x: &Array2<Float>) -> Result<Array1<Int>> {
68 let n_samples = x.nrows();
70 let predictions = x.column(0).mapv(|val| if val > 2.0 { 1 } else { 0 });
71 Ok(predictions)
72 }
73}
74
75impl DecisionTreeRegressor<Untrained> {
76 pub fn new() -> Self {
77 Self {
78 criterion: SplitCriterion::Gini,
79 max_depth: Some(1),
80 min_samples_split: 2,
81 min_samples_leaf: 1,
82 random_state: None,
83 state: PhantomData,
84 }
85 }
86
87 pub fn criterion(mut self, criterion: SplitCriterion) -> Self {
88 self.criterion = criterion;
89 self
90 }
91
92 pub fn max_depth(mut self, max_depth: usize) -> Self {
93 self.max_depth = Some(max_depth);
94 self
95 }
96
97 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
98 self.min_samples_split = min_samples_split;
99 self
100 }
101
102 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
103 self.min_samples_leaf = min_samples_leaf;
104 self
105 }
106
107 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
108 self.random_state = random_state;
109 self
110 }
111}
112
113impl Fit<Array2<Float>, Array1<Float>> for DecisionTreeRegressor<Untrained> {
114 type Fitted = DecisionTreeRegressor<Trained>;
115
116 fn fit(self, _x: &Array2<Float>, _y: &Array1<Float>) -> Result<Self::Fitted> {
117 Ok(DecisionTreeRegressor {
119 criterion: self.criterion,
120 max_depth: self.max_depth,
121 min_samples_split: self.min_samples_split,
122 min_samples_leaf: self.min_samples_leaf,
123 random_state: self.random_state,
124 state: PhantomData,
125 })
126 }
127}
128
129impl DecisionTreeRegressor<Trained> {
130 pub fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
131 let n_samples = x.nrows();
133 Ok(Array1::zeros(n_samples))
134 }
135}