1use crate::multi_label::{BinaryRelevance, BinaryRelevanceTrained};
7use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Predict, Untrained},
12 types::Float,
13};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ClassificationCriterion {
21 Gini,
23 Entropy,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum DAGInferenceMethod {
30 Greedy,
32 BeliefPropagation,
34 ExactILP,
36}
37
38#[derive(Debug, Clone)]
39struct DecisionNode {
40 is_leaf: bool,
41 prediction: Option<Array1<Float>>, feature_idx: Option<usize>,
43 threshold: Option<Float>,
44 left: Option<Box<DecisionNode>>,
45 right: Option<Box<DecisionNode>>,
46 n_samples: usize,
47 variance: Float, }
49
50#[derive(Debug, Clone)]
51pub struct ClassificationDecisionNode {
52 is_leaf: bool,
53 prediction: Option<Array1<i32>>, probabilities: Option<Array2<Float>>, feature_idx: Option<usize>,
56 threshold: Option<Float>,
57 left: Option<Box<ClassificationDecisionNode>>,
58 right: Option<Box<ClassificationDecisionNode>>,
59 n_samples: usize,
60 impurity: Float, }
62
63#[derive(Debug, Clone)]
86pub struct MultiTargetRegressionTree<S = Untrained> {
87 state: S,
88 max_depth: Option<usize>,
89 min_samples_split: usize,
90 min_samples_leaf: usize,
91 random_state: Option<u64>,
92}
93
94#[derive(Debug, Clone)]
95pub struct MultiTargetRegressionTreeTrained {
96 tree: DecisionNode,
97 n_features: usize,
98 n_targets: usize,
99 feature_importances: Array1<Float>,
100}
101
102impl MultiTargetRegressionTree<Untrained> {
103 pub fn new() -> Self {
105 Self {
106 state: Untrained,
107 max_depth: Some(5),
108 min_samples_split: 2,
109 min_samples_leaf: 1,
110 random_state: None,
111 }
112 }
113
114 pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
116 self.max_depth = max_depth;
117 self
118 }
119
120 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
122 self.min_samples_split = min_samples_split;
123 self
124 }
125
126 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
128 self.min_samples_leaf = min_samples_leaf;
129 self
130 }
131
132 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
134 self.random_state = random_state;
135 self
136 }
137
138 pub fn get_max_depth(&self) -> Option<usize> {
140 self.max_depth
141 }
142
143 pub fn get_min_samples_split(&self) -> usize {
145 self.min_samples_split
146 }
147
148 pub fn get_min_samples_leaf(&self) -> usize {
150 self.min_samples_leaf
151 }
152
153 pub fn get_random_state(&self) -> Option<u64> {
155 self.random_state
156 }
157}
158
159impl Default for MultiTargetRegressionTree<Untrained> {
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165impl Estimator for MultiTargetRegressionTree<Untrained> {
166 type Config = ();
167 type Error = SklearsError;
168 type Float = Float;
169
170 fn config(&self) -> &Self::Config {
171 &()
172 }
173}
174
175impl Fit<ArrayView2<'_, Float>, Array2<Float>> for MultiTargetRegressionTree<Untrained> {
176 type Fitted = MultiTargetRegressionTree<MultiTargetRegressionTreeTrained>;
177
178 #[allow(non_snake_case)]
179 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<Float>) -> SklResult<Self::Fitted> {
180 let X = X.to_owned();
181 let (n_samples, n_features) = X.dim();
182
183 if n_samples != y.nrows() {
184 return Err(SklearsError::InvalidInput(
185 "X and y must have the same number of samples".to_string(),
186 ));
187 }
188
189 let n_targets = y.ncols();
190 if n_targets == 0 {
191 return Err(SklearsError::InvalidInput(
192 "y must have at least one target".to_string(),
193 ));
194 }
195
196 if n_samples < self.min_samples_split {
197 return Err(SklearsError::InvalidInput(
198 "Number of samples is less than min_samples_split".to_string(),
199 ));
200 }
201
202 let indices: Vec<usize> = (0..n_samples).collect();
204 let tree = self.build_tree(&X, y, &indices, 0)?;
205
206 let mut feature_importances = Array1::<Float>::zeros(n_features);
208 self.calculate_feature_importances(&tree, &mut feature_importances, n_samples as Float);
209
210 let sum_importances: Float = feature_importances.sum();
212 if sum_importances > 0.0 {
213 feature_importances /= sum_importances;
214 }
215
216 Ok(MultiTargetRegressionTree {
217 state: MultiTargetRegressionTreeTrained {
218 tree,
219 n_features,
220 n_targets,
221 feature_importances,
222 },
223 max_depth: self.max_depth,
224 min_samples_split: self.min_samples_split,
225 min_samples_leaf: self.min_samples_leaf,
226 random_state: self.random_state,
227 })
228 }
229}
230
231impl MultiTargetRegressionTree<Untrained> {
232 fn build_tree(
233 &self,
234 X: &Array2<Float>,
235 y: &Array2<Float>,
236 indices: &[usize],
237 depth: usize,
238 ) -> SklResult<DecisionNode> {
239 let n_samples = indices.len();
240 let n_targets = y.ncols();
241
242 let mut prediction = Array1::<Float>::zeros(n_targets);
244 for &idx in indices {
245 for j in 0..n_targets {
246 prediction[j] += y[[idx, j]];
247 }
248 }
249 prediction /= n_samples as Float;
250
251 let mut variance = 0.0;
253 for &idx in indices {
254 for j in 0..n_targets {
255 let diff = y[[idx, j]] - prediction[j];
256 variance += diff * diff;
257 }
258 }
259 variance /= n_samples as Float;
260
261 let should_stop = n_samples < self.min_samples_split
263 || n_samples < self.min_samples_leaf
264 || self.max_depth.is_some_and(|max_d| depth >= max_d)
265 || variance < 1e-10;
266
267 if should_stop {
268 return Ok(DecisionNode {
269 is_leaf: true,
270 prediction: Some(prediction),
271 feature_idx: None,
272 threshold: None,
273 left: None,
274 right: None,
275 n_samples,
276 variance,
277 });
278 }
279
280 let (best_feature, best_threshold, best_variance_reduction) =
282 self.find_best_split(X, y, indices)?;
283
284 if best_variance_reduction <= 0.0 {
285 return Ok(DecisionNode {
286 is_leaf: true,
287 prediction: Some(prediction),
288 feature_idx: None,
289 threshold: None,
290 left: None,
291 right: None,
292 n_samples,
293 variance,
294 });
295 }
296
297 let (left_indices, right_indices) =
299 self.split_data(X, indices, best_feature, best_threshold);
300
301 if left_indices.len() < self.min_samples_leaf || right_indices.len() < self.min_samples_leaf
302 {
303 return Ok(DecisionNode {
304 is_leaf: true,
305 prediction: Some(prediction),
306 feature_idx: None,
307 threshold: None,
308 left: None,
309 right: None,
310 n_samples,
311 variance,
312 });
313 }
314
315 let left_child = self.build_tree(X, y, &left_indices, depth + 1)?;
317 let right_child = self.build_tree(X, y, &right_indices, depth + 1)?;
318
319 Ok(DecisionNode {
320 is_leaf: false,
321 prediction: None,
322 feature_idx: Some(best_feature),
323 threshold: Some(best_threshold),
324 left: Some(Box::new(left_child)),
325 right: Some(Box::new(right_child)),
326 n_samples,
327 variance,
328 })
329 }
330
331 fn find_best_split(
332 &self,
333 X: &Array2<Float>,
334 y: &Array2<Float>,
335 indices: &[usize],
336 ) -> SklResult<(usize, Float, Float)> {
337 let n_features = X.ncols();
338 let mut best_feature = 0;
339 let mut best_threshold = 0.0;
340 let mut best_variance_reduction = 0.0;
341
342 let current_variance = self.calculate_variance(y, indices);
344
345 for feature_idx in 0..n_features {
346 let mut feature_values: Vec<Float> =
348 indices.iter().map(|&idx| X[[idx, feature_idx]]).collect();
349 feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
350 feature_values.dedup();
351
352 for i in 0..feature_values.len().saturating_sub(1) {
353 let threshold = (feature_values[i] + feature_values[i + 1]) / 2.0;
354
355 let (left_indices, right_indices) =
356 self.split_data(X, indices, feature_idx, threshold);
357
358 if left_indices.is_empty() || right_indices.is_empty() {
359 continue;
360 }
361
362 let left_variance = self.calculate_variance(y, &left_indices);
363 let right_variance = self.calculate_variance(y, &right_indices);
364
365 let weighted_variance = (left_indices.len() as Float * left_variance
366 + right_indices.len() as Float * right_variance)
367 / indices.len() as Float;
368
369 let variance_reduction = current_variance - weighted_variance;
370
371 if variance_reduction > best_variance_reduction {
372 best_variance_reduction = variance_reduction;
373 best_feature = feature_idx;
374 best_threshold = threshold;
375 }
376 }
377 }
378
379 Ok((best_feature, best_threshold, best_variance_reduction))
380 }
381
382 fn calculate_variance(&self, y: &Array2<Float>, indices: &[usize]) -> Float {
383 if indices.is_empty() {
384 return 0.0;
385 }
386
387 let n_targets = y.ncols();
388 let n_samples = indices.len();
389
390 let mut means = Array1::<Float>::zeros(n_targets);
392 for &idx in indices {
393 for j in 0..n_targets {
394 means[j] += y[[idx, j]];
395 }
396 }
397 means /= n_samples as Float;
398
399 let mut variance = 0.0;
401 for &idx in indices {
402 for j in 0..n_targets {
403 let diff = y[[idx, j]] - means[j];
404 variance += diff * diff;
405 }
406 }
407 variance / n_samples as Float
408 }
409
410 fn split_data(
411 &self,
412 X: &Array2<Float>,
413 indices: &[usize],
414 feature_idx: usize,
415 threshold: Float,
416 ) -> (Vec<usize>, Vec<usize>) {
417 let mut left_indices = Vec::new();
418 let mut right_indices = Vec::new();
419
420 for &idx in indices {
421 if X[[idx, feature_idx]] <= threshold {
422 left_indices.push(idx);
423 } else {
424 right_indices.push(idx);
425 }
426 }
427
428 (left_indices, right_indices)
429 }
430
431 fn calculate_feature_importances(
432 &self,
433 node: &DecisionNode,
434 importances: &mut Array1<Float>,
435 total_samples: Float,
436 ) {
437 if let (Some(feature_idx), Some(left), Some(right)) =
438 (node.feature_idx, &node.left, &node.right)
439 {
440 let importance = (node.n_samples as Float / total_samples) * node.variance;
441 importances[feature_idx] += importance;
442
443 self.calculate_feature_importances(left, importances, total_samples);
444 self.calculate_feature_importances(right, importances, total_samples);
445 }
446 }
447}
448
449impl MultiTargetRegressionTree<MultiTargetRegressionTreeTrained> {
450 pub fn feature_importances(&self) -> &Array1<Float> {
452 &self.state.feature_importances
453 }
454
455 pub fn n_features(&self) -> usize {
457 self.state.n_features
458 }
459
460 pub fn n_targets(&self) -> usize {
462 self.state.n_targets
463 }
464
465 #[allow(non_snake_case)]
467 pub fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
468 let X = *X;
469 let (n_samples, n_features) = X.dim();
470
471 if n_features != self.state.n_features {
472 return Err(SklearsError::InvalidInput(
473 "Number of features doesn't match training data".to_string(),
474 ));
475 }
476
477 let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_targets));
478
479 for i in 0..n_samples {
480 let sample = X.slice(s![i, ..]);
481 let prediction = self.predict_single(&self.state.tree, &sample)?;
482 for j in 0..self.state.n_targets {
483 predictions[[i, j]] = prediction[j];
484 }
485 }
486
487 Ok(predictions)
488 }
489
490 fn predict_single(
491 &self,
492 node: &DecisionNode,
493 sample: &ArrayView1<'_, Float>,
494 ) -> SklResult<Array1<Float>> {
495 if node.is_leaf {
496 if let Some(ref prediction) = node.prediction {
497 Ok(prediction.clone())
498 } else {
499 Err(SklearsError::InvalidInput(
500 "Leaf node without prediction".to_string(),
501 ))
502 }
503 } else {
504 let feature_idx = node.feature_idx.ok_or(SklearsError::InvalidInput(
505 "Non-leaf node without feature index".to_string(),
506 ))?;
507 let threshold = node.threshold.ok_or(SklearsError::InvalidInput(
508 "Non-leaf node without threshold".to_string(),
509 ))?;
510
511 if sample[feature_idx] <= threshold {
512 if let Some(ref left) = node.left {
513 self.predict_single(left, sample)
514 } else {
515 Err(SklearsError::InvalidInput(
516 "Non-leaf node without left child".to_string(),
517 ))
518 }
519 } else if let Some(ref right) = node.right {
520 self.predict_single(right, sample)
521 } else {
522 Err(SklearsError::InvalidInput(
523 "Non-leaf node without right child".to_string(),
524 ))
525 }
526 }
527 }
528}
529
530#[derive(Debug, Clone)]
552pub struct MultiTargetDecisionTreeClassifier<S = Untrained> {
553 state: S,
554 max_depth: Option<usize>,
555 min_samples_split: usize,
556 min_samples_leaf: usize,
557 criterion: ClassificationCriterion,
558 random_state: Option<u64>,
559}
560
561#[derive(Debug, Clone)]
562pub struct MultiTargetDecisionTreeClassifierTrained {
563 tree: ClassificationDecisionNode,
564 n_features: usize,
565 n_targets: usize,
566 feature_importances: Array1<Float>,
567 classes_per_target: Vec<Vec<i32>>,
568}
569
570impl MultiTargetDecisionTreeClassifier<Untrained> {
571 pub fn new() -> Self {
573 Self {
574 state: Untrained,
575 max_depth: Some(5),
576 min_samples_split: 2,
577 min_samples_leaf: 1,
578 criterion: ClassificationCriterion::Gini,
579 random_state: None,
580 }
581 }
582
583 pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
585 self.max_depth = max_depth;
586 self
587 }
588
589 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
591 self.min_samples_split = min_samples_split;
592 self
593 }
594
595 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
597 self.min_samples_leaf = min_samples_leaf;
598 self
599 }
600
601 pub fn criterion(mut self, criterion: ClassificationCriterion) -> Self {
603 self.criterion = criterion;
604 self
605 }
606
607 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
609 self.random_state = random_state;
610 self
611 }
612}
613
614impl Default for MultiTargetDecisionTreeClassifier<Untrained> {
615 fn default() -> Self {
616 Self::new()
617 }
618}
619
620impl Estimator for MultiTargetDecisionTreeClassifier<Untrained> {
621 type Config = ();
622 type Error = SklearsError;
623 type Float = Float;
624
625 fn config(&self) -> &Self::Config {
626 &()
627 }
628}
629
630impl Fit<ArrayView2<'_, Float>, Array2<i32>> for MultiTargetDecisionTreeClassifier<Untrained> {
631 type Fitted = MultiTargetDecisionTreeClassifier<MultiTargetDecisionTreeClassifierTrained>;
632
633 #[allow(non_snake_case)]
634 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
635 let X = X.to_owned();
636 let (n_samples, n_features) = X.dim();
637
638 if n_samples != y.nrows() {
639 return Err(SklearsError::InvalidInput(
640 "X and y must have the same number of samples".to_string(),
641 ));
642 }
643
644 let n_targets = y.ncols();
645 if n_targets == 0 {
646 return Err(SklearsError::InvalidInput(
647 "y must have at least one target".to_string(),
648 ));
649 }
650
651 let mut classes_per_target = Vec::new();
653 for target_idx in 0..n_targets {
654 let target_column = y.column(target_idx);
655 let mut unique_classes: Vec<i32> = target_column.iter().cloned().collect();
656 unique_classes.sort_unstable();
657 unique_classes.dedup();
658 classes_per_target.push(unique_classes);
659 }
660
661 let mut feature_importances = Array1::<Float>::zeros(n_features);
663
664 let indices: Vec<usize> = (0..n_samples).collect();
666 let tree = build_classification_tree(
667 &X,
668 y,
669 &indices,
670 &mut feature_importances,
671 0,
672 self.max_depth,
673 self.min_samples_split,
674 self.min_samples_leaf,
675 self.criterion,
676 &classes_per_target,
677 )?;
678
679 let importance_sum = feature_importances.sum();
681 if importance_sum > 0.0 {
682 feature_importances /= importance_sum;
683 }
684
685 let trained_state = MultiTargetDecisionTreeClassifierTrained {
686 tree,
687 n_features,
688 n_targets,
689 feature_importances,
690 classes_per_target,
691 };
692
693 Ok(MultiTargetDecisionTreeClassifier {
694 state: trained_state,
695 max_depth: self.max_depth,
696 min_samples_split: self.min_samples_split,
697 min_samples_leaf: self.min_samples_leaf,
698 criterion: self.criterion,
699 random_state: self.random_state,
700 })
701 }
702}
703
704impl Predict<ArrayView2<'_, Float>, Array2<i32>>
705 for MultiTargetDecisionTreeClassifier<MultiTargetDecisionTreeClassifierTrained>
706{
707 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
708 let (n_samples, n_features) = X.dim();
709 if n_features != self.state.n_features {
710 return Err(SklearsError::InvalidInput(
711 "X has different number of features than training data".to_string(),
712 ));
713 }
714
715 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_targets));
716
717 for i in 0..n_samples {
718 let sample = X.row(i);
719 let prediction = predict_classification_sample(&self.state.tree, &sample);
720 for j in 0..self.state.n_targets {
721 predictions[[i, j]] = prediction[j];
722 }
723 }
724
725 Ok(predictions)
726 }
727}
728
729impl MultiTargetDecisionTreeClassifier<MultiTargetDecisionTreeClassifierTrained> {
730 pub fn feature_importances(&self) -> &Array1<Float> {
732 &self.state.feature_importances
733 }
734
735 pub fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Vec<Array2<Float>>> {
737 let (n_samples, n_features) = X.dim();
738 if n_features != self.state.n_features {
739 return Err(SklearsError::InvalidInput(
740 "X has different number of features than training data".to_string(),
741 ));
742 }
743
744 let mut all_probabilities = Vec::new();
745
746 for target_idx in 0..self.state.n_targets {
748 let n_classes = self.state.classes_per_target[target_idx].len();
749 all_probabilities.push(Array2::<Float>::zeros((n_samples, n_classes)));
750 }
751
752 for i in 0..n_samples {
753 let sample = X.row(i);
754 let probabilities = predict_classification_probabilities(
755 &self.state.tree,
756 &sample,
757 &self.state.classes_per_target,
758 );
759
760 for (target_idx, target_probs) in probabilities.iter().enumerate() {
761 for (class_idx, &prob) in target_probs.iter().enumerate() {
762 all_probabilities[target_idx][[i, class_idx]] = prob;
763 }
764 }
765 }
766
767 Ok(all_probabilities)
768 }
769}
770
771#[derive(Debug, Clone)]
803pub struct RandomForestMultiOutput<S = Untrained> {
804 state: S,
805 n_estimators: usize,
806 max_depth: Option<usize>,
807 min_samples_split: usize,
808 min_samples_leaf: usize,
809 max_features: Option<usize>,
810 bootstrap: bool,
811 random_state: Option<u64>,
812}
813
814#[derive(Debug, Clone)]
815pub struct RandomForestMultiOutputTrained {
816 trees: Vec<MultiTargetRegressionTree<MultiTargetRegressionTreeTrained>>,
817 n_features: usize,
818 n_targets: usize,
819 feature_importances: Array1<Float>,
820}
821
822impl RandomForestMultiOutput<Untrained> {
823 pub fn new() -> Self {
825 Self {
826 state: Untrained,
827 n_estimators: 10,
828 max_depth: None,
829 min_samples_split: 2,
830 min_samples_leaf: 1,
831 max_features: None,
832 bootstrap: true,
833 random_state: None,
834 }
835 }
836
837 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
839 self.n_estimators = n_estimators;
840 self
841 }
842
843 pub fn max_depth(mut self, max_depth: Option<usize>) -> Self {
845 self.max_depth = max_depth;
846 self
847 }
848
849 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
851 self.min_samples_split = min_samples_split;
852 self
853 }
854
855 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
857 self.min_samples_leaf = min_samples_leaf;
858 self
859 }
860
861 pub fn max_features(mut self, max_features: Option<usize>) -> Self {
863 self.max_features = max_features;
864 self
865 }
866
867 pub fn bootstrap(mut self, bootstrap: bool) -> Self {
869 self.bootstrap = bootstrap;
870 self
871 }
872
873 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
875 self.random_state = random_state;
876 self
877 }
878
879 pub fn get_n_estimators(&self) -> usize {
881 self.n_estimators
882 }
883
884 pub fn get_max_depth(&self) -> Option<usize> {
886 self.max_depth
887 }
888
889 pub fn get_min_samples_split(&self) -> usize {
891 self.min_samples_split
892 }
893
894 pub fn get_min_samples_leaf(&self) -> usize {
896 self.min_samples_leaf
897 }
898
899 pub fn get_max_features(&self) -> Option<usize> {
901 self.max_features
902 }
903
904 pub fn get_bootstrap(&self) -> bool {
906 self.bootstrap
907 }
908
909 pub fn get_random_state(&self) -> Option<u64> {
911 self.random_state
912 }
913}
914
915impl Default for RandomForestMultiOutput<Untrained> {
916 fn default() -> Self {
917 Self::new()
918 }
919}
920
921impl Estimator for RandomForestMultiOutput<Untrained> {
922 type Config = ();
923 type Error = SklearsError;
924 type Float = Float;
925
926 fn config(&self) -> &Self::Config {
927 &()
928 }
929}
930
931impl Fit<ArrayView2<'_, Float>, Array2<Float>> for RandomForestMultiOutput<Untrained> {
932 type Fitted = RandomForestMultiOutput<RandomForestMultiOutputTrained>;
933
934 #[allow(non_snake_case)]
935 fn fit(self, X: &ArrayView2<'_, Float>, y: &Array2<Float>) -> SklResult<Self::Fitted> {
936 let X = X.to_owned();
937 let (n_samples, n_features) = X.dim();
938
939 if n_samples != y.nrows() {
940 return Err(SklearsError::InvalidInput(
941 "X and y must have the same number of samples".to_string(),
942 ));
943 }
944
945 let n_targets = y.ncols();
946 if n_targets == 0 {
947 return Err(SklearsError::InvalidInput(
948 "y must have at least one target".to_string(),
949 ));
950 }
951
952 let mut trees = Vec::new();
953 let mut feature_importances = Array1::<Float>::zeros(n_features);
954
955 for i in 0..self.n_estimators {
956 let (X_sample, y_sample) = if self.bootstrap {
958 self.create_bootstrap_sample(&X, y, i)?
959 } else {
960 (X.clone(), y.clone())
961 };
962
963 let tree = MultiTargetRegressionTree::new()
965 .max_depth(self.max_depth)
966 .min_samples_split(self.min_samples_split)
967 .min_samples_leaf(self.min_samples_leaf)
968 .random_state(self.random_state.map(|s| s.wrapping_add(i as u64)));
969
970 let trained_tree = tree.fit(&X_sample.view(), &y_sample)?;
971
972 feature_importances += trained_tree.feature_importances();
974
975 trees.push(trained_tree);
976 }
977
978 feature_importances /= self.n_estimators as Float;
980
981 Ok(RandomForestMultiOutput {
982 state: RandomForestMultiOutputTrained {
983 trees,
984 n_features,
985 n_targets,
986 feature_importances,
987 },
988 n_estimators: self.n_estimators,
989 max_depth: self.max_depth,
990 min_samples_split: self.min_samples_split,
991 min_samples_leaf: self.min_samples_leaf,
992 max_features: self.max_features,
993 bootstrap: self.bootstrap,
994 random_state: self.random_state,
995 })
996 }
997}
998
999impl RandomForestMultiOutput<Untrained> {
1000 fn create_bootstrap_sample(
1001 &self,
1002 X: &Array2<Float>,
1003 y: &Array2<Float>,
1004 seed: usize,
1005 ) -> SklResult<(Array2<Float>, Array2<Float>)> {
1006 let n_samples = X.nrows();
1007 let mut rng_state = self.random_state.unwrap_or(42).wrapping_add(seed as u64);
1008
1009 let mut X_sample = Array2::<Float>::zeros(X.raw_dim());
1010 let mut y_sample = Array2::<Float>::zeros(y.raw_dim());
1011
1012 for i in 0..n_samples {
1013 rng_state = rng_state.wrapping_mul(1103515245).wrapping_add(12345);
1014 let idx = (rng_state / 65536) % (n_samples as u64);
1015
1016 X_sample
1017 .slice_mut(s![i, ..])
1018 .assign(&X.slice(s![idx as usize, ..]));
1019 y_sample
1020 .slice_mut(s![i, ..])
1021 .assign(&y.slice(s![idx as usize, ..]));
1022 }
1023
1024 Ok((X_sample, y_sample))
1025 }
1026}
1027
1028impl RandomForestMultiOutput<RandomForestMultiOutputTrained> {
1029 pub fn feature_importances(&self) -> &Array1<Float> {
1031 &self.state.feature_importances
1032 }
1033
1034 pub fn n_estimators(&self) -> usize {
1036 self.state.trees.len()
1037 }
1038
1039 pub fn n_features(&self) -> usize {
1041 self.state.n_features
1042 }
1043
1044 pub fn n_targets(&self) -> usize {
1046 self.state.n_targets
1047 }
1048
1049 #[allow(non_snake_case)]
1051 pub fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1052 let X = *X;
1053 let (n_samples, n_features) = X.dim();
1054
1055 if n_features != self.state.n_features {
1056 return Err(SklearsError::InvalidInput(
1057 "Number of features doesn't match training data".to_string(),
1058 ));
1059 }
1060
1061 let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_targets));
1062
1063 for tree in &self.state.trees {
1065 let tree_predictions = tree.predict(&X)?;
1066 predictions += &tree_predictions;
1067 }
1068
1069 predictions /= self.state.trees.len() as Float;
1070 Ok(predictions)
1071 }
1072}
1073
1074#[derive(Debug, Clone)]
1098pub struct TreeStructuredPredictor<State = Untrained> {
1099 max_depth: usize,
1100 branching_factor: usize,
1101 tree_structure: Vec<Vec<usize>>, node_classifiers: HashMap<usize, String>,
1103 state: State,
1104}
1105
1106#[derive(Debug, Clone)]
1108pub struct TreeStructuredPredictorTrained {
1109 node_classifiers: HashMap<usize, BinaryRelevance<BinaryRelevanceTrained>>,
1110 tree_structure: Vec<Vec<usize>>,
1111 max_depth: usize,
1112 n_nodes: usize,
1113}
1114
1115impl Default for TreeStructuredPredictor<Untrained> {
1116 fn default() -> Self {
1117 Self::new()
1118 }
1119}
1120
1121impl TreeStructuredPredictor<Untrained> {
1122 pub fn new() -> Self {
1124 Self {
1125 max_depth: 5,
1126 branching_factor: 2,
1127 tree_structure: Vec::new(),
1128 node_classifiers: HashMap::new(),
1129 state: Untrained,
1130 }
1131 }
1132
1133 pub fn max_depth(mut self, depth: usize) -> Self {
1135 self.max_depth = depth;
1136 self
1137 }
1138
1139 pub fn branching_factor(mut self, factor: usize) -> Self {
1141 self.branching_factor = factor;
1142 self
1143 }
1144
1145 pub fn tree_structure(mut self, structure: Vec<Vec<usize>>) -> Self {
1147 self.tree_structure = structure;
1148 self
1149 }
1150}
1151
1152impl Estimator for TreeStructuredPredictor<Untrained> {
1153 type Config = ();
1154 type Error = SklearsError;
1155 type Float = Float;
1156
1157 fn config(&self) -> &Self::Config {
1158 &()
1159 }
1160}
1161
1162impl Fit<Array2<Float>, Array2<i32>> for TreeStructuredPredictor<Untrained> {
1163 type Fitted = TreeStructuredPredictor<TreeStructuredPredictorTrained>;
1164
1165 fn fit(self, X: &Array2<Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
1166 let (n_samples, _n_features) = X.dim();
1167 let (y_samples, max_path_length) = y.dim();
1168
1169 if n_samples != y_samples {
1170 return Err(SklearsError::InvalidInput(
1171 "Number of samples in X and y must match".to_string(),
1172 ));
1173 }
1174
1175 let tree_structure = if self.tree_structure.is_empty() {
1177 self.build_default_tree_structure()?
1178 } else {
1179 self.tree_structure.clone()
1180 };
1181
1182 let n_nodes = tree_structure.len();
1183 let mut node_classifiers = HashMap::new();
1184
1185 for node_id in 0..n_nodes {
1187 if !tree_structure[node_id].is_empty() {
1188 let (node_X, node_y) = self.create_node_training_data(
1191 &X.view(),
1192 &y.view(),
1193 node_id,
1194 &tree_structure,
1195 max_path_length,
1196 )?;
1197
1198 if !node_y.is_empty() {
1199 let classifier = BinaryRelevance::new();
1200 let trained_classifier = classifier.fit(&node_X.view(), &node_y)?;
1201 node_classifiers.insert(node_id, trained_classifier);
1202 }
1203 }
1204 }
1205
1206 Ok(TreeStructuredPredictor {
1207 max_depth: self.max_depth,
1208 branching_factor: self.branching_factor,
1209 tree_structure: tree_structure.clone(),
1210 node_classifiers: HashMap::new(),
1211 state: TreeStructuredPredictorTrained {
1212 node_classifiers,
1213 tree_structure,
1214 max_depth: self.max_depth,
1215 n_nodes,
1216 },
1217 })
1218 }
1219}
1220
1221impl TreeStructuredPredictor<Untrained> {
1222 fn build_default_tree_structure(&self) -> SklResult<Vec<Vec<usize>>> {
1224 let mut total_nodes = 0;
1225 for depth in 0..self.max_depth {
1226 total_nodes += self.branching_factor.pow(depth as u32);
1227 }
1228
1229 let mut tree_structure = vec![Vec::new(); total_nodes];
1230 let mut node_id = 0;
1231
1232 for depth in 0..(self.max_depth - 1) {
1234 let nodes_at_depth = self.branching_factor.pow(depth as u32);
1235
1236 for _ in 0..nodes_at_depth {
1237 for child in 0..self.branching_factor {
1238 let child_id = node_id + nodes_at_depth + child;
1239 if child_id < total_nodes {
1240 tree_structure[node_id].push(child_id);
1241 }
1242 }
1243 node_id += 1;
1244 }
1245 }
1246
1247 Ok(tree_structure)
1248 }
1249
1250 fn create_node_training_data(
1252 &self,
1253 X: &ArrayView2<Float>,
1254 y: &ArrayView2<i32>,
1255 node_id: usize,
1256 tree_structure: &Vec<Vec<usize>>,
1257 max_path_length: usize,
1258 ) -> SklResult<(Array2<Float>, Array2<i32>)> {
1259 let n_samples = X.nrows();
1260 let mut valid_samples = Vec::new();
1261 let mut node_labels = Vec::new();
1262
1263 for sample_idx in 0..n_samples {
1264 let path = y.row(sample_idx);
1265
1266 for pos in 0..max_path_length {
1268 if path[pos] as usize == node_id && pos + 1 < max_path_length {
1269 let next_node = path[pos + 1] as usize;
1271
1272 if let Some(child_idx) = tree_structure[node_id]
1274 .iter()
1275 .position(|&child| child == next_node)
1276 {
1277 valid_samples.push(sample_idx);
1278 node_labels.push(child_idx as i32);
1279 break;
1280 }
1281 }
1282 }
1283 }
1284
1285 let n_valid = valid_samples.len();
1287 if n_valid == 0 {
1288 return Ok((
1289 Array2::<Float>::zeros((0, X.ncols())),
1290 Array2::<i32>::zeros((0, 1)),
1291 ));
1292 }
1293
1294 let mut node_X = Array2::<Float>::zeros((n_valid, X.ncols()));
1295 let mut node_y = Array2::<i32>::zeros((n_valid, 1));
1296
1297 for (i, &sample_idx) in valid_samples.iter().enumerate() {
1298 for j in 0..X.ncols() {
1299 node_X[[i, j]] = X[[sample_idx, j]];
1300 }
1301 node_y[[i, 0]] = node_labels[i];
1302 }
1303
1304 Ok((node_X, node_y))
1305 }
1306}
1307
1308impl Predict<Array2<Float>, Array2<i32>>
1309 for TreeStructuredPredictor<TreeStructuredPredictorTrained>
1310{
1311 fn predict(&self, X: &Array2<Float>) -> SklResult<Array2<i32>> {
1312 let n_samples = X.nrows();
1313 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.max_depth));
1314
1315 for sample_idx in 0..n_samples {
1316 let sample = X.row(sample_idx);
1317 let path = self.predict_tree_path(&sample)?;
1318
1319 for (pos, &node) in path.iter().enumerate() {
1320 if pos < self.state.max_depth {
1321 predictions[[sample_idx, pos]] = node as i32;
1322 }
1323 }
1324 }
1325
1326 Ok(predictions)
1327 }
1328}
1329
1330impl TreeStructuredPredictor<TreeStructuredPredictorTrained> {
1331 fn predict_tree_path(&self, sample: &ArrayView1<Float>) -> SklResult<Vec<usize>> {
1333 let mut path = Vec::new();
1334 let mut current_node = 0; path.push(current_node);
1336
1337 while !self.state.tree_structure[current_node].is_empty() {
1338 if let Some(classifier) = self.state.node_classifiers.get(¤t_node) {
1340 let sample_2d = sample.to_owned().insert_axis(scirs2_core::ndarray::Axis(0));
1341 let prediction = classifier.predict(&sample_2d.view())?;
1342 let child_idx = prediction[[0, 0]] as usize;
1343
1344 if child_idx < self.state.tree_structure[current_node].len() {
1345 current_node = self.state.tree_structure[current_node][child_idx];
1346 path.push(current_node);
1347 } else {
1348 break; }
1350 } else {
1351 break; }
1353 }
1354
1355 Ok(path)
1356 }
1357
1358 pub fn tree_structure(&self) -> &Vec<Vec<usize>> {
1360 &self.state.tree_structure
1361 }
1362}
1363
1364pub fn build_classification_tree(
1368 X: &Array2<Float>,
1369 y: &Array2<i32>,
1370 indices: &[usize],
1371 feature_importances: &mut Array1<Float>,
1372 depth: usize,
1373 max_depth: Option<usize>,
1374 min_samples_split: usize,
1375 min_samples_leaf: usize,
1376 criterion: ClassificationCriterion,
1377 classes_per_target: &[Vec<i32>],
1378) -> SklResult<ClassificationDecisionNode> {
1379 let n_samples = indices.len();
1380
1381 let (current_impurity, prediction, probabilities) =
1383 calculate_classification_metrics(y, indices, classes_per_target, criterion);
1384
1385 let should_stop = n_samples < min_samples_split
1387 || (max_depth.is_some() && depth >= max_depth.unwrap())
1388 || current_impurity == 0.0;
1389
1390 if should_stop {
1391 return Ok(ClassificationDecisionNode {
1392 is_leaf: true,
1393 prediction: Some(prediction),
1394 probabilities: Some(probabilities),
1395 feature_idx: None,
1396 threshold: None,
1397 left: None,
1398 right: None,
1399 n_samples,
1400 impurity: current_impurity,
1401 });
1402 }
1403
1404 let mut best_impurity_reduction = 0.0;
1406 let mut best_feature = None;
1407 let mut best_threshold = None;
1408 let mut best_left_indices = Vec::new();
1409 let mut best_right_indices = Vec::new();
1410
1411 for feature_idx in 0..X.ncols() {
1412 let mut feature_values: Vec<Float> = indices.iter().map(|&i| X[[i, feature_idx]]).collect();
1414 feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
1415 feature_values.dedup();
1416
1417 for i in 0..feature_values.len().saturating_sub(1) {
1419 let threshold = (feature_values[i] + feature_values[i + 1]) / 2.0;
1420
1421 let (left_indices, right_indices): (Vec<usize>, Vec<usize>) = indices
1422 .iter()
1423 .partition(|&&idx| X[[idx, feature_idx]] <= threshold);
1424
1425 if left_indices.len() < min_samples_leaf || right_indices.len() < min_samples_leaf {
1427 continue;
1428 }
1429
1430 let (left_impurity, _, _) =
1432 calculate_classification_metrics(y, &left_indices, classes_per_target, criterion);
1433 let (right_impurity, _, _) =
1434 calculate_classification_metrics(y, &right_indices, classes_per_target, criterion);
1435
1436 let weighted_impurity = (left_indices.len() as Float * left_impurity
1437 + right_indices.len() as Float * right_impurity)
1438 / n_samples as Float;
1439 let impurity_reduction = current_impurity - weighted_impurity;
1440
1441 if impurity_reduction > best_impurity_reduction {
1442 best_impurity_reduction = impurity_reduction;
1443 best_feature = Some(feature_idx);
1444 best_threshold = Some(threshold);
1445 best_left_indices = left_indices;
1446 best_right_indices = right_indices;
1447 }
1448 }
1449 }
1450
1451 if best_feature.is_none() || best_impurity_reduction <= 0.0 {
1453 return Ok(ClassificationDecisionNode {
1454 is_leaf: true,
1455 prediction: Some(prediction),
1456 probabilities: Some(probabilities),
1457 feature_idx: None,
1458 threshold: None,
1459 left: None,
1460 right: None,
1461 n_samples,
1462 impurity: current_impurity,
1463 });
1464 }
1465
1466 feature_importances[best_feature.unwrap()] += best_impurity_reduction * n_samples as Float;
1468
1469 let left_child = build_classification_tree(
1471 X,
1472 y,
1473 &best_left_indices,
1474 feature_importances,
1475 depth + 1,
1476 max_depth,
1477 min_samples_split,
1478 min_samples_leaf,
1479 criterion,
1480 classes_per_target,
1481 )?;
1482
1483 let right_child = build_classification_tree(
1484 X,
1485 y,
1486 &best_right_indices,
1487 feature_importances,
1488 depth + 1,
1489 max_depth,
1490 min_samples_split,
1491 min_samples_leaf,
1492 criterion,
1493 classes_per_target,
1494 )?;
1495
1496 Ok(ClassificationDecisionNode {
1497 is_leaf: false,
1498 prediction: Some(prediction),
1499 probabilities: Some(probabilities),
1500 feature_idx: best_feature,
1501 threshold: best_threshold,
1502 left: Some(Box::new(left_child)),
1503 right: Some(Box::new(right_child)),
1504 n_samples,
1505 impurity: current_impurity,
1506 })
1507}
1508
1509pub fn calculate_classification_metrics(
1511 y: &Array2<i32>,
1512 indices: &[usize],
1513 classes_per_target: &[Vec<i32>],
1514 criterion: ClassificationCriterion,
1515) -> (Float, Array1<i32>, Array2<Float>) {
1516 let n_targets = y.ncols();
1517 let n_samples = indices.len();
1518
1519 let mut prediction = Array1::<i32>::zeros(n_targets);
1520 let mut total_impurity = 0.0;
1521
1522 let max_classes = classes_per_target
1524 .iter()
1525 .map(|classes| classes.len())
1526 .max()
1527 .unwrap_or(0);
1528 let mut probabilities = Array2::<Float>::zeros((n_targets, max_classes));
1529
1530 for target_idx in 0..n_targets {
1531 let classes = &classes_per_target[target_idx];
1532 let n_classes = classes.len();
1533
1534 let mut class_counts = vec![0; n_classes];
1536 for &sample_idx in indices {
1537 let class_label = y[[sample_idx, target_idx]];
1538 if let Some(class_idx) = classes.iter().position(|&c| c == class_label) {
1539 class_counts[class_idx] += 1;
1540 }
1541 }
1542
1543 let majority_class_idx = class_counts
1545 .iter()
1546 .enumerate()
1547 .max_by_key(|(_, &count)| count)
1548 .map(|(idx, _)| idx)
1549 .unwrap_or(0);
1550
1551 prediction[target_idx] = classes[majority_class_idx];
1552
1553 let mut target_impurity = 0.0;
1555 for (class_idx, &count) in class_counts.iter().enumerate() {
1556 let prob = count as Float / n_samples as Float;
1557 probabilities[[target_idx, class_idx]] = prob;
1558
1559 if prob > 0.0 {
1560 target_impurity += match criterion {
1561 ClassificationCriterion::Gini => prob * (1.0 - prob),
1562 ClassificationCriterion::Entropy => -prob * prob.ln(),
1563 };
1564 }
1565 }
1566
1567 if matches!(criterion, ClassificationCriterion::Gini) {
1569 target_impurity *= 2.0;
1570 }
1571
1572 total_impurity += target_impurity;
1573 }
1574
1575 total_impurity /= n_targets as Float;
1577
1578 (total_impurity, prediction, probabilities)
1579}
1580
1581pub fn predict_classification_sample(
1583 node: &ClassificationDecisionNode,
1584 sample: &ArrayView1<Float>,
1585) -> Array1<i32> {
1586 if node.is_leaf {
1587 return node.prediction.as_ref().unwrap().clone();
1588 }
1589
1590 let feature_value = sample[node.feature_idx.unwrap()];
1591 let threshold = node.threshold.unwrap();
1592
1593 if feature_value <= threshold {
1594 predict_classification_sample(node.left.as_ref().unwrap(), sample)
1595 } else {
1596 predict_classification_sample(node.right.as_ref().unwrap(), sample)
1597 }
1598}
1599
1600pub fn predict_classification_probabilities(
1602 node: &ClassificationDecisionNode,
1603 sample: &ArrayView1<Float>,
1604 classes_per_target: &[Vec<i32>],
1605) -> Vec<Array1<Float>> {
1606 if node.is_leaf {
1607 let mut result = Vec::new();
1608 for target_idx in 0..classes_per_target.len() {
1609 let n_classes = classes_per_target[target_idx].len();
1610 let mut target_probs = Array1::<Float>::zeros(n_classes);
1611 for class_idx in 0..n_classes {
1612 target_probs[class_idx] =
1613 node.probabilities.as_ref().unwrap()[[target_idx, class_idx]];
1614 }
1615 result.push(target_probs);
1616 }
1617 return result;
1618 }
1619
1620 let feature_value = sample[node.feature_idx.unwrap()];
1621 let threshold = node.threshold.unwrap();
1622
1623 if feature_value <= threshold {
1624 predict_classification_probabilities(
1625 node.left.as_ref().unwrap(),
1626 sample,
1627 classes_per_target,
1628 )
1629 } else {
1630 predict_classification_probabilities(
1631 node.right.as_ref().unwrap(),
1632 sample,
1633 classes_per_target,
1634 )
1635 }
1636}