ruv_fann/training/
backprop.rs

1//! Backpropagation training algorithms
2
3use super::*;
4use num_traits::Float;
5use std::collections::HashMap;
6
7/// Incremental (online) backpropagation
8/// Updates weights after each training pattern
9pub struct IncrementalBackprop<T: Float + Send + Default> {
10    learning_rate: T,
11    momentum: T,
12    error_function: Box<dyn ErrorFunction<T>>,
13    previous_weight_deltas: Vec<Vec<T>>,
14    previous_bias_deltas: Vec<Vec<T>>,
15    callback: Option<TrainingCallback<T>>,
16}
17
18impl<T: Float + Send + Default> IncrementalBackprop<T> {
19    pub fn new(learning_rate: T) -> Self {
20        Self {
21            learning_rate,
22            momentum: T::zero(),
23            error_function: Box::new(MseError),
24            previous_weight_deltas: Vec::new(),
25            previous_bias_deltas: Vec::new(),
26            callback: None,
27        }
28    }
29
30    pub fn with_momentum(mut self, momentum: T) -> Self {
31        self.momentum = momentum;
32        self
33    }
34
35    pub fn with_error_function(mut self, error_function: Box<dyn ErrorFunction<T>>) -> Self {
36        self.error_function = error_function;
37        self
38    }
39
40    fn initialize_deltas(&mut self, network: &Network<T>) {
41        if self.previous_weight_deltas.is_empty() {
42            self.previous_weight_deltas = network
43                .layers
44                .iter()
45                .skip(1) // Skip input layer
46                .map(|layer| {
47                    let num_neurons = layer.neurons.len();
48                    let num_connections = if layer.neurons.is_empty() {
49                        0
50                    } else {
51                        layer.neurons[0].connections.len()
52                    };
53                    vec![T::zero(); num_neurons * num_connections]
54                })
55                .collect();
56            self.previous_bias_deltas = network
57                .layers
58                .iter()
59                .skip(1) // Skip input layer
60                .map(|layer| vec![T::zero(); layer.neurons.len()])
61                .collect();
62        }
63    }
64}
65
66impl<T: Float + Send + Default> TrainingAlgorithm<T> for IncrementalBackprop<T> {
67    fn train_epoch(
68        &mut self,
69        network: &mut Network<T>,
70        data: &TrainingData<T>,
71    ) -> Result<T, TrainingError> {
72        use super::helpers::*;
73
74        self.initialize_deltas(network);
75
76        let mut total_error = T::zero();
77
78        // Convert network to simplified form for easier manipulation
79        let simple_network = network_to_simple(network);
80
81        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
82            // Forward propagation to get all layer activations
83            let activations = forward_propagate(&simple_network, input);
84
85            // Get output from last layer
86            let output = &activations[activations.len() - 1];
87
88            // Calculate error
89            total_error = total_error + self.error_function.calculate(output, desired_output);
90
91            // Calculate gradients using backpropagation
92            let (weight_gradients, bias_gradients) = calculate_gradients(
93                &simple_network,
94                &activations,
95                desired_output,
96                self.error_function.as_ref(),
97            );
98
99            // Update weights and biases immediately (incremental/online learning)
100            // Apply momentum
101            for layer_idx in 0..weight_gradients.len() {
102                // Update weight deltas with momentum
103                for (i, &grad) in weight_gradients[layer_idx].iter().enumerate() {
104                    let delta = self.learning_rate * grad
105                        + self.momentum * self.previous_weight_deltas[layer_idx][i];
106                    self.previous_weight_deltas[layer_idx][i] = delta;
107                }
108
109                // Update bias deltas with momentum
110                for (i, &grad) in bias_gradients[layer_idx].iter().enumerate() {
111                    let delta = self.learning_rate * grad
112                        + self.momentum * self.previous_bias_deltas[layer_idx][i];
113                    self.previous_bias_deltas[layer_idx][i] = delta;
114                }
115            }
116
117            // Apply the updates to the actual network
118            apply_updates_to_network(
119                network,
120                &self.previous_weight_deltas,
121                &self.previous_bias_deltas,
122            );
123        }
124
125        Ok(total_error / T::from(data.inputs.len()).unwrap())
126    }
127
128    fn calculate_error(&self, network: &Network<T>, data: &TrainingData<T>) -> T {
129        let mut total_error = T::zero();
130        let mut network_clone = network.clone();
131
132        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
133            let output = network_clone.run(input);
134            total_error = total_error + self.error_function.calculate(&output, desired_output);
135        }
136
137        total_error / T::from(data.inputs.len()).unwrap()
138    }
139
140    fn count_bit_fails(
141        &self,
142        network: &Network<T>,
143        data: &TrainingData<T>,
144        bit_fail_limit: T,
145    ) -> usize {
146        let mut bit_fails = 0;
147        let mut network_clone = network.clone();
148
149        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
150            let output = network_clone.run(input);
151            for (&actual, &desired) in output.iter().zip(desired_output.iter()) {
152                if (actual - desired).abs() > bit_fail_limit {
153                    bit_fails += 1;
154                }
155            }
156        }
157
158        bit_fails
159    }
160
161    fn save_state(&self) -> TrainingState<T> {
162        let mut state = HashMap::new();
163        state.insert("learning_rate".to_string(), vec![self.learning_rate]);
164        state.insert("momentum".to_string(), vec![self.momentum]);
165
166        TrainingState {
167            epoch: 0,
168            best_error: T::from(f32::MAX).unwrap(),
169            algorithm_specific: state,
170        }
171    }
172
173    fn restore_state(&mut self, state: TrainingState<T>) {
174        if let Some(lr) = state.algorithm_specific.get("learning_rate") {
175            if !lr.is_empty() {
176                self.learning_rate = lr[0];
177            }
178        }
179        if let Some(mom) = state.algorithm_specific.get("momentum") {
180            if !mom.is_empty() {
181                self.momentum = mom[0];
182            }
183        }
184    }
185
186    fn set_callback(&mut self, callback: TrainingCallback<T>) {
187        self.callback = Some(callback);
188    }
189
190    fn call_callback(
191        &mut self,
192        epoch: usize,
193        network: &Network<T>,
194        data: &TrainingData<T>,
195    ) -> bool {
196        let error = self.calculate_error(network, data);
197        if let Some(ref mut callback) = self.callback {
198            callback(epoch, error)
199        } else {
200            true
201        }
202    }
203}
204
205/// Batch backpropagation
206/// Accumulates gradients over entire dataset before updating weights
207pub struct BatchBackprop<T: Float + Send> {
208    learning_rate: T,
209    momentum: T,
210    error_function: Box<dyn ErrorFunction<T>>,
211    previous_weight_deltas: Vec<Vec<T>>,
212    previous_bias_deltas: Vec<Vec<T>>,
213    callback: Option<TrainingCallback<T>>,
214}
215
216impl<T: Float + Send> BatchBackprop<T> {
217    pub fn new(learning_rate: T) -> Self {
218        Self {
219            learning_rate,
220            momentum: T::zero(),
221            error_function: Box::new(MseError),
222            previous_weight_deltas: Vec::new(),
223            previous_bias_deltas: Vec::new(),
224            callback: None,
225        }
226    }
227
228    pub fn with_momentum(mut self, momentum: T) -> Self {
229        self.momentum = momentum;
230        self
231    }
232
233    pub fn with_error_function(mut self, error_function: Box<dyn ErrorFunction<T>>) -> Self {
234        self.error_function = error_function;
235        self
236    }
237
238    fn initialize_deltas(&mut self, network: &Network<T>) {
239        if self.previous_weight_deltas.is_empty() {
240            self.previous_weight_deltas = network
241                .layers
242                .iter()
243                .skip(1) // Skip input layer
244                .map(|layer| {
245                    let num_neurons = layer.neurons.len();
246                    let num_connections = if layer.neurons.is_empty() {
247                        0
248                    } else {
249                        layer.neurons[0].connections.len()
250                    };
251                    vec![T::zero(); num_neurons * num_connections]
252                })
253                .collect();
254            self.previous_bias_deltas = network
255                .layers
256                .iter()
257                .skip(1) // Skip input layer
258                .map(|layer| vec![T::zero(); layer.neurons.len()])
259                .collect();
260        }
261    }
262}
263
264impl<T: Float + Send> TrainingAlgorithm<T> for BatchBackprop<T> {
265    fn train_epoch(
266        &mut self,
267        network: &mut Network<T>,
268        data: &TrainingData<T>,
269    ) -> Result<T, TrainingError> {
270        self.initialize_deltas(network);
271
272        let mut total_error = T::zero();
273
274        // Accumulate gradients over all patterns
275        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
276            let output = network.run(input);
277            total_error = total_error + self.error_function.calculate(&output, desired_output);
278
279            // Accumulate gradients here (placeholder)
280        }
281
282        // Update weights after processing all patterns
283        // Placeholder for actual batch update implementation
284
285        Ok(total_error / T::from(data.inputs.len()).unwrap())
286    }
287
288    fn calculate_error(&self, network: &Network<T>, data: &TrainingData<T>) -> T {
289        let mut total_error = T::zero();
290        let mut network_clone = network.clone();
291
292        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
293            let output = network_clone.run(input);
294            total_error = total_error + self.error_function.calculate(&output, desired_output);
295        }
296
297        total_error / T::from(data.inputs.len()).unwrap()
298    }
299
300    fn count_bit_fails(
301        &self,
302        network: &Network<T>,
303        data: &TrainingData<T>,
304        bit_fail_limit: T,
305    ) -> usize {
306        let mut bit_fails = 0;
307        let mut network_clone = network.clone();
308
309        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
310            let output = network_clone.run(input);
311            for (&actual, &desired) in output.iter().zip(desired_output.iter()) {
312                if (actual - desired).abs() > bit_fail_limit {
313                    bit_fails += 1;
314                }
315            }
316        }
317
318        bit_fails
319    }
320
321    fn save_state(&self) -> TrainingState<T> {
322        let mut state = HashMap::new();
323        state.insert("learning_rate".to_string(), vec![self.learning_rate]);
324        state.insert("momentum".to_string(), vec![self.momentum]);
325
326        TrainingState {
327            epoch: 0,
328            best_error: T::from(f32::MAX).unwrap(),
329            algorithm_specific: state,
330        }
331    }
332
333    fn restore_state(&mut self, state: TrainingState<T>) {
334        if let Some(lr) = state.algorithm_specific.get("learning_rate") {
335            if !lr.is_empty() {
336                self.learning_rate = lr[0];
337            }
338        }
339        if let Some(mom) = state.algorithm_specific.get("momentum") {
340            if !mom.is_empty() {
341                self.momentum = mom[0];
342            }
343        }
344    }
345
346    fn set_callback(&mut self, callback: TrainingCallback<T>) {
347        self.callback = Some(callback);
348    }
349
350    fn call_callback(
351        &mut self,
352        epoch: usize,
353        network: &Network<T>,
354        data: &TrainingData<T>,
355    ) -> bool {
356        let error = self.calculate_error(network, data);
357        if let Some(ref mut callback) = self.callback {
358            callback(epoch, error)
359        } else {
360            true
361        }
362    }
363}