ruv_fann/training/
quickprop.rs1use super::*;
4use num_traits::Float;
5use std::collections::HashMap;
6
7pub struct Quickprop<T: Float + Send> {
10 learning_rate: T,
11 mu: T,
12 decay: T,
13 error_function: Box<dyn ErrorFunction<T>>,
14
15 previous_weight_gradients: Vec<Vec<T>>,
17 previous_bias_gradients: Vec<Vec<T>>,
18 previous_weight_deltas: Vec<Vec<T>>,
19 previous_bias_deltas: Vec<Vec<T>>,
20
21 callback: Option<TrainingCallback<T>>,
22}
23
24impl<T: Float + Send> Quickprop<T> {
25 pub fn new() -> Self {
26 Self {
27 learning_rate: T::from(0.7).unwrap(),
28 mu: T::from(1.75).unwrap(),
29 decay: T::from(-0.0001).unwrap(),
30 error_function: Box::new(MseError),
31 previous_weight_gradients: Vec::new(),
32 previous_bias_gradients: Vec::new(),
33 previous_weight_deltas: Vec::new(),
34 previous_bias_deltas: Vec::new(),
35 callback: None,
36 }
37 }
38
39 pub fn with_parameters(mut self, learning_rate: T, mu: T, decay: T) -> Self {
40 self.learning_rate = learning_rate;
41 self.mu = mu;
42 self.decay = decay;
43 self
44 }
45
46 pub fn with_error_function(mut self, error_function: Box<dyn ErrorFunction<T>>) -> Self {
47 self.error_function = error_function;
48 self
49 }
50
51 fn initialize_state(&mut self, network: &Network<T>) {
52 if self.previous_weight_gradients.is_empty() {
53 self.previous_weight_gradients = network
55 .layers
56 .iter()
57 .skip(1) .map(|layer| {
59 let num_neurons = layer.neurons.len();
60 let num_connections = if layer.neurons.is_empty() {
61 0
62 } else {
63 layer.neurons[0].connections.len()
64 };
65 vec![T::zero(); num_neurons * num_connections]
66 })
67 .collect();
68
69 self.previous_bias_gradients = network
70 .layers
71 .iter()
72 .skip(1) .map(|layer| vec![T::zero(); layer.neurons.len()])
74 .collect();
75
76 self.previous_weight_deltas = network
77 .layers
78 .iter()
79 .skip(1) .map(|layer| {
81 let num_neurons = layer.neurons.len();
82 let num_connections = if layer.neurons.is_empty() {
83 0
84 } else {
85 layer.neurons[0].connections.len()
86 };
87 vec![T::zero(); num_neurons * num_connections]
88 })
89 .collect();
90
91 self.previous_bias_deltas = network
92 .layers
93 .iter()
94 .skip(1) .map(|layer| vec![T::zero(); layer.neurons.len()])
96 .collect();
97 }
98 }
99
100 fn calculate_quickprop_delta(
101 &self,
102 gradient: T,
103 previous_gradient: T,
104 previous_delta: T,
105 weight: T,
106 ) -> T {
107 if previous_gradient == T::zero() {
108 return -self.learning_rate * gradient + self.decay * weight;
110 }
111
112 let gradient_diff = gradient - previous_gradient;
113
114 if gradient_diff == T::zero() {
115 return -self.learning_rate * gradient + self.decay * weight;
117 }
118
119 let factor = gradient / gradient_diff;
121 let mut delta = factor * previous_delta;
122
123 let max_delta = self.mu * previous_delta.abs();
125 if delta.abs() > max_delta {
126 delta = if delta > T::zero() {
127 max_delta
128 } else {
129 -max_delta
130 };
131 }
132
133 delta + self.decay * weight
135 }
136}
137
138impl<T: Float + Send> Default for Quickprop<T> {
139 fn default() -> Self {
140 Self::new()
141 }
142}
143
144impl<T: Float + Send> TrainingAlgorithm<T> for Quickprop<T> {
145 fn train_epoch(
146 &mut self,
147 network: &mut Network<T>,
148 data: &TrainingData<T>,
149 ) -> Result<T, TrainingError> {
150 self.initialize_state(network);
151
152 let mut total_error = T::zero();
153
154 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
156 let output = network.run(input);
157 total_error = total_error + self.error_function.calculate(&output, desired_output);
158
159 }
164
165 Ok(total_error / T::from(data.inputs.len()).unwrap())
169 }
170
171 fn calculate_error(&self, network: &Network<T>, data: &TrainingData<T>) -> T {
172 let mut total_error = T::zero();
173 let mut network_clone = network.clone();
174
175 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
176 let output = network_clone.run(input);
177 total_error = total_error + self.error_function.calculate(&output, desired_output);
178 }
179
180 total_error / T::from(data.inputs.len()).unwrap()
181 }
182
183 fn count_bit_fails(
184 &self,
185 network: &Network<T>,
186 data: &TrainingData<T>,
187 bit_fail_limit: T,
188 ) -> usize {
189 let mut bit_fails = 0;
190 let mut network_clone = network.clone();
191
192 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
193 let output = network_clone.run(input);
194
195 for (&actual, &desired) in output.iter().zip(desired_output.iter()) {
196 if (actual - desired).abs() > bit_fail_limit {
197 bit_fails += 1;
198 }
199 }
200 }
201
202 bit_fails
203 }
204
205 fn save_state(&self) -> TrainingState<T> {
206 let mut state = HashMap::new();
207
208 state.insert("learning_rate".to_string(), vec![self.learning_rate]);
210 state.insert("mu".to_string(), vec![self.mu]);
211 state.insert("decay".to_string(), vec![self.decay]);
212
213 let mut all_weight_gradients = Vec::new();
215 for layer_gradients in &self.previous_weight_gradients {
216 all_weight_gradients.extend_from_slice(layer_gradients);
217 }
218 state.insert(
219 "previous_weight_gradients".to_string(),
220 all_weight_gradients,
221 );
222
223 let mut all_bias_gradients = Vec::new();
224 for layer_gradients in &self.previous_bias_gradients {
225 all_bias_gradients.extend_from_slice(layer_gradients);
226 }
227 state.insert("previous_bias_gradients".to_string(), all_bias_gradients);
228
229 let mut all_weight_deltas = Vec::new();
230 for layer_deltas in &self.previous_weight_deltas {
231 all_weight_deltas.extend_from_slice(layer_deltas);
232 }
233 state.insert("previous_weight_deltas".to_string(), all_weight_deltas);
234
235 let mut all_bias_deltas = Vec::new();
236 for layer_deltas in &self.previous_bias_deltas {
237 all_bias_deltas.extend_from_slice(layer_deltas);
238 }
239 state.insert("previous_bias_deltas".to_string(), all_bias_deltas);
240
241 TrainingState {
242 epoch: 0,
243 best_error: T::from(f32::MAX).unwrap(),
244 algorithm_specific: state,
245 }
246 }
247
248 fn restore_state(&mut self, state: TrainingState<T>) {
249 if let Some(val) = state.algorithm_specific.get("learning_rate") {
251 if !val.is_empty() {
252 self.learning_rate = val[0];
253 }
254 }
255 if let Some(val) = state.algorithm_specific.get("mu") {
256 if !val.is_empty() {
257 self.mu = val[0];
258 }
259 }
260 if let Some(val) = state.algorithm_specific.get("decay") {
261 if !val.is_empty() {
262 self.decay = val[0];
263 }
264 }
265
266 }
269
270 fn set_callback(&mut self, callback: TrainingCallback<T>) {
271 self.callback = Some(callback);
272 }
273
274 fn call_callback(
275 &mut self,
276 epoch: usize,
277 network: &Network<T>,
278 data: &TrainingData<T>,
279 ) -> bool {
280 let error = self.calculate_error(network, data);
281 if let Some(ref mut callback) = self.callback {
282 callback(epoch, error)
283 } else {
284 true
285 }
286 }
287}