stamm/
randforest.rs

1//! Implements a random forest using `stamm::tree`.
2use super::tree::*;
3use rand;
4use rand::Rng;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use rayon::prelude::*;
8
9/// Combine the results of all trees in a random forest  in a numeric way
10pub trait VotingMethod<L> {
11    fn voting(&self, tree_results: &[&L]) -> f64;
12}
13
14/// For voting some probability this trait should be implemented by the leaf data for extracting this probability
15pub trait FromGetProbability {
16    fn probability(&self) -> f64;
17}
18
19impl<T> FromGetProbability for T
20where
21    f64: From<T>,
22    T: Copy,
23{
24    fn probability(&self) -> f64 {
25        return f64::from(*self);
26    }
27}
28
29/// Implementing average voting
30pub struct AverageVoting;
31impl<L> VotingMethod<L> for AverageVoting
32where
33    L: FromGetProbability,
34{
35    fn voting(&self, tree_results: &[&L]) -> f64 {
36        let sum = tree_results.iter().fold(
37            0f64,
38            |sum, l| sum + l.probability(),
39        );
40        return (sum as f64) / (tree_results.len() as f64);
41    }
42}
43
44/// A random forest to combine some decision trees.
45/// The decision trees have leafs with data from type `L`
46/// and use a TreeFunction from type `F`.
47#[derive(Serialize, Deserialize)]
48#[serde(bound(serialize = "DecisionTree<L,F>: Serialize"))]
49#[serde(bound(deserialize = "DecisionTree<L,F>: DeserializeOwned"))]
50pub struct RandomForest<L, F>
51where
52    F: TreeFunction,
53{
54    subtrees: Vec<DecisionTree<L, F>>,
55}
56
57impl<L, F> RandomForest<L, F>
58where
59    F: TreeFunction,
60{
61    /// Let every tree predict a result of the `input`.
62    /// Returns the results in form a Vec of the leaf data.
63    pub fn forest_predictions(&self, input: &F::Data) -> Vec<&L> {
64        self.subtrees
65            .iter()
66            .filter_map(|tree| tree.predict(input))
67            .collect()
68    }
69
70    /// Let every tree predict a result and combine them using a vote method.
71    pub fn predict<V>(&self, input: &F::Data, voting_method: V) -> Option<f64>
72    where
73        V: VotingMethod<L>,
74    {
75        let predictions: Vec<_> = self.forest_predictions(input);
76        Some(voting_method.voting(&predictions[..]))
77    }
78}
79
80impl<L, F> RandomForest<L, F>
81where
82    F: TreeFunction + Send + Sync,
83    <F as TreeFunction>::Param: Send + Sync,
84    <F as TreeFunction>::Data: Send + Sync,
85    L: Send + Sync,
86{
87    /// Like [`forest_predictions`](#method.forest_predictions)
88    /// but use rayon to parallelize the computation.
89    pub fn forest_predictions_parallel(&self, input: &F::Data) -> Vec<&L> {
90        self.subtrees
91            .par_iter()
92            .filter_map(|tree| tree.predict(input))
93            .fold(|| Vec::with_capacity(self.subtrees.len()), |mut v, x| {
94                v.push(x);
95                v
96            })
97            .reduce(|| Vec::with_capacity(self.subtrees.len()), |mut v, mut x| {
98                v.append(&mut x);
99                v
100            })
101    }
102}
103
104
105/// Parameter describes the way to train a random forest
106pub struct RandomForestLearnParam<LearnF>
107where
108    LearnF: TreeLearnFunctions,
109{
110    /// parameter used for every tree
111    pub tree_param: TreeParameters,
112    /// number of trees
113    pub number_of_trees: usize,
114    /// size of a random training subset used for train one tree
115    pub size_of_subset_per_training: usize,
116    /// TreeLearnFunction
117    pub learn_function: LearnF,
118}
119
120impl<LearnF> RandomForestLearnParam<LearnF>
121where
122    LearnF: TreeLearnFunctions + Copy,
123{
124    /// Creates a new RandomForestLearnParam.
125    /// `number_of_trees` is the number of trees used in this random forest.
126    /// Every tree will be trained using a random subset of the training data. `size_of_subset_per_training` is the size of this subset.
127    /// `learnf` is the TreeLearnFunction for every tree
128    pub fn new(
129        number_of_trees: usize,
130        size_of_subset_per_training: usize,
131        learnf: LearnF,
132    ) -> RandomForestLearnParam<LearnF> {
133        RandomForestLearnParam {
134            tree_param: TreeParameters::new(),
135            number_of_trees: number_of_trees,
136            size_of_subset_per_training: size_of_subset_per_training,
137            learn_function: learnf,
138        }
139    }
140
141    /// Trains a random forest using the ground truth data `train_set`.
142    pub fn train_forest(
143        self,
144        train_set: &[(&LearnF::Data, &LearnF::Truth)],
145    ) -> Option<RandomForest<LearnF::LeafParam, LearnF::PredictFunction>> {
146        let mut res = vec![];
147        let mut rng = rand::thread_rng();
148        let mut subset = Vec::with_capacity(self.size_of_subset_per_training);
149        for _ in 0..self.number_of_trees {
150            subset.clear();
151            for _ in 0..self.size_of_subset_per_training {
152                subset.push(train_set[rng.gen_range(0, train_set.len())]);
153            }
154            let tree = self.tree_param.learn_tree(self.learn_function, &subset[..]);
155            res.push(tree);
156        }
157        Some(RandomForest { subtrees: res })
158
159    }
160}
161impl<LearnF> RandomForestLearnParam<LearnF>
162where
163    LearnF: TreeLearnFunctions + Copy + Send + Sync,
164    LearnF::PredictFunction: Send + Sync,
165    LearnF::Truth: Send + Sync,
166    LearnF::LeafParam: Send + Sync,
167    LearnF::Data: Send + Sync,
168    LearnF::Param: Send + Sync,
169{
170    /// Like [`train_forest`](#method.train_forest)
171    /// but use rayon to parallelize the training.
172    pub fn train_forest_parallel(
173        self,
174        train_set: &[(&LearnF::Data, &LearnF::Truth)],
175    ) -> Option<RandomForest<LearnF::LeafParam, LearnF::PredictFunction>> {
176
177        let subset_size = self.size_of_subset_per_training;
178        let trees = (0..self.number_of_trees)
179            .into_par_iter()
180            .map(|_| {
181                let mut rng = rand::thread_rng();
182                let mut subset = Vec::with_capacity(self.size_of_subset_per_training);
183                for _ in 0..subset_size {
184                    subset.push(train_set[rng.gen_range(0, train_set.len())]);
185                }
186                let tree = self.tree_param.learn_tree(self.learn_function, &subset[..]);
187                tree
188            })
189            .fold(|| Vec::with_capacity(subset_size), |mut v, x| {
190                v.push(x);
191                v
192            })
193            .reduce(|| Vec::with_capacity(subset_size), |mut v, mut x| {
194                v.append(&mut x);
195                v
196            });
197        Some(RandomForest { subtrees: trees })
198    }
199}