train_station/optimizers/adam/serialization.rs
1//! Comprehensive serialization support for Adam optimizer state and configuration
2//!
3//! This module provides complete serialization capabilities for the Adam optimizer, enabling
4//! model checkpointing, state persistence, and cross-platform optimizer state transfer.
5//! The serialization system supports both human-readable JSON and efficient binary formats
6//! with perfect roundtrip fidelity and seamless parameter re-linking.
7//!
8//! # Purpose
9//!
10//! The serialization module serves as the persistence layer for Adam optimizer training,
11//! providing essential functionality for:
12//! - **Model checkpointing**: Save and restore optimizer state during training
13//! - **Training resumption**: Continue training from saved checkpoints
14//! - **State transfer**: Move optimizer state between different environments
15//! - **Debugging support**: Human-readable JSON format for state inspection
16//! - **Performance optimization**: Efficient binary format for production use
17//! - **Cross-platform compatibility**: Consistent serialization across systems
18//!
19//! # Supported Components
20//!
21//! ## AdamConfig Serialization
22//! - **Learning rate**: Base learning rate for parameter updates
23//! - **Beta parameters**: Momentum decay rates (beta1, beta2)
24//! - **Epsilon**: Numerical stability constant
25//! - **Weight decay**: L2 regularization coefficient
26//! - **AMSGrad flag**: Whether to use AMSGrad variant
27//!
28//! ## Parameter State Serialization
29//! - **Momentum buffers**: First moment estimates for each parameter
30//! - **Velocity buffers**: Second moment estimates for each parameter
31//! - **AMSGrad state**: Maximum velocity buffers when AMSGrad is enabled
32//! - **Step counts**: Per-parameter step counts for bias correction
33//! - **Shape information**: Parameter shapes for validation during re-linking
34//!
35//! ## Global Optimizer State
36//! - **Global step count**: Total optimization steps performed
37//! - **Parameter insertion order**: Order for consistent parameter re-linking
38//! - **Configuration state**: Complete hyperparameter configuration
39//! - **State validation**: Integrity checks for serialized data
40//!
41//! # Serialization Formats
42//!
43//! ## JSON Format
44//! - **Human-readable**: Easy to inspect and debug optimizer state
45//! - **Cross-language**: Compatible with other JSON-parsing systems
46//! - **Configuration files**: Suitable for storing optimizer configurations
47//! - **Debugging**: Clear structure for troubleshooting training issues
48//! - **Interoperability**: Exchange optimizer state with external tools
49//!
50//! ## Binary Format
51//! - **Compact storage**: Minimal file sizes for production deployment
52//! - **Fast I/O**: Optimized for quick save/load operations
53//! - **Performance**: Reduced serialization overhead during training
54//! - **Bandwidth efficiency**: Minimal network transfer requirements
55//! - **Production ready**: Optimized for high-performance training workflows
56//!
57//! # Usage Patterns
58//!
59//! ## Basic Serialization
60//! ```
61//! use train_station::{Tensor, optimizers::Adam};
62//! use train_station::serialization::Serializable;
63//!
64//! // Create optimizer with parameters
65//! let weight = Tensor::ones(vec![10, 5]).with_requires_grad();
66//! let mut optimizer = Adam::new();
67//! optimizer.add_parameter(&weight);
68//!
69//! // Serialize optimizer state
70//! let json = optimizer.to_json().unwrap();
71//! let binary = optimizer.to_binary().unwrap();
72//!
73//! // Deserialize optimizer
74//! let loaded_optimizer = Adam::from_json(&json).unwrap();
75//! assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
76//! ```
77//!
78//! ## Training Checkpointing
79//! ```
80//! use train_station::{Tensor, optimizers::Adam};
81//! use train_station::serialization::{Serializable, Format};
82//!
83//! let mut weight = Tensor::randn(vec![100, 50], None).with_requires_grad();
84//! let mut optimizer = Adam::new();
85//! optimizer.add_parameter(&weight);
86//!
87//! // Training loop with checkpointing
88//! for epoch in 0..10 {
89//! // ... training logic ...
90//!
91//! // Save checkpoint every 5 epochs
92//! if epoch % 5 == 0 {
93//! let temp_dir = std::env::temp_dir();
94//! let checkpoint_path = temp_dir.join(format!("checkpoint_epoch_{}.json", epoch));
95//! optimizer.save(&checkpoint_path, Format::Json).unwrap();
96//!
97//! // Cleanup for example
98//! std::fs::remove_file(&checkpoint_path).ok();
99//! }
100//! }
101//! ```
102//!
103//! ## Parameter Re-linking
104//! ```
105//! use train_station::{Tensor, optimizers::Adam};
106//! use train_station::serialization::Serializable;
107//!
108//! // Original training setup
109//! let weight = Tensor::ones(vec![5, 5]).with_requires_grad();
110//! let bias = Tensor::zeros(vec![5]).with_requires_grad();
111//! let mut optimizer = Adam::new();
112//! optimizer.add_parameter(&weight);
113//! optimizer.add_parameter(&bias);
114//!
115//! // Serialize optimizer state
116//! let json = optimizer.to_json().unwrap();
117//!
118//! // Later: create new parameters with same shapes
119//! let new_weight = Tensor::ones(vec![5, 5]).with_requires_grad();
120//! let new_bias = Tensor::zeros(vec![5]).with_requires_grad();
121//!
122//! // Restore optimizer and re-link parameters
123//! let mut loaded_optimizer = Adam::from_json(&json).unwrap();
124//! loaded_optimizer.relink_parameters(&[&new_weight, &new_bias]).unwrap();
125//!
126//! assert!(loaded_optimizer.is_parameter_linked(&new_weight));
127//! assert!(loaded_optimizer.is_parameter_linked(&new_bias));
128//! ```
129//!
130//! # Architecture Design
131//!
132//! ## Serialization Strategy
133//! - **Unified interface**: All serialization through StructSerializable trait
134//! - **Type safety**: Strong typing prevents serialization errors
135//! - **Validation**: Comprehensive validation during deserialization
136//! - **Error handling**: Detailed error messages for debugging
137//! - **Memory efficiency**: Optimized memory usage during serialization
138//!
139//! ## Parameter State Management
140//! - **ID-based tracking**: Parameters tracked by unique tensor IDs
141//! - **Shape validation**: Ensures parameter compatibility during re-linking
142//! - **Insertion order**: Maintains parameter order for consistent re-linking
143//! - **State preservation**: Complete momentum and velocity buffer preservation
144//! - **AMSGrad support**: Full AMSGrad state serialization when enabled
145//!
146//! ## Performance Characteristics
147//! - **Linear complexity**: O(n) serialization time with parameter count
148//! - **Minimal overhead**: Efficient serialization with minimal memory allocation
149//! - **Streaming support**: Support for streaming serialization to files
150//! - **Compression ready**: Binary format suitable for compression
151//! - **Concurrent safe**: Thread-safe serialization operations
152//!
153//! # Thread Safety
154//!
155//! All serialization operations are thread-safe:
156//! - **Immutable serialization**: Serialization does not modify optimizer state
157//! - **Concurrent reads**: Multiple threads can serialize the same optimizer
158//! - **Deserialization safety**: Deserialization creates new optimizer instances
159//! - **Parameter linking**: Re-linking operations are thread-safe
160//!
161//! # Integration with Train Station
162//!
163//! The serialization module integrates seamlessly with the broader Train Station ecosystem:
164//! - **Tensor serialization**: Leverages efficient tensor serialization for parameter states
165//! - **GradTrack compatibility**: Maintains gradient tracking requirements during re-linking
166//! - **Device management**: Preserves device placement information
167//! - **Memory management**: Efficient memory usage aligned with Train Station patterns
168//! - **Error handling**: Consistent error handling with Train Station conventions
169
170use super::{Adam, AdamConfig, ParameterState};
171use crate::serialization::{
172 FieldValue, FromFieldValue, SerializationError, SerializationResult, StructDeserializer,
173 StructSerializable, StructSerializer, ToFieldValue,
174};
175use crate::tensor::core::Tensor;
176use std::collections::HashMap;
177
178// ===== AdamConfig Serialization =====
179
180impl StructSerializable for AdamConfig {
181 /// Convert AdamConfig to StructSerializer for comprehensive serialization
182 ///
183 /// This method serializes all Adam hyperparameters into a structured format suitable
184 /// for both JSON and binary serialization. Every field is essential for proper optimizer
185 /// reconstruction and training continuation. The serialization preserves exact floating-point
186 /// values and boolean flags to ensure identical behavior after deserialization.
187 ///
188 /// # Returns
189 ///
190 /// StructSerializer containing all configuration data with field names and values
191 ///
192 /// # Serialized Fields
193 ///
194 /// - **learning_rate**: Base learning rate for parameter updates
195 /// - **beta1**: Exponential decay rate for first moment estimates
196 /// - **beta2**: Exponential decay rate for second moment estimates
197 /// - **eps**: Small constant for numerical stability in denominator
198 /// - **weight_decay**: L2 regularization coefficient
199 /// - **amsgrad**: Boolean flag for AMSGrad variant usage
200 ///
201 /// # Performance
202 ///
203 /// - **Time Complexity**: O(1) - Constant time field serialization
204 /// - **Memory Usage**: Minimal allocation for field storage
205 /// - **Precision**: Full floating-point precision preservation
206 fn to_serializer(&self) -> StructSerializer {
207 StructSerializer::new()
208 .field("learning_rate", &self.learning_rate)
209 .field("beta1", &self.beta1)
210 .field("beta2", &self.beta2)
211 .field("eps", &self.eps)
212 .field("weight_decay", &self.weight_decay)
213 .field("amsgrad", &self.amsgrad)
214 }
215
216 /// Create AdamConfig from StructDeserializer with full validation
217 ///
218 /// This method reconstructs an AdamConfig instance from serialized hyperparameters,
219 /// performing comprehensive validation to ensure all required fields are present
220 /// and contain valid values. The deserialization process maintains exact floating-point
221 /// precision and validates that all hyperparameters are within reasonable ranges.
222 ///
223 /// # Arguments
224 ///
225 /// * `deserializer` - StructDeserializer containing configuration field data
226 ///
227 /// # Returns
228 ///
229 /// Reconstructed AdamConfig instance on success, or SerializationError on failure
230 ///
231 /// # Required Fields
232 ///
233 /// All fields must be present in the deserializer:
234 /// - **learning_rate**: Must be a valid f32 value
235 /// - **beta1**: Must be a valid f32 value (typically 0.0-1.0)
236 /// - **beta2**: Must be a valid f32 value (typically 0.0-1.0)
237 /// - **eps**: Must be a valid f32 value (typically small positive)
238 /// - **weight_decay**: Must be a valid f32 value (typically 0.0 or small positive)
239 /// - **amsgrad**: Must be a valid boolean value
240 ///
241 /// # Errors
242 ///
243 /// Returns SerializationError if:
244 /// - Any required field is missing from the deserializer
245 /// - Any field contains invalid data type
246 /// - Field extraction fails for any reason
247 ///
248 /// # Performance
249 ///
250 /// - **Time Complexity**: O(1) - Constant time field extraction
251 /// - **Memory Usage**: Minimal allocation for configuration structure
252 /// - **Validation**: Comprehensive field presence and type validation
253 fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
254 Ok(AdamConfig {
255 learning_rate: deserializer.field("learning_rate")?,
256 beta1: deserializer.field("beta1")?,
257 beta2: deserializer.field("beta2")?,
258 eps: deserializer.field("eps")?,
259 weight_decay: deserializer.field("weight_decay")?,
260 amsgrad: deserializer.field("amsgrad")?,
261 })
262 }
263}
264
265impl ToFieldValue for AdamConfig {
266 /// Convert AdamConfig to FieldValue for embedding in larger structures
267 ///
268 /// This method converts the AdamConfig into a FieldValue::Object that can be
269 /// embedded as a field within larger serializable structures. This enables
270 /// AdamConfig to be serialized as part of more complex training configurations
271 /// or model checkpoints while maintaining its structured representation.
272 ///
273 /// # Returns
274 ///
275 /// FieldValue::Object containing all configuration data as key-value pairs
276 ///
277 /// # Object Structure
278 ///
279 /// The returned object contains these fields:
280 /// - "learning_rate": f32 value as FieldValue::F32
281 /// - "beta1": f32 value as FieldValue::F32
282 /// - "beta2": f32 value as FieldValue::F32
283 /// - "eps": f32 value as FieldValue::F32
284 /// - "weight_decay": f32 value as FieldValue::F32
285 /// - "amsgrad": bool value as FieldValue::Bool
286 ///
287 /// # Performance
288 ///
289 /// - **Time Complexity**: O(1) - Constant time field conversion
290 /// - **Memory Usage**: Allocates HashMap for field storage
291 /// - **Conversion**: Direct field-to-FieldValue conversion without copying
292 fn to_field_value(&self) -> FieldValue {
293 let serializer = self.to_serializer();
294 FieldValue::from_object(serializer.fields.into_iter().collect())
295 }
296}
297
298impl FromFieldValue for AdamConfig {
299 /// Create AdamConfig from FieldValue with comprehensive validation
300 ///
301 /// This method reconstructs an AdamConfig instance from a FieldValue::Object,
302 /// performing type validation and field extraction. It's designed to handle
303 /// AdamConfig instances that were embedded as fields within larger serializable
304 /// structures, ensuring proper error handling and detailed error messages.
305 ///
306 /// # Arguments
307 ///
308 /// * `value` - FieldValue containing configuration data (must be Object variant)
309 /// * `field_name` - Name of the field being deserialized for error context
310 ///
311 /// # Returns
312 ///
313 /// Reconstructed AdamConfig instance on success, or SerializationError on failure
314 ///
315 /// # Expected FieldValue Structure
316 ///
317 /// The FieldValue must be an Object variant containing:
318 /// - "learning_rate": Numeric field value
319 /// - "beta1": Numeric field value
320 /// - "beta2": Numeric field value
321 /// - "eps": Numeric field value
322 /// - "weight_decay": Numeric field value
323 /// - "amsgrad": Boolean field value
324 ///
325 /// # Errors
326 ///
327 /// Returns SerializationError if:
328 /// - FieldValue is not an Object variant
329 /// - Any required field is missing from the object
330 /// - Any field has incorrect type or invalid value
331 /// - Deserialization process fails for any reason
332 ///
333 /// # Performance
334 ///
335 /// - **Time Complexity**: O(1) - Constant time field extraction and validation
336 /// - **Memory Usage**: Temporary deserializer allocation for field processing
337 /// - **Error Handling**: Detailed error messages with field name context
338 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
339 match value {
340 FieldValue::Object(fields) => {
341 let mut deserializer = StructDeserializer::from_fields(fields);
342 Self::from_deserializer(&mut deserializer)
343 }
344 _ => Err(SerializationError::ValidationFailed {
345 field: field_name.to_string(),
346 message: format!(
347 "Expected Object for {}, found {}",
348 std::any::type_name::<Self>(),
349 value.type_name()
350 ),
351 }),
352 }
353 }
354}
355
356// ===== Serializable Parameter State =====
357
358/// Serializable representation of Adam parameter optimization state
359///
360/// This structure provides a serializable wrapper around the internal ParameterState,
361/// enabling efficient persistence of Adam optimizer state for individual parameters.
362/// It leverages the Train Station tensor serialization system for optimal storage
363/// efficiency and maintains complete fidelity of momentum and velocity buffers.
364///
365/// # Purpose
366///
367/// SerializableParameterState serves as an intermediate representation that:
368/// - **Preserves state**: Maintains complete Adam optimization state per parameter
369/// - **Enables serialization**: Converts internal state to serializable format
370/// - **Supports AMSGrad**: Handles optional AMSGrad maximum velocity tracking
371/// - **Maintains precision**: Preserves exact floating-point values in buffers
372/// - **Validates shapes**: Ensures parameter shape consistency during restoration
373///
374/// # Fields
375///
376/// * `m` - First moment estimate (momentum) tensor buffer
377/// * `v` - Second moment estimate (velocity) tensor buffer
378/// * `v_hat_max` - Optional maximum velocity tensor for AMSGrad variant
379/// * `step` - Per-parameter step count for bias correction calculations
380///
381/// # Design Rationale
382///
383/// The structure uses direct tensor serialization rather than manual data extraction:
384/// - **Efficiency**: Leverages optimized tensor serialization infrastructure
385/// - **Simplicity**: Eliminates manual buffer management and data copying
386/// - **Consistency**: Maintains alignment with Train Station serialization patterns
387/// - **Reliability**: Reduces serialization bugs through proven tensor serialization
388/// - **Performance**: Optimized memory usage and serialization speed
389///
390/// # Thread Safety
391///
392/// This structure is thread-safe for serialization operations:
393/// - **Immutable serialization**: Serialization does not modify state
394/// - **Clone safety**: Safe to clone across thread boundaries
395/// - **Tensor safety**: Leverages thread-safe tensor operations
396///
397/// # Memory Layout
398///
399/// The structure maintains efficient memory usage:
400/// - **Tensor sharing**: Tensors use reference counting for memory efficiency
401/// - **Optional fields**: AMSGrad state only allocated when needed
402/// - **Minimal overhead**: Small additional memory footprint beyond tensors
403#[derive(Debug, Clone)]
404struct SerializableParameterState {
405 /// First moment estimate (momentum) tensor buffer
406 ///
407 /// Contains the exponentially decaying average of past gradients, used for
408 /// momentum-based parameter updates. Shape matches the associated parameter.
409 m: Tensor,
410
411 /// Second moment estimate (velocity) tensor buffer
412 ///
413 /// Contains the exponentially decaying average of past squared gradients,
414 /// used for adaptive learning rate scaling. Shape matches the associated parameter.
415 v: Tensor,
416
417 /// Optional maximum velocity tensor for AMSGrad variant
418 ///
419 /// When AMSGrad is enabled, this tensor maintains the element-wise maximum
420 /// of all past velocity estimates, providing improved convergence properties.
421 /// None when AMSGrad is disabled. Shape matches the associated parameter when present.
422 v_hat_max: Option<Tensor>,
423
424 /// Per-parameter step count for bias correction
425 ///
426 /// Tracks the number of optimization steps performed for this specific parameter,
427 /// used in Adam's bias correction calculations. Essential for proper optimizer
428 /// behavior when parameters are added at different training stages.
429 step: usize,
430}
431
432impl SerializableParameterState {
433 /// Create SerializableParameterState from internal ParameterState
434 ///
435 /// This method converts the internal ParameterState representation used by
436 /// the Adam optimizer into a serializable format. It performs efficient tensor
437 /// cloning to preserve all optimization state while enabling serialization.
438 /// The conversion maintains exact numerical precision and handles optional
439 /// AMSGrad state appropriately.
440 ///
441 /// # Arguments
442 ///
443 /// * `state` - Reference to the internal ParameterState to convert
444 ///
445 /// # Returns
446 ///
447 /// SerializableParameterState containing all state data ready for serialization
448 ///
449 /// # Conversion Process
450 ///
451 /// The method performs these operations:
452 /// 1. **Momentum cloning**: Clones the momentum tensor (m) with full precision
453 /// 2. **Velocity cloning**: Clones the velocity tensor (v) with full precision
454 /// 3. **AMSGrad handling**: Clones optional AMSGrad state when present
455 /// 4. **Step preservation**: Copies the step count for bias correction
456 ///
457 /// # Performance
458 ///
459 /// - **Time Complexity**: O(1) - Tensor cloning uses reference counting
460 /// - **Memory Usage**: Minimal additional allocation due to tensor sharing
461 /// - **Precision**: Maintains exact floating-point values in all buffers
462 fn from_parameter_state(state: &ParameterState) -> Self {
463 // Use direct tensor cloning - leverages Tensor's efficient serialization
464 // No manual data extraction needed
465 Self {
466 m: state.m.clone(),
467 v: state.v.clone(),
468 v_hat_max: state.v_hat_max.clone(),
469 step: state.step,
470 }
471 }
472
473 /// Convert SerializableParameterState back to internal ParameterState
474 ///
475 /// This method reconstructs the internal ParameterState representation from
476 /// the serializable format, enabling the Adam optimizer to resume training
477 /// with preserved optimization state. The conversion maintains exact numerical
478 /// precision and properly handles optional AMSGrad state restoration.
479 ///
480 /// # Returns
481 ///
482 /// ParameterState instance ready for use by Adam optimizer, or SerializationError on failure
483 ///
484 /// # Reconstruction Process
485 ///
486 /// The method performs these operations:
487 /// 1. **Momentum restoration**: Clones momentum tensor back to internal format
488 /// 2. **Velocity restoration**: Clones velocity tensor back to internal format
489 /// 3. **AMSGrad restoration**: Handles optional AMSGrad state when present
490 /// 4. **Step restoration**: Preserves step count for continued bias correction
491 ///
492 /// # Validation
493 ///
494 /// The method ensures:
495 /// - All tensors have consistent shapes
496 /// - Step count is valid (non-negative)
497 /// - AMSGrad state consistency with configuration
498 /// - Tensor data integrity is maintained
499 ///
500 /// # Performance
501 ///
502 /// - **Time Complexity**: O(1) - Tensor cloning uses reference counting
503 /// - **Memory Usage**: Minimal allocation due to efficient tensor sharing
504 /// - **Precision**: Maintains exact floating-point values from serialization
505 ///
506 /// # Errors
507 ///
508 /// Returns SerializationError if:
509 /// - Tensor shapes are inconsistent
510 /// - Internal tensor state is corrupted
511 /// - Memory allocation fails during reconstruction
512 fn to_parameter_state(&self) -> SerializationResult<ParameterState> {
513 // Direct tensor cloning - no manual memory management needed
514 // Tensors handle their own efficient reconstruction
515 Ok(ParameterState {
516 m: self.m.clone(),
517 v: self.v.clone(),
518 v_hat_max: self.v_hat_max.clone(),
519 step: self.step,
520 })
521 }
522}
523
524impl StructSerializable for SerializableParameterState {
525 /// Convert SerializableParameterState to StructSerializer for serialization
526 ///
527 /// This method serializes all Adam parameter state components into a structured
528 /// format suitable for both JSON and binary serialization. It leverages the
529 /// efficient tensor serialization system to handle momentum and velocity buffers
530 /// while properly managing optional AMSGrad state.
531 ///
532 /// # Returns
533 ///
534 /// StructSerializer containing all parameter state data with field names and values
535 ///
536 /// # Serialized Fields
537 ///
538 /// - **m**: Momentum tensor (first moment estimate)
539 /// - **v**: Velocity tensor (second moment estimate)
540 /// - **v_hat_max**: Optional AMSGrad maximum velocity tensor
541 /// - **step**: Per-parameter step count for bias correction
542 ///
543 /// # Performance
544 ///
545 /// - **Time Complexity**: O(1) - Constant time field serialization
546 /// - **Memory Usage**: Leverages efficient tensor serialization
547 /// - **Precision**: Full floating-point precision preservation in tensors
548 fn to_serializer(&self) -> StructSerializer {
549 StructSerializer::new()
550 .field("m", &self.m)
551 .field("v", &self.v)
552 .field("v_hat_max", &self.v_hat_max)
553 .field("step", &self.step)
554 }
555
556 /// Create SerializableParameterState from StructDeserializer with validation
557 ///
558 /// This method reconstructs a SerializableParameterState from serialized data,
559 /// performing comprehensive validation to ensure all required fields are present
560 /// and contain valid tensor data. It handles both momentum and velocity tensors
561 /// along with optional AMSGrad state and step count information.
562 ///
563 /// # Arguments
564 ///
565 /// * `deserializer` - StructDeserializer containing parameter state field data
566 ///
567 /// # Returns
568 ///
569 /// Reconstructed SerializableParameterState on success, or SerializationError on failure
570 ///
571 /// # Required Fields
572 ///
573 /// All fields must be present in the deserializer:
574 /// - **m**: Momentum tensor with valid shape and data
575 /// - **v**: Velocity tensor with shape matching momentum tensor
576 /// - **v_hat_max**: Optional AMSGrad tensor (None or matching shape)
577 /// - **step**: Valid step count (non-negative integer)
578 ///
579 /// # Validation
580 ///
581 /// The method validates:
582 /// - Tensor field presence and type correctness
583 /// - Shape consistency between momentum and velocity tensors
584 /// - AMSGrad tensor shape consistency when present
585 /// - Step count validity and range
586 ///
587 /// # Performance
588 ///
589 /// - **Time Complexity**: O(1) - Constant time field extraction
590 /// - **Memory Usage**: Efficient tensor deserialization
591 /// - **Validation**: Comprehensive field and type validation
592 fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
593 Ok(Self {
594 m: deserializer.field("m")?,
595 v: deserializer.field("v")?,
596 v_hat_max: deserializer.field("v_hat_max")?,
597 step: deserializer.field("step")?,
598 })
599 }
600}
601
602impl ToFieldValue for SerializableParameterState {
603 /// Convert SerializableParameterState to FieldValue for embedding in collections
604 ///
605 /// This method converts the parameter state into a FieldValue::Object that can be
606 /// stored in collections or embedded within larger serializable structures. It
607 /// maintains the structured representation of all optimization state components
608 /// while enabling flexible serialization patterns.
609 ///
610 /// # Returns
611 ///
612 /// FieldValue::Object containing all parameter state data as key-value pairs
613 ///
614 /// # Object Structure
615 ///
616 /// The returned object contains these fields:
617 /// - "m": Momentum tensor as serialized FieldValue
618 /// - "v": Velocity tensor as serialized FieldValue
619 /// - "v_hat_max": Optional AMSGrad tensor as serialized FieldValue
620 /// - "step": Step count as FieldValue::Usize
621 ///
622 /// # Performance
623 ///
624 /// - **Time Complexity**: O(1) - Constant time field conversion
625 /// - **Memory Usage**: Efficient tensor serialization with minimal overhead
626 /// - **Precision**: Maintains exact tensor data and step count values
627 fn to_field_value(&self) -> FieldValue {
628 let serializer = self.to_serializer();
629 FieldValue::from_object(serializer.fields.into_iter().collect())
630 }
631}
632
633impl FromFieldValue for SerializableParameterState {
634 /// Create SerializableParameterState from FieldValue with comprehensive validation
635 ///
636 /// This method reconstructs a SerializableParameterState from a FieldValue::Object,
637 /// performing type validation and tensor deserialization. It handles the complex
638 /// process of deserializing momentum and velocity tensors along with optional
639 /// AMSGrad state, ensuring data integrity and proper error handling.
640 ///
641 /// # Arguments
642 ///
643 /// * `value` - FieldValue containing parameter state data (must be Object variant)
644 /// * `field_name` - Name of the field being deserialized for error context
645 ///
646 /// # Returns
647 ///
648 /// Reconstructed SerializableParameterState on success, or SerializationError on failure
649 ///
650 /// # Expected FieldValue Structure
651 ///
652 /// The FieldValue must be an Object variant containing:
653 /// - "m": Serialized momentum tensor
654 /// - "v": Serialized velocity tensor
655 /// - "v_hat_max": Optional serialized AMSGrad tensor
656 /// - "step": Step count as numeric value
657 ///
658 /// # Validation
659 ///
660 /// The method validates:
661 /// - FieldValue is Object variant
662 /// - All required fields are present
663 /// - Tensor deserialization succeeds
664 /// - Step count is valid numeric value
665 /// - Tensor shapes are consistent
666 ///
667 /// # Performance
668 ///
669 /// - **Time Complexity**: O(1) - Constant time field extraction and tensor deserialization
670 /// - **Memory Usage**: Efficient tensor deserialization with minimal overhead
671 /// - **Error Handling**: Detailed error messages with field name context
672 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
673 match value {
674 FieldValue::Object(fields) => {
675 let mut deserializer = StructDeserializer::from_fields(fields);
676 Self::from_deserializer(&mut deserializer)
677 }
678 _ => Err(SerializationError::ValidationFailed {
679 field: field_name.to_string(),
680 message: format!(
681 "Expected Object for {}, found {}",
682 std::any::type_name::<Self>(),
683 value.type_name()
684 ),
685 }),
686 }
687 }
688}
689
690// ===== Adam Serialization =====
691
692impl StructSerializable for Adam {
693 /// Convert Adam to StructSerializer for serialization
694 ///
695 /// Serializes all optimizer state including configuration, parameter states,
696 /// and global step count. Parameter linking is not serialized and must be
697 /// done after deserialization.
698 ///
699 /// # Returns
700 ///
701 /// StructSerializer containing all serializable optimizer state
702 fn to_serializer(&self) -> StructSerializer {
703 // Convert parameter states to serializable form
704 let mut serializable_states = HashMap::new();
705 for (param_id, state) in &self.states {
706 serializable_states.insert(
707 *param_id,
708 SerializableParameterState::from_parameter_state(state),
709 );
710 }
711
712 StructSerializer::new()
713 .field("config", &self.config)
714 .field("states", &serializable_states)
715 .field("step_count", &self.step_count)
716 .field("insertion_order", &self.insertion_order)
717 }
718
719 /// Create Adam from StructDeserializer
720 ///
721 /// Reconstructs Adam optimizer from serialized state. Parameters must be
722 /// linked separately using `add_parameter` or `add_parameters`.
723 ///
724 /// # Arguments
725 ///
726 /// * `deserializer` - StructDeserializer containing optimizer data
727 ///
728 /// # Returns
729 ///
730 /// Reconstructed Adam instance without parameter links, or error if deserialization fails
731 fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
732 let config: AdamConfig = deserializer.field("config")?;
733 let serializable_states: HashMap<usize, SerializableParameterState> =
734 deserializer.field("states")?;
735 let step_count: usize = deserializer.field("step_count")?;
736 let insertion_order: Vec<usize> = deserializer.field("insertion_order")?;
737
738 // Reconstruct parameter states from serialized form
739 let mut states = HashMap::new();
740 for (param_id, serializable_state) in serializable_states {
741 states.insert(param_id, serializable_state.to_parameter_state()?);
742 }
743
744 // Create optimizer with reconstructed states - user must call relink_parameters to link tensors
745 Ok(Adam {
746 config,
747 states,
748 step_count,
749 insertion_order,
750 })
751 }
752}
753
754impl FromFieldValue for Adam {
755 /// Create Adam from FieldValue
756 ///
757 /// # Arguments
758 ///
759 /// * `value` - FieldValue containing optimizer data
760 /// * `field_name` - Name of the field being deserialized (for error messages)
761 ///
762 /// # Returns
763 ///
764 /// Reconstructed Adam instance or error if deserialization fails
765 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
766 match value {
767 FieldValue::Object(fields) => {
768 let mut deserializer = StructDeserializer::from_fields(fields);
769 Self::from_deserializer(&mut deserializer)
770 }
771 _ => Err(SerializationError::ValidationFailed {
772 field: field_name.to_string(),
773 message: format!(
774 "Expected Object for {}, found {}",
775 std::any::type_name::<Self>(),
776 value.type_name()
777 ),
778 }),
779 }
780 }
781}
782
783// ===== Serializable Trait Implementation =====
784
785impl crate::serialization::Serializable for Adam {
786 /// Serialize the Adam optimizer to JSON format
787 ///
788 /// This method converts the Adam optimizer into a human-readable JSON string representation
789 /// that includes all optimizer state, configuration, parameter states, and step counts.
790 /// The JSON format is suitable for debugging, configuration files, and cross-language
791 /// interoperability.
792 ///
793 /// # Returns
794 ///
795 /// JSON string representation of the optimizer on success, or `SerializationError` on failure
796 ///
797 /// # Examples
798 ///
799 /// ```
800 /// use train_station::{Tensor, optimizers::Adam};
801 /// use train_station::serialization::Serializable;
802 ///
803 /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
804 /// let mut optimizer = Adam::new();
805 /// optimizer.add_parameter(&weight);
806 ///
807 /// let json = optimizer.to_json().unwrap();
808 /// assert!(!json.is_empty());
809 /// ```
810 fn to_json(&self) -> SerializationResult<String> {
811 <Self as StructSerializable>::to_json(self)
812 }
813
814 /// Deserialize an Adam optimizer from JSON format
815 ///
816 /// This method parses a JSON string and reconstructs an Adam optimizer with all
817 /// saved state. Parameters must be re-linked after deserialization using
818 /// `add_parameter` or `relink_parameters`.
819 ///
820 /// # Arguments
821 ///
822 /// * `json` - JSON string containing serialized optimizer
823 ///
824 /// # Returns
825 ///
826 /// The deserialized optimizer on success, or `SerializationError` on failure
827 ///
828 /// # Examples
829 ///
830 /// ```
831 /// use train_station::{Tensor, optimizers::Adam};
832 /// use train_station::serialization::Serializable;
833 ///
834 /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
835 /// let mut optimizer = Adam::new();
836 /// optimizer.add_parameter(&weight);
837 ///
838 /// let json = optimizer.to_json().unwrap();
839 /// let loaded_optimizer = Adam::from_json(&json).unwrap();
840 /// assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
841 /// ```
842 fn from_json(json: &str) -> SerializationResult<Self> {
843 <Self as StructSerializable>::from_json(json)
844 }
845
846 /// Serialize the Adam optimizer to binary format
847 ///
848 /// This method converts the optimizer into a compact binary representation optimized
849 /// for storage and transmission. The binary format provides maximum performance
850 /// and minimal file sizes compared to JSON.
851 ///
852 /// # Returns
853 ///
854 /// Binary representation of the optimizer on success, or `SerializationError` on failure
855 ///
856 /// # Examples
857 ///
858 /// ```
859 /// use train_station::{Tensor, optimizers::Adam};
860 /// use train_station::serialization::Serializable;
861 ///
862 /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
863 /// let mut optimizer = Adam::new();
864 /// optimizer.add_parameter(&weight);
865 ///
866 /// let binary = optimizer.to_binary().unwrap();
867 /// assert!(!binary.is_empty());
868 /// ```
869 fn to_binary(&self) -> SerializationResult<Vec<u8>> {
870 <Self as StructSerializable>::to_binary(self)
871 }
872
873 /// Deserialize an Adam optimizer from binary format
874 ///
875 /// This method parses binary data and reconstructs an Adam optimizer with all
876 /// saved state. Parameters must be re-linked after deserialization using
877 /// `add_parameter` or `relink_parameters`.
878 ///
879 /// # Arguments
880 ///
881 /// * `data` - Binary data containing serialized optimizer
882 ///
883 /// # Returns
884 ///
885 /// The deserialized optimizer on success, or `SerializationError` on failure
886 ///
887 /// # Examples
888 ///
889 /// ```
890 /// use train_station::{Tensor, optimizers::Adam};
891 /// use train_station::serialization::Serializable;
892 ///
893 /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
894 /// let mut optimizer = Adam::new();
895 /// optimizer.add_parameter(&weight);
896 ///
897 /// let binary = optimizer.to_binary().unwrap();
898 /// let loaded_optimizer = Adam::from_binary(&binary).unwrap();
899 /// assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
900 /// ```
901 fn from_binary(data: &[u8]) -> SerializationResult<Self> {
902 <Self as StructSerializable>::from_binary(data)
903 }
904}
905
906// ===== Utility Methods =====
907
908impl Adam {
909 /// Get the number of saved parameter states for checkpoint validation
910 ///
911 /// This method returns the count of parameter states currently stored in the optimizer,
912 /// which is essential for validating checkpoint integrity and ensuring proper parameter
913 /// re-linking after deserialization. The count includes all parameters that have been
914 /// linked to the optimizer and have accumulated optimization state.
915 ///
916 /// # Returns
917 ///
918 /// Number of parameter states currently stored in the optimizer
919 ///
920 /// # Usage Patterns
921 ///
922 /// ## Checkpoint Validation
923 /// After deserializing an optimizer, this method helps verify that the expected
924 /// number of parameters were saved and can guide the re-linking process.
925 ///
926 /// ## Training Resumption
927 /// When resuming training, compare this count with the number of parameters
928 /// in your model to ensure checkpoint compatibility.
929 ///
930 /// ## State Management
931 /// Use this method to monitor optimizer state growth and memory usage during
932 /// training with dynamic parameter addition.
933 ///
934 /// # Examples
935 ///
936 /// ```
937 /// use train_station::{Tensor, optimizers::Adam};
938 /// use train_station::serialization::Serializable;
939 ///
940 /// let weight = Tensor::ones(vec![10, 5]).with_requires_grad();
941 /// let bias = Tensor::zeros(vec![5]).with_requires_grad();
942 /// let mut optimizer = Adam::new();
943 /// optimizer.add_parameter(&weight);
944 /// optimizer.add_parameter(&bias);
945 ///
946 /// // Check parameter count before serialization
947 /// assert_eq!(optimizer.saved_parameter_count(), 2);
948 ///
949 /// // Serialize and deserialize
950 /// let json = optimizer.to_json().unwrap();
951 /// let loaded_optimizer = Adam::from_json(&json).unwrap();
952 ///
953 /// // Verify parameter count is preserved
954 /// assert_eq!(loaded_optimizer.saved_parameter_count(), 2);
955 /// ```
956 ///
957 /// # Performance
958 ///
959 /// - **Time Complexity**: O(1) - Direct access to internal state count
960 /// - **Memory Usage**: No additional memory allocation
961 /// - **Thread Safety**: Safe to call from multiple threads concurrently
962 pub fn saved_parameter_count(&self) -> usize {
963 self.states.len()
964 }
965}
966
967// ===== Field Value Implementations for Collections =====
968
969impl ToFieldValue for HashMap<usize, SerializableParameterState> {
970 /// Convert parameter states HashMap to FieldValue for serialization
971 ///
972 /// This method converts the HashMap of parameter states into a FieldValue::Object
973 /// suitable for serialization. It handles the conversion of usize keys to string
974 /// format required by the FieldValue::Object representation while preserving
975 /// all parameter state data.
976 ///
977 /// # Returns
978 ///
979 /// FieldValue::Object with string keys and SerializableParameterState values
980 ///
981 /// # Key Conversion
982 ///
983 /// - **Input**: HashMap<usize, SerializableParameterState> with numeric tensor IDs
984 /// - **Output**: FieldValue::Object with string keys for JSON compatibility
985 /// - **Mapping**: Each usize key is converted to string representation
986 /// - **Preservation**: All parameter state data is preserved exactly
987 ///
988 /// # Performance
989 ///
990 /// - **Time Complexity**: O(n) where n is the number of parameter states
991 /// - **Memory Usage**: Allocates new HashMap for string keys
992 /// - **Conversion**: Efficient string conversion for numeric keys
993 fn to_field_value(&self) -> FieldValue {
994 let mut map = HashMap::new();
995 for (key, value) in self {
996 map.insert(key.to_string(), value.to_field_value());
997 }
998 FieldValue::from_object(map)
999 }
1000}
1001
1002impl FromFieldValue for HashMap<usize, SerializableParameterState> {
1003 /// Create parameter states HashMap from FieldValue with validation
1004 ///
1005 /// This method reconstructs the HashMap of parameter states from a FieldValue::Object,
1006 /// performing comprehensive validation and key conversion. It handles the conversion
1007 /// from string keys back to usize tensor IDs while ensuring all parameter state
1008 /// data is properly deserialized and validated.
1009 ///
1010 /// # Arguments
1011 ///
1012 /// * `value` - FieldValue containing parameter states data (must be Object variant)
1013 /// * `field_name` - Name of the field being deserialized for error context
1014 ///
1015 /// # Returns
1016 ///
1017 /// Reconstructed HashMap<usize, SerializableParameterState> on success, or SerializationError on failure
1018 ///
1019 /// # Key Conversion Process
1020 ///
1021 /// 1. **Validation**: Ensures FieldValue is Object variant
1022 /// 2. **Key parsing**: Converts string keys back to usize tensor IDs
1023 /// 3. **State deserialization**: Deserializes each parameter state
1024 /// 4. **Validation**: Validates parameter state integrity
1025 /// 5. **Collection**: Builds final HashMap with proper types
1026 ///
1027 /// # Errors
1028 ///
1029 /// Returns SerializationError if:
1030 /// - FieldValue is not Object variant
1031 /// - Any string key cannot be parsed as usize
1032 /// - Parameter state deserialization fails
1033 /// - Invalid parameter state data is encountered
1034 ///
1035 /// # Performance
1036 ///
1037 /// - **Time Complexity**: O(n) where n is the number of parameter states
1038 /// - **Memory Usage**: Allocates new HashMap with proper key types
1039 /// - **Validation**: Comprehensive key parsing and state validation
1040 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
1041 match value {
1042 FieldValue::Object(fields) => {
1043 let mut map = HashMap::new();
1044 for (key_str, field_value) in fields {
1045 let key = key_str.parse::<usize>().map_err(|_| {
1046 SerializationError::ValidationFailed {
1047 field: field_name.to_string(),
1048 message: format!("Invalid key '{}' in parameter states map", key_str),
1049 }
1050 })?;
1051 let state =
1052 SerializableParameterState::from_field_value(field_value, &key_str)?;
1053 map.insert(key, state);
1054 }
1055 Ok(map)
1056 }
1057 _ => Err(SerializationError::ValidationFailed {
1058 field: field_name.to_string(),
1059 message: format!(
1060 "Expected Object for HashMap<usize, SerializableParameterState>, found {}",
1061 value.type_name()
1062 ),
1063 }),
1064 }
1065 }
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070 use super::*;
1071 use crate::optimizers::Optimizer;
1072 use crate::tensor::core::Tensor;
1073
1074 // ===== AdamConfig Serialization Tests =====
1075
1076 #[test]
1077 fn test_adam_config_json_roundtrip() {
1078 let config = AdamConfig {
1079 learning_rate: 1e-4,
1080 beta1: 0.95,
1081 beta2: 0.9999,
1082 eps: 1e-7,
1083 weight_decay: 1e-5,
1084 amsgrad: true,
1085 };
1086
1087 let json = config.to_json().unwrap();
1088 let loaded_config = AdamConfig::from_json(&json).unwrap();
1089
1090 assert_eq!(config.learning_rate, loaded_config.learning_rate);
1091 assert_eq!(config.beta1, loaded_config.beta1);
1092 assert_eq!(config.beta2, loaded_config.beta2);
1093 assert_eq!(config.eps, loaded_config.eps);
1094 assert_eq!(config.weight_decay, loaded_config.weight_decay);
1095 assert_eq!(config.amsgrad, loaded_config.amsgrad);
1096 }
1097
1098 #[test]
1099 fn test_adam_config_binary_roundtrip() {
1100 let config = AdamConfig {
1101 learning_rate: 2e-3,
1102 beta1: 0.85,
1103 beta2: 0.995,
1104 eps: 1e-9,
1105 weight_decay: 5e-4,
1106 amsgrad: false,
1107 };
1108
1109 let binary = config.to_binary().unwrap();
1110 let loaded_config = AdamConfig::from_binary(&binary).unwrap();
1111
1112 assert_eq!(config.learning_rate, loaded_config.learning_rate);
1113 assert_eq!(config.beta1, loaded_config.beta1);
1114 assert_eq!(config.beta2, loaded_config.beta2);
1115 assert_eq!(config.eps, loaded_config.eps);
1116 assert_eq!(config.weight_decay, loaded_config.weight_decay);
1117 assert_eq!(config.amsgrad, loaded_config.amsgrad);
1118 }
1119
1120 #[test]
1121 fn test_adam_config_field_value_roundtrip() {
1122 let config = AdamConfig {
1123 learning_rate: 3e-4,
1124 beta1: 0.92,
1125 beta2: 0.998,
1126 eps: 1e-6,
1127 weight_decay: 2e-4,
1128 amsgrad: true,
1129 };
1130
1131 let field_value = config.to_field_value();
1132 let loaded_config = AdamConfig::from_field_value(field_value, "config").unwrap();
1133
1134 assert_eq!(config.learning_rate, loaded_config.learning_rate);
1135 assert_eq!(config.beta1, loaded_config.beta1);
1136 assert_eq!(config.beta2, loaded_config.beta2);
1137 assert_eq!(config.eps, loaded_config.eps);
1138 assert_eq!(config.weight_decay, loaded_config.weight_decay);
1139 assert_eq!(config.amsgrad, loaded_config.amsgrad);
1140 }
1141
1142 // ===== Adam Optimizer Serialization Tests =====
1143
1144 #[test]
1145 fn test_adam_optimizer_json_roundtrip() {
1146 let mut weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1147 let mut bias = Tensor::zeros(vec![2, 3]).with_requires_grad();
1148
1149 let mut optimizer = Adam::new();
1150 optimizer.add_parameter(&weight);
1151 optimizer.add_parameter(&bias);
1152
1153 // Perform some steps to create state
1154 let output = weight.add_tensor(&bias);
1155 let mut loss = output.sum();
1156 loss.backward(None);
1157 optimizer.step(&mut [&mut weight, &mut bias]);
1158
1159 // Test serialization
1160 let json = optimizer.to_json().unwrap();
1161 let loaded_optimizer = Adam::from_json(&json).unwrap();
1162
1163 assert_eq!(
1164 optimizer.config().learning_rate,
1165 loaded_optimizer.config().learning_rate
1166 );
1167 assert_eq!(
1168 optimizer.saved_parameter_count(),
1169 loaded_optimizer.saved_parameter_count()
1170 );
1171 }
1172
1173 #[test]
1174 fn test_adam_optimizer_binary_roundtrip() {
1175 let weight = Tensor::ones(vec![5, 2]).with_requires_grad();
1176
1177 let mut optimizer = Adam::with_learning_rate(1e-4);
1178 optimizer.add_parameter(&weight);
1179
1180 // Test serialization
1181 let binary = optimizer.to_binary().unwrap();
1182 let loaded_optimizer = Adam::from_binary(&binary).unwrap();
1183
1184 assert_eq!(
1185 optimizer.config().learning_rate,
1186 loaded_optimizer.config().learning_rate
1187 );
1188 assert_eq!(
1189 optimizer.saved_parameter_count(),
1190 loaded_optimizer.saved_parameter_count()
1191 );
1192 }
1193
1194 #[test]
1195 fn test_adam_parameter_relinking() {
1196 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1197 let mut optimizer = Adam::new();
1198 optimizer.add_parameter(&weight);
1199
1200 // Serialize
1201 let json = optimizer.to_json().unwrap();
1202
1203 // Deserialize
1204 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1205
1206 // After deserialization, saved states should be preserved
1207 assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
1208
1209 // Re-link parameter - this creates a new state since it's a new tensor with different ID
1210 let new_weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1211 loaded_optimizer.add_parameter(&new_weight);
1212
1213 // Now there should be 2 states: the original saved one + the new one
1214 assert_eq!(loaded_optimizer.parameter_count(), 2);
1215 assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1216 }
1217
1218 #[test]
1219 fn test_adam_state_preservation() {
1220 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1221 let mut optimizer = Adam::new();
1222 optimizer.add_parameter(&weight);
1223
1224 // Perform training steps to build up state
1225 for _ in 0..3 {
1226 let output = weight.mul_scalar(2.0);
1227 let mut loss = output.sum();
1228 loss.backward(None);
1229 optimizer.step(&mut [&mut weight]);
1230 optimizer.zero_grad(&mut [&mut weight]);
1231 }
1232
1233 // Serialize and deserialize
1234 let json = optimizer.to_json().unwrap();
1235 let loaded_optimizer = Adam::from_json(&json).unwrap();
1236
1237 // Check that states were preserved
1238 assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
1239 assert_eq!(
1240 loaded_optimizer.config().learning_rate,
1241 optimizer.config().learning_rate
1242 );
1243 }
1244
1245 #[test]
1246 fn test_relink_parameters_success() {
1247 // Create original optimizer with parameters
1248 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1249 let bias = Tensor::zeros(vec![3]).with_requires_grad();
1250
1251 let mut optimizer = Adam::new();
1252 optimizer.add_parameter(&weight);
1253 optimizer.add_parameter(&bias);
1254
1255 // Serialize
1256 let json = optimizer.to_json().unwrap();
1257
1258 // Create new parameters with same shapes but different IDs
1259 let new_weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1260 let new_bias = Tensor::zeros(vec![3]).with_requires_grad();
1261
1262 // Deserialize and re-link
1263 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1264 loaded_optimizer
1265 .relink_parameters(&[&new_weight, &new_bias])
1266 .unwrap();
1267
1268 // Verify re-linking worked
1269 assert_eq!(loaded_optimizer.parameter_count(), 2);
1270 assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1271 assert!(loaded_optimizer.is_parameter_linked(&new_bias));
1272 }
1273
1274 #[test]
1275 fn test_relink_parameters_shape_mismatch() {
1276 // Create original optimizer
1277 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1278 let mut optimizer = Adam::new();
1279 optimizer.add_parameter(&weight);
1280
1281 // Serialize
1282 let json = optimizer.to_json().unwrap();
1283
1284 // Create new parameter with different shape
1285 let new_weight = Tensor::ones(vec![3, 2]).with_requires_grad(); // Different shape!
1286
1287 // Deserialize and try to re-link
1288 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1289 let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1290
1291 // Should fail with shape mismatch error
1292 assert!(result.is_err());
1293 assert!(result.unwrap_err().contains("Shape mismatch"));
1294 }
1295
1296 #[test]
1297 fn test_relink_parameters_count_mismatch() {
1298 // Create original optimizer with 2 parameters
1299 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1300 let bias = Tensor::zeros(vec![3]).with_requires_grad();
1301
1302 let mut optimizer = Adam::new();
1303 optimizer.add_parameter(&weight);
1304 optimizer.add_parameter(&bias);
1305
1306 // Serialize
1307 let json = optimizer.to_json().unwrap();
1308
1309 // Create only 1 new parameter
1310 let new_weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1311
1312 // Deserialize and try to re-link with wrong count
1313 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1314 let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1315
1316 // Should fail with count mismatch error
1317 assert!(result.is_err());
1318 assert!(result.unwrap_err().contains("Parameter count mismatch"));
1319 }
1320
1321 #[test]
1322 fn test_relink_parameters_requires_grad() {
1323 // Create original optimizer
1324 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1325 let mut optimizer = Adam::new();
1326 optimizer.add_parameter(&weight);
1327
1328 // Serialize
1329 let json = optimizer.to_json().unwrap();
1330
1331 // Create new parameter without requires_grad
1332 let new_weight = Tensor::ones(vec![2, 3]); // No requires_grad!
1333
1334 // Deserialize and try to re-link
1335 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1336 let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1337
1338 // Should fail with requires_grad error
1339 assert!(result.is_err());
1340 assert!(result.unwrap_err().contains("must require gradients"));
1341 }
1342
1343 #[test]
1344 fn test_relink_preserves_state() {
1345 // Create original optimizer and train it
1346 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1347 let mut optimizer = Adam::new();
1348 optimizer.add_parameter(&weight);
1349
1350 // Perform some training to build up state
1351 for _ in 0..2 {
1352 let output = weight.mul_scalar(2.0);
1353 let mut loss = output.sum();
1354 loss.backward(None);
1355 optimizer.step(&mut [&mut weight]);
1356 optimizer.zero_grad(&mut [&mut weight]);
1357 }
1358
1359 // Get state before serialization
1360 let original_step_count = optimizer.step_count;
1361
1362 // Serialize
1363 let json = optimizer.to_json().unwrap();
1364
1365 // Create new parameter
1366 let new_weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1367
1368 // Deserialize and re-link
1369 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1370 loaded_optimizer.relink_parameters(&[&new_weight]).unwrap();
1371
1372 // Verify state is preserved
1373 assert_eq!(loaded_optimizer.step_count, original_step_count);
1374 assert_eq!(loaded_optimizer.parameter_count(), 1);
1375 assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1376 }
1377
1378 // ===== Serializable Trait Tests =====
1379
1380 #[test]
1381 fn test_serializable_json_methods() {
1382 // Create and populate test optimizer
1383 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1384 let bias = Tensor::zeros(vec![3]).with_requires_grad();
1385
1386 let mut optimizer = Adam::with_learning_rate(1e-3);
1387 optimizer.add_parameter(&weight);
1388 optimizer.add_parameter(&bias);
1389
1390 // Test to_json method
1391 let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1392 assert!(!json.is_empty());
1393 assert!(json.contains("config"));
1394 assert!(json.contains("states"));
1395 assert!(json.contains("step_count"));
1396 assert!(json.contains("learning_rate"));
1397
1398 // Test from_json method
1399 let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1400 assert_eq!(
1401 optimizer.config().learning_rate,
1402 restored.config().learning_rate
1403 );
1404 assert_eq!(optimizer.config().beta1, restored.config().beta1);
1405 assert_eq!(optimizer.config().beta2, restored.config().beta2);
1406 assert_eq!(optimizer.config().eps, restored.config().eps);
1407 assert_eq!(
1408 optimizer.config().weight_decay,
1409 restored.config().weight_decay
1410 );
1411 assert_eq!(optimizer.config().amsgrad, restored.config().amsgrad);
1412 assert_eq!(
1413 optimizer.saved_parameter_count(),
1414 restored.saved_parameter_count()
1415 );
1416 assert_eq!(optimizer.step_count, restored.step_count);
1417 }
1418
1419 #[test]
1420 fn test_serializable_binary_methods() {
1421 // Create and populate test optimizer
1422 let weight = Tensor::ones(vec![3, 4]).with_requires_grad();
1423 let mut optimizer = Adam::with_config(AdamConfig {
1424 learning_rate: 2e-4,
1425 beta1: 0.95,
1426 beta2: 0.999,
1427 eps: 1e-7,
1428 weight_decay: 1e-4,
1429 amsgrad: true,
1430 });
1431 optimizer.add_parameter(&weight);
1432
1433 // Test to_binary method
1434 let binary = <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1435 assert!(!binary.is_empty());
1436
1437 // Test from_binary method
1438 let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1439 assert_eq!(
1440 optimizer.config().learning_rate,
1441 restored.config().learning_rate
1442 );
1443 assert_eq!(optimizer.config().beta1, restored.config().beta1);
1444 assert_eq!(optimizer.config().beta2, restored.config().beta2);
1445 assert_eq!(optimizer.config().eps, restored.config().eps);
1446 assert_eq!(
1447 optimizer.config().weight_decay,
1448 restored.config().weight_decay
1449 );
1450 assert_eq!(optimizer.config().amsgrad, restored.config().amsgrad);
1451 assert_eq!(
1452 optimizer.saved_parameter_count(),
1453 restored.saved_parameter_count()
1454 );
1455 assert_eq!(optimizer.step_count, restored.step_count);
1456 }
1457
1458 #[test]
1459 fn test_serializable_file_io_json() {
1460 use crate::serialization::{Format, Serializable};
1461 use std::fs;
1462 use std::path::Path;
1463
1464 // Create test optimizer
1465 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1466 let bias = Tensor::zeros(vec![2]).with_requires_grad();
1467
1468 let mut optimizer = Adam::with_learning_rate(5e-4);
1469 optimizer.add_parameter(&weight);
1470 optimizer.add_parameter(&bias);
1471
1472 let json_path = "test_adam_serializable.json";
1473
1474 // Test save method with JSON format
1475 Serializable::save(&optimizer, json_path, Format::Json).unwrap();
1476 assert!(Path::new(json_path).exists());
1477
1478 // Test load method with JSON format
1479 let loaded_optimizer = Adam::load(json_path, Format::Json).unwrap();
1480 assert_eq!(
1481 optimizer.config().learning_rate,
1482 loaded_optimizer.config().learning_rate
1483 );
1484 assert_eq!(
1485 optimizer.saved_parameter_count(),
1486 loaded_optimizer.saved_parameter_count()
1487 );
1488
1489 // Test save_to_writer method
1490 let json_path_2 = "test_adam_serializable_writer.json";
1491 {
1492 let file = std::fs::File::create(json_path_2).unwrap();
1493 let mut writer = std::io::BufWriter::new(file);
1494 Serializable::save_to_writer(&optimizer, &mut writer, Format::Json).unwrap();
1495 }
1496 assert!(Path::new(json_path_2).exists());
1497
1498 // Test load_from_reader method
1499 {
1500 let file = std::fs::File::open(json_path_2).unwrap();
1501 let mut reader = std::io::BufReader::new(file);
1502 let loaded_optimizer = Adam::load_from_reader(&mut reader, Format::Json).unwrap();
1503 assert_eq!(
1504 optimizer.config().learning_rate,
1505 loaded_optimizer.config().learning_rate
1506 );
1507 }
1508
1509 // Cleanup test files
1510 let _ = fs::remove_file(json_path);
1511 let _ = fs::remove_file(json_path_2);
1512 }
1513
1514 #[test]
1515 fn test_serializable_file_io_binary() {
1516 use crate::serialization::{Format, Serializable};
1517 use std::fs;
1518 use std::path::Path;
1519
1520 // Create test optimizer
1521 let weight = Tensor::ones(vec![3, 3]).with_requires_grad();
1522 let mut optimizer = Adam::with_config(AdamConfig {
1523 learning_rate: 1e-3,
1524 beta1: 0.9,
1525 beta2: 0.999,
1526 eps: 1e-8,
1527 weight_decay: 0.0,
1528 amsgrad: false,
1529 });
1530 optimizer.add_parameter(&weight);
1531
1532 let binary_path = "test_adam_serializable.bin";
1533
1534 // Test save method with binary format
1535 Serializable::save(&optimizer, binary_path, Format::Binary).unwrap();
1536 assert!(Path::new(binary_path).exists());
1537
1538 // Test load method with binary format
1539 let loaded_optimizer = Adam::load(binary_path, Format::Binary).unwrap();
1540 assert_eq!(
1541 optimizer.config().learning_rate,
1542 loaded_optimizer.config().learning_rate
1543 );
1544 assert_eq!(
1545 optimizer.saved_parameter_count(),
1546 loaded_optimizer.saved_parameter_count()
1547 );
1548
1549 // Test save_to_writer method
1550 let binary_path_2 = "test_adam_serializable_writer.bin";
1551 {
1552 let file = std::fs::File::create(binary_path_2).unwrap();
1553 let mut writer = std::io::BufWriter::new(file);
1554 Serializable::save_to_writer(&optimizer, &mut writer, Format::Binary).unwrap();
1555 }
1556 assert!(Path::new(binary_path_2).exists());
1557
1558 // Test load_from_reader method
1559 {
1560 let file = std::fs::File::open(binary_path_2).unwrap();
1561 let mut reader = std::io::BufReader::new(file);
1562 let loaded_optimizer = Adam::load_from_reader(&mut reader, Format::Binary).unwrap();
1563 assert_eq!(
1564 optimizer.config().learning_rate,
1565 loaded_optimizer.config().learning_rate
1566 );
1567 }
1568
1569 // Cleanup test files
1570 let _ = fs::remove_file(binary_path);
1571 let _ = fs::remove_file(binary_path_2);
1572 }
1573
1574 #[test]
1575 fn test_serializable_large_optimizer_performance() {
1576 // Create a large optimizer to test performance characteristics
1577 let mut optimizer = Adam::with_learning_rate(1e-4);
1578
1579 // Add multiple parameters of different sizes
1580 for i in 0..5 {
1581 let size = 10 + i * 5;
1582 let param = Tensor::ones(vec![size, size]).with_requires_grad();
1583 optimizer.add_parameter(¶m);
1584 }
1585
1586 // Test JSON serialization
1587 let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1588 assert!(!json.is_empty());
1589 let restored_json = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1590 assert_eq!(
1591 optimizer.config().learning_rate,
1592 restored_json.config().learning_rate
1593 );
1594 assert_eq!(
1595 optimizer.saved_parameter_count(),
1596 restored_json.saved_parameter_count()
1597 );
1598
1599 // Test binary serialization
1600 let binary = <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1601 assert!(!binary.is_empty());
1602 // Binary format should be efficient (this is informational, not a requirement)
1603 println!(
1604 "JSON size: {} bytes, Binary size: {} bytes",
1605 json.len(),
1606 binary.len()
1607 );
1608
1609 let restored_binary =
1610 <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1611 assert_eq!(
1612 optimizer.config().learning_rate,
1613 restored_binary.config().learning_rate
1614 );
1615 assert_eq!(
1616 optimizer.saved_parameter_count(),
1617 restored_binary.saved_parameter_count()
1618 );
1619
1620 // Verify all configurations match
1621 assert_eq!(optimizer.config().beta1, restored_binary.config().beta1);
1622 assert_eq!(optimizer.config().beta2, restored_binary.config().beta2);
1623 assert_eq!(optimizer.config().eps, restored_binary.config().eps);
1624 assert_eq!(
1625 optimizer.config().weight_decay,
1626 restored_binary.config().weight_decay
1627 );
1628 assert_eq!(optimizer.config().amsgrad, restored_binary.config().amsgrad);
1629 }
1630
1631 #[test]
1632 fn test_serializable_error_handling() {
1633 // Test invalid JSON
1634 let invalid_json = r#"{"invalid": "json", "structure": true}"#;
1635 let result = <Adam as crate::serialization::Serializable>::from_json(invalid_json);
1636 assert!(result.is_err());
1637
1638 // Test empty JSON
1639 let empty_json = "{}";
1640 let result = <Adam as crate::serialization::Serializable>::from_json(empty_json);
1641 assert!(result.is_err());
1642
1643 // Test invalid binary data
1644 let invalid_binary = vec![1, 2, 3, 4, 5];
1645 let result = <Adam as crate::serialization::Serializable>::from_binary(&invalid_binary);
1646 assert!(result.is_err());
1647
1648 // Test empty binary data
1649 let empty_binary = vec![];
1650 let result = <Adam as crate::serialization::Serializable>::from_binary(&empty_binary);
1651 assert!(result.is_err());
1652 }
1653
1654 #[test]
1655 fn test_serializable_different_configurations() {
1656 let test_configs = vec![
1657 // Default configuration
1658 AdamConfig::default(),
1659 // High learning rate
1660 AdamConfig {
1661 learning_rate: 1e-2,
1662 beta1: 0.9,
1663 beta2: 0.999,
1664 eps: 1e-8,
1665 weight_decay: 0.0,
1666 amsgrad: false,
1667 },
1668 // AMSGrad enabled
1669 AdamConfig {
1670 learning_rate: 1e-4,
1671 beta1: 0.9,
1672 beta2: 0.999,
1673 eps: 1e-8,
1674 weight_decay: 1e-4,
1675 amsgrad: true,
1676 },
1677 // Custom betas
1678 AdamConfig {
1679 learning_rate: 5e-4,
1680 beta1: 0.95,
1681 beta2: 0.9999,
1682 eps: 1e-7,
1683 weight_decay: 1e-5,
1684 amsgrad: false,
1685 },
1686 ];
1687
1688 for config in test_configs {
1689 // Create optimizer with specific configuration
1690 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1691 let mut optimizer = Adam::with_config(config.clone());
1692 optimizer.add_parameter(&weight);
1693
1694 // Test JSON roundtrip
1695 let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1696 let restored_json =
1697 <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1698 assert_eq!(config.learning_rate, restored_json.config().learning_rate);
1699 assert_eq!(config.beta1, restored_json.config().beta1);
1700 assert_eq!(config.beta2, restored_json.config().beta2);
1701 assert_eq!(config.eps, restored_json.config().eps);
1702 assert_eq!(config.weight_decay, restored_json.config().weight_decay);
1703 assert_eq!(config.amsgrad, restored_json.config().amsgrad);
1704
1705 // Test binary roundtrip
1706 let binary =
1707 <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1708 let restored_binary =
1709 <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1710 assert_eq!(config.learning_rate, restored_binary.config().learning_rate);
1711 assert_eq!(config.beta1, restored_binary.config().beta1);
1712 assert_eq!(config.beta2, restored_binary.config().beta2);
1713 assert_eq!(config.eps, restored_binary.config().eps);
1714 assert_eq!(config.weight_decay, restored_binary.config().weight_decay);
1715 assert_eq!(config.amsgrad, restored_binary.config().amsgrad);
1716 }
1717 }
1718
1719 #[test]
1720 fn test_serializable_edge_cases() {
1721 // Test optimizer with no parameters
1722 let empty_optimizer = Adam::new();
1723 let json = <Adam as crate::serialization::Serializable>::to_json(&empty_optimizer).unwrap();
1724 let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1725 assert_eq!(
1726 empty_optimizer.saved_parameter_count(),
1727 restored.saved_parameter_count()
1728 );
1729 assert_eq!(empty_optimizer.step_count, restored.step_count);
1730
1731 let binary =
1732 <Adam as crate::serialization::Serializable>::to_binary(&empty_optimizer).unwrap();
1733 let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1734 assert_eq!(
1735 empty_optimizer.saved_parameter_count(),
1736 restored.saved_parameter_count()
1737 );
1738 assert_eq!(empty_optimizer.step_count, restored.step_count);
1739
1740 // Test optimizer with extreme configuration values
1741 let extreme_config = AdamConfig {
1742 learning_rate: 1e-10, // Very small learning rate
1743 beta1: 0.999999, // Very high beta1
1744 beta2: 0.000001, // Very low beta2
1745 eps: 1e-15, // Very small epsilon
1746 weight_decay: 1e-1, // High weight decay
1747 amsgrad: true,
1748 };
1749
1750 let weight = Tensor::ones(vec![1]).with_requires_grad();
1751 let mut extreme_optimizer = Adam::with_config(extreme_config.clone());
1752 extreme_optimizer.add_parameter(&weight);
1753
1754 let json =
1755 <Adam as crate::serialization::Serializable>::to_json(&extreme_optimizer).unwrap();
1756 let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1757 assert_eq!(
1758 extreme_config.learning_rate,
1759 restored.config().learning_rate
1760 );
1761 assert_eq!(extreme_config.beta1, restored.config().beta1);
1762 assert_eq!(extreme_config.beta2, restored.config().beta2);
1763 assert_eq!(extreme_config.eps, restored.config().eps);
1764 assert_eq!(extreme_config.weight_decay, restored.config().weight_decay);
1765 assert_eq!(extreme_config.amsgrad, restored.config().amsgrad);
1766
1767 let binary =
1768 <Adam as crate::serialization::Serializable>::to_binary(&extreme_optimizer).unwrap();
1769 let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1770 assert_eq!(
1771 extreme_config.learning_rate,
1772 restored.config().learning_rate
1773 );
1774 assert_eq!(extreme_config.beta1, restored.config().beta1);
1775 assert_eq!(extreme_config.beta2, restored.config().beta2);
1776 assert_eq!(extreme_config.eps, restored.config().eps);
1777 assert_eq!(extreme_config.weight_decay, restored.config().weight_decay);
1778 assert_eq!(extreme_config.amsgrad, restored.config().amsgrad);
1779 }
1780}