ruv_fann/training/
rprop.rs1use super::*;
4use num_traits::Float;
5use std::collections::HashMap;
6
7pub struct Rprop<T: Float + Send> {
10 increase_factor: T,
11 decrease_factor: T,
12 delta_min: T,
13 delta_max: T,
14 delta_zero: T,
15 error_function: Box<dyn ErrorFunction<T>>,
16
17 weight_step_sizes: Vec<Vec<T>>,
19 bias_step_sizes: Vec<Vec<T>>,
20 previous_weight_gradients: Vec<Vec<T>>,
21 previous_bias_gradients: Vec<Vec<T>>,
22
23 callback: Option<TrainingCallback<T>>,
24}
25
26impl<T: Float + Send> Rprop<T> {
27 pub fn new() -> Self {
28 Self {
29 increase_factor: T::from(1.2).unwrap(),
30 decrease_factor: T::from(0.5).unwrap(),
31 delta_min: T::zero(),
32 delta_max: T::from(50.0).unwrap(),
33 delta_zero: T::from(0.1).unwrap(),
34 error_function: Box::new(MseError),
35 weight_step_sizes: Vec::new(),
36 bias_step_sizes: Vec::new(),
37 previous_weight_gradients: Vec::new(),
38 previous_bias_gradients: Vec::new(),
39 callback: None,
40 }
41 }
42
43 pub fn with_parameters(
44 mut self,
45 increase_factor: T,
46 decrease_factor: T,
47 delta_min: T,
48 delta_max: T,
49 delta_zero: T,
50 ) -> Self {
51 self.increase_factor = increase_factor;
52 self.decrease_factor = decrease_factor;
53 self.delta_min = delta_min;
54 self.delta_max = delta_max;
55 self.delta_zero = delta_zero;
56 self
57 }
58
59 pub fn with_error_function(mut self, error_function: Box<dyn ErrorFunction<T>>) -> Self {
60 self.error_function = error_function;
61 self
62 }
63
64 fn initialize_state(&mut self, network: &Network<T>) {
65 if self.weight_step_sizes.is_empty() {
66 self.weight_step_sizes = network
68 .layers
69 .iter()
70 .skip(1) .map(|layer| {
72 let num_neurons = layer.neurons.len();
73 let num_connections = if layer.neurons.is_empty() {
74 0
75 } else {
76 layer.neurons[0].connections.len()
77 };
78 vec![self.delta_zero; num_neurons * num_connections]
79 })
80 .collect();
81
82 self.bias_step_sizes = network
83 .layers
84 .iter()
85 .skip(1) .map(|layer| vec![self.delta_zero; layer.neurons.len()])
87 .collect();
88
89 self.previous_weight_gradients = network
91 .layers
92 .iter()
93 .skip(1) .map(|layer| {
95 let num_neurons = layer.neurons.len();
96 let num_connections = if layer.neurons.is_empty() {
97 0
98 } else {
99 layer.neurons[0].connections.len()
100 };
101 vec![T::zero(); num_neurons * num_connections]
102 })
103 .collect();
104
105 self.previous_bias_gradients = network
106 .layers
107 .iter()
108 .skip(1) .map(|layer| vec![T::zero(); layer.neurons.len()])
110 .collect();
111 }
112 }
113
114 fn update_step_size(&self, step_size: T, gradient: T, previous_gradient: T) -> T {
115 let sign_change = gradient * previous_gradient;
116
117 if sign_change > T::zero() {
118 (step_size * self.increase_factor).min(self.delta_max)
120 } else if sign_change < T::zero() {
121 (step_size * self.decrease_factor).max(self.delta_min)
123 } else {
124 step_size
126 }
127 }
128}
129
130impl<T: Float + Send> Default for Rprop<T> {
131 fn default() -> Self {
132 Self::new()
133 }
134}
135
136impl<T: Float + Send> TrainingAlgorithm<T> for Rprop<T> {
137 fn train_epoch(
138 &mut self,
139 network: &mut Network<T>,
140 data: &TrainingData<T>,
141 ) -> Result<T, TrainingError> {
142 self.initialize_state(network);
143
144 let mut total_error = T::zero();
145
146 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
148 let output = network.run(input);
149 total_error = total_error + self.error_function.calculate(&output, desired_output);
150
151 }
156
157 Ok(total_error / T::from(data.inputs.len()).unwrap())
161 }
162
163 fn calculate_error(&self, network: &Network<T>, data: &TrainingData<T>) -> T {
164 let mut total_error = T::zero();
165 let mut network_clone = network.clone();
166
167 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
168 let output = network_clone.run(input);
169 total_error = total_error + self.error_function.calculate(&output, desired_output);
170 }
171
172 total_error / T::from(data.inputs.len()).unwrap()
173 }
174
175 fn count_bit_fails(
176 &self,
177 network: &Network<T>,
178 data: &TrainingData<T>,
179 bit_fail_limit: T,
180 ) -> usize {
181 let mut bit_fails = 0;
182 let mut network_clone = network.clone();
183
184 for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
185 let output = network_clone.run(input);
186
187 for (&actual, &desired) in output.iter().zip(desired_output.iter()) {
188 if (actual - desired).abs() > bit_fail_limit {
189 bit_fails += 1;
190 }
191 }
192 }
193
194 bit_fails
195 }
196
197 fn save_state(&self) -> TrainingState<T> {
198 let mut state = HashMap::new();
199
200 state.insert("increase_factor".to_string(), vec![self.increase_factor]);
202 state.insert("decrease_factor".to_string(), vec![self.decrease_factor]);
203 state.insert("delta_min".to_string(), vec![self.delta_min]);
204 state.insert("delta_max".to_string(), vec![self.delta_max]);
205 state.insert("delta_zero".to_string(), vec![self.delta_zero]);
206
207 let mut all_weight_steps = Vec::new();
209 for layer_steps in &self.weight_step_sizes {
210 all_weight_steps.extend_from_slice(layer_steps);
211 }
212 state.insert("weight_step_sizes".to_string(), all_weight_steps);
213
214 let mut all_bias_steps = Vec::new();
215 for layer_steps in &self.bias_step_sizes {
216 all_bias_steps.extend_from_slice(layer_steps);
217 }
218 state.insert("bias_step_sizes".to_string(), all_bias_steps);
219
220 TrainingState {
221 epoch: 0,
222 best_error: T::from(f32::MAX).unwrap(),
223 algorithm_specific: state,
224 }
225 }
226
227 fn restore_state(&mut self, state: TrainingState<T>) {
228 if let Some(val) = state.algorithm_specific.get("increase_factor") {
230 if !val.is_empty() {
231 self.increase_factor = val[0];
232 }
233 }
234 if let Some(val) = state.algorithm_specific.get("decrease_factor") {
235 if !val.is_empty() {
236 self.decrease_factor = val[0];
237 }
238 }
239 if let Some(val) = state.algorithm_specific.get("delta_min") {
240 if !val.is_empty() {
241 self.delta_min = val[0];
242 }
243 }
244 if let Some(val) = state.algorithm_specific.get("delta_max") {
245 if !val.is_empty() {
246 self.delta_max = val[0];
247 }
248 }
249 if let Some(val) = state.algorithm_specific.get("delta_zero") {
250 if !val.is_empty() {
251 self.delta_zero = val[0];
252 }
253 }
254
255 }
258
259 fn set_callback(&mut self, callback: TrainingCallback<T>) {
260 self.callback = Some(callback);
261 }
262
263 fn call_callback(
264 &mut self,
265 epoch: usize,
266 network: &Network<T>,
267 data: &TrainingData<T>,
268 ) -> bool {
269 let error = self.calculate_error(network, data);
270 if let Some(ref mut callback) = self.callback {
271 callback(epoch, error)
272 } else {
273 true
274 }
275 }
276}