ruv_fann/training/
quickprop.rs

1//! Quickprop training algorithm
2
3use super::*;
4use num_traits::Float;
5use std::collections::HashMap;
6
7/// Quickprop trainer
8/// An advanced batch training algorithm that uses second-order information
9pub struct Quickprop<T: Float + Send> {
10    learning_rate: T,
11    mu: T,
12    decay: T,
13    error_function: Box<dyn ErrorFunction<T>>,
14
15    // State variables
16    previous_weight_gradients: Vec<Vec<T>>,
17    previous_bias_gradients: Vec<Vec<T>>,
18    previous_weight_deltas: Vec<Vec<T>>,
19    previous_bias_deltas: Vec<Vec<T>>,
20
21    callback: Option<TrainingCallback<T>>,
22}
23
24impl<T: Float + Send> Quickprop<T> {
25    pub fn new() -> Self {
26        Self {
27            learning_rate: T::from(0.7).unwrap(),
28            mu: T::from(1.75).unwrap(),
29            decay: T::from(-0.0001).unwrap(),
30            error_function: Box::new(MseError),
31            previous_weight_gradients: Vec::new(),
32            previous_bias_gradients: Vec::new(),
33            previous_weight_deltas: Vec::new(),
34            previous_bias_deltas: Vec::new(),
35            callback: None,
36        }
37    }
38
39    pub fn with_parameters(mut self, learning_rate: T, mu: T, decay: T) -> Self {
40        self.learning_rate = learning_rate;
41        self.mu = mu;
42        self.decay = decay;
43        self
44    }
45
46    pub fn with_error_function(mut self, error_function: Box<dyn ErrorFunction<T>>) -> Self {
47        self.error_function = error_function;
48        self
49    }
50
51    fn initialize_state(&mut self, network: &Network<T>) {
52        if self.previous_weight_gradients.is_empty() {
53            // Initialize state for each layer
54            self.previous_weight_gradients = network
55                .layers
56                .iter()
57                .skip(1) // Skip input layer
58                .map(|layer| {
59                    let num_neurons = layer.neurons.len();
60                    let num_connections = if layer.neurons.is_empty() {
61                        0
62                    } else {
63                        layer.neurons[0].connections.len()
64                    };
65                    vec![T::zero(); num_neurons * num_connections]
66                })
67                .collect();
68
69            self.previous_bias_gradients = network
70                .layers
71                .iter()
72                .skip(1) // Skip input layer
73                .map(|layer| vec![T::zero(); layer.neurons.len()])
74                .collect();
75
76            self.previous_weight_deltas = network
77                .layers
78                .iter()
79                .skip(1) // Skip input layer
80                .map(|layer| {
81                    let num_neurons = layer.neurons.len();
82                    let num_connections = if layer.neurons.is_empty() {
83                        0
84                    } else {
85                        layer.neurons[0].connections.len()
86                    };
87                    vec![T::zero(); num_neurons * num_connections]
88                })
89                .collect();
90
91            self.previous_bias_deltas = network
92                .layers
93                .iter()
94                .skip(1) // Skip input layer
95                .map(|layer| vec![T::zero(); layer.neurons.len()])
96                .collect();
97        }
98    }
99
100    fn calculate_quickprop_delta(
101        &self,
102        gradient: T,
103        previous_gradient: T,
104        previous_delta: T,
105        weight: T,
106    ) -> T {
107        if previous_gradient == T::zero() {
108            // First epoch or no previous gradient: use standard gradient descent
109            return -self.learning_rate * gradient + self.decay * weight;
110        }
111
112        let gradient_diff = gradient - previous_gradient;
113
114        if gradient_diff == T::zero() {
115            // No change in gradient: use momentum-like update
116            return -self.learning_rate * gradient + self.decay * weight;
117        }
118
119        // Quickprop formula: delta = (gradient / (previous_gradient - gradient)) * previous_delta
120        let factor = gradient / gradient_diff;
121        let mut delta = factor * previous_delta;
122
123        // Limit the maximum step size
124        let max_delta = self.mu * previous_delta.abs();
125        if delta.abs() > max_delta {
126            delta = if delta > T::zero() {
127                max_delta
128            } else {
129                -max_delta
130            };
131        }
132
133        // Add decay term
134        delta + self.decay * weight
135    }
136}
137
138impl<T: Float + Send> Default for Quickprop<T> {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144impl<T: Float + Send> TrainingAlgorithm<T> for Quickprop<T> {
145    fn train_epoch(
146        &mut self,
147        network: &mut Network<T>,
148        data: &TrainingData<T>,
149    ) -> Result<T, TrainingError> {
150        self.initialize_state(network);
151
152        let mut total_error = T::zero();
153
154        // Calculate gradients over entire dataset
155        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
156            let output = network.run(input);
157            total_error = total_error + self.error_function.calculate(&output, desired_output);
158
159            // Calculate and accumulate gradients (placeholder)
160            // In a full implementation, you would:
161            // 1. Perform backpropagation to calculate gradients
162            // 2. Update weights using Quickprop rules
163        }
164
165        // Placeholder for Quickprop weight updates
166        // This would apply the Quickprop algorithm to update weights
167
168        Ok(total_error / T::from(data.inputs.len()).unwrap())
169    }
170
171    fn calculate_error(&self, network: &Network<T>, data: &TrainingData<T>) -> T {
172        let mut total_error = T::zero();
173        let mut network_clone = network.clone();
174
175        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
176            let output = network_clone.run(input);
177            total_error = total_error + self.error_function.calculate(&output, desired_output);
178        }
179
180        total_error / T::from(data.inputs.len()).unwrap()
181    }
182
183    fn count_bit_fails(
184        &self,
185        network: &Network<T>,
186        data: &TrainingData<T>,
187        bit_fail_limit: T,
188    ) -> usize {
189        let mut bit_fails = 0;
190        let mut network_clone = network.clone();
191
192        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
193            let output = network_clone.run(input);
194
195            for (&actual, &desired) in output.iter().zip(desired_output.iter()) {
196                if (actual - desired).abs() > bit_fail_limit {
197                    bit_fails += 1;
198                }
199            }
200        }
201
202        bit_fails
203    }
204
205    fn save_state(&self) -> TrainingState<T> {
206        let mut state = HashMap::new();
207
208        // Save Quickprop parameters
209        state.insert("learning_rate".to_string(), vec![self.learning_rate]);
210        state.insert("mu".to_string(), vec![self.mu]);
211        state.insert("decay".to_string(), vec![self.decay]);
212
213        // Save previous gradients and deltas (flattened)
214        let mut all_weight_gradients = Vec::new();
215        for layer_gradients in &self.previous_weight_gradients {
216            all_weight_gradients.extend_from_slice(layer_gradients);
217        }
218        state.insert(
219            "previous_weight_gradients".to_string(),
220            all_weight_gradients,
221        );
222
223        let mut all_bias_gradients = Vec::new();
224        for layer_gradients in &self.previous_bias_gradients {
225            all_bias_gradients.extend_from_slice(layer_gradients);
226        }
227        state.insert("previous_bias_gradients".to_string(), all_bias_gradients);
228
229        let mut all_weight_deltas = Vec::new();
230        for layer_deltas in &self.previous_weight_deltas {
231            all_weight_deltas.extend_from_slice(layer_deltas);
232        }
233        state.insert("previous_weight_deltas".to_string(), all_weight_deltas);
234
235        let mut all_bias_deltas = Vec::new();
236        for layer_deltas in &self.previous_bias_deltas {
237            all_bias_deltas.extend_from_slice(layer_deltas);
238        }
239        state.insert("previous_bias_deltas".to_string(), all_bias_deltas);
240
241        TrainingState {
242            epoch: 0,
243            best_error: T::from(f32::MAX).unwrap(),
244            algorithm_specific: state,
245        }
246    }
247
248    fn restore_state(&mut self, state: TrainingState<T>) {
249        // Restore Quickprop parameters
250        if let Some(val) = state.algorithm_specific.get("learning_rate") {
251            if !val.is_empty() {
252                self.learning_rate = val[0];
253            }
254        }
255        if let Some(val) = state.algorithm_specific.get("mu") {
256            if !val.is_empty() {
257                self.mu = val[0];
258            }
259        }
260        if let Some(val) = state.algorithm_specific.get("decay") {
261            if !val.is_empty() {
262                self.decay = val[0];
263            }
264        }
265
266        // Note: Previous gradients and deltas would need network structure info to properly restore
267        // This is a simplified version - in production, you'd need to store layer sizes too
268    }
269
270    fn set_callback(&mut self, callback: TrainingCallback<T>) {
271        self.callback = Some(callback);
272    }
273
274    fn call_callback(
275        &mut self,
276        epoch: usize,
277        network: &Network<T>,
278        data: &TrainingData<T>,
279    ) -> bool {
280        let error = self.calculate_error(network, data);
281        if let Some(ref mut callback) = self.callback {
282            callback(epoch, error)
283        } else {
284            true
285        }
286    }
287}