vexus/
neural_network.rs

1use core::panic;
2use rand::distr::{Distribution, Uniform};
3use rand::rng;
4
5#[derive(Clone)]
6struct Layer {
7    inputs: Vec<f32>,
8    outputs: Vec<f32>,
9    biases: Vec<f32>,
10    weights: Vec<Vec<f32>>,
11    error_derivative: Vec<f32>,
12}
13
14#[derive(Clone)]
15pub struct NeuralNetwork {
16    layers: Vec<Layer>,
17    learning_rate: f32,
18    activation: Box<dyn Activation>,
19}
20
21impl NeuralNetwork {
22    pub fn new(
23        layer_sizes: Vec<usize>,
24        learning_rate: f32,
25        activation: Box<dyn Activation>,
26    ) -> Self {
27        let mut rng = rng();
28        let between = Uniform::try_from(-1.0..1.0).unwrap();
29
30        let mut layers = Vec::new();
31        for layer_index in 0..layer_sizes.len() {
32            let inputs = vec![
33                0.0;
34                layer_sizes[match layer_index {
35                    0 => layer_index,
36                    _ => layer_index - 1,
37                }]
38            ];
39            dbg!(&inputs);
40            let weights = (0..layer_sizes[layer_index])
41                .into_iter()
42                .map(|_| {
43                    (0..layer_sizes[match layer_index {
44                        0 => layer_index,
45                        _ => layer_index - 1,
46                    }])
47                        .into_iter()
48                        .map(|_| between.sample(&mut rng))
49                        .collect()
50                })
51                .collect();
52            layers.push(Layer {
53                inputs,
54                outputs: vec![0.0; layer_sizes[layer_index]],
55                biases: (0..layer_sizes[layer_index])
56                    .into_iter()
57                    .map(|_| between.sample(&mut rng))
58                    .collect(),
59                weights,
60                error_derivative: vec![0.0; layer_sizes[layer_index]],
61            });
62        }
63        return Self {
64            layers,
65            learning_rate,
66            activation,
67        };
68    }
69
70    pub fn forward(&mut self, inputs: &Vec<f32>) -> Vec<f32> {
71        if inputs.len() != self.layers[0].inputs.len() {
72            dbg!(inputs.len(), self.layers[0].inputs.len());
73            panic!(
74                "The given arguement: 'inputs' in the 'forward' method must have the same length as the first 'Layer' inputs defined previously"
75            );
76        }
77
78        for layer_index in 0..self.layers.len() {
79            self.layers[layer_index].outputs.fill(0.0);
80            self.layers[layer_index].inputs = match layer_index {
81                0 => inputs.clone(),
82                _ => self.layers[layer_index - 1].outputs.clone(),
83            };
84            for j in 0..self.layers[layer_index].outputs.len() {
85                for k in 0..self.layers[layer_index].inputs.len() {
86                    self.layers[layer_index].outputs[j] +=
87                        self.layers[layer_index].inputs[k] * self.layers[layer_index].weights[j][k];
88                }
89                self.layers[layer_index].outputs[j] += self.layers[layer_index].biases[j];
90                self.layers[layer_index].outputs[j] = self
91                    .activation
92                    .function(self.layers[layer_index].outputs[j]);
93            }
94        }
95        self.layers.last().unwrap().outputs.clone()
96    }
97
98    pub fn errors(&self, expected: &Vec<f32>) -> f32 {
99        if expected.len() != self.layers.last().unwrap().outputs.len() {
100            panic!(
101                "The given arguement: 'expected' in the 'errors' method must have the same length as the last 'Layer
102                 outputs defined previously"
103            );
104        }
105        let mut error = 0.0;
106        for (actual, expected) in self.layers.last().unwrap().outputs.iter().zip(expected) {
107            error += (actual - expected).powi(2);
108        }
109        error
110    }
111    pub fn backpropagate(&mut self, expected: &Vec<f32>) {
112        for layer_index in (0..self.layers.len()).rev() {
113            for k in 0..self.layers[layer_index].outputs.len() {
114                let delta = if layer_index == self.layers.len() - 1 {
115                    // Output layer: difference from expected value
116                    let error = self.layers[layer_index].outputs[k] - expected[k];
117                    error
118                        * self
119                            .activation
120                            .derivative(self.layers[layer_index].outputs[k])
121                } else {
122                    // Hidden layer: sum of weighted deltas from next layer
123                    let mut error = 0.0;
124                    for j in 0..self.layers[layer_index + 1].outputs.len() {
125                        error += self.layers[layer_index + 1].weights[j][k]
126                            * self.layers[layer_index + 1].error_derivative[j];
127                    }
128                    error
129                        * self
130                            .activation
131                            .derivative(self.layers[layer_index].outputs[k])
132                };
133                self.layers[layer_index].error_derivative[k] = delta;
134            }
135            for j in 0..self.layers[layer_index].outputs.len() {
136                for k in 0..self.layers[layer_index].inputs.len() {
137                    self.layers[layer_index].weights[j][k] -= self.learning_rate
138                        * self.layers[layer_index].error_derivative[j]
139                        * self.layers[layer_index].inputs[k];
140                }
141                self.layers[layer_index].biases[j] -=
142                    self.learning_rate * self.layers[layer_index].error_derivative[j];
143            }
144        }
145    }
146}
147
148pub trait Activation: ActivationClone {
149    fn function(&self, x: f32) -> f32;
150    fn derivative(&self, x: f32) -> f32;
151}
152
153pub trait ActivationClone {
154    fn clone_box(&self) -> Box<dyn Activation>;
155}
156
157impl<T> ActivationClone for T
158where
159    T: 'static + Activation + Clone,
160{
161    fn clone_box(&self) -> Box<dyn Activation> {
162        Box::new(self.clone())
163    }
164}
165
166// Now implement Clone for Box<dyn Activation>
167impl Clone for Box<dyn Activation> {
168    fn clone(&self) -> Box<dyn Activation> {
169        self.clone_box()
170    }
171}
172
173#[derive(Clone)]
174pub struct Sigmoid;
175impl Activation for Sigmoid {
176    fn function(&self, x: f32) -> f32 {
177        x.tanh()
178    }
179
180    fn derivative(&self, x: f32) -> f32 {
181        1.0 - x.powi(2)
182    }
183}