tiny_ml/training/
mod.rs

1use std::usize;
2
3#[cfg(feature = "parallelization")]
4use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
5
6use crate::networks::NeuralNetwork;
7
8/// A simple struct for Training Neural Networks
9pub struct BasicTrainer<const N: usize, const O: usize> {
10    /// the data the model is to be trained on
11    training_data: DataSet<N, O>,
12}
13
14impl<const N: usize, const O: usize> BasicTrainer<N, O> {
15    pub fn new(data: DataSet<N, O>) -> Self {
16        Self {
17            training_data: data,
18        }
19    }
20
21    pub fn train(&self, net: &mut NeuralNetwork<N, O>, iterations: usize) {
22        let training_data = &self.training_data;
23        let mut pre_dist = self.compute_total_error(net, training_data);
24        for _ in 0..=iterations {
25            net.random_edit();
26            let after_dist = self.compute_total_error(net, training_data);
27            if pre_dist < after_dist {
28                net.reverse_edit();
29            } else {
30                pre_dist = after_dist;
31            }
32        }
33    }
34
35    pub fn get_total_error(&self, net: &NeuralNetwork<N, O>) -> f32 {
36        self.compute_total_error(net, &self.training_data)
37    }
38
39    fn compute_total_error(&self, net: &NeuralNetwork<N, O>, data: &DataSet<N, O>) -> f32 {
40        #[cfg(feature = "parallelization")]
41        let it = { data.inputs.par_iter() };
42
43        #[cfg(not(feature = "parallelization"))]
44        let it = { data.inputs.iter() };
45
46        it.zip(&data.outputs)
47            .map(|(input, output)| {
48                let result = net.unbuffered_run(input);
49                output
50                    .iter()
51                    .zip(result)
52                    .fold(0.0, |dist, x| dist + (x.0 - x.1).abs())
53            })
54            .sum()
55    }
56}
57
58/// A set of inputs and the expected Outputs
59#[derive(Default)]
60pub struct DataSet<const N: usize, const O: usize> {
61    pub inputs: Vec<[f32; N]>,
62    pub outputs: Vec<[f32; O]>,
63}