1use crate::error::OptimizeError;
9use crate::stochastic::{
10 clip_gradients, generate_batch_indices, update_learning_rate, DataProvider,
11 LearningRateSchedule, StochasticGradientFunction,
12};
13use crate::unconstrained::result::OptimizeResult;
14use ndarray::Array1;
15use statrs::statistics::Statistics;
16
17#[derive(Debug, Clone)]
19pub struct AdamOptions {
20 pub learning_rate: f64,
22 pub beta1: f64,
24 pub beta2: f64,
26 pub epsilon: f64,
28 pub max_iter: usize,
30 pub tol: f64,
32 pub lr_schedule: LearningRateSchedule,
34 pub gradient_clip: Option<f64>,
36 pub batch_size: Option<usize>,
38 pub amsgrad: bool,
40}
41
42impl Default for AdamOptions {
43 fn default() -> Self {
44 Self {
45 learning_rate: 0.001,
46 beta1: 0.9,
47 beta2: 0.999,
48 epsilon: 1e-8,
49 max_iter: 1000,
50 tol: 1e-6,
51 lr_schedule: LearningRateSchedule::Constant,
52 gradient_clip: None,
53 batch_size: None,
54 amsgrad: false,
55 }
56 }
57}
58
59#[allow(dead_code)]
61pub fn minimize_adam<F>(
62 mut grad_func: F,
63 mut x: Array1<f64>,
64 data_provider: Box<dyn DataProvider>,
65 options: AdamOptions,
66) -> Result<OptimizeResult<f64>, OptimizeError>
67where
68 F: StochasticGradientFunction,
69{
70 let mut func_evals = 0;
71 let mut _grad_evals = 0;
72
73 let num_samples = data_provider.num_samples();
74 let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
75 let actual_batch_size = batch_size.min(num_samples);
76
77 let mut m: Array1<f64> = Array1::zeros(x.len()); let mut v: Array1<f64> = Array1::zeros(x.len()); let mut v_hat_max: Array1<f64> = Array1::zeros(x.len()); let mut best_x = x.clone();
84 let mut best_f = f64::INFINITY;
85
86 let mut prev_loss = f64::INFINITY;
88 let mut stagnant_iterations = 0;
89
90 println!("Starting ADAM optimization:");
91 println!(" Parameters: {}", x.len());
92 println!(" Dataset size: {}", num_samples);
93 println!(" Batch size: {}", actual_batch_size);
94 println!(" Initial learning rate: {}", options.learning_rate);
95 println!(" Beta1: {}, Beta2: {}", options.beta1, options.beta2);
96 println!(" AMSGrad: {}", options.amsgrad);
97
98 #[allow(clippy::explicit_counter_loop)]
99 for iteration in 0..options.max_iter {
100 let current_lr = update_learning_rate(
102 options.learning_rate,
103 iteration,
104 options.max_iter,
105 &options.lr_schedule,
106 );
107
108 let batch_indices = if actual_batch_size < num_samples {
110 generate_batch_indices(num_samples, actual_batch_size, true)
111 } else {
112 (0..num_samples).collect()
113 };
114
115 let batch_data = data_provider.get_batch(&batch_indices);
117
118 let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
120 _grad_evals += 1;
121
122 if let Some(clip_threshold) = options.gradient_clip {
124 clip_gradients(&mut gradient, clip_threshold);
125 }
126
127 m = &m * options.beta1 + &gradient * (1.0 - options.beta1);
129
130 let gradient_sq = gradient.mapv(|g| g * g);
132 v = &v * options.beta2 + &gradient_sq * (1.0 - options.beta2);
133
134 let bias_correction1 = 1.0 - options.beta1.powi((iteration + 1) as i32);
136 let m_hat = &m / bias_correction1;
137
138 let bias_correction2 = 1.0 - options.beta2.powi((iteration + 1) as i32);
140 let v_hat = &v / bias_correction2;
141
142 let v_final = if options.amsgrad {
144 for i in 0..v_hat_max.len() {
146 v_hat_max[i] = v_hat_max[i].max(v_hat[i]);
147 }
148 &v_hat_max
149 } else {
150 &v_hat
151 };
152
153 let denominator = v_final.mapv(|v| v.sqrt() + options.epsilon);
155 let update = &m_hat / &denominator * current_lr;
156 x = &x - &update;
157
158 if iteration % 10 == 0 || iteration == options.max_iter - 1 {
160 let full_data = data_provider.get_full_data();
161 let current_loss = grad_func.compute_value(&x.view(), &full_data);
162 func_evals += 1;
163
164 if current_loss < best_f {
166 best_f = current_loss;
167 best_x = x.clone();
168 stagnant_iterations = 0;
169 } else {
170 stagnant_iterations += 1;
171 }
172
173 if iteration % 100 == 0 {
175 let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
176 let m_norm = m_hat.mapv(|g: f64| g * g).sum().sqrt();
177 let v_mean = v_final.view().mean();
178 println!(" Iteration {}: loss = {:.6e}, |grad| = {:.3e}, |m| = {:.3e}, <v> = {:.3e}, lr = {:.3e}",
179 iteration, current_loss, grad_norm, m_norm, v_mean, current_lr);
180 }
181
182 let loss_change = (prev_loss - current_loss).abs();
184 if loss_change < options.tol {
185 return Ok(OptimizeResult {
186 x: best_x,
187 fun: best_f,
188 nit: iteration,
189 func_evals,
190 nfev: func_evals,
191 success: true,
192 message: format!(
193 "ADAM converged: loss change {:.2e} < {:.2e}",
194 loss_change, options.tol
195 ),
196 jacobian: Some(gradient),
197 hessian: None,
198 });
199 }
200
201 prev_loss = current_loss;
202
203 if stagnant_iterations > 50 {
205 return Ok(OptimizeResult {
206 x: best_x,
207 fun: best_f,
208 nit: iteration,
209 func_evals,
210 nfev: func_evals,
211 success: false,
212 message: "ADAM stopped due to stagnation".to_string(),
213 jacobian: Some(gradient),
214 hessian: None,
215 });
216 }
217 }
218 }
219
220 let full_data = data_provider.get_full_data();
222 let final_loss = grad_func.compute_value(&best_x.view(), &full_data);
223 func_evals += 1;
224
225 Ok(OptimizeResult {
226 x: best_x,
227 fun: final_loss.min(best_f),
228 nit: options.max_iter,
229 func_evals,
230 nfev: func_evals,
231 success: false,
232 message: "ADAM reached maximum iterations".to_string(),
233 jacobian: None,
234 hessian: None,
235 })
236}
237
238#[allow(dead_code)]
240pub fn minimize_adam_with_warmup<F>(
241 grad_func: F,
242 x: Array1<f64>,
243 data_provider: Box<dyn DataProvider>,
244 options: AdamOptions,
245 warmup_steps: usize,
246) -> Result<OptimizeResult<f64>, OptimizeError>
247where
248 F: StochasticGradientFunction,
249{
250 let original_lr = options.learning_rate;
251
252 let warmup_schedule =
254 move |epoch: usize, max_epochs: usize, base_schedule: &LearningRateSchedule| -> f64 {
255 let base_lr = update_learning_rate(original_lr, epoch, max_epochs, base_schedule);
256
257 if epoch < warmup_steps {
258 base_lr * (epoch as f64 / warmup_steps as f64)
260 } else {
261 base_lr
262 }
263 };
264
265 minimize_adam_with_custom_schedule(grad_func, x, data_provider, options, warmup_schedule)
267}
268
269#[allow(dead_code)]
271fn minimize_adam_with_custom_schedule<F, S>(
272 mut grad_func: F,
273 mut x: Array1<f64>,
274 data_provider: Box<dyn DataProvider>,
275 options: AdamOptions,
276 lr_scheduler: S,
277) -> Result<OptimizeResult<f64>, OptimizeError>
278where
279 F: StochasticGradientFunction,
280 S: Fn(usize, usize, &LearningRateSchedule) -> f64,
281{
282 let mut func_evals = 0;
283 let mut _grad_evals = 0;
284
285 let num_samples = data_provider.num_samples();
286 let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
287 let actual_batch_size = batch_size.min(num_samples);
288
289 let mut m: Array1<f64> = Array1::zeros(x.len());
291 let mut v: Array1<f64> = Array1::zeros(x.len());
292 let mut v_hat_max: Array1<f64> = Array1::zeros(x.len());
293
294 let mut best_x = x.clone();
295 let mut best_f = f64::INFINITY;
296
297 #[allow(clippy::explicit_counter_loop)]
298 for iteration in 0..options.max_iter {
299 let current_lr = lr_scheduler(iteration, options.max_iter, &options.lr_schedule);
301
302 let batch_indices = if actual_batch_size < num_samples {
304 generate_batch_indices(num_samples, actual_batch_size, true)
305 } else {
306 (0..num_samples).collect()
307 };
308
309 let batch_data = data_provider.get_batch(&batch_indices);
310 let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
311 _grad_evals += 1;
312
313 if let Some(clip_threshold) = options.gradient_clip {
314 clip_gradients(&mut gradient, clip_threshold);
315 }
316
317 m = &m * options.beta1 + &gradient * (1.0 - options.beta1);
319 let gradient_sq = gradient.mapv(|g| g * g);
320 v = &v * options.beta2 + &gradient_sq * (1.0 - options.beta2);
321
322 let bias_correction1 = 1.0 - options.beta1.powi((iteration + 1) as i32);
323 let bias_correction2 = 1.0 - options.beta2.powi((iteration + 1) as i32);
324 let m_hat = &m / bias_correction1;
325 let v_hat = &v / bias_correction2;
326
327 let v_final = if options.amsgrad {
328 for i in 0..v_hat_max.len() {
329 v_hat_max[i] = v_hat_max[i].max(v_hat[i]);
330 }
331 &v_hat_max
332 } else {
333 &v_hat
334 };
335
336 let denominator = v_final.mapv(|v| v.sqrt() + options.epsilon);
337 let update = &m_hat / &denominator * current_lr;
338 x = &x - &update;
339
340 if iteration % 10 == 0 || iteration == options.max_iter - 1 {
342 let full_data = data_provider.get_full_data();
343 let current_loss = grad_func.compute_value(&x.view(), &full_data);
344 func_evals += 1;
345
346 if current_loss < best_f {
347 best_f = current_loss;
348 best_x = x.clone();
349 }
350
351 if iteration % 100 == 0 {
352 println!(
353 " Iteration {}: loss = {:.6e}, lr = {:.3e} (custom schedule)",
354 iteration, current_loss, current_lr
355 );
356 }
357 }
358 }
359
360 Ok(OptimizeResult {
361 x: best_x,
362 fun: best_f,
363 nit: options.max_iter,
364 func_evals,
365 nfev: func_evals,
366 success: false,
367 message: "ADAM with custom schedule completed".to_string(),
368 jacobian: None,
369 hessian: None,
370 })
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use crate::stochastic::InMemoryDataProvider;
377 use approx::assert_abs_diff_eq;
378 use ndarray::ArrayView1;
379
380 struct QuadraticFunction;
382
383 impl StochasticGradientFunction for QuadraticFunction {
384 fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
385 x.mapv(|xi| 2.0 * xi)
387 }
388
389 fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
390 x.mapv(|xi| xi * xi).sum()
392 }
393 }
394
395 #[test]
396 fn test_adam_quadratic() {
397 let grad_func = QuadraticFunction;
398 let x0 = Array1::from_vec(vec![1.0, 2.0, -1.5]);
399 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
400
401 let options = AdamOptions {
402 learning_rate: 0.1,
403 max_iter: 200,
404 tol: 1e-6,
405 ..Default::default()
406 };
407
408 let result = minimize_adam(grad_func, x0, data_provider, options).unwrap();
409
410 assert!(result.success || result.fun < 1e-4);
412 for &xi in result.x.iter() {
413 assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-2);
414 }
415 }
416
417 #[test]
418 fn test_adam_amsgrad() {
419 let grad_func = QuadraticFunction;
420 let x0 = Array1::from_vec(vec![1.0, -1.0]);
421 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
422
423 let options = AdamOptions {
424 learning_rate: 0.1,
425 max_iter: 100,
426 batch_size: Some(10),
427 amsgrad: true,
428 tol: 1e-6,
429 ..Default::default()
430 };
431
432 let result = minimize_adam(grad_func, x0, data_provider, options).unwrap();
433
434 assert!(result.success || result.fun < 1e-4);
436 }
437
438 #[test]
439 fn test_adam_with_warmup() {
440 let grad_func = QuadraticFunction;
441 let x0 = Array1::from_vec(vec![2.0, -2.0]);
442 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
443
444 let options = AdamOptions {
445 learning_rate: 0.1,
446 max_iter: 100,
447 batch_size: Some(20),
448 tol: 1e-6,
449 ..Default::default()
450 };
451
452 let result = minimize_adam_with_warmup(grad_func, x0, data_provider, options, 10).unwrap();
453
454 assert!(result.success || result.fun < 1e-3);
456 }
457
458 #[test]
459 fn test_adam_gradient_clipping() {
460 let grad_func = QuadraticFunction;
461 let x0 = Array1::from_vec(vec![10.0, -10.0]); let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
463
464 let options = AdamOptions {
465 learning_rate: 0.1, max_iter: 1000, gradient_clip: Some(1.0), tol: 1e-4,
469 ..Default::default()
470 };
471
472 let result = minimize_adam(grad_func, x0, data_provider, options).unwrap();
473
474 assert!(result.success || result.fun < 1e-1);
476 }
477}