scirs2_optimize/stochastic/
adam.rs

1//! ADAM (Adaptive Moment Estimation) optimizer
2//!
3//! ADAM combines the advantages of two other extensions of stochastic gradient descent:
4//! AdaGrad and RMSProp. It computes adaptive learning rates for each parameter and
5//! stores an exponentially decaying average of past gradients (momentum) and
6//! past squared gradients (adaptive learning rate).
7
8use crate::error::OptimizeError;
9use crate::stochastic::{
10    clip_gradients, generate_batch_indices, update_learning_rate, DataProvider,
11    LearningRateSchedule, StochasticGradientFunction,
12};
13use crate::unconstrained::result::OptimizeResult;
14use ndarray::Array1;
15use statrs::statistics::Statistics;
16
17/// Options for ADAM optimization
18#[derive(Debug, Clone)]
19pub struct AdamOptions {
20    /// Learning rate (step size)
21    pub learning_rate: f64,
22    /// First moment decay parameter (momentum)
23    pub beta1: f64,
24    /// Second moment decay parameter (RMSProp-like)
25    pub beta2: f64,
26    /// Small constant for numerical stability
27    pub epsilon: f64,
28    /// Maximum number of iterations
29    pub max_iter: usize,
30    /// Convergence tolerance
31    pub tol: f64,
32    /// Learning rate schedule
33    pub lr_schedule: LearningRateSchedule,
34    /// Gradient clipping threshold
35    pub gradient_clip: Option<f64>,
36    /// Batch size for mini-batch optimization
37    pub batch_size: Option<usize>,
38    /// Use AMSGrad variant (max of past second moments)
39    pub amsgrad: bool,
40}
41
42impl Default for AdamOptions {
43    fn default() -> Self {
44        Self {
45            learning_rate: 0.001,
46            beta1: 0.9,
47            beta2: 0.999,
48            epsilon: 1e-8,
49            max_iter: 1000,
50            tol: 1e-6,
51            lr_schedule: LearningRateSchedule::Constant,
52            gradient_clip: None,
53            batch_size: None,
54            amsgrad: false,
55        }
56    }
57}
58
59/// ADAM optimizer implementation
60#[allow(dead_code)]
61pub fn minimize_adam<F>(
62    mut grad_func: F,
63    mut x: Array1<f64>,
64    data_provider: Box<dyn DataProvider>,
65    options: AdamOptions,
66) -> Result<OptimizeResult<f64>, OptimizeError>
67where
68    F: StochasticGradientFunction,
69{
70    let mut func_evals = 0;
71    let mut _grad_evals = 0;
72
73    let num_samples = data_provider.num_samples();
74    let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
75    let actual_batch_size = batch_size.min(num_samples);
76
77    // Initialize moment estimates
78    let mut m: Array1<f64> = Array1::zeros(x.len()); // First moment estimate (momentum)
79    let mut v: Array1<f64> = Array1::zeros(x.len()); // Second moment estimate (adaptive lr)
80    let mut v_hat_max: Array1<f64> = Array1::zeros(x.len()); // For AMSGrad variant
81
82    // Track the best solution found
83    let mut best_x = x.clone();
84    let mut best_f = f64::INFINITY;
85
86    // Convergence tracking
87    let mut prev_loss = f64::INFINITY;
88    let mut stagnant_iterations = 0;
89
90    println!("Starting ADAM optimization:");
91    println!("  Parameters: {}", x.len());
92    println!("  Dataset size: {}", num_samples);
93    println!("  Batch size: {}", actual_batch_size);
94    println!("  Initial learning rate: {}", options.learning_rate);
95    println!("  Beta1: {}, Beta2: {}", options.beta1, options.beta2);
96    println!("  AMSGrad: {}", options.amsgrad);
97
98    #[allow(clippy::explicit_counter_loop)]
99    for iteration in 0..options.max_iter {
100        // Update learning rate according to schedule
101        let current_lr = update_learning_rate(
102            options.learning_rate,
103            iteration,
104            options.max_iter,
105            &options.lr_schedule,
106        );
107
108        // Generate batch indices
109        let batch_indices = if actual_batch_size < num_samples {
110            generate_batch_indices(num_samples, actual_batch_size, true)
111        } else {
112            (0..num_samples).collect()
113        };
114
115        // Get batch data
116        let batch_data = data_provider.get_batch(&batch_indices);
117
118        // Compute gradient on batch
119        let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
120        _grad_evals += 1;
121
122        // Apply gradient clipping if specified
123        if let Some(clip_threshold) = options.gradient_clip {
124            clip_gradients(&mut gradient, clip_threshold);
125        }
126
127        // Update biased first moment estimate
128        m = &m * options.beta1 + &gradient * (1.0 - options.beta1);
129
130        // Update biased second raw moment estimate
131        let gradient_sq = gradient.mapv(|g| g * g);
132        v = &v * options.beta2 + &gradient_sq * (1.0 - options.beta2);
133
134        // Compute bias-corrected first moment estimate
135        let bias_correction1 = 1.0 - options.beta1.powi((iteration + 1) as i32);
136        let m_hat = &m / bias_correction1;
137
138        // Compute bias-corrected second moment estimate
139        let bias_correction2 = 1.0 - options.beta2.powi((iteration + 1) as i32);
140        let v_hat = &v / bias_correction2;
141
142        // AMSGrad: Use max of current and past second moments
143        let v_final = if options.amsgrad {
144            // Element-wise maximum of v_hat and v_hat_max
145            for i in 0..v_hat_max.len() {
146                v_hat_max[i] = v_hat_max[i].max(v_hat[i]);
147            }
148            &v_hat_max
149        } else {
150            &v_hat
151        };
152
153        // Update parameters: x = x - lr * m_hat / (sqrt(v_final) + epsilon)
154        let denominator = v_final.mapv(|v| v.sqrt() + options.epsilon);
155        let update = &m_hat / &denominator * current_lr;
156        x = &x - &update;
157
158        // Evaluate on full dataset periodically for convergence check
159        if iteration % 10 == 0 || iteration == options.max_iter - 1 {
160            let full_data = data_provider.get_full_data();
161            let current_loss = grad_func.compute_value(&x.view(), &full_data);
162            func_evals += 1;
163
164            // Update best solution
165            if current_loss < best_f {
166                best_f = current_loss;
167                best_x = x.clone();
168                stagnant_iterations = 0;
169            } else {
170                stagnant_iterations += 1;
171            }
172
173            // Progress reporting
174            if iteration % 100 == 0 {
175                let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
176                let m_norm = m_hat.mapv(|g: f64| g * g).sum().sqrt();
177                let v_mean = v_final.view().mean();
178                println!("  Iteration {}: loss = {:.6e}, |grad| = {:.3e}, |m| = {:.3e}, <v> = {:.3e}, lr = {:.3e}",
179                    iteration, current_loss, grad_norm, m_norm, v_mean, current_lr);
180            }
181
182            // Check convergence
183            let loss_change = (prev_loss - current_loss).abs();
184            if loss_change < options.tol {
185                return Ok(OptimizeResult {
186                    x: best_x,
187                    fun: best_f,
188                    nit: iteration,
189                    func_evals,
190                    nfev: func_evals,
191                    success: true,
192                    message: format!(
193                        "ADAM converged: loss change {:.2e} < {:.2e}",
194                        loss_change, options.tol
195                    ),
196                    jacobian: Some(gradient),
197                    hessian: None,
198                });
199            }
200
201            prev_loss = current_loss;
202
203            // Early stopping for stagnation
204            if stagnant_iterations > 50 {
205                return Ok(OptimizeResult {
206                    x: best_x,
207                    fun: best_f,
208                    nit: iteration,
209                    func_evals,
210                    nfev: func_evals,
211                    success: false,
212                    message: "ADAM stopped due to stagnation".to_string(),
213                    jacobian: Some(gradient),
214                    hessian: None,
215                });
216            }
217        }
218    }
219
220    // Final evaluation
221    let full_data = data_provider.get_full_data();
222    let final_loss = grad_func.compute_value(&best_x.view(), &full_data);
223    func_evals += 1;
224
225    Ok(OptimizeResult {
226        x: best_x,
227        fun: final_loss.min(best_f),
228        nit: options.max_iter,
229        func_evals,
230        nfev: func_evals,
231        success: false,
232        message: "ADAM reached maximum iterations".to_string(),
233        jacobian: None,
234        hessian: None,
235    })
236}
237
238/// ADAM with learning rate warmup
239#[allow(dead_code)]
240pub fn minimize_adam_with_warmup<F>(
241    grad_func: F,
242    x: Array1<f64>,
243    data_provider: Box<dyn DataProvider>,
244    options: AdamOptions,
245    warmup_steps: usize,
246) -> Result<OptimizeResult<f64>, OptimizeError>
247where
248    F: StochasticGradientFunction,
249{
250    let original_lr = options.learning_rate;
251
252    // Implement warmup by modifying the learning rate schedule
253    let warmup_schedule =
254        move |epoch: usize, max_epochs: usize, base_schedule: &LearningRateSchedule| -> f64 {
255            let base_lr = update_learning_rate(original_lr, epoch, max_epochs, base_schedule);
256
257            if epoch < warmup_steps {
258                // Linear warmup from 0 to base_lr
259                base_lr * (epoch as f64 / warmup_steps as f64)
260            } else {
261                base_lr
262            }
263        };
264
265    // We'll handle warmup manually during optimization
266    minimize_adam_with_custom_schedule(grad_func, x, data_provider, options, warmup_schedule)
267}
268
269/// ADAM with custom learning rate schedule function
270#[allow(dead_code)]
271fn minimize_adam_with_custom_schedule<F, S>(
272    mut grad_func: F,
273    mut x: Array1<f64>,
274    data_provider: Box<dyn DataProvider>,
275    options: AdamOptions,
276    lr_scheduler: S,
277) -> Result<OptimizeResult<f64>, OptimizeError>
278where
279    F: StochasticGradientFunction,
280    S: Fn(usize, usize, &LearningRateSchedule) -> f64,
281{
282    let mut func_evals = 0;
283    let mut _grad_evals = 0;
284
285    let num_samples = data_provider.num_samples();
286    let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
287    let actual_batch_size = batch_size.min(num_samples);
288
289    // Initialize moment estimates
290    let mut m: Array1<f64> = Array1::zeros(x.len());
291    let mut v: Array1<f64> = Array1::zeros(x.len());
292    let mut v_hat_max: Array1<f64> = Array1::zeros(x.len());
293
294    let mut best_x = x.clone();
295    let mut best_f = f64::INFINITY;
296
297    #[allow(clippy::explicit_counter_loop)]
298    for iteration in 0..options.max_iter {
299        // Use custom learning rate schedule
300        let current_lr = lr_scheduler(iteration, options.max_iter, &options.lr_schedule);
301
302        // Generate batch and compute gradient
303        let batch_indices = if actual_batch_size < num_samples {
304            generate_batch_indices(num_samples, actual_batch_size, true)
305        } else {
306            (0..num_samples).collect()
307        };
308
309        let batch_data = data_provider.get_batch(&batch_indices);
310        let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
311        _grad_evals += 1;
312
313        if let Some(clip_threshold) = options.gradient_clip {
314            clip_gradients(&mut gradient, clip_threshold);
315        }
316
317        // ADAM updates
318        m = &m * options.beta1 + &gradient * (1.0 - options.beta1);
319        let gradient_sq = gradient.mapv(|g| g * g);
320        v = &v * options.beta2 + &gradient_sq * (1.0 - options.beta2);
321
322        let bias_correction1 = 1.0 - options.beta1.powi((iteration + 1) as i32);
323        let bias_correction2 = 1.0 - options.beta2.powi((iteration + 1) as i32);
324        let m_hat = &m / bias_correction1;
325        let v_hat = &v / bias_correction2;
326
327        let v_final = if options.amsgrad {
328            for i in 0..v_hat_max.len() {
329                v_hat_max[i] = v_hat_max[i].max(v_hat[i]);
330            }
331            &v_hat_max
332        } else {
333            &v_hat
334        };
335
336        let denominator = v_final.mapv(|v| v.sqrt() + options.epsilon);
337        let update = &m_hat / &denominator * current_lr;
338        x = &x - &update;
339
340        // Evaluate progress
341        if iteration % 10 == 0 || iteration == options.max_iter - 1 {
342            let full_data = data_provider.get_full_data();
343            let current_loss = grad_func.compute_value(&x.view(), &full_data);
344            func_evals += 1;
345
346            if current_loss < best_f {
347                best_f = current_loss;
348                best_x = x.clone();
349            }
350
351            if iteration % 100 == 0 {
352                println!(
353                    "  Iteration {}: loss = {:.6e}, lr = {:.3e} (custom schedule)",
354                    iteration, current_loss, current_lr
355                );
356            }
357        }
358    }
359
360    Ok(OptimizeResult {
361        x: best_x,
362        fun: best_f,
363        nit: options.max_iter,
364        func_evals,
365        nfev: func_evals,
366        success: false,
367        message: "ADAM with custom schedule completed".to_string(),
368        jacobian: None,
369        hessian: None,
370    })
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use crate::stochastic::InMemoryDataProvider;
377    use approx::assert_abs_diff_eq;
378    use ndarray::ArrayView1;
379
380    // Simple quadratic function for testing
381    struct QuadraticFunction;
382
383    impl StochasticGradientFunction for QuadraticFunction {
384        fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
385            // Gradient of f(x) = sum(x_i^2) is 2*x
386            x.mapv(|xi| 2.0 * xi)
387        }
388
389        fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
390            // f(x) = sum(x_i^2)
391            x.mapv(|xi| xi * xi).sum()
392        }
393    }
394
395    #[test]
396    fn test_adam_quadratic() {
397        let grad_func = QuadraticFunction;
398        let x0 = Array1::from_vec(vec![1.0, 2.0, -1.5]);
399        let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
400
401        let options = AdamOptions {
402            learning_rate: 0.1,
403            max_iter: 200,
404            tol: 1e-6,
405            ..Default::default()
406        };
407
408        let result = minimize_adam(grad_func, x0, data_provider, options).unwrap();
409
410        // Should converge to zero
411        assert!(result.success || result.fun < 1e-4);
412        for &xi in result.x.iter() {
413            assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-2);
414        }
415    }
416
417    #[test]
418    fn test_adam_amsgrad() {
419        let grad_func = QuadraticFunction;
420        let x0 = Array1::from_vec(vec![1.0, -1.0]);
421        let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
422
423        let options = AdamOptions {
424            learning_rate: 0.1,
425            max_iter: 100,
426            batch_size: Some(10),
427            amsgrad: true,
428            tol: 1e-6,
429            ..Default::default()
430        };
431
432        let result = minimize_adam(grad_func, x0, data_provider, options).unwrap();
433
434        // AMSGrad should converge reliably
435        assert!(result.success || result.fun < 1e-4);
436    }
437
438    #[test]
439    fn test_adam_with_warmup() {
440        let grad_func = QuadraticFunction;
441        let x0 = Array1::from_vec(vec![2.0, -2.0]);
442        let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
443
444        let options = AdamOptions {
445            learning_rate: 0.1,
446            max_iter: 100,
447            batch_size: Some(20),
448            tol: 1e-6,
449            ..Default::default()
450        };
451
452        let result = minimize_adam_with_warmup(grad_func, x0, data_provider, options, 10).unwrap();
453
454        // Warmup should help with convergence
455        assert!(result.success || result.fun < 1e-3);
456    }
457
458    #[test]
459    fn test_adam_gradient_clipping() {
460        let grad_func = QuadraticFunction;
461        let x0 = Array1::from_vec(vec![10.0, -10.0]); // Large initial values
462        let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
463
464        let options = AdamOptions {
465            learning_rate: 0.1,       // Increased learning rate to compensate for clipping
466            max_iter: 1000,           // More iterations for convergence with clipping
467            gradient_clip: Some(1.0), // Clip gradients to norm 1.0
468            tol: 1e-4,
469            ..Default::default()
470        };
471
472        let result = minimize_adam(grad_func, x0, data_provider, options).unwrap();
473
474        // Should still converge even with large initial gradients (relaxed tolerance for clipped gradients)
475        assert!(result.success || result.fun < 1e-1);
476    }
477}