train_station/optimizers/adam/
mod.rs

1//! Adam optimizer implementation for neural network training
2//!
3//! This module provides the Adam optimization algorithm with PyTorch-compatible interface.
4//! Adam combines the benefits of AdaGrad and RMSprop by using adaptive learning rates
5//! with momentum for efficient training of neural networks.
6//!
7//! # Features
8//!
9//! - **Adaptive Learning Rates**: Per-parameter learning rates based on gradient history
10//! - **Momentum Integration**: Combines momentum with adaptive learning rates
11//! - **Bias Correction**: Corrects for initialization bias in moment estimates
12//! - **Weight Decay**: Optional L2 regularization support
13//! - **AMSGrad Variant**: Optional AMSGrad for improved convergence stability
14//! - **SIMD Optimization**: AVX2-optimized parameter updates for maximum performance
15//! - **Thread Safety**: Send + Sync implementation for multi-threaded training
16//! - **Hybrid API**: Both safe (RwLock-based) and unsafe (direct pointer) access patterns
17//!
18//! # Thread Safety
19//!
20//! The optimizer provides two usage patterns:
21//!
22//! **Safe Multi-threaded Usage (Default)**:
23//! - Uses `Arc<RwLock<Tensor>>` for thread-safe parameter access
24//! - Multiple threads can read tensors simultaneously
25//! - Optimizer steps acquire write locks only during parameter updates
26//! - Recommended for most use cases
27//!
28//! **Unsafe Single-threaded Usage (Performance)**:
29//! - Uses raw pointers for maximum performance
30//! - No locking overhead during optimizer steps
31//! - Caller must ensure exclusive access during optimization
32//! - Use only when you can guarantee no concurrent tensor access
33//!
34//! # Algorithm
35//!
36//! Adam implements the following update rule for each parameter theta:
37//!
38//! ```text
39//! m_t = beta1 * m_{t-1} + (1 - beta1) * grad_theta_t
40//! v_t = beta2 * v_{t-1} + (1 - beta2) * (grad_theta_t)^2
41//! m_hat_t = m_t / (1 - beta1^t)
42//! v_hat_t = v_t / (1 - beta2^t)
43//! theta_{t+1} = theta_t - lr * m_hat_t / (sqrt(v_hat_t) + eps)
44//! ```
45//!
46//! Where:
47//! - lr is the learning rate
48//! - beta1, beta2 are exponential decay rates for moment estimates
49//! - eps is a small constant for numerical stability
50//! - m_t, v_t are biased first and second moment estimates
51//! - m_hat_t, v_hat_t are bias-corrected moment estimates
52//!
53//! # Performance Characteristics
54//!
55//! - **SIMD Optimization**: Uses AVX2 instructions for 8x vectorization when available
56//! - **Memory Efficiency**: In-place updates with minimal temporary allocations
57//! - **Cache-Friendly**: Optimized memory access patterns for large parameter tensors
58//! - **Zero-Cost Abstractions**: Compile-time optimization with minimal runtime overhead
59//! - **Lock-Free Reads**: RwLock allows concurrent tensor reads during training
60//!
61//! # References
62//!
63//! - Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization.
64//! - PyTorch Adam implementation: <https://pytorch.org/docs/stable/generated/torch.optim.Adam.html>
65
66pub mod serialization;
67
68use super::Optimizer;
69use crate::tensor::core::Tensor;
70use std::collections::HashMap;
71
72// SIMD optimizations for performance-critical operations
73#[cfg(target_arch = "x86_64")]
74use std::arch::x86_64::*;
75
76/// Configuration for the Adam optimization algorithm
77///
78/// Contains all hyperparameters that control the behavior of Adam optimization.
79/// Default values follow PyTorch conventions for maximum compatibility and
80/// optimal convergence across a wide range of neural network architectures.
81///
82/// # Fields
83///
84/// * `learning_rate` - Step size for parameter updates (default: 1e-3)
85/// * `beta1` - Exponential decay rate for first moment estimates (default: 0.9)
86/// * `beta2` - Exponential decay rate for second moment estimates (default: 0.999)
87/// * `eps` - Small constant for numerical stability (default: 1e-8)
88/// * `weight_decay` - L2 regularization coefficient (default: 0.0)
89/// * `amsgrad` - Whether to use AMSGrad variant for improved stability (default: false)
90#[derive(Debug, Clone)]
91pub struct AdamConfig {
92    /// Learning rate for parameter updates (default: 1e-3)
93    pub learning_rate: f32,
94    /// Exponential decay rate for first moment estimates (default: 0.9)  
95    pub beta1: f32,
96    /// Exponential decay rate for second moment estimates (default: 0.999)
97    pub beta2: f32,
98    /// Small constant for numerical stability (default: 1e-8)
99    pub eps: f32,
100    /// Weight decay coefficient for L2 regularization (default: 0.0)
101    pub weight_decay: f32,
102    /// Whether to use AMSGrad variant (default: false)
103    pub amsgrad: bool,
104}
105
106impl Default for AdamConfig {
107    fn default() -> Self {
108        Self {
109            learning_rate: 1e-3,
110            beta1: 0.9,
111            beta2: 0.999,
112            eps: 1e-8,
113            weight_decay: 0.0,
114            amsgrad: false,
115        }
116    }
117}
118
119/// Internal state tracking for a single parameter during Adam optimization
120///
121/// Stores the momentum and velocity buffers needed for Adam optimization,
122/// along with step count for bias correction and optional AMSGrad state.
123/// Memory layout is optimized for cache efficiency and SIMD operations.
124///
125/// # Fields
126///
127/// * `m` - First moment estimate (momentum buffer)
128/// * `v` - Second moment estimate (velocity buffer)
129/// * `v_hat_max` - Maximum of second moment estimates (for AMSGrad variant)
130/// * `step` - Step count for bias correction calculations
131#[derive(Debug)]
132struct ParameterState {
133    /// First moment estimate (momentum)
134    m: Tensor,
135    /// Second moment estimate (velocity)  
136    v: Tensor,
137    /// Maximum of second moment estimates (for AMSGrad)
138    v_hat_max: Option<Tensor>,
139    /// Step count for bias correction
140    step: usize,
141}
142
143impl ParameterState {
144    fn new(param_shape: &[usize]) -> Self {
145        Self {
146            m: Tensor::zeros(param_shape.to_vec()),
147            v: Tensor::zeros(param_shape.to_vec()),
148            v_hat_max: None,
149            step: 0,
150        }
151    }
152}
153
154/// Adam optimizer for neural network parameter optimization
155///
156/// Implements the Adam optimization algorithm with PyTorch-compatible interface.
157/// Provides adaptive learning rates with momentum for efficient training of neural networks.
158/// The optimizer maintains per-parameter state for momentum and velocity estimates,
159/// enabling adaptive learning rates that improve convergence across diverse architectures.
160///
161/// # Usage Pattern
162///
163/// The optimizer uses ID-based parameter linking for maximum flexibility and thread safety:
164/// - Parameters are linked to the optimizer via `add_parameter` or `add_parameters`
165/// - The `step` method takes mutable references to parameters for thread-safe updates
166/// - Parameter states are maintained by tensor ID, allowing for dynamic parameter management
167/// - Supports serialization and deserialization with parameter re-linking
168///
169/// # Dynamic Parameter Management
170///
171/// Parameters can be added, removed, or re-linked at runtime:
172/// - `add_parameter`: Link a single parameter
173/// - `add_parameters`: Link multiple parameters at once
174/// - `unlink_parameter`: Remove parameter state by ID
175/// - `clear_states`: Remove all parameter states
176/// - `is_parameter_linked`: Check if a parameter is linked
177///
178/// # Serialization Support
179///
180/// The optimizer supports full serialization and deserialization with state preservation:
181/// - Parameter states are saved with their shapes and insertion order for validation
182/// - After deserialization, use `relink_parameters` to restore saved states to new tensors
183/// - Parameters must be re-linked in the same chronological order they were originally added
184/// - Shape validation ensures consistency between saved and current parameters
185///
186/// # Features
187///
188/// - **ID-Based Parameter Linking**: Dynamic parameter management via tensor IDs
189/// - **Thread-Safe Step Method**: Takes mutable references for safe concurrent access
190/// - **Per-Parameter State**: Each parameter maintains its own momentum and velocity buffers
191/// - **Bias Correction**: Automatically corrects initialization bias in moment estimates
192/// - **Weight Decay**: Optional L2 regularization with efficient implementation
193/// - **AMSGrad Support**: Optional AMSGrad variant for improved convergence stability
194/// - **SIMD Optimization**: AVX2-optimized updates for maximum performance
195/// - **Full Serialization**: Complete state persistence and restoration
196///
197/// # Thread Safety
198///
199/// This type is thread-safe and can be shared between threads. The step method
200/// takes mutable references to parameters, ensuring exclusive access during updates.
201pub struct Adam {
202    /// Optimizer configuration
203    config: AdamConfig,
204    /// Parameter states indexed by tensor ID
205    states: HashMap<usize, ParameterState>,
206    /// Current global step count
207    step_count: usize,
208    /// Order in which parameters were added (for serialization/re-linking)
209    insertion_order: Vec<usize>,
210}
211
212impl Default for Adam {
213    fn default() -> Self {
214        Self::with_config(AdamConfig::default())
215    }
216}
217
218impl Adam {
219    /// Create a new Adam optimizer with default configuration
220    ///
221    /// Initializes an Adam optimizer with PyTorch-compatible default hyperparameters.
222    /// Parameters must be linked separately using `add_parameter` or `add_parameters`.
223    ///
224    /// # Returns
225    ///
226    /// A new Adam optimizer instance with default hyperparameters
227    pub fn new() -> Self {
228        Self::default()
229    }
230
231    /// Create a new Adam optimizer with custom configuration
232    ///
233    /// Allows full control over all Adam hyperparameters for specialized training
234    /// scenarios such as fine-tuning, transfer learning, or research applications.
235    /// Parameters must be linked separately using `add_parameter` or `add_parameters`.
236    ///
237    /// # Arguments
238    ///
239    /// * `config` - Adam configuration with custom hyperparameters
240    ///
241    /// # Returns
242    ///
243    /// A new Adam optimizer instance with the specified configuration
244    pub fn with_config(config: AdamConfig) -> Self {
245        Self {
246            config,
247            states: HashMap::new(),
248            step_count: 0,
249            insertion_order: Vec::new(),
250        }
251    }
252
253    /// Create a new Adam optimizer with custom learning rate
254    ///
255    /// A convenience constructor that allows setting only the learning rate while
256    /// using default values for all other hyperparameters. Parameters must be
257    /// linked separately using `add_parameter` or `add_parameters`.
258    ///
259    /// # Arguments
260    ///
261    /// * `learning_rate` - Learning rate for optimization
262    ///
263    /// # Returns
264    ///
265    /// A new Adam optimizer instance with the specified learning rate and default
266    /// values for all other hyperparameters
267    pub fn with_learning_rate(learning_rate: f32) -> Self {
268        let config = AdamConfig {
269            learning_rate,
270            ..Default::default()
271        };
272        Self::with_config(config)
273    }
274
275    /// Add a single parameter to the optimizer
276    ///
277    /// Links a parameter to the optimizer by creating a new parameter state
278    /// indexed by the tensor's ID. The parameter must have `requires_grad` set to true.
279    ///
280    /// # Arguments
281    ///
282    /// * `parameter` - Reference to the tensor to link
283    ///
284    /// # Panics
285    ///
286    /// Panics if the parameter does not have `requires_grad` set to true
287    pub fn add_parameter(&mut self, parameter: &Tensor) {
288        assert!(
289            parameter.requires_grad(),
290            "Parameter must require gradients"
291        );
292
293        let param_id = parameter.id();
294        let param_shape = parameter.shape().dims.clone();
295
296        // Initialize state for this parameter if not already present
297        use std::collections::hash_map::Entry;
298        if let Entry::Vacant(entry) = self.states.entry(param_id) {
299            entry.insert(ParameterState::new(&param_shape));
300            self.insertion_order.push(param_id);
301        }
302    }
303
304    /// Add multiple parameters to the optimizer
305    ///
306    /// Links multiple parameters to the optimizer by creating parameter states
307    /// indexed by each tensor's ID. All parameters must have `requires_grad` set to true.
308    ///
309    /// # Arguments
310    ///
311    /// * `parameters` - Slice of references to tensors to link
312    ///
313    /// # Panics
314    ///
315    /// Panics if any parameter does not have `requires_grad` set to true
316    pub fn add_parameters(&mut self, parameters: &[&Tensor]) {
317        for parameter in parameters {
318            self.add_parameter(parameter);
319        }
320    }
321
322    /// Remove a parameter from the optimizer
323    ///
324    /// Unlinks a parameter by removing its state from the optimizer.
325    /// The parameter ID is used for identification.
326    ///
327    /// # Arguments
328    ///
329    /// * `parameter` - Reference to the tensor to unlink
330    ///
331    /// # Returns
332    ///
333    /// True if the parameter was linked and removed, false if it was not linked
334    pub fn unlink_parameter(&mut self, parameter: &Tensor) -> bool {
335        let param_id = parameter.id();
336        let was_linked = self.states.remove(&param_id).is_some();
337        if was_linked {
338            self.insertion_order.retain(|&id| id != param_id);
339        }
340        was_linked
341    }
342
343    /// Remove all parameter states from the optimizer
344    ///
345    /// Clears all parameter states, effectively unlinking all parameters.
346    /// This is useful for resetting the optimizer or preparing for parameter re-linking.
347    pub fn clear_states(&mut self) {
348        self.states.clear();
349        self.insertion_order.clear();
350    }
351
352    /// Check if a parameter is linked to the optimizer
353    ///
354    /// Returns true if the parameter has an associated state in the optimizer.
355    ///
356    /// # Arguments
357    ///
358    /// * `parameter` - Reference to the tensor to check
359    ///
360    /// # Returns
361    ///
362    /// True if the parameter is linked, false otherwise
363    pub fn is_parameter_linked(&self, parameter: &Tensor) -> bool {
364        let param_id = parameter.id();
365        self.states.contains_key(&param_id)
366    }
367
368    /// Get the number of linked parameters
369    ///
370    /// Returns the count of parameters currently linked to the optimizer.
371    ///
372    /// # Returns
373    ///
374    /// Number of linked parameters
375    pub fn parameter_count(&self) -> usize {
376        self.states.len()
377    }
378
379    /// Re-link parameters to saved optimizer states in chronological order
380    ///
381    /// After deserializing an optimizer, use this method to restore saved parameter states
382    /// to new tensors. Parameters must be provided in the same chronological order they
383    /// were originally added to the optimizer. Shape validation ensures parameter compatibility.
384    ///
385    /// # Arguments
386    ///
387    /// * `parameters` - Slice of parameter references in chronological order
388    ///
389    /// # Returns
390    ///
391    /// Result indicating success or failure with detailed error message
392    ///
393    /// # Panics
394    ///
395    /// Panics if any parameter does not have `requires_grad` set to true
396    pub fn relink_parameters(&mut self, parameters: &[&Tensor]) -> Result<(), String> {
397        // Validate all parameters have requires_grad first
398        for (i, param) in parameters.iter().enumerate() {
399            if !param.requires_grad() {
400                return Err(format!("Parameter at index {} must require gradients", i));
401            }
402        }
403
404        // Check parameter count matches saved states
405        if parameters.len() != self.insertion_order.len() {
406            return Err(format!(
407                "Parameter count mismatch: expected {} parameters, got {}",
408                self.insertion_order.len(),
409                parameters.len()
410            ));
411        }
412
413        // Create new states map with parameter IDs mapped to saved states in chronological order
414        let mut new_states = HashMap::new();
415        let mut new_insertion_order = Vec::new();
416
417        for (i, param) in parameters.iter().enumerate() {
418            let new_param_id = param.id();
419            let old_param_id = self.insertion_order[i];
420
421            // Get the saved state for this position
422            let saved_state = self
423                .states
424                .get(&old_param_id)
425                .ok_or_else(|| format!("No saved state found for parameter at position {}", i))?;
426
427            // Validate shape matches
428            let param_shape = &param.shape().dims;
429            let saved_shape = &saved_state.m.shape().dims;
430            if param_shape != saved_shape {
431                return Err(format!(
432                    "Shape mismatch for parameter at position {}: expected {:?}, got {:?}",
433                    i, saved_shape, param_shape
434                ));
435            }
436
437            // Create new state for this parameter
438            let new_state = ParameterState {
439                m: saved_state.m.clone(),
440                v: saved_state.v.clone(),
441                v_hat_max: saved_state.v_hat_max.clone(),
442                step: saved_state.step,
443            };
444
445            new_states.insert(new_param_id, new_state);
446            new_insertion_order.push(new_param_id);
447        }
448
449        // Replace the states and insertion order
450        self.states = new_states;
451        self.insertion_order = new_insertion_order;
452
453        Ok(())
454    }
455
456    /// Get the current optimizer configuration
457    ///
458    /// Returns a reference to the current configuration, allowing inspection
459    /// of all hyperparameters without modification.
460    ///
461    /// # Returns
462    ///
463    /// Reference to the current Adam configuration
464    pub fn config(&self) -> &AdamConfig {
465        &self.config
466    }
467
468    /// Update a single parameter using Adam algorithm
469    ///
470    /// Implements the core Adam update rule with bias correction and optional AMSGrad.
471    /// Uses SIMD optimization when available for improved performance.
472    /// The parameter must be linked to the optimizer before calling this method.
473    ///
474    /// # Arguments
475    ///
476    /// * `param` - Mutable reference to the parameter tensor to update
477    ///
478    /// # Returns
479    ///
480    /// Result indicating success or failure of the parameter update
481    ///
482    /// # Panics
483    ///
484    /// Panics if the parameter is not linked to the optimizer
485    fn update_parameter(&mut self, param: &mut Tensor) -> Result<(), String> {
486        let param_id = param.id();
487
488        // Ensure parameter is linked
489        assert!(
490            self.states.contains_key(&param_id),
491            "Parameter must be linked to optimizer before stepping. Use add_parameter() first."
492        );
493
494        // Get parameter gradient
495        let grad = param
496            .grad_by_value()
497            .ok_or_else(|| format!("Parameter {} has no gradient", param_id))?;
498
499        // Get parameter state
500        let state = self
501            .states
502            .get_mut(&param_id)
503            .expect("Parameter state should exist after link check");
504
505        // Increment step count
506        state.step += 1;
507        let step = self.step_count as f32; // Use global step count for bias correction
508
509        // Apply weight decay if enabled
510        let effective_grad = if self.config.weight_decay > 0.0 {
511            // L2 regularization: grad + weight_decay * param
512            let mut grad_with_decay = grad.clone();
513            Self::add_weight_decay(&mut grad_with_decay, param, self.config.weight_decay);
514            grad_with_decay
515        } else {
516            grad
517        };
518
519        // Update biased first moment estimate: m = beta1 * m + (1 - beta1) * grad
520        Self::update_momentum(&mut state.m, &effective_grad, self.config.beta1);
521
522        // Update biased second moment estimate: v = beta2 * v + (1 - beta2) * grad^2
523        Self::update_velocity(&mut state.v, &effective_grad, self.config.beta2);
524
525        // Compute bias-corrected first moment estimate
526        let bias_correction1 = 1.0 - (self.config.beta1 as f64).powf(step as f64);
527        let m_hat = Self::bias_correct(&state.m, bias_correction1 as f32);
528
529        // Compute bias-corrected second moment estimate
530        let bias_correction2 = 1.0 - (self.config.beta2 as f64).powf(step as f64);
531        let mut v_hat = Self::bias_correct(&state.v, bias_correction2 as f32);
532
533        // AMSGrad: use maximum of v_hat over time
534        if self.config.amsgrad {
535            if state.v_hat_max.is_none() {
536                state.v_hat_max = Some(v_hat.clone());
537            }
538            let v_hat_max = state.v_hat_max.as_mut().unwrap();
539            Self::element_wise_max(v_hat_max, &v_hat);
540            v_hat = v_hat_max.clone();
541        }
542
543        // Compute parameter update: param = param - lr * m_hat / (sqrt(v_hat) + eps)
544        Self::apply_adam_update(
545            param,
546            &m_hat,
547            &v_hat,
548            self.config.learning_rate,
549            self.config.eps,
550        );
551
552        Ok(())
553    }
554
555    /// Apply weight decay (L2 regularization) to gradient
556    ///
557    /// Adds `weight_decay * param` to the gradient in-place for memory efficiency.
558    /// This implements L2 regularization by modifying the gradient before the
559    /// Adam update step, equivalent to adding a regularization term to the loss.
560    ///
561    /// # Arguments
562    ///
563    /// * `grad` - Gradient tensor to modify in-place
564    /// * `param` - Parameter tensor for weight decay calculation
565    /// * `weight_decay` - Weight decay coefficient
566    #[inline]
567    fn add_weight_decay(grad: &mut Tensor, param: &Tensor, weight_decay: f32) {
568        assert_eq!(
569            grad.size(),
570            param.size(),
571            "Gradient and parameter size mismatch"
572        );
573
574        unsafe {
575            let grad_ptr = grad.as_mut_ptr();
576            let param_ptr = param.as_ptr();
577            let size = grad.size();
578
579            #[cfg(target_arch = "x86_64")]
580            {
581                if is_x86_feature_detected!("avx2")
582                    && grad.is_simd_aligned()
583                    && param.is_simd_aligned()
584                {
585                    Self::add_weight_decay_simd_avx2(grad_ptr, param_ptr, weight_decay, size);
586                    return;
587                }
588            }
589
590            // Scalar fallback
591            for i in 0..size {
592                *grad_ptr.add(i) += weight_decay * *param_ptr.add(i);
593            }
594        }
595    }
596
597    #[cfg(target_arch = "x86_64")]
598    #[inline]
599    unsafe fn add_weight_decay_simd_avx2(
600        grad_ptr: *mut f32,
601        param_ptr: *const f32,
602        weight_decay: f32,
603        size: usize,
604    ) {
605        let decay_vec = _mm256_set1_ps(weight_decay);
606        let simd_count = size / 8;
607
608        for i in 0..simd_count {
609            let offset = i * 8;
610            let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
611            let param_vec = _mm256_loadu_ps(param_ptr.add(offset));
612            let decay_term = _mm256_mul_ps(decay_vec, param_vec);
613            let result = _mm256_add_ps(grad_vec, decay_term);
614            _mm256_storeu_ps(grad_ptr.add(offset), result);
615        }
616
617        // Handle remaining elements
618        for i in (simd_count * 8)..size {
619            *grad_ptr.add(i) += weight_decay * *param_ptr.add(i);
620        }
621    }
622
623    /// Update momentum (first moment estimate)
624    ///
625    /// Implements the momentum update rule: `m = beta1 * m + (1 - beta1) * grad`
626    /// This computes the exponentially decaying average of gradients, providing
627    /// momentum-like behavior that helps accelerate convergence in relevant directions.
628    ///
629    /// # Arguments
630    ///
631    /// * `momentum` - Momentum buffer to update in-place
632    /// * `grad` - Current gradient tensor
633    /// * `beta1` - Exponential decay rate for momentum
634    #[inline]
635    fn update_momentum(momentum: &mut Tensor, grad: &Tensor, beta1: f32) {
636        assert_eq!(
637            momentum.size(),
638            grad.size(),
639            "Momentum and gradient size mismatch"
640        );
641
642        let beta1_complement = 1.0 - beta1;
643
644        unsafe {
645            let m_ptr = momentum.as_mut_ptr();
646            let grad_ptr = grad.as_ptr();
647            let size = momentum.size();
648
649            #[cfg(target_arch = "x86_64")]
650            {
651                if is_x86_feature_detected!("avx2")
652                    && momentum.is_simd_aligned()
653                    && grad.is_simd_aligned()
654                {
655                    Self::update_momentum_simd_avx2(m_ptr, grad_ptr, beta1, beta1_complement, size);
656                    return;
657                }
658            }
659
660            // Scalar fallback
661            for i in 0..size {
662                *m_ptr.add(i) = beta1 * *m_ptr.add(i) + beta1_complement * *grad_ptr.add(i);
663            }
664        }
665    }
666
667    #[cfg(target_arch = "x86_64")]
668    #[inline]
669    unsafe fn update_momentum_simd_avx2(
670        m_ptr: *mut f32,
671        grad_ptr: *const f32,
672        beta1: f32,
673        beta1_complement: f32,
674        size: usize,
675    ) {
676        let beta1_vec = _mm256_set1_ps(beta1);
677        let beta1_comp_vec = _mm256_set1_ps(beta1_complement);
678        let simd_count = size / 8;
679
680        for i in 0..simd_count {
681            let offset = i * 8;
682            let m_vec = _mm256_loadu_ps(m_ptr.add(offset));
683            let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
684
685            let momentum_term = _mm256_mul_ps(beta1_vec, m_vec);
686            let gradient_term = _mm256_mul_ps(beta1_comp_vec, grad_vec);
687            let result = _mm256_add_ps(momentum_term, gradient_term);
688
689            _mm256_storeu_ps(m_ptr.add(offset), result);
690        }
691
692        // Handle remaining elements
693        for i in (simd_count * 8)..size {
694            *m_ptr.add(i) = beta1 * *m_ptr.add(i) + beta1_complement * *grad_ptr.add(i);
695        }
696    }
697
698    /// Update velocity (second moment estimate)
699    ///
700    /// Implements the velocity update rule: `v = beta2 * v + (1 - beta2) * grad^2`
701    /// This computes the exponentially decaying average of squared gradients,
702    /// providing adaptive learning rates that scale inversely with gradient magnitude.
703    ///
704    /// # Arguments
705    ///
706    /// * `velocity` - Velocity buffer to update in-place
707    /// * `grad` - Current gradient tensor
708    /// * `beta2` - Exponential decay rate for velocity
709    #[inline]
710    fn update_velocity(velocity: &mut Tensor, grad: &Tensor, beta2: f32) {
711        assert_eq!(
712            velocity.size(),
713            grad.size(),
714            "Velocity and gradient size mismatch"
715        );
716
717        let beta2_complement = 1.0 - beta2;
718
719        unsafe {
720            let v_ptr = velocity.as_mut_ptr();
721            let grad_ptr = grad.as_ptr();
722            let size = velocity.size();
723
724            #[cfg(target_arch = "x86_64")]
725            {
726                if is_x86_feature_detected!("avx2")
727                    && velocity.is_simd_aligned()
728                    && grad.is_simd_aligned()
729                {
730                    Self::update_velocity_simd_avx2(v_ptr, grad_ptr, beta2, beta2_complement, size);
731                    return;
732                }
733            }
734
735            // Scalar fallback
736            for i in 0..size {
737                let grad_val = *grad_ptr.add(i);
738                *v_ptr.add(i) = beta2 * *v_ptr.add(i) + beta2_complement * grad_val * grad_val;
739            }
740        }
741    }
742
743    #[cfg(target_arch = "x86_64")]
744    #[inline]
745    unsafe fn update_velocity_simd_avx2(
746        v_ptr: *mut f32,
747        grad_ptr: *const f32,
748        beta2: f32,
749        beta2_complement: f32,
750        size: usize,
751    ) {
752        let beta2_vec = _mm256_set1_ps(beta2);
753        let beta2_comp_vec = _mm256_set1_ps(beta2_complement);
754        let simd_count = size / 8;
755
756        for i in 0..simd_count {
757            let offset = i * 8;
758            let v_vec = _mm256_loadu_ps(v_ptr.add(offset));
759            let grad_vec = _mm256_loadu_ps(grad_ptr.add(offset));
760
761            let velocity_term = _mm256_mul_ps(beta2_vec, v_vec);
762            let grad_squared = _mm256_mul_ps(grad_vec, grad_vec);
763            let gradient_term = _mm256_mul_ps(beta2_comp_vec, grad_squared);
764            let result = _mm256_add_ps(velocity_term, gradient_term);
765
766            _mm256_storeu_ps(v_ptr.add(offset), result);
767        }
768
769        // Handle remaining elements
770        for i in (simd_count * 8)..size {
771            let grad_val = *grad_ptr.add(i);
772            *v_ptr.add(i) = beta2 * *v_ptr.add(i) + beta2_complement * grad_val * grad_val;
773        }
774    }
775
776    /// Apply bias correction to moment estimates
777    ///
778    /// Returns `tensor / (1 - beta^step)` for unbiasing the moment estimates.
779    /// This correction is necessary because the moment estimates are initialized
780    /// to zero, creating a bias towards zero that becomes negligible as training progresses.
781    ///
782    /// # Arguments
783    ///
784    /// * `tensor` - Moment estimate tensor to correct
785    /// * `bias_correction` - Bias correction factor (1 - beta^step)
786    ///
787    /// # Returns
788    ///
789    /// Bias-corrected tensor with the same shape as input
790    #[inline]
791    fn bias_correct(tensor: &Tensor, bias_correction: f32) -> Tensor {
792        let mut result = tensor.clone();
793        let correction_factor = 1.0 / bias_correction;
794
795        unsafe {
796            let ptr = result.as_mut_ptr();
797            let size = result.size();
798
799            #[cfg(target_arch = "x86_64")]
800            {
801                if is_x86_feature_detected!("avx2") && result.is_simd_aligned() {
802                    Self::bias_correct_simd_avx2(ptr, correction_factor, size);
803                    return result;
804                }
805            }
806
807            // Scalar fallback
808            for i in 0..size {
809                *ptr.add(i) *= correction_factor;
810            }
811        }
812
813        result
814    }
815
816    #[cfg(target_arch = "x86_64")]
817    #[inline]
818    unsafe fn bias_correct_simd_avx2(ptr: *mut f32, correction_factor: f32, size: usize) {
819        let factor_vec = _mm256_set1_ps(correction_factor);
820        let simd_count = size / 8;
821
822        for i in 0..simd_count {
823            let offset = i * 8;
824            let data_vec = _mm256_loadu_ps(ptr.add(offset));
825            let result = _mm256_mul_ps(data_vec, factor_vec);
826            _mm256_storeu_ps(ptr.add(offset), result);
827        }
828
829        // Handle remaining elements
830        for i in (simd_count * 8)..size {
831            *ptr.add(i) *= correction_factor;
832        }
833    }
834
835    /// Element-wise maximum for AMSGrad
836    ///
837    /// Updates first tensor in-place with `max(first, second)` for AMSGrad variant.
838    /// This maintains a running maximum of the second moment estimates, preventing
839    /// the learning rate from increasing over time and improving convergence stability.
840    ///
841    /// # Arguments
842    ///
843    /// * `first` - Tensor to update in-place with maximum values
844    /// * `second` - Tensor to compare against for maximum calculation
845    #[inline]
846    fn element_wise_max(first: &mut Tensor, second: &Tensor) {
847        assert_eq!(
848            first.size(),
849            second.size(),
850            "Tensor size mismatch for element-wise max"
851        );
852
853        unsafe {
854            let first_ptr = first.as_mut_ptr();
855            let second_ptr = second.as_ptr();
856            let size = first.size();
857
858            #[cfg(target_arch = "x86_64")]
859            {
860                if is_x86_feature_detected!("avx2")
861                    && first.is_simd_aligned()
862                    && second.is_simd_aligned()
863                {
864                    Self::element_wise_max_simd_avx2(first_ptr, second_ptr, size);
865                    return;
866                }
867            }
868
869            // Scalar fallback
870            for i in 0..size {
871                let a = *first_ptr.add(i);
872                let b = *second_ptr.add(i);
873                *first_ptr.add(i) = if a > b { a } else { b };
874            }
875        }
876    }
877
878    #[cfg(target_arch = "x86_64")]
879    #[inline]
880    unsafe fn element_wise_max_simd_avx2(first_ptr: *mut f32, second_ptr: *const f32, size: usize) {
881        let simd_count = size / 8;
882
883        for i in 0..simd_count {
884            let offset = i * 8;
885            let first_vec = _mm256_loadu_ps(first_ptr.add(offset));
886            let second_vec = _mm256_loadu_ps(second_ptr.add(offset));
887            let result = _mm256_max_ps(first_vec, second_vec);
888            _mm256_storeu_ps(first_ptr.add(offset), result);
889        }
890
891        // Handle remaining elements
892        for i in (simd_count * 8)..size {
893            let a = *first_ptr.add(i);
894            let b = *second_ptr.add(i);
895            *first_ptr.add(i) = if a > b { a } else { b };
896        }
897    }
898
899    /// Apply the final Adam parameter update
900    ///
901    /// Implements the core Adam update rule: `param = param - lr * m_hat / (sqrt(v_hat) + eps)`
902    /// This applies the bias-corrected momentum and velocity estimates to update
903    /// the parameter values, with adaptive learning rates that scale inversely
904    /// with the square root of the velocity estimates.
905    ///
906    /// # Arguments
907    ///
908    /// * `param` - Parameter tensor to update in-place
909    /// * `m_hat` - Bias-corrected first moment estimate
910    /// * `v_hat` - Bias-corrected second moment estimate
911    /// * `learning_rate` - Learning rate for the update
912    /// * `eps` - Small constant for numerical stability
913    #[inline]
914    fn apply_adam_update(
915        param: &mut Tensor,
916        m_hat: &Tensor,
917        v_hat: &Tensor,
918        learning_rate: f32,
919        eps: f32,
920    ) {
921        assert_eq!(
922            param.size(),
923            m_hat.size(),
924            "Parameter and momentum size mismatch"
925        );
926        assert_eq!(
927            param.size(),
928            v_hat.size(),
929            "Parameter and velocity size mismatch"
930        );
931
932        unsafe {
933            let param_ptr = param.as_mut_ptr();
934            let m_ptr = m_hat.as_ptr();
935            let v_ptr = v_hat.as_ptr();
936            let size = param.size();
937
938            #[cfg(target_arch = "x86_64")]
939            {
940                if is_x86_feature_detected!("avx2")
941                    && param.is_simd_aligned()
942                    && m_hat.is_simd_aligned()
943                    && v_hat.is_simd_aligned()
944                {
945                    Self::apply_adam_update_simd_avx2(
946                        param_ptr,
947                        m_ptr,
948                        v_ptr,
949                        learning_rate,
950                        eps,
951                        size,
952                    );
953                    return;
954                }
955            }
956
957            // Scalar fallback
958            for i in 0..size {
959                let m_val = *m_ptr.add(i);
960                let v_val = *v_ptr.add(i);
961                let denominator = v_val.sqrt() + eps;
962                *param_ptr.add(i) -= learning_rate * m_val / denominator;
963            }
964        }
965    }
966
967    #[cfg(target_arch = "x86_64")]
968    #[inline]
969    unsafe fn apply_adam_update_simd_avx2(
970        param_ptr: *mut f32,
971        m_ptr: *const f32,
972        v_ptr: *const f32,
973        learning_rate: f32,
974        eps: f32,
975        size: usize,
976    ) {
977        let lr_vec = _mm256_set1_ps(learning_rate);
978        let eps_vec = _mm256_set1_ps(eps);
979        let simd_count = size / 8;
980
981        for i in 0..simd_count {
982            let offset = i * 8;
983            let param_vec = _mm256_loadu_ps(param_ptr.add(offset));
984            let m_vec = _mm256_loadu_ps(m_ptr.add(offset));
985            let v_vec = _mm256_loadu_ps(v_ptr.add(offset));
986
987            // sqrt(v_hat) + eps
988            let sqrt_v = _mm256_sqrt_ps(v_vec);
989            let denominator = _mm256_add_ps(sqrt_v, eps_vec);
990
991            // lr * m_hat / denominator
992            let lr_m = _mm256_mul_ps(lr_vec, m_vec);
993            let update = _mm256_div_ps(lr_m, denominator);
994
995            // param - update
996            let result = _mm256_sub_ps(param_vec, update);
997            _mm256_storeu_ps(param_ptr.add(offset), result);
998        }
999
1000        // Handle remaining elements
1001        for i in (simd_count * 8)..size {
1002            let m_val = *m_ptr.add(i);
1003            let v_val = *v_ptr.add(i);
1004            let denominator = v_val.sqrt() + eps;
1005            *param_ptr.add(i) -= learning_rate * m_val / denominator;
1006        }
1007    }
1008}
1009
1010impl Optimizer for Adam {
1011    /// Perform a single optimization step
1012    ///
1013    /// Updates all provided parameters based on their accumulated gradients using the Adam algorithm.
1014    /// Each parameter is updated according to the Adam update rule with bias correction
1015    /// and optional AMSGrad variant if enabled. All parameters must be linked to the optimizer
1016    /// before calling this method.
1017    ///
1018    /// # Arguments
1019    ///
1020    /// * `parameters` - Mutable slice of parameter references to update
1021    ///
1022    /// # Thread Safety
1023    ///
1024    /// This method is thread-safe as it takes mutable references to parameters,
1025    /// ensuring exclusive access during updates.
1026    ///
1027    /// # Performance
1028    ///
1029    /// - Uses SIMD optimization (AVX2) when available for 8x vectorization
1030    /// - Processes parameters in sequence for optimal cache usage
1031    /// - Maintains per-parameter state for momentum and velocity estimates
1032    ///
1033    /// # Panics
1034    ///
1035    /// Panics if any parameter is not linked to the optimizer
1036    fn step(&mut self, parameters: &mut [&mut Tensor]) {
1037        self.step_count += 1;
1038
1039        for param in parameters {
1040            if let Err(e) = self.update_parameter(param) {
1041                eprintln!("Warning: Failed to update parameter: {}", e);
1042            }
1043        }
1044    }
1045
1046    /// Zero out all parameter gradients
1047    ///
1048    /// Clears accumulated gradients for all provided parameters. This should be called
1049    /// before each backward pass to prevent gradient accumulation across multiple
1050    /// forward/backward passes. Also clears the global autograd gradient map.
1051    ///
1052    /// # Arguments
1053    ///
1054    /// * `parameters` - Mutable slice of parameter references to clear gradients for
1055    ///
1056    /// # Performance
1057    ///
1058    /// - Efficiently clears gradients using optimized tensor operations
1059    /// - Clears both per-tensor gradients and global autograd state
1060    /// - Thread-safe as it takes mutable references to parameters
1061    fn zero_grad(&mut self, parameters: &mut [&mut Tensor]) {
1062        for param in parameters {
1063            param.zero_grad();
1064        }
1065
1066        // Also clear gradient tracking gradient map
1067        crate::gradtrack::clear_gradients();
1068    }
1069
1070    /// Get the current learning rate
1071    ///
1072    /// Returns the current learning rate used for parameter updates.
1073    ///
1074    /// # Returns
1075    ///
1076    /// Current learning rate as f32
1077    fn learning_rate(&self) -> f32 {
1078        self.config.learning_rate
1079    }
1080
1081    /// Set the learning rate for all parameters
1082    ///
1083    /// Updates the learning rate for all parameters in the optimizer.
1084    /// This allows dynamic learning rate scheduling during training.
1085    ///
1086    /// # Arguments
1087    ///
1088    /// * `lr` - New learning rate value
1089    fn set_learning_rate(&mut self, lr: f32) {
1090        self.config.learning_rate = lr;
1091    }
1092}
1093
1094// Adam is automatically Send + Sync since it no longer contains raw pointers
1095// and all fields are Send + Sync (HashMap, usize, AdamConfig)
1096
1097#[cfg(test)]
1098mod tests {
1099    use super::*;
1100    use crate::tensor::core::Tensor;
1101
1102    /// Test Adam optimizer creation with default configuration
1103    ///
1104    /// Verifies that the optimizer is created correctly with default hyperparameters
1105    /// and that parameter states are properly initialized.
1106    #[test]
1107    fn test_adam_creation() {
1108        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1109        let bias = Tensor::zeros(vec![3, 1]).with_requires_grad();
1110
1111        let mut optimizer = Adam::new();
1112        optimizer.add_parameter(&weight);
1113        optimizer.add_parameter(&bias);
1114
1115        assert_eq!(optimizer.learning_rate(), 1e-3);
1116        assert_eq!(optimizer.config.beta1, 0.9);
1117        assert_eq!(optimizer.config.beta2, 0.999);
1118        assert_eq!(optimizer.parameter_count(), 2);
1119    }
1120
1121    /// Test Adam optimizer creation with custom configuration
1122    ///
1123    /// Verifies that custom hyperparameters are properly set and that
1124    /// the optimizer configuration matches the provided values.
1125    #[test]
1126    fn test_adam_with_config() {
1127        let weight = Tensor::ones(vec![5, 5]).with_requires_grad();
1128
1129        let config = AdamConfig {
1130            learning_rate: 1e-4,
1131            beta1: 0.95,
1132            beta2: 0.9999,
1133            weight_decay: 1e-5,
1134            amsgrad: true,
1135            ..Default::default()
1136        };
1137
1138        let mut optimizer = Adam::with_config(config);
1139        optimizer.add_parameter(&weight);
1140
1141        assert_eq!(optimizer.learning_rate(), 1e-4);
1142        assert_eq!(optimizer.config.beta1, 0.95);
1143        assert_eq!(optimizer.config.beta2, 0.9999);
1144        assert_eq!(optimizer.config.weight_decay, 1e-5);
1145        assert!(optimizer.config.amsgrad);
1146    }
1147
1148    /// Test Adam optimizer creation with custom learning rate
1149    ///
1150    /// Verifies that the convenience constructor properly sets the learning rate
1151    /// while maintaining default values for other hyperparameters.
1152    #[test]
1153    fn test_adam_with_learning_rate() {
1154        let weight = Tensor::ones(vec![3, 3]).with_requires_grad();
1155        let mut optimizer = Adam::with_learning_rate(5e-4);
1156        optimizer.add_parameter(&weight);
1157
1158        assert_eq!(optimizer.learning_rate(), 5e-4);
1159        assert_eq!(optimizer.config.beta1, 0.9); // Should use defaults for other params
1160    }
1161
1162    /// Test Adam step without gradients
1163    ///
1164    /// Verifies that the optimizer handles the case where parameters have no
1165    /// gradients gracefully, leaving parameters unchanged.
1166    #[test]
1167    fn test_adam_step_without_gradients() {
1168        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1169        let original_data = weight.clone();
1170
1171        let mut optimizer = Adam::new();
1172        optimizer.add_parameter(&weight);
1173
1174        optimizer.step(&mut [&mut weight]); // Should not update without gradients
1175
1176        // Parameters should remain unchanged without gradients
1177        for i in 0..weight.size() {
1178            unsafe {
1179                assert_eq!(*weight.as_ptr().add(i), *original_data.as_ptr().add(i));
1180            }
1181        }
1182    }
1183
1184    /// Test learning rate update functionality
1185    ///
1186    /// Verifies that the learning rate can be dynamically updated during training
1187    /// and that the optimizer uses the new learning rate for subsequent steps.
1188    #[test]
1189    fn test_learning_rate_update() {
1190        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1191        let mut optimizer = Adam::new();
1192        optimizer.add_parameter(&weight);
1193
1194        assert_eq!(optimizer.learning_rate(), 1e-3);
1195
1196        optimizer.set_learning_rate(1e-2);
1197        assert_eq!(optimizer.learning_rate(), 1e-2);
1198    }
1199
1200    /// Test gradient zeroing functionality
1201    ///
1202    /// Verifies that zero_grad properly clears accumulated gradients for all
1203    /// parameters and the global autograd state.
1204    #[test]
1205    fn test_zero_grad() {
1206        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1207
1208        // Check that zero_grad clears gradients
1209        let mut optimizer = Adam::new();
1210        optimizer.add_parameter(&weight);
1211        optimizer.zero_grad(&mut [&mut weight]);
1212
1213        // After zero_grad, there should be no accumulated gradients
1214        assert!(weight.grad_by_value().is_none());
1215    }
1216
1217    /// Test requires_grad assertion
1218    ///
1219    /// Verifies that the optimizer correctly panics when parameters do not
1220    /// have requires_grad set to true, ensuring proper gradient tracking.
1221    #[test]
1222    #[should_panic(expected = "Parameter must require gradients")]
1223    fn test_adam_requires_grad_assertion() {
1224        let weight = Tensor::ones(vec![2, 2]); // No requires_grad
1225        let mut optimizer = Adam::new();
1226        optimizer.add_parameter(&weight);
1227    }
1228
1229    /// Test AdamConfig default values
1230    ///
1231    /// Verifies that the default configuration matches PyTorch conventions
1232    /// and provides optimal settings for most training scenarios.
1233    #[test]
1234    fn test_adam_config_default() {
1235        let config = AdamConfig::default();
1236
1237        assert_eq!(config.learning_rate, 1e-3);
1238        assert_eq!(config.beta1, 0.9);
1239        assert_eq!(config.beta2, 0.999);
1240        assert_eq!(config.eps, 1e-8);
1241        assert_eq!(config.weight_decay, 0.0);
1242        assert!(!config.amsgrad);
1243    }
1244
1245    /// Test ParameterState creation and initialization
1246    ///
1247    /// Verifies that parameter states are properly initialized with zero tensors
1248    /// and that the step count starts at zero.
1249    #[test]
1250    fn test_parameter_state_creation() {
1251        let state = ParameterState::new(&[3, 4]);
1252
1253        assert_eq!(state.m.shape().dims, vec![3, 4]);
1254        assert_eq!(state.v.shape().dims, vec![3, 4]);
1255        assert!(state.v_hat_max.is_none());
1256        assert_eq!(state.step, 0);
1257
1258        // Verify tensors are zero-initialized
1259        for i in 0..state.m.size() {
1260            unsafe {
1261                assert_eq!(*state.m.as_ptr().add(i), 0.0);
1262                assert_eq!(*state.v.as_ptr().add(i), 0.0);
1263            }
1264        }
1265    }
1266
1267    /// Test parameter linking functionality
1268    ///
1269    /// Verifies that parameters can be linked and unlinked from the optimizer.
1270    #[test]
1271    fn test_parameter_linking() {
1272        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1273        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1274
1275        let mut optimizer = Adam::new();
1276
1277        // Initially no parameters linked
1278        assert_eq!(optimizer.parameter_count(), 0);
1279        assert!(!optimizer.is_parameter_linked(&weight));
1280        assert!(!optimizer.is_parameter_linked(&bias));
1281
1282        // Link weight
1283        optimizer.add_parameter(&weight);
1284        assert_eq!(optimizer.parameter_count(), 1);
1285        assert!(optimizer.is_parameter_linked(&weight));
1286        assert!(!optimizer.is_parameter_linked(&bias));
1287
1288        // Link bias
1289        optimizer.add_parameter(&bias);
1290        assert_eq!(optimizer.parameter_count(), 2);
1291        assert!(optimizer.is_parameter_linked(&weight));
1292        assert!(optimizer.is_parameter_linked(&bias));
1293
1294        // Unlink weight
1295        let was_linked = optimizer.unlink_parameter(&weight);
1296        assert!(was_linked);
1297        assert_eq!(optimizer.parameter_count(), 1);
1298        assert!(!optimizer.is_parameter_linked(&weight));
1299        assert!(optimizer.is_parameter_linked(&bias));
1300
1301        // Clear all states
1302        optimizer.clear_states();
1303        assert_eq!(optimizer.parameter_count(), 0);
1304        assert!(!optimizer.is_parameter_linked(&weight));
1305        assert!(!optimizer.is_parameter_linked(&bias));
1306    }
1307
1308    /// Test parameter linking with multiple parameters at once
1309    ///
1310    /// Verifies that multiple parameters can be linked simultaneously.
1311    #[test]
1312    fn test_add_multiple_parameters() {
1313        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1314        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1315        let weight2 = Tensor::ones(vec![3, 2]).with_requires_grad();
1316
1317        let mut optimizer = Adam::new();
1318
1319        // Link multiple parameters at once
1320        optimizer.add_parameters(&[&weight, &bias, &weight2]);
1321
1322        assert_eq!(optimizer.parameter_count(), 3);
1323        assert!(optimizer.is_parameter_linked(&weight));
1324        assert!(optimizer.is_parameter_linked(&bias));
1325        assert!(optimizer.is_parameter_linked(&weight2));
1326    }
1327
1328    /// Test stepping with unlinked parameter
1329    ///
1330    /// Verifies that the optimizer panics when trying to step with an unlinked parameter.
1331    #[test]
1332    #[should_panic(expected = "Parameter must be linked to optimizer before stepping")]
1333    fn test_step_with_unlinked_parameter() {
1334        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1335        let mut optimizer = Adam::new();
1336
1337        // Don't link the parameter
1338        optimizer.step(&mut [&mut weight]); // Should panic
1339    }
1340
1341    /// Test optimizer with actual gradients
1342    ///
1343    /// Verifies that the optimizer properly updates parameters when gradients are present.
1344    #[test]
1345    fn test_optimizer_with_gradients() {
1346        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1347        let original_data = weight.clone();
1348
1349        let mut optimizer = Adam::new();
1350        optimizer.add_parameter(&weight);
1351
1352        // Generate some gradients
1353        let output = weight.mul_scalar(2.0);
1354        let mut loss = output.sum();
1355        loss.backward(None);
1356
1357        // Step should update parameters
1358        optimizer.step(&mut [&mut weight]);
1359
1360        // Parameters should have changed
1361        let mut changed = false;
1362        for i in 0..weight.size() {
1363            unsafe {
1364                if (*weight.as_ptr().add(i) - *original_data.as_ptr().add(i)).abs() > 1e-6 {
1365                    changed = true;
1366                    break;
1367                }
1368            }
1369        }
1370        assert!(
1371            changed,
1372            "Parameters should have been updated by optimizer step"
1373        );
1374    }
1375
1376    /// Test optimizer with multiple parameters and gradients
1377    ///
1378    /// Verifies that the optimizer works correctly with multiple parameters.
1379    #[test]
1380    fn test_optimizer_multiple_parameters() {
1381        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1382        let mut bias = Tensor::zeros(vec![2, 2]).with_requires_grad(); // Same shape as weight
1383
1384        let mut optimizer = Adam::new();
1385        optimizer.add_parameter(&weight);
1386        optimizer.add_parameter(&bias);
1387
1388        // Generate gradients for both parameters
1389        let output = weight.mul_scalar(2.0).add_tensor(&bias);
1390        let mut loss = output.sum();
1391        loss.backward(None);
1392
1393        // Step should update both parameters
1394        optimizer.step(&mut [&mut weight, &mut bias]);
1395
1396        // Both parameters should have gradients
1397        assert!(weight.grad_by_value().is_some());
1398        assert!(bias.grad_by_value().is_some());
1399    }
1400
1401    /// Test optimizer with custom configuration and multiple steps
1402    ///
1403    /// Verifies that the optimizer works correctly with custom configuration over multiple steps.
1404    #[test]
1405    fn test_optimizer_custom_config_multiple_steps() {
1406        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1407
1408        let config = AdamConfig {
1409            learning_rate: 1e-4,
1410            beta1: 0.95,
1411            beta2: 0.9999,
1412            weight_decay: 1e-5,
1413            amsgrad: true,
1414            ..Default::default()
1415        };
1416
1417        let mut optimizer = Adam::with_config(config);
1418        optimizer.add_parameter(&weight);
1419
1420        // Multiple training steps
1421        for _ in 0..5 {
1422            // Generate gradients
1423            let output = weight.mul_scalar(2.0);
1424            let mut loss = output.sum();
1425            loss.backward(None);
1426
1427            // Step
1428            optimizer.step(&mut [&mut weight]);
1429            optimizer.zero_grad(&mut [&mut weight]);
1430        }
1431
1432        // Should complete without errors
1433        assert_eq!(optimizer.parameter_count(), 1);
1434        assert!(optimizer.is_parameter_linked(&weight));
1435    }
1436
1437    /// Test optimizer with different tensor shapes
1438    ///
1439    /// Verifies that the optimizer works correctly with various tensor shapes.
1440    #[test]
1441    fn test_optimizer_different_shapes() {
1442        let shapes = vec![
1443            vec![1],       // Scalar
1444            vec![3],       // 1D
1445            vec![2, 2],    // 2D square
1446            vec![2, 3],    // 2D rectangular
1447            vec![1, 1, 3], // 3D
1448            vec![2, 2, 2], // 3D cube
1449        ];
1450
1451        for shape in shapes {
1452            let mut tensor = Tensor::ones(shape.clone()).with_requires_grad();
1453            let mut optimizer = Adam::new();
1454            optimizer.add_parameter(&tensor);
1455
1456            // Generate gradients
1457            let output = tensor.mul_scalar(2.0);
1458            let mut loss = output.sum();
1459            loss.backward(None);
1460
1461            // Step should work for all shapes
1462            optimizer.step(&mut [&mut tensor]);
1463
1464            // Verify tensor is still valid
1465            assert_eq!(tensor.shape().dims, shape);
1466            assert!(tensor.requires_grad());
1467        }
1468    }
1469
1470    /// Test double parameter linking
1471    ///
1472    /// Verifies that linking the same parameter twice doesn't create duplicate states.
1473    #[test]
1474    fn test_double_parameter_linking() {
1475        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1476        let mut optimizer = Adam::new();
1477
1478        // Link parameter twice
1479        optimizer.add_parameter(&weight);
1480        optimizer.add_parameter(&weight);
1481
1482        // Should only have one state
1483        assert_eq!(optimizer.parameter_count(), 1);
1484        assert!(optimizer.is_parameter_linked(&weight));
1485    }
1486
1487    /// Test unlink non-linked parameter
1488    ///
1489    /// Verifies that unlinking a non-linked parameter returns false.
1490    #[test]
1491    fn test_unlink_non_linked_parameter() {
1492        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1493        let mut optimizer = Adam::new();
1494
1495        // Try to unlink a parameter that was never linked
1496        let was_linked = optimizer.unlink_parameter(&weight);
1497        assert!(!was_linked);
1498    }
1499
1500    /// Test parameter shape validation
1501    ///
1502    /// Verifies that parameters with different shapes can be linked correctly
1503    /// and maintain their shape information in the optimizer state.
1504    #[test]
1505    fn test_parameter_shape_validation() {
1506        let weight_1d = Tensor::ones(vec![5]).with_requires_grad();
1507        let weight_2d = Tensor::ones(vec![3, 4]).with_requires_grad();
1508        let weight_3d = Tensor::ones(vec![2, 3, 4]).with_requires_grad();
1509
1510        let mut optimizer = Adam::new();
1511
1512        // Test linking parameters with different shapes
1513        optimizer.add_parameter(&weight_1d);
1514        optimizer.add_parameter(&weight_2d);
1515        optimizer.add_parameter(&weight_3d);
1516
1517        assert_eq!(optimizer.parameter_count(), 3);
1518        assert!(optimizer.is_parameter_linked(&weight_1d));
1519        assert!(optimizer.is_parameter_linked(&weight_2d));
1520        assert!(optimizer.is_parameter_linked(&weight_3d));
1521
1522        // Test that each parameter maintains its shape after linking
1523        let state_1d = &optimizer.states[&weight_1d.id()];
1524        let state_2d = &optimizer.states[&weight_2d.id()];
1525        let state_3d = &optimizer.states[&weight_3d.id()];
1526
1527        assert_eq!(state_1d.m.shape().dims, vec![5]);
1528        assert_eq!(state_2d.m.shape().dims, vec![3, 4]);
1529        assert_eq!(state_3d.m.shape().dims, vec![2, 3, 4]);
1530    }
1531
1532    /// Test parameter relinking with shape consistency
1533    ///
1534    /// Verifies that when re-linking parameters (e.g., after deserialization),
1535    /// shape information is correctly maintained.
1536    #[test]
1537    fn test_parameter_relinking_shape_consistency() {
1538        // Create initial parameter and optimizer
1539        let mut weight_original = Tensor::ones(vec![3, 3]).with_requires_grad();
1540        let mut optimizer = Adam::new();
1541        optimizer.add_parameter(&weight_original);
1542
1543        // Perform a step to create state
1544        let output = weight_original.mul_scalar(2.0);
1545        let mut loss = output.sum();
1546        loss.backward(None);
1547        optimizer.step(&mut [&mut weight_original]);
1548
1549        // Verify state was created with correct shape
1550        let original_state = &optimizer.states[&weight_original.id()];
1551        assert_eq!(original_state.m.shape().dims, vec![3, 3]);
1552        assert_eq!(original_state.v.shape().dims, vec![3, 3]);
1553
1554        // Create new parameter with same shape (will get different ID)
1555        let weight_new = Tensor::ones(vec![3, 3]).with_requires_grad();
1556
1557        // This should create a new state since it's a different parameter
1558        optimizer.add_parameter(&weight_new);
1559
1560        // Should now have 2 states
1561        assert_eq!(optimizer.parameter_count(), 2);
1562
1563        // Both should be linked with correct shapes
1564        assert!(optimizer.is_parameter_linked(&weight_original));
1565        assert!(optimizer.is_parameter_linked(&weight_new));
1566
1567        let new_state = &optimizer.states[&weight_new.id()];
1568        assert_eq!(new_state.m.shape().dims, vec![3, 3]);
1569        assert_eq!(new_state.v.shape().dims, vec![3, 3]);
1570    }
1571
1572    /// Test large parameter count handling
1573    ///
1574    /// Verifies that the optimizer can handle many parameters efficiently
1575    /// and correctly manage linking/unlinking operations.
1576    #[test]
1577    fn test_large_parameter_count() {
1578        let mut optimizer = Adam::new();
1579        let mut params = Vec::new();
1580
1581        // Create 50 parameters of different shapes
1582        for i in 1..=50 {
1583            let param = Tensor::ones(vec![i]).with_requires_grad();
1584            optimizer.add_parameter(&param);
1585            params.push(param);
1586        }
1587
1588        assert_eq!(optimizer.parameter_count(), 50);
1589
1590        // Verify all parameters are linked
1591        for param in &params {
1592            assert!(optimizer.is_parameter_linked(param));
1593        }
1594
1595        // Test unlinking some parameters
1596        for param in params.iter().take(25).step_by(2) {
1597            assert!(optimizer.unlink_parameter(param));
1598        }
1599
1600        assert_eq!(optimizer.parameter_count(), 37); // 50 - 13 = 37
1601
1602        // Verify correct parameters are unlinked
1603        for (i, param) in params.iter().enumerate() {
1604            if i < 25 && i % 2 == 0 {
1605                assert!(!optimizer.is_parameter_linked(param));
1606            } else {
1607                assert!(optimizer.is_parameter_linked(param));
1608            }
1609        }
1610    }
1611
1612    /// Test clear_states functionality
1613    ///
1614    /// Verifies that clearing all states works correctly and allows
1615    /// re-adding parameters afterwards.
1616    #[test]
1617    fn test_clear_states_functionality() {
1618        let weight1 = Tensor::ones(vec![2, 2]).with_requires_grad();
1619        let weight2 = Tensor::ones(vec![3, 3]).with_requires_grad();
1620        let weight3 = Tensor::ones(vec![4, 4]).with_requires_grad();
1621
1622        let mut optimizer = Adam::new();
1623        optimizer.add_parameter(&weight1);
1624        optimizer.add_parameter(&weight2);
1625        optimizer.add_parameter(&weight3);
1626
1627        assert_eq!(optimizer.parameter_count(), 3);
1628
1629        // Clear all states
1630        optimizer.clear_states();
1631
1632        assert_eq!(optimizer.parameter_count(), 0);
1633        assert!(!optimizer.is_parameter_linked(&weight1));
1634        assert!(!optimizer.is_parameter_linked(&weight2));
1635        assert!(!optimizer.is_parameter_linked(&weight3));
1636
1637        // Should be able to re-add parameters after clearing
1638        optimizer.add_parameter(&weight1);
1639        assert_eq!(optimizer.parameter_count(), 1);
1640        assert!(optimizer.is_parameter_linked(&weight1));
1641    }
1642}