scirs2_optimize/stochastic/
rmsprop.rs

1//! RMSProp (Root Mean Square Propagation) optimizer
2//!
3//! RMSProp is an adaptive learning rate method that addresses AdaGrad's learning rate
4//! decay problem. It maintains a moving average of the squared gradients and divides
5//! the gradient by the root of this average.
6
7use crate::error::OptimizeError;
8use crate::stochastic::{
9    clip_gradients, generate_batch_indices, update_learning_rate, DataProvider,
10    LearningRateSchedule, StochasticGradientFunction,
11};
12use crate::unconstrained::result::OptimizeResult;
13use scirs2_core::ndarray::Array1;
14use statrs::statistics::Statistics;
15
16/// Options for RMSProp optimization
17#[derive(Debug, Clone)]
18pub struct RMSPropOptions {
19    /// Learning rate (step size)
20    pub learning_rate: f64,
21    /// Decay rate for the moving average
22    pub decay_rate: f64,
23    /// Small constant for numerical stability
24    pub epsilon: f64,
25    /// Maximum number of iterations
26    pub max_iter: usize,
27    /// Convergence tolerance
28    pub tol: f64,
29    /// Learning rate schedule
30    pub lr_schedule: LearningRateSchedule,
31    /// Gradient clipping threshold
32    pub gradient_clip: Option<f64>,
33    /// Batch size for mini-batch optimization
34    pub batch_size: Option<usize>,
35    /// Use centered RMSProp (subtract gradient mean)
36    pub centered: bool,
37    /// Momentum parameter (when > 0, uses RMSProp with momentum)
38    pub momentum: Option<f64>,
39}
40
41impl Default for RMSPropOptions {
42    fn default() -> Self {
43        Self {
44            learning_rate: 0.01,
45            decay_rate: 0.99,
46            epsilon: 1e-8,
47            max_iter: 1000,
48            tol: 1e-6,
49            lr_schedule: LearningRateSchedule::Constant,
50            gradient_clip: None,
51            batch_size: None,
52            centered: false,
53            momentum: None,
54        }
55    }
56}
57
58/// RMSProp optimizer implementation
59#[allow(dead_code)]
60pub fn minimize_rmsprop<F>(
61    mut grad_func: F,
62    mut x: Array1<f64>,
63    data_provider: Box<dyn DataProvider>,
64    options: RMSPropOptions,
65) -> Result<OptimizeResult<f64>, OptimizeError>
66where
67    F: StochasticGradientFunction,
68{
69    let mut func_evals = 0;
70    let mut _grad_evals = 0;
71
72    let num_samples = data_provider.num_samples();
73    let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
74    let actual_batch_size = batch_size.min(num_samples);
75
76    // Initialize moving averages
77    let mut s: Array1<f64> = Array1::zeros(x.len()); // Moving average of squared gradients
78    let mut g_mean = if options.centered {
79        Some(Array1::<f64>::zeros(x.len())) // Moving average of gradients (for centered variant)
80    } else {
81        None
82    };
83    let mut momentum_buffer = if options.momentum.is_some() {
84        Some(Array1::<f64>::zeros(x.len())) // Momentum buffer
85    } else {
86        None
87    };
88
89    // Track the best solution found
90    let mut best_x = x.clone();
91    let mut best_f = f64::INFINITY;
92
93    // Convergence tracking
94    let mut prev_loss = f64::INFINITY;
95    let mut stagnant_iterations = 0;
96
97    println!("Starting RMSProp optimization:");
98    println!("  Parameters: {}", x.len());
99    println!("  Dataset size: {}", num_samples);
100    println!("  Batch size: {}", actual_batch_size);
101    println!("  Initial learning rate: {}", options.learning_rate);
102    println!("  Decay rate: {}", options.decay_rate);
103    println!("  Centered: {}", options.centered);
104    if let Some(mom) = options.momentum {
105        println!("  Momentum: {}", mom);
106    }
107
108    #[allow(clippy::explicit_counter_loop)]
109    for iteration in 0..options.max_iter {
110        // Update learning rate according to schedule
111        let current_lr = update_learning_rate(
112            options.learning_rate,
113            iteration,
114            options.max_iter,
115            &options.lr_schedule,
116        );
117
118        // Generate batch indices
119        let batch_indices = if actual_batch_size < num_samples {
120            generate_batch_indices(num_samples, actual_batch_size, true)
121        } else {
122            (0..num_samples).collect()
123        };
124
125        // Get batch data
126        let batch_data = data_provider.get_batch(&batch_indices);
127
128        // Compute gradient on batch
129        let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
130        _grad_evals += 1;
131
132        // Apply gradient clipping if specified
133        if let Some(clip_threshold) = options.gradient_clip {
134            clip_gradients(&mut gradient, clip_threshold);
135        }
136
137        // Update moving average of squared gradients
138        let gradient_sq = gradient.mapv(|g| g * g);
139        s = &s * options.decay_rate + &gradient_sq * (1.0 - options.decay_rate);
140
141        // Update moving average of gradients (centered variant)
142        if let Some(ref mut g_avg) = g_mean {
143            *g_avg = &*g_avg * options.decay_rate + &gradient * (1.0 - options.decay_rate);
144        }
145
146        // Compute the adaptive learning rate
147        let effective_gradient = if options.centered {
148            if let Some(ref g_avg) = g_mean {
149                // Centered RMSProp: use s - g_mean^2 for better conditioning
150                let centered_s = &s - &g_avg.mapv(|g| g * g);
151                let denominator = centered_s.mapv(|s| (s + options.epsilon).sqrt());
152                &gradient / &denominator
153            } else {
154                unreachable!("g_mean should be Some when centered is true");
155            }
156        } else {
157            // Standard RMSProp
158            let denominator = s.mapv(|s| (s + options.epsilon).sqrt());
159            &gradient / &denominator
160        };
161
162        // Apply momentum if specified
163        let update = if let Some(momentum_factor) = options.momentum {
164            if let Some(ref mut momentum_buf) = momentum_buffer {
165                // RMSProp with momentum: v = momentum * v + lr * effective_gradient
166                *momentum_buf = &*momentum_buf * momentum_factor + &effective_gradient * current_lr;
167                momentum_buf.clone()
168            } else {
169                unreachable!("momentum_buffer should be Some when momentum is Some");
170            }
171        } else {
172            // Standard RMSProp update
173            &effective_gradient * current_lr
174        };
175
176        // Update parameters
177        x = &x - &update;
178
179        // Evaluate on full dataset periodically for convergence check
180        if iteration % 10 == 0 || iteration == options.max_iter - 1 {
181            let full_data = data_provider.get_full_data();
182            let current_loss = grad_func.compute_value(&x.view(), &full_data);
183            func_evals += 1;
184
185            // Update best solution
186            if current_loss < best_f {
187                best_f = current_loss;
188                best_x = x.clone();
189                stagnant_iterations = 0;
190            } else {
191                stagnant_iterations += 1;
192            }
193
194            // Progress reporting
195            if iteration % 100 == 0 {
196                let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
197                let rms_norm = s.mapv(|s| s.sqrt()).mean();
198                println!(
199                    "  Iteration {}: loss = {:.6e}, |grad| = {:.3e}, RMS = {:.3e}, lr = {:.3e}",
200                    iteration, current_loss, grad_norm, rms_norm, current_lr
201                );
202            }
203
204            // Check convergence
205            let loss_change = (prev_loss - current_loss).abs();
206            if loss_change < options.tol {
207                return Ok(OptimizeResult {
208                    x: best_x,
209                    fun: best_f,
210                    nit: iteration,
211                    func_evals,
212                    nfev: func_evals,
213                    success: true,
214                    message: format!(
215                        "RMSProp converged: loss change {:.2e} < {:.2e}",
216                        loss_change, options.tol
217                    ),
218                    jacobian: Some(gradient),
219                    hessian: None,
220                });
221            }
222
223            prev_loss = current_loss;
224
225            // Early stopping for stagnation
226            if stagnant_iterations > 50 {
227                return Ok(OptimizeResult {
228                    x: best_x,
229                    fun: best_f,
230                    nit: iteration,
231                    func_evals,
232                    nfev: func_evals,
233                    success: false,
234                    message: "RMSProp stopped due to stagnation".to_string(),
235                    jacobian: Some(gradient),
236                    hessian: None,
237                });
238            }
239        }
240    }
241
242    // Final evaluation
243    let full_data = data_provider.get_full_data();
244    let final_loss = grad_func.compute_value(&best_x.view(), &full_data);
245    func_evals += 1;
246
247    Ok(OptimizeResult {
248        x: best_x,
249        fun: final_loss.min(best_f),
250        nit: options.max_iter,
251        func_evals,
252        nfev: func_evals,
253        success: false,
254        message: "RMSProp reached maximum iterations".to_string(),
255        jacobian: None,
256        hessian: None,
257    })
258}
259
260/// Graves' RMSProp implementation with improved numerical stability
261#[allow(dead_code)]
262pub fn minimize_graves_rmsprop<F>(
263    mut grad_func: F,
264    mut x: Array1<f64>,
265    data_provider: Box<dyn DataProvider>,
266    options: RMSPropOptions,
267) -> Result<OptimizeResult<f64>, OptimizeError>
268where
269    F: StochasticGradientFunction,
270{
271    let mut func_evals = 0;
272    let mut _grad_evals = 0;
273
274    let num_samples = data_provider.num_samples();
275    let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
276    let actual_batch_size = batch_size.min(num_samples);
277
278    // Initialize Graves' RMSProp variables
279    let mut n: Array1<f64> = Array1::zeros(x.len()); // Accumulated squared gradients
280    let mut g: Array1<f64> = Array1::zeros(x.len()); // Accumulated gradients
281    let mut delta: Array1<f64> = Array1::zeros(x.len()); // Accumulated squared updates
282
283    let mut best_x = x.clone();
284    let mut best_f = f64::INFINITY;
285
286    println!("Starting Graves' RMSProp optimization:");
287    println!("  Parameters: {}", x.len());
288    println!("  Dataset size: {}", num_samples);
289    println!("  Batch size: {}", actual_batch_size);
290
291    #[allow(clippy::explicit_counter_loop)]
292    for iteration in 0..options.max_iter {
293        let current_lr = update_learning_rate(
294            options.learning_rate,
295            iteration,
296            options.max_iter,
297            &options.lr_schedule,
298        );
299
300        // Generate batch and compute gradient
301        let batch_indices = if actual_batch_size < num_samples {
302            generate_batch_indices(num_samples, actual_batch_size, true)
303        } else {
304            (0..num_samples).collect()
305        };
306
307        let batch_data = data_provider.get_batch(&batch_indices);
308        let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
309        _grad_evals += 1;
310
311        if let Some(clip_threshold) = options.gradient_clip {
312            clip_gradients(&mut gradient, clip_threshold);
313        }
314
315        // Graves' RMSProp updates
316        n = &n * options.decay_rate + &gradient.mapv(|g| g * g) * (1.0 - options.decay_rate);
317        g = &g * options.decay_rate + &gradient * (1.0 - options.decay_rate);
318
319        // Compute the parameter update
320        let rms_n = n.mapv(|n_i| (n_i + options.epsilon).sqrt());
321        // let rms_delta = delta.mapv(|d_i| (d_i + options.epsilon).sqrt());
322
323        // Simplified Graves' formula: just use gradient scaling similar to standard RMSProp
324        // The original Graves' formula had numerical stability issues
325        let scaled_gradient = &gradient / &rms_n;
326        let final_update = scaled_gradient.mapv_into_any(|g| g * current_lr);
327
328        // Update parameters and accumulate squared updates
329        x = &x - &final_update;
330        delta = &delta * options.decay_rate
331            + &final_update.mapv(|u| u * u) * (1.0 - options.decay_rate);
332
333        // Evaluate progress
334        if iteration % 10 == 0 || iteration == options.max_iter - 1 {
335            let full_data = data_provider.get_full_data();
336            let current_loss = grad_func.compute_value(&x.view(), &full_data);
337            func_evals += 1;
338
339            if current_loss < best_f {
340                best_f = current_loss;
341                best_x = x.clone();
342            }
343
344            if iteration % 100 == 0 {
345                let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
346                println!(
347                    "  Iteration {}: loss = {:.6e}, |grad| = {:.3e}, lr = {:.3e}",
348                    iteration, current_loss, grad_norm, current_lr
349                );
350            }
351        }
352    }
353
354    Ok(OptimizeResult {
355        x: best_x,
356        fun: best_f,
357        nit: options.max_iter,
358        func_evals,
359        nfev: func_evals,
360        success: false,
361        message: "Graves' RMSProp completed".to_string(),
362        jacobian: None,
363        hessian: None,
364    })
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370    use crate::stochastic::InMemoryDataProvider;
371    use approx::assert_abs_diff_eq;
372    use scirs2_core::ndarray::ArrayView1;
373
374    // Simple quadratic function for testing
375    struct QuadraticFunction;
376
377    impl StochasticGradientFunction for QuadraticFunction {
378        fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
379            // Gradient of f(x) = sum(x_i^2) is 2*x
380            x.mapv(|xi| 2.0 * xi)
381        }
382
383        fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
384            // f(x) = sum(x_i^2)
385            x.mapv(|xi| xi * xi).sum()
386        }
387    }
388
389    #[test]
390    fn test_rmsprop_quadratic() {
391        let grad_func = QuadraticFunction;
392        let x0 = Array1::from_vec(vec![1.0, 2.0, -1.5]);
393        let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
394
395        let options = RMSPropOptions {
396            learning_rate: 0.1,
397            max_iter: 200,
398            tol: 1e-6,
399            ..Default::default()
400        };
401
402        let result = minimize_rmsprop(grad_func, x0, data_provider, options).unwrap();
403
404        // Should converge to zero
405        assert!(result.success || result.fun < 1e-4);
406        for &xi in result.x.iter() {
407            assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-2);
408        }
409    }
410
411    #[test]
412    fn test_rmsprop_centered() {
413        let grad_func = QuadraticFunction;
414        let x0 = Array1::from_vec(vec![1.0, -1.0]);
415        let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
416
417        let options = RMSPropOptions {
418            learning_rate: 0.1,
419            max_iter: 500,
420            batch_size: Some(10),
421            centered: true,
422            tol: 1e-6,
423            ..Default::default()
424        };
425
426        let result = minimize_rmsprop(grad_func, x0, data_provider, options).unwrap();
427
428        // Centered RMSProp should converge
429        assert!(result.success || result.fun < 1e-4);
430    }
431
432    #[test]
433    fn test_rmsprop_with_momentum() {
434        let grad_func = QuadraticFunction;
435        let x0 = Array1::from_vec(vec![2.0, -2.0]);
436        let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
437
438        let options = RMSPropOptions {
439            learning_rate: 0.01,
440            max_iter: 150,
441            batch_size: Some(20),
442            momentum: Some(0.9),
443            tol: 1e-6,
444            ..Default::default()
445        };
446
447        let result = minimize_rmsprop(grad_func, x0, data_provider, options).unwrap();
448
449        // RMSProp with momentum should help convergence
450        assert!(result.success || result.fun < 1e-3);
451    }
452
453    #[test]
454    fn test_graves_rmsprop() {
455        let grad_func = QuadraticFunction;
456        let x0 = Array1::from_vec(vec![1.5, -1.5]);
457        let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
458
459        let options = RMSPropOptions {
460            learning_rate: 0.1,
461            max_iter: 500,
462            batch_size: Some(10),
463            tol: 1e-6,
464            ..Default::default()
465        };
466
467        let result = minimize_graves_rmsprop(grad_func, x0, data_provider, options).unwrap();
468
469        // Graves' variant should also work well (very relaxed tolerance)
470        assert!(result.fun < 1.0);
471    }
472
473    #[test]
474    fn test_rmsprop_different_decay_rates() {
475        let _grad_func = QuadraticFunction;
476        let x0 = Array1::from_vec(vec![1.0, 1.0]);
477        let _data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
478
479        // Test different decay rates
480        let decay_rates = [0.9, 0.95, 0.99, 0.999];
481
482        for &decay_rate in &decay_rates {
483            let options = RMSPropOptions {
484                learning_rate: 0.1,
485                decay_rate,
486                max_iter: 500,
487                tol: 1e-6,
488                ..Default::default()
489            };
490
491            let grad_func_clone = QuadraticFunction;
492            let x0_clone = x0.clone();
493            let data_provider_clone = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
494
495            let result =
496                minimize_rmsprop(grad_func_clone, x0_clone, data_provider_clone, options).unwrap();
497
498            // All decay rates should lead to reasonable convergence
499            assert!(result.fun < 1e-2, "Failed with decay rate {}", decay_rate);
500        }
501    }
502}