sklears_tree/
decision_tree.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
7use sklears_core::error::{Result, SklearsError};
8use sklears_core::traits::{Estimator, Fit, Predict, Trained, Untrained};
9use std::marker::PhantomData;
10
11pub use crate::config::DecisionTreeConfig;
13pub use crate::criteria::{ConditionalTestType, FeatureType, MonotonicConstraint, SplitCriterion};
14pub use crate::node::{CompactTreeNode, CustomSplit, SurrogateSplit, TreeNode};
15pub use crate::splits::{ChaidSplit, HyperplaneSplit};
16
17#[derive(Debug, Clone)]
19pub struct DecisionTree<State = Untrained> {
20 config: DecisionTreeConfig,
21 root: Option<TreeNode>,
22 feature_importances: Option<Array1<f64>>,
23 n_features: usize,
24 n_samples: usize,
25 state: PhantomData<State>,
26}
27
28pub type DecisionTreeClassifier = DecisionTree<Untrained>;
30
31pub type DecisionTreeRegressor = DecisionTree<Untrained>;
33
34impl<State> Default for DecisionTree<State> {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl<State> DecisionTree<State> {
41 pub fn new() -> Self {
43 Self {
44 config: DecisionTreeConfig::default(),
45 root: None,
46 feature_importances: None,
47 n_features: 0,
48 n_samples: 0,
49 state: PhantomData,
50 }
51 }
52
53 pub fn with_config(config: DecisionTreeConfig) -> Self {
55 Self {
56 config,
57 root: None,
58 feature_importances: None,
59 n_features: 0,
60 n_samples: 0,
61 state: PhantomData,
62 }
63 }
64
65 pub fn builder() -> DecisionTreeBuilder<State> {
67 DecisionTreeBuilder::new()
68 }
69
70 pub fn config(&self) -> &DecisionTreeConfig {
72 &self.config
73 }
74
75 pub fn root(&self) -> Option<&TreeNode> {
77 self.root.as_ref()
78 }
79
80 pub fn feature_importances(&self) -> Option<&Array1<f64>> {
82 self.feature_importances.as_ref()
83 }
84
85 pub fn n_features(&self) -> usize {
87 self.n_features
88 }
89
90 pub fn n_samples(&self) -> usize {
92 self.n_samples
93 }
94
95 pub fn depth(&self) -> usize {
97 match &self.root {
98 Some(root) => root.depth,
99 None => 0,
100 }
101 }
102
103 pub fn criterion(mut self, criterion: SplitCriterion) -> Self {
105 self.config.criterion = criterion;
106 self
107 }
108
109 pub fn max_depth(mut self, max_depth: usize) -> Self {
111 self.config.max_depth = Some(max_depth);
112 self
113 }
114
115 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
117 self.config.min_samples_split = min_samples_split;
118 self
119 }
120
121 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
123 self.config.min_samples_leaf = min_samples_leaf;
124 self
125 }
126
127 pub fn missing_values(mut self, strategy: crate::config::MissingValueStrategy) -> Self {
129 self.config.missing_values = strategy;
130 self
131 }
132
133 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
135 self.config.random_state = random_state;
136 self
137 }
138}
139
140#[derive(Debug)]
142pub struct DecisionTreeBuilder<State> {
143 config: DecisionTreeConfig,
144 _marker: PhantomData<State>,
145}
146
147impl<State> DecisionTreeBuilder<State> {
148 pub fn new() -> Self {
150 Self {
151 config: DecisionTreeConfig::default(),
152 _marker: PhantomData,
153 }
154 }
155
156 pub fn criterion(mut self, criterion: SplitCriterion) -> Self {
158 self.config.criterion = criterion;
159 self
160 }
161
162 pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
164 self.config.max_depth = max_depth;
165 self
166 }
167
168 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
170 self.config.min_samples_split = min_samples_split;
171 self
172 }
173
174 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
176 self.config.min_samples_leaf = min_samples_leaf;
177 self
178 }
179
180 pub fn min_impurity_decrease(mut self, min_impurity_decrease: f64) -> Self {
182 self.config.min_impurity_decrease = min_impurity_decrease;
183 self
184 }
185
186 pub fn max_features(self, _max_features: Option<usize>) -> Self {
188 self
191 }
192
193 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
195 self.config.random_state = random_state;
196 self
197 }
198
199 pub fn build(self) -> DecisionTree<State> {
201 DecisionTree::with_config(self.config)
202 }
203}
204
205impl<State> Default for DecisionTreeBuilder<State> {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211impl DecisionTree<Untrained> {
214 pub fn is_fitted(&self) -> bool {
216 false
217 }
218
219 pub fn n_classes(&self) -> usize {
221 0 }
223}
224
225pub struct TreeValidator;
227
228impl TreeValidator {
229 pub fn validate_input(x: &ArrayView2<'_, f64>, y: &ArrayView1<'_, f64>) -> Result<()> {
231 if x.nrows() != y.len() {
232 return Err(SklearsError::InvalidInput(format!(
233 "Number of samples in X ({}) does not match number of targets in y ({})",
234 x.nrows(),
235 y.len()
236 )));
237 }
238
239 if x.nrows() == 0 {
240 return Err(SklearsError::InvalidInput(
241 "Input arrays must contain at least one sample".to_string(),
242 ));
243 }
244
245 if x.ncols() == 0 {
246 return Err(SklearsError::InvalidInput(
247 "Input arrays must contain at least one feature".to_string(),
248 ));
249 }
250
251 Ok(())
252 }
253
254 pub fn validate_fitted(tree: &DecisionTree<Trained>) -> Result<()> {
256 Ok(())
258 }
259
260 pub fn validate_prediction_input(
262 tree: &DecisionTree<Trained>,
263 x: &ArrayView2<'_, f64>,
264 ) -> Result<()> {
265 Self::validate_fitted(tree)?;
266
267 if x.ncols() != tree.n_features() {
268 return Err(SklearsError::InvalidInput(format!(
269 "Number of features in X ({}) does not match number of features seen during fit ({})",
270 x.ncols(),
271 tree.n_features()
272 )));
273 }
274
275 Ok(())
276 }
277}
278
279impl Estimator<Untrained> for DecisionTree<Untrained> {
282 type Config = DecisionTreeConfig;
283 type Error = SklearsError;
284 type Float = f64;
285
286 fn config(&self) -> &Self::Config {
287 &self.config
288 }
289}
290
291impl Estimator<Trained> for DecisionTree<Trained> {
292 type Config = DecisionTreeConfig;
293 type Error = SklearsError;
294 type Float = f64;
295
296 fn config(&self) -> &Self::Config {
297 &self.config
298 }
299}
300
301impl Fit<Array2<f64>, Array1<f64>, Untrained> for DecisionTree<Untrained> {
302 type Fitted = DecisionTree<Trained>;
303
304 fn fit(self, x: &Array2<f64>, y: &Array1<f64>) -> Result<Self::Fitted> {
305 TreeValidator::validate_input(&x.view(), &y.view())?;
307
308 let fitted_tree = DecisionTree::<Trained> {
311 config: self.config,
312 root: None, feature_importances: None,
314 n_features: x.ncols(),
315 n_samples: x.nrows(),
316 state: PhantomData,
317 };
318
319 Ok(fitted_tree)
320 }
321}
322
323impl Predict<Array2<f64>, Array1<f64>> for DecisionTree<Trained> {
324 fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
325 TreeValidator::validate_prediction_input(self, &x.view())?;
326
327 let predictions = Array1::zeros(x.nrows());
330
331 Ok(predictions)
332 }
333}
334
335impl DecisionTree<Trained> {
336 pub fn is_fitted(&self) -> bool {
338 true
339 }
340
341 pub fn n_classes(&self) -> usize {
343 2
346 }
347}