ruv_fann/training/
mod.rs

1//! Training algorithms for neural networks
2//!
3//! This module implements various training algorithms including:
4//! - Incremental (online) backpropagation
5//! - Batch backpropagation
6//! - RPROP (Resilient Propagation)
7//! - Quickprop
8//!
9//! All training algorithms implement the `TrainingAlgorithm` trait for extensibility.
10
11use crate::Network;
12use num_traits::Float;
13use std::collections::HashMap;
14use thiserror::Error;
15
16// #[cfg(feature = "parallel")]
17// use rayon::prelude::*;
18
19#[derive(Debug, Clone)]
20pub struct TrainingData<T: Float> {
21    pub inputs: Vec<Vec<T>>,
22    pub outputs: Vec<Vec<T>>,
23}
24
25/// Options for parallel training
26#[derive(Debug, Clone)]
27pub struct ParallelTrainingOptions {
28    /// Number of threads to use (0 = use all available cores)
29    pub num_threads: usize,
30    /// Batch size for parallel processing
31    pub batch_size: usize,
32    /// Whether to use parallel gradient computation
33    pub parallel_gradients: bool,
34    /// Whether to use parallel error calculation
35    pub parallel_error_calc: bool,
36}
37
38impl Default for ParallelTrainingOptions {
39    fn default() -> Self {
40        Self {
41            num_threads: 0, // Use all available cores
42            batch_size: 32,
43            parallel_gradients: true,
44            parallel_error_calc: true,
45        }
46    }
47}
48
49/// Error types for training operations
50#[derive(Error, Debug)]
51pub enum TrainingError {
52    #[error("Invalid training data: {0}")]
53    InvalidData(String),
54
55    #[error("Network configuration error: {0}")]
56    NetworkError(String),
57
58    #[error("Training failed: {0}")]
59    TrainingFailed(String),
60}
61
62/// Trait for error/loss functions
63pub trait ErrorFunction<T: Float>: Send + Sync {
64    /// Calculate the error between actual and desired outputs
65    fn calculate(&self, actual: &[T], desired: &[T]) -> T;
66
67    /// Calculate the derivative of the error function
68    fn derivative(&self, actual: T, desired: T) -> T;
69}
70
71/// Mean Squared Error (MSE)
72#[derive(Clone)]
73pub struct MseError;
74
75impl<T: Float> ErrorFunction<T> for MseError {
76    fn calculate(&self, actual: &[T], desired: &[T]) -> T {
77        let sum = actual
78            .iter()
79            .zip(desired.iter())
80            .map(|(&a, &d)| {
81                let diff = a - d;
82                diff * diff
83            })
84            .fold(T::zero(), |acc, x| acc + x);
85        sum / T::from(actual.len()).unwrap()
86    }
87
88    fn derivative(&self, actual: T, desired: T) -> T {
89        T::from(2.0).unwrap() * (actual - desired)
90    }
91}
92
93/// Mean Absolute Error (MAE)
94#[derive(Clone)]
95pub struct MaeError;
96
97impl<T: Float> ErrorFunction<T> for MaeError {
98    fn calculate(&self, actual: &[T], desired: &[T]) -> T {
99        let sum = actual
100            .iter()
101            .zip(desired.iter())
102            .map(|(&a, &d)| (a - d).abs())
103            .fold(T::zero(), |acc, x| acc + x);
104        sum / T::from(actual.len()).unwrap()
105    }
106
107    fn derivative(&self, actual: T, desired: T) -> T {
108        if actual > desired {
109            T::one()
110        } else if actual < desired {
111            -T::one()
112        } else {
113            T::zero()
114        }
115    }
116}
117
118/// Tanh Error Function
119#[derive(Clone)]
120pub struct TanhError;
121
122impl<T: Float> ErrorFunction<T> for TanhError {
123    fn calculate(&self, actual: &[T], desired: &[T]) -> T {
124        let sum = actual
125            .iter()
126            .zip(desired.iter())
127            .map(|(&a, &d)| {
128                let diff = a - d;
129                let tanh_diff = diff.tanh();
130                tanh_diff * tanh_diff
131            })
132            .fold(T::zero(), |acc, x| acc + x);
133        sum / T::from(actual.len()).unwrap()
134    }
135
136    fn derivative(&self, actual: T, desired: T) -> T {
137        let diff = actual - desired;
138        let tanh_diff = diff.tanh();
139        T::from(2.0).unwrap() * tanh_diff * (T::one() - tanh_diff * tanh_diff)
140    }
141}
142
143/// Learning rate schedule trait
144pub trait LearningRateSchedule<T: Float> {
145    fn get_rate(&mut self, epoch: usize) -> T;
146}
147
148/// Exponential decay learning rate schedule
149pub struct ExponentialDecay<T: Float> {
150    initial_rate: T,
151    decay_rate: T,
152}
153
154impl<T: Float> ExponentialDecay<T> {
155    pub fn new(initial_rate: T, decay_rate: T) -> Self {
156        Self {
157            initial_rate,
158            decay_rate,
159        }
160    }
161}
162
163impl<T: Float> LearningRateSchedule<T> for ExponentialDecay<T> {
164    fn get_rate(&mut self, epoch: usize) -> T {
165        self.initial_rate * self.decay_rate.powi(epoch as i32)
166    }
167}
168
169/// Step decay learning rate schedule
170pub struct StepDecay<T: Float> {
171    initial_rate: T,
172    drop_rate: T,
173    epochs_per_drop: usize,
174}
175
176impl<T: Float> StepDecay<T> {
177    pub fn new(initial_rate: T, drop_rate: T, epochs_per_drop: usize) -> Self {
178        Self {
179            initial_rate,
180            drop_rate,
181            epochs_per_drop,
182        }
183    }
184}
185
186impl<T: Float> LearningRateSchedule<T> for StepDecay<T> {
187    fn get_rate(&mut self, epoch: usize) -> T {
188        let drops = epoch / self.epochs_per_drop;
189        self.initial_rate * self.drop_rate.powi(drops as i32)
190    }
191}
192
193/// Training state that can be saved and restored
194#[derive(Clone, Debug)]
195pub struct TrainingState<T: Float> {
196    pub epoch: usize,
197    pub best_error: T,
198    pub algorithm_specific: HashMap<String, Vec<T>>,
199}
200
201/// Stop criteria trait
202pub trait StopCriteria<T: Float> {
203    fn should_stop(
204        &self,
205        trainer: &dyn TrainingAlgorithm<T>,
206        network: &Network<T>,
207        data: &TrainingData<T>,
208        epoch: usize,
209    ) -> bool;
210}
211
212/// MSE-based stop criteria
213pub struct MseStopCriteria<T: Float> {
214    pub target_error: T,
215}
216
217impl<T: Float> StopCriteria<T> for MseStopCriteria<T> {
218    fn should_stop(
219        &self,
220        trainer: &dyn TrainingAlgorithm<T>,
221        network: &Network<T>,
222        data: &TrainingData<T>,
223        _epoch: usize,
224    ) -> bool {
225        let error = trainer.calculate_error(network, data);
226        error <= self.target_error
227    }
228}
229
230/// Bit fail based stop criteria
231pub struct BitFailStopCriteria<T: Float> {
232    pub target_bit_fail: usize,
233    pub bit_fail_limit: T,
234}
235
236impl<T: Float> StopCriteria<T> for BitFailStopCriteria<T> {
237    fn should_stop(
238        &self,
239        trainer: &dyn TrainingAlgorithm<T>,
240        network: &Network<T>,
241        data: &TrainingData<T>,
242        _epoch: usize,
243    ) -> bool {
244        let bit_fails = trainer.count_bit_fails(network, data, self.bit_fail_limit);
245        bit_fails <= self.target_bit_fail
246    }
247}
248
249/// Callback function type for training progress
250pub type TrainingCallback<T> = Box<dyn FnMut(usize, T) -> bool + Send>;
251
252/// Main trait for training algorithms
253pub trait TrainingAlgorithm<T: Float>: Send {
254    /// Train for one epoch
255    fn train_epoch(
256        &mut self,
257        network: &mut Network<T>,
258        data: &TrainingData<T>,
259    ) -> Result<T, TrainingError>;
260
261    /// Calculate the current error
262    fn calculate_error(&self, network: &Network<T>, data: &TrainingData<T>) -> T;
263
264    /// Count bit fails
265    fn count_bit_fails(
266        &self,
267        network: &Network<T>,
268        data: &TrainingData<T>,
269        bit_fail_limit: T,
270    ) -> usize;
271
272    /// Save training state
273    fn save_state(&self) -> TrainingState<T>;
274
275    /// Restore training state
276    fn restore_state(&mut self, state: TrainingState<T>);
277
278    /// Set a callback function
279    fn set_callback(&mut self, callback: TrainingCallback<T>);
280
281    /// Call the callback if set
282    fn call_callback(&mut self, epoch: usize, network: &Network<T>, data: &TrainingData<T>)
283        -> bool;
284}
285
286// Module declarations for specific algorithms
287mod backprop;
288mod quickprop;
289mod rprop;
290
291// Re-export main types
292pub use backprop::{BatchBackprop, IncrementalBackprop};
293pub use quickprop::Quickprop;
294pub use rprop::Rprop;
295
296/// Helper functions for forward propagation and gradient calculation
297pub(crate) mod helpers {
298    use super::*;
299
300    /// Simple network representation for training algorithms
301    #[derive(Debug, Clone)]
302    pub struct SimpleNetwork<T: Float> {
303        pub layer_sizes: Vec<usize>,
304        pub weights: Vec<Vec<T>>,
305        pub biases: Vec<Vec<T>>,
306    }
307
308    /// Convert a real Network to a simplified representation for training
309    pub fn network_to_simple<T: Float + Default>(network: &Network<T>) -> SimpleNetwork<T> {
310        let layer_sizes: Vec<usize> = network
311            .layers
312            .iter()
313            .map(|layer| layer.num_regular_neurons())
314            .collect();
315
316        // Extract weights and biases from the complex structure
317        let mut weights = Vec::new();
318        let mut biases = Vec::new();
319
320        for layer_idx in 1..network.layers.len() {
321            let current_layer = &network.layers[layer_idx];
322            let _prev_layer_size = network.layers[layer_idx - 1].size(); // Include bias neurons
323
324            let mut layer_weights = Vec::new();
325            let mut layer_biases = Vec::new();
326
327            for neuron in &current_layer.neurons {
328                if !neuron.is_bias {
329                    // Extract bias (connection index 0 should be bias)
330                    let bias = if !neuron.connections.is_empty() {
331                        neuron.connections[0].weight
332                    } else {
333                        T::zero()
334                    };
335                    layer_biases.push(bias);
336
337                    // Extract weights (skip bias connection)
338                    for connection in neuron.connections.iter().skip(1) {
339                        layer_weights.push(connection.weight);
340                    }
341                }
342            }
343
344            weights.push(layer_weights);
345            biases.push(layer_biases);
346        }
347
348        SimpleNetwork {
349            layer_sizes,
350            weights,
351            biases,
352        }
353    }
354
355    /// Apply weight and bias updates back to the real Network
356    pub fn apply_updates_to_network<T: Float>(
357        network: &mut Network<T>,
358        weight_updates: &[Vec<T>],
359        bias_updates: &[Vec<T>],
360    ) {
361        for layer_idx in 1..network.layers.len() {
362            let current_layer = &mut network.layers[layer_idx];
363            let weight_layer_idx = layer_idx - 1;
364
365            let mut neuron_idx = 0;
366            let mut weight_idx = 0;
367
368            for neuron in &mut current_layer.neurons {
369                if !neuron.is_bias {
370                    // Update bias (connection index 0)
371                    if !neuron.connections.is_empty() {
372                        neuron.connections[0].weight = neuron.connections[0].weight
373                            + bias_updates[weight_layer_idx][neuron_idx];
374                    }
375
376                    // Update weights (skip bias connection)
377                    for connection in neuron.connections.iter_mut().skip(1) {
378                        connection.weight =
379                            connection.weight + weight_updates[weight_layer_idx][weight_idx];
380                        weight_idx += 1;
381                    }
382
383                    neuron_idx += 1;
384                }
385            }
386        }
387    }
388
389    /// Activation function that works with our simplified representation
390    pub fn sigmoid<T: Float>(x: T) -> T {
391        T::one() / (T::one() + (-x).exp())
392    }
393
394    /// Sigmoid derivative
395    pub fn sigmoid_derivative<T: Float>(output: T) -> T {
396        output * (T::one() - output)
397    }
398
399    /// Forward propagation through the simplified network
400    pub fn forward_propagate<T: Float>(network: &SimpleNetwork<T>, input: &[T]) -> Vec<Vec<T>> {
401        let mut activations = vec![input.to_vec()];
402
403        for layer_idx in 1..network.layer_sizes.len() {
404            let prev_activations = &activations[layer_idx - 1];
405            let weights = &network.weights[layer_idx - 1];
406            let biases = &network.biases[layer_idx - 1];
407
408            let mut layer_activations = Vec::with_capacity(network.layer_sizes[layer_idx]);
409
410            for neuron_idx in 0..network.layer_sizes[layer_idx] {
411                let mut sum = biases[neuron_idx];
412                let weight_start = neuron_idx * prev_activations.len();
413
414                for (input_idx, &input_val) in prev_activations.iter().enumerate() {
415                    if weight_start + input_idx < weights.len() {
416                        sum = sum + input_val * weights[weight_start + input_idx];
417                    }
418                }
419
420                layer_activations.push(sigmoid(sum));
421            }
422
423            activations.push(layer_activations);
424        }
425
426        activations
427    }
428
429    /// Calculate gradients using backpropagation on simplified network
430    pub fn calculate_gradients<T: Float>(
431        network: &SimpleNetwork<T>,
432        activations: &[Vec<T>],
433        desired_output: &[T],
434        error_function: &dyn ErrorFunction<T>,
435    ) -> (Vec<Vec<T>>, Vec<Vec<T>>) {
436        let mut weight_gradients = network
437            .weights
438            .iter()
439            .map(|w| vec![T::zero(); w.len()])
440            .collect::<Vec<_>>();
441        let mut bias_gradients = network
442            .biases
443            .iter()
444            .map(|b| vec![T::zero(); b.len()])
445            .collect::<Vec<_>>();
446
447        // Calculate output layer errors
448        let output_idx = activations.len() - 1;
449        let mut errors = vec![];
450
451        let output_errors: Vec<T> = activations[output_idx]
452            .iter()
453            .zip(desired_output.iter())
454            .map(|(&actual, &desired)| {
455                error_function.derivative(actual, desired) * sigmoid_derivative(actual)
456            })
457            .collect();
458
459        errors.push(output_errors);
460
461        // Backpropagate errors
462        for layer_idx in (1..network.layer_sizes.len() - 1).rev() {
463            let mut layer_errors = vec![T::zero(); network.layer_sizes[layer_idx]];
464
465            for neuron_idx in 0..network.layer_sizes[layer_idx] {
466                let mut error_sum = T::zero();
467
468                // Sum weighted errors from next layer
469                for next_neuron_idx in 0..network.layer_sizes[layer_idx + 1] {
470                    let weight_idx = next_neuron_idx * network.layer_sizes[layer_idx] + neuron_idx;
471                    if weight_idx < network.weights[layer_idx].len() {
472                        error_sum = error_sum
473                            + errors[0][next_neuron_idx] * network.weights[layer_idx][weight_idx];
474                    }
475                }
476
477                layer_errors[neuron_idx] =
478                    error_sum * sigmoid_derivative(activations[layer_idx][neuron_idx]);
479            }
480
481            errors.insert(0, layer_errors);
482        }
483
484        // Calculate gradients
485        for layer_idx in 0..network.weights.len() {
486            let prev_activations = &activations[layer_idx];
487            let layer_errors = &errors[layer_idx];
488
489            for neuron_idx in 0..layer_errors.len() {
490                // Bias gradient
491                bias_gradients[layer_idx][neuron_idx] = layer_errors[neuron_idx];
492
493                // Weight gradients
494                let weight_start = neuron_idx * prev_activations.len();
495                for (input_idx, &activation) in prev_activations.iter().enumerate() {
496                    if weight_start + input_idx < weight_gradients[layer_idx].len() {
497                        weight_gradients[layer_idx][weight_start + input_idx] =
498                            layer_errors[neuron_idx] * activation;
499                    }
500                }
501            }
502        }
503
504        (weight_gradients, bias_gradients)
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511
512    #[test]
513    fn test_sigmoid() {
514        use helpers::sigmoid;
515
516        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
517        assert!(sigmoid(10.0) > 0.99);
518        assert!(sigmoid(-10.0) < 0.01);
519    }
520}