train_station/optimizers/adam/
mod.rs

1//! Adam optimizer implementation for neural network training
2//!
3//! This module provides the Adam optimization algorithm with PyTorch-compatible interface.
4//! Adam combines the benefits of AdaGrad and RMSprop by using adaptive learning rates
5//! with momentum for efficient training of neural networks.
6//!
7//! # Features
8//!
9//! - **Adaptive Learning Rates**: Per-parameter learning rates based on gradient history
10//! - **Momentum Integration**: Combines momentum with adaptive learning rates
11//! - **Bias Correction**: Corrects for initialization bias in moment estimates
12//! - **Weight Decay**: Optional L2 regularization support
13//! - **AMSGrad Variant**: Optional AMSGrad for improved convergence stability
14//! - **SIMD Optimization**: AVX2-optimized parameter updates for maximum performance
15//! - **Thread Safety**: Send + Sync implementation for multi-threaded training
16//! - **Hybrid API**: Both safe (RwLock-based) and unsafe (direct pointer) access patterns
17//!
18//! # Thread Safety
19//!
20//! The optimizer provides two usage patterns:
21//!
22//! **Safe Multi-threaded Usage (Default)**:
23//! - Uses `Arc<RwLock<Tensor>>` for thread-safe parameter access
24//! - Multiple threads can read tensors simultaneously
25//! - Optimizer steps acquire write locks only during parameter updates
26//! - Recommended for most use cases
27//!
28//! **Unsafe Single-threaded Usage (Performance)**:
29//! - Uses raw pointers for maximum performance
30//! - No locking overhead during optimizer steps
31//! - Caller must ensure exclusive access during optimization
32//! - Use only when you can guarantee no concurrent tensor access
33//!
34//! # Algorithm
35//!
36//! Adam implements the following update rule for each parameter theta:
37//!
38//! ```text
39//! m_t = beta1 * m_{t-1} + (1 - beta1) * grad_theta_t
40//! v_t = beta2 * v_{t-1} + (1 - beta2) * (grad_theta_t)^2
41//! m_hat_t = m_t / (1 - beta1^t)
42//! v_hat_t = v_t / (1 - beta2^t)
43//! theta_{t+1} = theta_t - lr * m_hat_t / (sqrt(v_hat_t) + eps)
44//! ```
45//!
46//! Where:
47//! - lr is the learning rate
48//! - beta1, beta2 are exponential decay rates for moment estimates
49//! - eps is a small constant for numerical stability
50//! - m_t, v_t are biased first and second moment estimates
51//! - m_hat_t, v_hat_t are bias-corrected moment estimates
52//!
53//! # Performance Characteristics
54//!
55//! - **SIMD Optimization**: Uses AVX2 instructions for 8x vectorization when available
56//! - **Memory Efficiency**: In-place updates with minimal temporary allocations
57//! - **Cache-Friendly**: Optimized memory access patterns for large parameter tensors
58//! - **Zero-Cost Abstractions**: Compile-time optimization with minimal runtime overhead
59//! - **Lock-Free Reads**: RwLock allows concurrent tensor reads during training
60//!
61//! # References
62//!
63//! - Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization.
64//! - PyTorch Adam implementation: <https://pytorch.org/docs/stable/generated/torch.optim.Adam.html>
65
66pub mod serialization;
67
68use super::Optimizer;
69use crate::tensor::core::Tensor;
70use std::collections::HashMap;
71
72// SIMD optimizations for performance-critical operations
73#[cfg(target_arch = "x86_64")]
74use std::arch::x86_64::*;
75
76/// Configuration for the Adam optimization algorithm
77///
78/// Contains all hyperparameters that control the behavior of Adam optimization.
79/// Default values follow PyTorch conventions for maximum compatibility and
80/// optimal convergence across a wide range of neural network architectures.
81///
82/// # Fields
83///
84/// * `learning_rate` - Step size for parameter updates (default: 1e-3)
85/// * `beta1` - Exponential decay rate for first moment estimates (default: 0.9)
86/// * `beta2` - Exponential decay rate for second moment estimates (default: 0.999)
87/// * `eps` - Small constant for numerical stability (default: 1e-8)
88/// * `weight_decay` - L2 regularization coefficient (default: 0.0)
89/// * `amsgrad` - Whether to use AMSGrad variant for improved stability (default: false)
90#[derive(Debug, Clone)]
91pub struct AdamConfig {
92    /// Learning rate for parameter updates (default: 1e-3)
93    pub learning_rate: f32,
94    /// Exponential decay rate for first moment estimates (default: 0.9)  
95    pub beta1: f32,
96    /// Exponential decay rate for second moment estimates (default: 0.999)
97    pub beta2: f32,
98    /// Small constant for numerical stability (default: 1e-8)
99    pub eps: f32,
100    /// Weight decay coefficient for L2 regularization (default: 0.0)
101    pub weight_decay: f32,
102    /// Whether to use AMSGrad variant (default: false)
103    pub amsgrad: bool,
104}
105
106impl Default for AdamConfig {
107    fn default() -> Self {
108        Self {
109            learning_rate: 1e-3,
110            beta1: 0.9,
111            beta2: 0.999,
112            eps: 1e-8,
113            weight_decay: 0.0,
114            amsgrad: false,
115        }
116    }
117}
118
119/// Internal state tracking for a single parameter during Adam optimization
120///
121/// Stores the momentum and velocity buffers needed for Adam optimization,
122/// along with step count for bias correction and optional AMSGrad state.
123/// Memory layout is optimized for cache efficiency and SIMD operations.
124///
125/// # Fields
126///
127/// * `m` - First moment estimate (momentum buffer)
128/// * `v` - Second moment estimate (velocity buffer)
129/// * `v_hat_max` - Maximum of second moment estimates (for AMSGrad variant)
130/// * `step` - Step count for bias correction calculations
131#[derive(Debug)]
132struct ParameterState {
133    /// First moment estimate (momentum)
134    m: Tensor,
135    /// Second moment estimate (velocity)  
136    v: Tensor,
137    /// Maximum of second moment estimates (for AMSGrad)
138    v_hat_max: Option<Tensor>,
139    /// Step count for bias correction
140    step: usize,
141}
142
143impl ParameterState {
144    fn new(param_shape: &[usize]) -> Self {
145        Self {
146            m: Tensor::zeros(param_shape.to_vec()),
147            v: Tensor::zeros(param_shape.to_vec()),
148            v_hat_max: None,
149            step: 0,
150        }
151    }
152}
153
154/// Adam optimizer for neural network parameter optimization
155///
156/// Implements the Adam optimization algorithm with PyTorch-compatible interface.
157/// Provides adaptive learning rates with momentum for efficient training of neural networks.
158/// The optimizer maintains per-parameter state for momentum and velocity estimates,
159/// enabling adaptive learning rates that improve convergence across diverse architectures.
160///
161/// # Usage Pattern
162///
163/// The optimizer uses ID-based parameter linking for maximum flexibility and thread safety:
164/// - Parameters are linked to the optimizer via `add_parameter` or `add_parameters`
165/// - The `step` method takes mutable references to parameters for thread-safe updates
166/// - Parameter states are maintained by tensor ID, allowing for dynamic parameter management
167/// - Supports serialization and deserialization with parameter re-linking
168///
169/// # Dynamic Parameter Management
170///
171/// Parameters can be added, removed, or re-linked at runtime:
172/// - `add_parameter`: Link a single parameter
173/// - `add_parameters`: Link multiple parameters at once
174/// - `unlink_parameter`: Remove parameter state by ID
175/// - `clear_states`: Remove all parameter states
176/// - `is_parameter_linked`: Check if a parameter is linked
177///
178/// # Serialization Support
179///
180/// The optimizer supports full serialization and deserialization with state preservation:
181/// - Parameter states are saved with their shapes and insertion order for validation
182/// - After deserialization, use `relink_parameters` to restore saved states to new tensors
183/// - Parameters must be re-linked in the same chronological order they were originally added
184/// - Shape validation ensures consistency between saved and current parameters
185///
186/// # Features
187///
188/// - **ID-Based Parameter Linking**: Dynamic parameter management via tensor IDs
189/// - **Thread-Safe Step Method**: Takes mutable references for safe concurrent access
190/// - **Per-Parameter State**: Each parameter maintains its own momentum and velocity buffers
191/// - **Bias Correction**: Automatically corrects initialization bias in moment estimates
192/// - **Weight Decay**: Optional L2 regularization with efficient implementation
193/// - **AMSGrad Support**: Optional AMSGrad variant for improved convergence stability
194/// - **SIMD Optimization**: AVX2-optimized updates for maximum performance
195/// - **Full Serialization**: Complete state persistence and restoration
196///
197/// # Thread Safety
198///
199/// This type is thread-safe and can be shared between threads. The step method
200/// takes mutable references to parameters, ensuring exclusive access during updates.
201pub struct Adam {
202    /// Optimizer configuration
203    config: AdamConfig,
204    /// Parameter states indexed by tensor ID
205    states: HashMap<usize, ParameterState>,
206    /// Current global step count
207    step_count: usize,
208    /// Order in which parameters were added (for serialization/re-linking)
209    insertion_order: Vec<usize>,
210}
211
212impl Default for Adam {
213    fn default() -> Self {
214        Self::with_config(AdamConfig::default())
215    }
216}
217
218impl Adam {
219    /// Create a new Adam optimizer with default configuration
220    ///
221    /// Initializes an Adam optimizer with PyTorch-compatible default hyperparameters.
222    /// Parameters must be linked separately using `add_parameter` or `add_parameters`.
223    ///
224    /// # Returns
225    ///
226    /// A new Adam optimizer instance with default hyperparameters
227    #[track_caller]
228    pub fn new() -> Self {
229        Self::default()
230    }
231
232    /// Create a new Adam optimizer with custom configuration
233    ///
234    /// Allows full control over all Adam hyperparameters for specialized training
235    /// scenarios such as fine-tuning, transfer learning, or research applications.
236    /// Parameters must be linked separately using `add_parameter` or `add_parameters`.
237    ///
238    /// # Arguments
239    ///
240    /// * `config` - Adam configuration with custom hyperparameters
241    ///
242    /// # Returns
243    ///
244    /// A new Adam optimizer instance with the specified configuration
245    #[track_caller]
246    pub fn with_config(config: AdamConfig) -> Self {
247        Self {
248            config,
249            states: HashMap::new(),
250            step_count: 0,
251            insertion_order: Vec::new(),
252        }
253    }
254
255    /// Create a new Adam optimizer with custom learning rate
256    ///
257    /// A convenience constructor that allows setting only the learning rate while
258    /// using default values for all other hyperparameters. Parameters must be
259    /// linked separately using `add_parameter` or `add_parameters`.
260    ///
261    /// # Arguments
262    ///
263    /// * `learning_rate` - Learning rate for optimization
264    ///
265    /// # Returns
266    ///
267    /// A new Adam optimizer instance with the specified learning rate and default
268    /// values for all other hyperparameters
269    #[track_caller]
270    pub fn with_learning_rate(learning_rate: f32) -> Self {
271        let config = AdamConfig {
272            learning_rate,
273            ..Default::default()
274        };
275        Self::with_config(config)
276    }
277
278    /// Add a single parameter to the optimizer
279    ///
280    /// Links a parameter to the optimizer by creating a new parameter state
281    /// indexed by the tensor's ID. The parameter must have `requires_grad` set to true.
282    ///
283    /// # Arguments
284    ///
285    /// * `parameter` - Reference to the tensor to link
286    ///
287    /// # Panics
288    ///
289    /// Panics if the parameter does not have `requires_grad` set to true
290    #[track_caller]
291    pub fn add_parameter(&mut self, parameter: &Tensor) {
292        assert!(
293            parameter.requires_grad(),
294            "Parameter must require gradients"
295        );
296
297        let param_id = parameter.id();
298        let param_shape = parameter.shape().dims.clone();
299
300        // Initialize state for this parameter if not already present
301        use std::collections::hash_map::Entry;
302        if let Entry::Vacant(entry) = self.states.entry(param_id) {
303            entry.insert(ParameterState::new(&param_shape));
304            self.insertion_order.push(param_id);
305        }
306    }
307
308    /// Add multiple parameters to the optimizer
309    ///
310    /// Links multiple parameters to the optimizer by creating parameter states
311    /// indexed by each tensor's ID. All parameters must have `requires_grad` set to true.
312    ///
313    /// # Arguments
314    ///
315    /// * `parameters` - Slice of references to tensors to link
316    ///
317    /// # Panics
318    ///
319    /// Panics if any parameter does not have `requires_grad` set to true
320    #[track_caller]
321    pub fn add_parameters(&mut self, parameters: &[&Tensor]) {
322        for parameter in parameters {
323            self.add_parameter(parameter);
324        }
325    }
326
327    /// Remove a parameter from the optimizer
328    ///
329    /// Unlinks a parameter by removing its state from the optimizer.
330    /// The parameter ID is used for identification.
331    ///
332    /// # Arguments
333    ///
334    /// * `parameter` - Reference to the tensor to unlink
335    ///
336    /// # Returns
337    ///
338    /// True if the parameter was linked and removed, false if it was not linked
339    #[track_caller]
340    pub fn unlink_parameter(&mut self, parameter: &Tensor) -> bool {
341        let param_id = parameter.id();
342        let was_linked = self.states.remove(&param_id).is_some();
343        if was_linked {
344            self.insertion_order.retain(|&id| id != param_id);
345        }
346        was_linked
347    }
348
349    /// Remove all parameter states from the optimizer
350    ///
351    /// Clears all parameter states, effectively unlinking all parameters.
352    /// This is useful for resetting the optimizer or preparing for parameter re-linking.
353    #[track_caller]
354    pub fn clear_states(&mut self) {
355        self.states.clear();
356        self.insertion_order.clear();
357    }
358
359    /// Check if a parameter is linked to the optimizer
360    ///
361    /// Returns true if the parameter has an associated state in the optimizer.
362    ///
363    /// # Arguments
364    ///
365    /// * `parameter` - Reference to the tensor to check
366    ///
367    /// # Returns
368    ///
369    /// True if the parameter is linked, false otherwise
370    #[track_caller]
371    pub fn is_parameter_linked(&self, parameter: &Tensor) -> bool {
372        let param_id = parameter.id();
373        self.states.contains_key(&param_id)
374    }
375
376    /// Get the number of linked parameters
377    ///
378    /// Returns the count of parameters currently linked to the optimizer.
379    ///
380    /// # Returns
381    ///
382    /// Number of linked parameters
383    #[track_caller]
384    pub fn parameter_count(&self) -> usize {
385        self.states.len()
386    }
387
388    /// Re-link parameters to saved optimizer states in chronological order
389    ///
390    /// After deserializing an optimizer, use this method to restore saved parameter states
391    /// to new tensors. Parameters must be provided in the same chronological order they
392    /// were originally added to the optimizer. Shape validation ensures parameter compatibility.
393    ///
394    /// # Arguments
395    ///
396    /// * `parameters` - Slice of parameter references in chronological order
397    ///
398    /// # Returns
399    ///
400    /// Result indicating success or failure with detailed error message
401    ///
402    /// # Panics
403    ///
404    /// Panics if any parameter does not have `requires_grad` set to true
405    #[track_caller]
406    pub fn relink_parameters(&mut self, parameters: &[&Tensor]) -> Result<(), String> {
407        // Validate all parameters have requires_grad first
408        for (i, param) in parameters.iter().enumerate() {
409            if !param.requires_grad() {
410                return Err(format!("Parameter at index {} must require gradients", i));
411            }
412        }
413
414        // Check parameter count matches saved states
415        if parameters.len() != self.insertion_order.len() {
416            return Err(format!(
417                "Parameter count mismatch: expected {} parameters, got {}",
418                self.insertion_order.len(),
419                parameters.len()
420            ));
421        }
422
423        // Create new states map with parameter IDs mapped to saved states in chronological order
424        let mut new_states = HashMap::new();
425        let mut new_insertion_order = Vec::new();
426
427        for (i, param) in parameters.iter().enumerate() {
428            let new_param_id = param.id();
429            let old_param_id = self.insertion_order[i];
430
431            // Get the saved state for this position
432            let saved_state = self
433                .states
434                .get(&old_param_id)
435                .ok_or_else(|| format!("No saved state found for parameter at position {}", i))?;
436
437            // Validate shape matches
438            let param_shape = &param.shape().dims;
439            let saved_shape = &saved_state.m.shape().dims;
440            if param_shape != saved_shape {
441                return Err(format!(
442                    "Shape mismatch for parameter at position {}: expected {:?}, got {:?}",
443                    i, saved_shape, param_shape
444                ));
445            }
446
447            // Create new state for this parameter
448            let new_state = ParameterState {
449                m: saved_state.m.clone(),
450                v: saved_state.v.clone(),
451                v_hat_max: saved_state.v_hat_max.clone(),
452                step: saved_state.step,
453            };
454
455            new_states.insert(new_param_id, new_state);
456            new_insertion_order.push(new_param_id);
457        }
458
459        // Replace the states and insertion order
460        self.states = new_states;
461        self.insertion_order = new_insertion_order;
462
463        Ok(())
464    }
465
466    /// Get the current optimizer configuration
467    ///
468    /// Returns a reference to the current configuration, allowing inspection
469    /// of all hyperparameters without modification.
470    ///
471    /// # Returns
472    ///
473    /// Reference to the current Adam configuration
474    #[track_caller]
475    pub fn config(&self) -> &AdamConfig {
476        &self.config
477    }
478
479    /// Update a single parameter using Adam algorithm
480    ///
481    /// Implements the core Adam update rule with bias correction and optional AMSGrad.
482    /// Uses SIMD optimization when available for improved performance.
483    /// The parameter must be linked to the optimizer before calling this method.
484    ///
485    /// # Arguments
486    ///
487    /// * `param` - Mutable reference to the parameter tensor to update
488    ///
489    /// # Returns
490    ///
491    /// Result indicating success or failure of the parameter update
492    ///
493    /// # Panics
494    ///
495    /// Panics if the parameter is not linked to the optimizer
496    fn update_parameter(&mut self, param: &mut Tensor) -> Result<(), String> {
497        let param_id = param.id();
498
499        // Ensure parameter is linked
500        assert!(
501            self.states.contains_key(&param_id),
502            "Parameter must be linked to optimizer before stepping. Use add_parameter() first."
503        );
504
505        // Get parameter gradient
506        let grad = param
507            .grad_by_value()
508            .ok_or_else(|| format!("Parameter {} has no gradient", param_id))?;
509
510        // Get parameter state
511        let state = self
512            .states
513            .get_mut(&param_id)
514            .expect("Parameter state should exist after link check");
515
516        // Increment step count
517        state.step += 1;
518        let step = self.step_count as f32; // Use global step count for bias correction
519
520        // Apply weight decay if enabled
521        let effective_grad = if self.config.weight_decay > 0.0 {
522            // L2 regularization: grad + weight_decay * param
523            let mut grad_with_decay = grad.clone();
524            Self::add_weight_decay(&mut grad_with_decay, param, self.config.weight_decay);
525            grad_with_decay
526        } else {
527            grad
528        };
529
530        // Update biased first moment estimate: m = beta1 * m + (1 - beta1) * grad
531        Self::update_momentum(&mut state.m, &effective_grad, self.config.beta1);
532
533        // Update biased second moment estimate: v = beta2 * v + (1 - beta2) * grad^2
534        Self::update_velocity(&mut state.v, &effective_grad, self.config.beta2);
535
536        // Compute bias-corrected first moment estimate
537        let bias_correction1 = 1.0 - (self.config.beta1 as f64).powf(step as f64);
538        let m_hat = Self::bias_correct(&state.m, bias_correction1 as f32);
539
540        // Compute bias-corrected second moment estimate
541        let bias_correction2 = 1.0 - (self.config.beta2 as f64).powf(step as f64);
542        let mut v_hat = Self::bias_correct(&state.v, bias_correction2 as f32);
543
544        // AMSGrad: use maximum of v_hat over time
545        if self.config.amsgrad {
546            if state.v_hat_max.is_none() {
547                state.v_hat_max = Some(v_hat.clone());
548            }
549            let v_hat_max = state.v_hat_max.as_mut().unwrap();
550            Self::element_wise_max(v_hat_max, &v_hat);
551            v_hat = v_hat_max.clone();
552        }
553
554        // Compute parameter update: param = param - lr * m_hat / (sqrt(v_hat) + eps)
555        Self::apply_adam_update(
556            param,
557            &m_hat,
558            &v_hat,
559            self.config.learning_rate,
560            self.config.eps,
561        );
562
563        Ok(())
564    }
565
566    /// Apply weight decay (L2 regularization) to gradient
567    ///
568    /// Adds `weight_decay * param` to the gradient in-place for memory efficiency.
569    /// This implements L2 regularization by modifying the gradient before the
570    /// Adam update step, equivalent to adding a regularization term to the loss.
571    ///
572    /// # Arguments
573    ///
574    /// * `grad` - Gradient tensor to modify in-place
575    /// * `param` - Parameter tensor for weight decay calculation
576    /// * `weight_decay` - Weight decay coefficient
577    #[inline]
578    fn add_weight_decay(grad: &mut Tensor, param: &Tensor, weight_decay: f32) {
579        assert_eq!(
580            grad.size(),
581            param.size(),
582            "Gradient and parameter size mismatch"
583        );
584
585        unsafe {
586            let grad_ptr = grad.as_mut_ptr();
587            let param_ptr = param.as_ptr();
588            let size = grad.size();
589
590            #[cfg(target_arch = "x86_64")]
591            {
592                if is_x86_feature_detected!("avx2")
593                    && grad.is_simd_aligned()
594                    && param.is_simd_aligned()
595                {
596                    Self::add_weight_decay_simd_avx2(grad_ptr, param_ptr, weight_decay, size);
597                    return;
598                }
599            }
600
601            // Scalar fallback
602            for i in 0..size {
603                *grad_ptr.add(i) += weight_decay * *param_ptr.add(i);
604            }
605        }
606    }
607
608    #[cfg(target_arch = "x86_64")]
609    #[inline]
610    unsafe fn add_weight_decay_simd_avx2(
611        grad_ptr: *mut f32,
612        param_ptr: *const f32,
613        weight_decay: f32,
614        size: usize,
615    ) {
616        let decay_vec = _mm256_set1_ps(weight_decay);
617        let simd_count = size / 8;
618
619        for i in 0..simd_count {
620            let offset = i * 8;
621            let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
622            let param_vec = _mm256_loadu_ps(param_ptr.add(offset));
623            let decay_term = _mm256_mul_ps(decay_vec, param_vec);
624            let result = _mm256_add_ps(grad_vec, decay_term);
625            _mm256_storeu_ps(grad_ptr.add(offset), result);
626        }
627
628        // Handle remaining elements
629        for i in (simd_count * 8)..size {
630            *grad_ptr.add(i) += weight_decay * *param_ptr.add(i);
631        }
632    }
633
634    /// Update momentum (first moment estimate)
635    ///
636    /// Implements the momentum update rule: `m = beta1 * m + (1 - beta1) * grad`
637    /// This computes the exponentially decaying average of gradients, providing
638    /// momentum-like behavior that helps accelerate convergence in relevant directions.
639    ///
640    /// # Arguments
641    ///
642    /// * `momentum` - Momentum buffer to update in-place
643    /// * `grad` - Current gradient tensor
644    /// * `beta1` - Exponential decay rate for momentum
645    #[inline]
646    fn update_momentum(momentum: &mut Tensor, grad: &Tensor, beta1: f32) {
647        assert_eq!(
648            momentum.size(),
649            grad.size(),
650            "Momentum and gradient size mismatch"
651        );
652
653        let beta1_complement = 1.0 - beta1;
654
655        unsafe {
656            let m_ptr = momentum.as_mut_ptr();
657            let grad_ptr = grad.as_ptr();
658            let size = momentum.size();
659
660            #[cfg(target_arch = "x86_64")]
661            {
662                if is_x86_feature_detected!("avx2")
663                    && momentum.is_simd_aligned()
664                    && grad.is_simd_aligned()
665                {
666                    Self::update_momentum_simd_avx2(m_ptr, grad_ptr, beta1, beta1_complement, size);
667                    return;
668                }
669            }
670
671            // Scalar fallback
672            for i in 0..size {
673                *m_ptr.add(i) = beta1 * *m_ptr.add(i) + beta1_complement * *grad_ptr.add(i);
674            }
675        }
676    }
677
678    #[cfg(target_arch = "x86_64")]
679    #[inline]
680    unsafe fn update_momentum_simd_avx2(
681        m_ptr: *mut f32,
682        grad_ptr: *const f32,
683        beta1: f32,
684        beta1_complement: f32,
685        size: usize,
686    ) {
687        let beta1_vec = _mm256_set1_ps(beta1);
688        let beta1_comp_vec = _mm256_set1_ps(beta1_complement);
689        let simd_count = size / 8;
690
691        for i in 0..simd_count {
692            let offset = i * 8;
693            let m_vec = _mm256_loadu_ps(m_ptr.add(offset));
694            let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
695
696            let momentum_term = _mm256_mul_ps(beta1_vec, m_vec);
697            let gradient_term = _mm256_mul_ps(beta1_comp_vec, grad_vec);
698            let result = _mm256_add_ps(momentum_term, gradient_term);
699
700            _mm256_storeu_ps(m_ptr.add(offset), result);
701        }
702
703        // Handle remaining elements
704        for i in (simd_count * 8)..size {
705            *m_ptr.add(i) = beta1 * *m_ptr.add(i) + beta1_complement * *grad_ptr.add(i);
706        }
707    }
708
709    /// Update velocity (second moment estimate)
710    ///
711    /// Implements the velocity update rule: `v = beta2 * v + (1 - beta2) * grad^2`
712    /// This computes the exponentially decaying average of squared gradients,
713    /// providing adaptive learning rates that scale inversely with gradient magnitude.
714    ///
715    /// # Arguments
716    ///
717    /// * `velocity` - Velocity buffer to update in-place
718    /// * `grad` - Current gradient tensor
719    /// * `beta2` - Exponential decay rate for velocity
720    #[inline]
721    fn update_velocity(velocity: &mut Tensor, grad: &Tensor, beta2: f32) {
722        assert_eq!(
723            velocity.size(),
724            grad.size(),
725            "Velocity and gradient size mismatch"
726        );
727
728        let beta2_complement = 1.0 - beta2;
729
730        unsafe {
731            let v_ptr = velocity.as_mut_ptr();
732            let grad_ptr = grad.as_ptr();
733            let size = velocity.size();
734
735            #[cfg(target_arch = "x86_64")]
736            {
737                if is_x86_feature_detected!("avx2")
738                    && velocity.is_simd_aligned()
739                    && grad.is_simd_aligned()
740                {
741                    Self::update_velocity_simd_avx2(v_ptr, grad_ptr, beta2, beta2_complement, size);
742                    return;
743                }
744            }
745
746            // Scalar fallback
747            for i in 0..size {
748                let grad_val = *grad_ptr.add(i);
749                *v_ptr.add(i) = beta2 * *v_ptr.add(i) + beta2_complement * grad_val * grad_val;
750            }
751        }
752    }
753
754    #[cfg(target_arch = "x86_64")]
755    #[inline]
756    unsafe fn update_velocity_simd_avx2(
757        v_ptr: *mut f32,
758        grad_ptr: *const f32,
759        beta2: f32,
760        beta2_complement: f32,
761        size: usize,
762    ) {
763        let beta2_vec = _mm256_set1_ps(beta2);
764        let beta2_comp_vec = _mm256_set1_ps(beta2_complement);
765        let simd_count = size / 8;
766
767        for i in 0..simd_count {
768            let offset = i * 8;
769            let v_vec = _mm256_loadu_ps(v_ptr.add(offset));
770            let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
771
772            let velocity_term = _mm256_mul_ps(beta2_vec, v_vec);
773            let grad_squared = _mm256_mul_ps(grad_vec, grad_vec);
774            let gradient_term = _mm256_mul_ps(beta2_comp_vec, grad_squared);
775            let result = _mm256_add_ps(velocity_term, gradient_term);
776
777            _mm256_storeu_ps(v_ptr.add(offset), result);
778        }
779
780        // Handle remaining elements
781        for i in (simd_count * 8)..size {
782            let grad_val = *grad_ptr.add(i);
783            *v_ptr.add(i) = beta2 * *v_ptr.add(i) + beta2_complement * grad_val * grad_val;
784        }
785    }
786
787    /// Apply bias correction to moment estimates
788    ///
789    /// Returns `tensor / (1 - beta^step)` for unbiasing the moment estimates.
790    /// This correction is necessary because the moment estimates are initialized
791    /// to zero, creating a bias towards zero that becomes negligible as training progresses.
792    ///
793    /// # Arguments
794    ///
795    /// * `tensor` - Moment estimate tensor to correct
796    /// * `bias_correction` - Bias correction factor (1 - beta^step)
797    ///
798    /// # Returns
799    ///
800    /// Bias-corrected tensor with the same shape as input
801    #[inline]
802    fn bias_correct(tensor: &Tensor, bias_correction: f32) -> Tensor {
803        let mut result = tensor.clone();
804        let correction_factor = 1.0 / bias_correction;
805
806        unsafe {
807            let ptr = result.as_mut_ptr();
808            let size = result.size();
809
810            #[cfg(target_arch = "x86_64")]
811            {
812                if is_x86_feature_detected!("avx2") && result.is_simd_aligned() {
813                    Self::bias_correct_simd_avx2(ptr, correction_factor, size);
814                    return result;
815                }
816            }
817
818            // Scalar fallback
819            for i in 0..size {
820                *ptr.add(i) *= correction_factor;
821            }
822        }
823
824        result
825    }
826
827    #[cfg(target_arch = "x86_64")]
828    #[inline]
829    unsafe fn bias_correct_simd_avx2(ptr: *mut f32, correction_factor: f32, size: usize) {
830        let factor_vec = _mm256_set1_ps(correction_factor);
831        let simd_count = size / 8;
832
833        for i in 0..simd_count {
834            let offset = i * 8;
835            let data_vec = _mm256_loadu_ps(ptr.add(offset));
836            let result = _mm256_mul_ps(data_vec, factor_vec);
837            _mm256_storeu_ps(ptr.add(offset), result);
838        }
839
840        // Handle remaining elements
841        for i in (simd_count * 8)..size {
842            *ptr.add(i) *= correction_factor;
843        }
844    }
845
846    /// Element-wise maximum for AMSGrad
847    ///
848    /// Updates first tensor in-place with `max(first, second)` for AMSGrad variant.
849    /// This maintains a running maximum of the second moment estimates, preventing
850    /// the learning rate from increasing over time and improving convergence stability.
851    ///
852    /// # Arguments
853    ///
854    /// * `first` - Tensor to update in-place with maximum values
855    /// * `second` - Tensor to compare against for maximum calculation
856    #[inline]
857    fn element_wise_max(first: &mut Tensor, second: &Tensor) {
858        assert_eq!(
859            first.size(),
860            second.size(),
861            "Tensor size mismatch for element-wise max"
862        );
863
864        unsafe {
865            let first_ptr = first.as_mut_ptr();
866            let second_ptr = second.as_ptr();
867            let size = first.size();
868
869            #[cfg(target_arch = "x86_64")]
870            {
871                if is_x86_feature_detected!("avx2")
872                    && first.is_simd_aligned()
873                    && second.is_simd_aligned()
874                {
875                    Self::element_wise_max_simd_avx2(first_ptr, second_ptr, size);
876                    return;
877                }
878            }
879
880            // Scalar fallback
881            for i in 0..size {
882                let a = *first_ptr.add(i);
883                let b = *second_ptr.add(i);
884                *first_ptr.add(i) = if a > b { a } else { b };
885            }
886        }
887    }
888
889    #[cfg(target_arch = "x86_64")]
890    #[inline]
891    unsafe fn element_wise_max_simd_avx2(first_ptr: *mut f32, second_ptr: *const f32, size: usize) {
892        let simd_count = size / 8;
893
894        for i in 0..simd_count {
895            let offset = i * 8;
896            let first_vec = _mm256_loadu_ps(first_ptr.add(offset));
897            let second_vec = _mm256_loadu_ps(second_ptr.add(offset));
898            let result = _mm256_max_ps(first_vec, second_vec);
899            _mm256_storeu_ps(first_ptr.add(offset), result);
900        }
901
902        // Handle remaining elements
903        for i in (simd_count * 8)..size {
904            let a = *first_ptr.add(i);
905            let b = *second_ptr.add(i);
906            *first_ptr.add(i) = if a > b { a } else { b };
907        }
908    }
909
910    /// Apply the final Adam parameter update
911    ///
912    /// Implements the core Adam update rule: `param = param - lr * m_hat / (sqrt(v_hat) + eps)`
913    /// This applies the bias-corrected momentum and velocity estimates to update
914    /// the parameter values, with adaptive learning rates that scale inversely
915    /// with the square root of the velocity estimates.
916    ///
917    /// # Arguments
918    ///
919    /// * `param` - Parameter tensor to update in-place
920    /// * `m_hat` - Bias-corrected first moment estimate
921    /// * `v_hat` - Bias-corrected second moment estimate
922    /// * `learning_rate` - Learning rate for the update
923    /// * `eps` - Small constant for numerical stability
924    #[inline]
925    fn apply_adam_update(
926        param: &mut Tensor,
927        m_hat: &Tensor,
928        v_hat: &Tensor,
929        learning_rate: f32,
930        eps: f32,
931    ) {
932        assert_eq!(
933            param.size(),
934            m_hat.size(),
935            "Parameter and momentum size mismatch"
936        );
937        assert_eq!(
938            param.size(),
939            v_hat.size(),
940            "Parameter and velocity size mismatch"
941        );
942
943        unsafe {
944            let param_ptr = param.as_mut_ptr();
945            let m_ptr = m_hat.as_ptr();
946            let v_ptr = v_hat.as_ptr();
947            let size = param.size();
948
949            #[cfg(target_arch = "x86_64")]
950            {
951                if is_x86_feature_detected!("avx2")
952                    && param.is_simd_aligned()
953                    && m_hat.is_simd_aligned()
954                    && v_hat.is_simd_aligned()
955                {
956                    Self::apply_adam_update_simd_avx2(
957                        param_ptr,
958                        m_ptr,
959                        v_ptr,
960                        learning_rate,
961                        eps,
962                        size,
963                    );
964                    return;
965                }
966            }
967
968            // Scalar fallback
969            for i in 0..size {
970                let m_val = *m_ptr.add(i);
971                let v_val = *v_ptr.add(i);
972                let denominator = v_val.sqrt() + eps;
973                *param_ptr.add(i) -= learning_rate * m_val / denominator;
974            }
975        }
976    }
977
978    #[cfg(target_arch = "x86_64")]
979    #[inline]
980    unsafe fn apply_adam_update_simd_avx2(
981        param_ptr: *mut f32,
982        m_ptr: *const f32,
983        v_ptr: *const f32,
984        learning_rate: f32,
985        eps: f32,
986        size: usize,
987    ) {
988        let lr_vec = _mm256_set1_ps(learning_rate);
989        let eps_vec = _mm256_set1_ps(eps);
990        let simd_count = size / 8;
991
992        for i in 0..simd_count {
993            let offset = i * 8;
994            let param_vec = _mm256_loadu_ps(param_ptr.add(offset));
995            let m_vec = _mm256_loadu_ps(m_ptr.add(offset));
996            let v_vec = _mm256_loadu_ps(v_ptr.add(offset));
997
998            // sqrt(v_hat) + eps
999            let sqrt_v = _mm256_sqrt_ps(v_vec);
1000            let denominator = _mm256_add_ps(sqrt_v, eps_vec);
1001
1002            // lr * m_hat / denominator
1003            let lr_m = _mm256_mul_ps(lr_vec, m_vec);
1004            let update = _mm256_div_ps(lr_m, denominator);
1005
1006            // param - update
1007            let result = _mm256_sub_ps(param_vec, update);
1008            _mm256_storeu_ps(param_ptr.add(offset), result);
1009        }
1010
1011        // Handle remaining elements
1012        for i in (simd_count * 8)..size {
1013            let m_val = *m_ptr.add(i);
1014            let v_val = *v_ptr.add(i);
1015            let denominator = v_val.sqrt() + eps;
1016            *param_ptr.add(i) -= learning_rate * m_val / denominator;
1017        }
1018    }
1019}
1020
1021impl Optimizer for Adam {
1022    /// Perform a single optimization step
1023    ///
1024    /// Updates all provided parameters based on their accumulated gradients using the Adam algorithm.
1025    /// Each parameter is updated according to the Adam update rule with bias correction
1026    /// and optional AMSGrad variant if enabled. All parameters must be linked to the optimizer
1027    /// before calling this method.
1028    ///
1029    /// # Arguments
1030    ///
1031    /// * `parameters` - Mutable slice of parameter references to update
1032    ///
1033    /// # Thread Safety
1034    ///
1035    /// This method is thread-safe as it takes mutable references to parameters,
1036    /// ensuring exclusive access during updates.
1037    ///
1038    /// # Performance
1039    ///
1040    /// - Uses SIMD optimization (AVX2) when available for 8x vectorization
1041    /// - Processes parameters in sequence for optimal cache usage
1042    /// - Maintains per-parameter state for momentum and velocity estimates
1043    ///
1044    /// # Panics
1045    ///
1046    /// Panics if any parameter is not linked to the optimizer
1047    fn step(&mut self, parameters: &mut [&mut Tensor]) {
1048        self.step_count += 1;
1049
1050        for param in parameters {
1051            if let Err(e) = self.update_parameter(param) {
1052                eprintln!("Warning: Failed to update parameter: {}", e);
1053            }
1054        }
1055    }
1056
1057    /// Zero out all parameter gradients
1058    ///
1059    /// Clears accumulated gradients for all provided parameters. This should be called
1060    /// before each backward pass to prevent gradient accumulation across multiple
1061    /// forward/backward passes. Also clears the global autograd gradient map.
1062    ///
1063    /// # Arguments
1064    ///
1065    /// * `parameters` - Mutable slice of parameter references to clear gradients for
1066    ///
1067    /// # Performance
1068    ///
1069    /// - Efficiently clears gradients using optimized tensor operations
1070    /// - Clears both per-tensor gradients and global autograd state
1071    /// - Thread-safe as it takes mutable references to parameters
1072    fn zero_grad(&mut self, parameters: &mut [&mut Tensor]) {
1073        for param in parameters {
1074            param.zero_grad();
1075        }
1076
1077        // Also clear gradient tracking gradient map
1078        crate::gradtrack::clear_gradients();
1079    }
1080
1081    /// Get the current learning rate
1082    ///
1083    /// Returns the current learning rate used for parameter updates.
1084    ///
1085    /// # Returns
1086    ///
1087    /// Current learning rate as f32
1088    fn learning_rate(&self) -> f32 {
1089        self.config.learning_rate
1090    }
1091
1092    /// Set the learning rate for all parameters
1093    ///
1094    /// Updates the learning rate for all parameters in the optimizer.
1095    /// This allows dynamic learning rate scheduling during training.
1096    ///
1097    /// # Arguments
1098    ///
1099    /// * `lr` - New learning rate value
1100    fn set_learning_rate(&mut self, lr: f32) {
1101        self.config.learning_rate = lr;
1102    }
1103}
1104
1105// Adam is automatically Send + Sync since it no longer contains raw pointers
1106// and all fields are Send + Sync (HashMap, usize, AdamConfig)
1107
1108#[cfg(test)]
1109mod tests {
1110    use super::*;
1111    use crate::tensor::core::Tensor;
1112
1113    /// Test Adam optimizer creation with default configuration
1114    ///
1115    /// Verifies that the optimizer is created correctly with default hyperparameters
1116    /// and that parameter states are properly initialized.
1117    #[test]
1118    fn test_adam_creation() {
1119        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1120        let bias = Tensor::zeros(vec![3, 1]).with_requires_grad();
1121
1122        let mut optimizer = Adam::new();
1123        optimizer.add_parameter(&weight);
1124        optimizer.add_parameter(&bias);
1125
1126        assert_eq!(optimizer.learning_rate(), 1e-3);
1127        assert_eq!(optimizer.config.beta1, 0.9);
1128        assert_eq!(optimizer.config.beta2, 0.999);
1129        assert_eq!(optimizer.parameter_count(), 2);
1130    }
1131
1132    /// Test Adam optimizer creation with custom configuration
1133    ///
1134    /// Verifies that custom hyperparameters are properly set and that
1135    /// the optimizer configuration matches the provided values.
1136    #[test]
1137    fn test_adam_with_config() {
1138        let weight = Tensor::ones(vec![5, 5]).with_requires_grad();
1139
1140        let config = AdamConfig {
1141            learning_rate: 1e-4,
1142            beta1: 0.95,
1143            beta2: 0.9999,
1144            weight_decay: 1e-5,
1145            amsgrad: true,
1146            ..Default::default()
1147        };
1148
1149        let mut optimizer = Adam::with_config(config);
1150        optimizer.add_parameter(&weight);
1151
1152        assert_eq!(optimizer.learning_rate(), 1e-4);
1153        assert_eq!(optimizer.config.beta1, 0.95);
1154        assert_eq!(optimizer.config.beta2, 0.9999);
1155        assert_eq!(optimizer.config.weight_decay, 1e-5);
1156        assert!(optimizer.config.amsgrad);
1157    }
1158
1159    /// Test Adam optimizer creation with custom learning rate
1160    ///
1161    /// Verifies that the convenience constructor properly sets the learning rate
1162    /// while maintaining default values for other hyperparameters.
1163    #[test]
1164    fn test_adam_with_learning_rate() {
1165        let weight = Tensor::ones(vec![3, 3]).with_requires_grad();
1166        let mut optimizer = Adam::with_learning_rate(5e-4);
1167        optimizer.add_parameter(&weight);
1168
1169        assert_eq!(optimizer.learning_rate(), 5e-4);
1170        assert_eq!(optimizer.config.beta1, 0.9); // Should use defaults for other params
1171    }
1172
1173    /// Test Adam step without gradients
1174    ///
1175    /// Verifies that the optimizer handles the case where parameters have no
1176    /// gradients gracefully, leaving parameters unchanged.
1177    #[test]
1178    fn test_adam_step_without_gradients() {
1179        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1180        let original_data = weight.clone();
1181
1182        let mut optimizer = Adam::new();
1183        optimizer.add_parameter(&weight);
1184
1185        optimizer.step(&mut [&mut weight]); // Should not update without gradients
1186
1187        // Parameters should remain unchanged without gradients
1188        for i in 0..weight.size() {
1189            unsafe {
1190                assert_eq!(*weight.as_ptr().add(i), *original_data.as_ptr().add(i));
1191            }
1192        }
1193    }
1194
1195    /// Test learning rate update functionality
1196    ///
1197    /// Verifies that the learning rate can be dynamically updated during training
1198    /// and that the optimizer uses the new learning rate for subsequent steps.
1199    #[test]
1200    fn test_learning_rate_update() {
1201        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1202        let mut optimizer = Adam::new();
1203        optimizer.add_parameter(&weight);
1204
1205        assert_eq!(optimizer.learning_rate(), 1e-3);
1206
1207        optimizer.set_learning_rate(1e-2);
1208        assert_eq!(optimizer.learning_rate(), 1e-2);
1209    }
1210
1211    /// Test gradient zeroing functionality
1212    ///
1213    /// Verifies that zero_grad properly clears accumulated gradients for all
1214    /// parameters and the global autograd state.
1215    #[test]
1216    fn test_zero_grad() {
1217        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1218
1219        // Check that zero_grad clears gradients
1220        let mut optimizer = Adam::new();
1221        optimizer.add_parameter(&weight);
1222        optimizer.zero_grad(&mut [&mut weight]);
1223
1224        // After zero_grad, there should be no accumulated gradients
1225        assert!(weight.grad_by_value().is_none());
1226    }
1227
1228    /// Test requires_grad assertion
1229    ///
1230    /// Verifies that the optimizer correctly panics when parameters do not
1231    /// have requires_grad set to true, ensuring proper gradient tracking.
1232    #[test]
1233    #[should_panic(expected = "Parameter must require gradients")]
1234    fn test_adam_requires_grad_assertion() {
1235        let weight = Tensor::ones(vec![2, 2]); // No requires_grad
1236        let mut optimizer = Adam::new();
1237        optimizer.add_parameter(&weight);
1238    }
1239
1240    /// Test AdamConfig default values
1241    ///
1242    /// Verifies that the default configuration matches PyTorch conventions
1243    /// and provides optimal settings for most training scenarios.
1244    #[test]
1245    fn test_adam_config_default() {
1246        let config = AdamConfig::default();
1247
1248        assert_eq!(config.learning_rate, 1e-3);
1249        assert_eq!(config.beta1, 0.9);
1250        assert_eq!(config.beta2, 0.999);
1251        assert_eq!(config.eps, 1e-8);
1252        assert_eq!(config.weight_decay, 0.0);
1253        assert!(!config.amsgrad);
1254    }
1255
1256    /// Test ParameterState creation and initialization
1257    ///
1258    /// Verifies that parameter states are properly initialized with zero tensors
1259    /// and that the step count starts at zero.
1260    #[test]
1261    fn test_parameter_state_creation() {
1262        let state = ParameterState::new(&[3, 4]);
1263
1264        assert_eq!(state.m.shape().dims, vec![3, 4]);
1265        assert_eq!(state.v.shape().dims, vec![3, 4]);
1266        assert!(state.v_hat_max.is_none());
1267        assert_eq!(state.step, 0);
1268
1269        // Verify tensors are zero-initialized
1270        for i in 0..state.m.size() {
1271            unsafe {
1272                assert_eq!(*state.m.as_ptr().add(i), 0.0);
1273                assert_eq!(*state.v.as_ptr().add(i), 0.0);
1274            }
1275        }
1276    }
1277
1278    /// Test parameter linking functionality
1279    ///
1280    /// Verifies that parameters can be linked and unlinked from the optimizer.
1281    #[test]
1282    fn test_parameter_linking() {
1283        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1284        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1285
1286        let mut optimizer = Adam::new();
1287
1288        // Initially no parameters linked
1289        assert_eq!(optimizer.parameter_count(), 0);
1290        assert!(!optimizer.is_parameter_linked(&weight));
1291        assert!(!optimizer.is_parameter_linked(&bias));
1292
1293        // Link weight
1294        optimizer.add_parameter(&weight);
1295        assert_eq!(optimizer.parameter_count(), 1);
1296        assert!(optimizer.is_parameter_linked(&weight));
1297        assert!(!optimizer.is_parameter_linked(&bias));
1298
1299        // Link bias
1300        optimizer.add_parameter(&bias);
1301        assert_eq!(optimizer.parameter_count(), 2);
1302        assert!(optimizer.is_parameter_linked(&weight));
1303        assert!(optimizer.is_parameter_linked(&bias));
1304
1305        // Unlink weight
1306        let was_linked = optimizer.unlink_parameter(&weight);
1307        assert!(was_linked);
1308        assert_eq!(optimizer.parameter_count(), 1);
1309        assert!(!optimizer.is_parameter_linked(&weight));
1310        assert!(optimizer.is_parameter_linked(&bias));
1311
1312        // Clear all states
1313        optimizer.clear_states();
1314        assert_eq!(optimizer.parameter_count(), 0);
1315        assert!(!optimizer.is_parameter_linked(&weight));
1316        assert!(!optimizer.is_parameter_linked(&bias));
1317    }
1318
1319    /// Test parameter linking with multiple parameters at once
1320    ///
1321    /// Verifies that multiple parameters can be linked simultaneously.
1322    #[test]
1323    fn test_add_multiple_parameters() {
1324        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1325        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1326        let weight2 = Tensor::ones(vec![3, 2]).with_requires_grad();
1327
1328        let mut optimizer = Adam::new();
1329
1330        // Link multiple parameters at once
1331        optimizer.add_parameters(&[&weight, &bias, &weight2]);
1332
1333        assert_eq!(optimizer.parameter_count(), 3);
1334        assert!(optimizer.is_parameter_linked(&weight));
1335        assert!(optimizer.is_parameter_linked(&bias));
1336        assert!(optimizer.is_parameter_linked(&weight2));
1337    }
1338
1339    /// Test stepping with unlinked parameter
1340    ///
1341    /// Verifies that the optimizer panics when trying to step with an unlinked parameter.
1342    #[test]
1343    #[should_panic(expected = "Parameter must be linked to optimizer before stepping")]
1344    fn test_step_with_unlinked_parameter() {
1345        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1346        let mut optimizer = Adam::new();
1347
1348        // Don't link the parameter
1349        optimizer.step(&mut [&mut weight]); // Should panic
1350    }
1351
1352    /// Test optimizer with actual gradients
1353    ///
1354    /// Verifies that the optimizer properly updates parameters when gradients are present.
1355    #[test]
1356    fn test_optimizer_with_gradients() {
1357        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1358        let original_data = weight.clone();
1359
1360        let mut optimizer = Adam::new();
1361        optimizer.add_parameter(&weight);
1362
1363        // Generate some gradients
1364        let output = weight.mul_scalar(2.0);
1365        let mut loss = output.sum();
1366        loss.backward(None);
1367
1368        // Step should update parameters
1369        optimizer.step(&mut [&mut weight]);
1370
1371        // Parameters should have changed
1372        let mut changed = false;
1373        for i in 0..weight.size() {
1374            unsafe {
1375                if (*weight.as_ptr().add(i) - *original_data.as_ptr().add(i)).abs() > 1e-6 {
1376                    changed = true;
1377                    break;
1378                }
1379            }
1380        }
1381        assert!(
1382            changed,
1383            "Parameters should have been updated by optimizer step"
1384        );
1385    }
1386
1387    /// Test optimizer with multiple parameters and gradients
1388    ///
1389    /// Verifies that the optimizer works correctly with multiple parameters.
1390    #[test]
1391    fn test_optimizer_multiple_parameters() {
1392        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1393        let mut bias = Tensor::zeros(vec![2, 2]).with_requires_grad(); // Same shape as weight
1394
1395        let mut optimizer = Adam::new();
1396        optimizer.add_parameter(&weight);
1397        optimizer.add_parameter(&bias);
1398
1399        // Generate gradients for both parameters
1400        let output = weight.mul_scalar(2.0).add_tensor(&bias);
1401        let mut loss = output.sum();
1402        loss.backward(None);
1403
1404        // Step should update both parameters
1405        optimizer.step(&mut [&mut weight, &mut bias]);
1406
1407        // Both parameters should have gradients
1408        assert!(weight.grad_by_value().is_some());
1409        assert!(bias.grad_by_value().is_some());
1410    }
1411
1412    /// Test optimizer with custom configuration and multiple steps
1413    ///
1414    /// Verifies that the optimizer works correctly with custom configuration over multiple steps.
1415    #[test]
1416    fn test_optimizer_custom_config_multiple_steps() {
1417        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1418
1419        let config = AdamConfig {
1420            learning_rate: 1e-4,
1421            beta1: 0.95,
1422            beta2: 0.9999,
1423            weight_decay: 1e-5,
1424            amsgrad: true,
1425            ..Default::default()
1426        };
1427
1428        let mut optimizer = Adam::with_config(config);
1429        optimizer.add_parameter(&weight);
1430
1431        // Multiple training steps
1432        for _ in 0..5 {
1433            // Generate gradients
1434            let output = weight.mul_scalar(2.0);
1435            let mut loss = output.sum();
1436            loss.backward(None);
1437
1438            // Step
1439            optimizer.step(&mut [&mut weight]);
1440            optimizer.zero_grad(&mut [&mut weight]);
1441        }
1442
1443        // Should complete without errors
1444        assert_eq!(optimizer.parameter_count(), 1);
1445        assert!(optimizer.is_parameter_linked(&weight));
1446    }
1447
1448    /// Test optimizer with different tensor shapes
1449    ///
1450    /// Verifies that the optimizer works correctly with various tensor shapes.
1451    #[test]
1452    fn test_optimizer_different_shapes() {
1453        let shapes = vec![
1454            vec![1],       // Scalar
1455            vec![3],       // 1D
1456            vec![2, 2],    // 2D square
1457            vec![2, 3],    // 2D rectangular
1458            vec![1, 1, 3], // 3D
1459            vec![2, 2, 2], // 3D cube
1460        ];
1461
1462        for shape in shapes {
1463            let mut tensor = Tensor::ones(shape.clone()).with_requires_grad();
1464            let mut optimizer = Adam::new();
1465            optimizer.add_parameter(&tensor);
1466
1467            // Generate gradients
1468            let output = tensor.mul_scalar(2.0);
1469            let mut loss = output.sum();
1470            loss.backward(None);
1471
1472            // Step should work for all shapes
1473            optimizer.step(&mut [&mut tensor]);
1474
1475            // Verify tensor is still valid
1476            assert_eq!(tensor.shape().dims, shape);
1477            assert!(tensor.requires_grad());
1478        }
1479    }
1480
1481    /// Test double parameter linking
1482    ///
1483    /// Verifies that linking the same parameter twice doesn't create duplicate states.
1484    #[test]
1485    fn test_double_parameter_linking() {
1486        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1487        let mut optimizer = Adam::new();
1488
1489        // Link parameter twice
1490        optimizer.add_parameter(&weight);
1491        optimizer.add_parameter(&weight);
1492
1493        // Should only have one state
1494        assert_eq!(optimizer.parameter_count(), 1);
1495        assert!(optimizer.is_parameter_linked(&weight));
1496    }
1497
1498    /// Test unlink non-linked parameter
1499    ///
1500    /// Verifies that unlinking a non-linked parameter returns false.
1501    #[test]
1502    fn test_unlink_non_linked_parameter() {
1503        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1504        let mut optimizer = Adam::new();
1505
1506        // Try to unlink a parameter that was never linked
1507        let was_linked = optimizer.unlink_parameter(&weight);
1508        assert!(!was_linked);
1509    }
1510
1511    /// Test parameter shape validation
1512    ///
1513    /// Verifies that parameters with different shapes can be linked correctly
1514    /// and maintain their shape information in the optimizer state.
1515    #[test]
1516    fn test_parameter_shape_validation() {
1517        let weight_1d = Tensor::ones(vec![5]).with_requires_grad();
1518        let weight_2d = Tensor::ones(vec![3, 4]).with_requires_grad();
1519        let weight_3d = Tensor::ones(vec![2, 3, 4]).with_requires_grad();
1520
1521        let mut optimizer = Adam::new();
1522
1523        // Test linking parameters with different shapes
1524        optimizer.add_parameter(&weight_1d);
1525        optimizer.add_parameter(&weight_2d);
1526        optimizer.add_parameter(&weight_3d);
1527
1528        assert_eq!(optimizer.parameter_count(), 3);
1529        assert!(optimizer.is_parameter_linked(&weight_1d));
1530        assert!(optimizer.is_parameter_linked(&weight_2d));
1531        assert!(optimizer.is_parameter_linked(&weight_3d));
1532
1533        // Test that each parameter maintains its shape after linking
1534        let state_1d = &optimizer.states[&weight_1d.id()];
1535        let state_2d = &optimizer.states[&weight_2d.id()];
1536        let state_3d = &optimizer.states[&weight_3d.id()];
1537
1538        assert_eq!(state_1d.m.shape().dims, vec![5]);
1539        assert_eq!(state_2d.m.shape().dims, vec![3, 4]);
1540        assert_eq!(state_3d.m.shape().dims, vec![2, 3, 4]);
1541    }
1542
1543    /// Test parameter relinking with shape consistency
1544    ///
1545    /// Verifies that when re-linking parameters (e.g., after deserialization),
1546    /// shape information is correctly maintained.
1547    #[test]
1548    fn test_parameter_relinking_shape_consistency() {
1549        // Create initial parameter and optimizer
1550        let mut weight_original = Tensor::ones(vec![3, 3]).with_requires_grad();
1551        let mut optimizer = Adam::new();
1552        optimizer.add_parameter(&weight_original);
1553
1554        // Perform a step to create state
1555        let output = weight_original.mul_scalar(2.0);
1556        let mut loss = output.sum();
1557        loss.backward(None);
1558        optimizer.step(&mut [&mut weight_original]);
1559
1560        // Verify state was created with correct shape
1561        let original_state = &optimizer.states[&weight_original.id()];
1562        assert_eq!(original_state.m.shape().dims, vec![3, 3]);
1563        assert_eq!(original_state.v.shape().dims, vec![3, 3]);
1564
1565        // Create new parameter with same shape (will get different ID)
1566        let weight_new = Tensor::ones(vec![3, 3]).with_requires_grad();
1567
1568        // This should create a new state since it's a different parameter
1569        optimizer.add_parameter(&weight_new);
1570
1571        // Should now have 2 states
1572        assert_eq!(optimizer.parameter_count(), 2);
1573
1574        // Both should be linked with correct shapes
1575        assert!(optimizer.is_parameter_linked(&weight_original));
1576        assert!(optimizer.is_parameter_linked(&weight_new));
1577
1578        let new_state = &optimizer.states[&weight_new.id()];
1579        assert_eq!(new_state.m.shape().dims, vec![3, 3]);
1580        assert_eq!(new_state.v.shape().dims, vec![3, 3]);
1581    }
1582
1583    /// Test large parameter count handling
1584    ///
1585    /// Verifies that the optimizer can handle many parameters efficiently
1586    /// and correctly manage linking/unlinking operations.
1587    #[test]
1588    fn test_large_parameter_count() {
1589        let mut optimizer = Adam::new();
1590        let mut params = Vec::new();
1591
1592        // Create 50 parameters of different shapes
1593        for i in 1..=50 {
1594            let param = Tensor::ones(vec![i]).with_requires_grad();
1595            optimizer.add_parameter(&param);
1596            params.push(param);
1597        }
1598
1599        assert_eq!(optimizer.parameter_count(), 50);
1600
1601        // Verify all parameters are linked
1602        for param in &params {
1603            assert!(optimizer.is_parameter_linked(param));
1604        }
1605
1606        // Test unlinking some parameters
1607        for param in params.iter().take(25).step_by(2) {
1608            assert!(optimizer.unlink_parameter(param));
1609        }
1610
1611        assert_eq!(optimizer.parameter_count(), 37); // 50 - 13 = 37
1612
1613        // Verify correct parameters are unlinked
1614        for (i, param) in params.iter().enumerate() {
1615            if i < 25 && i % 2 == 0 {
1616                assert!(!optimizer.is_parameter_linked(param));
1617            } else {
1618                assert!(optimizer.is_parameter_linked(param));
1619            }
1620        }
1621    }
1622
1623    /// Test clear_states functionality
1624    ///
1625    /// Verifies that clearing all states works correctly and allows
1626    /// re-adding parameters afterwards.
1627    #[test]
1628    fn test_clear_states_functionality() {
1629        let weight1 = Tensor::ones(vec![2, 2]).with_requires_grad();
1630        let weight2 = Tensor::ones(vec![3, 3]).with_requires_grad();
1631        let weight3 = Tensor::ones(vec![4, 4]).with_requires_grad();
1632
1633        let mut optimizer = Adam::new();
1634        optimizer.add_parameter(&weight1);
1635        optimizer.add_parameter(&weight2);
1636        optimizer.add_parameter(&weight3);
1637
1638        assert_eq!(optimizer.parameter_count(), 3);
1639
1640        // Clear all states
1641        optimizer.clear_states();
1642
1643        assert_eq!(optimizer.parameter_count(), 0);
1644        assert!(!optimizer.is_parameter_linked(&weight1));
1645        assert!(!optimizer.is_parameter_linked(&weight2));
1646        assert!(!optimizer.is_parameter_linked(&weight3));
1647
1648        // Should be able to re-add parameters after clearing
1649        optimizer.add_parameter(&weight1);
1650        assert_eq!(optimizer.parameter_count(), 1);
1651        assert!(optimizer.is_parameter_linked(&weight1));
1652    }
1653}