1use super::tree::*;
3use rand;
4use rand::Rng;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use rayon::prelude::*;
8
9pub trait VotingMethod<L> {
11 fn voting(&self, tree_results: &[&L]) -> f64;
12}
13
14pub trait FromGetProbability {
16 fn probability(&self) -> f64;
17}
18
19impl<T> FromGetProbability for T
20where
21 f64: From<T>,
22 T: Copy,
23{
24 fn probability(&self) -> f64 {
25 return f64::from(*self);
26 }
27}
28
29pub struct AverageVoting;
31impl<L> VotingMethod<L> for AverageVoting
32where
33 L: FromGetProbability,
34{
35 fn voting(&self, tree_results: &[&L]) -> f64 {
36 let sum = tree_results.iter().fold(
37 0f64,
38 |sum, l| sum + l.probability(),
39 );
40 return (sum as f64) / (tree_results.len() as f64);
41 }
42}
43
44#[derive(Serialize, Deserialize)]
48#[serde(bound(serialize = "DecisionTree<L,F>: Serialize"))]
49#[serde(bound(deserialize = "DecisionTree<L,F>: DeserializeOwned"))]
50pub struct RandomForest<L, F>
51where
52 F: TreeFunction,
53{
54 subtrees: Vec<DecisionTree<L, F>>,
55}
56
57impl<L, F> RandomForest<L, F>
58where
59 F: TreeFunction,
60{
61 pub fn forest_predictions(&self, input: &F::Data) -> Vec<&L> {
64 self.subtrees
65 .iter()
66 .filter_map(|tree| tree.predict(input))
67 .collect()
68 }
69
70 pub fn predict<V>(&self, input: &F::Data, voting_method: V) -> Option<f64>
72 where
73 V: VotingMethod<L>,
74 {
75 let predictions: Vec<_> = self.forest_predictions(input);
76 Some(voting_method.voting(&predictions[..]))
77 }
78}
79
80impl<L, F> RandomForest<L, F>
81where
82 F: TreeFunction + Send + Sync,
83 <F as TreeFunction>::Param: Send + Sync,
84 <F as TreeFunction>::Data: Send + Sync,
85 L: Send + Sync,
86{
87 pub fn forest_predictions_parallel(&self, input: &F::Data) -> Vec<&L> {
90 self.subtrees
91 .par_iter()
92 .filter_map(|tree| tree.predict(input))
93 .fold(|| Vec::with_capacity(self.subtrees.len()), |mut v, x| {
94 v.push(x);
95 v
96 })
97 .reduce(|| Vec::with_capacity(self.subtrees.len()), |mut v, mut x| {
98 v.append(&mut x);
99 v
100 })
101 }
102}
103
104
105pub struct RandomForestLearnParam<LearnF>
107where
108 LearnF: TreeLearnFunctions,
109{
110 pub tree_param: TreeParameters,
112 pub number_of_trees: usize,
114 pub size_of_subset_per_training: usize,
116 pub learn_function: LearnF,
118}
119
120impl<LearnF> RandomForestLearnParam<LearnF>
121where
122 LearnF: TreeLearnFunctions + Copy,
123{
124 pub fn new(
129 number_of_trees: usize,
130 size_of_subset_per_training: usize,
131 learnf: LearnF,
132 ) -> RandomForestLearnParam<LearnF> {
133 RandomForestLearnParam {
134 tree_param: TreeParameters::new(),
135 number_of_trees: number_of_trees,
136 size_of_subset_per_training: size_of_subset_per_training,
137 learn_function: learnf,
138 }
139 }
140
141 pub fn train_forest(
143 self,
144 train_set: &[(&LearnF::Data, &LearnF::Truth)],
145 ) -> Option<RandomForest<LearnF::LeafParam, LearnF::PredictFunction>> {
146 let mut res = vec![];
147 let mut rng = rand::thread_rng();
148 let mut subset = Vec::with_capacity(self.size_of_subset_per_training);
149 for _ in 0..self.number_of_trees {
150 subset.clear();
151 for _ in 0..self.size_of_subset_per_training {
152 subset.push(train_set[rng.gen_range(0, train_set.len())]);
153 }
154 let tree = self.tree_param.learn_tree(self.learn_function, &subset[..]);
155 res.push(tree);
156 }
157 Some(RandomForest { subtrees: res })
158
159 }
160}
161impl<LearnF> RandomForestLearnParam<LearnF>
162where
163 LearnF: TreeLearnFunctions + Copy + Send + Sync,
164 LearnF::PredictFunction: Send + Sync,
165 LearnF::Truth: Send + Sync,
166 LearnF::LeafParam: Send + Sync,
167 LearnF::Data: Send + Sync,
168 LearnF::Param: Send + Sync,
169{
170 pub fn train_forest_parallel(
173 self,
174 train_set: &[(&LearnF::Data, &LearnF::Truth)],
175 ) -> Option<RandomForest<LearnF::LeafParam, LearnF::PredictFunction>> {
176
177 let subset_size = self.size_of_subset_per_training;
178 let trees = (0..self.number_of_trees)
179 .into_par_iter()
180 .map(|_| {
181 let mut rng = rand::thread_rng();
182 let mut subset = Vec::with_capacity(self.size_of_subset_per_training);
183 for _ in 0..subset_size {
184 subset.push(train_set[rng.gen_range(0, train_set.len())]);
185 }
186 let tree = self.tree_param.learn_tree(self.learn_function, &subset[..]);
187 tree
188 })
189 .fold(|| Vec::with_capacity(subset_size), |mut v, x| {
190 v.push(x);
191 v
192 })
193 .reduce(|| Vec::with_capacity(subset_size), |mut v, mut x| {
194 v.append(&mut x);
195 v
196 });
197 Some(RandomForest { subtrees: trees })
198 }
199}