1use crate::error::OptimizeError;
8use crate::stochastic::{
9 clip_gradients, generate_batch_indices, update_learning_rate, DataProvider,
10 LearningRateSchedule, StochasticGradientFunction,
11};
12use crate::unconstrained::result::OptimizeResult;
13use scirs2_core::ndarray::Array1;
14use statrs::statistics::Statistics;
15
16#[derive(Debug, Clone)]
18pub struct AdamWOptions {
19 pub learning_rate: f64,
21 pub beta1: f64,
23 pub beta2: f64,
25 pub epsilon: f64,
27 pub weight_decay: f64,
29 pub max_iter: usize,
31 pub tol: f64,
33 pub lr_schedule: LearningRateSchedule,
35 pub gradient_clip: Option<f64>,
37 pub batch_size: Option<usize>,
39 pub decouple_weight_decay: bool,
41}
42
43impl Default for AdamWOptions {
44 fn default() -> Self {
45 Self {
46 learning_rate: 0.001,
47 beta1: 0.9,
48 beta2: 0.999,
49 epsilon: 1e-8,
50 weight_decay: 0.01,
51 max_iter: 1000,
52 tol: 1e-6,
53 lr_schedule: LearningRateSchedule::Constant,
54 gradient_clip: None,
55 batch_size: None,
56 decouple_weight_decay: true,
57 }
58 }
59}
60
61#[allow(dead_code)]
63pub fn minimize_adamw<F>(
64 mut grad_func: F,
65 mut x: Array1<f64>,
66 data_provider: Box<dyn DataProvider>,
67 options: AdamWOptions,
68) -> Result<OptimizeResult<f64>, OptimizeError>
69where
70 F: StochasticGradientFunction,
71{
72 let mut func_evals = 0;
73 let mut _grad_evals = 0;
74
75 let num_samples = data_provider.num_samples();
76 let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
77 let actual_batch_size = batch_size.min(num_samples);
78
79 let mut m: Array1<f64> = Array1::zeros(x.len()); let mut v: Array1<f64> = Array1::zeros(x.len()); let mut best_x = x.clone();
85 let mut best_f = f64::INFINITY;
86
87 let mut prev_loss = f64::INFINITY;
89 let mut stagnant_iterations = 0;
90
91 println!("Starting AdamW optimization:");
92 println!(" Parameters: {}", x.len());
93 println!(" Dataset size: {}", num_samples);
94 println!(" Batch size: {}", actual_batch_size);
95 println!(" Initial learning rate: {}", options.learning_rate);
96 println!(" Beta1: {}, Beta2: {}", options.beta1, options.beta2);
97 println!(" Weight decay: {}", options.weight_decay);
98 println!(" Decoupled: {}", options.decouple_weight_decay);
99
100 #[allow(clippy::explicit_counter_loop)]
101 for iteration in 0..options.max_iter {
102 let current_lr = update_learning_rate(
104 options.learning_rate,
105 iteration,
106 options.max_iter,
107 &options.lr_schedule,
108 );
109
110 let batch_indices = if actual_batch_size < num_samples {
112 generate_batch_indices(num_samples, actual_batch_size, true)
113 } else {
114 (0..num_samples).collect()
115 };
116
117 let batch_data = data_provider.get_batch(&batch_indices);
119
120 let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
122 _grad_evals += 1;
123
124 if let Some(clip_threshold) = options.gradient_clip {
126 clip_gradients(&mut gradient, clip_threshold);
127 }
128
129 if options.decouple_weight_decay && options.weight_decay > 0.0 {
131 x = &x * (1.0 - current_lr * options.weight_decay);
133 } else if options.weight_decay > 0.0 {
134 gradient = &gradient + &x * options.weight_decay;
136 }
137
138 m = &m * options.beta1 + &gradient * (1.0 - options.beta1);
140
141 let gradient_sq = gradient.mapv(|g| g * g);
143 v = &v * options.beta2 + &gradient_sq * (1.0 - options.beta2);
144
145 let bias_correction1 = 1.0 - options.beta1.powi((iteration + 1) as i32);
147 let m_hat = &m / bias_correction1;
148
149 let bias_correction2 = 1.0 - options.beta2.powi((iteration + 1) as i32);
151 let v_hat = &v / bias_correction2;
152
153 let denominator = v_hat.mapv(|v: f64| v.sqrt() + options.epsilon);
155 let gradient_update = &m_hat / &denominator * current_lr;
156 x = &x - &gradient_update;
157
158 if iteration % 10 == 0 || iteration == options.max_iter - 1 {
160 let full_data = data_provider.get_full_data();
161 let current_loss = grad_func.compute_value(&x.view(), &full_data);
162 func_evals += 1;
163
164 if current_loss < best_f {
166 best_f = current_loss;
167 best_x = x.clone();
168 stagnant_iterations = 0;
169 } else {
170 stagnant_iterations += 1;
171 }
172
173 if iteration % 100 == 0 {
175 let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
176 let param_norm = x.mapv(|p| p * p).sum().sqrt();
177 let m_norm = m_hat.mapv(|g: f64| g * g).sum().sqrt();
178 let v_mean = v_hat.mean();
179 println!(" Iteration {}: loss = {:.6e}, |grad| = {:.3e}, |param| = {:.3e}, |m| = {:.3e}, <v> = {:.3e}, lr = {:.3e}",
180 iteration, current_loss, grad_norm, param_norm, m_norm, v_mean, current_lr);
181 }
182
183 let loss_change = (prev_loss - current_loss).abs();
185 if loss_change < options.tol {
186 return Ok(OptimizeResult {
187 x: best_x,
188 fun: best_f,
189 nit: iteration,
190 func_evals,
191 nfev: func_evals,
192 success: true,
193 message: format!(
194 "AdamW converged: loss change {:.2e} < {:.2e}",
195 loss_change, options.tol
196 ),
197 jacobian: Some(gradient),
198 hessian: None,
199 });
200 }
201
202 prev_loss = current_loss;
203
204 if stagnant_iterations > 50 {
206 return Ok(OptimizeResult {
207 x: best_x,
208 fun: best_f,
209 nit: iteration,
210 func_evals,
211 nfev: func_evals,
212 success: false,
213 message: "AdamW stopped due to stagnation".to_string(),
214 jacobian: Some(gradient),
215 hessian: None,
216 });
217 }
218 }
219 }
220
221 let full_data = data_provider.get_full_data();
223 let final_loss = grad_func.compute_value(&best_x.view(), &full_data);
224 func_evals += 1;
225
226 Ok(OptimizeResult {
227 x: best_x,
228 fun: final_loss.min(best_f),
229 nit: options.max_iter,
230 func_evals,
231 nfev: func_evals,
232 success: false,
233 message: "AdamW reached maximum iterations".to_string(),
234 jacobian: None,
235 hessian: None,
236 })
237}
238
239#[allow(dead_code)]
241pub fn minimize_adamw_cosine_restarts<F>(
242 mut grad_func: F,
243 mut x: Array1<f64>,
244 data_provider: Box<dyn DataProvider>,
245 options: AdamWOptions,
246 t_initial: usize,
247 t_mult: f64,
248 eta_min: f64,
249) -> Result<OptimizeResult<f64>, OptimizeError>
250where
251 F: StochasticGradientFunction,
252{
253 let mut current_cycle_length = t_initial;
255 let mut cycle_start = 0;
256 let mut restart_count = 0;
257 let initial_lr = options.learning_rate;
258 let total_max_iter = options.max_iter; let mut global_best_x = x.clone();
262 let mut global_best_f = f64::INFINITY;
263
264 while cycle_start < total_max_iter {
265 let cycle_end = (cycle_start + current_cycle_length).min(total_max_iter);
266
267 println!(
268 "Starting restart {} (cycle {}-{}, length {})",
269 restart_count, cycle_start, cycle_end, current_cycle_length
270 );
271
272 let mut cycle_options = options.clone();
274 cycle_options.lr_schedule = LearningRateSchedule::CosineAnnealing;
275 cycle_options.max_iter = cycle_end - cycle_start;
276 cycle_options.learning_rate = initial_lr;
277
278 let cycle_result = minimize_adamw_cycle(
280 &mut grad_func,
281 x.clone(),
282 data_provider.as_ref(),
283 &cycle_options,
284 initial_lr,
285 eta_min,
286 cycle_start,
287 )?;
288
289 if cycle_result.fun < global_best_f {
291 global_best_f = cycle_result.fun;
292 global_best_x = cycle_result.x.clone();
293 }
294
295 x = cycle_result.x; cycle_start = cycle_end;
298 current_cycle_length = (current_cycle_length as f64 * t_mult) as usize;
299 restart_count += 1;
300
301 if global_best_f < options.tol {
303 break;
304 }
305 }
306
307 Ok(OptimizeResult {
308 x: global_best_x,
309 fun: global_best_f,
310 nit: cycle_start,
311 func_evals: 0, nfev: 0,
313 success: global_best_f < options.tol,
314 message: format!(
315 "AdamW with cosine restarts completed ({} restarts)",
316 restart_count
317 ),
318 jacobian: None,
319 hessian: None,
320 })
321}
322
323#[allow(dead_code)]
325fn minimize_adamw_cycle<F>(
326 grad_func: &mut F,
327 mut x: Array1<f64>,
328 data_provider: &dyn DataProvider,
329 options: &AdamWOptions,
330 lr_max: f64,
331 lr_min: f64,
332 cycle_offset: usize,
333) -> Result<OptimizeResult<f64>, OptimizeError>
334where
335 F: StochasticGradientFunction,
336{
337 let mut m: Array1<f64> = Array1::zeros(x.len());
338 let mut v: Array1<f64> = Array1::zeros(x.len());
339 let mut best_x = x.clone();
340 let mut best_f = f64::INFINITY;
341
342 let num_samples = data_provider.num_samples();
343 let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
344 let actual_batch_size = batch_size.min(num_samples);
345
346 #[allow(clippy::explicit_counter_loop)]
347 for iteration in 0..options.max_iter {
348 let progress = iteration as f64 / options.max_iter as f64;
350 let current_lr =
351 lr_min + 0.5 * (lr_max - lr_min) * (1.0 + (std::f64::consts::PI * progress).cos());
352
353 let batch_indices = if actual_batch_size < num_samples {
355 generate_batch_indices(num_samples, actual_batch_size, true)
356 } else {
357 (0..num_samples).collect()
358 };
359
360 let batch_data = data_provider.get_batch(&batch_indices);
361 let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
362
363 if let Some(clip_threshold) = options.gradient_clip {
364 clip_gradients(&mut gradient, clip_threshold);
365 }
366
367 if options.decouple_weight_decay && options.weight_decay > 0.0 {
369 x = &x * (1.0 - current_lr * options.weight_decay);
370 }
371
372 m = &m * options.beta1 + &gradient * (1.0 - options.beta1);
374 let gradient_sq = gradient.mapv(|g| g * g);
375 v = &v * options.beta2 + &gradient_sq * (1.0 - options.beta2);
376
377 let global_step = cycle_offset + iteration + 1;
378 let bias_correction1 = 1.0 - options.beta1.powi(global_step as i32);
379 let bias_correction2 = 1.0 - options.beta2.powi(global_step as i32);
380
381 let m_hat = &m / bias_correction1;
382 let v_hat = &v / bias_correction2;
383
384 let denominator = v_hat.mapv(|v: f64| v.sqrt() + options.epsilon);
385 let update = &m_hat / &denominator * current_lr;
386 x = &x - &update;
387
388 if iteration % 10 == 0 {
390 let full_data = data_provider.get_full_data();
391 let current_loss = grad_func.compute_value(&x.view(), &full_data);
392
393 if current_loss < best_f {
394 best_f = current_loss;
395 best_x = x.clone();
396 }
397 }
398 }
399
400 Ok(OptimizeResult {
401 x: best_x,
402 fun: best_f,
403 nit: options.max_iter,
404 func_evals: 0,
405 nfev: 0,
406 success: false,
407 message: "Cycle completed".to_string(),
408 jacobian: None,
409 hessian: None,
410 })
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use crate::stochastic::InMemoryDataProvider;
417 use approx::assert_abs_diff_eq;
418 use scirs2_core::ndarray::ArrayView1;
419
420 struct QuadraticFunction;
422
423 impl StochasticGradientFunction for QuadraticFunction {
424 fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
425 x.mapv(|xi| 2.0 * xi)
427 }
428
429 fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
430 x.mapv(|xi| xi * xi).sum()
432 }
433 }
434
435 #[test]
436 fn test_adamw_quadratic() {
437 let grad_func = QuadraticFunction;
438 let x0 = Array1::from_vec(vec![1.0, 2.0, -1.5]);
439 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
440
441 let options = AdamWOptions {
442 learning_rate: 0.1,
443 max_iter: 200,
444 tol: 1e-6,
445 ..Default::default()
446 };
447
448 let result = minimize_adamw(grad_func, x0, data_provider, options).unwrap();
449
450 assert!(result.success || result.fun < 1e-4);
452 for &xi in result.x.iter() {
453 assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-2);
454 }
455 }
456
457 #[test]
458 fn test_adamw_weight_decay() {
459 let grad_func = QuadraticFunction;
460 let x0 = Array1::from_vec(vec![1.0, -1.0]);
461 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
462
463 let options = AdamWOptions {
464 learning_rate: 0.1,
465 weight_decay: 0.01,
466 max_iter: 100,
467 batch_size: Some(10),
468 tol: 1e-6,
469 ..Default::default()
470 };
471
472 let result = minimize_adamw(grad_func, x0, data_provider, options).unwrap();
473
474 assert!(result.success || result.fun < 1e-4);
476 }
477
478 #[test]
479 fn test_adamw_decoupled_vs_coupled() {
480 let x0 = Array1::from_vec(vec![2.0, -2.0]);
481 let data_provider1 = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
482 let data_provider2 = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
483
484 let options_decoupled = AdamWOptions {
486 learning_rate: 0.01,
487 weight_decay: 0.1,
488 decouple_weight_decay: true,
489 max_iter: 500,
490 tol: 1e-4,
491 ..Default::default()
492 };
493
494 let grad_func1 = QuadraticFunction;
495 let result_decoupled =
496 minimize_adamw(grad_func1, x0.clone(), data_provider1, options_decoupled).unwrap();
497
498 let options_coupled = AdamWOptions {
500 learning_rate: 0.01,
501 weight_decay: 0.1,
502 decouple_weight_decay: false,
503 max_iter: 500, tol: 1e-4,
505 ..Default::default()
506 };
507
508 let grad_func2 = QuadraticFunction;
509 let result_coupled =
510 minimize_adamw(grad_func2, x0, data_provider2, options_coupled).unwrap();
511
512 assert!(result_decoupled.fun < 1.0);
514 assert!(result_coupled.fun < 1.0);
515 }
516
517 #[test]
518 fn test_adamw_cosine_restarts() {
519 let grad_func = QuadraticFunction;
520 let x0 = Array1::from_vec(vec![3.0, -3.0]);
521 let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
522
523 let options = AdamWOptions {
524 learning_rate: 0.1,
525 max_iter: 500,
526 tol: 1e-4,
527 ..Default::default()
528 };
529
530 let result = minimize_adamw_cosine_restarts(
531 grad_func,
532 x0,
533 data_provider,
534 options,
535 50, 1.5, 1e-6, )
539 .unwrap();
540
541 assert!(result.fun < 10.0); }
544
545 #[test]
546 fn test_adamw_gradient_clipping() {
547 let grad_func = QuadraticFunction;
548 let x0 = Array1::from_vec(vec![10.0, -10.0]); let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
550
551 let options = AdamWOptions {
552 learning_rate: 0.1, max_iter: 1000, gradient_clip: Some(1.0), tol: 1e-4,
556 ..Default::default()
557 };
558
559 let result = minimize_adamw(grad_func, x0, data_provider, options).unwrap();
560
561 assert!(result.success || result.fun < 1e-1);
563 }
564}