train_station/optimizers/adam/mod.rs
1//! Adam optimizer implementation for neural network training
2//!
3//! This module provides the Adam optimization algorithm with PyTorch-compatible interface.
4//! Adam combines the benefits of AdaGrad and RMSprop by using adaptive learning rates
5//! with momentum for efficient training of neural networks.
6//!
7//! # Features
8//!
9//! - **Adaptive Learning Rates**: Per-parameter learning rates based on gradient history
10//! - **Momentum Integration**: Combines momentum with adaptive learning rates
11//! - **Bias Correction**: Corrects for initialization bias in moment estimates
12//! - **Weight Decay**: Optional L2 regularization support
13//! - **AMSGrad Variant**: Optional AMSGrad for improved convergence stability
14//! - **SIMD Optimization**: AVX2-optimized parameter updates for maximum performance
15//! - **Thread Safety**: Send + Sync implementation for multi-threaded training
16//! - **Hybrid API**: Both safe (RwLock-based) and unsafe (direct pointer) access patterns
17//!
18//! # Thread Safety
19//!
20//! The optimizer provides two usage patterns:
21//!
22//! **Safe Multi-threaded Usage (Default)**:
23//! - Uses `Arc<RwLock<Tensor>>` for thread-safe parameter access
24//! - Multiple threads can read tensors simultaneously
25//! - Optimizer steps acquire write locks only during parameter updates
26//! - Recommended for most use cases
27//!
28//! **Unsafe Single-threaded Usage (Performance)**:
29//! - Uses raw pointers for maximum performance
30//! - No locking overhead during optimizer steps
31//! - Caller must ensure exclusive access during optimization
32//! - Use only when you can guarantee no concurrent tensor access
33//!
34//! # Algorithm
35//!
36//! Adam implements the following update rule for each parameter theta:
37//!
38//! ```text
39//! m_t = beta1 * m_{t-1} + (1 - beta1) * grad_theta_t
40//! v_t = beta2 * v_{t-1} + (1 - beta2) * (grad_theta_t)^2
41//! m_hat_t = m_t / (1 - beta1^t)
42//! v_hat_t = v_t / (1 - beta2^t)
43//! theta_{t+1} = theta_t - lr * m_hat_t / (sqrt(v_hat_t) + eps)
44//! ```
45//!
46//! Where:
47//! - lr is the learning rate
48//! - beta1, beta2 are exponential decay rates for moment estimates
49//! - eps is a small constant for numerical stability
50//! - m_t, v_t are biased first and second moment estimates
51//! - m_hat_t, v_hat_t are bias-corrected moment estimates
52//!
53//! # Performance Characteristics
54//!
55//! - **SIMD Optimization**: Uses AVX2 instructions for 8x vectorization when available
56//! - **Memory Efficiency**: In-place updates with minimal temporary allocations
57//! - **Cache-Friendly**: Optimized memory access patterns for large parameter tensors
58//! - **Zero-Cost Abstractions**: Compile-time optimization with minimal runtime overhead
59//! - **Lock-Free Reads**: RwLock allows concurrent tensor reads during training
60//!
61//! # References
62//!
63//! - Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization.
64//! - PyTorch Adam implementation: <https://pytorch.org/docs/stable/generated/torch.optim.Adam.html>
65
66pub mod serialization;
67
68use super::Optimizer;
69use crate::tensor::core::Tensor;
70use std::collections::HashMap;
71
72// SIMD optimizations for performance-critical operations
73#[cfg(target_arch = "x86_64")]
74use std::arch::x86_64::*;
75
76/// Configuration for the Adam optimization algorithm
77///
78/// Contains all hyperparameters that control the behavior of Adam optimization.
79/// Default values follow PyTorch conventions for maximum compatibility and
80/// optimal convergence across a wide range of neural network architectures.
81///
82/// # Fields
83///
84/// * `learning_rate` - Step size for parameter updates (default: 1e-3)
85/// * `beta1` - Exponential decay rate for first moment estimates (default: 0.9)
86/// * `beta2` - Exponential decay rate for second moment estimates (default: 0.999)
87/// * `eps` - Small constant for numerical stability (default: 1e-8)
88/// * `weight_decay` - L2 regularization coefficient (default: 0.0)
89/// * `amsgrad` - Whether to use AMSGrad variant for improved stability (default: false)
90#[derive(Debug, Clone)]
91pub struct AdamConfig {
92 /// Learning rate for parameter updates (default: 1e-3)
93 pub learning_rate: f32,
94 /// Exponential decay rate for first moment estimates (default: 0.9)
95 pub beta1: f32,
96 /// Exponential decay rate for second moment estimates (default: 0.999)
97 pub beta2: f32,
98 /// Small constant for numerical stability (default: 1e-8)
99 pub eps: f32,
100 /// Weight decay coefficient for L2 regularization (default: 0.0)
101 pub weight_decay: f32,
102 /// Whether to use AMSGrad variant (default: false)
103 pub amsgrad: bool,
104}
105
106impl Default for AdamConfig {
107 fn default() -> Self {
108 Self {
109 learning_rate: 1e-3,
110 beta1: 0.9,
111 beta2: 0.999,
112 eps: 1e-8,
113 weight_decay: 0.0,
114 amsgrad: false,
115 }
116 }
117}
118
119/// Internal state tracking for a single parameter during Adam optimization
120///
121/// Stores the momentum and velocity buffers needed for Adam optimization,
122/// along with step count for bias correction and optional AMSGrad state.
123/// Memory layout is optimized for cache efficiency and SIMD operations.
124///
125/// # Fields
126///
127/// * `m` - First moment estimate (momentum buffer)
128/// * `v` - Second moment estimate (velocity buffer)
129/// * `v_hat_max` - Maximum of second moment estimates (for AMSGrad variant)
130/// * `step` - Step count for bias correction calculations
131#[derive(Debug)]
132struct ParameterState {
133 /// First moment estimate (momentum)
134 m: Tensor,
135 /// Second moment estimate (velocity)
136 v: Tensor,
137 /// Maximum of second moment estimates (for AMSGrad)
138 v_hat_max: Option<Tensor>,
139 /// Step count for bias correction
140 step: usize,
141}
142
143impl ParameterState {
144 fn new(param_shape: &[usize]) -> Self {
145 Self {
146 m: Tensor::zeros(param_shape.to_vec()),
147 v: Tensor::zeros(param_shape.to_vec()),
148 v_hat_max: None,
149 step: 0,
150 }
151 }
152}
153
154/// Adam optimizer for neural network parameter optimization
155///
156/// Implements the Adam optimization algorithm with PyTorch-compatible interface.
157/// Provides adaptive learning rates with momentum for efficient training of neural networks.
158/// The optimizer maintains per-parameter state for momentum and velocity estimates,
159/// enabling adaptive learning rates that improve convergence across diverse architectures.
160///
161/// # Usage Pattern
162///
163/// The optimizer uses ID-based parameter linking for maximum flexibility and thread safety:
164/// - Parameters are linked to the optimizer via `add_parameter` or `add_parameters`
165/// - The `step` method takes mutable references to parameters for thread-safe updates
166/// - Parameter states are maintained by tensor ID, allowing for dynamic parameter management
167/// - Supports serialization and deserialization with parameter re-linking
168///
169/// # Dynamic Parameter Management
170///
171/// Parameters can be added, removed, or re-linked at runtime:
172/// - `add_parameter`: Link a single parameter
173/// - `add_parameters`: Link multiple parameters at once
174/// - `unlink_parameter`: Remove parameter state by ID
175/// - `clear_states`: Remove all parameter states
176/// - `is_parameter_linked`: Check if a parameter is linked
177///
178/// # Serialization Support
179///
180/// The optimizer supports full serialization and deserialization with state preservation:
181/// - Parameter states are saved with their shapes and insertion order for validation
182/// - After deserialization, use `relink_parameters` to restore saved states to new tensors
183/// - Parameters must be re-linked in the same chronological order they were originally added
184/// - Shape validation ensures consistency between saved and current parameters
185///
186/// # Features
187///
188/// - **ID-Based Parameter Linking**: Dynamic parameter management via tensor IDs
189/// - **Thread-Safe Step Method**: Takes mutable references for safe concurrent access
190/// - **Per-Parameter State**: Each parameter maintains its own momentum and velocity buffers
191/// - **Bias Correction**: Automatically corrects initialization bias in moment estimates
192/// - **Weight Decay**: Optional L2 regularization with efficient implementation
193/// - **AMSGrad Support**: Optional AMSGrad variant for improved convergence stability
194/// - **SIMD Optimization**: AVX2-optimized updates for maximum performance
195/// - **Full Serialization**: Complete state persistence and restoration
196///
197/// # Thread Safety
198///
199/// This type is thread-safe and can be shared between threads. The step method
200/// takes mutable references to parameters, ensuring exclusive access during updates.
201pub struct Adam {
202 /// Optimizer configuration
203 config: AdamConfig,
204 /// Parameter states indexed by tensor ID
205 states: HashMap<usize, ParameterState>,
206 /// Current global step count
207 step_count: usize,
208 /// Order in which parameters were added (for serialization/re-linking)
209 insertion_order: Vec<usize>,
210}
211
212impl Default for Adam {
213 fn default() -> Self {
214 Self::with_config(AdamConfig::default())
215 }
216}
217
218impl Adam {
219 /// Create a new Adam optimizer with default configuration
220 ///
221 /// Initializes an Adam optimizer with PyTorch-compatible default hyperparameters.
222 /// Parameters must be linked separately using `add_parameter` or `add_parameters`.
223 ///
224 /// # Returns
225 ///
226 /// A new Adam optimizer instance with default hyperparameters
227 pub fn new() -> Self {
228 Self::default()
229 }
230
231 /// Create a new Adam optimizer with custom configuration
232 ///
233 /// Allows full control over all Adam hyperparameters for specialized training
234 /// scenarios such as fine-tuning, transfer learning, or research applications.
235 /// Parameters must be linked separately using `add_parameter` or `add_parameters`.
236 ///
237 /// # Arguments
238 ///
239 /// * `config` - Adam configuration with custom hyperparameters
240 ///
241 /// # Returns
242 ///
243 /// A new Adam optimizer instance with the specified configuration
244 pub fn with_config(config: AdamConfig) -> Self {
245 Self {
246 config,
247 states: HashMap::new(),
248 step_count: 0,
249 insertion_order: Vec::new(),
250 }
251 }
252
253 /// Create a new Adam optimizer with custom learning rate
254 ///
255 /// A convenience constructor that allows setting only the learning rate while
256 /// using default values for all other hyperparameters. Parameters must be
257 /// linked separately using `add_parameter` or `add_parameters`.
258 ///
259 /// # Arguments
260 ///
261 /// * `learning_rate` - Learning rate for optimization
262 ///
263 /// # Returns
264 ///
265 /// A new Adam optimizer instance with the specified learning rate and default
266 /// values for all other hyperparameters
267 pub fn with_learning_rate(learning_rate: f32) -> Self {
268 let config = AdamConfig {
269 learning_rate,
270 ..Default::default()
271 };
272 Self::with_config(config)
273 }
274
275 /// Add a single parameter to the optimizer
276 ///
277 /// Links a parameter to the optimizer by creating a new parameter state
278 /// indexed by the tensor's ID. The parameter must have `requires_grad` set to true.
279 ///
280 /// # Arguments
281 ///
282 /// * `parameter` - Reference to the tensor to link
283 ///
284 /// # Panics
285 ///
286 /// Panics if the parameter does not have `requires_grad` set to true
287 pub fn add_parameter(&mut self, parameter: &Tensor) {
288 assert!(
289 parameter.requires_grad(),
290 "Parameter must require gradients"
291 );
292
293 let param_id = parameter.id();
294 let param_shape = parameter.shape().dims.clone();
295
296 // Initialize state for this parameter if not already present
297 use std::collections::hash_map::Entry;
298 if let Entry::Vacant(entry) = self.states.entry(param_id) {
299 entry.insert(ParameterState::new(¶m_shape));
300 self.insertion_order.push(param_id);
301 }
302 }
303
304 /// Add multiple parameters to the optimizer
305 ///
306 /// Links multiple parameters to the optimizer by creating parameter states
307 /// indexed by each tensor's ID. All parameters must have `requires_grad` set to true.
308 ///
309 /// # Arguments
310 ///
311 /// * `parameters` - Slice of references to tensors to link
312 ///
313 /// # Panics
314 ///
315 /// Panics if any parameter does not have `requires_grad` set to true
316 pub fn add_parameters(&mut self, parameters: &[&Tensor]) {
317 for parameter in parameters {
318 self.add_parameter(parameter);
319 }
320 }
321
322 /// Remove a parameter from the optimizer
323 ///
324 /// Unlinks a parameter by removing its state from the optimizer.
325 /// The parameter ID is used for identification.
326 ///
327 /// # Arguments
328 ///
329 /// * `parameter` - Reference to the tensor to unlink
330 ///
331 /// # Returns
332 ///
333 /// True if the parameter was linked and removed, false if it was not linked
334 pub fn unlink_parameter(&mut self, parameter: &Tensor) -> bool {
335 let param_id = parameter.id();
336 let was_linked = self.states.remove(¶m_id).is_some();
337 if was_linked {
338 self.insertion_order.retain(|&id| id != param_id);
339 }
340 was_linked
341 }
342
343 /// Remove all parameter states from the optimizer
344 ///
345 /// Clears all parameter states, effectively unlinking all parameters.
346 /// This is useful for resetting the optimizer or preparing for parameter re-linking.
347 pub fn clear_states(&mut self) {
348 self.states.clear();
349 self.insertion_order.clear();
350 }
351
352 /// Check if a parameter is linked to the optimizer
353 ///
354 /// Returns true if the parameter has an associated state in the optimizer.
355 ///
356 /// # Arguments
357 ///
358 /// * `parameter` - Reference to the tensor to check
359 ///
360 /// # Returns
361 ///
362 /// True if the parameter is linked, false otherwise
363 pub fn is_parameter_linked(&self, parameter: &Tensor) -> bool {
364 let param_id = parameter.id();
365 self.states.contains_key(¶m_id)
366 }
367
368 /// Get the number of linked parameters
369 ///
370 /// Returns the count of parameters currently linked to the optimizer.
371 ///
372 /// # Returns
373 ///
374 /// Number of linked parameters
375 pub fn parameter_count(&self) -> usize {
376 self.states.len()
377 }
378
379 /// Re-link parameters to saved optimizer states in chronological order
380 ///
381 /// After deserializing an optimizer, use this method to restore saved parameter states
382 /// to new tensors. Parameters must be provided in the same chronological order they
383 /// were originally added to the optimizer. Shape validation ensures parameter compatibility.
384 ///
385 /// # Arguments
386 ///
387 /// * `parameters` - Slice of parameter references in chronological order
388 ///
389 /// # Returns
390 ///
391 /// Result indicating success or failure with detailed error message
392 ///
393 /// # Panics
394 ///
395 /// Panics if any parameter does not have `requires_grad` set to true
396 pub fn relink_parameters(&mut self, parameters: &[&Tensor]) -> Result<(), String> {
397 // Validate all parameters have requires_grad first
398 for (i, param) in parameters.iter().enumerate() {
399 if !param.requires_grad() {
400 return Err(format!("Parameter at index {} must require gradients", i));
401 }
402 }
403
404 // Check parameter count matches saved states
405 if parameters.len() != self.insertion_order.len() {
406 return Err(format!(
407 "Parameter count mismatch: expected {} parameters, got {}",
408 self.insertion_order.len(),
409 parameters.len()
410 ));
411 }
412
413 // Create new states map with parameter IDs mapped to saved states in chronological order
414 let mut new_states = HashMap::new();
415 let mut new_insertion_order = Vec::new();
416
417 for (i, param) in parameters.iter().enumerate() {
418 let new_param_id = param.id();
419 let old_param_id = self.insertion_order[i];
420
421 // Get the saved state for this position
422 let saved_state = self
423 .states
424 .get(&old_param_id)
425 .ok_or_else(|| format!("No saved state found for parameter at position {}", i))?;
426
427 // Validate shape matches
428 let param_shape = ¶m.shape().dims;
429 let saved_shape = &saved_state.m.shape().dims;
430 if param_shape != saved_shape {
431 return Err(format!(
432 "Shape mismatch for parameter at position {}: expected {:?}, got {:?}",
433 i, saved_shape, param_shape
434 ));
435 }
436
437 // Create new state for this parameter
438 let new_state = ParameterState {
439 m: saved_state.m.clone(),
440 v: saved_state.v.clone(),
441 v_hat_max: saved_state.v_hat_max.clone(),
442 step: saved_state.step,
443 };
444
445 new_states.insert(new_param_id, new_state);
446 new_insertion_order.push(new_param_id);
447 }
448
449 // Replace the states and insertion order
450 self.states = new_states;
451 self.insertion_order = new_insertion_order;
452
453 Ok(())
454 }
455
456 /// Get the current optimizer configuration
457 ///
458 /// Returns a reference to the current configuration, allowing inspection
459 /// of all hyperparameters without modification.
460 ///
461 /// # Returns
462 ///
463 /// Reference to the current Adam configuration
464 pub fn config(&self) -> &AdamConfig {
465 &self.config
466 }
467
468 /// Update a single parameter using Adam algorithm
469 ///
470 /// Implements the core Adam update rule with bias correction and optional AMSGrad.
471 /// Uses SIMD optimization when available for improved performance.
472 /// The parameter must be linked to the optimizer before calling this method.
473 ///
474 /// # Arguments
475 ///
476 /// * `param` - Mutable reference to the parameter tensor to update
477 ///
478 /// # Returns
479 ///
480 /// Result indicating success or failure of the parameter update
481 ///
482 /// # Panics
483 ///
484 /// Panics if the parameter is not linked to the optimizer
485 fn update_parameter(&mut self, param: &mut Tensor) -> Result<(), String> {
486 let param_id = param.id();
487
488 // Ensure parameter is linked
489 assert!(
490 self.states.contains_key(¶m_id),
491 "Parameter must be linked to optimizer before stepping. Use add_parameter() first."
492 );
493
494 // Get parameter gradient
495 let grad = param
496 .grad_by_value()
497 .ok_or_else(|| format!("Parameter {} has no gradient", param_id))?;
498
499 // Get parameter state
500 let state = self
501 .states
502 .get_mut(¶m_id)
503 .expect("Parameter state should exist after link check");
504
505 // Increment step count
506 state.step += 1;
507 let step = self.step_count as f32; // Use global step count for bias correction
508
509 // Apply weight decay if enabled
510 let effective_grad = if self.config.weight_decay > 0.0 {
511 // L2 regularization: grad + weight_decay * param
512 let mut grad_with_decay = grad.clone();
513 Self::add_weight_decay(&mut grad_with_decay, param, self.config.weight_decay);
514 grad_with_decay
515 } else {
516 grad
517 };
518
519 // Update biased first moment estimate: m = beta1 * m + (1 - beta1) * grad
520 Self::update_momentum(&mut state.m, &effective_grad, self.config.beta1);
521
522 // Update biased second moment estimate: v = beta2 * v + (1 - beta2) * grad^2
523 Self::update_velocity(&mut state.v, &effective_grad, self.config.beta2);
524
525 // Compute bias-corrected first moment estimate
526 let bias_correction1 = 1.0 - (self.config.beta1 as f64).powf(step as f64);
527 let m_hat = Self::bias_correct(&state.m, bias_correction1 as f32);
528
529 // Compute bias-corrected second moment estimate
530 let bias_correction2 = 1.0 - (self.config.beta2 as f64).powf(step as f64);
531 let mut v_hat = Self::bias_correct(&state.v, bias_correction2 as f32);
532
533 // AMSGrad: use maximum of v_hat over time
534 if self.config.amsgrad {
535 if state.v_hat_max.is_none() {
536 state.v_hat_max = Some(v_hat.clone());
537 }
538 let v_hat_max = state.v_hat_max.as_mut().unwrap();
539 Self::element_wise_max(v_hat_max, &v_hat);
540 v_hat = v_hat_max.clone();
541 }
542
543 // Compute parameter update: param = param - lr * m_hat / (sqrt(v_hat) + eps)
544 Self::apply_adam_update(
545 param,
546 &m_hat,
547 &v_hat,
548 self.config.learning_rate,
549 self.config.eps,
550 );
551
552 Ok(())
553 }
554
555 /// Apply weight decay (L2 regularization) to gradient
556 ///
557 /// Adds `weight_decay * param` to the gradient in-place for memory efficiency.
558 /// This implements L2 regularization by modifying the gradient before the
559 /// Adam update step, equivalent to adding a regularization term to the loss.
560 ///
561 /// # Arguments
562 ///
563 /// * `grad` - Gradient tensor to modify in-place
564 /// * `param` - Parameter tensor for weight decay calculation
565 /// * `weight_decay` - Weight decay coefficient
566 #[inline]
567 fn add_weight_decay(grad: &mut Tensor, param: &Tensor, weight_decay: f32) {
568 assert_eq!(
569 grad.size(),
570 param.size(),
571 "Gradient and parameter size mismatch"
572 );
573
574 unsafe {
575 let grad_ptr = grad.as_mut_ptr();
576 let param_ptr = param.as_ptr();
577 let size = grad.size();
578
579 #[cfg(target_arch = "x86_64")]
580 {
581 if is_x86_feature_detected!("avx2")
582 && grad.is_simd_aligned()
583 && param.is_simd_aligned()
584 {
585 Self::add_weight_decay_simd_avx2(grad_ptr, param_ptr, weight_decay, size);
586 return;
587 }
588 }
589
590 // Scalar fallback
591 for i in 0..size {
592 *grad_ptr.add(i) += weight_decay * *param_ptr.add(i);
593 }
594 }
595 }
596
597 #[cfg(target_arch = "x86_64")]
598 #[inline]
599 unsafe fn add_weight_decay_simd_avx2(
600 grad_ptr: *mut f32,
601 param_ptr: *const f32,
602 weight_decay: f32,
603 size: usize,
604 ) {
605 let decay_vec = _mm256_set1_ps(weight_decay);
606 let simd_count = size / 8;
607
608 for i in 0..simd_count {
609 let offset = i * 8;
610 let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
611 let param_vec = _mm256_loadu_ps(param_ptr.add(offset));
612 let decay_term = _mm256_mul_ps(decay_vec, param_vec);
613 let result = _mm256_add_ps(grad_vec, decay_term);
614 _mm256_storeu_ps(grad_ptr.add(offset), result);
615 }
616
617 // Handle remaining elements
618 for i in (simd_count * 8)..size {
619 *grad_ptr.add(i) += weight_decay * *param_ptr.add(i);
620 }
621 }
622
623 /// Update momentum (first moment estimate)
624 ///
625 /// Implements the momentum update rule: `m = beta1 * m + (1 - beta1) * grad`
626 /// This computes the exponentially decaying average of gradients, providing
627 /// momentum-like behavior that helps accelerate convergence in relevant directions.
628 ///
629 /// # Arguments
630 ///
631 /// * `momentum` - Momentum buffer to update in-place
632 /// * `grad` - Current gradient tensor
633 /// * `beta1` - Exponential decay rate for momentum
634 #[inline]
635 fn update_momentum(momentum: &mut Tensor, grad: &Tensor, beta1: f32) {
636 assert_eq!(
637 momentum.size(),
638 grad.size(),
639 "Momentum and gradient size mismatch"
640 );
641
642 let beta1_complement = 1.0 - beta1;
643
644 unsafe {
645 let m_ptr = momentum.as_mut_ptr();
646 let grad_ptr = grad.as_ptr();
647 let size = momentum.size();
648
649 #[cfg(target_arch = "x86_64")]
650 {
651 if is_x86_feature_detected!("avx2")
652 && momentum.is_simd_aligned()
653 && grad.is_simd_aligned()
654 {
655 Self::update_momentum_simd_avx2(m_ptr, grad_ptr, beta1, beta1_complement, size);
656 return;
657 }
658 }
659
660 // Scalar fallback
661 for i in 0..size {
662 *m_ptr.add(i) = beta1 * *m_ptr.add(i) + beta1_complement * *grad_ptr.add(i);
663 }
664 }
665 }
666
667 #[cfg(target_arch = "x86_64")]
668 #[inline]
669 unsafe fn update_momentum_simd_avx2(
670 m_ptr: *mut f32,
671 grad_ptr: *const f32,
672 beta1: f32,
673 beta1_complement: f32,
674 size: usize,
675 ) {
676 let beta1_vec = _mm256_set1_ps(beta1);
677 let beta1_comp_vec = _mm256_set1_ps(beta1_complement);
678 let simd_count = size / 8;
679
680 for i in 0..simd_count {
681 let offset = i * 8;
682 let m_vec = _mm256_loadu_ps(m_ptr.add(offset));
683 let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
684
685 let momentum_term = _mm256_mul_ps(beta1_vec, m_vec);
686 let gradient_term = _mm256_mul_ps(beta1_comp_vec, grad_vec);
687 let result = _mm256_add_ps(momentum_term, gradient_term);
688
689 _mm256_storeu_ps(m_ptr.add(offset), result);
690 }
691
692 // Handle remaining elements
693 for i in (simd_count * 8)..size {
694 *m_ptr.add(i) = beta1 * *m_ptr.add(i) + beta1_complement * *grad_ptr.add(i);
695 }
696 }
697
698 /// Update velocity (second moment estimate)
699 ///
700 /// Implements the velocity update rule: `v = beta2 * v + (1 - beta2) * grad^2`
701 /// This computes the exponentially decaying average of squared gradients,
702 /// providing adaptive learning rates that scale inversely with gradient magnitude.
703 ///
704 /// # Arguments
705 ///
706 /// * `velocity` - Velocity buffer to update in-place
707 /// * `grad` - Current gradient tensor
708 /// * `beta2` - Exponential decay rate for velocity
709 #[inline]
710 fn update_velocity(velocity: &mut Tensor, grad: &Tensor, beta2: f32) {
711 assert_eq!(
712 velocity.size(),
713 grad.size(),
714 "Velocity and gradient size mismatch"
715 );
716
717 let beta2_complement = 1.0 - beta2;
718
719 unsafe {
720 let v_ptr = velocity.as_mut_ptr();
721 let grad_ptr = grad.as_ptr();
722 let size = velocity.size();
723
724 #[cfg(target_arch = "x86_64")]
725 {
726 if is_x86_feature_detected!("avx2")
727 && velocity.is_simd_aligned()
728 && grad.is_simd_aligned()
729 {
730 Self::update_velocity_simd_avx2(v_ptr, grad_ptr, beta2, beta2_complement, size);
731 return;
732 }
733 }
734
735 // Scalar fallback
736 for i in 0..size {
737 let grad_val = *grad_ptr.add(i);
738 *v_ptr.add(i) = beta2 * *v_ptr.add(i) + beta2_complement * grad_val * grad_val;
739 }
740 }
741 }
742
743 #[cfg(target_arch = "x86_64")]
744 #[inline]
745 unsafe fn update_velocity_simd_avx2(
746 v_ptr: *mut f32,
747 grad_ptr: *const f32,
748 beta2: f32,
749 beta2_complement: f32,
750 size: usize,
751 ) {
752 let beta2_vec = _mm256_set1_ps(beta2);
753 let beta2_comp_vec = _mm256_set1_ps(beta2_complement);
754 let simd_count = size / 8;
755
756 for i in 0..simd_count {
757 let offset = i * 8;
758 let v_vec = _mm256_loadu_ps(v_ptr.add(offset));
759 let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
760
761 let velocity_term = _mm256_mul_ps(beta2_vec, v_vec);
762 let grad_squared = _mm256_mul_ps(grad_vec, grad_vec);
763 let gradient_term = _mm256_mul_ps(beta2_comp_vec, grad_squared);
764 let result = _mm256_add_ps(velocity_term, gradient_term);
765
766 _mm256_storeu_ps(v_ptr.add(offset), result);
767 }
768
769 // Handle remaining elements
770 for i in (simd_count * 8)..size {
771 let grad_val = *grad_ptr.add(i);
772 *v_ptr.add(i) = beta2 * *v_ptr.add(i) + beta2_complement * grad_val * grad_val;
773 }
774 }
775
776 /// Apply bias correction to moment estimates
777 ///
778 /// Returns `tensor / (1 - beta^step)` for unbiasing the moment estimates.
779 /// This correction is necessary because the moment estimates are initialized
780 /// to zero, creating a bias towards zero that becomes negligible as training progresses.
781 ///
782 /// # Arguments
783 ///
784 /// * `tensor` - Moment estimate tensor to correct
785 /// * `bias_correction` - Bias correction factor (1 - beta^step)
786 ///
787 /// # Returns
788 ///
789 /// Bias-corrected tensor with the same shape as input
790 #[inline]
791 fn bias_correct(tensor: &Tensor, bias_correction: f32) -> Tensor {
792 let mut result = tensor.clone();
793 let correction_factor = 1.0 / bias_correction;
794
795 unsafe {
796 let ptr = result.as_mut_ptr();
797 let size = result.size();
798
799 #[cfg(target_arch = "x86_64")]
800 {
801 if is_x86_feature_detected!("avx2") && result.is_simd_aligned() {
802 Self::bias_correct_simd_avx2(ptr, correction_factor, size);
803 return result;
804 }
805 }
806
807 // Scalar fallback
808 for i in 0..size {
809 *ptr.add(i) *= correction_factor;
810 }
811 }
812
813 result
814 }
815
816 #[cfg(target_arch = "x86_64")]
817 #[inline]
818 unsafe fn bias_correct_simd_avx2(ptr: *mut f32, correction_factor: f32, size: usize) {
819 let factor_vec = _mm256_set1_ps(correction_factor);
820 let simd_count = size / 8;
821
822 for i in 0..simd_count {
823 let offset = i * 8;
824 let data_vec = _mm256_loadu_ps(ptr.add(offset));
825 let result = _mm256_mul_ps(data_vec, factor_vec);
826 _mm256_storeu_ps(ptr.add(offset), result);
827 }
828
829 // Handle remaining elements
830 for i in (simd_count * 8)..size {
831 *ptr.add(i) *= correction_factor;
832 }
833 }
834
835 /// Element-wise maximum for AMSGrad
836 ///
837 /// Updates first tensor in-place with `max(first, second)` for AMSGrad variant.
838 /// This maintains a running maximum of the second moment estimates, preventing
839 /// the learning rate from increasing over time and improving convergence stability.
840 ///
841 /// # Arguments
842 ///
843 /// * `first` - Tensor to update in-place with maximum values
844 /// * `second` - Tensor to compare against for maximum calculation
845 #[inline]
846 fn element_wise_max(first: &mut Tensor, second: &Tensor) {
847 assert_eq!(
848 first.size(),
849 second.size(),
850 "Tensor size mismatch for element-wise max"
851 );
852
853 unsafe {
854 let first_ptr = first.as_mut_ptr();
855 let second_ptr = second.as_ptr();
856 let size = first.size();
857
858 #[cfg(target_arch = "x86_64")]
859 {
860 if is_x86_feature_detected!("avx2")
861 && first.is_simd_aligned()
862 && second.is_simd_aligned()
863 {
864 Self::element_wise_max_simd_avx2(first_ptr, second_ptr, size);
865 return;
866 }
867 }
868
869 // Scalar fallback
870 for i in 0..size {
871 let a = *first_ptr.add(i);
872 let b = *second_ptr.add(i);
873 *first_ptr.add(i) = if a > b { a } else { b };
874 }
875 }
876 }
877
878 #[cfg(target_arch = "x86_64")]
879 #[inline]
880 unsafe fn element_wise_max_simd_avx2(first_ptr: *mut f32, second_ptr: *const f32, size: usize) {
881 let simd_count = size / 8;
882
883 for i in 0..simd_count {
884 let offset = i * 8;
885 let first_vec = _mm256_loadu_ps(first_ptr.add(offset));
886 let second_vec = _mm256_loadu_ps(second_ptr.add(offset));
887 let result = _mm256_max_ps(first_vec, second_vec);
888 _mm256_storeu_ps(first_ptr.add(offset), result);
889 }
890
891 // Handle remaining elements
892 for i in (simd_count * 8)..size {
893 let a = *first_ptr.add(i);
894 let b = *second_ptr.add(i);
895 *first_ptr.add(i) = if a > b { a } else { b };
896 }
897 }
898
899 /// Apply the final Adam parameter update
900 ///
901 /// Implements the core Adam update rule: `param = param - lr * m_hat / (sqrt(v_hat) + eps)`
902 /// This applies the bias-corrected momentum and velocity estimates to update
903 /// the parameter values, with adaptive learning rates that scale inversely
904 /// with the square root of the velocity estimates.
905 ///
906 /// # Arguments
907 ///
908 /// * `param` - Parameter tensor to update in-place
909 /// * `m_hat` - Bias-corrected first moment estimate
910 /// * `v_hat` - Bias-corrected second moment estimate
911 /// * `learning_rate` - Learning rate for the update
912 /// * `eps` - Small constant for numerical stability
913 #[inline]
914 fn apply_adam_update(
915 param: &mut Tensor,
916 m_hat: &Tensor,
917 v_hat: &Tensor,
918 learning_rate: f32,
919 eps: f32,
920 ) {
921 assert_eq!(
922 param.size(),
923 m_hat.size(),
924 "Parameter and momentum size mismatch"
925 );
926 assert_eq!(
927 param.size(),
928 v_hat.size(),
929 "Parameter and velocity size mismatch"
930 );
931
932 unsafe {
933 let param_ptr = param.as_mut_ptr();
934 let m_ptr = m_hat.as_ptr();
935 let v_ptr = v_hat.as_ptr();
936 let size = param.size();
937
938 #[cfg(target_arch = "x86_64")]
939 {
940 if is_x86_feature_detected!("avx2")
941 && param.is_simd_aligned()
942 && m_hat.is_simd_aligned()
943 && v_hat.is_simd_aligned()
944 {
945 Self::apply_adam_update_simd_avx2(
946 param_ptr,
947 m_ptr,
948 v_ptr,
949 learning_rate,
950 eps,
951 size,
952 );
953 return;
954 }
955 }
956
957 // Scalar fallback
958 for i in 0..size {
959 let m_val = *m_ptr.add(i);
960 let v_val = *v_ptr.add(i);
961 let denominator = v_val.sqrt() + eps;
962 *param_ptr.add(i) -= learning_rate * m_val / denominator;
963 }
964 }
965 }
966
967 #[cfg(target_arch = "x86_64")]
968 #[inline]
969 unsafe fn apply_adam_update_simd_avx2(
970 param_ptr: *mut f32,
971 m_ptr: *const f32,
972 v_ptr: *const f32,
973 learning_rate: f32,
974 eps: f32,
975 size: usize,
976 ) {
977 let lr_vec = _mm256_set1_ps(learning_rate);
978 let eps_vec = _mm256_set1_ps(eps);
979 let simd_count = size / 8;
980
981 for i in 0..simd_count {
982 let offset = i * 8;
983 let param_vec = _mm256_loadu_ps(param_ptr.add(offset));
984 let m_vec = _mm256_loadu_ps(m_ptr.add(offset));
985 let v_vec = _mm256_loadu_ps(v_ptr.add(offset));
986
987 // sqrt(v_hat) + eps
988 let sqrt_v = _mm256_sqrt_ps(v_vec);
989 let denominator = _mm256_add_ps(sqrt_v, eps_vec);
990
991 // lr * m_hat / denominator
992 let lr_m = _mm256_mul_ps(lr_vec, m_vec);
993 let update = _mm256_div_ps(lr_m, denominator);
994
995 // param - update
996 let result = _mm256_sub_ps(param_vec, update);
997 _mm256_storeu_ps(param_ptr.add(offset), result);
998 }
999
1000 // Handle remaining elements
1001 for i in (simd_count * 8)..size {
1002 let m_val = *m_ptr.add(i);
1003 let v_val = *v_ptr.add(i);
1004 let denominator = v_val.sqrt() + eps;
1005 *param_ptr.add(i) -= learning_rate * m_val / denominator;
1006 }
1007 }
1008}
1009
1010impl Optimizer for Adam {
1011 /// Perform a single optimization step
1012 ///
1013 /// Updates all provided parameters based on their accumulated gradients using the Adam algorithm.
1014 /// Each parameter is updated according to the Adam update rule with bias correction
1015 /// and optional AMSGrad variant if enabled. All parameters must be linked to the optimizer
1016 /// before calling this method.
1017 ///
1018 /// # Arguments
1019 ///
1020 /// * `parameters` - Mutable slice of parameter references to update
1021 ///
1022 /// # Thread Safety
1023 ///
1024 /// This method is thread-safe as it takes mutable references to parameters,
1025 /// ensuring exclusive access during updates.
1026 ///
1027 /// # Performance
1028 ///
1029 /// - Uses SIMD optimization (AVX2) when available for 8x vectorization
1030 /// - Processes parameters in sequence for optimal cache usage
1031 /// - Maintains per-parameter state for momentum and velocity estimates
1032 ///
1033 /// # Panics
1034 ///
1035 /// Panics if any parameter is not linked to the optimizer
1036 fn step(&mut self, parameters: &mut [&mut Tensor]) {
1037 self.step_count += 1;
1038
1039 for param in parameters {
1040 if let Err(e) = self.update_parameter(param) {
1041 eprintln!("Warning: Failed to update parameter: {}", e);
1042 }
1043 }
1044 }
1045
1046 /// Zero out all parameter gradients
1047 ///
1048 /// Clears accumulated gradients for all provided parameters. This should be called
1049 /// before each backward pass to prevent gradient accumulation across multiple
1050 /// forward/backward passes. Also clears the global autograd gradient map.
1051 ///
1052 /// # Arguments
1053 ///
1054 /// * `parameters` - Mutable slice of parameter references to clear gradients for
1055 ///
1056 /// # Performance
1057 ///
1058 /// - Efficiently clears gradients using optimized tensor operations
1059 /// - Clears both per-tensor gradients and global autograd state
1060 /// - Thread-safe as it takes mutable references to parameters
1061 fn zero_grad(&mut self, parameters: &mut [&mut Tensor]) {
1062 for param in parameters {
1063 param.zero_grad();
1064 }
1065
1066 // Also clear gradient tracking gradient map
1067 crate::gradtrack::clear_gradients();
1068 }
1069
1070 /// Get the current learning rate
1071 ///
1072 /// Returns the current learning rate used for parameter updates.
1073 ///
1074 /// # Returns
1075 ///
1076 /// Current learning rate as f32
1077 fn learning_rate(&self) -> f32 {
1078 self.config.learning_rate
1079 }
1080
1081 /// Set the learning rate for all parameters
1082 ///
1083 /// Updates the learning rate for all parameters in the optimizer.
1084 /// This allows dynamic learning rate scheduling during training.
1085 ///
1086 /// # Arguments
1087 ///
1088 /// * `lr` - New learning rate value
1089 fn set_learning_rate(&mut self, lr: f32) {
1090 self.config.learning_rate = lr;
1091 }
1092}
1093
1094// Adam is automatically Send + Sync since it no longer contains raw pointers
1095// and all fields are Send + Sync (HashMap, usize, AdamConfig)
1096
1097#[cfg(test)]
1098mod tests {
1099 use super::*;
1100 use crate::tensor::core::Tensor;
1101
1102 /// Test Adam optimizer creation with default configuration
1103 ///
1104 /// Verifies that the optimizer is created correctly with default hyperparameters
1105 /// and that parameter states are properly initialized.
1106 #[test]
1107 fn test_adam_creation() {
1108 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1109 let bias = Tensor::zeros(vec![3, 1]).with_requires_grad();
1110
1111 let mut optimizer = Adam::new();
1112 optimizer.add_parameter(&weight);
1113 optimizer.add_parameter(&bias);
1114
1115 assert_eq!(optimizer.learning_rate(), 1e-3);
1116 assert_eq!(optimizer.config.beta1, 0.9);
1117 assert_eq!(optimizer.config.beta2, 0.999);
1118 assert_eq!(optimizer.parameter_count(), 2);
1119 }
1120
1121 /// Test Adam optimizer creation with custom configuration
1122 ///
1123 /// Verifies that custom hyperparameters are properly set and that
1124 /// the optimizer configuration matches the provided values.
1125 #[test]
1126 fn test_adam_with_config() {
1127 let weight = Tensor::ones(vec![5, 5]).with_requires_grad();
1128
1129 let config = AdamConfig {
1130 learning_rate: 1e-4,
1131 beta1: 0.95,
1132 beta2: 0.9999,
1133 weight_decay: 1e-5,
1134 amsgrad: true,
1135 ..Default::default()
1136 };
1137
1138 let mut optimizer = Adam::with_config(config);
1139 optimizer.add_parameter(&weight);
1140
1141 assert_eq!(optimizer.learning_rate(), 1e-4);
1142 assert_eq!(optimizer.config.beta1, 0.95);
1143 assert_eq!(optimizer.config.beta2, 0.9999);
1144 assert_eq!(optimizer.config.weight_decay, 1e-5);
1145 assert!(optimizer.config.amsgrad);
1146 }
1147
1148 /// Test Adam optimizer creation with custom learning rate
1149 ///
1150 /// Verifies that the convenience constructor properly sets the learning rate
1151 /// while maintaining default values for other hyperparameters.
1152 #[test]
1153 fn test_adam_with_learning_rate() {
1154 let weight = Tensor::ones(vec![3, 3]).with_requires_grad();
1155 let mut optimizer = Adam::with_learning_rate(5e-4);
1156 optimizer.add_parameter(&weight);
1157
1158 assert_eq!(optimizer.learning_rate(), 5e-4);
1159 assert_eq!(optimizer.config.beta1, 0.9); // Should use defaults for other params
1160 }
1161
1162 /// Test Adam step without gradients
1163 ///
1164 /// Verifies that the optimizer handles the case where parameters have no
1165 /// gradients gracefully, leaving parameters unchanged.
1166 #[test]
1167 fn test_adam_step_without_gradients() {
1168 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1169 let original_data = weight.clone();
1170
1171 let mut optimizer = Adam::new();
1172 optimizer.add_parameter(&weight);
1173
1174 optimizer.step(&mut [&mut weight]); // Should not update without gradients
1175
1176 // Parameters should remain unchanged without gradients
1177 for i in 0..weight.size() {
1178 unsafe {
1179 assert_eq!(*weight.as_ptr().add(i), *original_data.as_ptr().add(i));
1180 }
1181 }
1182 }
1183
1184 /// Test learning rate update functionality
1185 ///
1186 /// Verifies that the learning rate can be dynamically updated during training
1187 /// and that the optimizer uses the new learning rate for subsequent steps.
1188 #[test]
1189 fn test_learning_rate_update() {
1190 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1191 let mut optimizer = Adam::new();
1192 optimizer.add_parameter(&weight);
1193
1194 assert_eq!(optimizer.learning_rate(), 1e-3);
1195
1196 optimizer.set_learning_rate(1e-2);
1197 assert_eq!(optimizer.learning_rate(), 1e-2);
1198 }
1199
1200 /// Test gradient zeroing functionality
1201 ///
1202 /// Verifies that zero_grad properly clears accumulated gradients for all
1203 /// parameters and the global autograd state.
1204 #[test]
1205 fn test_zero_grad() {
1206 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1207
1208 // Check that zero_grad clears gradients
1209 let mut optimizer = Adam::new();
1210 optimizer.add_parameter(&weight);
1211 optimizer.zero_grad(&mut [&mut weight]);
1212
1213 // After zero_grad, there should be no accumulated gradients
1214 assert!(weight.grad_by_value().is_none());
1215 }
1216
1217 /// Test requires_grad assertion
1218 ///
1219 /// Verifies that the optimizer correctly panics when parameters do not
1220 /// have requires_grad set to true, ensuring proper gradient tracking.
1221 #[test]
1222 #[should_panic(expected = "Parameter must require gradients")]
1223 fn test_adam_requires_grad_assertion() {
1224 let weight = Tensor::ones(vec![2, 2]); // No requires_grad
1225 let mut optimizer = Adam::new();
1226 optimizer.add_parameter(&weight);
1227 }
1228
1229 /// Test AdamConfig default values
1230 ///
1231 /// Verifies that the default configuration matches PyTorch conventions
1232 /// and provides optimal settings for most training scenarios.
1233 #[test]
1234 fn test_adam_config_default() {
1235 let config = AdamConfig::default();
1236
1237 assert_eq!(config.learning_rate, 1e-3);
1238 assert_eq!(config.beta1, 0.9);
1239 assert_eq!(config.beta2, 0.999);
1240 assert_eq!(config.eps, 1e-8);
1241 assert_eq!(config.weight_decay, 0.0);
1242 assert!(!config.amsgrad);
1243 }
1244
1245 /// Test ParameterState creation and initialization
1246 ///
1247 /// Verifies that parameter states are properly initialized with zero tensors
1248 /// and that the step count starts at zero.
1249 #[test]
1250 fn test_parameter_state_creation() {
1251 let state = ParameterState::new(&[3, 4]);
1252
1253 assert_eq!(state.m.shape().dims, vec![3, 4]);
1254 assert_eq!(state.v.shape().dims, vec![3, 4]);
1255 assert!(state.v_hat_max.is_none());
1256 assert_eq!(state.step, 0);
1257
1258 // Verify tensors are zero-initialized
1259 for i in 0..state.m.size() {
1260 unsafe {
1261 assert_eq!(*state.m.as_ptr().add(i), 0.0);
1262 assert_eq!(*state.v.as_ptr().add(i), 0.0);
1263 }
1264 }
1265 }
1266
1267 /// Test parameter linking functionality
1268 ///
1269 /// Verifies that parameters can be linked and unlinked from the optimizer.
1270 #[test]
1271 fn test_parameter_linking() {
1272 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1273 let bias = Tensor::zeros(vec![3]).with_requires_grad();
1274
1275 let mut optimizer = Adam::new();
1276
1277 // Initially no parameters linked
1278 assert_eq!(optimizer.parameter_count(), 0);
1279 assert!(!optimizer.is_parameter_linked(&weight));
1280 assert!(!optimizer.is_parameter_linked(&bias));
1281
1282 // Link weight
1283 optimizer.add_parameter(&weight);
1284 assert_eq!(optimizer.parameter_count(), 1);
1285 assert!(optimizer.is_parameter_linked(&weight));
1286 assert!(!optimizer.is_parameter_linked(&bias));
1287
1288 // Link bias
1289 optimizer.add_parameter(&bias);
1290 assert_eq!(optimizer.parameter_count(), 2);
1291 assert!(optimizer.is_parameter_linked(&weight));
1292 assert!(optimizer.is_parameter_linked(&bias));
1293
1294 // Unlink weight
1295 let was_linked = optimizer.unlink_parameter(&weight);
1296 assert!(was_linked);
1297 assert_eq!(optimizer.parameter_count(), 1);
1298 assert!(!optimizer.is_parameter_linked(&weight));
1299 assert!(optimizer.is_parameter_linked(&bias));
1300
1301 // Clear all states
1302 optimizer.clear_states();
1303 assert_eq!(optimizer.parameter_count(), 0);
1304 assert!(!optimizer.is_parameter_linked(&weight));
1305 assert!(!optimizer.is_parameter_linked(&bias));
1306 }
1307
1308 /// Test parameter linking with multiple parameters at once
1309 ///
1310 /// Verifies that multiple parameters can be linked simultaneously.
1311 #[test]
1312 fn test_add_multiple_parameters() {
1313 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1314 let bias = Tensor::zeros(vec![3]).with_requires_grad();
1315 let weight2 = Tensor::ones(vec![3, 2]).with_requires_grad();
1316
1317 let mut optimizer = Adam::new();
1318
1319 // Link multiple parameters at once
1320 optimizer.add_parameters(&[&weight, &bias, &weight2]);
1321
1322 assert_eq!(optimizer.parameter_count(), 3);
1323 assert!(optimizer.is_parameter_linked(&weight));
1324 assert!(optimizer.is_parameter_linked(&bias));
1325 assert!(optimizer.is_parameter_linked(&weight2));
1326 }
1327
1328 /// Test stepping with unlinked parameter
1329 ///
1330 /// Verifies that the optimizer panics when trying to step with an unlinked parameter.
1331 #[test]
1332 #[should_panic(expected = "Parameter must be linked to optimizer before stepping")]
1333 fn test_step_with_unlinked_parameter() {
1334 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1335 let mut optimizer = Adam::new();
1336
1337 // Don't link the parameter
1338 optimizer.step(&mut [&mut weight]); // Should panic
1339 }
1340
1341 /// Test optimizer with actual gradients
1342 ///
1343 /// Verifies that the optimizer properly updates parameters when gradients are present.
1344 #[test]
1345 fn test_optimizer_with_gradients() {
1346 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1347 let original_data = weight.clone();
1348
1349 let mut optimizer = Adam::new();
1350 optimizer.add_parameter(&weight);
1351
1352 // Generate some gradients
1353 let output = weight.mul_scalar(2.0);
1354 let mut loss = output.sum();
1355 loss.backward(None);
1356
1357 // Step should update parameters
1358 optimizer.step(&mut [&mut weight]);
1359
1360 // Parameters should have changed
1361 let mut changed = false;
1362 for i in 0..weight.size() {
1363 unsafe {
1364 if (*weight.as_ptr().add(i) - *original_data.as_ptr().add(i)).abs() > 1e-6 {
1365 changed = true;
1366 break;
1367 }
1368 }
1369 }
1370 assert!(
1371 changed,
1372 "Parameters should have been updated by optimizer step"
1373 );
1374 }
1375
1376 /// Test optimizer with multiple parameters and gradients
1377 ///
1378 /// Verifies that the optimizer works correctly with multiple parameters.
1379 #[test]
1380 fn test_optimizer_multiple_parameters() {
1381 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1382 let mut bias = Tensor::zeros(vec![2, 2]).with_requires_grad(); // Same shape as weight
1383
1384 let mut optimizer = Adam::new();
1385 optimizer.add_parameter(&weight);
1386 optimizer.add_parameter(&bias);
1387
1388 // Generate gradients for both parameters
1389 let output = weight.mul_scalar(2.0).add_tensor(&bias);
1390 let mut loss = output.sum();
1391 loss.backward(None);
1392
1393 // Step should update both parameters
1394 optimizer.step(&mut [&mut weight, &mut bias]);
1395
1396 // Both parameters should have gradients
1397 assert!(weight.grad_by_value().is_some());
1398 assert!(bias.grad_by_value().is_some());
1399 }
1400
1401 /// Test optimizer with custom configuration and multiple steps
1402 ///
1403 /// Verifies that the optimizer works correctly with custom configuration over multiple steps.
1404 #[test]
1405 fn test_optimizer_custom_config_multiple_steps() {
1406 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1407
1408 let config = AdamConfig {
1409 learning_rate: 1e-4,
1410 beta1: 0.95,
1411 beta2: 0.9999,
1412 weight_decay: 1e-5,
1413 amsgrad: true,
1414 ..Default::default()
1415 };
1416
1417 let mut optimizer = Adam::with_config(config);
1418 optimizer.add_parameter(&weight);
1419
1420 // Multiple training steps
1421 for _ in 0..5 {
1422 // Generate gradients
1423 let output = weight.mul_scalar(2.0);
1424 let mut loss = output.sum();
1425 loss.backward(None);
1426
1427 // Step
1428 optimizer.step(&mut [&mut weight]);
1429 optimizer.zero_grad(&mut [&mut weight]);
1430 }
1431
1432 // Should complete without errors
1433 assert_eq!(optimizer.parameter_count(), 1);
1434 assert!(optimizer.is_parameter_linked(&weight));
1435 }
1436
1437 /// Test optimizer with different tensor shapes
1438 ///
1439 /// Verifies that the optimizer works correctly with various tensor shapes.
1440 #[test]
1441 fn test_optimizer_different_shapes() {
1442 let shapes = vec![
1443 vec![1], // Scalar
1444 vec![3], // 1D
1445 vec![2, 2], // 2D square
1446 vec![2, 3], // 2D rectangular
1447 vec![1, 1, 3], // 3D
1448 vec![2, 2, 2], // 3D cube
1449 ];
1450
1451 for shape in shapes {
1452 let mut tensor = Tensor::ones(shape.clone()).with_requires_grad();
1453 let mut optimizer = Adam::new();
1454 optimizer.add_parameter(&tensor);
1455
1456 // Generate gradients
1457 let output = tensor.mul_scalar(2.0);
1458 let mut loss = output.sum();
1459 loss.backward(None);
1460
1461 // Step should work for all shapes
1462 optimizer.step(&mut [&mut tensor]);
1463
1464 // Verify tensor is still valid
1465 assert_eq!(tensor.shape().dims, shape);
1466 assert!(tensor.requires_grad());
1467 }
1468 }
1469
1470 /// Test double parameter linking
1471 ///
1472 /// Verifies that linking the same parameter twice doesn't create duplicate states.
1473 #[test]
1474 fn test_double_parameter_linking() {
1475 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1476 let mut optimizer = Adam::new();
1477
1478 // Link parameter twice
1479 optimizer.add_parameter(&weight);
1480 optimizer.add_parameter(&weight);
1481
1482 // Should only have one state
1483 assert_eq!(optimizer.parameter_count(), 1);
1484 assert!(optimizer.is_parameter_linked(&weight));
1485 }
1486
1487 /// Test unlink non-linked parameter
1488 ///
1489 /// Verifies that unlinking a non-linked parameter returns false.
1490 #[test]
1491 fn test_unlink_non_linked_parameter() {
1492 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1493 let mut optimizer = Adam::new();
1494
1495 // Try to unlink a parameter that was never linked
1496 let was_linked = optimizer.unlink_parameter(&weight);
1497 assert!(!was_linked);
1498 }
1499
1500 /// Test parameter shape validation
1501 ///
1502 /// Verifies that parameters with different shapes can be linked correctly
1503 /// and maintain their shape information in the optimizer state.
1504 #[test]
1505 fn test_parameter_shape_validation() {
1506 let weight_1d = Tensor::ones(vec![5]).with_requires_grad();
1507 let weight_2d = Tensor::ones(vec![3, 4]).with_requires_grad();
1508 let weight_3d = Tensor::ones(vec![2, 3, 4]).with_requires_grad();
1509
1510 let mut optimizer = Adam::new();
1511
1512 // Test linking parameters with different shapes
1513 optimizer.add_parameter(&weight_1d);
1514 optimizer.add_parameter(&weight_2d);
1515 optimizer.add_parameter(&weight_3d);
1516
1517 assert_eq!(optimizer.parameter_count(), 3);
1518 assert!(optimizer.is_parameter_linked(&weight_1d));
1519 assert!(optimizer.is_parameter_linked(&weight_2d));
1520 assert!(optimizer.is_parameter_linked(&weight_3d));
1521
1522 // Test that each parameter maintains its shape after linking
1523 let state_1d = &optimizer.states[&weight_1d.id()];
1524 let state_2d = &optimizer.states[&weight_2d.id()];
1525 let state_3d = &optimizer.states[&weight_3d.id()];
1526
1527 assert_eq!(state_1d.m.shape().dims, vec![5]);
1528 assert_eq!(state_2d.m.shape().dims, vec![3, 4]);
1529 assert_eq!(state_3d.m.shape().dims, vec![2, 3, 4]);
1530 }
1531
1532 /// Test parameter relinking with shape consistency
1533 ///
1534 /// Verifies that when re-linking parameters (e.g., after deserialization),
1535 /// shape information is correctly maintained.
1536 #[test]
1537 fn test_parameter_relinking_shape_consistency() {
1538 // Create initial parameter and optimizer
1539 let mut weight_original = Tensor::ones(vec![3, 3]).with_requires_grad();
1540 let mut optimizer = Adam::new();
1541 optimizer.add_parameter(&weight_original);
1542
1543 // Perform a step to create state
1544 let output = weight_original.mul_scalar(2.0);
1545 let mut loss = output.sum();
1546 loss.backward(None);
1547 optimizer.step(&mut [&mut weight_original]);
1548
1549 // Verify state was created with correct shape
1550 let original_state = &optimizer.states[&weight_original.id()];
1551 assert_eq!(original_state.m.shape().dims, vec![3, 3]);
1552 assert_eq!(original_state.v.shape().dims, vec![3, 3]);
1553
1554 // Create new parameter with same shape (will get different ID)
1555 let weight_new = Tensor::ones(vec![3, 3]).with_requires_grad();
1556
1557 // This should create a new state since it's a different parameter
1558 optimizer.add_parameter(&weight_new);
1559
1560 // Should now have 2 states
1561 assert_eq!(optimizer.parameter_count(), 2);
1562
1563 // Both should be linked with correct shapes
1564 assert!(optimizer.is_parameter_linked(&weight_original));
1565 assert!(optimizer.is_parameter_linked(&weight_new));
1566
1567 let new_state = &optimizer.states[&weight_new.id()];
1568 assert_eq!(new_state.m.shape().dims, vec![3, 3]);
1569 assert_eq!(new_state.v.shape().dims, vec![3, 3]);
1570 }
1571
1572 /// Test large parameter count handling
1573 ///
1574 /// Verifies that the optimizer can handle many parameters efficiently
1575 /// and correctly manage linking/unlinking operations.
1576 #[test]
1577 fn test_large_parameter_count() {
1578 let mut optimizer = Adam::new();
1579 let mut params = Vec::new();
1580
1581 // Create 50 parameters of different shapes
1582 for i in 1..=50 {
1583 let param = Tensor::ones(vec![i]).with_requires_grad();
1584 optimizer.add_parameter(¶m);
1585 params.push(param);
1586 }
1587
1588 assert_eq!(optimizer.parameter_count(), 50);
1589
1590 // Verify all parameters are linked
1591 for param in ¶ms {
1592 assert!(optimizer.is_parameter_linked(param));
1593 }
1594
1595 // Test unlinking some parameters
1596 for param in params.iter().take(25).step_by(2) {
1597 assert!(optimizer.unlink_parameter(param));
1598 }
1599
1600 assert_eq!(optimizer.parameter_count(), 37); // 50 - 13 = 37
1601
1602 // Verify correct parameters are unlinked
1603 for (i, param) in params.iter().enumerate() {
1604 if i < 25 && i % 2 == 0 {
1605 assert!(!optimizer.is_parameter_linked(param));
1606 } else {
1607 assert!(optimizer.is_parameter_linked(param));
1608 }
1609 }
1610 }
1611
1612 /// Test clear_states functionality
1613 ///
1614 /// Verifies that clearing all states works correctly and allows
1615 /// re-adding parameters afterwards.
1616 #[test]
1617 fn test_clear_states_functionality() {
1618 let weight1 = Tensor::ones(vec![2, 2]).with_requires_grad();
1619 let weight2 = Tensor::ones(vec![3, 3]).with_requires_grad();
1620 let weight3 = Tensor::ones(vec![4, 4]).with_requires_grad();
1621
1622 let mut optimizer = Adam::new();
1623 optimizer.add_parameter(&weight1);
1624 optimizer.add_parameter(&weight2);
1625 optimizer.add_parameter(&weight3);
1626
1627 assert_eq!(optimizer.parameter_count(), 3);
1628
1629 // Clear all states
1630 optimizer.clear_states();
1631
1632 assert_eq!(optimizer.parameter_count(), 0);
1633 assert!(!optimizer.is_parameter_linked(&weight1));
1634 assert!(!optimizer.is_parameter_linked(&weight2));
1635 assert!(!optimizer.is_parameter_linked(&weight3));
1636
1637 // Should be able to re-add parameters after clearing
1638 optimizer.add_parameter(&weight1);
1639 assert_eq!(optimizer.parameter_count(), 1);
1640 assert!(optimizer.is_parameter_linked(&weight1));
1641 }
1642}