train_station/optimizers/adam/mod.rs
1//! Adam optimizer implementation for neural network training
2//!
3//! This module provides the Adam optimization algorithm with PyTorch-compatible interface.
4//! Adam combines the benefits of AdaGrad and RMSprop by using adaptive learning rates
5//! with momentum for efficient training of neural networks.
6//!
7//! # Features
8//!
9//! - **Adaptive Learning Rates**: Per-parameter learning rates based on gradient history
10//! - **Momentum Integration**: Combines momentum with adaptive learning rates
11//! - **Bias Correction**: Corrects for initialization bias in moment estimates
12//! - **Weight Decay**: Optional L2 regularization support
13//! - **AMSGrad Variant**: Optional AMSGrad for improved convergence stability
14//! - **SIMD Optimization**: AVX2-optimized parameter updates for maximum performance
15//! - **Thread Safety**: Send + Sync implementation for multi-threaded training
16//! - **Hybrid API**: Both safe (RwLock-based) and unsafe (direct pointer) access patterns
17//!
18//! # Thread Safety
19//!
20//! The optimizer provides two usage patterns:
21//!
22//! **Safe Multi-threaded Usage (Default)**:
23//! - Uses `Arc<RwLock<Tensor>>` for thread-safe parameter access
24//! - Multiple threads can read tensors simultaneously
25//! - Optimizer steps acquire write locks only during parameter updates
26//! - Recommended for most use cases
27//!
28//! **Unsafe Single-threaded Usage (Performance)**:
29//! - Uses raw pointers for maximum performance
30//! - No locking overhead during optimizer steps
31//! - Caller must ensure exclusive access during optimization
32//! - Use only when you can guarantee no concurrent tensor access
33//!
34//! # Algorithm
35//!
36//! Adam implements the following update rule for each parameter theta:
37//!
38//! ```text
39//! m_t = beta1 * m_{t-1} + (1 - beta1) * grad_theta_t
40//! v_t = beta2 * v_{t-1} + (1 - beta2) * (grad_theta_t)^2
41//! m_hat_t = m_t / (1 - beta1^t)
42//! v_hat_t = v_t / (1 - beta2^t)
43//! theta_{t+1} = theta_t - lr * m_hat_t / (sqrt(v_hat_t) + eps)
44//! ```
45//!
46//! Where:
47//! - lr is the learning rate
48//! - beta1, beta2 are exponential decay rates for moment estimates
49//! - eps is a small constant for numerical stability
50//! - m_t, v_t are biased first and second moment estimates
51//! - m_hat_t, v_hat_t are bias-corrected moment estimates
52//!
53//! # Performance Characteristics
54//!
55//! - **SIMD Optimization**: Uses AVX2 instructions for 8x vectorization when available
56//! - **Memory Efficiency**: In-place updates with minimal temporary allocations
57//! - **Cache-Friendly**: Optimized memory access patterns for large parameter tensors
58//! - **Zero-Cost Abstractions**: Compile-time optimization with minimal runtime overhead
59//! - **Lock-Free Reads**: RwLock allows concurrent tensor reads during training
60//!
61//! # References
62//!
63//! - Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization.
64//! - PyTorch Adam implementation: <https://pytorch.org/docs/stable/generated/torch.optim.Adam.html>
65
66pub mod serialization;
67
68use super::Optimizer;
69use crate::tensor::core::Tensor;
70use std::collections::HashMap;
71
72// SIMD optimizations for performance-critical operations
73#[cfg(target_arch = "x86_64")]
74use std::arch::x86_64::*;
75
76/// Configuration for the Adam optimization algorithm
77///
78/// Contains all hyperparameters that control the behavior of Adam optimization.
79/// Default values follow PyTorch conventions for maximum compatibility and
80/// optimal convergence across a wide range of neural network architectures.
81///
82/// # Fields
83///
84/// * `learning_rate` - Step size for parameter updates (default: 1e-3)
85/// * `beta1` - Exponential decay rate for first moment estimates (default: 0.9)
86/// * `beta2` - Exponential decay rate for second moment estimates (default: 0.999)
87/// * `eps` - Small constant for numerical stability (default: 1e-8)
88/// * `weight_decay` - L2 regularization coefficient (default: 0.0)
89/// * `amsgrad` - Whether to use AMSGrad variant for improved stability (default: false)
90#[derive(Debug, Clone)]
91pub struct AdamConfig {
92 /// Learning rate for parameter updates (default: 1e-3)
93 pub learning_rate: f32,
94 /// Exponential decay rate for first moment estimates (default: 0.9)
95 pub beta1: f32,
96 /// Exponential decay rate for second moment estimates (default: 0.999)
97 pub beta2: f32,
98 /// Small constant for numerical stability (default: 1e-8)
99 pub eps: f32,
100 /// Weight decay coefficient for L2 regularization (default: 0.0)
101 pub weight_decay: f32,
102 /// Whether to use AMSGrad variant (default: false)
103 pub amsgrad: bool,
104}
105
106impl Default for AdamConfig {
107 fn default() -> Self {
108 Self {
109 learning_rate: 1e-3,
110 beta1: 0.9,
111 beta2: 0.999,
112 eps: 1e-8,
113 weight_decay: 0.0,
114 amsgrad: false,
115 }
116 }
117}
118
119/// Internal state tracking for a single parameter during Adam optimization
120///
121/// Stores the momentum and velocity buffers needed for Adam optimization,
122/// along with step count for bias correction and optional AMSGrad state.
123/// Memory layout is optimized for cache efficiency and SIMD operations.
124///
125/// # Fields
126///
127/// * `m` - First moment estimate (momentum buffer)
128/// * `v` - Second moment estimate (velocity buffer)
129/// * `v_hat_max` - Maximum of second moment estimates (for AMSGrad variant)
130/// * `step` - Step count for bias correction calculations
131#[derive(Debug)]
132struct ParameterState {
133 /// First moment estimate (momentum)
134 m: Tensor,
135 /// Second moment estimate (velocity)
136 v: Tensor,
137 /// Maximum of second moment estimates (for AMSGrad)
138 v_hat_max: Option<Tensor>,
139 /// Step count for bias correction
140 step: usize,
141}
142
143impl ParameterState {
144 fn new(param_shape: &[usize]) -> Self {
145 Self {
146 m: Tensor::zeros(param_shape.to_vec()),
147 v: Tensor::zeros(param_shape.to_vec()),
148 v_hat_max: None,
149 step: 0,
150 }
151 }
152}
153
154/// Adam optimizer for neural network parameter optimization
155///
156/// Implements the Adam optimization algorithm with PyTorch-compatible interface.
157/// Provides adaptive learning rates with momentum for efficient training of neural networks.
158/// The optimizer maintains per-parameter state for momentum and velocity estimates,
159/// enabling adaptive learning rates that improve convergence across diverse architectures.
160///
161/// # Usage Pattern
162///
163/// The optimizer uses ID-based parameter linking for maximum flexibility and thread safety:
164/// - Parameters are linked to the optimizer via `add_parameter` or `add_parameters`
165/// - The `step` method takes mutable references to parameters for thread-safe updates
166/// - Parameter states are maintained by tensor ID, allowing for dynamic parameter management
167/// - Supports serialization and deserialization with parameter re-linking
168///
169/// # Dynamic Parameter Management
170///
171/// Parameters can be added, removed, or re-linked at runtime:
172/// - `add_parameter`: Link a single parameter
173/// - `add_parameters`: Link multiple parameters at once
174/// - `unlink_parameter`: Remove parameter state by ID
175/// - `clear_states`: Remove all parameter states
176/// - `is_parameter_linked`: Check if a parameter is linked
177///
178/// # Serialization Support
179///
180/// The optimizer supports full serialization and deserialization with state preservation:
181/// - Parameter states are saved with their shapes and insertion order for validation
182/// - After deserialization, use `relink_parameters` to restore saved states to new tensors
183/// - Parameters must be re-linked in the same chronological order they were originally added
184/// - Shape validation ensures consistency between saved and current parameters
185///
186/// # Features
187///
188/// - **ID-Based Parameter Linking**: Dynamic parameter management via tensor IDs
189/// - **Thread-Safe Step Method**: Takes mutable references for safe concurrent access
190/// - **Per-Parameter State**: Each parameter maintains its own momentum and velocity buffers
191/// - **Bias Correction**: Automatically corrects initialization bias in moment estimates
192/// - **Weight Decay**: Optional L2 regularization with efficient implementation
193/// - **AMSGrad Support**: Optional AMSGrad variant for improved convergence stability
194/// - **SIMD Optimization**: AVX2-optimized updates for maximum performance
195/// - **Full Serialization**: Complete state persistence and restoration
196///
197/// # Thread Safety
198///
199/// This type is thread-safe and can be shared between threads. The step method
200/// takes mutable references to parameters, ensuring exclusive access during updates.
201pub struct Adam {
202 /// Optimizer configuration
203 config: AdamConfig,
204 /// Parameter states indexed by tensor ID
205 states: HashMap<usize, ParameterState>,
206 /// Current global step count
207 step_count: usize,
208 /// Order in which parameters were added (for serialization/re-linking)
209 insertion_order: Vec<usize>,
210}
211
212impl Default for Adam {
213 fn default() -> Self {
214 Self::with_config(AdamConfig::default())
215 }
216}
217
218impl Adam {
219 /// Create a new Adam optimizer with default configuration
220 ///
221 /// Initializes an Adam optimizer with PyTorch-compatible default hyperparameters.
222 /// Parameters must be linked separately using `add_parameter` or `add_parameters`.
223 ///
224 /// # Returns
225 ///
226 /// A new Adam optimizer instance with default hyperparameters
227 #[track_caller]
228 pub fn new() -> Self {
229 Self::default()
230 }
231
232 /// Create a new Adam optimizer with custom configuration
233 ///
234 /// Allows full control over all Adam hyperparameters for specialized training
235 /// scenarios such as fine-tuning, transfer learning, or research applications.
236 /// Parameters must be linked separately using `add_parameter` or `add_parameters`.
237 ///
238 /// # Arguments
239 ///
240 /// * `config` - Adam configuration with custom hyperparameters
241 ///
242 /// # Returns
243 ///
244 /// A new Adam optimizer instance with the specified configuration
245 #[track_caller]
246 pub fn with_config(config: AdamConfig) -> Self {
247 Self {
248 config,
249 states: HashMap::new(),
250 step_count: 0,
251 insertion_order: Vec::new(),
252 }
253 }
254
255 /// Create a new Adam optimizer with custom learning rate
256 ///
257 /// A convenience constructor that allows setting only the learning rate while
258 /// using default values for all other hyperparameters. Parameters must be
259 /// linked separately using `add_parameter` or `add_parameters`.
260 ///
261 /// # Arguments
262 ///
263 /// * `learning_rate` - Learning rate for optimization
264 ///
265 /// # Returns
266 ///
267 /// A new Adam optimizer instance with the specified learning rate and default
268 /// values for all other hyperparameters
269 #[track_caller]
270 pub fn with_learning_rate(learning_rate: f32) -> Self {
271 let config = AdamConfig {
272 learning_rate,
273 ..Default::default()
274 };
275 Self::with_config(config)
276 }
277
278 /// Add a single parameter to the optimizer
279 ///
280 /// Links a parameter to the optimizer by creating a new parameter state
281 /// indexed by the tensor's ID. The parameter must have `requires_grad` set to true.
282 ///
283 /// # Arguments
284 ///
285 /// * `parameter` - Reference to the tensor to link
286 ///
287 /// # Panics
288 ///
289 /// Panics if the parameter does not have `requires_grad` set to true
290 #[track_caller]
291 pub fn add_parameter(&mut self, parameter: &Tensor) {
292 assert!(
293 parameter.requires_grad(),
294 "Parameter must require gradients"
295 );
296
297 let param_id = parameter.id();
298 let param_shape = parameter.shape().dims();
299
300 // Initialize state for this parameter if not already present
301 use std::collections::hash_map::Entry;
302 if let Entry::Vacant(entry) = self.states.entry(param_id) {
303 entry.insert(ParameterState::new(param_shape));
304 self.insertion_order.push(param_id);
305 }
306 }
307
308 /// Add multiple parameters to the optimizer
309 ///
310 /// Links multiple parameters to the optimizer by creating parameter states
311 /// indexed by each tensor's ID. All parameters must have `requires_grad` set to true.
312 ///
313 /// # Arguments
314 ///
315 /// * `parameters` - Slice of references to tensors to link
316 ///
317 /// # Panics
318 ///
319 /// Panics if any parameter does not have `requires_grad` set to true
320 #[track_caller]
321 pub fn add_parameters(&mut self, parameters: &[&Tensor]) {
322 for parameter in parameters {
323 self.add_parameter(parameter);
324 }
325 }
326
327 /// Remove a parameter from the optimizer
328 ///
329 /// Unlinks a parameter by removing its state from the optimizer.
330 /// The parameter ID is used for identification.
331 ///
332 /// # Arguments
333 ///
334 /// * `parameter` - Reference to the tensor to unlink
335 ///
336 /// # Returns
337 ///
338 /// True if the parameter was linked and removed, false if it was not linked
339 #[track_caller]
340 pub fn unlink_parameter(&mut self, parameter: &Tensor) -> bool {
341 let param_id = parameter.id();
342 let was_linked = self.states.remove(¶m_id).is_some();
343 if was_linked {
344 self.insertion_order.retain(|&id| id != param_id);
345 }
346 was_linked
347 }
348
349 /// Remove all parameter states from the optimizer
350 ///
351 /// Clears all parameter states, effectively unlinking all parameters.
352 /// This is useful for resetting the optimizer or preparing for parameter re-linking.
353 #[track_caller]
354 pub fn clear_states(&mut self) {
355 self.states.clear();
356 self.insertion_order.clear();
357 }
358
359 /// Check if a parameter is linked to the optimizer
360 ///
361 /// Returns true if the parameter has an associated state in the optimizer.
362 ///
363 /// # Arguments
364 ///
365 /// * `parameter` - Reference to the tensor to check
366 ///
367 /// # Returns
368 ///
369 /// True if the parameter is linked, false otherwise
370 #[track_caller]
371 pub fn is_parameter_linked(&self, parameter: &Tensor) -> bool {
372 let param_id = parameter.id();
373 self.states.contains_key(¶m_id)
374 }
375
376 /// Get the number of linked parameters
377 ///
378 /// Returns the count of parameters currently linked to the optimizer.
379 ///
380 /// # Returns
381 ///
382 /// Number of linked parameters
383 #[track_caller]
384 pub fn parameter_count(&self) -> usize {
385 self.states.len()
386 }
387
388 /// Re-link parameters to saved optimizer states in chronological order
389 ///
390 /// After deserializing an optimizer, use this method to restore saved parameter states
391 /// to new tensors. Parameters must be provided in the same chronological order they
392 /// were originally added to the optimizer. Shape validation ensures parameter compatibility.
393 ///
394 /// # Arguments
395 ///
396 /// * `parameters` - Slice of parameter references in chronological order
397 ///
398 /// # Returns
399 ///
400 /// Result indicating success or failure with detailed error message
401 ///
402 /// # Panics
403 ///
404 /// Panics if any parameter does not have `requires_grad` set to true
405 #[track_caller]
406 pub fn relink_parameters(&mut self, parameters: &[&Tensor]) -> Result<(), String> {
407 // Validate all parameters have requires_grad first
408 for (i, param) in parameters.iter().enumerate() {
409 if !param.requires_grad() {
410 return Err(format!("Parameter at index {} must require gradients", i));
411 }
412 }
413
414 // Check parameter count matches saved states
415 if parameters.len() != self.insertion_order.len() {
416 return Err(format!(
417 "Parameter count mismatch: expected {} parameters, got {}",
418 self.insertion_order.len(),
419 parameters.len()
420 ));
421 }
422
423 // Create new states map with parameter IDs mapped to saved states in chronological order
424 let mut new_states = HashMap::new();
425 let mut new_insertion_order = Vec::new();
426
427 for (i, param) in parameters.iter().enumerate() {
428 let new_param_id = param.id();
429 let old_param_id = self.insertion_order[i];
430
431 // Get the saved state for this position
432 let saved_state = self
433 .states
434 .get(&old_param_id)
435 .ok_or_else(|| format!("No saved state found for parameter at position {}", i))?;
436
437 // Validate shape matches
438 let param_shape = ¶m.shape().dims();
439 let saved_shape = &saved_state.m.shape().dims();
440 if param_shape != saved_shape {
441 return Err(format!(
442 "Shape mismatch for parameter at position {}: expected {:?}, got {:?}",
443 i, saved_shape, param_shape
444 ));
445 }
446
447 // Create new state for this parameter
448 let new_state = ParameterState {
449 m: saved_state.m.clone(),
450 v: saved_state.v.clone(),
451 v_hat_max: saved_state.v_hat_max.clone(),
452 step: saved_state.step,
453 };
454
455 new_states.insert(new_param_id, new_state);
456 new_insertion_order.push(new_param_id);
457 }
458
459 // Replace the states and insertion order
460 self.states = new_states;
461 self.insertion_order = new_insertion_order;
462
463 Ok(())
464 }
465
466 /// Get the current optimizer configuration
467 ///
468 /// Returns a reference to the current configuration, allowing inspection
469 /// of all hyperparameters without modification.
470 ///
471 /// # Returns
472 ///
473 /// Reference to the current Adam configuration
474 #[track_caller]
475 pub fn config(&self) -> &AdamConfig {
476 &self.config
477 }
478
479 /// Update a single parameter using Adam algorithm
480 ///
481 /// Implements the core Adam update rule with bias correction and optional AMSGrad.
482 /// Uses SIMD optimization when available for improved performance.
483 /// The parameter must be linked to the optimizer before calling this method.
484 ///
485 /// # Arguments
486 ///
487 /// * `param` - Mutable reference to the parameter tensor to update
488 ///
489 /// # Returns
490 ///
491 /// Result indicating success or failure of the parameter update
492 ///
493 /// # Panics
494 ///
495 /// Panics if the parameter is not linked to the optimizer
496 fn update_parameter(&mut self, param: &mut Tensor) -> Result<(), String> {
497 let param_id = param.id();
498
499 // Ensure parameter is linked
500 assert!(
501 self.states.contains_key(¶m_id),
502 "Parameter must be linked to optimizer before stepping. Use add_parameter() first."
503 );
504
505 // Get parameter gradient
506 let grad = param
507 .grad_owned()
508 .ok_or_else(|| format!("Parameter {} has no gradient", param_id))?;
509
510 // Get parameter state
511 let state = self
512 .states
513 .get_mut(¶m_id)
514 .expect("Parameter state should exist after link check");
515
516 // Increment per-parameter step count and use it for bias correction
517 state.step += 1;
518 let step = state.step as f32;
519
520 // Apply weight decay if enabled
521 let effective_grad = if self.config.weight_decay > 0.0 {
522 // L2 regularization: grad + weight_decay * param
523 let mut grad_with_decay = grad.clone();
524 Self::add_weight_decay(&mut grad_with_decay, param, self.config.weight_decay);
525 grad_with_decay
526 } else {
527 grad
528 };
529
530 // Update biased first moment estimate: m = beta1 * m + (1 - beta1) * grad
531 Self::update_momentum(&mut state.m, &effective_grad, self.config.beta1);
532
533 // Update biased second moment estimate: v = beta2 * v + (1 - beta2) * grad^2
534 Self::update_velocity(&mut state.v, &effective_grad, self.config.beta2);
535
536 // Compute bias-corrected first moment estimate
537 let bias_correction1 = 1.0 - (self.config.beta1).powf(step);
538 let m_hat = Self::bias_correct(&state.m, bias_correction1);
539
540 // Compute bias-corrected second moment estimate
541 let bias_correction2 = 1.0 - (self.config.beta2).powf(step);
542 let mut v_hat = Self::bias_correct(&state.v, bias_correction2);
543
544 // AMSGrad: use maximum of v_hat over time
545 if self.config.amsgrad {
546 if state.v_hat_max.is_none() {
547 state.v_hat_max = Some(v_hat.clone());
548 }
549 let v_hat_max = state.v_hat_max.as_mut().unwrap();
550 Self::element_wise_max(v_hat_max, &v_hat);
551 v_hat = v_hat_max.clone();
552 }
553
554 // Compute parameter update: param = param - lr * m_hat / (sqrt(v_hat) + eps)
555 Self::apply_adam_update(
556 param,
557 &m_hat,
558 &v_hat,
559 self.config.learning_rate,
560 self.config.eps,
561 );
562
563 Ok(())
564 }
565
566 /// Apply weight decay (L2 regularization) to gradient
567 ///
568 /// Adds `weight_decay * param` to the gradient in-place for memory efficiency.
569 /// This implements L2 regularization by modifying the gradient before the
570 /// Adam update step, equivalent to adding a regularization term to the loss.
571 ///
572 /// # Arguments
573 ///
574 /// * `grad` - Gradient tensor to modify in-place
575 /// * `param` - Parameter tensor for weight decay calculation
576 /// * `weight_decay` - Weight decay coefficient
577 #[inline]
578 fn add_weight_decay(grad: &mut Tensor, param: &Tensor, weight_decay: f32) {
579 assert_eq!(
580 grad.size(),
581 param.size(),
582 "Gradient and parameter size mismatch"
583 );
584
585 unsafe {
586 let grad_ptr = grad.as_mut_ptr();
587 let param_ptr = param.as_ptr();
588 let size = grad.size();
589
590 #[cfg(target_arch = "x86_64")]
591 {
592 if is_x86_feature_detected!("avx2")
593 && grad.is_simd_aligned()
594 && param.is_simd_aligned()
595 {
596 Self::add_weight_decay_simd_avx2(grad_ptr, param_ptr, weight_decay, size);
597 return;
598 }
599 }
600
601 // Scalar fallback
602 for i in 0..size {
603 *grad_ptr.add(i) += weight_decay * *param_ptr.add(i);
604 }
605 }
606 }
607
608 #[cfg(target_arch = "x86_64")]
609 #[inline]
610 unsafe fn add_weight_decay_simd_avx2(
611 grad_ptr: *mut f32,
612 param_ptr: *const f32,
613 weight_decay: f32,
614 size: usize,
615 ) {
616 let decay_vec = _mm256_set1_ps(weight_decay);
617 let simd_count = size / 8;
618
619 for i in 0..simd_count {
620 let offset = i * 8;
621 let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
622 let param_vec = _mm256_loadu_ps(param_ptr.add(offset));
623 let decay_term = _mm256_mul_ps(decay_vec, param_vec);
624 let result = _mm256_add_ps(grad_vec, decay_term);
625 _mm256_storeu_ps(grad_ptr.add(offset), result);
626 }
627
628 // Handle remaining elements
629 for i in (simd_count * 8)..size {
630 *grad_ptr.add(i) += weight_decay * *param_ptr.add(i);
631 }
632 }
633
634 /// Update momentum (first moment estimate)
635 ///
636 /// Implements the momentum update rule: `m = beta1 * m + (1 - beta1) * grad`
637 /// This computes the exponentially decaying average of gradients, providing
638 /// momentum-like behavior that helps accelerate convergence in relevant directions.
639 ///
640 /// # Arguments
641 ///
642 /// * `momentum` - Momentum buffer to update in-place
643 /// * `grad` - Current gradient tensor
644 /// * `beta1` - Exponential decay rate for momentum
645 #[inline]
646 fn update_momentum(momentum: &mut Tensor, grad: &Tensor, beta1: f32) {
647 assert_eq!(
648 momentum.size(),
649 grad.size(),
650 "Momentum and gradient size mismatch"
651 );
652
653 let beta1_complement = 1.0 - beta1;
654
655 unsafe {
656 let m_ptr = momentum.as_mut_ptr();
657 let grad_ptr = grad.as_ptr();
658 let size = momentum.size();
659
660 #[cfg(target_arch = "x86_64")]
661 {
662 if is_x86_feature_detected!("avx2")
663 && momentum.is_simd_aligned()
664 && grad.is_simd_aligned()
665 {
666 Self::update_momentum_simd_avx2(m_ptr, grad_ptr, beta1, beta1_complement, size);
667 return;
668 }
669 }
670
671 // Scalar fallback
672 for i in 0..size {
673 *m_ptr.add(i) = beta1 * *m_ptr.add(i) + beta1_complement * *grad_ptr.add(i);
674 }
675 }
676 }
677
678 #[cfg(target_arch = "x86_64")]
679 #[inline]
680 unsafe fn update_momentum_simd_avx2(
681 m_ptr: *mut f32,
682 grad_ptr: *const f32,
683 beta1: f32,
684 beta1_complement: f32,
685 size: usize,
686 ) {
687 let beta1_vec = _mm256_set1_ps(beta1);
688 let beta1_comp_vec = _mm256_set1_ps(beta1_complement);
689 let simd_count = size / 8;
690
691 for i in 0..simd_count {
692 let offset = i * 8;
693 let m_vec = _mm256_loadu_ps(m_ptr.add(offset));
694 let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
695
696 let momentum_term = _mm256_mul_ps(beta1_vec, m_vec);
697 let gradient_term = _mm256_mul_ps(beta1_comp_vec, grad_vec);
698 let result = _mm256_add_ps(momentum_term, gradient_term);
699
700 _mm256_storeu_ps(m_ptr.add(offset), result);
701 }
702
703 // Handle remaining elements
704 for i in (simd_count * 8)..size {
705 *m_ptr.add(i) = beta1 * *m_ptr.add(i) + beta1_complement * *grad_ptr.add(i);
706 }
707 }
708
709 /// Update velocity (second moment estimate)
710 ///
711 /// Implements the velocity update rule: `v = beta2 * v + (1 - beta2) * grad^2`
712 /// This computes the exponentially decaying average of squared gradients,
713 /// providing adaptive learning rates that scale inversely with gradient magnitude.
714 ///
715 /// # Arguments
716 ///
717 /// * `velocity` - Velocity buffer to update in-place
718 /// * `grad` - Current gradient tensor
719 /// * `beta2` - Exponential decay rate for velocity
720 #[inline]
721 fn update_velocity(velocity: &mut Tensor, grad: &Tensor, beta2: f32) {
722 assert_eq!(
723 velocity.size(),
724 grad.size(),
725 "Velocity and gradient size mismatch"
726 );
727
728 let beta2_complement = 1.0 - beta2;
729
730 unsafe {
731 let v_ptr = velocity.as_mut_ptr();
732 let grad_ptr = grad.as_ptr();
733 let size = velocity.size();
734
735 #[cfg(target_arch = "x86_64")]
736 {
737 if is_x86_feature_detected!("avx2")
738 && velocity.is_simd_aligned()
739 && grad.is_simd_aligned()
740 {
741 Self::update_velocity_simd_avx2(v_ptr, grad_ptr, beta2, beta2_complement, size);
742 return;
743 }
744 }
745
746 // Scalar fallback
747 for i in 0..size {
748 let grad_val = *grad_ptr.add(i);
749 *v_ptr.add(i) = beta2 * *v_ptr.add(i) + beta2_complement * grad_val * grad_val;
750 }
751 }
752 }
753
754 #[cfg(target_arch = "x86_64")]
755 #[inline]
756 unsafe fn update_velocity_simd_avx2(
757 v_ptr: *mut f32,
758 grad_ptr: *const f32,
759 beta2: f32,
760 beta2_complement: f32,
761 size: usize,
762 ) {
763 let beta2_vec = _mm256_set1_ps(beta2);
764 let beta2_comp_vec = _mm256_set1_ps(beta2_complement);
765 let simd_count = size / 8;
766
767 for i in 0..simd_count {
768 let offset = i * 8;
769 let v_vec = _mm256_loadu_ps(v_ptr.add(offset));
770 let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
771
772 let velocity_term = _mm256_mul_ps(beta2_vec, v_vec);
773 let grad_squared = _mm256_mul_ps(grad_vec, grad_vec);
774 let gradient_term = _mm256_mul_ps(beta2_comp_vec, grad_squared);
775 let result = _mm256_add_ps(velocity_term, gradient_term);
776
777 _mm256_storeu_ps(v_ptr.add(offset), result);
778 }
779
780 // Handle remaining elements
781 for i in (simd_count * 8)..size {
782 let grad_val = *grad_ptr.add(i);
783 *v_ptr.add(i) = beta2 * *v_ptr.add(i) + beta2_complement * grad_val * grad_val;
784 }
785 }
786
787 /// Apply bias correction to moment estimates
788 ///
789 /// Returns `tensor / (1 - beta^step)` for unbiasing the moment estimates.
790 /// This correction is necessary because the moment estimates are initialized
791 /// to zero, creating a bias towards zero that becomes negligible as training progresses.
792 ///
793 /// # Arguments
794 ///
795 /// * `tensor` - Moment estimate tensor to correct
796 /// * `bias_correction` - Bias correction factor (1 - beta^step)
797 ///
798 /// # Returns
799 ///
800 /// Bias-corrected tensor with the same shape as input
801 #[inline]
802 fn bias_correct(tensor: &Tensor, bias_correction: f32) -> Tensor {
803 let mut result = tensor.clone();
804 let correction_factor = 1.0 / bias_correction;
805
806 unsafe {
807 let ptr = result.as_mut_ptr();
808 let size = result.size();
809
810 #[cfg(target_arch = "x86_64")]
811 {
812 if is_x86_feature_detected!("avx2") && result.is_simd_aligned() {
813 Self::bias_correct_simd_avx2(ptr, correction_factor, size);
814 return result;
815 }
816 }
817
818 // Scalar fallback
819 for i in 0..size {
820 *ptr.add(i) *= correction_factor;
821 }
822 }
823
824 result
825 }
826
827 #[cfg(target_arch = "x86_64")]
828 #[inline]
829 unsafe fn bias_correct_simd_avx2(ptr: *mut f32, correction_factor: f32, size: usize) {
830 let factor_vec = _mm256_set1_ps(correction_factor);
831 let simd_count = size / 8;
832
833 for i in 0..simd_count {
834 let offset = i * 8;
835 let data_vec = _mm256_loadu_ps(ptr.add(offset));
836 let result = _mm256_mul_ps(data_vec, factor_vec);
837 _mm256_storeu_ps(ptr.add(offset), result);
838 }
839
840 // Handle remaining elements
841 for i in (simd_count * 8)..size {
842 *ptr.add(i) *= correction_factor;
843 }
844 }
845
846 /// Element-wise maximum for AMSGrad
847 ///
848 /// Updates first tensor in-place with `max(first, second)` for AMSGrad variant.
849 /// This maintains a running maximum of the second moment estimates, preventing
850 /// the learning rate from increasing over time and improving convergence stability.
851 ///
852 /// # Arguments
853 ///
854 /// * `first` - Tensor to update in-place with maximum values
855 /// * `second` - Tensor to compare against for maximum calculation
856 #[inline]
857 fn element_wise_max(first: &mut Tensor, second: &Tensor) {
858 assert_eq!(
859 first.size(),
860 second.size(),
861 "Tensor size mismatch for element-wise max"
862 );
863
864 unsafe {
865 let first_ptr = first.as_mut_ptr();
866 let second_ptr = second.as_ptr();
867 let size = first.size();
868
869 #[cfg(target_arch = "x86_64")]
870 {
871 if is_x86_feature_detected!("avx2")
872 && first.is_simd_aligned()
873 && second.is_simd_aligned()
874 {
875 Self::element_wise_max_simd_avx2(first_ptr, second_ptr, size);
876 return;
877 }
878 }
879
880 // Scalar fallback
881 for i in 0..size {
882 let a = *first_ptr.add(i);
883 let b = *second_ptr.add(i);
884 *first_ptr.add(i) = if a > b { a } else { b };
885 }
886 }
887 }
888
889 #[cfg(target_arch = "x86_64")]
890 #[inline]
891 unsafe fn element_wise_max_simd_avx2(first_ptr: *mut f32, second_ptr: *const f32, size: usize) {
892 let simd_count = size / 8;
893
894 for i in 0..simd_count {
895 let offset = i * 8;
896 let first_vec = _mm256_loadu_ps(first_ptr.add(offset));
897 let second_vec = _mm256_loadu_ps(second_ptr.add(offset));
898 let result = _mm256_max_ps(first_vec, second_vec);
899 _mm256_storeu_ps(first_ptr.add(offset), result);
900 }
901
902 // Handle remaining elements
903 for i in (simd_count * 8)..size {
904 let a = *first_ptr.add(i);
905 let b = *second_ptr.add(i);
906 *first_ptr.add(i) = if a > b { a } else { b };
907 }
908 }
909
910 /// Apply the final Adam parameter update
911 ///
912 /// Implements the core Adam update rule: `param = param - lr * m_hat / (sqrt(v_hat) + eps)`
913 /// This applies the bias-corrected momentum and velocity estimates to update
914 /// the parameter values, with adaptive learning rates that scale inversely
915 /// with the square root of the velocity estimates.
916 ///
917 /// # Arguments
918 ///
919 /// * `param` - Parameter tensor to update in-place
920 /// * `m_hat` - Bias-corrected first moment estimate
921 /// * `v_hat` - Bias-corrected second moment estimate
922 /// * `learning_rate` - Learning rate for the update
923 /// * `eps` - Small constant for numerical stability
924 #[inline]
925 fn apply_adam_update(
926 param: &mut Tensor,
927 m_hat: &Tensor,
928 v_hat: &Tensor,
929 learning_rate: f32,
930 eps: f32,
931 ) {
932 assert_eq!(
933 param.size(),
934 m_hat.size(),
935 "Parameter and momentum size mismatch"
936 );
937 assert_eq!(
938 param.size(),
939 v_hat.size(),
940 "Parameter and velocity size mismatch"
941 );
942
943 unsafe {
944 let param_ptr = param.as_mut_ptr();
945 let m_ptr = m_hat.as_ptr();
946 let v_ptr = v_hat.as_ptr();
947 let size = param.size();
948
949 #[cfg(target_arch = "x86_64")]
950 {
951 if is_x86_feature_detected!("avx2")
952 && param.is_simd_aligned()
953 && m_hat.is_simd_aligned()
954 && v_hat.is_simd_aligned()
955 {
956 Self::apply_adam_update_simd_avx2(
957 param_ptr,
958 m_ptr,
959 v_ptr,
960 learning_rate,
961 eps,
962 size,
963 );
964 return;
965 }
966 }
967
968 // Scalar fallback
969 for i in 0..size {
970 let m_val = *m_ptr.add(i);
971 let v_val = *v_ptr.add(i);
972 let denominator = v_val.sqrt() + eps;
973 *param_ptr.add(i) -= learning_rate * m_val / denominator;
974 }
975 }
976 }
977
978 #[cfg(target_arch = "x86_64")]
979 #[inline]
980 unsafe fn apply_adam_update_simd_avx2(
981 param_ptr: *mut f32,
982 m_ptr: *const f32,
983 v_ptr: *const f32,
984 learning_rate: f32,
985 eps: f32,
986 size: usize,
987 ) {
988 let lr_vec = _mm256_set1_ps(learning_rate);
989 let eps_vec = _mm256_set1_ps(eps);
990 let simd_count = size / 8;
991
992 for i in 0..simd_count {
993 let offset = i * 8;
994 let param_vec = _mm256_loadu_ps(param_ptr.add(offset));
995 let m_vec = _mm256_loadu_ps(m_ptr.add(offset));
996 let v_vec = _mm256_loadu_ps(v_ptr.add(offset));
997
998 // sqrt(v_hat) + eps
999 let sqrt_v = _mm256_sqrt_ps(v_vec);
1000 let denominator = _mm256_add_ps(sqrt_v, eps_vec);
1001
1002 // lr * m_hat / denominator
1003 let lr_m = _mm256_mul_ps(lr_vec, m_vec);
1004 let update = _mm256_div_ps(lr_m, denominator);
1005
1006 // param - update
1007 let result = _mm256_sub_ps(param_vec, update);
1008 _mm256_storeu_ps(param_ptr.add(offset), result);
1009 }
1010
1011 // Handle remaining elements
1012 for i in (simd_count * 8)..size {
1013 let m_val = *m_ptr.add(i);
1014 let v_val = *v_ptr.add(i);
1015 let denominator = v_val.sqrt() + eps;
1016 *param_ptr.add(i) -= learning_rate * m_val / denominator;
1017 }
1018 }
1019}
1020
1021impl Optimizer for Adam {
1022 /// Perform a single optimization step
1023 ///
1024 /// Updates all provided parameters based on their accumulated gradients using the Adam algorithm.
1025 /// Each parameter is updated according to the Adam update rule with bias correction
1026 /// and optional AMSGrad variant if enabled. All parameters must be linked to the optimizer
1027 /// before calling this method.
1028 ///
1029 /// # Arguments
1030 ///
1031 /// * `parameters` - Mutable slice of parameter references to update
1032 ///
1033 /// # Thread Safety
1034 ///
1035 /// This method is thread-safe as it takes mutable references to parameters,
1036 /// ensuring exclusive access during updates.
1037 ///
1038 /// # Performance
1039 ///
1040 /// - Uses SIMD optimization (AVX2) when available for 8x vectorization
1041 /// - Processes parameters in sequence for optimal cache usage
1042 /// - Maintains per-parameter state for momentum and velocity estimates
1043 ///
1044 /// # Panics
1045 ///
1046 /// Panics if any parameter is not linked to the optimizer
1047 fn step(&mut self, parameters: &mut [&mut Tensor]) {
1048 self.step_count += 1;
1049
1050 for param in parameters {
1051 if let Err(e) = self.update_parameter(param) {
1052 eprintln!("Warning: Failed to update parameter: {}", e);
1053 }
1054 }
1055 }
1056
1057 /// Zero out all parameter gradients
1058 ///
1059 /// Clears accumulated gradients for all provided parameters. This should be called
1060 /// before each backward pass to prevent gradient accumulation across multiple
1061 /// forward/backward passes. Also clears the global autograd gradient map.
1062 ///
1063 /// # Arguments
1064 ///
1065 /// * `parameters` - Mutable slice of parameter references to clear gradients for
1066 ///
1067 /// # Performance
1068 ///
1069 /// - Efficiently clears gradients using optimized tensor operations
1070 /// - Clears both per-tensor gradients and global autograd state
1071 /// - Thread-safe as it takes mutable references to parameters
1072 fn zero_grad(&mut self, parameters: &mut [&mut Tensor]) {
1073 for param in parameters {
1074 param.zero_grad();
1075 }
1076 }
1077
1078 /// Get the current learning rate
1079 ///
1080 /// Returns the current learning rate used for parameter updates.
1081 ///
1082 /// # Returns
1083 ///
1084 /// Current learning rate as f32
1085 fn learning_rate(&self) -> f32 {
1086 self.config.learning_rate
1087 }
1088
1089 /// Set the learning rate for all parameters
1090 ///
1091 /// Updates the learning rate for all parameters in the optimizer.
1092 /// This allows dynamic learning rate scheduling during training.
1093 ///
1094 /// # Arguments
1095 ///
1096 /// * `lr` - New learning rate value
1097 fn set_learning_rate(&mut self, lr: f32) {
1098 self.config.learning_rate = lr;
1099 }
1100}
1101
1102// Adam is automatically Send + Sync since it no longer contains raw pointers
1103// and all fields are Send + Sync (HashMap, usize, AdamConfig)
1104
1105#[cfg(test)]
1106mod tests {
1107 use super::*;
1108 use crate::tensor::core::Tensor;
1109
1110 /// Test Adam optimizer creation with default configuration
1111 ///
1112 /// Verifies that the optimizer is created correctly with default hyperparameters
1113 /// and that parameter states are properly initialized.
1114 #[test]
1115 fn test_adam_creation() {
1116 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1117 let bias = Tensor::zeros(vec![3, 1]).with_requires_grad();
1118
1119 let mut optimizer = Adam::new();
1120 optimizer.add_parameter(&weight);
1121 optimizer.add_parameter(&bias);
1122
1123 assert_eq!(optimizer.learning_rate(), 1e-3);
1124 assert_eq!(optimizer.config.beta1, 0.9);
1125 assert_eq!(optimizer.config.beta2, 0.999);
1126 assert_eq!(optimizer.parameter_count(), 2);
1127 }
1128
1129 /// Test Adam optimizer creation with custom configuration
1130 ///
1131 /// Verifies that custom hyperparameters are properly set and that
1132 /// the optimizer configuration matches the provided values.
1133 #[test]
1134 fn test_adam_with_config() {
1135 let weight = Tensor::ones(vec![5, 5]).with_requires_grad();
1136
1137 let config = AdamConfig {
1138 learning_rate: 1e-4,
1139 beta1: 0.95,
1140 beta2: 0.9999,
1141 weight_decay: 1e-5,
1142 amsgrad: true,
1143 ..Default::default()
1144 };
1145
1146 let mut optimizer = Adam::with_config(config);
1147 optimizer.add_parameter(&weight);
1148
1149 assert_eq!(optimizer.learning_rate(), 1e-4);
1150 assert_eq!(optimizer.config.beta1, 0.95);
1151 assert_eq!(optimizer.config.beta2, 0.9999);
1152 assert_eq!(optimizer.config.weight_decay, 1e-5);
1153 assert!(optimizer.config.amsgrad);
1154 }
1155
1156 /// Test Adam optimizer creation with custom learning rate
1157 ///
1158 /// Verifies that the convenience constructor properly sets the learning rate
1159 /// while maintaining default values for other hyperparameters.
1160 #[test]
1161 fn test_adam_with_learning_rate() {
1162 let weight = Tensor::ones(vec![3, 3]).with_requires_grad();
1163 let mut optimizer = Adam::with_learning_rate(5e-4);
1164 optimizer.add_parameter(&weight);
1165
1166 assert_eq!(optimizer.learning_rate(), 5e-4);
1167 assert_eq!(optimizer.config.beta1, 0.9); // Should use defaults for other params
1168 }
1169
1170 /// Test Adam step without gradients
1171 ///
1172 /// Verifies that the optimizer handles the case where parameters have no
1173 /// gradients gracefully, leaving parameters unchanged.
1174 #[test]
1175 fn test_adam_step_without_gradients() {
1176 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1177 let original_data = weight.clone();
1178
1179 let mut optimizer = Adam::new();
1180 optimizer.add_parameter(&weight);
1181
1182 optimizer.step(&mut [&mut weight]); // Should not update without gradients
1183
1184 // Parameters should remain unchanged without gradients
1185 for i in 0..weight.size() {
1186 unsafe {
1187 assert_eq!(*weight.as_ptr().add(i), *original_data.as_ptr().add(i));
1188 }
1189 }
1190 }
1191
1192 /// Test learning rate update functionality
1193 ///
1194 /// Verifies that the learning rate can be dynamically updated during training
1195 /// and that the optimizer uses the new learning rate for subsequent steps.
1196 #[test]
1197 fn test_learning_rate_update() {
1198 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1199 let mut optimizer = Adam::new();
1200 optimizer.add_parameter(&weight);
1201
1202 assert_eq!(optimizer.learning_rate(), 1e-3);
1203
1204 optimizer.set_learning_rate(1e-2);
1205 assert_eq!(optimizer.learning_rate(), 1e-2);
1206 }
1207
1208 /// Test gradient zeroing functionality
1209 ///
1210 /// Verifies that zero_grad properly clears accumulated gradients for all
1211 /// parameters and the global autograd state.
1212 #[test]
1213 fn test_zero_grad() {
1214 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1215
1216 // Check that zero_grad clears gradients
1217 let mut optimizer = Adam::new();
1218 optimizer.add_parameter(&weight);
1219 optimizer.zero_grad(&mut [&mut weight]);
1220
1221 // After zero_grad, there should be no accumulated gradients
1222 assert!(weight.grad_owned().is_none());
1223 }
1224
1225 /// Test requires_grad assertion
1226 ///
1227 /// Verifies that the optimizer correctly panics when parameters do not
1228 /// have requires_grad set to true, ensuring proper gradient tracking.
1229 #[test]
1230 #[should_panic(expected = "Parameter must require gradients")]
1231 fn test_adam_requires_grad_assertion() {
1232 let weight = Tensor::ones(vec![2, 2]); // No requires_grad
1233 let mut optimizer = Adam::new();
1234 optimizer.add_parameter(&weight);
1235 }
1236
1237 /// Test AdamConfig default values
1238 ///
1239 /// Verifies that the default configuration matches PyTorch conventions
1240 /// and provides optimal settings for most training scenarios.
1241 #[test]
1242 fn test_adam_config_default() {
1243 let config = AdamConfig::default();
1244
1245 assert_eq!(config.learning_rate, 1e-3);
1246 assert_eq!(config.beta1, 0.9);
1247 assert_eq!(config.beta2, 0.999);
1248 assert_eq!(config.eps, 1e-8);
1249 assert_eq!(config.weight_decay, 0.0);
1250 assert!(!config.amsgrad);
1251 }
1252
1253 /// Test ParameterState creation and initialization
1254 ///
1255 /// Verifies that parameter states are properly initialized with zero tensors
1256 /// and that the step count starts at zero.
1257 #[test]
1258 fn test_parameter_state_creation() {
1259 let state = ParameterState::new(&[3, 4]);
1260
1261 assert_eq!(state.m.shape().dims(), vec![3, 4]);
1262 assert_eq!(state.v.shape().dims(), vec![3, 4]);
1263 assert!(state.v_hat_max.is_none());
1264 assert_eq!(state.step, 0);
1265
1266 // Verify tensors are zero-initialized
1267 for i in 0..state.m.size() {
1268 unsafe {
1269 assert_eq!(*state.m.as_ptr().add(i), 0.0);
1270 assert_eq!(*state.v.as_ptr().add(i), 0.0);
1271 }
1272 }
1273 }
1274
1275 /// Test parameter linking functionality
1276 ///
1277 /// Verifies that parameters can be linked and unlinked from the optimizer.
1278 #[test]
1279 fn test_parameter_linking() {
1280 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1281 let bias = Tensor::zeros(vec![3]).with_requires_grad();
1282
1283 let mut optimizer = Adam::new();
1284
1285 // Initially no parameters linked
1286 assert_eq!(optimizer.parameter_count(), 0);
1287 assert!(!optimizer.is_parameter_linked(&weight));
1288 assert!(!optimizer.is_parameter_linked(&bias));
1289
1290 // Link weight
1291 optimizer.add_parameter(&weight);
1292 assert_eq!(optimizer.parameter_count(), 1);
1293 assert!(optimizer.is_parameter_linked(&weight));
1294 assert!(!optimizer.is_parameter_linked(&bias));
1295
1296 // Link bias
1297 optimizer.add_parameter(&bias);
1298 assert_eq!(optimizer.parameter_count(), 2);
1299 assert!(optimizer.is_parameter_linked(&weight));
1300 assert!(optimizer.is_parameter_linked(&bias));
1301
1302 // Unlink weight
1303 let was_linked = optimizer.unlink_parameter(&weight);
1304 assert!(was_linked);
1305 assert_eq!(optimizer.parameter_count(), 1);
1306 assert!(!optimizer.is_parameter_linked(&weight));
1307 assert!(optimizer.is_parameter_linked(&bias));
1308
1309 // Clear all states
1310 optimizer.clear_states();
1311 assert_eq!(optimizer.parameter_count(), 0);
1312 assert!(!optimizer.is_parameter_linked(&weight));
1313 assert!(!optimizer.is_parameter_linked(&bias));
1314 }
1315
1316 /// Test parameter linking with multiple parameters at once
1317 ///
1318 /// Verifies that multiple parameters can be linked simultaneously.
1319 #[test]
1320 fn test_add_multiple_parameters() {
1321 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1322 let bias = Tensor::zeros(vec![3]).with_requires_grad();
1323 let weight2 = Tensor::ones(vec![3, 2]).with_requires_grad();
1324
1325 let mut optimizer = Adam::new();
1326
1327 // Link multiple parameters at once
1328 optimizer.add_parameters(&[&weight, &bias, &weight2]);
1329
1330 assert_eq!(optimizer.parameter_count(), 3);
1331 assert!(optimizer.is_parameter_linked(&weight));
1332 assert!(optimizer.is_parameter_linked(&bias));
1333 assert!(optimizer.is_parameter_linked(&weight2));
1334 }
1335
1336 /// Test stepping with unlinked parameter
1337 ///
1338 /// Verifies that the optimizer panics when trying to step with an unlinked parameter.
1339 #[test]
1340 #[should_panic(expected = "Parameter must be linked to optimizer before stepping")]
1341 fn test_step_with_unlinked_parameter() {
1342 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1343 let mut optimizer = Adam::new();
1344
1345 // Don't link the parameter
1346 optimizer.step(&mut [&mut weight]); // Should panic
1347 }
1348
1349 /// Test optimizer with actual gradients
1350 ///
1351 /// Verifies that the optimizer properly updates parameters when gradients are present.
1352 #[test]
1353 fn test_optimizer_with_gradients() {
1354 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1355 let original_data = weight.clone();
1356
1357 let mut optimizer = Adam::new();
1358 optimizer.add_parameter(&weight);
1359
1360 // Generate some gradients
1361 let output = weight.mul_scalar(2.0);
1362 let mut loss = output.sum();
1363 loss.backward(None);
1364
1365 // Step should update parameters
1366 optimizer.step(&mut [&mut weight]);
1367
1368 // Parameters should have changed
1369 let mut changed = false;
1370 for i in 0..weight.size() {
1371 unsafe {
1372 if (*weight.as_ptr().add(i) - *original_data.as_ptr().add(i)).abs() > 1e-6 {
1373 changed = true;
1374 break;
1375 }
1376 }
1377 }
1378 assert!(
1379 changed,
1380 "Parameters should have been updated by optimizer step"
1381 );
1382 }
1383
1384 /// Test optimizer with multiple parameters and gradients
1385 ///
1386 /// Verifies that the optimizer works correctly with multiple parameters.
1387 #[test]
1388 fn test_optimizer_multiple_parameters() {
1389 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1390 let mut bias = Tensor::zeros(vec![2, 2]).with_requires_grad(); // Same shape as weight
1391
1392 let mut optimizer = Adam::new();
1393 optimizer.add_parameter(&weight);
1394 optimizer.add_parameter(&bias);
1395
1396 // Generate gradients for both parameters
1397 let output = weight.mul_scalar(2.0).add_tensor(&bias);
1398 let mut loss = output.sum();
1399 loss.backward(None);
1400
1401 // Step should update both parameters
1402 optimizer.step(&mut [&mut weight, &mut bias]);
1403
1404 // Both parameters should have gradients
1405 assert!(weight.grad_owned().is_some());
1406 assert!(bias.grad_owned().is_some());
1407 }
1408
1409 /// Test optimizer with custom configuration and multiple steps
1410 ///
1411 /// Verifies that the optimizer works correctly with custom configuration over multiple steps.
1412 #[test]
1413 fn test_optimizer_custom_config_multiple_steps() {
1414 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1415
1416 let config = AdamConfig {
1417 learning_rate: 1e-4,
1418 beta1: 0.95,
1419 beta2: 0.9999,
1420 weight_decay: 1e-5,
1421 amsgrad: true,
1422 ..Default::default()
1423 };
1424
1425 let mut optimizer = Adam::with_config(config);
1426 optimizer.add_parameter(&weight);
1427
1428 // Multiple training steps
1429 for _ in 0..5 {
1430 // Generate gradients
1431 let output = weight.mul_scalar(2.0);
1432 let mut loss = output.sum();
1433 loss.backward(None);
1434
1435 // Step
1436 optimizer.step(&mut [&mut weight]);
1437 optimizer.zero_grad(&mut [&mut weight]);
1438 }
1439
1440 // Should complete without errors
1441 assert_eq!(optimizer.parameter_count(), 1);
1442 assert!(optimizer.is_parameter_linked(&weight));
1443 }
1444
1445 /// Test optimizer with different tensor shapes
1446 ///
1447 /// Verifies that the optimizer works correctly with various tensor shapes.
1448 #[test]
1449 fn test_optimizer_different_shapes() {
1450 let shapes = vec![
1451 vec![1], // Scalar
1452 vec![3], // 1D
1453 vec![2, 2], // 2D square
1454 vec![2, 3], // 2D rectangular
1455 vec![1, 1, 3], // 3D
1456 vec![2, 2, 2], // 3D cube
1457 ];
1458
1459 for shape in shapes {
1460 let mut tensor = Tensor::ones(shape.clone()).with_requires_grad();
1461 let mut optimizer = Adam::new();
1462 optimizer.add_parameter(&tensor);
1463
1464 // Generate gradients
1465 let output = tensor.mul_scalar(2.0);
1466 let mut loss = output.sum();
1467 loss.backward(None);
1468
1469 // Step should work for all shapes
1470 optimizer.step(&mut [&mut tensor]);
1471
1472 // Verify tensor is still valid
1473 assert_eq!(tensor.shape().dims(), shape);
1474 assert!(tensor.requires_grad());
1475 }
1476 }
1477
1478 /// Test double parameter linking
1479 ///
1480 /// Verifies that linking the same parameter twice doesn't create duplicate states.
1481 #[test]
1482 fn test_double_parameter_linking() {
1483 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1484 let mut optimizer = Adam::new();
1485
1486 // Link parameter twice
1487 optimizer.add_parameter(&weight);
1488 optimizer.add_parameter(&weight);
1489
1490 // Should only have one state
1491 assert_eq!(optimizer.parameter_count(), 1);
1492 assert!(optimizer.is_parameter_linked(&weight));
1493 }
1494
1495 /// Test unlink non-linked parameter
1496 ///
1497 /// Verifies that unlinking a non-linked parameter returns false.
1498 #[test]
1499 fn test_unlink_non_linked_parameter() {
1500 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1501 let mut optimizer = Adam::new();
1502
1503 // Try to unlink a parameter that was never linked
1504 let was_linked = optimizer.unlink_parameter(&weight);
1505 assert!(!was_linked);
1506 }
1507
1508 /// Test parameter shape validation
1509 ///
1510 /// Verifies that parameters with different shapes can be linked correctly
1511 /// and maintain their shape information in the optimizer state.
1512 #[test]
1513 fn test_parameter_shape_validation() {
1514 let weight_1d = Tensor::ones(vec![5]).with_requires_grad();
1515 let weight_2d = Tensor::ones(vec![3, 4]).with_requires_grad();
1516 let weight_3d = Tensor::ones(vec![2, 3, 4]).with_requires_grad();
1517
1518 let mut optimizer = Adam::new();
1519
1520 // Test linking parameters with different shapes
1521 optimizer.add_parameter(&weight_1d);
1522 optimizer.add_parameter(&weight_2d);
1523 optimizer.add_parameter(&weight_3d);
1524
1525 assert_eq!(optimizer.parameter_count(), 3);
1526 assert!(optimizer.is_parameter_linked(&weight_1d));
1527 assert!(optimizer.is_parameter_linked(&weight_2d));
1528 assert!(optimizer.is_parameter_linked(&weight_3d));
1529
1530 // Test that each parameter maintains its shape after linking
1531 let state_1d = &optimizer.states[&weight_1d.id()];
1532 let state_2d = &optimizer.states[&weight_2d.id()];
1533 let state_3d = &optimizer.states[&weight_3d.id()];
1534
1535 assert_eq!(state_1d.m.shape().dims(), vec![5]);
1536 assert_eq!(state_2d.m.shape().dims(), vec![3, 4]);
1537 assert_eq!(state_3d.m.shape().dims(), vec![2, 3, 4]);
1538 }
1539
1540 /// Test parameter relinking with shape consistency
1541 ///
1542 /// Verifies that when re-linking parameters (e.g., after deserialization),
1543 /// shape information is correctly maintained.
1544 #[test]
1545 fn test_parameter_relinking_shape_consistency() {
1546 // Create initial parameter and optimizer
1547 let mut weight_original = Tensor::ones(vec![3, 3]).with_requires_grad();
1548 let mut optimizer = Adam::new();
1549 optimizer.add_parameter(&weight_original);
1550
1551 // Perform a step to create state
1552 let output = weight_original.mul_scalar(2.0);
1553 let mut loss = output.sum();
1554 loss.backward(None);
1555 optimizer.step(&mut [&mut weight_original]);
1556
1557 // Verify state was created with correct shape
1558 let original_state = &optimizer.states[&weight_original.id()];
1559 assert_eq!(original_state.m.shape().dims(), vec![3, 3]);
1560 assert_eq!(original_state.v.shape().dims(), vec![3, 3]);
1561
1562 // Create new parameter with same shape (will get different ID)
1563 let weight_new = Tensor::ones(vec![3, 3]).with_requires_grad();
1564
1565 // This should create a new state since it's a different parameter
1566 optimizer.add_parameter(&weight_new);
1567
1568 // Should now have 2 states
1569 assert_eq!(optimizer.parameter_count(), 2);
1570
1571 // Both should be linked with correct shapes
1572 assert!(optimizer.is_parameter_linked(&weight_original));
1573 assert!(optimizer.is_parameter_linked(&weight_new));
1574
1575 let new_state = &optimizer.states[&weight_new.id()];
1576 assert_eq!(new_state.m.shape().dims(), vec![3, 3]);
1577 assert_eq!(new_state.v.shape().dims(), vec![3, 3]);
1578 }
1579
1580 /// Test large parameter count handling
1581 ///
1582 /// Verifies that the optimizer can handle many parameters efficiently
1583 /// and correctly manage linking/unlinking operations.
1584 #[test]
1585 fn test_large_parameter_count() {
1586 let mut optimizer = Adam::new();
1587 let mut params = Vec::new();
1588
1589 // Create 50 parameters of different shapes
1590 for i in 1..=50 {
1591 let param = Tensor::ones(vec![i]).with_requires_grad();
1592 optimizer.add_parameter(¶m);
1593 params.push(param);
1594 }
1595
1596 assert_eq!(optimizer.parameter_count(), 50);
1597
1598 // Verify all parameters are linked
1599 for param in ¶ms {
1600 assert!(optimizer.is_parameter_linked(param));
1601 }
1602
1603 // Test unlinking some parameters
1604 for param in params.iter().take(25).step_by(2) {
1605 assert!(optimizer.unlink_parameter(param));
1606 }
1607
1608 assert_eq!(optimizer.parameter_count(), 37); // 50 - 13 = 37
1609
1610 // Verify correct parameters are unlinked
1611 for (i, param) in params.iter().enumerate() {
1612 if i < 25 && i % 2 == 0 {
1613 assert!(!optimizer.is_parameter_linked(param));
1614 } else {
1615 assert!(optimizer.is_parameter_linked(param));
1616 }
1617 }
1618 }
1619
1620 /// Test clear_states functionality
1621 ///
1622 /// Verifies that clearing all states works correctly and allows
1623 /// re-adding parameters afterwards.
1624 #[test]
1625 fn test_clear_states_functionality() {
1626 let weight1 = Tensor::ones(vec![2, 2]).with_requires_grad();
1627 let weight2 = Tensor::ones(vec![3, 3]).with_requires_grad();
1628 let weight3 = Tensor::ones(vec![4, 4]).with_requires_grad();
1629
1630 let mut optimizer = Adam::new();
1631 optimizer.add_parameter(&weight1);
1632 optimizer.add_parameter(&weight2);
1633 optimizer.add_parameter(&weight3);
1634
1635 assert_eq!(optimizer.parameter_count(), 3);
1636
1637 // Clear all states
1638 optimizer.clear_states();
1639
1640 assert_eq!(optimizer.parameter_count(), 0);
1641 assert!(!optimizer.is_parameter_linked(&weight1));
1642 assert!(!optimizer.is_parameter_linked(&weight2));
1643 assert!(!optimizer.is_parameter_linked(&weight3));
1644
1645 // Should be able to re-add parameters after clearing
1646 optimizer.add_parameter(&weight1);
1647 assert_eq!(optimizer.parameter_count(), 1);
1648 assert!(optimizer.is_parameter_linked(&weight1));
1649 }
1650}