1use crate::error::OptimizeError;
8use crate::stochastic::{
9 clip_gradients, generate_batch_indices, update_learning_rate, DataProvider,
10 LearningRateSchedule, StochasticGradientFunction,
11};
12use crate::unconstrained::result::OptimizeResult;
13use scirs2_core::ndarray::Array1;
14use statrs::statistics::Statistics;
15
16#[derive(Debug, Clone)]
18pub struct RMSPropOptions {
19 pub learning_rate: f64,
21 pub decay_rate: f64,
23 pub epsilon: f64,
25 pub max_iter: usize,
27 pub tol: f64,
29 pub lr_schedule: LearningRateSchedule,
31 pub gradient_clip: Option<f64>,
33 pub batch_size: Option<usize>,
35 pub centered: bool,
37 pub momentum: Option<f64>,
39}
40
41impl Default for RMSPropOptions {
42 fn default() -> Self {
43 Self {
44 learning_rate: 0.01,
45 decay_rate: 0.99,
46 epsilon: 1e-8,
47 max_iter: 1000,
48 tol: 1e-6,
49 lr_schedule: LearningRateSchedule::Constant,
50 gradient_clip: None,
51 batch_size: None,
52 centered: false,
53 momentum: None,
54 }
55 }
56}
57
58#[allow(dead_code)]
60pub fn minimize_rmsprop<F>(
61 mut grad_func: F,
62 mut x: Array1<f64>,
63 data_provider: Box<dyn DataProvider>,
64 options: RMSPropOptions,
65) -> Result<OptimizeResult<f64>, OptimizeError>
66where
67 F: StochasticGradientFunction,
68{
69 let mut func_evals = 0;
70 let mut _grad_evals = 0;
71
72 let num_samples = data_provider.num_samples();
73 let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
74 let actual_batch_size = batch_size.min(num_samples);
75
76 let mut s: Array1<f64> = Array1::zeros(x.len()); let mut g_mean = if options.centered {
79 Some(Array1::<f64>::zeros(x.len())) } else {
81 None
82 };
83 let mut momentum_buffer = if options.momentum.is_some() {
84 Some(Array1::<f64>::zeros(x.len())) } else {
86 None
87 };
88
89 let mut best_x = x.clone();
91 let mut best_f = f64::INFINITY;
92
93 let mut prev_loss = f64::INFINITY;
95 let mut stagnant_iterations = 0;
96
97 println!("Starting RMSProp optimization:");
98 println!(" Parameters: {}", x.len());
99 println!(" Dataset size: {}", num_samples);
100 println!(" Batch size: {}", actual_batch_size);
101 println!(" Initial learning rate: {}", options.learning_rate);
102 println!(" Decay rate: {}", options.decay_rate);
103 println!(" Centered: {}", options.centered);
104 if let Some(mom) = options.momentum {
105 println!(" Momentum: {}", mom);
106 }
107
108 #[allow(clippy::explicit_counter_loop)]
109 for iteration in 0..options.max_iter {
110 let current_lr = update_learning_rate(
112 options.learning_rate,
113 iteration,
114 options.max_iter,
115 &options.lr_schedule,
116 );
117
118 let batch_indices = if actual_batch_size < num_samples {
120 generate_batch_indices(num_samples, actual_batch_size, true)
121 } else {
122 (0..num_samples).collect()
123 };
124
125 let batch_data = data_provider.get_batch(&batch_indices);
127
128 let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
130 _grad_evals += 1;
131
132 if let Some(clip_threshold) = options.gradient_clip {
134 clip_gradients(&mut gradient, clip_threshold);
135 }
136
137 let gradient_sq = gradient.mapv(|g| g * g);
139 s = &s * options.decay_rate + &gradient_sq * (1.0 - options.decay_rate);
140
141 if let Some(ref mut g_avg) = g_mean {
143 *g_avg = &*g_avg * options.decay_rate + &gradient * (1.0 - options.decay_rate);
144 }
145
146 let effective_gradient = if options.centered {
148 if let Some(ref g_avg) = g_mean {
149 let centered_s = &s - &g_avg.mapv(|g| g * g);
151 let denominator = centered_s.mapv(|s| (s + options.epsilon).sqrt());
152 &gradient / &denominator
153 } else {
154 unreachable!("g_mean should be Some when centered is true");
155 }
156 } else {
157 let denominator = s.mapv(|s| (s + options.epsilon).sqrt());
159 &gradient / &denominator
160 };
161
162 let update = if let Some(momentum_factor) = options.momentum {
164 if let Some(ref mut momentum_buf) = momentum_buffer {
165 *momentum_buf = &*momentum_buf * momentum_factor + &effective_gradient * current_lr;
167 momentum_buf.clone()
168 } else {
169 unreachable!("momentum_buffer should be Some when momentum is Some");
170 }
171 } else {
172 &effective_gradient * current_lr
174 };
175
176 x = &x - &update;
178
179 if iteration % 10 == 0 || iteration == options.max_iter - 1 {
181 let full_data = data_provider.get_full_data();
182 let current_loss = grad_func.compute_value(&x.view(), &full_data);
183 func_evals += 1;
184
185 if current_loss < best_f {
187 best_f = current_loss;
188 best_x = x.clone();
189 stagnant_iterations = 0;
190 } else {
191 stagnant_iterations += 1;
192 }
193
194 if iteration % 100 == 0 {
196 let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
197 let rms_norm = s.mapv(|s| s.sqrt()).mean();
198 println!(
199 " Iteration {}: loss = {:.6e}, |grad| = {:.3e}, RMS = {:.3e}, lr = {:.3e}",
200 iteration, current_loss, grad_norm, rms_norm, current_lr
201 );
202 }
203
204 let loss_change = (prev_loss - current_loss).abs();
206 if loss_change < options.tol {
207 return Ok(OptimizeResult {
208 x: best_x,
209 fun: best_f,
210 nit: iteration,
211 func_evals,
212 nfev: func_evals,
213 success: true,
214 message: format!(
215 "RMSProp converged: loss change {:.2e} < {:.2e}",
216 loss_change, options.tol
217 ),
218 jacobian: Some(gradient),
219 hessian: None,
220 });
221 }
222
223 prev_loss = current_loss;
224
225 if stagnant_iterations > 50 {
227 return Ok(OptimizeResult {
228 x: best_x,
229 fun: best_f,
230 nit: iteration,
231 func_evals,
232 nfev: func_evals,
233 success: false,
234 message: "RMSProp stopped due to stagnation".to_string(),
235 jacobian: Some(gradient),
236 hessian: None,
237 });
238 }
239 }
240 }
241
242 let full_data = data_provider.get_full_data();
244 let final_loss = grad_func.compute_value(&best_x.view(), &full_data);
245 func_evals += 1;
246
247 Ok(OptimizeResult {
248 x: best_x,
249 fun: final_loss.min(best_f),
250 nit: options.max_iter,
251 func_evals,
252 nfev: func_evals,
253 success: false,
254 message: "RMSProp reached maximum iterations".to_string(),
255 jacobian: None,
256 hessian: None,
257 })
258}
259
260#[allow(dead_code)]
262pub fn minimize_graves_rmsprop<F>(
263 mut grad_func: F,
264 mut x: Array1<f64>,
265 data_provider: Box<dyn DataProvider>,
266 options: RMSPropOptions,
267) -> Result<OptimizeResult<f64>, OptimizeError>
268where
269 F: StochasticGradientFunction,
270{
271 let mut func_evals = 0;
272 let mut _grad_evals = 0;
273
274 let num_samples = data_provider.num_samples();
275 let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
276 let actual_batch_size = batch_size.min(num_samples);
277
278 let mut n: Array1<f64> = Array1::zeros(x.len()); let mut g: Array1<f64> = Array1::zeros(x.len()); let mut delta: Array1<f64> = Array1::zeros(x.len()); let mut best_x = x.clone();
284 let mut best_f = f64::INFINITY;
285
286 println!("Starting Graves' RMSProp optimization:");
287 println!(" Parameters: {}", x.len());
288 println!(" Dataset size: {}", num_samples);
289 println!(" Batch size: {}", actual_batch_size);
290
291 #[allow(clippy::explicit_counter_loop)]
292 for iteration in 0..options.max_iter {
293 let current_lr = update_learning_rate(
294 options.learning_rate,
295 iteration,
296 options.max_iter,
297 &options.lr_schedule,
298 );
299
300 let batch_indices = if actual_batch_size < num_samples {
302 generate_batch_indices(num_samples, actual_batch_size, true)
303 } else {
304 (0..num_samples).collect()
305 };
306
307 let batch_data = data_provider.get_batch(&batch_indices);
308 let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
309 _grad_evals += 1;
310
311 if let Some(clip_threshold) = options.gradient_clip {
312 clip_gradients(&mut gradient, clip_threshold);
313 }
314
315 n = &n * options.decay_rate + &gradient.mapv(|g| g * g) * (1.0 - options.decay_rate);
317 g = &g * options.decay_rate + &gradient * (1.0 - options.decay_rate);
318
319 let rms_n = n.mapv(|n_i| (n_i + options.epsilon).sqrt());
321 let scaled_gradient = &gradient / &rms_n;
326 let final_update = scaled_gradient.mapv_into_any(|g| g * current_lr);
327
328 x = &x - &final_update;
330 delta = &delta * options.decay_rate
331 + &final_update.mapv(|u| u * u) * (1.0 - options.decay_rate);
332
333 if iteration % 10 == 0 || iteration == options.max_iter - 1 {
335 let full_data = data_provider.get_full_data();
336 let current_loss = grad_func.compute_value(&x.view(), &full_data);
337 func_evals += 1;
338
339 if current_loss < best_f {
340 best_f = current_loss;
341 best_x = x.clone();
342 }
343
344 if iteration % 100 == 0 {
345 let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
346 println!(
347 " Iteration {}: loss = {:.6e}, |grad| = {:.3e}, lr = {:.3e}",
348 iteration, current_loss, grad_norm, current_lr
349 );
350 }
351 }
352 }
353
354 Ok(OptimizeResult {
355 x: best_x,
356 fun: best_f,
357 nit: options.max_iter,
358 func_evals,
359 nfev: func_evals,
360 success: false,
361 message: "Graves' RMSProp completed".to_string(),
362 jacobian: None,
363 hessian: None,
364 })
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::stochastic::InMemoryDataProvider;
371 use approx::assert_abs_diff_eq;
372 use scirs2_core::ndarray::ArrayView1;
373
374 struct QuadraticFunction;
376
377 impl StochasticGradientFunction for QuadraticFunction {
378 fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
379 x.mapv(|xi| 2.0 * xi)
381 }
382
383 fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
384 x.mapv(|xi| xi * xi).sum()
386 }
387 }
388
389 #[test]
390 fn test_rmsprop_quadratic() {
391 let grad_func = QuadraticFunction;
392 let x0 = Array1::from_vec(vec![1.0, 2.0, -1.5]);
393 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
394
395 let options = RMSPropOptions {
396 learning_rate: 0.1,
397 max_iter: 200,
398 tol: 1e-6,
399 ..Default::default()
400 };
401
402 let result = minimize_rmsprop(grad_func, x0, data_provider, options).unwrap();
403
404 assert!(result.success || result.fun < 1e-4);
406 for &xi in result.x.iter() {
407 assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-2);
408 }
409 }
410
411 #[test]
412 fn test_rmsprop_centered() {
413 let grad_func = QuadraticFunction;
414 let x0 = Array1::from_vec(vec![1.0, -1.0]);
415 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
416
417 let options = RMSPropOptions {
418 learning_rate: 0.1,
419 max_iter: 500,
420 batch_size: Some(10),
421 centered: true,
422 tol: 1e-6,
423 ..Default::default()
424 };
425
426 let result = minimize_rmsprop(grad_func, x0, data_provider, options).unwrap();
427
428 assert!(result.success || result.fun < 1e-4);
430 }
431
432 #[test]
433 fn test_rmsprop_with_momentum() {
434 let grad_func = QuadraticFunction;
435 let x0 = Array1::from_vec(vec![2.0, -2.0]);
436 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
437
438 let options = RMSPropOptions {
439 learning_rate: 0.01,
440 max_iter: 150,
441 batch_size: Some(20),
442 momentum: Some(0.9),
443 tol: 1e-6,
444 ..Default::default()
445 };
446
447 let result = minimize_rmsprop(grad_func, x0, data_provider, options).unwrap();
448
449 assert!(result.success || result.fun < 1e-3);
451 }
452
453 #[test]
454 fn test_graves_rmsprop() {
455 let grad_func = QuadraticFunction;
456 let x0 = Array1::from_vec(vec![1.5, -1.5]);
457 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
458
459 let options = RMSPropOptions {
460 learning_rate: 0.1,
461 max_iter: 500,
462 batch_size: Some(10),
463 tol: 1e-6,
464 ..Default::default()
465 };
466
467 let result = minimize_graves_rmsprop(grad_func, x0, data_provider, options).unwrap();
468
469 assert!(result.fun < 1.0);
471 }
472
473 #[test]
474 fn test_rmsprop_different_decay_rates() {
475 let _grad_func = QuadraticFunction;
476 let x0 = Array1::from_vec(vec![1.0, 1.0]);
477 let _data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
478
479 let decay_rates = [0.9, 0.95, 0.99, 0.999];
481
482 for &decay_rate in &decay_rates {
483 let options = RMSPropOptions {
484 learning_rate: 0.1,
485 decay_rate,
486 max_iter: 500,
487 tol: 1e-6,
488 ..Default::default()
489 };
490
491 let grad_func_clone = QuadraticFunction;
492 let x0_clone = x0.clone();
493 let data_provider_clone = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
494
495 let result =
496 minimize_rmsprop(grad_func_clone, x0_clone, data_provider_clone, options).unwrap();
497
498 assert!(result.fun < 1e-2, "Failed with decay rate {}", decay_rate);
500 }
501 }
502}