1use core::panic;
2use rand::distr::{Distribution, Uniform};
3use rand::rng;
4
5#[derive(Clone)]
6struct Layer {
7 inputs: Vec<f32>,
8 outputs: Vec<f32>,
9 biases: Vec<f32>,
10 weights: Vec<Vec<f32>>,
11 error_derivative: Vec<f32>,
12}
13
14#[derive(Clone)]
15pub struct NeuralNetwork {
16 layers: Vec<Layer>,
17 learning_rate: f32,
18 activation: Box<dyn Activation>,
19}
20
21impl NeuralNetwork {
22 pub fn new(
23 layer_sizes: Vec<usize>,
24 learning_rate: f32,
25 activation: Box<dyn Activation>,
26 ) -> Self {
27 let mut rng = rng();
28 let between = Uniform::try_from(-1.0..1.0).unwrap();
29
30 let mut layers = Vec::new();
31 for layer_index in 0..layer_sizes.len() {
32 let inputs = vec![
33 0.0;
34 layer_sizes[match layer_index {
35 0 => layer_index,
36 _ => layer_index - 1,
37 }]
38 ];
39 dbg!(&inputs);
40 let weights = (0..layer_sizes[layer_index])
41 .into_iter()
42 .map(|_| {
43 (0..layer_sizes[match layer_index {
44 0 => layer_index,
45 _ => layer_index - 1,
46 }])
47 .into_iter()
48 .map(|_| between.sample(&mut rng))
49 .collect()
50 })
51 .collect();
52 layers.push(Layer {
53 inputs,
54 outputs: vec![0.0; layer_sizes[layer_index]],
55 biases: (0..layer_sizes[layer_index])
56 .into_iter()
57 .map(|_| between.sample(&mut rng))
58 .collect(),
59 weights,
60 error_derivative: vec![0.0; layer_sizes[layer_index]],
61 });
62 }
63 return Self {
64 layers,
65 learning_rate,
66 activation,
67 };
68 }
69
70 pub fn forward(&mut self, inputs: &Vec<f32>) -> Vec<f32> {
71 if inputs.len() != self.layers[0].inputs.len() {
72 dbg!(inputs.len(), self.layers[0].inputs.len());
73 panic!(
74 "The given arguement: 'inputs' in the 'forward' method must have the same length as the first 'Layer' inputs defined previously"
75 );
76 }
77
78 for layer_index in 0..self.layers.len() {
79 self.layers[layer_index].outputs.fill(0.0);
80 self.layers[layer_index].inputs = match layer_index {
81 0 => inputs.clone(),
82 _ => self.layers[layer_index - 1].outputs.clone(),
83 };
84 for j in 0..self.layers[layer_index].outputs.len() {
85 for k in 0..self.layers[layer_index].inputs.len() {
86 self.layers[layer_index].outputs[j] +=
87 self.layers[layer_index].inputs[k] * self.layers[layer_index].weights[j][k];
88 }
89 self.layers[layer_index].outputs[j] += self.layers[layer_index].biases[j];
90 self.layers[layer_index].outputs[j] = self
91 .activation
92 .function(self.layers[layer_index].outputs[j]);
93 }
94 }
95 self.layers.last().unwrap().outputs.clone()
96 }
97
98 pub fn errors(&self, expected: &Vec<f32>) -> f32 {
99 if expected.len() != self.layers.last().unwrap().outputs.len() {
100 panic!(
101 "The given arguement: 'expected' in the 'errors' method must have the same length as the last 'Layer
102 outputs defined previously"
103 );
104 }
105 let mut error = 0.0;
106 for (actual, expected) in self.layers.last().unwrap().outputs.iter().zip(expected) {
107 error += (actual - expected).powi(2);
108 }
109 error
110 }
111 pub fn backpropagate(&mut self, expected: &Vec<f32>) {
112 for layer_index in (0..self.layers.len()).rev() {
113 for k in 0..self.layers[layer_index].outputs.len() {
114 let delta = if layer_index == self.layers.len() - 1 {
115 let error = self.layers[layer_index].outputs[k] - expected[k];
117 error
118 * self
119 .activation
120 .derivative(self.layers[layer_index].outputs[k])
121 } else {
122 let mut error = 0.0;
124 for j in 0..self.layers[layer_index + 1].outputs.len() {
125 error += self.layers[layer_index + 1].weights[j][k]
126 * self.layers[layer_index + 1].error_derivative[j];
127 }
128 error
129 * self
130 .activation
131 .derivative(self.layers[layer_index].outputs[k])
132 };
133 self.layers[layer_index].error_derivative[k] = delta;
134 }
135 for j in 0..self.layers[layer_index].outputs.len() {
136 for k in 0..self.layers[layer_index].inputs.len() {
137 self.layers[layer_index].weights[j][k] -= self.learning_rate
138 * self.layers[layer_index].error_derivative[j]
139 * self.layers[layer_index].inputs[k];
140 }
141 self.layers[layer_index].biases[j] -=
142 self.learning_rate * self.layers[layer_index].error_derivative[j];
143 }
144 }
145 }
146}
147
148pub trait Activation: ActivationClone {
149 fn function(&self, x: f32) -> f32;
150 fn derivative(&self, x: f32) -> f32;
151}
152
153pub trait ActivationClone {
154 fn clone_box(&self) -> Box<dyn Activation>;
155}
156
157impl<T> ActivationClone for T
158where
159 T: 'static + Activation + Clone,
160{
161 fn clone_box(&self) -> Box<dyn Activation> {
162 Box::new(self.clone())
163 }
164}
165
166impl Clone for Box<dyn Activation> {
168 fn clone(&self) -> Box<dyn Activation> {
169 self.clone_box()
170 }
171}
172
173#[derive(Clone)]
174pub struct Sigmoid;
175impl Activation for Sigmoid {
176 fn function(&self, x: f32) -> f32 {
177 x.tanh()
178 }
179
180 fn derivative(&self, x: f32) -> f32 {
181 1.0 - x.powi(2)
182 }
183}