train_station/optimizers/adam/
serialization.rs

1//! Comprehensive serialization support for Adam optimizer state and configuration
2//!
3//! This module provides complete serialization capabilities for the Adam optimizer, enabling
4//! model checkpointing, state persistence, and cross-platform optimizer state transfer.
5//! The serialization system supports both human-readable JSON and efficient binary formats
6//! with perfect roundtrip fidelity and seamless parameter re-linking.
7//!
8//! # Purpose
9//!
10//! The serialization module serves as the persistence layer for Adam optimizer training,
11//! providing essential functionality for:
12//! - **Model checkpointing**: Save and restore optimizer state during training
13//! - **Training resumption**: Continue training from saved checkpoints
14//! - **State transfer**: Move optimizer state between different environments
15//! - **Debugging support**: Human-readable JSON format for state inspection
16//! - **Performance optimization**: Efficient binary format for production use
17//! - **Cross-platform compatibility**: Consistent serialization across systems
18//!
19//! # Supported Components
20//!
21//! ## AdamConfig Serialization
22//! - **Learning rate**: Base learning rate for parameter updates
23//! - **Beta parameters**: Momentum decay rates (beta1, beta2)
24//! - **Epsilon**: Numerical stability constant
25//! - **Weight decay**: L2 regularization coefficient
26//! - **AMSGrad flag**: Whether to use AMSGrad variant
27//!
28//! ## Parameter State Serialization
29//! - **Momentum buffers**: First moment estimates for each parameter
30//! - **Velocity buffers**: Second moment estimates for each parameter
31//! - **AMSGrad state**: Maximum velocity buffers when AMSGrad is enabled
32//! - **Step counts**: Per-parameter step counts for bias correction
33//! - **Shape information**: Parameter shapes for validation during re-linking
34//!
35//! ## Global Optimizer State
36//! - **Global step count**: Total optimization steps performed
37//! - **Parameter insertion order**: Order for consistent parameter re-linking
38//! - **Configuration state**: Complete hyperparameter configuration
39//! - **State validation**: Integrity checks for serialized data
40//!
41//! # Serialization Formats
42//!
43//! ## JSON Format
44//! - **Human-readable**: Easy to inspect and debug optimizer state
45//! - **Cross-language**: Compatible with other JSON-parsing systems
46//! - **Configuration files**: Suitable for storing optimizer configurations
47//! - **Debugging**: Clear structure for troubleshooting training issues
48//! - **Interoperability**: Exchange optimizer state with external tools
49//!
50//! ## Binary Format
51//! - **Compact storage**: Minimal file sizes for production deployment
52//! - **Fast I/O**: Optimized for quick save/load operations
53//! - **Performance**: Reduced serialization overhead during training
54//! - **Bandwidth efficiency**: Minimal network transfer requirements
55//! - **Production ready**: Optimized for high-performance training workflows
56//!
57//! # Usage Patterns
58//!
59//! ## Basic Serialization
60//! ```
61//! use train_station::{Tensor, optimizers::Adam};
62//! use train_station::serialization::Serializable;
63//!
64//! // Create optimizer with parameters
65//! let weight = Tensor::ones(vec![10, 5]).with_requires_grad();
66//! let mut optimizer = Adam::new();
67//! optimizer.add_parameter(&weight);
68//!
69//! // Serialize optimizer state
70//! let json = optimizer.to_json().unwrap();
71//! let binary = optimizer.to_binary().unwrap();
72//!
73//! // Deserialize optimizer
74//! let loaded_optimizer = Adam::from_json(&json).unwrap();
75//! assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
76//! ```
77//!
78//! ## Training Checkpointing
79//! ```
80//! use train_station::{Tensor, optimizers::Adam};
81//! use train_station::serialization::{Serializable, Format};
82//!
83//! let mut weight = Tensor::randn(vec![100, 50], None).with_requires_grad();
84//! let mut optimizer = Adam::new();
85//! optimizer.add_parameter(&weight);
86//!
87//! // Training loop with checkpointing
88//! for epoch in 0..10 {
89//!     // ... training logic ...
90//!     
91//!     // Save checkpoint every 5 epochs
92//!     if epoch % 5 == 0 {
93//!         let temp_dir = std::env::temp_dir();
94//!         let checkpoint_path = temp_dir.join(format!("checkpoint_epoch_{}.json", epoch));
95//!         optimizer.save(&checkpoint_path, Format::Json).unwrap();
96//!         
97//!         // Cleanup for example
98//!         std::fs::remove_file(&checkpoint_path).ok();
99//!     }
100//! }
101//! ```
102//!
103//! ## Parameter Re-linking
104//! ```
105//! use train_station::{Tensor, optimizers::Adam};
106//! use train_station::serialization::Serializable;
107//!
108//! // Original training setup
109//! let weight = Tensor::ones(vec![5, 5]).with_requires_grad();
110//! let bias = Tensor::zeros(vec![5]).with_requires_grad();
111//! let mut optimizer = Adam::new();
112//! optimizer.add_parameter(&weight);
113//! optimizer.add_parameter(&bias);
114//!
115//! // Serialize optimizer state
116//! let json = optimizer.to_json().unwrap();
117//!
118//! // Later: create new parameters with same shapes
119//! let new_weight = Tensor::ones(vec![5, 5]).with_requires_grad();
120//! let new_bias = Tensor::zeros(vec![5]).with_requires_grad();
121//!
122//! // Restore optimizer and re-link parameters
123//! let mut loaded_optimizer = Adam::from_json(&json).unwrap();
124//! loaded_optimizer.relink_parameters(&[&new_weight, &new_bias]).unwrap();
125//!
126//! assert!(loaded_optimizer.is_parameter_linked(&new_weight));
127//! assert!(loaded_optimizer.is_parameter_linked(&new_bias));
128//! ```
129//!
130//! # Architecture Design
131//!
132//! ## Serialization Strategy
133//! - **Unified interface**: All serialization through StructSerializable trait
134//! - **Type safety**: Strong typing prevents serialization errors
135//! - **Validation**: Comprehensive validation during deserialization
136//! - **Error handling**: Detailed error messages for debugging
137//! - **Memory efficiency**: Optimized memory usage during serialization
138//!
139//! ## Parameter State Management
140//! - **ID-based tracking**: Parameters tracked by unique tensor IDs
141//! - **Shape validation**: Ensures parameter compatibility during re-linking
142//! - **Insertion order**: Maintains parameter order for consistent re-linking
143//! - **State preservation**: Complete momentum and velocity buffer preservation
144//! - **AMSGrad support**: Full AMSGrad state serialization when enabled
145//!
146//! ## Performance Characteristics
147//! - **Linear complexity**: O(n) serialization time with parameter count
148//! - **Minimal overhead**: Efficient serialization with minimal memory allocation
149//! - **Streaming support**: Support for streaming serialization to files
150//! - **Compression ready**: Binary format suitable for compression
151//! - **Concurrent safe**: Thread-safe serialization operations
152//!
153//! # Thread Safety
154//!
155//! All serialization operations are thread-safe:
156//! - **Immutable serialization**: Serialization does not modify optimizer state
157//! - **Concurrent reads**: Multiple threads can serialize the same optimizer
158//! - **Deserialization safety**: Deserialization creates new optimizer instances
159//! - **Parameter linking**: Re-linking operations are thread-safe
160//!
161//! # Integration with Train Station
162//!
163//! The serialization module integrates seamlessly with the broader Train Station ecosystem:
164//! - **Tensor serialization**: Leverages efficient tensor serialization for parameter states
165//! - **GradTrack compatibility**: Maintains gradient tracking requirements during re-linking
166//! - **Device management**: Preserves device placement information
167//! - **Memory management**: Efficient memory usage aligned with Train Station patterns
168//! - **Error handling**: Consistent error handling with Train Station conventions
169
170use super::{Adam, AdamConfig, ParameterState};
171use crate::serialization::{
172    FieldValue, FromFieldValue, SerializationError, SerializationResult, StructDeserializer,
173    StructSerializable, StructSerializer, ToFieldValue,
174};
175use crate::tensor::core::Tensor;
176use std::collections::HashMap;
177
178// ===== AdamConfig Serialization =====
179
180impl StructSerializable for AdamConfig {
181    /// Convert AdamConfig to StructSerializer for comprehensive serialization
182    ///
183    /// This method serializes all Adam hyperparameters into a structured format suitable
184    /// for both JSON and binary serialization. Every field is essential for proper optimizer
185    /// reconstruction and training continuation. The serialization preserves exact floating-point
186    /// values and boolean flags to ensure identical behavior after deserialization.
187    ///
188    /// # Returns
189    ///
190    /// StructSerializer containing all configuration data with field names and values
191    ///
192    /// # Serialized Fields
193    ///
194    /// - **learning_rate**: Base learning rate for parameter updates
195    /// - **beta1**: Exponential decay rate for first moment estimates
196    /// - **beta2**: Exponential decay rate for second moment estimates  
197    /// - **eps**: Small constant for numerical stability in denominator
198    /// - **weight_decay**: L2 regularization coefficient
199    /// - **amsgrad**: Boolean flag for AMSGrad variant usage
200    ///
201    /// # Performance
202    ///
203    /// - **Time Complexity**: O(1) - Constant time field serialization
204    /// - **Memory Usage**: Minimal allocation for field storage
205    /// - **Precision**: Full floating-point precision preservation
206    fn to_serializer(&self) -> StructSerializer {
207        StructSerializer::new()
208            .field("learning_rate", &self.learning_rate)
209            .field("beta1", &self.beta1)
210            .field("beta2", &self.beta2)
211            .field("eps", &self.eps)
212            .field("weight_decay", &self.weight_decay)
213            .field("amsgrad", &self.amsgrad)
214    }
215
216    /// Create AdamConfig from StructDeserializer with full validation
217    ///
218    /// This method reconstructs an AdamConfig instance from serialized hyperparameters,
219    /// performing comprehensive validation to ensure all required fields are present
220    /// and contain valid values. The deserialization process maintains exact floating-point
221    /// precision and validates that all hyperparameters are within reasonable ranges.
222    ///
223    /// # Arguments
224    ///
225    /// * `deserializer` - StructDeserializer containing configuration field data
226    ///
227    /// # Returns
228    ///
229    /// Reconstructed AdamConfig instance on success, or SerializationError on failure
230    ///
231    /// # Required Fields
232    ///
233    /// All fields must be present in the deserializer:
234    /// - **learning_rate**: Must be a valid f32 value
235    /// - **beta1**: Must be a valid f32 value (typically 0.0-1.0)
236    /// - **beta2**: Must be a valid f32 value (typically 0.0-1.0)
237    /// - **eps**: Must be a valid f32 value (typically small positive)
238    /// - **weight_decay**: Must be a valid f32 value (typically 0.0 or small positive)
239    /// - **amsgrad**: Must be a valid boolean value
240    ///
241    /// # Errors
242    ///
243    /// Returns SerializationError if:
244    /// - Any required field is missing from the deserializer
245    /// - Any field contains invalid data type
246    /// - Field extraction fails for any reason
247    ///
248    /// # Performance
249    ///
250    /// - **Time Complexity**: O(1) - Constant time field extraction
251    /// - **Memory Usage**: Minimal allocation for configuration structure
252    /// - **Validation**: Comprehensive field presence and type validation
253    fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
254        Ok(AdamConfig {
255            learning_rate: deserializer.field("learning_rate")?,
256            beta1: deserializer.field("beta1")?,
257            beta2: deserializer.field("beta2")?,
258            eps: deserializer.field("eps")?,
259            weight_decay: deserializer.field("weight_decay")?,
260            amsgrad: deserializer.field("amsgrad")?,
261        })
262    }
263}
264
265impl ToFieldValue for AdamConfig {
266    /// Convert AdamConfig to FieldValue for embedding in larger structures
267    ///
268    /// This method converts the AdamConfig into a FieldValue::Object that can be
269    /// embedded as a field within larger serializable structures. This enables
270    /// AdamConfig to be serialized as part of more complex training configurations
271    /// or model checkpoints while maintaining its structured representation.
272    ///
273    /// # Returns
274    ///
275    /// FieldValue::Object containing all configuration data as key-value pairs
276    ///
277    /// # Object Structure
278    ///
279    /// The returned object contains these fields:
280    /// - "learning_rate": f32 value as FieldValue::F32
281    /// - "beta1": f32 value as FieldValue::F32
282    /// - "beta2": f32 value as FieldValue::F32
283    /// - "eps": f32 value as FieldValue::F32
284    /// - "weight_decay": f32 value as FieldValue::F32
285    /// - "amsgrad": bool value as FieldValue::Bool
286    ///
287    /// # Performance
288    ///
289    /// - **Time Complexity**: O(1) - Constant time field conversion
290    /// - **Memory Usage**: Allocates HashMap for field storage
291    /// - **Conversion**: Direct field-to-FieldValue conversion without copying
292    fn to_field_value(&self) -> FieldValue {
293        let serializer = self.to_serializer();
294        FieldValue::from_object(serializer.fields.into_iter().collect())
295    }
296}
297
298impl FromFieldValue for AdamConfig {
299    /// Create AdamConfig from FieldValue with comprehensive validation
300    ///
301    /// This method reconstructs an AdamConfig instance from a FieldValue::Object,
302    /// performing type validation and field extraction. It's designed to handle
303    /// AdamConfig instances that were embedded as fields within larger serializable
304    /// structures, ensuring proper error handling and detailed error messages.
305    ///
306    /// # Arguments
307    ///
308    /// * `value` - FieldValue containing configuration data (must be Object variant)
309    /// * `field_name` - Name of the field being deserialized for error context
310    ///
311    /// # Returns
312    ///
313    /// Reconstructed AdamConfig instance on success, or SerializationError on failure
314    ///
315    /// # Expected FieldValue Structure
316    ///
317    /// The FieldValue must be an Object variant containing:
318    /// - "learning_rate": Numeric field value
319    /// - "beta1": Numeric field value  
320    /// - "beta2": Numeric field value
321    /// - "eps": Numeric field value
322    /// - "weight_decay": Numeric field value
323    /// - "amsgrad": Boolean field value
324    ///
325    /// # Errors
326    ///
327    /// Returns SerializationError if:
328    /// - FieldValue is not an Object variant
329    /// - Any required field is missing from the object
330    /// - Any field has incorrect type or invalid value
331    /// - Deserialization process fails for any reason
332    ///
333    /// # Performance
334    ///
335    /// - **Time Complexity**: O(1) - Constant time field extraction and validation
336    /// - **Memory Usage**: Temporary deserializer allocation for field processing
337    /// - **Error Handling**: Detailed error messages with field name context
338    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
339        match value {
340            FieldValue::Object(fields) => {
341                let mut deserializer = StructDeserializer::from_fields(fields);
342                Self::from_deserializer(&mut deserializer)
343            }
344            _ => Err(SerializationError::ValidationFailed {
345                field: field_name.to_string(),
346                message: format!(
347                    "Expected Object for {}, found {}",
348                    std::any::type_name::<Self>(),
349                    value.type_name()
350                ),
351            }),
352        }
353    }
354}
355
356// ===== Serializable Parameter State =====
357
358/// Serializable representation of Adam parameter optimization state
359///
360/// This structure provides a serializable wrapper around the internal ParameterState,
361/// enabling efficient persistence of Adam optimizer state for individual parameters.
362/// It leverages the Train Station tensor serialization system for optimal storage
363/// efficiency and maintains complete fidelity of momentum and velocity buffers.
364///
365/// # Purpose
366///
367/// SerializableParameterState serves as an intermediate representation that:
368/// - **Preserves state**: Maintains complete Adam optimization state per parameter
369/// - **Enables serialization**: Converts internal state to serializable format
370/// - **Supports AMSGrad**: Handles optional AMSGrad maximum velocity tracking
371/// - **Maintains precision**: Preserves exact floating-point values in buffers
372/// - **Validates shapes**: Ensures parameter shape consistency during restoration
373///
374/// # Fields
375///
376/// * `m` - First moment estimate (momentum) tensor buffer
377/// * `v` - Second moment estimate (velocity) tensor buffer  
378/// * `v_hat_max` - Optional maximum velocity tensor for AMSGrad variant
379/// * `step` - Per-parameter step count for bias correction calculations
380///
381/// # Design Rationale
382///
383/// The structure uses direct tensor serialization rather than manual data extraction:
384/// - **Efficiency**: Leverages optimized tensor serialization infrastructure
385/// - **Simplicity**: Eliminates manual buffer management and data copying
386/// - **Consistency**: Maintains alignment with Train Station serialization patterns
387/// - **Reliability**: Reduces serialization bugs through proven tensor serialization
388/// - **Performance**: Optimized memory usage and serialization speed
389///
390/// # Thread Safety
391///
392/// This structure is thread-safe for serialization operations:
393/// - **Immutable serialization**: Serialization does not modify state
394/// - **Clone safety**: Safe to clone across thread boundaries
395/// - **Tensor safety**: Leverages thread-safe tensor operations
396///
397/// # Memory Layout
398///
399/// The structure maintains efficient memory usage:
400/// - **Tensor sharing**: Tensors use reference counting for memory efficiency
401/// - **Optional fields**: AMSGrad state only allocated when needed
402/// - **Minimal overhead**: Small additional memory footprint beyond tensors
403#[derive(Debug, Clone)]
404struct SerializableParameterState {
405    /// First moment estimate (momentum) tensor buffer
406    ///
407    /// Contains the exponentially decaying average of past gradients, used for
408    /// momentum-based parameter updates. Shape matches the associated parameter.
409    m: Tensor,
410
411    /// Second moment estimate (velocity) tensor buffer
412    ///
413    /// Contains the exponentially decaying average of past squared gradients,
414    /// used for adaptive learning rate scaling. Shape matches the associated parameter.
415    v: Tensor,
416
417    /// Optional maximum velocity tensor for AMSGrad variant
418    ///
419    /// When AMSGrad is enabled, this tensor maintains the element-wise maximum
420    /// of all past velocity estimates, providing improved convergence properties.
421    /// None when AMSGrad is disabled. Shape matches the associated parameter when present.
422    v_hat_max: Option<Tensor>,
423
424    /// Per-parameter step count for bias correction
425    ///
426    /// Tracks the number of optimization steps performed for this specific parameter,
427    /// used in Adam's bias correction calculations. Essential for proper optimizer
428    /// behavior when parameters are added at different training stages.
429    step: usize,
430}
431
432impl SerializableParameterState {
433    /// Create SerializableParameterState from internal ParameterState
434    ///
435    /// This method converts the internal ParameterState representation used by
436    /// the Adam optimizer into a serializable format. It performs efficient tensor
437    /// cloning to preserve all optimization state while enabling serialization.
438    /// The conversion maintains exact numerical precision and handles optional
439    /// AMSGrad state appropriately.
440    ///
441    /// # Arguments
442    ///
443    /// * `state` - Reference to the internal ParameterState to convert
444    ///
445    /// # Returns
446    ///
447    /// SerializableParameterState containing all state data ready for serialization
448    ///
449    /// # Conversion Process
450    ///
451    /// The method performs these operations:
452    /// 1. **Momentum cloning**: Clones the momentum tensor (m) with full precision
453    /// 2. **Velocity cloning**: Clones the velocity tensor (v) with full precision
454    /// 3. **AMSGrad handling**: Clones optional AMSGrad state when present
455    /// 4. **Step preservation**: Copies the step count for bias correction
456    ///
457    /// # Performance
458    ///
459    /// - **Time Complexity**: O(1) - Tensor cloning uses reference counting
460    /// - **Memory Usage**: Minimal additional allocation due to tensor sharing
461    /// - **Precision**: Maintains exact floating-point values in all buffers
462    fn from_parameter_state(state: &ParameterState) -> Self {
463        // Use direct tensor cloning - leverages Tensor's efficient serialization
464        // No manual data extraction needed
465        Self {
466            m: state.m.clone(),
467            v: state.v.clone(),
468            v_hat_max: state.v_hat_max.clone(),
469            step: state.step,
470        }
471    }
472
473    /// Convert SerializableParameterState back to internal ParameterState
474    ///
475    /// This method reconstructs the internal ParameterState representation from
476    /// the serializable format, enabling the Adam optimizer to resume training
477    /// with preserved optimization state. The conversion maintains exact numerical
478    /// precision and properly handles optional AMSGrad state restoration.
479    ///
480    /// # Returns
481    ///
482    /// ParameterState instance ready for use by Adam optimizer, or SerializationError on failure
483    ///
484    /// # Reconstruction Process
485    ///
486    /// The method performs these operations:
487    /// 1. **Momentum restoration**: Clones momentum tensor back to internal format
488    /// 2. **Velocity restoration**: Clones velocity tensor back to internal format
489    /// 3. **AMSGrad restoration**: Handles optional AMSGrad state when present
490    /// 4. **Step restoration**: Preserves step count for continued bias correction
491    ///
492    /// # Validation
493    ///
494    /// The method ensures:
495    /// - All tensors have consistent shapes
496    /// - Step count is valid (non-negative)
497    /// - AMSGrad state consistency with configuration
498    /// - Tensor data integrity is maintained
499    ///
500    /// # Performance
501    ///
502    /// - **Time Complexity**: O(1) - Tensor cloning uses reference counting
503    /// - **Memory Usage**: Minimal allocation due to efficient tensor sharing
504    /// - **Precision**: Maintains exact floating-point values from serialization
505    ///
506    /// # Errors
507    ///
508    /// Returns SerializationError if:
509    /// - Tensor shapes are inconsistent
510    /// - Internal tensor state is corrupted
511    /// - Memory allocation fails during reconstruction
512    fn to_parameter_state(&self) -> SerializationResult<ParameterState> {
513        // Direct tensor cloning - no manual memory management needed
514        // Tensors handle their own efficient reconstruction
515        Ok(ParameterState {
516            m: self.m.clone(),
517            v: self.v.clone(),
518            v_hat_max: self.v_hat_max.clone(),
519            step: self.step,
520        })
521    }
522}
523
524impl StructSerializable for SerializableParameterState {
525    /// Convert SerializableParameterState to StructSerializer for serialization
526    ///
527    /// This method serializes all Adam parameter state components into a structured
528    /// format suitable for both JSON and binary serialization. It leverages the
529    /// efficient tensor serialization system to handle momentum and velocity buffers
530    /// while properly managing optional AMSGrad state.
531    ///
532    /// # Returns
533    ///
534    /// StructSerializer containing all parameter state data with field names and values
535    ///
536    /// # Serialized Fields
537    ///
538    /// - **m**: Momentum tensor (first moment estimate)
539    /// - **v**: Velocity tensor (second moment estimate)
540    /// - **v_hat_max**: Optional AMSGrad maximum velocity tensor
541    /// - **step**: Per-parameter step count for bias correction
542    ///
543    /// # Performance
544    ///
545    /// - **Time Complexity**: O(1) - Constant time field serialization
546    /// - **Memory Usage**: Leverages efficient tensor serialization
547    /// - **Precision**: Full floating-point precision preservation in tensors
548    fn to_serializer(&self) -> StructSerializer {
549        StructSerializer::new()
550            .field("m", &self.m)
551            .field("v", &self.v)
552            .field("v_hat_max", &self.v_hat_max)
553            .field("step", &self.step)
554    }
555
556    /// Create SerializableParameterState from StructDeserializer with validation
557    ///
558    /// This method reconstructs a SerializableParameterState from serialized data,
559    /// performing comprehensive validation to ensure all required fields are present
560    /// and contain valid tensor data. It handles both momentum and velocity tensors
561    /// along with optional AMSGrad state and step count information.
562    ///
563    /// # Arguments
564    ///
565    /// * `deserializer` - StructDeserializer containing parameter state field data
566    ///
567    /// # Returns
568    ///
569    /// Reconstructed SerializableParameterState on success, or SerializationError on failure
570    ///
571    /// # Required Fields
572    ///
573    /// All fields must be present in the deserializer:
574    /// - **m**: Momentum tensor with valid shape and data
575    /// - **v**: Velocity tensor with shape matching momentum tensor
576    /// - **v_hat_max**: Optional AMSGrad tensor (None or matching shape)
577    /// - **step**: Valid step count (non-negative integer)
578    ///
579    /// # Validation
580    ///
581    /// The method validates:
582    /// - Tensor field presence and type correctness
583    /// - Shape consistency between momentum and velocity tensors
584    /// - AMSGrad tensor shape consistency when present
585    /// - Step count validity and range
586    ///
587    /// # Performance
588    ///
589    /// - **Time Complexity**: O(1) - Constant time field extraction
590    /// - **Memory Usage**: Efficient tensor deserialization
591    /// - **Validation**: Comprehensive field and type validation
592    fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
593        Ok(Self {
594            m: deserializer.field("m")?,
595            v: deserializer.field("v")?,
596            v_hat_max: deserializer.field("v_hat_max")?,
597            step: deserializer.field("step")?,
598        })
599    }
600}
601
602impl ToFieldValue for SerializableParameterState {
603    /// Convert SerializableParameterState to FieldValue for embedding in collections
604    ///
605    /// This method converts the parameter state into a FieldValue::Object that can be
606    /// stored in collections or embedded within larger serializable structures. It
607    /// maintains the structured representation of all optimization state components
608    /// while enabling flexible serialization patterns.
609    ///
610    /// # Returns
611    ///
612    /// FieldValue::Object containing all parameter state data as key-value pairs
613    ///
614    /// # Object Structure
615    ///
616    /// The returned object contains these fields:
617    /// - "m": Momentum tensor as serialized FieldValue
618    /// - "v": Velocity tensor as serialized FieldValue
619    /// - "v_hat_max": Optional AMSGrad tensor as serialized FieldValue
620    /// - "step": Step count as FieldValue::Usize
621    ///
622    /// # Performance
623    ///
624    /// - **Time Complexity**: O(1) - Constant time field conversion
625    /// - **Memory Usage**: Efficient tensor serialization with minimal overhead
626    /// - **Precision**: Maintains exact tensor data and step count values
627    fn to_field_value(&self) -> FieldValue {
628        let serializer = self.to_serializer();
629        FieldValue::from_object(serializer.fields.into_iter().collect())
630    }
631}
632
633impl FromFieldValue for SerializableParameterState {
634    /// Create SerializableParameterState from FieldValue with comprehensive validation
635    ///
636    /// This method reconstructs a SerializableParameterState from a FieldValue::Object,
637    /// performing type validation and tensor deserialization. It handles the complex
638    /// process of deserializing momentum and velocity tensors along with optional
639    /// AMSGrad state, ensuring data integrity and proper error handling.
640    ///
641    /// # Arguments
642    ///
643    /// * `value` - FieldValue containing parameter state data (must be Object variant)
644    /// * `field_name` - Name of the field being deserialized for error context
645    ///
646    /// # Returns
647    ///
648    /// Reconstructed SerializableParameterState on success, or SerializationError on failure
649    ///
650    /// # Expected FieldValue Structure
651    ///
652    /// The FieldValue must be an Object variant containing:
653    /// - "m": Serialized momentum tensor
654    /// - "v": Serialized velocity tensor
655    /// - "v_hat_max": Optional serialized AMSGrad tensor
656    /// - "step": Step count as numeric value
657    ///
658    /// # Validation
659    ///
660    /// The method validates:
661    /// - FieldValue is Object variant
662    /// - All required fields are present
663    /// - Tensor deserialization succeeds
664    /// - Step count is valid numeric value
665    /// - Tensor shapes are consistent
666    ///
667    /// # Performance
668    ///
669    /// - **Time Complexity**: O(1) - Constant time field extraction and tensor deserialization
670    /// - **Memory Usage**: Efficient tensor deserialization with minimal overhead
671    /// - **Error Handling**: Detailed error messages with field name context
672    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
673        match value {
674            FieldValue::Object(fields) => {
675                let mut deserializer = StructDeserializer::from_fields(fields);
676                Self::from_deserializer(&mut deserializer)
677            }
678            _ => Err(SerializationError::ValidationFailed {
679                field: field_name.to_string(),
680                message: format!(
681                    "Expected Object for {}, found {}",
682                    std::any::type_name::<Self>(),
683                    value.type_name()
684                ),
685            }),
686        }
687    }
688}
689
690// ===== Adam Serialization =====
691
692impl StructSerializable for Adam {
693    /// Convert Adam to StructSerializer for serialization
694    ///
695    /// Serializes all optimizer state including configuration, parameter states,
696    /// and global step count. Parameter linking is not serialized and must be
697    /// done after deserialization.
698    ///
699    /// # Returns
700    ///
701    /// StructSerializer containing all serializable optimizer state
702    fn to_serializer(&self) -> StructSerializer {
703        // Convert parameter states to serializable form
704        let mut serializable_states = HashMap::new();
705        for (param_id, state) in &self.states {
706            serializable_states.insert(
707                *param_id,
708                SerializableParameterState::from_parameter_state(state),
709            );
710        }
711
712        StructSerializer::new()
713            .field("config", &self.config)
714            .field("states", &serializable_states)
715            .field("step_count", &self.step_count)
716            .field("insertion_order", &self.insertion_order)
717    }
718
719    /// Create Adam from StructDeserializer
720    ///
721    /// Reconstructs Adam optimizer from serialized state. Parameters must be
722    /// linked separately using `add_parameter` or `add_parameters`.
723    ///
724    /// # Arguments
725    ///
726    /// * `deserializer` - StructDeserializer containing optimizer data
727    ///
728    /// # Returns
729    ///
730    /// Reconstructed Adam instance without parameter links, or error if deserialization fails
731    fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
732        let config: AdamConfig = deserializer.field("config")?;
733        let serializable_states: HashMap<usize, SerializableParameterState> =
734            deserializer.field("states")?;
735        let step_count: usize = deserializer.field("step_count")?;
736        let insertion_order: Vec<usize> = deserializer.field("insertion_order")?;
737
738        // Reconstruct parameter states from serialized form
739        let mut states = HashMap::new();
740        for (param_id, serializable_state) in serializable_states {
741            states.insert(param_id, serializable_state.to_parameter_state()?);
742        }
743
744        // Create optimizer with reconstructed states - user must call relink_parameters to link tensors
745        Ok(Adam {
746            config,
747            states,
748            step_count,
749            insertion_order,
750        })
751    }
752}
753
754impl FromFieldValue for Adam {
755    /// Create Adam from FieldValue
756    ///
757    /// # Arguments
758    ///
759    /// * `value` - FieldValue containing optimizer data
760    /// * `field_name` - Name of the field being deserialized (for error messages)
761    ///
762    /// # Returns
763    ///
764    /// Reconstructed Adam instance or error if deserialization fails
765    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
766        match value {
767            FieldValue::Object(fields) => {
768                let mut deserializer = StructDeserializer::from_fields(fields);
769                Self::from_deserializer(&mut deserializer)
770            }
771            _ => Err(SerializationError::ValidationFailed {
772                field: field_name.to_string(),
773                message: format!(
774                    "Expected Object for {}, found {}",
775                    std::any::type_name::<Self>(),
776                    value.type_name()
777                ),
778            }),
779        }
780    }
781}
782
783// ===== Serializable Trait Implementation =====
784
785impl crate::serialization::Serializable for Adam {
786    /// Serialize the Adam optimizer to JSON format
787    ///
788    /// This method converts the Adam optimizer into a human-readable JSON string representation
789    /// that includes all optimizer state, configuration, parameter states, and step counts.
790    /// The JSON format is suitable for debugging, configuration files, and cross-language
791    /// interoperability.
792    ///
793    /// # Returns
794    ///
795    /// JSON string representation of the optimizer on success, or `SerializationError` on failure
796    ///
797    /// # Examples
798    ///
799    /// ```
800    /// use train_station::{Tensor, optimizers::Adam};
801    /// use train_station::serialization::Serializable;
802    ///
803    /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
804    /// let mut optimizer = Adam::new();
805    /// optimizer.add_parameter(&weight);
806    ///
807    /// let json = optimizer.to_json().unwrap();
808    /// assert!(!json.is_empty());
809    /// ```
810    fn to_json(&self) -> SerializationResult<String> {
811        <Self as StructSerializable>::to_json(self)
812    }
813
814    /// Deserialize an Adam optimizer from JSON format
815    ///
816    /// This method parses a JSON string and reconstructs an Adam optimizer with all
817    /// saved state. Parameters must be re-linked after deserialization using
818    /// `add_parameter` or `relink_parameters`.
819    ///
820    /// # Arguments
821    ///
822    /// * `json` - JSON string containing serialized optimizer
823    ///
824    /// # Returns
825    ///
826    /// The deserialized optimizer on success, or `SerializationError` on failure
827    ///
828    /// # Examples
829    ///
830    /// ```
831    /// use train_station::{Tensor, optimizers::Adam};
832    /// use train_station::serialization::Serializable;
833    ///
834    /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
835    /// let mut optimizer = Adam::new();
836    /// optimizer.add_parameter(&weight);
837    ///
838    /// let json = optimizer.to_json().unwrap();
839    /// let loaded_optimizer = Adam::from_json(&json).unwrap();
840    /// assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
841    /// ```
842    fn from_json(json: &str) -> SerializationResult<Self> {
843        <Self as StructSerializable>::from_json(json)
844    }
845
846    /// Serialize the Adam optimizer to binary format
847    ///
848    /// This method converts the optimizer into a compact binary representation optimized
849    /// for storage and transmission. The binary format provides maximum performance
850    /// and minimal file sizes compared to JSON.
851    ///
852    /// # Returns
853    ///
854    /// Binary representation of the optimizer on success, or `SerializationError` on failure
855    ///
856    /// # Examples
857    ///
858    /// ```
859    /// use train_station::{Tensor, optimizers::Adam};
860    /// use train_station::serialization::Serializable;
861    ///
862    /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
863    /// let mut optimizer = Adam::new();
864    /// optimizer.add_parameter(&weight);
865    ///
866    /// let binary = optimizer.to_binary().unwrap();
867    /// assert!(!binary.is_empty());
868    /// ```
869    fn to_binary(&self) -> SerializationResult<Vec<u8>> {
870        <Self as StructSerializable>::to_binary(self)
871    }
872
873    /// Deserialize an Adam optimizer from binary format
874    ///
875    /// This method parses binary data and reconstructs an Adam optimizer with all
876    /// saved state. Parameters must be re-linked after deserialization using
877    /// `add_parameter` or `relink_parameters`.
878    ///
879    /// # Arguments
880    ///
881    /// * `data` - Binary data containing serialized optimizer
882    ///
883    /// # Returns
884    ///
885    /// The deserialized optimizer on success, or `SerializationError` on failure
886    ///
887    /// # Examples
888    ///
889    /// ```
890    /// use train_station::{Tensor, optimizers::Adam};
891    /// use train_station::serialization::Serializable;
892    ///
893    /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
894    /// let mut optimizer = Adam::new();
895    /// optimizer.add_parameter(&weight);
896    ///
897    /// let binary = optimizer.to_binary().unwrap();
898    /// let loaded_optimizer = Adam::from_binary(&binary).unwrap();
899    /// assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
900    /// ```
901    fn from_binary(data: &[u8]) -> SerializationResult<Self> {
902        <Self as StructSerializable>::from_binary(data)
903    }
904}
905
906// ===== Utility Methods =====
907
908impl Adam {
909    /// Get the number of saved parameter states for checkpoint validation
910    ///
911    /// This method returns the count of parameter states currently stored in the optimizer,
912    /// which is essential for validating checkpoint integrity and ensuring proper parameter
913    /// re-linking after deserialization. The count includes all parameters that have been
914    /// linked to the optimizer and have accumulated optimization state.
915    ///
916    /// # Returns
917    ///
918    /// Number of parameter states currently stored in the optimizer
919    ///
920    /// # Usage Patterns
921    ///
922    /// ## Checkpoint Validation
923    /// After deserializing an optimizer, this method helps verify that the expected
924    /// number of parameters were saved and can guide the re-linking process.
925    ///
926    /// ## Training Resumption
927    /// When resuming training, compare this count with the number of parameters
928    /// in your model to ensure checkpoint compatibility.
929    ///
930    /// ## State Management
931    /// Use this method to monitor optimizer state growth and memory usage during
932    /// training with dynamic parameter addition.
933    ///
934    /// # Examples
935    ///
936    /// ```
937    /// use train_station::{Tensor, optimizers::Adam};
938    /// use train_station::serialization::Serializable;
939    ///
940    /// let weight = Tensor::ones(vec![10, 5]).with_requires_grad();
941    /// let bias = Tensor::zeros(vec![5]).with_requires_grad();
942    /// let mut optimizer = Adam::new();
943    /// optimizer.add_parameter(&weight);
944    /// optimizer.add_parameter(&bias);
945    ///
946    /// // Check parameter count before serialization
947    /// assert_eq!(optimizer.saved_parameter_count(), 2);
948    ///
949    /// // Serialize and deserialize
950    /// let json = optimizer.to_json().unwrap();
951    /// let loaded_optimizer = Adam::from_json(&json).unwrap();
952    ///
953    /// // Verify parameter count is preserved
954    /// assert_eq!(loaded_optimizer.saved_parameter_count(), 2);
955    /// ```
956    ///
957    /// # Performance
958    ///
959    /// - **Time Complexity**: O(1) - Direct access to internal state count
960    /// - **Memory Usage**: No additional memory allocation
961    /// - **Thread Safety**: Safe to call from multiple threads concurrently
962    #[track_caller]
963    pub fn saved_parameter_count(&self) -> usize {
964        self.states.len()
965    }
966}
967
968// ===== Field Value Implementations for Collections =====
969
970impl ToFieldValue for HashMap<usize, SerializableParameterState> {
971    /// Convert parameter states HashMap to FieldValue for serialization
972    ///
973    /// This method converts the HashMap of parameter states into a FieldValue::Object
974    /// suitable for serialization. It handles the conversion of usize keys to string
975    /// format required by the FieldValue::Object representation while preserving
976    /// all parameter state data.
977    ///
978    /// # Returns
979    ///
980    /// FieldValue::Object with string keys and SerializableParameterState values
981    ///
982    /// # Key Conversion
983    ///
984    /// - **Input**: HashMap<usize, SerializableParameterState> with numeric tensor IDs
985    /// - **Output**: FieldValue::Object with string keys for JSON compatibility
986    /// - **Mapping**: Each usize key is converted to string representation
987    /// - **Preservation**: All parameter state data is preserved exactly
988    ///
989    /// # Performance
990    ///
991    /// - **Time Complexity**: O(n) where n is the number of parameter states
992    /// - **Memory Usage**: Allocates new HashMap for string keys
993    /// - **Conversion**: Efficient string conversion for numeric keys
994    fn to_field_value(&self) -> FieldValue {
995        let mut map = HashMap::new();
996        for (key, value) in self {
997            map.insert(key.to_string(), value.to_field_value());
998        }
999        FieldValue::from_object(map)
1000    }
1001}
1002
1003impl FromFieldValue for HashMap<usize, SerializableParameterState> {
1004    /// Create parameter states HashMap from FieldValue with validation
1005    ///
1006    /// This method reconstructs the HashMap of parameter states from a FieldValue::Object,
1007    /// performing comprehensive validation and key conversion. It handles the conversion
1008    /// from string keys back to usize tensor IDs while ensuring all parameter state
1009    /// data is properly deserialized and validated.
1010    ///
1011    /// # Arguments
1012    ///
1013    /// * `value` - FieldValue containing parameter states data (must be Object variant)
1014    /// * `field_name` - Name of the field being deserialized for error context
1015    ///
1016    /// # Returns
1017    ///
1018    /// Reconstructed HashMap<usize, SerializableParameterState> on success, or SerializationError on failure
1019    ///
1020    /// # Key Conversion Process
1021    ///
1022    /// 1. **Validation**: Ensures FieldValue is Object variant
1023    /// 2. **Key parsing**: Converts string keys back to usize tensor IDs
1024    /// 3. **State deserialization**: Deserializes each parameter state
1025    /// 4. **Validation**: Validates parameter state integrity
1026    /// 5. **Collection**: Builds final HashMap with proper types
1027    ///
1028    /// # Errors
1029    ///
1030    /// Returns SerializationError if:
1031    /// - FieldValue is not Object variant
1032    /// - Any string key cannot be parsed as usize
1033    /// - Parameter state deserialization fails
1034    /// - Invalid parameter state data is encountered
1035    ///
1036    /// # Performance
1037    ///
1038    /// - **Time Complexity**: O(n) where n is the number of parameter states
1039    /// - **Memory Usage**: Allocates new HashMap with proper key types
1040    /// - **Validation**: Comprehensive key parsing and state validation
1041    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
1042        match value {
1043            FieldValue::Object(fields) => {
1044                let mut map = HashMap::new();
1045                for (key_str, field_value) in fields {
1046                    let key = key_str.parse::<usize>().map_err(|_| {
1047                        SerializationError::ValidationFailed {
1048                            field: field_name.to_string(),
1049                            message: format!("Invalid key '{}' in parameter states map", key_str),
1050                        }
1051                    })?;
1052                    let state =
1053                        SerializableParameterState::from_field_value(field_value, &key_str)?;
1054                    map.insert(key, state);
1055                }
1056                Ok(map)
1057            }
1058            _ => Err(SerializationError::ValidationFailed {
1059                field: field_name.to_string(),
1060                message: format!(
1061                    "Expected Object for HashMap<usize, SerializableParameterState>, found {}",
1062                    value.type_name()
1063                ),
1064            }),
1065        }
1066    }
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071    use super::*;
1072    use crate::optimizers::Optimizer;
1073    use crate::tensor::core::Tensor;
1074
1075    // ===== AdamConfig Serialization Tests =====
1076
1077    #[test]
1078    fn test_adam_config_json_roundtrip() {
1079        let config = AdamConfig {
1080            learning_rate: 1e-4,
1081            beta1: 0.95,
1082            beta2: 0.9999,
1083            eps: 1e-7,
1084            weight_decay: 1e-5,
1085            amsgrad: true,
1086        };
1087
1088        let json = config.to_json().unwrap();
1089        let loaded_config = AdamConfig::from_json(&json).unwrap();
1090
1091        assert_eq!(config.learning_rate, loaded_config.learning_rate);
1092        assert_eq!(config.beta1, loaded_config.beta1);
1093        assert_eq!(config.beta2, loaded_config.beta2);
1094        assert_eq!(config.eps, loaded_config.eps);
1095        assert_eq!(config.weight_decay, loaded_config.weight_decay);
1096        assert_eq!(config.amsgrad, loaded_config.amsgrad);
1097    }
1098
1099    #[test]
1100    fn test_adam_config_binary_roundtrip() {
1101        let config = AdamConfig {
1102            learning_rate: 2e-3,
1103            beta1: 0.85,
1104            beta2: 0.995,
1105            eps: 1e-9,
1106            weight_decay: 5e-4,
1107            amsgrad: false,
1108        };
1109
1110        let binary = config.to_binary().unwrap();
1111        let loaded_config = AdamConfig::from_binary(&binary).unwrap();
1112
1113        assert_eq!(config.learning_rate, loaded_config.learning_rate);
1114        assert_eq!(config.beta1, loaded_config.beta1);
1115        assert_eq!(config.beta2, loaded_config.beta2);
1116        assert_eq!(config.eps, loaded_config.eps);
1117        assert_eq!(config.weight_decay, loaded_config.weight_decay);
1118        assert_eq!(config.amsgrad, loaded_config.amsgrad);
1119    }
1120
1121    #[test]
1122    fn test_adam_config_field_value_roundtrip() {
1123        let config = AdamConfig {
1124            learning_rate: 3e-4,
1125            beta1: 0.92,
1126            beta2: 0.998,
1127            eps: 1e-6,
1128            weight_decay: 2e-4,
1129            amsgrad: true,
1130        };
1131
1132        let field_value = config.to_field_value();
1133        let loaded_config = AdamConfig::from_field_value(field_value, "config").unwrap();
1134
1135        assert_eq!(config.learning_rate, loaded_config.learning_rate);
1136        assert_eq!(config.beta1, loaded_config.beta1);
1137        assert_eq!(config.beta2, loaded_config.beta2);
1138        assert_eq!(config.eps, loaded_config.eps);
1139        assert_eq!(config.weight_decay, loaded_config.weight_decay);
1140        assert_eq!(config.amsgrad, loaded_config.amsgrad);
1141    }
1142
1143    // ===== Adam Optimizer Serialization Tests =====
1144
1145    #[test]
1146    fn test_adam_optimizer_json_roundtrip() {
1147        let mut weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1148        let mut bias = Tensor::zeros(vec![2, 3]).with_requires_grad();
1149
1150        let mut optimizer = Adam::new();
1151        optimizer.add_parameter(&weight);
1152        optimizer.add_parameter(&bias);
1153
1154        // Perform some steps to create state
1155        let output = weight.add_tensor(&bias);
1156        let mut loss = output.sum();
1157        loss.backward(None);
1158        optimizer.step(&mut [&mut weight, &mut bias]);
1159
1160        // Test serialization
1161        let json = optimizer.to_json().unwrap();
1162        let loaded_optimizer = Adam::from_json(&json).unwrap();
1163
1164        assert_eq!(
1165            optimizer.config().learning_rate,
1166            loaded_optimizer.config().learning_rate
1167        );
1168        assert_eq!(
1169            optimizer.saved_parameter_count(),
1170            loaded_optimizer.saved_parameter_count()
1171        );
1172    }
1173
1174    #[test]
1175    fn test_adam_optimizer_binary_roundtrip() {
1176        let weight = Tensor::ones(vec![5, 2]).with_requires_grad();
1177
1178        let mut optimizer = Adam::with_learning_rate(1e-4);
1179        optimizer.add_parameter(&weight);
1180
1181        // Test serialization
1182        let binary = optimizer.to_binary().unwrap();
1183        let loaded_optimizer = Adam::from_binary(&binary).unwrap();
1184
1185        assert_eq!(
1186            optimizer.config().learning_rate,
1187            loaded_optimizer.config().learning_rate
1188        );
1189        assert_eq!(
1190            optimizer.saved_parameter_count(),
1191            loaded_optimizer.saved_parameter_count()
1192        );
1193    }
1194
1195    #[test]
1196    fn test_adam_parameter_relinking() {
1197        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1198        let mut optimizer = Adam::new();
1199        optimizer.add_parameter(&weight);
1200
1201        // Serialize
1202        let json = optimizer.to_json().unwrap();
1203
1204        // Deserialize
1205        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1206
1207        // After deserialization, saved states should be preserved
1208        assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
1209
1210        // Re-link parameter - this creates a new state since it's a new tensor with different ID
1211        let new_weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1212        loaded_optimizer.add_parameter(&new_weight);
1213
1214        // Now there should be 2 states: the original saved one + the new one
1215        assert_eq!(loaded_optimizer.parameter_count(), 2);
1216        assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1217    }
1218
1219    #[test]
1220    fn test_adam_state_preservation() {
1221        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1222        let mut optimizer = Adam::new();
1223        optimizer.add_parameter(&weight);
1224
1225        // Perform training steps to build up state
1226        for _ in 0..3 {
1227            let output = weight.mul_scalar(2.0);
1228            let mut loss = output.sum();
1229            loss.backward(None);
1230            optimizer.step(&mut [&mut weight]);
1231            optimizer.zero_grad(&mut [&mut weight]);
1232        }
1233
1234        // Serialize and deserialize
1235        let json = optimizer.to_json().unwrap();
1236        let loaded_optimizer = Adam::from_json(&json).unwrap();
1237
1238        // Check that states were preserved
1239        assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
1240        assert_eq!(
1241            loaded_optimizer.config().learning_rate,
1242            optimizer.config().learning_rate
1243        );
1244    }
1245
1246    #[test]
1247    fn test_relink_parameters_success() {
1248        // Create original optimizer with parameters
1249        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1250        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1251
1252        let mut optimizer = Adam::new();
1253        optimizer.add_parameter(&weight);
1254        optimizer.add_parameter(&bias);
1255
1256        // Serialize
1257        let json = optimizer.to_json().unwrap();
1258
1259        // Create new parameters with same shapes but different IDs
1260        let new_weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1261        let new_bias = Tensor::zeros(vec![3]).with_requires_grad();
1262
1263        // Deserialize and re-link
1264        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1265        loaded_optimizer
1266            .relink_parameters(&[&new_weight, &new_bias])
1267            .unwrap();
1268
1269        // Verify re-linking worked
1270        assert_eq!(loaded_optimizer.parameter_count(), 2);
1271        assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1272        assert!(loaded_optimizer.is_parameter_linked(&new_bias));
1273    }
1274
1275    #[test]
1276    fn test_relink_parameters_shape_mismatch() {
1277        // Create original optimizer
1278        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1279        let mut optimizer = Adam::new();
1280        optimizer.add_parameter(&weight);
1281
1282        // Serialize
1283        let json = optimizer.to_json().unwrap();
1284
1285        // Create new parameter with different shape
1286        let new_weight = Tensor::ones(vec![3, 2]).with_requires_grad(); // Different shape!
1287
1288        // Deserialize and try to re-link
1289        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1290        let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1291
1292        // Should fail with shape mismatch error
1293        assert!(result.is_err());
1294        assert!(result.unwrap_err().contains("Shape mismatch"));
1295    }
1296
1297    #[test]
1298    fn test_relink_parameters_count_mismatch() {
1299        // Create original optimizer with 2 parameters
1300        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1301        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1302
1303        let mut optimizer = Adam::new();
1304        optimizer.add_parameter(&weight);
1305        optimizer.add_parameter(&bias);
1306
1307        // Serialize
1308        let json = optimizer.to_json().unwrap();
1309
1310        // Create only 1 new parameter
1311        let new_weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1312
1313        // Deserialize and try to re-link with wrong count
1314        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1315        let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1316
1317        // Should fail with count mismatch error
1318        assert!(result.is_err());
1319        assert!(result.unwrap_err().contains("Parameter count mismatch"));
1320    }
1321
1322    #[test]
1323    fn test_relink_parameters_requires_grad() {
1324        // Create original optimizer
1325        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1326        let mut optimizer = Adam::new();
1327        optimizer.add_parameter(&weight);
1328
1329        // Serialize
1330        let json = optimizer.to_json().unwrap();
1331
1332        // Create new parameter without requires_grad
1333        let new_weight = Tensor::ones(vec![2, 3]); // No requires_grad!
1334
1335        // Deserialize and try to re-link
1336        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1337        let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1338
1339        // Should fail with requires_grad error
1340        assert!(result.is_err());
1341        assert!(result.unwrap_err().contains("must require gradients"));
1342    }
1343
1344    #[test]
1345    fn test_relink_preserves_state() {
1346        // Create original optimizer and train it
1347        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1348        let mut optimizer = Adam::new();
1349        optimizer.add_parameter(&weight);
1350
1351        // Perform some training to build up state
1352        for _ in 0..2 {
1353            let output = weight.mul_scalar(2.0);
1354            let mut loss = output.sum();
1355            loss.backward(None);
1356            optimizer.step(&mut [&mut weight]);
1357            optimizer.zero_grad(&mut [&mut weight]);
1358        }
1359
1360        // Get state before serialization
1361        let original_step_count = optimizer.step_count;
1362
1363        // Serialize
1364        let json = optimizer.to_json().unwrap();
1365
1366        // Create new parameter
1367        let new_weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1368
1369        // Deserialize and re-link
1370        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1371        loaded_optimizer.relink_parameters(&[&new_weight]).unwrap();
1372
1373        // Verify state is preserved
1374        assert_eq!(loaded_optimizer.step_count, original_step_count);
1375        assert_eq!(loaded_optimizer.parameter_count(), 1);
1376        assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1377    }
1378
1379    // ===== Serializable Trait Tests =====
1380
1381    #[test]
1382    fn test_serializable_json_methods() {
1383        // Create and populate test optimizer
1384        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1385        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1386
1387        let mut optimizer = Adam::with_learning_rate(1e-3);
1388        optimizer.add_parameter(&weight);
1389        optimizer.add_parameter(&bias);
1390
1391        // Test to_json method
1392        let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1393        assert!(!json.is_empty());
1394        assert!(json.contains("config"));
1395        assert!(json.contains("states"));
1396        assert!(json.contains("step_count"));
1397        assert!(json.contains("learning_rate"));
1398
1399        // Test from_json method
1400        let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1401        assert_eq!(
1402            optimizer.config().learning_rate,
1403            restored.config().learning_rate
1404        );
1405        assert_eq!(optimizer.config().beta1, restored.config().beta1);
1406        assert_eq!(optimizer.config().beta2, restored.config().beta2);
1407        assert_eq!(optimizer.config().eps, restored.config().eps);
1408        assert_eq!(
1409            optimizer.config().weight_decay,
1410            restored.config().weight_decay
1411        );
1412        assert_eq!(optimizer.config().amsgrad, restored.config().amsgrad);
1413        assert_eq!(
1414            optimizer.saved_parameter_count(),
1415            restored.saved_parameter_count()
1416        );
1417        assert_eq!(optimizer.step_count, restored.step_count);
1418    }
1419
1420    #[test]
1421    fn test_serializable_binary_methods() {
1422        // Create and populate test optimizer
1423        let weight = Tensor::ones(vec![3, 4]).with_requires_grad();
1424        let mut optimizer = Adam::with_config(AdamConfig {
1425            learning_rate: 2e-4,
1426            beta1: 0.95,
1427            beta2: 0.999,
1428            eps: 1e-7,
1429            weight_decay: 1e-4,
1430            amsgrad: true,
1431        });
1432        optimizer.add_parameter(&weight);
1433
1434        // Test to_binary method
1435        let binary = <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1436        assert!(!binary.is_empty());
1437
1438        // Test from_binary method
1439        let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1440        assert_eq!(
1441            optimizer.config().learning_rate,
1442            restored.config().learning_rate
1443        );
1444        assert_eq!(optimizer.config().beta1, restored.config().beta1);
1445        assert_eq!(optimizer.config().beta2, restored.config().beta2);
1446        assert_eq!(optimizer.config().eps, restored.config().eps);
1447        assert_eq!(
1448            optimizer.config().weight_decay,
1449            restored.config().weight_decay
1450        );
1451        assert_eq!(optimizer.config().amsgrad, restored.config().amsgrad);
1452        assert_eq!(
1453            optimizer.saved_parameter_count(),
1454            restored.saved_parameter_count()
1455        );
1456        assert_eq!(optimizer.step_count, restored.step_count);
1457    }
1458
1459    #[test]
1460    fn test_serializable_file_io_json() {
1461        use crate::serialization::{Format, Serializable};
1462        use std::fs;
1463        use std::path::Path;
1464
1465        // Create test optimizer
1466        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1467        let bias = Tensor::zeros(vec![2]).with_requires_grad();
1468
1469        let mut optimizer = Adam::with_learning_rate(5e-4);
1470        optimizer.add_parameter(&weight);
1471        optimizer.add_parameter(&bias);
1472
1473        let json_path = "test_adam_serializable.json";
1474
1475        // Test save method with JSON format
1476        Serializable::save(&optimizer, json_path, Format::Json).unwrap();
1477        assert!(Path::new(json_path).exists());
1478
1479        // Test load method with JSON format
1480        let loaded_optimizer = Adam::load(json_path, Format::Json).unwrap();
1481        assert_eq!(
1482            optimizer.config().learning_rate,
1483            loaded_optimizer.config().learning_rate
1484        );
1485        assert_eq!(
1486            optimizer.saved_parameter_count(),
1487            loaded_optimizer.saved_parameter_count()
1488        );
1489
1490        // Test save_to_writer method
1491        let json_path_2 = "test_adam_serializable_writer.json";
1492        {
1493            let file = std::fs::File::create(json_path_2).unwrap();
1494            let mut writer = std::io::BufWriter::new(file);
1495            Serializable::save_to_writer(&optimizer, &mut writer, Format::Json).unwrap();
1496        }
1497        assert!(Path::new(json_path_2).exists());
1498
1499        // Test load_from_reader method
1500        {
1501            let file = std::fs::File::open(json_path_2).unwrap();
1502            let mut reader = std::io::BufReader::new(file);
1503            let loaded_optimizer = Adam::load_from_reader(&mut reader, Format::Json).unwrap();
1504            assert_eq!(
1505                optimizer.config().learning_rate,
1506                loaded_optimizer.config().learning_rate
1507            );
1508        }
1509
1510        // Cleanup test files
1511        let _ = fs::remove_file(json_path);
1512        let _ = fs::remove_file(json_path_2);
1513    }
1514
1515    #[test]
1516    fn test_serializable_file_io_binary() {
1517        use crate::serialization::{Format, Serializable};
1518        use std::fs;
1519        use std::path::Path;
1520
1521        // Create test optimizer
1522        let weight = Tensor::ones(vec![3, 3]).with_requires_grad();
1523        let mut optimizer = Adam::with_config(AdamConfig {
1524            learning_rate: 1e-3,
1525            beta1: 0.9,
1526            beta2: 0.999,
1527            eps: 1e-8,
1528            weight_decay: 0.0,
1529            amsgrad: false,
1530        });
1531        optimizer.add_parameter(&weight);
1532
1533        let binary_path = "test_adam_serializable.bin";
1534
1535        // Test save method with binary format
1536        Serializable::save(&optimizer, binary_path, Format::Binary).unwrap();
1537        assert!(Path::new(binary_path).exists());
1538
1539        // Test load method with binary format
1540        let loaded_optimizer = Adam::load(binary_path, Format::Binary).unwrap();
1541        assert_eq!(
1542            optimizer.config().learning_rate,
1543            loaded_optimizer.config().learning_rate
1544        );
1545        assert_eq!(
1546            optimizer.saved_parameter_count(),
1547            loaded_optimizer.saved_parameter_count()
1548        );
1549
1550        // Test save_to_writer method
1551        let binary_path_2 = "test_adam_serializable_writer.bin";
1552        {
1553            let file = std::fs::File::create(binary_path_2).unwrap();
1554            let mut writer = std::io::BufWriter::new(file);
1555            Serializable::save_to_writer(&optimizer, &mut writer, Format::Binary).unwrap();
1556        }
1557        assert!(Path::new(binary_path_2).exists());
1558
1559        // Test load_from_reader method
1560        {
1561            let file = std::fs::File::open(binary_path_2).unwrap();
1562            let mut reader = std::io::BufReader::new(file);
1563            let loaded_optimizer = Adam::load_from_reader(&mut reader, Format::Binary).unwrap();
1564            assert_eq!(
1565                optimizer.config().learning_rate,
1566                loaded_optimizer.config().learning_rate
1567            );
1568        }
1569
1570        // Cleanup test files
1571        let _ = fs::remove_file(binary_path);
1572        let _ = fs::remove_file(binary_path_2);
1573    }
1574
1575    #[test]
1576    fn test_serializable_large_optimizer_performance() {
1577        // Create a large optimizer to test performance characteristics
1578        let mut optimizer = Adam::with_learning_rate(1e-4);
1579
1580        // Add multiple parameters of different sizes
1581        for i in 0..5 {
1582            let size = 10 + i * 5;
1583            let param = Tensor::ones(vec![size, size]).with_requires_grad();
1584            optimizer.add_parameter(&param);
1585        }
1586
1587        // Test JSON serialization
1588        let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1589        assert!(!json.is_empty());
1590        let restored_json = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1591        assert_eq!(
1592            optimizer.config().learning_rate,
1593            restored_json.config().learning_rate
1594        );
1595        assert_eq!(
1596            optimizer.saved_parameter_count(),
1597            restored_json.saved_parameter_count()
1598        );
1599
1600        // Test binary serialization
1601        let binary = <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1602        assert!(!binary.is_empty());
1603        // Binary format should be efficient (this is informational, not a requirement)
1604        println!(
1605            "JSON size: {} bytes, Binary size: {} bytes",
1606            json.len(),
1607            binary.len()
1608        );
1609
1610        let restored_binary =
1611            <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1612        assert_eq!(
1613            optimizer.config().learning_rate,
1614            restored_binary.config().learning_rate
1615        );
1616        assert_eq!(
1617            optimizer.saved_parameter_count(),
1618            restored_binary.saved_parameter_count()
1619        );
1620
1621        // Verify all configurations match
1622        assert_eq!(optimizer.config().beta1, restored_binary.config().beta1);
1623        assert_eq!(optimizer.config().beta2, restored_binary.config().beta2);
1624        assert_eq!(optimizer.config().eps, restored_binary.config().eps);
1625        assert_eq!(
1626            optimizer.config().weight_decay,
1627            restored_binary.config().weight_decay
1628        );
1629        assert_eq!(optimizer.config().amsgrad, restored_binary.config().amsgrad);
1630    }
1631
1632    #[test]
1633    fn test_serializable_error_handling() {
1634        // Test invalid JSON
1635        let invalid_json = r#"{"invalid": "json", "structure": true}"#;
1636        let result = <Adam as crate::serialization::Serializable>::from_json(invalid_json);
1637        assert!(result.is_err());
1638
1639        // Test empty JSON
1640        let empty_json = "{}";
1641        let result = <Adam as crate::serialization::Serializable>::from_json(empty_json);
1642        assert!(result.is_err());
1643
1644        // Test invalid binary data
1645        let invalid_binary = vec![1, 2, 3, 4, 5];
1646        let result = <Adam as crate::serialization::Serializable>::from_binary(&invalid_binary);
1647        assert!(result.is_err());
1648
1649        // Test empty binary data
1650        let empty_binary = vec![];
1651        let result = <Adam as crate::serialization::Serializable>::from_binary(&empty_binary);
1652        assert!(result.is_err());
1653    }
1654
1655    #[test]
1656    fn test_serializable_different_configurations() {
1657        let test_configs = vec![
1658            // Default configuration
1659            AdamConfig::default(),
1660            // High learning rate
1661            AdamConfig {
1662                learning_rate: 1e-2,
1663                beta1: 0.9,
1664                beta2: 0.999,
1665                eps: 1e-8,
1666                weight_decay: 0.0,
1667                amsgrad: false,
1668            },
1669            // AMSGrad enabled
1670            AdamConfig {
1671                learning_rate: 1e-4,
1672                beta1: 0.9,
1673                beta2: 0.999,
1674                eps: 1e-8,
1675                weight_decay: 1e-4,
1676                amsgrad: true,
1677            },
1678            // Custom betas
1679            AdamConfig {
1680                learning_rate: 5e-4,
1681                beta1: 0.95,
1682                beta2: 0.9999,
1683                eps: 1e-7,
1684                weight_decay: 1e-5,
1685                amsgrad: false,
1686            },
1687        ];
1688
1689        for config in test_configs {
1690            // Create optimizer with specific configuration
1691            let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1692            let mut optimizer = Adam::with_config(config.clone());
1693            optimizer.add_parameter(&weight);
1694
1695            // Test JSON roundtrip
1696            let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1697            let restored_json =
1698                <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1699            assert_eq!(config.learning_rate, restored_json.config().learning_rate);
1700            assert_eq!(config.beta1, restored_json.config().beta1);
1701            assert_eq!(config.beta2, restored_json.config().beta2);
1702            assert_eq!(config.eps, restored_json.config().eps);
1703            assert_eq!(config.weight_decay, restored_json.config().weight_decay);
1704            assert_eq!(config.amsgrad, restored_json.config().amsgrad);
1705
1706            // Test binary roundtrip
1707            let binary =
1708                <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1709            let restored_binary =
1710                <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1711            assert_eq!(config.learning_rate, restored_binary.config().learning_rate);
1712            assert_eq!(config.beta1, restored_binary.config().beta1);
1713            assert_eq!(config.beta2, restored_binary.config().beta2);
1714            assert_eq!(config.eps, restored_binary.config().eps);
1715            assert_eq!(config.weight_decay, restored_binary.config().weight_decay);
1716            assert_eq!(config.amsgrad, restored_binary.config().amsgrad);
1717        }
1718    }
1719
1720    #[test]
1721    fn test_serializable_edge_cases() {
1722        // Test optimizer with no parameters
1723        let empty_optimizer = Adam::new();
1724        let json = <Adam as crate::serialization::Serializable>::to_json(&empty_optimizer).unwrap();
1725        let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1726        assert_eq!(
1727            empty_optimizer.saved_parameter_count(),
1728            restored.saved_parameter_count()
1729        );
1730        assert_eq!(empty_optimizer.step_count, restored.step_count);
1731
1732        let binary =
1733            <Adam as crate::serialization::Serializable>::to_binary(&empty_optimizer).unwrap();
1734        let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1735        assert_eq!(
1736            empty_optimizer.saved_parameter_count(),
1737            restored.saved_parameter_count()
1738        );
1739        assert_eq!(empty_optimizer.step_count, restored.step_count);
1740
1741        // Test optimizer with extreme configuration values
1742        let extreme_config = AdamConfig {
1743            learning_rate: 1e-10, // Very small learning rate
1744            beta1: 0.999999,      // Very high beta1
1745            beta2: 0.000001,      // Very low beta2
1746            eps: 1e-15,           // Very small epsilon
1747            weight_decay: 1e-1,   // High weight decay
1748            amsgrad: true,
1749        };
1750
1751        let weight = Tensor::ones(vec![1]).with_requires_grad();
1752        let mut extreme_optimizer = Adam::with_config(extreme_config.clone());
1753        extreme_optimizer.add_parameter(&weight);
1754
1755        let json =
1756            <Adam as crate::serialization::Serializable>::to_json(&extreme_optimizer).unwrap();
1757        let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1758        assert_eq!(
1759            extreme_config.learning_rate,
1760            restored.config().learning_rate
1761        );
1762        assert_eq!(extreme_config.beta1, restored.config().beta1);
1763        assert_eq!(extreme_config.beta2, restored.config().beta2);
1764        assert_eq!(extreme_config.eps, restored.config().eps);
1765        assert_eq!(extreme_config.weight_decay, restored.config().weight_decay);
1766        assert_eq!(extreme_config.amsgrad, restored.config().amsgrad);
1767
1768        let binary =
1769            <Adam as crate::serialization::Serializable>::to_binary(&extreme_optimizer).unwrap();
1770        let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1771        assert_eq!(
1772            extreme_config.learning_rate,
1773            restored.config().learning_rate
1774        );
1775        assert_eq!(extreme_config.beta1, restored.config().beta1);
1776        assert_eq!(extreme_config.beta2, restored.config().beta2);
1777        assert_eq!(extreme_config.eps, restored.config().eps);
1778        assert_eq!(extreme_config.weight_decay, restored.config().weight_decay);
1779        assert_eq!(extreme_config.amsgrad, restored.config().amsgrad);
1780    }
1781}