train_station/optimizers/adam/
mod.rs

1//! Adam optimizer implementation for neural network training
2//!
3//! This module provides the Adam optimization algorithm with PyTorch-compatible interface.
4//! Adam combines the benefits of AdaGrad and RMSprop by using adaptive learning rates
5//! with momentum for efficient training of neural networks.
6//!
7//! # Features
8//!
9//! - **Adaptive Learning Rates**: Per-parameter learning rates based on gradient history
10//! - **Momentum Integration**: Combines momentum with adaptive learning rates
11//! - **Bias Correction**: Corrects for initialization bias in moment estimates
12//! - **Weight Decay**: Optional L2 regularization support
13//! - **AMSGrad Variant**: Optional AMSGrad for improved convergence stability
14//! - **SIMD Optimization**: AVX2-optimized parameter updates for maximum performance
15//! - **Thread Safety**: Send + Sync implementation for multi-threaded training
16//! - **Hybrid API**: Both safe (RwLock-based) and unsafe (direct pointer) access patterns
17//!
18//! # Thread Safety
19//!
20//! The optimizer provides two usage patterns:
21//!
22//! **Safe Multi-threaded Usage (Default)**:
23//! - Uses `Arc<RwLock<Tensor>>` for thread-safe parameter access
24//! - Multiple threads can read tensors simultaneously
25//! - Optimizer steps acquire write locks only during parameter updates
26//! - Recommended for most use cases
27//!
28//! **Unsafe Single-threaded Usage (Performance)**:
29//! - Uses raw pointers for maximum performance
30//! - No locking overhead during optimizer steps
31//! - Caller must ensure exclusive access during optimization
32//! - Use only when you can guarantee no concurrent tensor access
33//!
34//! # Algorithm
35//!
36//! Adam implements the following update rule for each parameter theta:
37//!
38//! ```text
39//! m_t = beta1 * m_{t-1} + (1 - beta1) * grad_theta_t
40//! v_t = beta2 * v_{t-1} + (1 - beta2) * (grad_theta_t)^2
41//! m_hat_t = m_t / (1 - beta1^t)
42//! v_hat_t = v_t / (1 - beta2^t)
43//! theta_{t+1} = theta_t - lr * m_hat_t / (sqrt(v_hat_t) + eps)
44//! ```
45//!
46//! Where:
47//! - lr is the learning rate
48//! - beta1, beta2 are exponential decay rates for moment estimates
49//! - eps is a small constant for numerical stability
50//! - m_t, v_t are biased first and second moment estimates
51//! - m_hat_t, v_hat_t are bias-corrected moment estimates
52//!
53//! # Performance Characteristics
54//!
55//! - **SIMD Optimization**: Uses AVX2 instructions for 8x vectorization when available
56//! - **Memory Efficiency**: In-place updates with minimal temporary allocations
57//! - **Cache-Friendly**: Optimized memory access patterns for large parameter tensors
58//! - **Zero-Cost Abstractions**: Compile-time optimization with minimal runtime overhead
59//! - **Lock-Free Reads**: RwLock allows concurrent tensor reads during training
60//!
61//! # References
62//!
63//! - Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization.
64//! - PyTorch Adam implementation: <https://pytorch.org/docs/stable/generated/torch.optim.Adam.html>
65
66pub mod serialization;
67
68use super::Optimizer;
69use crate::tensor::core::Tensor;
70use std::collections::HashMap;
71
72// SIMD optimizations for performance-critical operations
73#[cfg(target_arch = "x86_64")]
74use std::arch::x86_64::*;
75
76/// Configuration for the Adam optimization algorithm
77///
78/// Contains all hyperparameters that control the behavior of Adam optimization.
79/// Default values follow PyTorch conventions for maximum compatibility and
80/// optimal convergence across a wide range of neural network architectures.
81///
82/// # Fields
83///
84/// * `learning_rate` - Step size for parameter updates (default: 1e-3)
85/// * `beta1` - Exponential decay rate for first moment estimates (default: 0.9)
86/// * `beta2` - Exponential decay rate for second moment estimates (default: 0.999)
87/// * `eps` - Small constant for numerical stability (default: 1e-8)
88/// * `weight_decay` - L2 regularization coefficient (default: 0.0)
89/// * `amsgrad` - Whether to use AMSGrad variant for improved stability (default: false)
90#[derive(Debug, Clone)]
91pub struct AdamConfig {
92    /// Learning rate for parameter updates (default: 1e-3)
93    pub learning_rate: f32,
94    /// Exponential decay rate for first moment estimates (default: 0.9)  
95    pub beta1: f32,
96    /// Exponential decay rate for second moment estimates (default: 0.999)
97    pub beta2: f32,
98    /// Small constant for numerical stability (default: 1e-8)
99    pub eps: f32,
100    /// Weight decay coefficient for L2 regularization (default: 0.0)
101    pub weight_decay: f32,
102    /// Whether to use AMSGrad variant (default: false)
103    pub amsgrad: bool,
104}
105
106impl Default for AdamConfig {
107    fn default() -> Self {
108        Self {
109            learning_rate: 1e-3,
110            beta1: 0.9,
111            beta2: 0.999,
112            eps: 1e-8,
113            weight_decay: 0.0,
114            amsgrad: false,
115        }
116    }
117}
118
119/// Internal state tracking for a single parameter during Adam optimization
120///
121/// Stores the momentum and velocity buffers needed for Adam optimization,
122/// along with step count for bias correction and optional AMSGrad state.
123/// Memory layout is optimized for cache efficiency and SIMD operations.
124///
125/// # Fields
126///
127/// * `m` - First moment estimate (momentum buffer)
128/// * `v` - Second moment estimate (velocity buffer)
129/// * `v_hat_max` - Maximum of second moment estimates (for AMSGrad variant)
130/// * `step` - Step count for bias correction calculations
131#[derive(Debug)]
132struct ParameterState {
133    /// First moment estimate (momentum)
134    m: Tensor,
135    /// Second moment estimate (velocity)  
136    v: Tensor,
137    /// Maximum of second moment estimates (for AMSGrad)
138    v_hat_max: Option<Tensor>,
139    /// Step count for bias correction
140    step: usize,
141}
142
143impl ParameterState {
144    fn new(param_shape: &[usize]) -> Self {
145        Self {
146            m: Tensor::zeros(param_shape.to_vec()),
147            v: Tensor::zeros(param_shape.to_vec()),
148            v_hat_max: None,
149            step: 0,
150        }
151    }
152}
153
154/// Adam optimizer for neural network parameter optimization
155///
156/// Implements the Adam optimization algorithm with PyTorch-compatible interface.
157/// Provides adaptive learning rates with momentum for efficient training of neural networks.
158/// The optimizer maintains per-parameter state for momentum and velocity estimates,
159/// enabling adaptive learning rates that improve convergence across diverse architectures.
160///
161/// # Usage Pattern
162///
163/// The optimizer uses ID-based parameter linking for maximum flexibility and thread safety:
164/// - Parameters are linked to the optimizer via `add_parameter` or `add_parameters`
165/// - The `step` method takes mutable references to parameters for thread-safe updates
166/// - Parameter states are maintained by tensor ID, allowing for dynamic parameter management
167/// - Supports serialization and deserialization with parameter re-linking
168///
169/// # Dynamic Parameter Management
170///
171/// Parameters can be added, removed, or re-linked at runtime:
172/// - `add_parameter`: Link a single parameter
173/// - `add_parameters`: Link multiple parameters at once
174/// - `unlink_parameter`: Remove parameter state by ID
175/// - `clear_states`: Remove all parameter states
176/// - `is_parameter_linked`: Check if a parameter is linked
177///
178/// # Serialization Support
179///
180/// The optimizer supports full serialization and deserialization with state preservation:
181/// - Parameter states are saved with their shapes and insertion order for validation
182/// - After deserialization, use `relink_parameters` to restore saved states to new tensors
183/// - Parameters must be re-linked in the same chronological order they were originally added
184/// - Shape validation ensures consistency between saved and current parameters
185///
186/// # Features
187///
188/// - **ID-Based Parameter Linking**: Dynamic parameter management via tensor IDs
189/// - **Thread-Safe Step Method**: Takes mutable references for safe concurrent access
190/// - **Per-Parameter State**: Each parameter maintains its own momentum and velocity buffers
191/// - **Bias Correction**: Automatically corrects initialization bias in moment estimates
192/// - **Weight Decay**: Optional L2 regularization with efficient implementation
193/// - **AMSGrad Support**: Optional AMSGrad variant for improved convergence stability
194/// - **SIMD Optimization**: AVX2-optimized updates for maximum performance
195/// - **Full Serialization**: Complete state persistence and restoration
196///
197/// # Thread Safety
198///
199/// This type is thread-safe and can be shared between threads. The step method
200/// takes mutable references to parameters, ensuring exclusive access during updates.
201pub struct Adam {
202    /// Optimizer configuration
203    config: AdamConfig,
204    /// Parameter states indexed by tensor ID
205    states: HashMap<usize, ParameterState>,
206    /// Current global step count
207    step_count: usize,
208    /// Order in which parameters were added (for serialization/re-linking)
209    insertion_order: Vec<usize>,
210}
211
212impl Default for Adam {
213    fn default() -> Self {
214        Self::with_config(AdamConfig::default())
215    }
216}
217
218impl Adam {
219    /// Create a new Adam optimizer with default configuration
220    ///
221    /// Initializes an Adam optimizer with PyTorch-compatible default hyperparameters.
222    /// Parameters must be linked separately using `add_parameter` or `add_parameters`.
223    ///
224    /// # Returns
225    ///
226    /// A new Adam optimizer instance with default hyperparameters
227    #[track_caller]
228    pub fn new() -> Self {
229        Self::default()
230    }
231
232    /// Create a new Adam optimizer with custom configuration
233    ///
234    /// Allows full control over all Adam hyperparameters for specialized training
235    /// scenarios such as fine-tuning, transfer learning, or research applications.
236    /// Parameters must be linked separately using `add_parameter` or `add_parameters`.
237    ///
238    /// # Arguments
239    ///
240    /// * `config` - Adam configuration with custom hyperparameters
241    ///
242    /// # Returns
243    ///
244    /// A new Adam optimizer instance with the specified configuration
245    #[track_caller]
246    pub fn with_config(config: AdamConfig) -> Self {
247        Self {
248            config,
249            states: HashMap::new(),
250            step_count: 0,
251            insertion_order: Vec::new(),
252        }
253    }
254
255    /// Create a new Adam optimizer with custom learning rate
256    ///
257    /// A convenience constructor that allows setting only the learning rate while
258    /// using default values for all other hyperparameters. Parameters must be
259    /// linked separately using `add_parameter` or `add_parameters`.
260    ///
261    /// # Arguments
262    ///
263    /// * `learning_rate` - Learning rate for optimization
264    ///
265    /// # Returns
266    ///
267    /// A new Adam optimizer instance with the specified learning rate and default
268    /// values for all other hyperparameters
269    #[track_caller]
270    pub fn with_learning_rate(learning_rate: f32) -> Self {
271        let config = AdamConfig {
272            learning_rate,
273            ..Default::default()
274        };
275        Self::with_config(config)
276    }
277
278    /// Add a single parameter to the optimizer
279    ///
280    /// Links a parameter to the optimizer by creating a new parameter state
281    /// indexed by the tensor's ID. The parameter must have `requires_grad` set to true.
282    ///
283    /// # Arguments
284    ///
285    /// * `parameter` - Reference to the tensor to link
286    ///
287    /// # Panics
288    ///
289    /// Panics if the parameter does not have `requires_grad` set to true
290    #[track_caller]
291    pub fn add_parameter(&mut self, parameter: &Tensor) {
292        assert!(
293            parameter.requires_grad(),
294            "Parameter must require gradients"
295        );
296
297        let param_id = parameter.id();
298        let param_shape = parameter.shape().dims();
299
300        // Initialize state for this parameter if not already present
301        use std::collections::hash_map::Entry;
302        if let Entry::Vacant(entry) = self.states.entry(param_id) {
303            entry.insert(ParameterState::new(param_shape));
304            self.insertion_order.push(param_id);
305        }
306    }
307
308    /// Add multiple parameters to the optimizer
309    ///
310    /// Links multiple parameters to the optimizer by creating parameter states
311    /// indexed by each tensor's ID. All parameters must have `requires_grad` set to true.
312    ///
313    /// # Arguments
314    ///
315    /// * `parameters` - Slice of references to tensors to link
316    ///
317    /// # Panics
318    ///
319    /// Panics if any parameter does not have `requires_grad` set to true
320    #[track_caller]
321    pub fn add_parameters(&mut self, parameters: &[&Tensor]) {
322        for parameter in parameters {
323            self.add_parameter(parameter);
324        }
325    }
326
327    /// Remove a parameter from the optimizer
328    ///
329    /// Unlinks a parameter by removing its state from the optimizer.
330    /// The parameter ID is used for identification.
331    ///
332    /// # Arguments
333    ///
334    /// * `parameter` - Reference to the tensor to unlink
335    ///
336    /// # Returns
337    ///
338    /// True if the parameter was linked and removed, false if it was not linked
339    #[track_caller]
340    pub fn unlink_parameter(&mut self, parameter: &Tensor) -> bool {
341        let param_id = parameter.id();
342        let was_linked = self.states.remove(&param_id).is_some();
343        if was_linked {
344            self.insertion_order.retain(|&id| id != param_id);
345        }
346        was_linked
347    }
348
349    /// Remove all parameter states from the optimizer
350    ///
351    /// Clears all parameter states, effectively unlinking all parameters.
352    /// This is useful for resetting the optimizer or preparing for parameter re-linking.
353    #[track_caller]
354    pub fn clear_states(&mut self) {
355        self.states.clear();
356        self.insertion_order.clear();
357    }
358
359    /// Check if a parameter is linked to the optimizer
360    ///
361    /// Returns true if the parameter has an associated state in the optimizer.
362    ///
363    /// # Arguments
364    ///
365    /// * `parameter` - Reference to the tensor to check
366    ///
367    /// # Returns
368    ///
369    /// True if the parameter is linked, false otherwise
370    #[track_caller]
371    pub fn is_parameter_linked(&self, parameter: &Tensor) -> bool {
372        let param_id = parameter.id();
373        self.states.contains_key(&param_id)
374    }
375
376    /// Get the number of linked parameters
377    ///
378    /// Returns the count of parameters currently linked to the optimizer.
379    ///
380    /// # Returns
381    ///
382    /// Number of linked parameters
383    #[track_caller]
384    pub fn parameter_count(&self) -> usize {
385        self.states.len()
386    }
387
388    /// Re-link parameters to saved optimizer states in chronological order
389    ///
390    /// After deserializing an optimizer, use this method to restore saved parameter states
391    /// to new tensors. Parameters must be provided in the same chronological order they
392    /// were originally added to the optimizer. Shape validation ensures parameter compatibility.
393    ///
394    /// # Arguments
395    ///
396    /// * `parameters` - Slice of parameter references in chronological order
397    ///
398    /// # Returns
399    ///
400    /// Result indicating success or failure with detailed error message
401    ///
402    /// # Panics
403    ///
404    /// Panics if any parameter does not have `requires_grad` set to true
405    #[track_caller]
406    pub fn relink_parameters(&mut self, parameters: &[&Tensor]) -> Result<(), String> {
407        // Validate all parameters have requires_grad first
408        for (i, param) in parameters.iter().enumerate() {
409            if !param.requires_grad() {
410                return Err(format!("Parameter at index {} must require gradients", i));
411            }
412        }
413
414        // Check parameter count matches saved states
415        if parameters.len() != self.insertion_order.len() {
416            return Err(format!(
417                "Parameter count mismatch: expected {} parameters, got {}",
418                self.insertion_order.len(),
419                parameters.len()
420            ));
421        }
422
423        // Create new states map with parameter IDs mapped to saved states in chronological order
424        let mut new_states = HashMap::new();
425        let mut new_insertion_order = Vec::new();
426
427        for (i, param) in parameters.iter().enumerate() {
428            let new_param_id = param.id();
429            let old_param_id = self.insertion_order[i];
430
431            // Get the saved state for this position
432            let saved_state = self
433                .states
434                .get(&old_param_id)
435                .ok_or_else(|| format!("No saved state found for parameter at position {}", i))?;
436
437            // Validate shape matches
438            let param_shape = &param.shape().dims();
439            let saved_shape = &saved_state.m.shape().dims();
440            if param_shape != saved_shape {
441                return Err(format!(
442                    "Shape mismatch for parameter at position {}: expected {:?}, got {:?}",
443                    i, saved_shape, param_shape
444                ));
445            }
446
447            // Create new state for this parameter
448            let new_state = ParameterState {
449                m: saved_state.m.clone(),
450                v: saved_state.v.clone(),
451                v_hat_max: saved_state.v_hat_max.clone(),
452                step: saved_state.step,
453            };
454
455            new_states.insert(new_param_id, new_state);
456            new_insertion_order.push(new_param_id);
457        }
458
459        // Replace the states and insertion order
460        self.states = new_states;
461        self.insertion_order = new_insertion_order;
462
463        Ok(())
464    }
465
466    /// Get the current optimizer configuration
467    ///
468    /// Returns a reference to the current configuration, allowing inspection
469    /// of all hyperparameters without modification.
470    ///
471    /// # Returns
472    ///
473    /// Reference to the current Adam configuration
474    #[track_caller]
475    pub fn config(&self) -> &AdamConfig {
476        &self.config
477    }
478
479    /// Update a single parameter using Adam algorithm
480    ///
481    /// Implements the core Adam update rule with bias correction and optional AMSGrad.
482    /// Uses SIMD optimization when available for improved performance.
483    /// The parameter must be linked to the optimizer before calling this method.
484    ///
485    /// # Arguments
486    ///
487    /// * `param` - Mutable reference to the parameter tensor to update
488    ///
489    /// # Returns
490    ///
491    /// Result indicating success or failure of the parameter update
492    ///
493    /// # Panics
494    ///
495    /// Panics if the parameter is not linked to the optimizer
496    fn update_parameter(&mut self, param: &mut Tensor) -> Result<(), String> {
497        let param_id = param.id();
498
499        // Ensure parameter is linked
500        assert!(
501            self.states.contains_key(&param_id),
502            "Parameter must be linked to optimizer before stepping. Use add_parameter() first."
503        );
504
505        // Get parameter gradient
506        let grad = param
507            .grad_owned()
508            .ok_or_else(|| format!("Parameter {} has no gradient", param_id))?;
509
510        // Get parameter state
511        let state = self
512            .states
513            .get_mut(&param_id)
514            .expect("Parameter state should exist after link check");
515
516        // Increment per-parameter step count and use it for bias correction
517        state.step += 1;
518        let step = state.step as f32;
519
520        // Apply weight decay if enabled
521        let effective_grad = if self.config.weight_decay > 0.0 {
522            // L2 regularization: grad + weight_decay * param
523            let mut grad_with_decay = grad.clone();
524            Self::add_weight_decay(&mut grad_with_decay, param, self.config.weight_decay);
525            grad_with_decay
526        } else {
527            grad
528        };
529
530        // Update biased first moment estimate: m = beta1 * m + (1 - beta1) * grad
531        Self::update_momentum(&mut state.m, &effective_grad, self.config.beta1);
532
533        // Update biased second moment estimate: v = beta2 * v + (1 - beta2) * grad^2
534        Self::update_velocity(&mut state.v, &effective_grad, self.config.beta2);
535
536        // Compute bias-corrected first moment estimate
537        let bias_correction1 = 1.0 - (self.config.beta1).powf(step);
538        let m_hat = Self::bias_correct(&state.m, bias_correction1);
539
540        // Compute bias-corrected second moment estimate
541        let bias_correction2 = 1.0 - (self.config.beta2).powf(step);
542        let mut v_hat = Self::bias_correct(&state.v, bias_correction2);
543
544        // AMSGrad: use maximum of v_hat over time
545        if self.config.amsgrad {
546            if state.v_hat_max.is_none() {
547                state.v_hat_max = Some(v_hat.clone());
548            }
549            let v_hat_max = state.v_hat_max.as_mut().unwrap();
550            Self::element_wise_max(v_hat_max, &v_hat);
551            v_hat = v_hat_max.clone();
552        }
553
554        // Compute parameter update: param = param - lr * m_hat / (sqrt(v_hat) + eps)
555        Self::apply_adam_update(
556            param,
557            &m_hat,
558            &v_hat,
559            self.config.learning_rate,
560            self.config.eps,
561        );
562
563        Ok(())
564    }
565
566    /// Apply weight decay (L2 regularization) to gradient
567    ///
568    /// Adds `weight_decay * param` to the gradient in-place for memory efficiency.
569    /// This implements L2 regularization by modifying the gradient before the
570    /// Adam update step, equivalent to adding a regularization term to the loss.
571    ///
572    /// # Arguments
573    ///
574    /// * `grad` - Gradient tensor to modify in-place
575    /// * `param` - Parameter tensor for weight decay calculation
576    /// * `weight_decay` - Weight decay coefficient
577    #[inline]
578    fn add_weight_decay(grad: &mut Tensor, param: &Tensor, weight_decay: f32) {
579        assert_eq!(
580            grad.size(),
581            param.size(),
582            "Gradient and parameter size mismatch"
583        );
584
585        unsafe {
586            let grad_ptr = grad.as_mut_ptr();
587            let param_ptr = param.as_ptr();
588            let size = grad.size();
589
590            #[cfg(target_arch = "x86_64")]
591            {
592                if is_x86_feature_detected!("avx2")
593                    && grad.is_simd_aligned()
594                    && param.is_simd_aligned()
595                {
596                    Self::add_weight_decay_simd_avx2(grad_ptr, param_ptr, weight_decay, size);
597                    return;
598                }
599            }
600
601            // Scalar fallback
602            for i in 0..size {
603                *grad_ptr.add(i) += weight_decay * *param_ptr.add(i);
604            }
605        }
606    }
607
608    #[cfg(target_arch = "x86_64")]
609    #[inline]
610    unsafe fn add_weight_decay_simd_avx2(
611        grad_ptr: *mut f32,
612        param_ptr: *const f32,
613        weight_decay: f32,
614        size: usize,
615    ) {
616        let decay_vec = _mm256_set1_ps(weight_decay);
617        let simd_count = size / 8;
618
619        for i in 0..simd_count {
620            let offset = i * 8;
621            let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
622            let param_vec = _mm256_loadu_ps(param_ptr.add(offset));
623            let decay_term = _mm256_mul_ps(decay_vec, param_vec);
624            let result = _mm256_add_ps(grad_vec, decay_term);
625            _mm256_storeu_ps(grad_ptr.add(offset), result);
626        }
627
628        // Handle remaining elements
629        for i in (simd_count * 8)..size {
630            *grad_ptr.add(i) += weight_decay * *param_ptr.add(i);
631        }
632    }
633
634    /// Update momentum (first moment estimate)
635    ///
636    /// Implements the momentum update rule: `m = beta1 * m + (1 - beta1) * grad`
637    /// This computes the exponentially decaying average of gradients, providing
638    /// momentum-like behavior that helps accelerate convergence in relevant directions.
639    ///
640    /// # Arguments
641    ///
642    /// * `momentum` - Momentum buffer to update in-place
643    /// * `grad` - Current gradient tensor
644    /// * `beta1` - Exponential decay rate for momentum
645    #[inline]
646    fn update_momentum(momentum: &mut Tensor, grad: &Tensor, beta1: f32) {
647        assert_eq!(
648            momentum.size(),
649            grad.size(),
650            "Momentum and gradient size mismatch"
651        );
652
653        let beta1_complement = 1.0 - beta1;
654
655        unsafe {
656            let m_ptr = momentum.as_mut_ptr();
657            let grad_ptr = grad.as_ptr();
658            let size = momentum.size();
659
660            #[cfg(target_arch = "x86_64")]
661            {
662                if is_x86_feature_detected!("avx2")
663                    && momentum.is_simd_aligned()
664                    && grad.is_simd_aligned()
665                {
666                    Self::update_momentum_simd_avx2(m_ptr, grad_ptr, beta1, beta1_complement, size);
667                    return;
668                }
669            }
670
671            // Scalar fallback
672            for i in 0..size {
673                *m_ptr.add(i) = beta1 * *m_ptr.add(i) + beta1_complement * *grad_ptr.add(i);
674            }
675        }
676    }
677
678    #[cfg(target_arch = "x86_64")]
679    #[inline]
680    unsafe fn update_momentum_simd_avx2(
681        m_ptr: *mut f32,
682        grad_ptr: *const f32,
683        beta1: f32,
684        beta1_complement: f32,
685        size: usize,
686    ) {
687        let beta1_vec = _mm256_set1_ps(beta1);
688        let beta1_comp_vec = _mm256_set1_ps(beta1_complement);
689        let simd_count = size / 8;
690
691        for i in 0..simd_count {
692            let offset = i * 8;
693            let m_vec = _mm256_loadu_ps(m_ptr.add(offset));
694            let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
695
696            let momentum_term = _mm256_mul_ps(beta1_vec, m_vec);
697            let gradient_term = _mm256_mul_ps(beta1_comp_vec, grad_vec);
698            let result = _mm256_add_ps(momentum_term, gradient_term);
699
700            _mm256_storeu_ps(m_ptr.add(offset), result);
701        }
702
703        // Handle remaining elements
704        for i in (simd_count * 8)..size {
705            *m_ptr.add(i) = beta1 * *m_ptr.add(i) + beta1_complement * *grad_ptr.add(i);
706        }
707    }
708
709    /// Update velocity (second moment estimate)
710    ///
711    /// Implements the velocity update rule: `v = beta2 * v + (1 - beta2) * grad^2`
712    /// This computes the exponentially decaying average of squared gradients,
713    /// providing adaptive learning rates that scale inversely with gradient magnitude.
714    ///
715    /// # Arguments
716    ///
717    /// * `velocity` - Velocity buffer to update in-place
718    /// * `grad` - Current gradient tensor
719    /// * `beta2` - Exponential decay rate for velocity
720    #[inline]
721    fn update_velocity(velocity: &mut Tensor, grad: &Tensor, beta2: f32) {
722        assert_eq!(
723            velocity.size(),
724            grad.size(),
725            "Velocity and gradient size mismatch"
726        );
727
728        let beta2_complement = 1.0 - beta2;
729
730        unsafe {
731            let v_ptr = velocity.as_mut_ptr();
732            let grad_ptr = grad.as_ptr();
733            let size = velocity.size();
734
735            #[cfg(target_arch = "x86_64")]
736            {
737                if is_x86_feature_detected!("avx2")
738                    && velocity.is_simd_aligned()
739                    && grad.is_simd_aligned()
740                {
741                    Self::update_velocity_simd_avx2(v_ptr, grad_ptr, beta2, beta2_complement, size);
742                    return;
743                }
744            }
745
746            // Scalar fallback
747            for i in 0..size {
748                let grad_val = *grad_ptr.add(i);
749                *v_ptr.add(i) = beta2 * *v_ptr.add(i) + beta2_complement * grad_val * grad_val;
750            }
751        }
752    }
753
754    #[cfg(target_arch = "x86_64")]
755    #[inline]
756    unsafe fn update_velocity_simd_avx2(
757        v_ptr: *mut f32,
758        grad_ptr: *const f32,
759        beta2: f32,
760        beta2_complement: f32,
761        size: usize,
762    ) {
763        let beta2_vec = _mm256_set1_ps(beta2);
764        let beta2_comp_vec = _mm256_set1_ps(beta2_complement);
765        let simd_count = size / 8;
766
767        for i in 0..simd_count {
768            let offset = i * 8;
769            let v_vec = _mm256_loadu_ps(v_ptr.add(offset));
770            let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
771
772            let velocity_term = _mm256_mul_ps(beta2_vec, v_vec);
773            let grad_squared = _mm256_mul_ps(grad_vec, grad_vec);
774            let gradient_term = _mm256_mul_ps(beta2_comp_vec, grad_squared);
775            let result = _mm256_add_ps(velocity_term, gradient_term);
776
777            _mm256_storeu_ps(v_ptr.add(offset), result);
778        }
779
780        // Handle remaining elements
781        for i in (simd_count * 8)..size {
782            let grad_val = *grad_ptr.add(i);
783            *v_ptr.add(i) = beta2 * *v_ptr.add(i) + beta2_complement * grad_val * grad_val;
784        }
785    }
786
787    /// Apply bias correction to moment estimates
788    ///
789    /// Returns `tensor / (1 - beta^step)` for unbiasing the moment estimates.
790    /// This correction is necessary because the moment estimates are initialized
791    /// to zero, creating a bias towards zero that becomes negligible as training progresses.
792    ///
793    /// # Arguments
794    ///
795    /// * `tensor` - Moment estimate tensor to correct
796    /// * `bias_correction` - Bias correction factor (1 - beta^step)
797    ///
798    /// # Returns
799    ///
800    /// Bias-corrected tensor with the same shape as input
801    #[inline]
802    fn bias_correct(tensor: &Tensor, bias_correction: f32) -> Tensor {
803        let mut result = tensor.clone();
804        let correction_factor = 1.0 / bias_correction;
805
806        unsafe {
807            let ptr = result.as_mut_ptr();
808            let size = result.size();
809
810            #[cfg(target_arch = "x86_64")]
811            {
812                if is_x86_feature_detected!("avx2") && result.is_simd_aligned() {
813                    Self::bias_correct_simd_avx2(ptr, correction_factor, size);
814                    return result;
815                }
816            }
817
818            // Scalar fallback
819            for i in 0..size {
820                *ptr.add(i) *= correction_factor;
821            }
822        }
823
824        result
825    }
826
827    #[cfg(target_arch = "x86_64")]
828    #[inline]
829    unsafe fn bias_correct_simd_avx2(ptr: *mut f32, correction_factor: f32, size: usize) {
830        let factor_vec = _mm256_set1_ps(correction_factor);
831        let simd_count = size / 8;
832
833        for i in 0..simd_count {
834            let offset = i * 8;
835            let data_vec = _mm256_loadu_ps(ptr.add(offset));
836            let result = _mm256_mul_ps(data_vec, factor_vec);
837            _mm256_storeu_ps(ptr.add(offset), result);
838        }
839
840        // Handle remaining elements
841        for i in (simd_count * 8)..size {
842            *ptr.add(i) *= correction_factor;
843        }
844    }
845
846    /// Element-wise maximum for AMSGrad
847    ///
848    /// Updates first tensor in-place with `max(first, second)` for AMSGrad variant.
849    /// This maintains a running maximum of the second moment estimates, preventing
850    /// the learning rate from increasing over time and improving convergence stability.
851    ///
852    /// # Arguments
853    ///
854    /// * `first` - Tensor to update in-place with maximum values
855    /// * `second` - Tensor to compare against for maximum calculation
856    #[inline]
857    fn element_wise_max(first: &mut Tensor, second: &Tensor) {
858        assert_eq!(
859            first.size(),
860            second.size(),
861            "Tensor size mismatch for element-wise max"
862        );
863
864        unsafe {
865            let first_ptr = first.as_mut_ptr();
866            let second_ptr = second.as_ptr();
867            let size = first.size();
868
869            #[cfg(target_arch = "x86_64")]
870            {
871                if is_x86_feature_detected!("avx2")
872                    && first.is_simd_aligned()
873                    && second.is_simd_aligned()
874                {
875                    Self::element_wise_max_simd_avx2(first_ptr, second_ptr, size);
876                    return;
877                }
878            }
879
880            // Scalar fallback
881            for i in 0..size {
882                let a = *first_ptr.add(i);
883                let b = *second_ptr.add(i);
884                *first_ptr.add(i) = if a > b { a } else { b };
885            }
886        }
887    }
888
889    #[cfg(target_arch = "x86_64")]
890    #[inline]
891    unsafe fn element_wise_max_simd_avx2(first_ptr: *mut f32, second_ptr: *const f32, size: usize) {
892        let simd_count = size / 8;
893
894        for i in 0..simd_count {
895            let offset = i * 8;
896            let first_vec = _mm256_loadu_ps(first_ptr.add(offset));
897            let second_vec = _mm256_loadu_ps(second_ptr.add(offset));
898            let result = _mm256_max_ps(first_vec, second_vec);
899            _mm256_storeu_ps(first_ptr.add(offset), result);
900        }
901
902        // Handle remaining elements
903        for i in (simd_count * 8)..size {
904            let a = *first_ptr.add(i);
905            let b = *second_ptr.add(i);
906            *first_ptr.add(i) = if a > b { a } else { b };
907        }
908    }
909
910    /// Apply the final Adam parameter update
911    ///
912    /// Implements the core Adam update rule: `param = param - lr * m_hat / (sqrt(v_hat) + eps)`
913    /// This applies the bias-corrected momentum and velocity estimates to update
914    /// the parameter values, with adaptive learning rates that scale inversely
915    /// with the square root of the velocity estimates.
916    ///
917    /// # Arguments
918    ///
919    /// * `param` - Parameter tensor to update in-place
920    /// * `m_hat` - Bias-corrected first moment estimate
921    /// * `v_hat` - Bias-corrected second moment estimate
922    /// * `learning_rate` - Learning rate for the update
923    /// * `eps` - Small constant for numerical stability
924    #[inline]
925    fn apply_adam_update(
926        param: &mut Tensor,
927        m_hat: &Tensor,
928        v_hat: &Tensor,
929        learning_rate: f32,
930        eps: f32,
931    ) {
932        assert_eq!(
933            param.size(),
934            m_hat.size(),
935            "Parameter and momentum size mismatch"
936        );
937        assert_eq!(
938            param.size(),
939            v_hat.size(),
940            "Parameter and velocity size mismatch"
941        );
942
943        unsafe {
944            let param_ptr = param.as_mut_ptr();
945            let m_ptr = m_hat.as_ptr();
946            let v_ptr = v_hat.as_ptr();
947            let size = param.size();
948
949            #[cfg(target_arch = "x86_64")]
950            {
951                if is_x86_feature_detected!("avx2")
952                    && param.is_simd_aligned()
953                    && m_hat.is_simd_aligned()
954                    && v_hat.is_simd_aligned()
955                {
956                    Self::apply_adam_update_simd_avx2(
957                        param_ptr,
958                        m_ptr,
959                        v_ptr,
960                        learning_rate,
961                        eps,
962                        size,
963                    );
964                    return;
965                }
966            }
967
968            // Scalar fallback
969            for i in 0..size {
970                let m_val = *m_ptr.add(i);
971                let v_val = *v_ptr.add(i);
972                let denominator = v_val.sqrt() + eps;
973                *param_ptr.add(i) -= learning_rate * m_val / denominator;
974            }
975        }
976    }
977
978    #[cfg(target_arch = "x86_64")]
979    #[inline]
980    unsafe fn apply_adam_update_simd_avx2(
981        param_ptr: *mut f32,
982        m_ptr: *const f32,
983        v_ptr: *const f32,
984        learning_rate: f32,
985        eps: f32,
986        size: usize,
987    ) {
988        let lr_vec = _mm256_set1_ps(learning_rate);
989        let eps_vec = _mm256_set1_ps(eps);
990        let simd_count = size / 8;
991
992        for i in 0..simd_count {
993            let offset = i * 8;
994            let param_vec = _mm256_loadu_ps(param_ptr.add(offset));
995            let m_vec = _mm256_loadu_ps(m_ptr.add(offset));
996            let v_vec = _mm256_loadu_ps(v_ptr.add(offset));
997
998            // sqrt(v_hat) + eps
999            let sqrt_v = _mm256_sqrt_ps(v_vec);
1000            let denominator = _mm256_add_ps(sqrt_v, eps_vec);
1001
1002            // lr * m_hat / denominator
1003            let lr_m = _mm256_mul_ps(lr_vec, m_vec);
1004            let update = _mm256_div_ps(lr_m, denominator);
1005
1006            // param - update
1007            let result = _mm256_sub_ps(param_vec, update);
1008            _mm256_storeu_ps(param_ptr.add(offset), result);
1009        }
1010
1011        // Handle remaining elements
1012        for i in (simd_count * 8)..size {
1013            let m_val = *m_ptr.add(i);
1014            let v_val = *v_ptr.add(i);
1015            let denominator = v_val.sqrt() + eps;
1016            *param_ptr.add(i) -= learning_rate * m_val / denominator;
1017        }
1018    }
1019}
1020
1021impl Optimizer for Adam {
1022    /// Perform a single optimization step
1023    ///
1024    /// Updates all provided parameters based on their accumulated gradients using the Adam algorithm.
1025    /// Each parameter is updated according to the Adam update rule with bias correction
1026    /// and optional AMSGrad variant if enabled. All parameters must be linked to the optimizer
1027    /// before calling this method.
1028    ///
1029    /// # Arguments
1030    ///
1031    /// * `parameters` - Mutable slice of parameter references to update
1032    ///
1033    /// # Thread Safety
1034    ///
1035    /// This method is thread-safe as it takes mutable references to parameters,
1036    /// ensuring exclusive access during updates.
1037    ///
1038    /// # Performance
1039    ///
1040    /// - Uses SIMD optimization (AVX2) when available for 8x vectorization
1041    /// - Processes parameters in sequence for optimal cache usage
1042    /// - Maintains per-parameter state for momentum and velocity estimates
1043    ///
1044    /// # Panics
1045    ///
1046    /// Panics if any parameter is not linked to the optimizer
1047    fn step(&mut self, parameters: &mut [&mut Tensor]) {
1048        self.step_count += 1;
1049
1050        for param in parameters {
1051            if let Err(e) = self.update_parameter(param) {
1052                eprintln!("Warning: Failed to update parameter: {}", e);
1053            }
1054        }
1055    }
1056
1057    /// Zero out all parameter gradients
1058    ///
1059    /// Clears accumulated gradients for all provided parameters. This should be called
1060    /// before each backward pass to prevent gradient accumulation across multiple
1061    /// forward/backward passes. Also clears the global autograd gradient map.
1062    ///
1063    /// # Arguments
1064    ///
1065    /// * `parameters` - Mutable slice of parameter references to clear gradients for
1066    ///
1067    /// # Performance
1068    ///
1069    /// - Efficiently clears gradients using optimized tensor operations
1070    /// - Clears both per-tensor gradients and global autograd state
1071    /// - Thread-safe as it takes mutable references to parameters
1072    fn zero_grad(&mut self, parameters: &mut [&mut Tensor]) {
1073        for param in parameters {
1074            param.zero_grad();
1075        }
1076    }
1077
1078    /// Get the current learning rate
1079    ///
1080    /// Returns the current learning rate used for parameter updates.
1081    ///
1082    /// # Returns
1083    ///
1084    /// Current learning rate as f32
1085    fn learning_rate(&self) -> f32 {
1086        self.config.learning_rate
1087    }
1088
1089    /// Set the learning rate for all parameters
1090    ///
1091    /// Updates the learning rate for all parameters in the optimizer.
1092    /// This allows dynamic learning rate scheduling during training.
1093    ///
1094    /// # Arguments
1095    ///
1096    /// * `lr` - New learning rate value
1097    fn set_learning_rate(&mut self, lr: f32) {
1098        self.config.learning_rate = lr;
1099    }
1100}
1101
1102// Adam is automatically Send + Sync since it no longer contains raw pointers
1103// and all fields are Send + Sync (HashMap, usize, AdamConfig)
1104
1105#[cfg(test)]
1106mod tests {
1107    use super::*;
1108    use crate::tensor::core::Tensor;
1109
1110    /// Test Adam optimizer creation with default configuration
1111    ///
1112    /// Verifies that the optimizer is created correctly with default hyperparameters
1113    /// and that parameter states are properly initialized.
1114    #[test]
1115    fn test_adam_creation() {
1116        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1117        let bias = Tensor::zeros(vec![3, 1]).with_requires_grad();
1118
1119        let mut optimizer = Adam::new();
1120        optimizer.add_parameter(&weight);
1121        optimizer.add_parameter(&bias);
1122
1123        assert_eq!(optimizer.learning_rate(), 1e-3);
1124        assert_eq!(optimizer.config.beta1, 0.9);
1125        assert_eq!(optimizer.config.beta2, 0.999);
1126        assert_eq!(optimizer.parameter_count(), 2);
1127    }
1128
1129    /// Test Adam optimizer creation with custom configuration
1130    ///
1131    /// Verifies that custom hyperparameters are properly set and that
1132    /// the optimizer configuration matches the provided values.
1133    #[test]
1134    fn test_adam_with_config() {
1135        let weight = Tensor::ones(vec![5, 5]).with_requires_grad();
1136
1137        let config = AdamConfig {
1138            learning_rate: 1e-4,
1139            beta1: 0.95,
1140            beta2: 0.9999,
1141            weight_decay: 1e-5,
1142            amsgrad: true,
1143            ..Default::default()
1144        };
1145
1146        let mut optimizer = Adam::with_config(config);
1147        optimizer.add_parameter(&weight);
1148
1149        assert_eq!(optimizer.learning_rate(), 1e-4);
1150        assert_eq!(optimizer.config.beta1, 0.95);
1151        assert_eq!(optimizer.config.beta2, 0.9999);
1152        assert_eq!(optimizer.config.weight_decay, 1e-5);
1153        assert!(optimizer.config.amsgrad);
1154    }
1155
1156    /// Test Adam optimizer creation with custom learning rate
1157    ///
1158    /// Verifies that the convenience constructor properly sets the learning rate
1159    /// while maintaining default values for other hyperparameters.
1160    #[test]
1161    fn test_adam_with_learning_rate() {
1162        let weight = Tensor::ones(vec![3, 3]).with_requires_grad();
1163        let mut optimizer = Adam::with_learning_rate(5e-4);
1164        optimizer.add_parameter(&weight);
1165
1166        assert_eq!(optimizer.learning_rate(), 5e-4);
1167        assert_eq!(optimizer.config.beta1, 0.9); // Should use defaults for other params
1168    }
1169
1170    /// Test Adam step without gradients
1171    ///
1172    /// Verifies that the optimizer handles the case where parameters have no
1173    /// gradients gracefully, leaving parameters unchanged.
1174    #[test]
1175    fn test_adam_step_without_gradients() {
1176        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1177        let original_data = weight.clone();
1178
1179        let mut optimizer = Adam::new();
1180        optimizer.add_parameter(&weight);
1181
1182        optimizer.step(&mut [&mut weight]); // Should not update without gradients
1183
1184        // Parameters should remain unchanged without gradients
1185        for i in 0..weight.size() {
1186            unsafe {
1187                assert_eq!(*weight.as_ptr().add(i), *original_data.as_ptr().add(i));
1188            }
1189        }
1190    }
1191
1192    /// Test learning rate update functionality
1193    ///
1194    /// Verifies that the learning rate can be dynamically updated during training
1195    /// and that the optimizer uses the new learning rate for subsequent steps.
1196    #[test]
1197    fn test_learning_rate_update() {
1198        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1199        let mut optimizer = Adam::new();
1200        optimizer.add_parameter(&weight);
1201
1202        assert_eq!(optimizer.learning_rate(), 1e-3);
1203
1204        optimizer.set_learning_rate(1e-2);
1205        assert_eq!(optimizer.learning_rate(), 1e-2);
1206    }
1207
1208    /// Test gradient zeroing functionality
1209    ///
1210    /// Verifies that zero_grad properly clears accumulated gradients for all
1211    /// parameters and the global autograd state.
1212    #[test]
1213    fn test_zero_grad() {
1214        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1215
1216        // Check that zero_grad clears gradients
1217        let mut optimizer = Adam::new();
1218        optimizer.add_parameter(&weight);
1219        optimizer.zero_grad(&mut [&mut weight]);
1220
1221        // After zero_grad, there should be no accumulated gradients
1222        assert!(weight.grad_owned().is_none());
1223    }
1224
1225    /// Test requires_grad assertion
1226    ///
1227    /// Verifies that the optimizer correctly panics when parameters do not
1228    /// have requires_grad set to true, ensuring proper gradient tracking.
1229    #[test]
1230    #[should_panic(expected = "Parameter must require gradients")]
1231    fn test_adam_requires_grad_assertion() {
1232        let weight = Tensor::ones(vec![2, 2]); // No requires_grad
1233        let mut optimizer = Adam::new();
1234        optimizer.add_parameter(&weight);
1235    }
1236
1237    /// Test AdamConfig default values
1238    ///
1239    /// Verifies that the default configuration matches PyTorch conventions
1240    /// and provides optimal settings for most training scenarios.
1241    #[test]
1242    fn test_adam_config_default() {
1243        let config = AdamConfig::default();
1244
1245        assert_eq!(config.learning_rate, 1e-3);
1246        assert_eq!(config.beta1, 0.9);
1247        assert_eq!(config.beta2, 0.999);
1248        assert_eq!(config.eps, 1e-8);
1249        assert_eq!(config.weight_decay, 0.0);
1250        assert!(!config.amsgrad);
1251    }
1252
1253    /// Test ParameterState creation and initialization
1254    ///
1255    /// Verifies that parameter states are properly initialized with zero tensors
1256    /// and that the step count starts at zero.
1257    #[test]
1258    fn test_parameter_state_creation() {
1259        let state = ParameterState::new(&[3, 4]);
1260
1261        assert_eq!(state.m.shape().dims(), vec![3, 4]);
1262        assert_eq!(state.v.shape().dims(), vec![3, 4]);
1263        assert!(state.v_hat_max.is_none());
1264        assert_eq!(state.step, 0);
1265
1266        // Verify tensors are zero-initialized
1267        for i in 0..state.m.size() {
1268            unsafe {
1269                assert_eq!(*state.m.as_ptr().add(i), 0.0);
1270                assert_eq!(*state.v.as_ptr().add(i), 0.0);
1271            }
1272        }
1273    }
1274
1275    /// Test parameter linking functionality
1276    ///
1277    /// Verifies that parameters can be linked and unlinked from the optimizer.
1278    #[test]
1279    fn test_parameter_linking() {
1280        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1281        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1282
1283        let mut optimizer = Adam::new();
1284
1285        // Initially no parameters linked
1286        assert_eq!(optimizer.parameter_count(), 0);
1287        assert!(!optimizer.is_parameter_linked(&weight));
1288        assert!(!optimizer.is_parameter_linked(&bias));
1289
1290        // Link weight
1291        optimizer.add_parameter(&weight);
1292        assert_eq!(optimizer.parameter_count(), 1);
1293        assert!(optimizer.is_parameter_linked(&weight));
1294        assert!(!optimizer.is_parameter_linked(&bias));
1295
1296        // Link bias
1297        optimizer.add_parameter(&bias);
1298        assert_eq!(optimizer.parameter_count(), 2);
1299        assert!(optimizer.is_parameter_linked(&weight));
1300        assert!(optimizer.is_parameter_linked(&bias));
1301
1302        // Unlink weight
1303        let was_linked = optimizer.unlink_parameter(&weight);
1304        assert!(was_linked);
1305        assert_eq!(optimizer.parameter_count(), 1);
1306        assert!(!optimizer.is_parameter_linked(&weight));
1307        assert!(optimizer.is_parameter_linked(&bias));
1308
1309        // Clear all states
1310        optimizer.clear_states();
1311        assert_eq!(optimizer.parameter_count(), 0);
1312        assert!(!optimizer.is_parameter_linked(&weight));
1313        assert!(!optimizer.is_parameter_linked(&bias));
1314    }
1315
1316    /// Test parameter linking with multiple parameters at once
1317    ///
1318    /// Verifies that multiple parameters can be linked simultaneously.
1319    #[test]
1320    fn test_add_multiple_parameters() {
1321        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1322        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1323        let weight2 = Tensor::ones(vec![3, 2]).with_requires_grad();
1324
1325        let mut optimizer = Adam::new();
1326
1327        // Link multiple parameters at once
1328        optimizer.add_parameters(&[&weight, &bias, &weight2]);
1329
1330        assert_eq!(optimizer.parameter_count(), 3);
1331        assert!(optimizer.is_parameter_linked(&weight));
1332        assert!(optimizer.is_parameter_linked(&bias));
1333        assert!(optimizer.is_parameter_linked(&weight2));
1334    }
1335
1336    /// Test stepping with unlinked parameter
1337    ///
1338    /// Verifies that the optimizer panics when trying to step with an unlinked parameter.
1339    #[test]
1340    #[should_panic(expected = "Parameter must be linked to optimizer before stepping")]
1341    fn test_step_with_unlinked_parameter() {
1342        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1343        let mut optimizer = Adam::new();
1344
1345        // Don't link the parameter
1346        optimizer.step(&mut [&mut weight]); // Should panic
1347    }
1348
1349    /// Test optimizer with actual gradients
1350    ///
1351    /// Verifies that the optimizer properly updates parameters when gradients are present.
1352    #[test]
1353    fn test_optimizer_with_gradients() {
1354        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1355        let original_data = weight.clone();
1356
1357        let mut optimizer = Adam::new();
1358        optimizer.add_parameter(&weight);
1359
1360        // Generate some gradients
1361        let output = weight.mul_scalar(2.0);
1362        let mut loss = output.sum();
1363        loss.backward(None);
1364
1365        // Step should update parameters
1366        optimizer.step(&mut [&mut weight]);
1367
1368        // Parameters should have changed
1369        let mut changed = false;
1370        for i in 0..weight.size() {
1371            unsafe {
1372                if (*weight.as_ptr().add(i) - *original_data.as_ptr().add(i)).abs() > 1e-6 {
1373                    changed = true;
1374                    break;
1375                }
1376            }
1377        }
1378        assert!(
1379            changed,
1380            "Parameters should have been updated by optimizer step"
1381        );
1382    }
1383
1384    /// Test optimizer with multiple parameters and gradients
1385    ///
1386    /// Verifies that the optimizer works correctly with multiple parameters.
1387    #[test]
1388    fn test_optimizer_multiple_parameters() {
1389        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1390        let mut bias = Tensor::zeros(vec![2, 2]).with_requires_grad(); // Same shape as weight
1391
1392        let mut optimizer = Adam::new();
1393        optimizer.add_parameter(&weight);
1394        optimizer.add_parameter(&bias);
1395
1396        // Generate gradients for both parameters
1397        let output = weight.mul_scalar(2.0).add_tensor(&bias);
1398        let mut loss = output.sum();
1399        loss.backward(None);
1400
1401        // Step should update both parameters
1402        optimizer.step(&mut [&mut weight, &mut bias]);
1403
1404        // Both parameters should have gradients
1405        assert!(weight.grad_owned().is_some());
1406        assert!(bias.grad_owned().is_some());
1407    }
1408
1409    /// Test optimizer with custom configuration and multiple steps
1410    ///
1411    /// Verifies that the optimizer works correctly with custom configuration over multiple steps.
1412    #[test]
1413    fn test_optimizer_custom_config_multiple_steps() {
1414        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1415
1416        let config = AdamConfig {
1417            learning_rate: 1e-4,
1418            beta1: 0.95,
1419            beta2: 0.9999,
1420            weight_decay: 1e-5,
1421            amsgrad: true,
1422            ..Default::default()
1423        };
1424
1425        let mut optimizer = Adam::with_config(config);
1426        optimizer.add_parameter(&weight);
1427
1428        // Multiple training steps
1429        for _ in 0..5 {
1430            // Generate gradients
1431            let output = weight.mul_scalar(2.0);
1432            let mut loss = output.sum();
1433            loss.backward(None);
1434
1435            // Step
1436            optimizer.step(&mut [&mut weight]);
1437            optimizer.zero_grad(&mut [&mut weight]);
1438        }
1439
1440        // Should complete without errors
1441        assert_eq!(optimizer.parameter_count(), 1);
1442        assert!(optimizer.is_parameter_linked(&weight));
1443    }
1444
1445    /// Test optimizer with different tensor shapes
1446    ///
1447    /// Verifies that the optimizer works correctly with various tensor shapes.
1448    #[test]
1449    fn test_optimizer_different_shapes() {
1450        let shapes = vec![
1451            vec![1],       // Scalar
1452            vec![3],       // 1D
1453            vec![2, 2],    // 2D square
1454            vec![2, 3],    // 2D rectangular
1455            vec![1, 1, 3], // 3D
1456            vec![2, 2, 2], // 3D cube
1457        ];
1458
1459        for shape in shapes {
1460            let mut tensor = Tensor::ones(shape.clone()).with_requires_grad();
1461            let mut optimizer = Adam::new();
1462            optimizer.add_parameter(&tensor);
1463
1464            // Generate gradients
1465            let output = tensor.mul_scalar(2.0);
1466            let mut loss = output.sum();
1467            loss.backward(None);
1468
1469            // Step should work for all shapes
1470            optimizer.step(&mut [&mut tensor]);
1471
1472            // Verify tensor is still valid
1473            assert_eq!(tensor.shape().dims(), shape);
1474            assert!(tensor.requires_grad());
1475        }
1476    }
1477
1478    /// Test double parameter linking
1479    ///
1480    /// Verifies that linking the same parameter twice doesn't create duplicate states.
1481    #[test]
1482    fn test_double_parameter_linking() {
1483        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1484        let mut optimizer = Adam::new();
1485
1486        // Link parameter twice
1487        optimizer.add_parameter(&weight);
1488        optimizer.add_parameter(&weight);
1489
1490        // Should only have one state
1491        assert_eq!(optimizer.parameter_count(), 1);
1492        assert!(optimizer.is_parameter_linked(&weight));
1493    }
1494
1495    /// Test unlink non-linked parameter
1496    ///
1497    /// Verifies that unlinking a non-linked parameter returns false.
1498    #[test]
1499    fn test_unlink_non_linked_parameter() {
1500        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1501        let mut optimizer = Adam::new();
1502
1503        // Try to unlink a parameter that was never linked
1504        let was_linked = optimizer.unlink_parameter(&weight);
1505        assert!(!was_linked);
1506    }
1507
1508    /// Test parameter shape validation
1509    ///
1510    /// Verifies that parameters with different shapes can be linked correctly
1511    /// and maintain their shape information in the optimizer state.
1512    #[test]
1513    fn test_parameter_shape_validation() {
1514        let weight_1d = Tensor::ones(vec![5]).with_requires_grad();
1515        let weight_2d = Tensor::ones(vec![3, 4]).with_requires_grad();
1516        let weight_3d = Tensor::ones(vec![2, 3, 4]).with_requires_grad();
1517
1518        let mut optimizer = Adam::new();
1519
1520        // Test linking parameters with different shapes
1521        optimizer.add_parameter(&weight_1d);
1522        optimizer.add_parameter(&weight_2d);
1523        optimizer.add_parameter(&weight_3d);
1524
1525        assert_eq!(optimizer.parameter_count(), 3);
1526        assert!(optimizer.is_parameter_linked(&weight_1d));
1527        assert!(optimizer.is_parameter_linked(&weight_2d));
1528        assert!(optimizer.is_parameter_linked(&weight_3d));
1529
1530        // Test that each parameter maintains its shape after linking
1531        let state_1d = &optimizer.states[&weight_1d.id()];
1532        let state_2d = &optimizer.states[&weight_2d.id()];
1533        let state_3d = &optimizer.states[&weight_3d.id()];
1534
1535        assert_eq!(state_1d.m.shape().dims(), vec![5]);
1536        assert_eq!(state_2d.m.shape().dims(), vec![3, 4]);
1537        assert_eq!(state_3d.m.shape().dims(), vec![2, 3, 4]);
1538    }
1539
1540    /// Test parameter relinking with shape consistency
1541    ///
1542    /// Verifies that when re-linking parameters (e.g., after deserialization),
1543    /// shape information is correctly maintained.
1544    #[test]
1545    fn test_parameter_relinking_shape_consistency() {
1546        // Create initial parameter and optimizer
1547        let mut weight_original = Tensor::ones(vec![3, 3]).with_requires_grad();
1548        let mut optimizer = Adam::new();
1549        optimizer.add_parameter(&weight_original);
1550
1551        // Perform a step to create state
1552        let output = weight_original.mul_scalar(2.0);
1553        let mut loss = output.sum();
1554        loss.backward(None);
1555        optimizer.step(&mut [&mut weight_original]);
1556
1557        // Verify state was created with correct shape
1558        let original_state = &optimizer.states[&weight_original.id()];
1559        assert_eq!(original_state.m.shape().dims(), vec![3, 3]);
1560        assert_eq!(original_state.v.shape().dims(), vec![3, 3]);
1561
1562        // Create new parameter with same shape (will get different ID)
1563        let weight_new = Tensor::ones(vec![3, 3]).with_requires_grad();
1564
1565        // This should create a new state since it's a different parameter
1566        optimizer.add_parameter(&weight_new);
1567
1568        // Should now have 2 states
1569        assert_eq!(optimizer.parameter_count(), 2);
1570
1571        // Both should be linked with correct shapes
1572        assert!(optimizer.is_parameter_linked(&weight_original));
1573        assert!(optimizer.is_parameter_linked(&weight_new));
1574
1575        let new_state = &optimizer.states[&weight_new.id()];
1576        assert_eq!(new_state.m.shape().dims(), vec![3, 3]);
1577        assert_eq!(new_state.v.shape().dims(), vec![3, 3]);
1578    }
1579
1580    /// Test large parameter count handling
1581    ///
1582    /// Verifies that the optimizer can handle many parameters efficiently
1583    /// and correctly manage linking/unlinking operations.
1584    #[test]
1585    fn test_large_parameter_count() {
1586        let mut optimizer = Adam::new();
1587        let mut params = Vec::new();
1588
1589        // Create 50 parameters of different shapes
1590        for i in 1..=50 {
1591            let param = Tensor::ones(vec![i]).with_requires_grad();
1592            optimizer.add_parameter(&param);
1593            params.push(param);
1594        }
1595
1596        assert_eq!(optimizer.parameter_count(), 50);
1597
1598        // Verify all parameters are linked
1599        for param in &params {
1600            assert!(optimizer.is_parameter_linked(param));
1601        }
1602
1603        // Test unlinking some parameters
1604        for param in params.iter().take(25).step_by(2) {
1605            assert!(optimizer.unlink_parameter(param));
1606        }
1607
1608        assert_eq!(optimizer.parameter_count(), 37); // 50 - 13 = 37
1609
1610        // Verify correct parameters are unlinked
1611        for (i, param) in params.iter().enumerate() {
1612            if i < 25 && i % 2 == 0 {
1613                assert!(!optimizer.is_parameter_linked(param));
1614            } else {
1615                assert!(optimizer.is_parameter_linked(param));
1616            }
1617        }
1618    }
1619
1620    /// Test clear_states functionality
1621    ///
1622    /// Verifies that clearing all states works correctly and allows
1623    /// re-adding parameters afterwards.
1624    #[test]
1625    fn test_clear_states_functionality() {
1626        let weight1 = Tensor::ones(vec![2, 2]).with_requires_grad();
1627        let weight2 = Tensor::ones(vec![3, 3]).with_requires_grad();
1628        let weight3 = Tensor::ones(vec![4, 4]).with_requires_grad();
1629
1630        let mut optimizer = Adam::new();
1631        optimizer.add_parameter(&weight1);
1632        optimizer.add_parameter(&weight2);
1633        optimizer.add_parameter(&weight3);
1634
1635        assert_eq!(optimizer.parameter_count(), 3);
1636
1637        // Clear all states
1638        optimizer.clear_states();
1639
1640        assert_eq!(optimizer.parameter_count(), 0);
1641        assert!(!optimizer.is_parameter_linked(&weight1));
1642        assert!(!optimizer.is_parameter_linked(&weight2));
1643        assert!(!optimizer.is_parameter_linked(&weight3));
1644
1645        // Should be able to re-add parameters after clearing
1646        optimizer.add_parameter(&weight1);
1647        assert_eq!(optimizer.parameter_count(), 1);
1648        assert!(optimizer.is_parameter_linked(&weight1));
1649    }
1650}