ruv_fann/training/
rprop.rs

1//! Resilient Propagation (RPROP) training algorithm
2
3use super::*;
4use num_traits::Float;
5use std::collections::HashMap;
6
7/// RPROP (Resilient Propagation) trainer
8/// An adaptive learning algorithm that only uses the sign of the gradient
9pub struct Rprop<T: Float + Send> {
10    increase_factor: T,
11    decrease_factor: T,
12    delta_min: T,
13    delta_max: T,
14    delta_zero: T,
15    error_function: Box<dyn ErrorFunction<T>>,
16
17    // State variables
18    weight_step_sizes: Vec<Vec<T>>,
19    bias_step_sizes: Vec<Vec<T>>,
20    previous_weight_gradients: Vec<Vec<T>>,
21    previous_bias_gradients: Vec<Vec<T>>,
22
23    callback: Option<TrainingCallback<T>>,
24}
25
26impl<T: Float + Send> Rprop<T> {
27    pub fn new() -> Self {
28        Self {
29            increase_factor: T::from(1.2).unwrap(),
30            decrease_factor: T::from(0.5).unwrap(),
31            delta_min: T::zero(),
32            delta_max: T::from(50.0).unwrap(),
33            delta_zero: T::from(0.1).unwrap(),
34            error_function: Box::new(MseError),
35            weight_step_sizes: Vec::new(),
36            bias_step_sizes: Vec::new(),
37            previous_weight_gradients: Vec::new(),
38            previous_bias_gradients: Vec::new(),
39            callback: None,
40        }
41    }
42
43    pub fn with_parameters(
44        mut self,
45        increase_factor: T,
46        decrease_factor: T,
47        delta_min: T,
48        delta_max: T,
49        delta_zero: T,
50    ) -> Self {
51        self.increase_factor = increase_factor;
52        self.decrease_factor = decrease_factor;
53        self.delta_min = delta_min;
54        self.delta_max = delta_max;
55        self.delta_zero = delta_zero;
56        self
57    }
58
59    pub fn with_error_function(mut self, error_function: Box<dyn ErrorFunction<T>>) -> Self {
60        self.error_function = error_function;
61        self
62    }
63
64    fn initialize_state(&mut self, network: &Network<T>) {
65        if self.weight_step_sizes.is_empty() {
66            // Initialize step sizes and gradients for each layer
67            self.weight_step_sizes = network
68                .layers
69                .iter()
70                .skip(1) // Skip input layer
71                .map(|layer| {
72                    let num_neurons = layer.neurons.len();
73                    let num_connections = if layer.neurons.is_empty() {
74                        0
75                    } else {
76                        layer.neurons[0].connections.len()
77                    };
78                    vec![self.delta_zero; num_neurons * num_connections]
79                })
80                .collect();
81
82            self.bias_step_sizes = network
83                .layers
84                .iter()
85                .skip(1) // Skip input layer
86                .map(|layer| vec![self.delta_zero; layer.neurons.len()])
87                .collect();
88
89            // Initialize previous gradients to zero
90            self.previous_weight_gradients = network
91                .layers
92                .iter()
93                .skip(1) // Skip input layer
94                .map(|layer| {
95                    let num_neurons = layer.neurons.len();
96                    let num_connections = if layer.neurons.is_empty() {
97                        0
98                    } else {
99                        layer.neurons[0].connections.len()
100                    };
101                    vec![T::zero(); num_neurons * num_connections]
102                })
103                .collect();
104
105            self.previous_bias_gradients = network
106                .layers
107                .iter()
108                .skip(1) // Skip input layer
109                .map(|layer| vec![T::zero(); layer.neurons.len()])
110                .collect();
111        }
112    }
113
114    fn update_step_size(&self, step_size: T, gradient: T, previous_gradient: T) -> T {
115        let sign_change = gradient * previous_gradient;
116
117        if sign_change > T::zero() {
118            // Same sign: increase step size
119            (step_size * self.increase_factor).min(self.delta_max)
120        } else if sign_change < T::zero() {
121            // Different sign: decrease step size
122            (step_size * self.decrease_factor).max(self.delta_min)
123        } else {
124            // No previous gradient or zero gradient: keep step size
125            step_size
126        }
127    }
128}
129
130impl<T: Float + Send> Default for Rprop<T> {
131    fn default() -> Self {
132        Self::new()
133    }
134}
135
136impl<T: Float + Send> TrainingAlgorithm<T> for Rprop<T> {
137    fn train_epoch(
138        &mut self,
139        network: &mut Network<T>,
140        data: &TrainingData<T>,
141    ) -> Result<T, TrainingError> {
142        self.initialize_state(network);
143
144        let mut total_error = T::zero();
145
146        // Calculate gradients over entire dataset
147        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
148            let output = network.run(input);
149            total_error = total_error + self.error_function.calculate(&output, desired_output);
150
151            // Calculate and accumulate gradients (placeholder)
152            // In a full implementation, you would:
153            // 1. Perform backpropagation to calculate gradients
154            // 2. Update weights using RPROP rules
155        }
156
157        // Placeholder for RPROP weight updates
158        // This would apply the RPROP algorithm to update weights based on gradient signs
159
160        Ok(total_error / T::from(data.inputs.len()).unwrap())
161    }
162
163    fn calculate_error(&self, network: &Network<T>, data: &TrainingData<T>) -> T {
164        let mut total_error = T::zero();
165        let mut network_clone = network.clone();
166
167        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
168            let output = network_clone.run(input);
169            total_error = total_error + self.error_function.calculate(&output, desired_output);
170        }
171
172        total_error / T::from(data.inputs.len()).unwrap()
173    }
174
175    fn count_bit_fails(
176        &self,
177        network: &Network<T>,
178        data: &TrainingData<T>,
179        bit_fail_limit: T,
180    ) -> usize {
181        let mut bit_fails = 0;
182        let mut network_clone = network.clone();
183
184        for (input, desired_output) in data.inputs.iter().zip(data.outputs.iter()) {
185            let output = network_clone.run(input);
186
187            for (&actual, &desired) in output.iter().zip(desired_output.iter()) {
188                if (actual - desired).abs() > bit_fail_limit {
189                    bit_fails += 1;
190                }
191            }
192        }
193
194        bit_fails
195    }
196
197    fn save_state(&self) -> TrainingState<T> {
198        let mut state = HashMap::new();
199
200        // Save RPROP parameters
201        state.insert("increase_factor".to_string(), vec![self.increase_factor]);
202        state.insert("decrease_factor".to_string(), vec![self.decrease_factor]);
203        state.insert("delta_min".to_string(), vec![self.delta_min]);
204        state.insert("delta_max".to_string(), vec![self.delta_max]);
205        state.insert("delta_zero".to_string(), vec![self.delta_zero]);
206
207        // Save step sizes (flattened)
208        let mut all_weight_steps = Vec::new();
209        for layer_steps in &self.weight_step_sizes {
210            all_weight_steps.extend_from_slice(layer_steps);
211        }
212        state.insert("weight_step_sizes".to_string(), all_weight_steps);
213
214        let mut all_bias_steps = Vec::new();
215        for layer_steps in &self.bias_step_sizes {
216            all_bias_steps.extend_from_slice(layer_steps);
217        }
218        state.insert("bias_step_sizes".to_string(), all_bias_steps);
219
220        TrainingState {
221            epoch: 0,
222            best_error: T::from(f32::MAX).unwrap(),
223            algorithm_specific: state,
224        }
225    }
226
227    fn restore_state(&mut self, state: TrainingState<T>) {
228        // Restore RPROP parameters
229        if let Some(val) = state.algorithm_specific.get("increase_factor") {
230            if !val.is_empty() {
231                self.increase_factor = val[0];
232            }
233        }
234        if let Some(val) = state.algorithm_specific.get("decrease_factor") {
235            if !val.is_empty() {
236                self.decrease_factor = val[0];
237            }
238        }
239        if let Some(val) = state.algorithm_specific.get("delta_min") {
240            if !val.is_empty() {
241                self.delta_min = val[0];
242            }
243        }
244        if let Some(val) = state.algorithm_specific.get("delta_max") {
245            if !val.is_empty() {
246                self.delta_max = val[0];
247            }
248        }
249        if let Some(val) = state.algorithm_specific.get("delta_zero") {
250            if !val.is_empty() {
251                self.delta_zero = val[0];
252            }
253        }
254
255        // Note: Step sizes would need network structure info to properly restore
256        // This is a simplified version - in production, you'd need to store layer sizes too
257    }
258
259    fn set_callback(&mut self, callback: TrainingCallback<T>) {
260        self.callback = Some(callback);
261    }
262
263    fn call_callback(
264        &mut self,
265        epoch: usize,
266        network: &Network<T>,
267        data: &TrainingData<T>,
268    ) -> bool {
269        let error = self.calculate_error(network, data);
270        if let Some(ref mut callback) = self.callback {
271            callback(epoch, error)
272        } else {
273            true
274        }
275    }
276}