1use super::*;
4use num_traits::Float;
5use std::collections::HashMap;
6
7pub struct IncrementalBackprop<T: Float + Send + Default> {
10 learning_rate: T,
11 momentum: T,
12 error_function: Box<dyn ErrorFunction<T>>,
13 previous_weight_deltas: Vec<Vec<T>>,
14 previous_bias_deltas: Vec<Vec<T>>,
15 callback: Option<TrainingCallback<T>>,
16}
17
18impl<T: Float + Send + Default> IncrementalBackprop<T> {
19 pub fn new(learning_rate: T) -> Self {
20 Self {
21 learning_rate,
22 momentum: T::zero(),
23 error_function: Box::new(MseError),
24 previous_weight_deltas: Vec::new(),
25 previous_bias_deltas: Vec::new(),
26 callback: None,
27 }
28 }
29
30 pub fn with_momentum(mut self, momentum: T) -> Self {
31 self.momentum = momentum;
32 self
33 }
34
35 pub fn with_error_function(mut self, error_function: Box<dyn ErrorFunction<T>>) -> Self {
36 self.error_function = error_function;
37 self
38 }
39
40 fn initialize_deltas(&mut self, network: &Network<T>) {
41 if self.previous_weight_deltas.is_empty() {
42 self.previous_weight_deltas = network
43 .layers
44 .iter()
45 .skip(1) .map(|layer| {
47 let num_neurons = layer.neurons.len();
48 let num_connections = if layer.neurons.is_empty() {
49 0
50 } else {
51 layer.neurons[0].connections.len()
52 };
53 vec![T::zero(); num_neurons * num_connections]
54 })
55 .collect();
56 self.previous_bias_deltas = network
57 .layers
58 .iter()
59 .skip(1) .map(|layer| vec![T::zero(); layer.neurons.len()])
61 .collect();
62 }
63 }
64}
65
66impl<T: Float + Send + Default> TrainingAlgorithm<T> for IncrementalBackprop<T> {
67 fn train_epoch(
68 &mut self,
69 network: &mut Network<T>,
70 data: &TrainingData<T>,
71 ) -> Result<T, TrainingError> {
72 use super::helpers::*;
73
74 self.initialize_deltas(network);
75
76 let mut total_error = T::zero();
77
78 let simple_network = network_to_simple(network);
80
81 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
82 let activations = forward_propagate(&simple_network, input);
84
85 let output = &activations[activations.len() - 1];
87
88 total_error = total_error + self.error_function.calculate(output, desired_output);
90
91 let (weight_gradients, bias_gradients) = calculate_gradients(
93 &simple_network,
94 &activations,
95 desired_output,
96 self.error_function.as_ref(),
97 );
98
99 for layer_idx in 0..weight_gradients.len() {
102 for (i, &grad) in weight_gradients[layer_idx].iter().enumerate() {
104 let delta = self.learning_rate * grad
105 + self.momentum * self.previous_weight_deltas[layer_idx][i];
106 self.previous_weight_deltas[layer_idx][i] = delta;
107 }
108
109 for (i, &grad) in bias_gradients[layer_idx].iter().enumerate() {
111 let delta = self.learning_rate * grad
112 + self.momentum * self.previous_bias_deltas[layer_idx][i];
113 self.previous_bias_deltas[layer_idx][i] = delta;
114 }
115 }
116
117 apply_updates_to_network(
119 network,
120 &self.previous_weight_deltas,
121 &self.previous_bias_deltas,
122 );
123 }
124
125 Ok(total_error / T::from(data.inputs.len()).unwrap())
126 }
127
128 fn calculate_error(&self, network: &Network<T>, data: &TrainingData<T>) -> T {
129 let mut total_error = T::zero();
130 let mut network_clone = network.clone();
131
132 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
133 let output = network_clone.run(input);
134 total_error = total_error + self.error_function.calculate(&output, desired_output);
135 }
136
137 total_error / T::from(data.inputs.len()).unwrap()
138 }
139
140 fn count_bit_fails(
141 &self,
142 network: &Network<T>,
143 data: &TrainingData<T>,
144 bit_fail_limit: T,
145 ) -> usize {
146 let mut bit_fails = 0;
147 let mut network_clone = network.clone();
148
149 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
150 let output = network_clone.run(input);
151 for (&actual, &desired) in output.iter().zip(desired_output.iter()) {
152 if (actual - desired).abs() > bit_fail_limit {
153 bit_fails += 1;
154 }
155 }
156 }
157
158 bit_fails
159 }
160
161 fn save_state(&self) -> TrainingState<T> {
162 let mut state = HashMap::new();
163 state.insert("learning_rate".to_string(), vec![self.learning_rate]);
164 state.insert("momentum".to_string(), vec![self.momentum]);
165
166 TrainingState {
167 epoch: 0,
168 best_error: T::from(f32::MAX).unwrap(),
169 algorithm_specific: state,
170 }
171 }
172
173 fn restore_state(&mut self, state: TrainingState<T>) {
174 if let Some(lr) = state.algorithm_specific.get("learning_rate") {
175 if !lr.is_empty() {
176 self.learning_rate = lr[0];
177 }
178 }
179 if let Some(mom) = state.algorithm_specific.get("momentum") {
180 if !mom.is_empty() {
181 self.momentum = mom[0];
182 }
183 }
184 }
185
186 fn set_callback(&mut self, callback: TrainingCallback<T>) {
187 self.callback = Some(callback);
188 }
189
190 fn call_callback(
191 &mut self,
192 epoch: usize,
193 network: &Network<T>,
194 data: &TrainingData<T>,
195 ) -> bool {
196 let error = self.calculate_error(network, data);
197 if let Some(ref mut callback) = self.callback {
198 callback(epoch, error)
199 } else {
200 true
201 }
202 }
203}
204
205pub struct BatchBackprop<T: Float + Send> {
208 learning_rate: T,
209 momentum: T,
210 error_function: Box<dyn ErrorFunction<T>>,
211 previous_weight_deltas: Vec<Vec<T>>,
212 previous_bias_deltas: Vec<Vec<T>>,
213 callback: Option<TrainingCallback<T>>,
214}
215
216impl<T: Float + Send> BatchBackprop<T> {
217 pub fn new(learning_rate: T) -> Self {
218 Self {
219 learning_rate,
220 momentum: T::zero(),
221 error_function: Box::new(MseError),
222 previous_weight_deltas: Vec::new(),
223 previous_bias_deltas: Vec::new(),
224 callback: None,
225 }
226 }
227
228 pub fn with_momentum(mut self, momentum: T) -> Self {
229 self.momentum = momentum;
230 self
231 }
232
233 pub fn with_error_function(mut self, error_function: Box<dyn ErrorFunction<T>>) -> Self {
234 self.error_function = error_function;
235 self
236 }
237
238 fn initialize_deltas(&mut self, network: &Network<T>) {
239 if self.previous_weight_deltas.is_empty() {
240 self.previous_weight_deltas = network
241 .layers
242 .iter()
243 .skip(1) .map(|layer| {
245 let num_neurons = layer.neurons.len();
246 let num_connections = if layer.neurons.is_empty() {
247 0
248 } else {
249 layer.neurons[0].connections.len()
250 };
251 vec![T::zero(); num_neurons * num_connections]
252 })
253 .collect();
254 self.previous_bias_deltas = network
255 .layers
256 .iter()
257 .skip(1) .map(|layer| vec![T::zero(); layer.neurons.len()])
259 .collect();
260 }
261 }
262}
263
264impl<T: Float + Send> TrainingAlgorithm<T> for BatchBackprop<T> {
265 fn train_epoch(
266 &mut self,
267 network: &mut Network<T>,
268 data: &TrainingData<T>,
269 ) -> Result<T, TrainingError> {
270 self.initialize_deltas(network);
271
272 let mut total_error = T::zero();
273
274 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
276 let output = network.run(input);
277 total_error = total_error + self.error_function.calculate(&output, desired_output);
278
279 }
281
282 Ok(total_error / T::from(data.inputs.len()).unwrap())
286 }
287
288 fn calculate_error(&self, network: &Network<T>, data: &TrainingData<T>) -> T {
289 let mut total_error = T::zero();
290 let mut network_clone = network.clone();
291
292 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
293 let output = network_clone.run(input);
294 total_error = total_error + self.error_function.calculate(&output, desired_output);
295 }
296
297 total_error / T::from(data.inputs.len()).unwrap()
298 }
299
300 fn count_bit_fails(
301 &self,
302 network: &Network<T>,
303 data: &TrainingData<T>,
304 bit_fail_limit: T,
305 ) -> usize {
306 let mut bit_fails = 0;
307 let mut network_clone = network.clone();
308
309 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
310 let output = network_clone.run(input);
311 for (&actual, &desired) in output.iter().zip(desired_output.iter()) {
312 if (actual - desired).abs() > bit_fail_limit {
313 bit_fails += 1;
314 }
315 }
316 }
317
318 bit_fails
319 }
320
321 fn save_state(&self) -> TrainingState<T> {
322 let mut state = HashMap::new();
323 state.insert("learning_rate".to_string(), vec![self.learning_rate]);
324 state.insert("momentum".to_string(), vec![self.momentum]);
325
326 TrainingState {
327 epoch: 0,
328 best_error: T::from(f32::MAX).unwrap(),
329 algorithm_specific: state,
330 }
331 }
332
333 fn restore_state(&mut self, state: TrainingState<T>) {
334 if let Some(lr) = state.algorithm_specific.get("learning_rate") {
335 if !lr.is_empty() {
336 self.learning_rate = lr[0];
337 }
338 }
339 if let Some(mom) = state.algorithm_specific.get("momentum") {
340 if !mom.is_empty() {
341 self.momentum = mom[0];
342 }
343 }
344 }
345
346 fn set_callback(&mut self, callback: TrainingCallback<T>) {
347 self.callback = Some(callback);
348 }
349
350 fn call_callback(
351 &mut self,
352 epoch: usize,
353 network: &Network<T>,
354 data: &TrainingData<T>,
355 ) -> bool {
356 let error = self.calculate_error(network, data);
357 if let Some(ref mut callback) = self.callback {
358 callback(epoch, error)
359 } else {
360 true
361 }
362 }
363}