1use std::usize;
2
3#[cfg(feature = "parallelization")]
4use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
5
6use crate::networks::NeuralNetwork;
7
8pub struct BasicTrainer<const N: usize, const O: usize> {
10 training_data: DataSet<N, O>,
12}
13
14impl<const N: usize, const O: usize> BasicTrainer<N, O> {
15 pub fn new(data: DataSet<N, O>) -> Self {
16 Self {
17 training_data: data,
18 }
19 }
20
21 pub fn train(&self, net: &mut NeuralNetwork<N, O>, iterations: usize) {
22 let training_data = &self.training_data;
23 let mut pre_dist = self.compute_total_error(net, training_data);
24 for _ in 0..=iterations {
25 net.random_edit();
26 let after_dist = self.compute_total_error(net, training_data);
27 if pre_dist < after_dist {
28 net.reverse_edit();
29 } else {
30 pre_dist = after_dist;
31 }
32 }
33 }
34
35 pub fn get_total_error(&self, net: &NeuralNetwork<N, O>) -> f32 {
36 self.compute_total_error(net, &self.training_data)
37 }
38
39 fn compute_total_error(&self, net: &NeuralNetwork<N, O>, data: &DataSet<N, O>) -> f32 {
40 #[cfg(feature = "parallelization")]
41 let it = { data.inputs.par_iter() };
42
43 #[cfg(not(feature = "parallelization"))]
44 let it = { data.inputs.iter() };
45
46 it.zip(&data.outputs)
47 .map(|(input, output)| {
48 let result = net.unbuffered_run(input);
49 output
50 .iter()
51 .zip(result)
52 .fold(0.0, |dist, x| dist + (x.0 - x.1).abs())
53 })
54 .sum()
55 }
56}
57
58#[derive(Default)]
60pub struct DataSet<const N: usize, const O: usize> {
61 pub inputs: Vec<[f32; N]>,
62 pub outputs: Vec<[f32; O]>,
63}