1use crate::Network;
12use num_traits::Float;
13use std::collections::HashMap;
14use thiserror::Error;
15
16#[derive(Debug, Clone)]
20pub struct TrainingData<T: Float> {
21 pub inputs: Vec<Vec<T>>,
22 pub outputs: Vec<Vec<T>>,
23}
24
25#[derive(Debug, Clone)]
27pub struct ParallelTrainingOptions {
28 pub num_threads: usize,
30 pub batch_size: usize,
32 pub parallel_gradients: bool,
34 pub parallel_error_calc: bool,
36}
37
38impl Default for ParallelTrainingOptions {
39 fn default() -> Self {
40 Self {
41 num_threads: 0, batch_size: 32,
43 parallel_gradients: true,
44 parallel_error_calc: true,
45 }
46 }
47}
48
49#[derive(Error, Debug)]
51pub enum TrainingError {
52 #[error("Invalid training data: {0}")]
53 InvalidData(String),
54
55 #[error("Network configuration error: {0}")]
56 NetworkError(String),
57
58 #[error("Training failed: {0}")]
59 TrainingFailed(String),
60}
61
62pub trait ErrorFunction<T: Float>: Send + Sync {
64 fn calculate(&self, actual: &[T], desired: &[T]) -> T;
66
67 fn derivative(&self, actual: T, desired: T) -> T;
69}
70
71#[derive(Clone)]
73pub struct MseError;
74
75impl<T: Float> ErrorFunction<T> for MseError {
76 fn calculate(&self, actual: &[T], desired: &[T]) -> T {
77 let sum = actual
78 .iter()
79 .zip(desired.iter())
80 .map(|(&a, &d)| {
81 let diff = a - d;
82 diff * diff
83 })
84 .fold(T::zero(), |acc, x| acc + x);
85 sum / T::from(actual.len()).unwrap()
86 }
87
88 fn derivative(&self, actual: T, desired: T) -> T {
89 T::from(2.0).unwrap() * (actual - desired)
90 }
91}
92
93#[derive(Clone)]
95pub struct MaeError;
96
97impl<T: Float> ErrorFunction<T> for MaeError {
98 fn calculate(&self, actual: &[T], desired: &[T]) -> T {
99 let sum = actual
100 .iter()
101 .zip(desired.iter())
102 .map(|(&a, &d)| (a - d).abs())
103 .fold(T::zero(), |acc, x| acc + x);
104 sum / T::from(actual.len()).unwrap()
105 }
106
107 fn derivative(&self, actual: T, desired: T) -> T {
108 if actual > desired {
109 T::one()
110 } else if actual < desired {
111 -T::one()
112 } else {
113 T::zero()
114 }
115 }
116}
117
118#[derive(Clone)]
120pub struct TanhError;
121
122impl<T: Float> ErrorFunction<T> for TanhError {
123 fn calculate(&self, actual: &[T], desired: &[T]) -> T {
124 let sum = actual
125 .iter()
126 .zip(desired.iter())
127 .map(|(&a, &d)| {
128 let diff = a - d;
129 let tanh_diff = diff.tanh();
130 tanh_diff * tanh_diff
131 })
132 .fold(T::zero(), |acc, x| acc + x);
133 sum / T::from(actual.len()).unwrap()
134 }
135
136 fn derivative(&self, actual: T, desired: T) -> T {
137 let diff = actual - desired;
138 let tanh_diff = diff.tanh();
139 T::from(2.0).unwrap() * tanh_diff * (T::one() - tanh_diff * tanh_diff)
140 }
141}
142
143pub trait LearningRateSchedule<T: Float> {
145 fn get_rate(&mut self, epoch: usize) -> T;
146}
147
148pub struct ExponentialDecay<T: Float> {
150 initial_rate: T,
151 decay_rate: T,
152}
153
154impl<T: Float> ExponentialDecay<T> {
155 pub fn new(initial_rate: T, decay_rate: T) -> Self {
156 Self {
157 initial_rate,
158 decay_rate,
159 }
160 }
161}
162
163impl<T: Float> LearningRateSchedule<T> for ExponentialDecay<T> {
164 fn get_rate(&mut self, epoch: usize) -> T {
165 self.initial_rate * self.decay_rate.powi(epoch as i32)
166 }
167}
168
169pub struct StepDecay<T: Float> {
171 initial_rate: T,
172 drop_rate: T,
173 epochs_per_drop: usize,
174}
175
176impl<T: Float> StepDecay<T> {
177 pub fn new(initial_rate: T, drop_rate: T, epochs_per_drop: usize) -> Self {
178 Self {
179 initial_rate,
180 drop_rate,
181 epochs_per_drop,
182 }
183 }
184}
185
186impl<T: Float> LearningRateSchedule<T> for StepDecay<T> {
187 fn get_rate(&mut self, epoch: usize) -> T {
188 let drops = epoch / self.epochs_per_drop;
189 self.initial_rate * self.drop_rate.powi(drops as i32)
190 }
191}
192
193#[derive(Clone, Debug)]
195pub struct TrainingState<T: Float> {
196 pub epoch: usize,
197 pub best_error: T,
198 pub algorithm_specific: HashMap<String, Vec<T>>,
199}
200
201pub trait StopCriteria<T: Float> {
203 fn should_stop(
204 &self,
205 trainer: &dyn TrainingAlgorithm<T>,
206 network: &Network<T>,
207 data: &TrainingData<T>,
208 epoch: usize,
209 ) -> bool;
210}
211
212pub struct MseStopCriteria<T: Float> {
214 pub target_error: T,
215}
216
217impl<T: Float> StopCriteria<T> for MseStopCriteria<T> {
218 fn should_stop(
219 &self,
220 trainer: &dyn TrainingAlgorithm<T>,
221 network: &Network<T>,
222 data: &TrainingData<T>,
223 _epoch: usize,
224 ) -> bool {
225 let error = trainer.calculate_error(network, data);
226 error <= self.target_error
227 }
228}
229
230pub struct BitFailStopCriteria<T: Float> {
232 pub target_bit_fail: usize,
233 pub bit_fail_limit: T,
234}
235
236impl<T: Float> StopCriteria<T> for BitFailStopCriteria<T> {
237 fn should_stop(
238 &self,
239 trainer: &dyn TrainingAlgorithm<T>,
240 network: &Network<T>,
241 data: &TrainingData<T>,
242 _epoch: usize,
243 ) -> bool {
244 let bit_fails = trainer.count_bit_fails(network, data, self.bit_fail_limit);
245 bit_fails <= self.target_bit_fail
246 }
247}
248
249pub type TrainingCallback<T> = Box<dyn FnMut(usize, T) -> bool + Send>;
251
252pub trait TrainingAlgorithm<T: Float>: Send {
254 fn train_epoch(
256 &mut self,
257 network: &mut Network<T>,
258 data: &TrainingData<T>,
259 ) -> Result<T, TrainingError>;
260
261 fn calculate_error(&self, network: &Network<T>, data: &TrainingData<T>) -> T;
263
264 fn count_bit_fails(
266 &self,
267 network: &Network<T>,
268 data: &TrainingData<T>,
269 bit_fail_limit: T,
270 ) -> usize;
271
272 fn save_state(&self) -> TrainingState<T>;
274
275 fn restore_state(&mut self, state: TrainingState<T>);
277
278 fn set_callback(&mut self, callback: TrainingCallback<T>);
280
281 fn call_callback(&mut self, epoch: usize, network: &Network<T>, data: &TrainingData<T>)
283 -> bool;
284}
285
286mod backprop;
288mod quickprop;
289mod rprop;
290
291pub use backprop::{BatchBackprop, IncrementalBackprop};
293pub use quickprop::Quickprop;
294pub use rprop::Rprop;
295
296pub(crate) mod helpers {
298 use super::*;
299
300 #[derive(Debug, Clone)]
302 pub struct SimpleNetwork<T: Float> {
303 pub layer_sizes: Vec<usize>,
304 pub weights: Vec<Vec<T>>,
305 pub biases: Vec<Vec<T>>,
306 }
307
308 pub fn network_to_simple<T: Float + Default>(network: &Network<T>) -> SimpleNetwork<T> {
310 let layer_sizes: Vec<usize> = network
311 .layers
312 .iter()
313 .map(|layer| layer.num_regular_neurons())
314 .collect();
315
316 let mut weights = Vec::new();
318 let mut biases = Vec::new();
319
320 for layer_idx in 1..network.layers.len() {
321 let current_layer = &network.layers[layer_idx];
322 let _prev_layer_size = network.layers[layer_idx - 1].size(); let mut layer_weights = Vec::new();
325 let mut layer_biases = Vec::new();
326
327 for neuron in ¤t_layer.neurons {
328 if !neuron.is_bias {
329 let bias = if !neuron.connections.is_empty() {
331 neuron.connections[0].weight
332 } else {
333 T::zero()
334 };
335 layer_biases.push(bias);
336
337 for connection in neuron.connections.iter().skip(1) {
339 layer_weights.push(connection.weight);
340 }
341 }
342 }
343
344 weights.push(layer_weights);
345 biases.push(layer_biases);
346 }
347
348 SimpleNetwork {
349 layer_sizes,
350 weights,
351 biases,
352 }
353 }
354
355 pub fn apply_updates_to_network<T: Float>(
357 network: &mut Network<T>,
358 weight_updates: &[Vec<T>],
359 bias_updates: &[Vec<T>],
360 ) {
361 for layer_idx in 1..network.layers.len() {
362 let current_layer = &mut network.layers[layer_idx];
363 let weight_layer_idx = layer_idx - 1;
364
365 let mut neuron_idx = 0;
366 let mut weight_idx = 0;
367
368 for neuron in &mut current_layer.neurons {
369 if !neuron.is_bias {
370 if !neuron.connections.is_empty() {
372 neuron.connections[0].weight = neuron.connections[0].weight
373 + bias_updates[weight_layer_idx][neuron_idx];
374 }
375
376 for connection in neuron.connections.iter_mut().skip(1) {
378 connection.weight =
379 connection.weight + weight_updates[weight_layer_idx][weight_idx];
380 weight_idx += 1;
381 }
382
383 neuron_idx += 1;
384 }
385 }
386 }
387 }
388
389 pub fn sigmoid<T: Float>(x: T) -> T {
391 T::one() / (T::one() + (-x).exp())
392 }
393
394 pub fn sigmoid_derivative<T: Float>(output: T) -> T {
396 output * (T::one() - output)
397 }
398
399 pub fn forward_propagate<T: Float>(network: &SimpleNetwork<T>, input: &[T]) -> Vec<Vec<T>> {
401 let mut activations = vec![input.to_vec()];
402
403 for layer_idx in 1..network.layer_sizes.len() {
404 let prev_activations = &activations[layer_idx - 1];
405 let weights = &network.weights[layer_idx - 1];
406 let biases = &network.biases[layer_idx - 1];
407
408 let mut layer_activations = Vec::with_capacity(network.layer_sizes[layer_idx]);
409
410 for neuron_idx in 0..network.layer_sizes[layer_idx] {
411 let mut sum = biases[neuron_idx];
412 let weight_start = neuron_idx * prev_activations.len();
413
414 for (input_idx, &input_val) in prev_activations.iter().enumerate() {
415 if weight_start + input_idx < weights.len() {
416 sum = sum + input_val * weights[weight_start + input_idx];
417 }
418 }
419
420 layer_activations.push(sigmoid(sum));
421 }
422
423 activations.push(layer_activations);
424 }
425
426 activations
427 }
428
429 pub fn calculate_gradients<T: Float>(
431 network: &SimpleNetwork<T>,
432 activations: &[Vec<T>],
433 desired_output: &[T],
434 error_function: &dyn ErrorFunction<T>,
435 ) -> (Vec<Vec<T>>, Vec<Vec<T>>) {
436 let mut weight_gradients = network
437 .weights
438 .iter()
439 .map(|w| vec![T::zero(); w.len()])
440 .collect::<Vec<_>>();
441 let mut bias_gradients = network
442 .biases
443 .iter()
444 .map(|b| vec![T::zero(); b.len()])
445 .collect::<Vec<_>>();
446
447 let output_idx = activations.len() - 1;
449 let mut errors = vec![];
450
451 let output_errors: Vec<T> = activations[output_idx]
452 .iter()
453 .zip(desired_output.iter())
454 .map(|(&actual, &desired)| {
455 error_function.derivative(actual, desired) * sigmoid_derivative(actual)
456 })
457 .collect();
458
459 errors.push(output_errors);
460
461 for layer_idx in (1..network.layer_sizes.len() - 1).rev() {
463 let mut layer_errors = vec![T::zero(); network.layer_sizes[layer_idx]];
464
465 for neuron_idx in 0..network.layer_sizes[layer_idx] {
466 let mut error_sum = T::zero();
467
468 for next_neuron_idx in 0..network.layer_sizes[layer_idx + 1] {
470 let weight_idx = next_neuron_idx * network.layer_sizes[layer_idx] + neuron_idx;
471 if weight_idx < network.weights[layer_idx].len() {
472 error_sum = error_sum
473 + errors[0][next_neuron_idx] * network.weights[layer_idx][weight_idx];
474 }
475 }
476
477 layer_errors[neuron_idx] =
478 error_sum * sigmoid_derivative(activations[layer_idx][neuron_idx]);
479 }
480
481 errors.insert(0, layer_errors);
482 }
483
484 for layer_idx in 0..network.weights.len() {
486 let prev_activations = &activations[layer_idx];
487 let layer_errors = &errors[layer_idx];
488
489 for neuron_idx in 0..layer_errors.len() {
490 bias_gradients[layer_idx][neuron_idx] = layer_errors[neuron_idx];
492
493 let weight_start = neuron_idx * prev_activations.len();
495 for (input_idx, &activation) in prev_activations.iter().enumerate() {
496 if weight_start + input_idx < weight_gradients[layer_idx].len() {
497 weight_gradients[layer_idx][weight_start + input_idx] =
498 layer_errors[neuron_idx] * activation;
499 }
500 }
501 }
502 }
503
504 (weight_gradients, bias_gradients)
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_sigmoid() {
514 use helpers::sigmoid;
515
516 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
517 assert!(sigmoid(10.0) > 0.99);
518 assert!(sigmoid(-10.0) < 0.01);
519 }
520}