train_station/optimizers/adam/
serialization.rs

1//! Comprehensive serialization support for Adam optimizer state and configuration
2//!
3//! This module provides complete serialization capabilities for the Adam optimizer, enabling
4//! model checkpointing, state persistence, and cross-platform optimizer state transfer.
5//! The serialization system supports both human-readable JSON and efficient binary formats
6//! with perfect roundtrip fidelity and seamless parameter re-linking.
7//!
8//! # Purpose
9//!
10//! The serialization module serves as the persistence layer for Adam optimizer training,
11//! providing essential functionality for:
12//! - **Model checkpointing**: Save and restore optimizer state during training
13//! - **Training resumption**: Continue training from saved checkpoints
14//! - **State transfer**: Move optimizer state between different environments
15//! - **Debugging support**: Human-readable JSON format for state inspection
16//! - **Performance optimization**: Efficient binary format for production use
17//! - **Cross-platform compatibility**: Consistent serialization across systems
18//!
19//! # Supported Components
20//!
21//! ## AdamConfig Serialization
22//! - **Learning rate**: Base learning rate for parameter updates
23//! - **Beta parameters**: Momentum decay rates (beta1, beta2)
24//! - **Epsilon**: Numerical stability constant
25//! - **Weight decay**: L2 regularization coefficient
26//! - **AMSGrad flag**: Whether to use AMSGrad variant
27//!
28//! ## Parameter State Serialization
29//! - **Momentum buffers**: First moment estimates for each parameter
30//! - **Velocity buffers**: Second moment estimates for each parameter
31//! - **AMSGrad state**: Maximum velocity buffers when AMSGrad is enabled
32//! - **Step counts**: Per-parameter step counts for bias correction
33//! - **Shape information**: Parameter shapes for validation during re-linking
34//!
35//! ## Global Optimizer State
36//! - **Global step count**: Total optimization steps performed
37//! - **Parameter insertion order**: Order for consistent parameter re-linking
38//! - **Configuration state**: Complete hyperparameter configuration
39//! - **State validation**: Integrity checks for serialized data
40//!
41//! # Serialization Formats
42//!
43//! ## JSON Format
44//! - **Human-readable**: Easy to inspect and debug optimizer state
45//! - **Cross-language**: Compatible with other JSON-parsing systems
46//! - **Configuration files**: Suitable for storing optimizer configurations
47//! - **Debugging**: Clear structure for troubleshooting training issues
48//! - **Interoperability**: Exchange optimizer state with external tools
49//!
50//! ## Binary Format
51//! - **Compact storage**: Minimal file sizes for production deployment
52//! - **Fast I/O**: Optimized for quick save/load operations
53//! - **Performance**: Reduced serialization overhead during training
54//! - **Bandwidth efficiency**: Minimal network transfer requirements
55//! - **Production ready**: Optimized for high-performance training workflows
56//!
57//! # Usage Patterns
58//!
59//! ## Basic Serialization
60//! ```
61//! use train_station::{Tensor, optimizers::Adam};
62//! use train_station::serialization::Serializable;
63//!
64//! // Create optimizer with parameters
65//! let weight = Tensor::ones(vec![10, 5]).with_requires_grad();
66//! let mut optimizer = Adam::new();
67//! optimizer.add_parameter(&weight);
68//!
69//! // Serialize optimizer state
70//! let json = optimizer.to_json().unwrap();
71//! let binary = optimizer.to_binary().unwrap();
72//!
73//! // Deserialize optimizer
74//! let loaded_optimizer = Adam::from_json(&json).unwrap();
75//! assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
76//! ```
77//!
78//! ## Training Checkpointing
79//! ```
80//! use train_station::{Tensor, optimizers::Adam};
81//! use train_station::serialization::{Serializable, Format};
82//!
83//! let mut weight = Tensor::randn(vec![100, 50], None).with_requires_grad();
84//! let mut optimizer = Adam::new();
85//! optimizer.add_parameter(&weight);
86//!
87//! // Training loop with checkpointing
88//! for epoch in 0..10 {
89//!     // ... training logic ...
90//!     
91//!     // Save checkpoint every 5 epochs
92//!     if epoch % 5 == 0 {
93//!         let temp_dir = std::env::temp_dir();
94//!         let checkpoint_path = temp_dir.join(format!("checkpoint_epoch_{}.json", epoch));
95//!         optimizer.save(&checkpoint_path, Format::Json).unwrap();
96//!         
97//!         // Cleanup for example
98//!         std::fs::remove_file(&checkpoint_path).ok();
99//!     }
100//! }
101//! ```
102//!
103//! ## Parameter Re-linking
104//! ```
105//! use train_station::{Tensor, optimizers::Adam};
106//! use train_station::serialization::Serializable;
107//!
108//! // Original training setup
109//! let weight = Tensor::ones(vec![5, 5]).with_requires_grad();
110//! let bias = Tensor::zeros(vec![5]).with_requires_grad();
111//! let mut optimizer = Adam::new();
112//! optimizer.add_parameter(&weight);
113//! optimizer.add_parameter(&bias);
114//!
115//! // Serialize optimizer state
116//! let json = optimizer.to_json().unwrap();
117//!
118//! // Later: create new parameters with same shapes
119//! let new_weight = Tensor::ones(vec![5, 5]).with_requires_grad();
120//! let new_bias = Tensor::zeros(vec![5]).with_requires_grad();
121//!
122//! // Restore optimizer and re-link parameters
123//! let mut loaded_optimizer = Adam::from_json(&json).unwrap();
124//! loaded_optimizer.relink_parameters(&[&new_weight, &new_bias]).unwrap();
125//!
126//! assert!(loaded_optimizer.is_parameter_linked(&new_weight));
127//! assert!(loaded_optimizer.is_parameter_linked(&new_bias));
128//! ```
129//!
130//! # Architecture Design
131//!
132//! ## Serialization Strategy
133//! - **Unified interface**: All serialization through StructSerializable trait
134//! - **Type safety**: Strong typing prevents serialization errors
135//! - **Validation**: Comprehensive validation during deserialization
136//! - **Error handling**: Detailed error messages for debugging
137//! - **Memory efficiency**: Optimized memory usage during serialization
138//!
139//! ## Parameter State Management
140//! - **ID-based tracking**: Parameters tracked by unique tensor IDs
141//! - **Shape validation**: Ensures parameter compatibility during re-linking
142//! - **Insertion order**: Maintains parameter order for consistent re-linking
143//! - **State preservation**: Complete momentum and velocity buffer preservation
144//! - **AMSGrad support**: Full AMSGrad state serialization when enabled
145//!
146//! ## Performance Characteristics
147//! - **Linear complexity**: O(n) serialization time with parameter count
148//! - **Minimal overhead**: Efficient serialization with minimal memory allocation
149//! - **Streaming support**: Support for streaming serialization to files
150//! - **Compression ready**: Binary format suitable for compression
151//! - **Concurrent safe**: Thread-safe serialization operations
152//!
153//! # Thread Safety
154//!
155//! All serialization operations are thread-safe:
156//! - **Immutable serialization**: Serialization does not modify optimizer state
157//! - **Concurrent reads**: Multiple threads can serialize the same optimizer
158//! - **Deserialization safety**: Deserialization creates new optimizer instances
159//! - **Parameter linking**: Re-linking operations are thread-safe
160//!
161//! # Integration with Train Station
162//!
163//! The serialization module integrates seamlessly with the broader Train Station ecosystem:
164//! - **Tensor serialization**: Leverages efficient tensor serialization for parameter states
165//! - **GradTrack compatibility**: Maintains gradient tracking requirements during re-linking
166//! - **Device management**: Preserves device placement information
167//! - **Memory management**: Efficient memory usage aligned with Train Station patterns
168//! - **Error handling**: Consistent error handling with Train Station conventions
169
170use super::{Adam, AdamConfig, ParameterState};
171use crate::serialization::{
172    FieldValue, FromFieldValue, SerializationError, SerializationResult, StructDeserializer,
173    StructSerializable, StructSerializer, ToFieldValue,
174};
175use crate::tensor::core::Tensor;
176use std::collections::HashMap;
177
178// ===== AdamConfig Serialization =====
179
180impl StructSerializable for AdamConfig {
181    /// Convert AdamConfig to StructSerializer for comprehensive serialization
182    ///
183    /// This method serializes all Adam hyperparameters into a structured format suitable
184    /// for both JSON and binary serialization. Every field is essential for proper optimizer
185    /// reconstruction and training continuation. The serialization preserves exact floating-point
186    /// values and boolean flags to ensure identical behavior after deserialization.
187    ///
188    /// # Returns
189    ///
190    /// StructSerializer containing all configuration data with field names and values
191    ///
192    /// # Serialized Fields
193    ///
194    /// - **learning_rate**: Base learning rate for parameter updates
195    /// - **beta1**: Exponential decay rate for first moment estimates
196    /// - **beta2**: Exponential decay rate for second moment estimates  
197    /// - **eps**: Small constant for numerical stability in denominator
198    /// - **weight_decay**: L2 regularization coefficient
199    /// - **amsgrad**: Boolean flag for AMSGrad variant usage
200    ///
201    /// # Performance
202    ///
203    /// - **Time Complexity**: O(1) - Constant time field serialization
204    /// - **Memory Usage**: Minimal allocation for field storage
205    /// - **Precision**: Full floating-point precision preservation
206    fn to_serializer(&self) -> StructSerializer {
207        StructSerializer::new()
208            .field("learning_rate", &self.learning_rate)
209            .field("beta1", &self.beta1)
210            .field("beta2", &self.beta2)
211            .field("eps", &self.eps)
212            .field("weight_decay", &self.weight_decay)
213            .field("amsgrad", &self.amsgrad)
214    }
215
216    /// Create AdamConfig from StructDeserializer with full validation
217    ///
218    /// This method reconstructs an AdamConfig instance from serialized hyperparameters,
219    /// performing comprehensive validation to ensure all required fields are present
220    /// and contain valid values. The deserialization process maintains exact floating-point
221    /// precision and validates that all hyperparameters are within reasonable ranges.
222    ///
223    /// # Arguments
224    ///
225    /// * `deserializer` - StructDeserializer containing configuration field data
226    ///
227    /// # Returns
228    ///
229    /// Reconstructed AdamConfig instance on success, or SerializationError on failure
230    ///
231    /// # Required Fields
232    ///
233    /// All fields must be present in the deserializer:
234    /// - **learning_rate**: Must be a valid f32 value
235    /// - **beta1**: Must be a valid f32 value (typically 0.0-1.0)
236    /// - **beta2**: Must be a valid f32 value (typically 0.0-1.0)
237    /// - **eps**: Must be a valid f32 value (typically small positive)
238    /// - **weight_decay**: Must be a valid f32 value (typically 0.0 or small positive)
239    /// - **amsgrad**: Must be a valid boolean value
240    ///
241    /// # Errors
242    ///
243    /// Returns SerializationError if:
244    /// - Any required field is missing from the deserializer
245    /// - Any field contains invalid data type
246    /// - Field extraction fails for any reason
247    ///
248    /// # Performance
249    ///
250    /// - **Time Complexity**: O(1) - Constant time field extraction
251    /// - **Memory Usage**: Minimal allocation for configuration structure
252    /// - **Validation**: Comprehensive field presence and type validation
253    fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
254        Ok(AdamConfig {
255            learning_rate: deserializer.field("learning_rate")?,
256            beta1: deserializer.field("beta1")?,
257            beta2: deserializer.field("beta2")?,
258            eps: deserializer.field("eps")?,
259            weight_decay: deserializer.field("weight_decay")?,
260            amsgrad: deserializer.field("amsgrad")?,
261        })
262    }
263}
264
265impl ToFieldValue for AdamConfig {
266    /// Convert AdamConfig to FieldValue for embedding in larger structures
267    ///
268    /// This method converts the AdamConfig into a FieldValue::Object that can be
269    /// embedded as a field within larger serializable structures. This enables
270    /// AdamConfig to be serialized as part of more complex training configurations
271    /// or model checkpoints while maintaining its structured representation.
272    ///
273    /// # Returns
274    ///
275    /// FieldValue::Object containing all configuration data as key-value pairs
276    ///
277    /// # Object Structure
278    ///
279    /// The returned object contains these fields:
280    /// - "learning_rate": f32 value as FieldValue::F32
281    /// - "beta1": f32 value as FieldValue::F32
282    /// - "beta2": f32 value as FieldValue::F32
283    /// - "eps": f32 value as FieldValue::F32
284    /// - "weight_decay": f32 value as FieldValue::F32
285    /// - "amsgrad": bool value as FieldValue::Bool
286    ///
287    /// # Performance
288    ///
289    /// - **Time Complexity**: O(1) - Constant time field conversion
290    /// - **Memory Usage**: Allocates HashMap for field storage
291    /// - **Conversion**: Direct field-to-FieldValue conversion without copying
292    fn to_field_value(&self) -> FieldValue {
293        let serializer = self.to_serializer();
294        FieldValue::from_object(serializer.fields.into_iter().collect())
295    }
296}
297
298impl FromFieldValue for AdamConfig {
299    /// Create AdamConfig from FieldValue with comprehensive validation
300    ///
301    /// This method reconstructs an AdamConfig instance from a FieldValue::Object,
302    /// performing type validation and field extraction. It's designed to handle
303    /// AdamConfig instances that were embedded as fields within larger serializable
304    /// structures, ensuring proper error handling and detailed error messages.
305    ///
306    /// # Arguments
307    ///
308    /// * `value` - FieldValue containing configuration data (must be Object variant)
309    /// * `field_name` - Name of the field being deserialized for error context
310    ///
311    /// # Returns
312    ///
313    /// Reconstructed AdamConfig instance on success, or SerializationError on failure
314    ///
315    /// # Expected FieldValue Structure
316    ///
317    /// The FieldValue must be an Object variant containing:
318    /// - "learning_rate": Numeric field value
319    /// - "beta1": Numeric field value  
320    /// - "beta2": Numeric field value
321    /// - "eps": Numeric field value
322    /// - "weight_decay": Numeric field value
323    /// - "amsgrad": Boolean field value
324    ///
325    /// # Errors
326    ///
327    /// Returns SerializationError if:
328    /// - FieldValue is not an Object variant
329    /// - Any required field is missing from the object
330    /// - Any field has incorrect type or invalid value
331    /// - Deserialization process fails for any reason
332    ///
333    /// # Performance
334    ///
335    /// - **Time Complexity**: O(1) - Constant time field extraction and validation
336    /// - **Memory Usage**: Temporary deserializer allocation for field processing
337    /// - **Error Handling**: Detailed error messages with field name context
338    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
339        match value {
340            FieldValue::Object(fields) => {
341                let mut deserializer = StructDeserializer::from_fields(fields);
342                Self::from_deserializer(&mut deserializer)
343            }
344            _ => Err(SerializationError::ValidationFailed {
345                field: field_name.to_string(),
346                message: format!(
347                    "Expected Object for {}, found {}",
348                    std::any::type_name::<Self>(),
349                    value.type_name()
350                ),
351            }),
352        }
353    }
354}
355
356// ===== Serializable Parameter State =====
357
358/// Serializable representation of Adam parameter optimization state
359///
360/// This structure provides a serializable wrapper around the internal ParameterState,
361/// enabling efficient persistence of Adam optimizer state for individual parameters.
362/// It leverages the Train Station tensor serialization system for optimal storage
363/// efficiency and maintains complete fidelity of momentum and velocity buffers.
364///
365/// # Purpose
366///
367/// SerializableParameterState serves as an intermediate representation that:
368/// - **Preserves state**: Maintains complete Adam optimization state per parameter
369/// - **Enables serialization**: Converts internal state to serializable format
370/// - **Supports AMSGrad**: Handles optional AMSGrad maximum velocity tracking
371/// - **Maintains precision**: Preserves exact floating-point values in buffers
372/// - **Validates shapes**: Ensures parameter shape consistency during restoration
373///
374/// # Fields
375///
376/// * `m` - First moment estimate (momentum) tensor buffer
377/// * `v` - Second moment estimate (velocity) tensor buffer  
378/// * `v_hat_max` - Optional maximum velocity tensor for AMSGrad variant
379/// * `step` - Per-parameter step count for bias correction calculations
380///
381/// # Design Rationale
382///
383/// The structure uses direct tensor serialization rather than manual data extraction:
384/// - **Efficiency**: Leverages optimized tensor serialization infrastructure
385/// - **Simplicity**: Eliminates manual buffer management and data copying
386/// - **Consistency**: Maintains alignment with Train Station serialization patterns
387/// - **Reliability**: Reduces serialization bugs through proven tensor serialization
388/// - **Performance**: Optimized memory usage and serialization speed
389///
390/// # Thread Safety
391///
392/// This structure is thread-safe for serialization operations:
393/// - **Immutable serialization**: Serialization does not modify state
394/// - **Clone safety**: Safe to clone across thread boundaries
395/// - **Tensor safety**: Leverages thread-safe tensor operations
396///
397/// # Memory Layout
398///
399/// The structure maintains efficient memory usage:
400/// - **Tensor sharing**: Tensors use reference counting for memory efficiency
401/// - **Optional fields**: AMSGrad state only allocated when needed
402/// - **Minimal overhead**: Small additional memory footprint beyond tensors
403#[derive(Debug, Clone)]
404struct SerializableParameterState {
405    /// First moment estimate (momentum) tensor buffer
406    ///
407    /// Contains the exponentially decaying average of past gradients, used for
408    /// momentum-based parameter updates. Shape matches the associated parameter.
409    m: Tensor,
410
411    /// Second moment estimate (velocity) tensor buffer
412    ///
413    /// Contains the exponentially decaying average of past squared gradients,
414    /// used for adaptive learning rate scaling. Shape matches the associated parameter.
415    v: Tensor,
416
417    /// Optional maximum velocity tensor for AMSGrad variant
418    ///
419    /// When AMSGrad is enabled, this tensor maintains the element-wise maximum
420    /// of all past velocity estimates, providing improved convergence properties.
421    /// None when AMSGrad is disabled. Shape matches the associated parameter when present.
422    v_hat_max: Option<Tensor>,
423
424    /// Per-parameter step count for bias correction
425    ///
426    /// Tracks the number of optimization steps performed for this specific parameter,
427    /// used in Adam's bias correction calculations. Essential for proper optimizer
428    /// behavior when parameters are added at different training stages.
429    step: usize,
430}
431
432impl SerializableParameterState {
433    /// Create SerializableParameterState from internal ParameterState
434    ///
435    /// This method converts the internal ParameterState representation used by
436    /// the Adam optimizer into a serializable format. It performs efficient tensor
437    /// cloning to preserve all optimization state while enabling serialization.
438    /// The conversion maintains exact numerical precision and handles optional
439    /// AMSGrad state appropriately.
440    ///
441    /// # Arguments
442    ///
443    /// * `state` - Reference to the internal ParameterState to convert
444    ///
445    /// # Returns
446    ///
447    /// SerializableParameterState containing all state data ready for serialization
448    ///
449    /// # Conversion Process
450    ///
451    /// The method performs these operations:
452    /// 1. **Momentum cloning**: Clones the momentum tensor (m) with full precision
453    /// 2. **Velocity cloning**: Clones the velocity tensor (v) with full precision
454    /// 3. **AMSGrad handling**: Clones optional AMSGrad state when present
455    /// 4. **Step preservation**: Copies the step count for bias correction
456    ///
457    /// # Performance
458    ///
459    /// - **Time Complexity**: O(1) - Tensor cloning uses reference counting
460    /// - **Memory Usage**: Minimal additional allocation due to tensor sharing
461    /// - **Precision**: Maintains exact floating-point values in all buffers
462    fn from_parameter_state(state: &ParameterState) -> Self {
463        // Use direct tensor cloning - leverages Tensor's efficient serialization
464        // No manual data extraction needed
465        Self {
466            m: state.m.clone(),
467            v: state.v.clone(),
468            v_hat_max: state.v_hat_max.clone(),
469            step: state.step,
470        }
471    }
472
473    /// Convert SerializableParameterState back to internal ParameterState
474    ///
475    /// This method reconstructs the internal ParameterState representation from
476    /// the serializable format, enabling the Adam optimizer to resume training
477    /// with preserved optimization state. The conversion maintains exact numerical
478    /// precision and properly handles optional AMSGrad state restoration.
479    ///
480    /// # Returns
481    ///
482    /// ParameterState instance ready for use by Adam optimizer, or SerializationError on failure
483    ///
484    /// # Reconstruction Process
485    ///
486    /// The method performs these operations:
487    /// 1. **Momentum restoration**: Clones momentum tensor back to internal format
488    /// 2. **Velocity restoration**: Clones velocity tensor back to internal format
489    /// 3. **AMSGrad restoration**: Handles optional AMSGrad state when present
490    /// 4. **Step restoration**: Preserves step count for continued bias correction
491    ///
492    /// # Validation
493    ///
494    /// The method ensures:
495    /// - All tensors have consistent shapes
496    /// - Step count is valid (non-negative)
497    /// - AMSGrad state consistency with configuration
498    /// - Tensor data integrity is maintained
499    ///
500    /// # Performance
501    ///
502    /// - **Time Complexity**: O(1) - Tensor cloning uses reference counting
503    /// - **Memory Usage**: Minimal allocation due to efficient tensor sharing
504    /// - **Precision**: Maintains exact floating-point values from serialization
505    ///
506    /// # Errors
507    ///
508    /// Returns SerializationError if:
509    /// - Tensor shapes are inconsistent
510    /// - Internal tensor state is corrupted
511    /// - Memory allocation fails during reconstruction
512    fn to_parameter_state(&self) -> SerializationResult<ParameterState> {
513        // Direct tensor cloning - no manual memory management needed
514        // Tensors handle their own efficient reconstruction
515        Ok(ParameterState {
516            m: self.m.clone(),
517            v: self.v.clone(),
518            v_hat_max: self.v_hat_max.clone(),
519            step: self.step,
520        })
521    }
522}
523
524impl StructSerializable for SerializableParameterState {
525    /// Convert SerializableParameterState to StructSerializer for serialization
526    ///
527    /// This method serializes all Adam parameter state components into a structured
528    /// format suitable for both JSON and binary serialization. It leverages the
529    /// efficient tensor serialization system to handle momentum and velocity buffers
530    /// while properly managing optional AMSGrad state.
531    ///
532    /// # Returns
533    ///
534    /// StructSerializer containing all parameter state data with field names and values
535    ///
536    /// # Serialized Fields
537    ///
538    /// - **m**: Momentum tensor (first moment estimate)
539    /// - **v**: Velocity tensor (second moment estimate)
540    /// - **v_hat_max**: Optional AMSGrad maximum velocity tensor
541    /// - **step**: Per-parameter step count for bias correction
542    ///
543    /// # Performance
544    ///
545    /// - **Time Complexity**: O(1) - Constant time field serialization
546    /// - **Memory Usage**: Leverages efficient tensor serialization
547    /// - **Precision**: Full floating-point precision preservation in tensors
548    fn to_serializer(&self) -> StructSerializer {
549        StructSerializer::new()
550            .field("m", &self.m)
551            .field("v", &self.v)
552            .field("v_hat_max", &self.v_hat_max)
553            .field("step", &self.step)
554    }
555
556    /// Create SerializableParameterState from StructDeserializer with validation
557    ///
558    /// This method reconstructs a SerializableParameterState from serialized data,
559    /// performing comprehensive validation to ensure all required fields are present
560    /// and contain valid tensor data. It handles both momentum and velocity tensors
561    /// along with optional AMSGrad state and step count information.
562    ///
563    /// # Arguments
564    ///
565    /// * `deserializer` - StructDeserializer containing parameter state field data
566    ///
567    /// # Returns
568    ///
569    /// Reconstructed SerializableParameterState on success, or SerializationError on failure
570    ///
571    /// # Required Fields
572    ///
573    /// All fields must be present in the deserializer:
574    /// - **m**: Momentum tensor with valid shape and data
575    /// - **v**: Velocity tensor with shape matching momentum tensor
576    /// - **v_hat_max**: Optional AMSGrad tensor (None or matching shape)
577    /// - **step**: Valid step count (non-negative integer)
578    ///
579    /// # Validation
580    ///
581    /// The method validates:
582    /// - Tensor field presence and type correctness
583    /// - Shape consistency between momentum and velocity tensors
584    /// - AMSGrad tensor shape consistency when present
585    /// - Step count validity and range
586    ///
587    /// # Performance
588    ///
589    /// - **Time Complexity**: O(1) - Constant time field extraction
590    /// - **Memory Usage**: Efficient tensor deserialization
591    /// - **Validation**: Comprehensive field and type validation
592    fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
593        Ok(Self {
594            m: deserializer.field("m")?,
595            v: deserializer.field("v")?,
596            v_hat_max: deserializer.field("v_hat_max")?,
597            step: deserializer.field("step")?,
598        })
599    }
600}
601
602impl ToFieldValue for SerializableParameterState {
603    /// Convert SerializableParameterState to FieldValue for embedding in collections
604    ///
605    /// This method converts the parameter state into a FieldValue::Object that can be
606    /// stored in collections or embedded within larger serializable structures. It
607    /// maintains the structured representation of all optimization state components
608    /// while enabling flexible serialization patterns.
609    ///
610    /// # Returns
611    ///
612    /// FieldValue::Object containing all parameter state data as key-value pairs
613    ///
614    /// # Object Structure
615    ///
616    /// The returned object contains these fields:
617    /// - "m": Momentum tensor as serialized FieldValue
618    /// - "v": Velocity tensor as serialized FieldValue
619    /// - "v_hat_max": Optional AMSGrad tensor as serialized FieldValue
620    /// - "step": Step count as FieldValue::Usize
621    ///
622    /// # Performance
623    ///
624    /// - **Time Complexity**: O(1) - Constant time field conversion
625    /// - **Memory Usage**: Efficient tensor serialization with minimal overhead
626    /// - **Precision**: Maintains exact tensor data and step count values
627    fn to_field_value(&self) -> FieldValue {
628        let serializer = self.to_serializer();
629        FieldValue::from_object(serializer.fields.into_iter().collect())
630    }
631}
632
633impl FromFieldValue for SerializableParameterState {
634    /// Create SerializableParameterState from FieldValue with comprehensive validation
635    ///
636    /// This method reconstructs a SerializableParameterState from a FieldValue::Object,
637    /// performing type validation and tensor deserialization. It handles the complex
638    /// process of deserializing momentum and velocity tensors along with optional
639    /// AMSGrad state, ensuring data integrity and proper error handling.
640    ///
641    /// # Arguments
642    ///
643    /// * `value` - FieldValue containing parameter state data (must be Object variant)
644    /// * `field_name` - Name of the field being deserialized for error context
645    ///
646    /// # Returns
647    ///
648    /// Reconstructed SerializableParameterState on success, or SerializationError on failure
649    ///
650    /// # Expected FieldValue Structure
651    ///
652    /// The FieldValue must be an Object variant containing:
653    /// - "m": Serialized momentum tensor
654    /// - "v": Serialized velocity tensor
655    /// - "v_hat_max": Optional serialized AMSGrad tensor
656    /// - "step": Step count as numeric value
657    ///
658    /// # Validation
659    ///
660    /// The method validates:
661    /// - FieldValue is Object variant
662    /// - All required fields are present
663    /// - Tensor deserialization succeeds
664    /// - Step count is valid numeric value
665    /// - Tensor shapes are consistent
666    ///
667    /// # Performance
668    ///
669    /// - **Time Complexity**: O(1) - Constant time field extraction and tensor deserialization
670    /// - **Memory Usage**: Efficient tensor deserialization with minimal overhead
671    /// - **Error Handling**: Detailed error messages with field name context
672    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
673        match value {
674            FieldValue::Object(fields) => {
675                let mut deserializer = StructDeserializer::from_fields(fields);
676                Self::from_deserializer(&mut deserializer)
677            }
678            _ => Err(SerializationError::ValidationFailed {
679                field: field_name.to_string(),
680                message: format!(
681                    "Expected Object for {}, found {}",
682                    std::any::type_name::<Self>(),
683                    value.type_name()
684                ),
685            }),
686        }
687    }
688}
689
690// ===== Adam Serialization =====
691
692impl StructSerializable for Adam {
693    /// Convert Adam to StructSerializer for serialization
694    ///
695    /// Serializes all optimizer state including configuration, parameter states,
696    /// and global step count. Parameter linking is not serialized and must be
697    /// done after deserialization.
698    ///
699    /// # Returns
700    ///
701    /// StructSerializer containing all serializable optimizer state
702    fn to_serializer(&self) -> StructSerializer {
703        // Convert parameter states to serializable form
704        let mut serializable_states = HashMap::new();
705        for (param_id, state) in &self.states {
706            serializable_states.insert(
707                *param_id,
708                SerializableParameterState::from_parameter_state(state),
709            );
710        }
711
712        StructSerializer::new()
713            .field("config", &self.config)
714            .field("states", &serializable_states)
715            .field("step_count", &self.step_count)
716            .field("insertion_order", &self.insertion_order)
717    }
718
719    /// Create Adam from StructDeserializer
720    ///
721    /// Reconstructs Adam optimizer from serialized state. Parameters must be
722    /// linked separately using `add_parameter` or `add_parameters`.
723    ///
724    /// # Arguments
725    ///
726    /// * `deserializer` - StructDeserializer containing optimizer data
727    ///
728    /// # Returns
729    ///
730    /// Reconstructed Adam instance without parameter links, or error if deserialization fails
731    fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
732        let config: AdamConfig = deserializer.field("config")?;
733        let serializable_states: HashMap<usize, SerializableParameterState> =
734            deserializer.field("states")?;
735        let step_count: usize = deserializer.field("step_count")?;
736        let insertion_order: Vec<usize> = deserializer.field("insertion_order")?;
737
738        // Reconstruct parameter states from serialized form
739        let mut states = HashMap::new();
740        for (param_id, serializable_state) in serializable_states {
741            states.insert(param_id, serializable_state.to_parameter_state()?);
742        }
743
744        // Create optimizer with reconstructed states - user must call relink_parameters to link tensors
745        Ok(Adam {
746            config,
747            states,
748            step_count,
749            insertion_order,
750        })
751    }
752}
753
754impl FromFieldValue for Adam {
755    /// Create Adam from FieldValue
756    ///
757    /// # Arguments
758    ///
759    /// * `value` - FieldValue containing optimizer data
760    /// * `field_name` - Name of the field being deserialized (for error messages)
761    ///
762    /// # Returns
763    ///
764    /// Reconstructed Adam instance or error if deserialization fails
765    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
766        match value {
767            FieldValue::Object(fields) => {
768                let mut deserializer = StructDeserializer::from_fields(fields);
769                Self::from_deserializer(&mut deserializer)
770            }
771            _ => Err(SerializationError::ValidationFailed {
772                field: field_name.to_string(),
773                message: format!(
774                    "Expected Object for {}, found {}",
775                    std::any::type_name::<Self>(),
776                    value.type_name()
777                ),
778            }),
779        }
780    }
781}
782
783// ===== Serializable Trait Implementation =====
784
785impl crate::serialization::Serializable for Adam {
786    /// Serialize the Adam optimizer to JSON format
787    ///
788    /// This method converts the Adam optimizer into a human-readable JSON string representation
789    /// that includes all optimizer state, configuration, parameter states, and step counts.
790    /// The JSON format is suitable for debugging, configuration files, and cross-language
791    /// interoperability.
792    ///
793    /// # Returns
794    ///
795    /// JSON string representation of the optimizer on success, or `SerializationError` on failure
796    ///
797    /// # Examples
798    ///
799    /// ```
800    /// use train_station::{Tensor, optimizers::Adam};
801    /// use train_station::serialization::Serializable;
802    ///
803    /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
804    /// let mut optimizer = Adam::new();
805    /// optimizer.add_parameter(&weight);
806    ///
807    /// let json = optimizer.to_json().unwrap();
808    /// assert!(!json.is_empty());
809    /// ```
810    fn to_json(&self) -> SerializationResult<String> {
811        <Self as StructSerializable>::to_json(self)
812    }
813
814    /// Deserialize an Adam optimizer from JSON format
815    ///
816    /// This method parses a JSON string and reconstructs an Adam optimizer with all
817    /// saved state. Parameters must be re-linked after deserialization using
818    /// `add_parameter` or `relink_parameters`.
819    ///
820    /// # Arguments
821    ///
822    /// * `json` - JSON string containing serialized optimizer
823    ///
824    /// # Returns
825    ///
826    /// The deserialized optimizer on success, or `SerializationError` on failure
827    ///
828    /// # Examples
829    ///
830    /// ```
831    /// use train_station::{Tensor, optimizers::Adam};
832    /// use train_station::serialization::Serializable;
833    ///
834    /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
835    /// let mut optimizer = Adam::new();
836    /// optimizer.add_parameter(&weight);
837    ///
838    /// let json = optimizer.to_json().unwrap();
839    /// let loaded_optimizer = Adam::from_json(&json).unwrap();
840    /// assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
841    /// ```
842    fn from_json(json: &str) -> SerializationResult<Self> {
843        <Self as StructSerializable>::from_json(json)
844    }
845
846    /// Serialize the Adam optimizer to binary format
847    ///
848    /// This method converts the optimizer into a compact binary representation optimized
849    /// for storage and transmission. The binary format provides maximum performance
850    /// and minimal file sizes compared to JSON.
851    ///
852    /// # Returns
853    ///
854    /// Binary representation of the optimizer on success, or `SerializationError` on failure
855    ///
856    /// # Examples
857    ///
858    /// ```
859    /// use train_station::{Tensor, optimizers::Adam};
860    /// use train_station::serialization::Serializable;
861    ///
862    /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
863    /// let mut optimizer = Adam::new();
864    /// optimizer.add_parameter(&weight);
865    ///
866    /// let binary = optimizer.to_binary().unwrap();
867    /// assert!(!binary.is_empty());
868    /// ```
869    fn to_binary(&self) -> SerializationResult<Vec<u8>> {
870        <Self as StructSerializable>::to_binary(self)
871    }
872
873    /// Deserialize an Adam optimizer from binary format
874    ///
875    /// This method parses binary data and reconstructs an Adam optimizer with all
876    /// saved state. Parameters must be re-linked after deserialization using
877    /// `add_parameter` or `relink_parameters`.
878    ///
879    /// # Arguments
880    ///
881    /// * `data` - Binary data containing serialized optimizer
882    ///
883    /// # Returns
884    ///
885    /// The deserialized optimizer on success, or `SerializationError` on failure
886    ///
887    /// # Examples
888    ///
889    /// ```
890    /// use train_station::{Tensor, optimizers::Adam};
891    /// use train_station::serialization::Serializable;
892    ///
893    /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
894    /// let mut optimizer = Adam::new();
895    /// optimizer.add_parameter(&weight);
896    ///
897    /// let binary = optimizer.to_binary().unwrap();
898    /// let loaded_optimizer = Adam::from_binary(&binary).unwrap();
899    /// assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
900    /// ```
901    fn from_binary(data: &[u8]) -> SerializationResult<Self> {
902        <Self as StructSerializable>::from_binary(data)
903    }
904}
905
906// ===== Utility Methods =====
907
908impl Adam {
909    /// Get the number of saved parameter states for checkpoint validation
910    ///
911    /// This method returns the count of parameter states currently stored in the optimizer,
912    /// which is essential for validating checkpoint integrity and ensuring proper parameter
913    /// re-linking after deserialization. The count includes all parameters that have been
914    /// linked to the optimizer and have accumulated optimization state.
915    ///
916    /// # Returns
917    ///
918    /// Number of parameter states currently stored in the optimizer
919    ///
920    /// # Usage Patterns
921    ///
922    /// ## Checkpoint Validation
923    /// After deserializing an optimizer, this method helps verify that the expected
924    /// number of parameters were saved and can guide the re-linking process.
925    ///
926    /// ## Training Resumption
927    /// When resuming training, compare this count with the number of parameters
928    /// in your model to ensure checkpoint compatibility.
929    ///
930    /// ## State Management
931    /// Use this method to monitor optimizer state growth and memory usage during
932    /// training with dynamic parameter addition.
933    ///
934    /// # Examples
935    ///
936    /// ```
937    /// use train_station::{Tensor, optimizers::Adam};
938    /// use train_station::serialization::Serializable;
939    ///
940    /// let weight = Tensor::ones(vec![10, 5]).with_requires_grad();
941    /// let bias = Tensor::zeros(vec![5]).with_requires_grad();
942    /// let mut optimizer = Adam::new();
943    /// optimizer.add_parameter(&weight);
944    /// optimizer.add_parameter(&bias);
945    ///
946    /// // Check parameter count before serialization
947    /// assert_eq!(optimizer.saved_parameter_count(), 2);
948    ///
949    /// // Serialize and deserialize
950    /// let json = optimizer.to_json().unwrap();
951    /// let loaded_optimizer = Adam::from_json(&json).unwrap();
952    ///
953    /// // Verify parameter count is preserved
954    /// assert_eq!(loaded_optimizer.saved_parameter_count(), 2);
955    /// ```
956    ///
957    /// # Performance
958    ///
959    /// - **Time Complexity**: O(1) - Direct access to internal state count
960    /// - **Memory Usage**: No additional memory allocation
961    /// - **Thread Safety**: Safe to call from multiple threads concurrently
962    pub fn saved_parameter_count(&self) -> usize {
963        self.states.len()
964    }
965}
966
967// ===== Field Value Implementations for Collections =====
968
969impl ToFieldValue for HashMap<usize, SerializableParameterState> {
970    /// Convert parameter states HashMap to FieldValue for serialization
971    ///
972    /// This method converts the HashMap of parameter states into a FieldValue::Object
973    /// suitable for serialization. It handles the conversion of usize keys to string
974    /// format required by the FieldValue::Object representation while preserving
975    /// all parameter state data.
976    ///
977    /// # Returns
978    ///
979    /// FieldValue::Object with string keys and SerializableParameterState values
980    ///
981    /// # Key Conversion
982    ///
983    /// - **Input**: HashMap<usize, SerializableParameterState> with numeric tensor IDs
984    /// - **Output**: FieldValue::Object with string keys for JSON compatibility
985    /// - **Mapping**: Each usize key is converted to string representation
986    /// - **Preservation**: All parameter state data is preserved exactly
987    ///
988    /// # Performance
989    ///
990    /// - **Time Complexity**: O(n) where n is the number of parameter states
991    /// - **Memory Usage**: Allocates new HashMap for string keys
992    /// - **Conversion**: Efficient string conversion for numeric keys
993    fn to_field_value(&self) -> FieldValue {
994        let mut map = HashMap::new();
995        for (key, value) in self {
996            map.insert(key.to_string(), value.to_field_value());
997        }
998        FieldValue::from_object(map)
999    }
1000}
1001
1002impl FromFieldValue for HashMap<usize, SerializableParameterState> {
1003    /// Create parameter states HashMap from FieldValue with validation
1004    ///
1005    /// This method reconstructs the HashMap of parameter states from a FieldValue::Object,
1006    /// performing comprehensive validation and key conversion. It handles the conversion
1007    /// from string keys back to usize tensor IDs while ensuring all parameter state
1008    /// data is properly deserialized and validated.
1009    ///
1010    /// # Arguments
1011    ///
1012    /// * `value` - FieldValue containing parameter states data (must be Object variant)
1013    /// * `field_name` - Name of the field being deserialized for error context
1014    ///
1015    /// # Returns
1016    ///
1017    /// Reconstructed HashMap<usize, SerializableParameterState> on success, or SerializationError on failure
1018    ///
1019    /// # Key Conversion Process
1020    ///
1021    /// 1. **Validation**: Ensures FieldValue is Object variant
1022    /// 2. **Key parsing**: Converts string keys back to usize tensor IDs
1023    /// 3. **State deserialization**: Deserializes each parameter state
1024    /// 4. **Validation**: Validates parameter state integrity
1025    /// 5. **Collection**: Builds final HashMap with proper types
1026    ///
1027    /// # Errors
1028    ///
1029    /// Returns SerializationError if:
1030    /// - FieldValue is not Object variant
1031    /// - Any string key cannot be parsed as usize
1032    /// - Parameter state deserialization fails
1033    /// - Invalid parameter state data is encountered
1034    ///
1035    /// # Performance
1036    ///
1037    /// - **Time Complexity**: O(n) where n is the number of parameter states
1038    /// - **Memory Usage**: Allocates new HashMap with proper key types
1039    /// - **Validation**: Comprehensive key parsing and state validation
1040    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
1041        match value {
1042            FieldValue::Object(fields) => {
1043                let mut map = HashMap::new();
1044                for (key_str, field_value) in fields {
1045                    let key = key_str.parse::<usize>().map_err(|_| {
1046                        SerializationError::ValidationFailed {
1047                            field: field_name.to_string(),
1048                            message: format!("Invalid key '{}' in parameter states map", key_str),
1049                        }
1050                    })?;
1051                    let state =
1052                        SerializableParameterState::from_field_value(field_value, &key_str)?;
1053                    map.insert(key, state);
1054                }
1055                Ok(map)
1056            }
1057            _ => Err(SerializationError::ValidationFailed {
1058                field: field_name.to_string(),
1059                message: format!(
1060                    "Expected Object for HashMap<usize, SerializableParameterState>, found {}",
1061                    value.type_name()
1062                ),
1063            }),
1064        }
1065    }
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070    use super::*;
1071    use crate::optimizers::Optimizer;
1072    use crate::tensor::core::Tensor;
1073
1074    // ===== AdamConfig Serialization Tests =====
1075
1076    #[test]
1077    fn test_adam_config_json_roundtrip() {
1078        let config = AdamConfig {
1079            learning_rate: 1e-4,
1080            beta1: 0.95,
1081            beta2: 0.9999,
1082            eps: 1e-7,
1083            weight_decay: 1e-5,
1084            amsgrad: true,
1085        };
1086
1087        let json = config.to_json().unwrap();
1088        let loaded_config = AdamConfig::from_json(&json).unwrap();
1089
1090        assert_eq!(config.learning_rate, loaded_config.learning_rate);
1091        assert_eq!(config.beta1, loaded_config.beta1);
1092        assert_eq!(config.beta2, loaded_config.beta2);
1093        assert_eq!(config.eps, loaded_config.eps);
1094        assert_eq!(config.weight_decay, loaded_config.weight_decay);
1095        assert_eq!(config.amsgrad, loaded_config.amsgrad);
1096    }
1097
1098    #[test]
1099    fn test_adam_config_binary_roundtrip() {
1100        let config = AdamConfig {
1101            learning_rate: 2e-3,
1102            beta1: 0.85,
1103            beta2: 0.995,
1104            eps: 1e-9,
1105            weight_decay: 5e-4,
1106            amsgrad: false,
1107        };
1108
1109        let binary = config.to_binary().unwrap();
1110        let loaded_config = AdamConfig::from_binary(&binary).unwrap();
1111
1112        assert_eq!(config.learning_rate, loaded_config.learning_rate);
1113        assert_eq!(config.beta1, loaded_config.beta1);
1114        assert_eq!(config.beta2, loaded_config.beta2);
1115        assert_eq!(config.eps, loaded_config.eps);
1116        assert_eq!(config.weight_decay, loaded_config.weight_decay);
1117        assert_eq!(config.amsgrad, loaded_config.amsgrad);
1118    }
1119
1120    #[test]
1121    fn test_adam_config_field_value_roundtrip() {
1122        let config = AdamConfig {
1123            learning_rate: 3e-4,
1124            beta1: 0.92,
1125            beta2: 0.998,
1126            eps: 1e-6,
1127            weight_decay: 2e-4,
1128            amsgrad: true,
1129        };
1130
1131        let field_value = config.to_field_value();
1132        let loaded_config = AdamConfig::from_field_value(field_value, "config").unwrap();
1133
1134        assert_eq!(config.learning_rate, loaded_config.learning_rate);
1135        assert_eq!(config.beta1, loaded_config.beta1);
1136        assert_eq!(config.beta2, loaded_config.beta2);
1137        assert_eq!(config.eps, loaded_config.eps);
1138        assert_eq!(config.weight_decay, loaded_config.weight_decay);
1139        assert_eq!(config.amsgrad, loaded_config.amsgrad);
1140    }
1141
1142    // ===== Adam Optimizer Serialization Tests =====
1143
1144    #[test]
1145    fn test_adam_optimizer_json_roundtrip() {
1146        let mut weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1147        let mut bias = Tensor::zeros(vec![2, 3]).with_requires_grad();
1148
1149        let mut optimizer = Adam::new();
1150        optimizer.add_parameter(&weight);
1151        optimizer.add_parameter(&bias);
1152
1153        // Perform some steps to create state
1154        let output = weight.add_tensor(&bias);
1155        let mut loss = output.sum();
1156        loss.backward(None);
1157        optimizer.step(&mut [&mut weight, &mut bias]);
1158
1159        // Test serialization
1160        let json = optimizer.to_json().unwrap();
1161        let loaded_optimizer = Adam::from_json(&json).unwrap();
1162
1163        assert_eq!(
1164            optimizer.config().learning_rate,
1165            loaded_optimizer.config().learning_rate
1166        );
1167        assert_eq!(
1168            optimizer.saved_parameter_count(),
1169            loaded_optimizer.saved_parameter_count()
1170        );
1171    }
1172
1173    #[test]
1174    fn test_adam_optimizer_binary_roundtrip() {
1175        let weight = Tensor::ones(vec![5, 2]).with_requires_grad();
1176
1177        let mut optimizer = Adam::with_learning_rate(1e-4);
1178        optimizer.add_parameter(&weight);
1179
1180        // Test serialization
1181        let binary = optimizer.to_binary().unwrap();
1182        let loaded_optimizer = Adam::from_binary(&binary).unwrap();
1183
1184        assert_eq!(
1185            optimizer.config().learning_rate,
1186            loaded_optimizer.config().learning_rate
1187        );
1188        assert_eq!(
1189            optimizer.saved_parameter_count(),
1190            loaded_optimizer.saved_parameter_count()
1191        );
1192    }
1193
1194    #[test]
1195    fn test_adam_parameter_relinking() {
1196        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1197        let mut optimizer = Adam::new();
1198        optimizer.add_parameter(&weight);
1199
1200        // Serialize
1201        let json = optimizer.to_json().unwrap();
1202
1203        // Deserialize
1204        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1205
1206        // After deserialization, saved states should be preserved
1207        assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
1208
1209        // Re-link parameter - this creates a new state since it's a new tensor with different ID
1210        let new_weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1211        loaded_optimizer.add_parameter(&new_weight);
1212
1213        // Now there should be 2 states: the original saved one + the new one
1214        assert_eq!(loaded_optimizer.parameter_count(), 2);
1215        assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1216    }
1217
1218    #[test]
1219    fn test_adam_state_preservation() {
1220        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1221        let mut optimizer = Adam::new();
1222        optimizer.add_parameter(&weight);
1223
1224        // Perform training steps to build up state
1225        for _ in 0..3 {
1226            let output = weight.mul_scalar(2.0);
1227            let mut loss = output.sum();
1228            loss.backward(None);
1229            optimizer.step(&mut [&mut weight]);
1230            optimizer.zero_grad(&mut [&mut weight]);
1231        }
1232
1233        // Serialize and deserialize
1234        let json = optimizer.to_json().unwrap();
1235        let loaded_optimizer = Adam::from_json(&json).unwrap();
1236
1237        // Check that states were preserved
1238        assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
1239        assert_eq!(
1240            loaded_optimizer.config().learning_rate,
1241            optimizer.config().learning_rate
1242        );
1243    }
1244
1245    #[test]
1246    fn test_relink_parameters_success() {
1247        // Create original optimizer with parameters
1248        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1249        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1250
1251        let mut optimizer = Adam::new();
1252        optimizer.add_parameter(&weight);
1253        optimizer.add_parameter(&bias);
1254
1255        // Serialize
1256        let json = optimizer.to_json().unwrap();
1257
1258        // Create new parameters with same shapes but different IDs
1259        let new_weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1260        let new_bias = Tensor::zeros(vec![3]).with_requires_grad();
1261
1262        // Deserialize and re-link
1263        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1264        loaded_optimizer
1265            .relink_parameters(&[&new_weight, &new_bias])
1266            .unwrap();
1267
1268        // Verify re-linking worked
1269        assert_eq!(loaded_optimizer.parameter_count(), 2);
1270        assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1271        assert!(loaded_optimizer.is_parameter_linked(&new_bias));
1272    }
1273
1274    #[test]
1275    fn test_relink_parameters_shape_mismatch() {
1276        // Create original optimizer
1277        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1278        let mut optimizer = Adam::new();
1279        optimizer.add_parameter(&weight);
1280
1281        // Serialize
1282        let json = optimizer.to_json().unwrap();
1283
1284        // Create new parameter with different shape
1285        let new_weight = Tensor::ones(vec![3, 2]).with_requires_grad(); // Different shape!
1286
1287        // Deserialize and try to re-link
1288        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1289        let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1290
1291        // Should fail with shape mismatch error
1292        assert!(result.is_err());
1293        assert!(result.unwrap_err().contains("Shape mismatch"));
1294    }
1295
1296    #[test]
1297    fn test_relink_parameters_count_mismatch() {
1298        // Create original optimizer with 2 parameters
1299        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1300        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1301
1302        let mut optimizer = Adam::new();
1303        optimizer.add_parameter(&weight);
1304        optimizer.add_parameter(&bias);
1305
1306        // Serialize
1307        let json = optimizer.to_json().unwrap();
1308
1309        // Create only 1 new parameter
1310        let new_weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1311
1312        // Deserialize and try to re-link with wrong count
1313        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1314        let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1315
1316        // Should fail with count mismatch error
1317        assert!(result.is_err());
1318        assert!(result.unwrap_err().contains("Parameter count mismatch"));
1319    }
1320
1321    #[test]
1322    fn test_relink_parameters_requires_grad() {
1323        // Create original optimizer
1324        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1325        let mut optimizer = Adam::new();
1326        optimizer.add_parameter(&weight);
1327
1328        // Serialize
1329        let json = optimizer.to_json().unwrap();
1330
1331        // Create new parameter without requires_grad
1332        let new_weight = Tensor::ones(vec![2, 3]); // No requires_grad!
1333
1334        // Deserialize and try to re-link
1335        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1336        let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1337
1338        // Should fail with requires_grad error
1339        assert!(result.is_err());
1340        assert!(result.unwrap_err().contains("must require gradients"));
1341    }
1342
1343    #[test]
1344    fn test_relink_preserves_state() {
1345        // Create original optimizer and train it
1346        let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1347        let mut optimizer = Adam::new();
1348        optimizer.add_parameter(&weight);
1349
1350        // Perform some training to build up state
1351        for _ in 0..2 {
1352            let output = weight.mul_scalar(2.0);
1353            let mut loss = output.sum();
1354            loss.backward(None);
1355            optimizer.step(&mut [&mut weight]);
1356            optimizer.zero_grad(&mut [&mut weight]);
1357        }
1358
1359        // Get state before serialization
1360        let original_step_count = optimizer.step_count;
1361
1362        // Serialize
1363        let json = optimizer.to_json().unwrap();
1364
1365        // Create new parameter
1366        let new_weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1367
1368        // Deserialize and re-link
1369        let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1370        loaded_optimizer.relink_parameters(&[&new_weight]).unwrap();
1371
1372        // Verify state is preserved
1373        assert_eq!(loaded_optimizer.step_count, original_step_count);
1374        assert_eq!(loaded_optimizer.parameter_count(), 1);
1375        assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1376    }
1377
1378    // ===== Serializable Trait Tests =====
1379
1380    #[test]
1381    fn test_serializable_json_methods() {
1382        // Create and populate test optimizer
1383        let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1384        let bias = Tensor::zeros(vec![3]).with_requires_grad();
1385
1386        let mut optimizer = Adam::with_learning_rate(1e-3);
1387        optimizer.add_parameter(&weight);
1388        optimizer.add_parameter(&bias);
1389
1390        // Test to_json method
1391        let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1392        assert!(!json.is_empty());
1393        assert!(json.contains("config"));
1394        assert!(json.contains("states"));
1395        assert!(json.contains("step_count"));
1396        assert!(json.contains("learning_rate"));
1397
1398        // Test from_json method
1399        let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1400        assert_eq!(
1401            optimizer.config().learning_rate,
1402            restored.config().learning_rate
1403        );
1404        assert_eq!(optimizer.config().beta1, restored.config().beta1);
1405        assert_eq!(optimizer.config().beta2, restored.config().beta2);
1406        assert_eq!(optimizer.config().eps, restored.config().eps);
1407        assert_eq!(
1408            optimizer.config().weight_decay,
1409            restored.config().weight_decay
1410        );
1411        assert_eq!(optimizer.config().amsgrad, restored.config().amsgrad);
1412        assert_eq!(
1413            optimizer.saved_parameter_count(),
1414            restored.saved_parameter_count()
1415        );
1416        assert_eq!(optimizer.step_count, restored.step_count);
1417    }
1418
1419    #[test]
1420    fn test_serializable_binary_methods() {
1421        // Create and populate test optimizer
1422        let weight = Tensor::ones(vec![3, 4]).with_requires_grad();
1423        let mut optimizer = Adam::with_config(AdamConfig {
1424            learning_rate: 2e-4,
1425            beta1: 0.95,
1426            beta2: 0.999,
1427            eps: 1e-7,
1428            weight_decay: 1e-4,
1429            amsgrad: true,
1430        });
1431        optimizer.add_parameter(&weight);
1432
1433        // Test to_binary method
1434        let binary = <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1435        assert!(!binary.is_empty());
1436
1437        // Test from_binary method
1438        let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1439        assert_eq!(
1440            optimizer.config().learning_rate,
1441            restored.config().learning_rate
1442        );
1443        assert_eq!(optimizer.config().beta1, restored.config().beta1);
1444        assert_eq!(optimizer.config().beta2, restored.config().beta2);
1445        assert_eq!(optimizer.config().eps, restored.config().eps);
1446        assert_eq!(
1447            optimizer.config().weight_decay,
1448            restored.config().weight_decay
1449        );
1450        assert_eq!(optimizer.config().amsgrad, restored.config().amsgrad);
1451        assert_eq!(
1452            optimizer.saved_parameter_count(),
1453            restored.saved_parameter_count()
1454        );
1455        assert_eq!(optimizer.step_count, restored.step_count);
1456    }
1457
1458    #[test]
1459    fn test_serializable_file_io_json() {
1460        use crate::serialization::{Format, Serializable};
1461        use std::fs;
1462        use std::path::Path;
1463
1464        // Create test optimizer
1465        let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1466        let bias = Tensor::zeros(vec![2]).with_requires_grad();
1467
1468        let mut optimizer = Adam::with_learning_rate(5e-4);
1469        optimizer.add_parameter(&weight);
1470        optimizer.add_parameter(&bias);
1471
1472        let json_path = "test_adam_serializable.json";
1473
1474        // Test save method with JSON format
1475        Serializable::save(&optimizer, json_path, Format::Json).unwrap();
1476        assert!(Path::new(json_path).exists());
1477
1478        // Test load method with JSON format
1479        let loaded_optimizer = Adam::load(json_path, Format::Json).unwrap();
1480        assert_eq!(
1481            optimizer.config().learning_rate,
1482            loaded_optimizer.config().learning_rate
1483        );
1484        assert_eq!(
1485            optimizer.saved_parameter_count(),
1486            loaded_optimizer.saved_parameter_count()
1487        );
1488
1489        // Test save_to_writer method
1490        let json_path_2 = "test_adam_serializable_writer.json";
1491        {
1492            let file = std::fs::File::create(json_path_2).unwrap();
1493            let mut writer = std::io::BufWriter::new(file);
1494            Serializable::save_to_writer(&optimizer, &mut writer, Format::Json).unwrap();
1495        }
1496        assert!(Path::new(json_path_2).exists());
1497
1498        // Test load_from_reader method
1499        {
1500            let file = std::fs::File::open(json_path_2).unwrap();
1501            let mut reader = std::io::BufReader::new(file);
1502            let loaded_optimizer = Adam::load_from_reader(&mut reader, Format::Json).unwrap();
1503            assert_eq!(
1504                optimizer.config().learning_rate,
1505                loaded_optimizer.config().learning_rate
1506            );
1507        }
1508
1509        // Cleanup test files
1510        let _ = fs::remove_file(json_path);
1511        let _ = fs::remove_file(json_path_2);
1512    }
1513
1514    #[test]
1515    fn test_serializable_file_io_binary() {
1516        use crate::serialization::{Format, Serializable};
1517        use std::fs;
1518        use std::path::Path;
1519
1520        // Create test optimizer
1521        let weight = Tensor::ones(vec![3, 3]).with_requires_grad();
1522        let mut optimizer = Adam::with_config(AdamConfig {
1523            learning_rate: 1e-3,
1524            beta1: 0.9,
1525            beta2: 0.999,
1526            eps: 1e-8,
1527            weight_decay: 0.0,
1528            amsgrad: false,
1529        });
1530        optimizer.add_parameter(&weight);
1531
1532        let binary_path = "test_adam_serializable.bin";
1533
1534        // Test save method with binary format
1535        Serializable::save(&optimizer, binary_path, Format::Binary).unwrap();
1536        assert!(Path::new(binary_path).exists());
1537
1538        // Test load method with binary format
1539        let loaded_optimizer = Adam::load(binary_path, Format::Binary).unwrap();
1540        assert_eq!(
1541            optimizer.config().learning_rate,
1542            loaded_optimizer.config().learning_rate
1543        );
1544        assert_eq!(
1545            optimizer.saved_parameter_count(),
1546            loaded_optimizer.saved_parameter_count()
1547        );
1548
1549        // Test save_to_writer method
1550        let binary_path_2 = "test_adam_serializable_writer.bin";
1551        {
1552            let file = std::fs::File::create(binary_path_2).unwrap();
1553            let mut writer = std::io::BufWriter::new(file);
1554            Serializable::save_to_writer(&optimizer, &mut writer, Format::Binary).unwrap();
1555        }
1556        assert!(Path::new(binary_path_2).exists());
1557
1558        // Test load_from_reader method
1559        {
1560            let file = std::fs::File::open(binary_path_2).unwrap();
1561            let mut reader = std::io::BufReader::new(file);
1562            let loaded_optimizer = Adam::load_from_reader(&mut reader, Format::Binary).unwrap();
1563            assert_eq!(
1564                optimizer.config().learning_rate,
1565                loaded_optimizer.config().learning_rate
1566            );
1567        }
1568
1569        // Cleanup test files
1570        let _ = fs::remove_file(binary_path);
1571        let _ = fs::remove_file(binary_path_2);
1572    }
1573
1574    #[test]
1575    fn test_serializable_large_optimizer_performance() {
1576        // Create a large optimizer to test performance characteristics
1577        let mut optimizer = Adam::with_learning_rate(1e-4);
1578
1579        // Add multiple parameters of different sizes
1580        for i in 0..5 {
1581            let size = 10 + i * 5;
1582            let param = Tensor::ones(vec![size, size]).with_requires_grad();
1583            optimizer.add_parameter(&param);
1584        }
1585
1586        // Test JSON serialization
1587        let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1588        assert!(!json.is_empty());
1589        let restored_json = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1590        assert_eq!(
1591            optimizer.config().learning_rate,
1592            restored_json.config().learning_rate
1593        );
1594        assert_eq!(
1595            optimizer.saved_parameter_count(),
1596            restored_json.saved_parameter_count()
1597        );
1598
1599        // Test binary serialization
1600        let binary = <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1601        assert!(!binary.is_empty());
1602        // Binary format should be efficient (this is informational, not a requirement)
1603        println!(
1604            "JSON size: {} bytes, Binary size: {} bytes",
1605            json.len(),
1606            binary.len()
1607        );
1608
1609        let restored_binary =
1610            <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1611        assert_eq!(
1612            optimizer.config().learning_rate,
1613            restored_binary.config().learning_rate
1614        );
1615        assert_eq!(
1616            optimizer.saved_parameter_count(),
1617            restored_binary.saved_parameter_count()
1618        );
1619
1620        // Verify all configurations match
1621        assert_eq!(optimizer.config().beta1, restored_binary.config().beta1);
1622        assert_eq!(optimizer.config().beta2, restored_binary.config().beta2);
1623        assert_eq!(optimizer.config().eps, restored_binary.config().eps);
1624        assert_eq!(
1625            optimizer.config().weight_decay,
1626            restored_binary.config().weight_decay
1627        );
1628        assert_eq!(optimizer.config().amsgrad, restored_binary.config().amsgrad);
1629    }
1630
1631    #[test]
1632    fn test_serializable_error_handling() {
1633        // Test invalid JSON
1634        let invalid_json = r#"{"invalid": "json", "structure": true}"#;
1635        let result = <Adam as crate::serialization::Serializable>::from_json(invalid_json);
1636        assert!(result.is_err());
1637
1638        // Test empty JSON
1639        let empty_json = "{}";
1640        let result = <Adam as crate::serialization::Serializable>::from_json(empty_json);
1641        assert!(result.is_err());
1642
1643        // Test invalid binary data
1644        let invalid_binary = vec![1, 2, 3, 4, 5];
1645        let result = <Adam as crate::serialization::Serializable>::from_binary(&invalid_binary);
1646        assert!(result.is_err());
1647
1648        // Test empty binary data
1649        let empty_binary = vec![];
1650        let result = <Adam as crate::serialization::Serializable>::from_binary(&empty_binary);
1651        assert!(result.is_err());
1652    }
1653
1654    #[test]
1655    fn test_serializable_different_configurations() {
1656        let test_configs = vec![
1657            // Default configuration
1658            AdamConfig::default(),
1659            // High learning rate
1660            AdamConfig {
1661                learning_rate: 1e-2,
1662                beta1: 0.9,
1663                beta2: 0.999,
1664                eps: 1e-8,
1665                weight_decay: 0.0,
1666                amsgrad: false,
1667            },
1668            // AMSGrad enabled
1669            AdamConfig {
1670                learning_rate: 1e-4,
1671                beta1: 0.9,
1672                beta2: 0.999,
1673                eps: 1e-8,
1674                weight_decay: 1e-4,
1675                amsgrad: true,
1676            },
1677            // Custom betas
1678            AdamConfig {
1679                learning_rate: 5e-4,
1680                beta1: 0.95,
1681                beta2: 0.9999,
1682                eps: 1e-7,
1683                weight_decay: 1e-5,
1684                amsgrad: false,
1685            },
1686        ];
1687
1688        for config in test_configs {
1689            // Create optimizer with specific configuration
1690            let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1691            let mut optimizer = Adam::with_config(config.clone());
1692            optimizer.add_parameter(&weight);
1693
1694            // Test JSON roundtrip
1695            let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1696            let restored_json =
1697                <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1698            assert_eq!(config.learning_rate, restored_json.config().learning_rate);
1699            assert_eq!(config.beta1, restored_json.config().beta1);
1700            assert_eq!(config.beta2, restored_json.config().beta2);
1701            assert_eq!(config.eps, restored_json.config().eps);
1702            assert_eq!(config.weight_decay, restored_json.config().weight_decay);
1703            assert_eq!(config.amsgrad, restored_json.config().amsgrad);
1704
1705            // Test binary roundtrip
1706            let binary =
1707                <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1708            let restored_binary =
1709                <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1710            assert_eq!(config.learning_rate, restored_binary.config().learning_rate);
1711            assert_eq!(config.beta1, restored_binary.config().beta1);
1712            assert_eq!(config.beta2, restored_binary.config().beta2);
1713            assert_eq!(config.eps, restored_binary.config().eps);
1714            assert_eq!(config.weight_decay, restored_binary.config().weight_decay);
1715            assert_eq!(config.amsgrad, restored_binary.config().amsgrad);
1716        }
1717    }
1718
1719    #[test]
1720    fn test_serializable_edge_cases() {
1721        // Test optimizer with no parameters
1722        let empty_optimizer = Adam::new();
1723        let json = <Adam as crate::serialization::Serializable>::to_json(&empty_optimizer).unwrap();
1724        let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1725        assert_eq!(
1726            empty_optimizer.saved_parameter_count(),
1727            restored.saved_parameter_count()
1728        );
1729        assert_eq!(empty_optimizer.step_count, restored.step_count);
1730
1731        let binary =
1732            <Adam as crate::serialization::Serializable>::to_binary(&empty_optimizer).unwrap();
1733        let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1734        assert_eq!(
1735            empty_optimizer.saved_parameter_count(),
1736            restored.saved_parameter_count()
1737        );
1738        assert_eq!(empty_optimizer.step_count, restored.step_count);
1739
1740        // Test optimizer with extreme configuration values
1741        let extreme_config = AdamConfig {
1742            learning_rate: 1e-10, // Very small learning rate
1743            beta1: 0.999999,      // Very high beta1
1744            beta2: 0.000001,      // Very low beta2
1745            eps: 1e-15,           // Very small epsilon
1746            weight_decay: 1e-1,   // High weight decay
1747            amsgrad: true,
1748        };
1749
1750        let weight = Tensor::ones(vec![1]).with_requires_grad();
1751        let mut extreme_optimizer = Adam::with_config(extreme_config.clone());
1752        extreme_optimizer.add_parameter(&weight);
1753
1754        let json =
1755            <Adam as crate::serialization::Serializable>::to_json(&extreme_optimizer).unwrap();
1756        let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1757        assert_eq!(
1758            extreme_config.learning_rate,
1759            restored.config().learning_rate
1760        );
1761        assert_eq!(extreme_config.beta1, restored.config().beta1);
1762        assert_eq!(extreme_config.beta2, restored.config().beta2);
1763        assert_eq!(extreme_config.eps, restored.config().eps);
1764        assert_eq!(extreme_config.weight_decay, restored.config().weight_decay);
1765        assert_eq!(extreme_config.amsgrad, restored.config().amsgrad);
1766
1767        let binary =
1768            <Adam as crate::serialization::Serializable>::to_binary(&extreme_optimizer).unwrap();
1769        let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1770        assert_eq!(
1771            extreme_config.learning_rate,
1772            restored.config().learning_rate
1773        );
1774        assert_eq!(extreme_config.beta1, restored.config().beta1);
1775        assert_eq!(extreme_config.beta2, restored.config().beta2);
1776        assert_eq!(extreme_config.eps, restored.config().eps);
1777        assert_eq!(extreme_config.weight_decay, restored.config().weight_decay);
1778        assert_eq!(extreme_config.amsgrad, restored.config().amsgrad);
1779    }
1780}