train_station/optimizers/adam/serialization.rs
1//! Comprehensive serialization support for Adam optimizer state and configuration
2//!
3//! This module provides complete serialization capabilities for the Adam optimizer, enabling
4//! model checkpointing, state persistence, and cross-platform optimizer state transfer.
5//! The serialization system supports both human-readable JSON and efficient binary formats
6//! with perfect roundtrip fidelity and seamless parameter re-linking.
7//!
8//! # Purpose
9//!
10//! The serialization module serves as the persistence layer for Adam optimizer training,
11//! providing essential functionality for:
12//! - **Model checkpointing**: Save and restore optimizer state during training
13//! - **Training resumption**: Continue training from saved checkpoints
14//! - **State transfer**: Move optimizer state between different environments
15//! - **Debugging support**: Human-readable JSON format for state inspection
16//! - **Performance optimization**: Efficient binary format for production use
17//! - **Cross-platform compatibility**: Consistent serialization across systems
18//!
19//! # Supported Components
20//!
21//! ## AdamConfig Serialization
22//! - **Learning rate**: Base learning rate for parameter updates
23//! - **Beta parameters**: Momentum decay rates (beta1, beta2)
24//! - **Epsilon**: Numerical stability constant
25//! - **Weight decay**: L2 regularization coefficient
26//! - **AMSGrad flag**: Whether to use AMSGrad variant
27//!
28//! ## Parameter State Serialization
29//! - **Momentum buffers**: First moment estimates for each parameter
30//! - **Velocity buffers**: Second moment estimates for each parameter
31//! - **AMSGrad state**: Maximum velocity buffers when AMSGrad is enabled
32//! - **Step counts**: Per-parameter step counts for bias correction
33//! - **Shape information**: Parameter shapes for validation during re-linking
34//!
35//! ## Global Optimizer State
36//! - **Global step count**: Total optimization steps performed
37//! - **Parameter insertion order**: Order for consistent parameter re-linking
38//! - **Configuration state**: Complete hyperparameter configuration
39//! - **State validation**: Integrity checks for serialized data
40//!
41//! # Serialization Formats
42//!
43//! ## JSON Format
44//! - **Human-readable**: Easy to inspect and debug optimizer state
45//! - **Cross-language**: Compatible with other JSON-parsing systems
46//! - **Configuration files**: Suitable for storing optimizer configurations
47//! - **Debugging**: Clear structure for troubleshooting training issues
48//! - **Interoperability**: Exchange optimizer state with external tools
49//!
50//! ## Binary Format
51//! - **Compact storage**: Minimal file sizes for production deployment
52//! - **Fast I/O**: Optimized for quick save/load operations
53//! - **Performance**: Reduced serialization overhead during training
54//! - **Bandwidth efficiency**: Minimal network transfer requirements
55//! - **Production ready**: Optimized for high-performance training workflows
56//!
57//! # Usage Patterns
58//!
59//! ## Basic Serialization
60//! ```
61//! use train_station::{Tensor, optimizers::Adam};
62//! use train_station::serialization::Serializable;
63//!
64//! // Create optimizer with parameters
65//! let weight = Tensor::ones(vec![10, 5]).with_requires_grad();
66//! let mut optimizer = Adam::new();
67//! optimizer.add_parameter(&weight);
68//!
69//! // Serialize optimizer state
70//! let json = optimizer.to_json().unwrap();
71//! let binary = optimizer.to_binary().unwrap();
72//!
73//! // Deserialize optimizer
74//! let loaded_optimizer = Adam::from_json(&json).unwrap();
75//! assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
76//! ```
77//!
78//! ## Training Checkpointing
79//! ```
80//! use train_station::{Tensor, optimizers::Adam};
81//! use train_station::serialization::{Serializable, Format};
82//!
83//! let mut weight = Tensor::randn(vec![100, 50], None).with_requires_grad();
84//! let mut optimizer = Adam::new();
85//! optimizer.add_parameter(&weight);
86//!
87//! // Training loop with checkpointing
88//! for epoch in 0..10 {
89//! // ... training logic ...
90//!
91//! // Save checkpoint every 5 epochs
92//! if epoch % 5 == 0 {
93//! let temp_dir = std::env::temp_dir();
94//! let checkpoint_path = temp_dir.join(format!("checkpoint_epoch_{}.json", epoch));
95//! optimizer.save(&checkpoint_path, Format::Json).unwrap();
96//!
97//! // Cleanup for example
98//! std::fs::remove_file(&checkpoint_path).ok();
99//! }
100//! }
101//! ```
102//!
103//! ## Parameter Re-linking
104//! ```
105//! use train_station::{Tensor, optimizers::Adam};
106//! use train_station::serialization::Serializable;
107//!
108//! // Original training setup
109//! let weight = Tensor::ones(vec![5, 5]).with_requires_grad();
110//! let bias = Tensor::zeros(vec![5]).with_requires_grad();
111//! let mut optimizer = Adam::new();
112//! optimizer.add_parameter(&weight);
113//! optimizer.add_parameter(&bias);
114//!
115//! // Serialize optimizer state
116//! let json = optimizer.to_json().unwrap();
117//!
118//! // Later: create new parameters with same shapes
119//! let new_weight = Tensor::ones(vec![5, 5]).with_requires_grad();
120//! let new_bias = Tensor::zeros(vec![5]).with_requires_grad();
121//!
122//! // Restore optimizer and re-link parameters
123//! let mut loaded_optimizer = Adam::from_json(&json).unwrap();
124//! loaded_optimizer.relink_parameters(&[&new_weight, &new_bias]).unwrap();
125//!
126//! assert!(loaded_optimizer.is_parameter_linked(&new_weight));
127//! assert!(loaded_optimizer.is_parameter_linked(&new_bias));
128//! ```
129//!
130//! # Architecture Design
131//!
132//! ## Serialization Strategy
133//! - **Unified interface**: All serialization through StructSerializable trait
134//! - **Type safety**: Strong typing prevents serialization errors
135//! - **Validation**: Comprehensive validation during deserialization
136//! - **Error handling**: Detailed error messages for debugging
137//! - **Memory efficiency**: Optimized memory usage during serialization
138//!
139//! ## Parameter State Management
140//! - **ID-based tracking**: Parameters tracked by unique tensor IDs
141//! - **Shape validation**: Ensures parameter compatibility during re-linking
142//! - **Insertion order**: Maintains parameter order for consistent re-linking
143//! - **State preservation**: Complete momentum and velocity buffer preservation
144//! - **AMSGrad support**: Full AMSGrad state serialization when enabled
145//!
146//! ## Performance Characteristics
147//! - **Linear complexity**: O(n) serialization time with parameter count
148//! - **Minimal overhead**: Efficient serialization with minimal memory allocation
149//! - **Streaming support**: Support for streaming serialization to files
150//! - **Compression ready**: Binary format suitable for compression
151//! - **Concurrent safe**: Thread-safe serialization operations
152//!
153//! # Thread Safety
154//!
155//! All serialization operations are thread-safe:
156//! - **Immutable serialization**: Serialization does not modify optimizer state
157//! - **Concurrent reads**: Multiple threads can serialize the same optimizer
158//! - **Deserialization safety**: Deserialization creates new optimizer instances
159//! - **Parameter linking**: Re-linking operations are thread-safe
160//!
161//! # Integration with Train Station
162//!
163//! The serialization module integrates seamlessly with the broader Train Station ecosystem:
164//! - **Tensor serialization**: Leverages efficient tensor serialization for parameter states
165//! - **GradTrack compatibility**: Maintains gradient tracking requirements during re-linking
166//! - **Device management**: Preserves device placement information
167//! - **Memory management**: Efficient memory usage aligned with Train Station patterns
168//! - **Error handling**: Consistent error handling with Train Station conventions
169
170use super::{Adam, AdamConfig, ParameterState};
171use crate::serialization::{
172 FieldValue, FromFieldValue, SerializationError, SerializationResult, StructDeserializer,
173 StructSerializable, StructSerializer, ToFieldValue,
174};
175use crate::tensor::core::Tensor;
176use std::collections::HashMap;
177
178// ===== AdamConfig Serialization =====
179
180impl StructSerializable for AdamConfig {
181 /// Convert AdamConfig to StructSerializer for comprehensive serialization
182 ///
183 /// This method serializes all Adam hyperparameters into a structured format suitable
184 /// for both JSON and binary serialization. Every field is essential for proper optimizer
185 /// reconstruction and training continuation. The serialization preserves exact floating-point
186 /// values and boolean flags to ensure identical behavior after deserialization.
187 ///
188 /// # Returns
189 ///
190 /// StructSerializer containing all configuration data with field names and values
191 ///
192 /// # Serialized Fields
193 ///
194 /// - **learning_rate**: Base learning rate for parameter updates
195 /// - **beta1**: Exponential decay rate for first moment estimates
196 /// - **beta2**: Exponential decay rate for second moment estimates
197 /// - **eps**: Small constant for numerical stability in denominator
198 /// - **weight_decay**: L2 regularization coefficient
199 /// - **amsgrad**: Boolean flag for AMSGrad variant usage
200 ///
201 /// # Performance
202 ///
203 /// - **Time Complexity**: O(1) - Constant time field serialization
204 /// - **Memory Usage**: Minimal allocation for field storage
205 /// - **Precision**: Full floating-point precision preservation
206 fn to_serializer(&self) -> StructSerializer {
207 StructSerializer::new()
208 .field("learning_rate", &self.learning_rate)
209 .field("beta1", &self.beta1)
210 .field("beta2", &self.beta2)
211 .field("eps", &self.eps)
212 .field("weight_decay", &self.weight_decay)
213 .field("amsgrad", &self.amsgrad)
214 }
215
216 /// Create AdamConfig from StructDeserializer with full validation
217 ///
218 /// This method reconstructs an AdamConfig instance from serialized hyperparameters,
219 /// performing comprehensive validation to ensure all required fields are present
220 /// and contain valid values. The deserialization process maintains exact floating-point
221 /// precision and validates that all hyperparameters are within reasonable ranges.
222 ///
223 /// # Arguments
224 ///
225 /// * `deserializer` - StructDeserializer containing configuration field data
226 ///
227 /// # Returns
228 ///
229 /// Reconstructed AdamConfig instance on success, or SerializationError on failure
230 ///
231 /// # Required Fields
232 ///
233 /// All fields must be present in the deserializer:
234 /// - **learning_rate**: Must be a valid f32 value
235 /// - **beta1**: Must be a valid f32 value (typically 0.0-1.0)
236 /// - **beta2**: Must be a valid f32 value (typically 0.0-1.0)
237 /// - **eps**: Must be a valid f32 value (typically small positive)
238 /// - **weight_decay**: Must be a valid f32 value (typically 0.0 or small positive)
239 /// - **amsgrad**: Must be a valid boolean value
240 ///
241 /// # Errors
242 ///
243 /// Returns SerializationError if:
244 /// - Any required field is missing from the deserializer
245 /// - Any field contains invalid data type
246 /// - Field extraction fails for any reason
247 ///
248 /// # Performance
249 ///
250 /// - **Time Complexity**: O(1) - Constant time field extraction
251 /// - **Memory Usage**: Minimal allocation for configuration structure
252 /// - **Validation**: Comprehensive field presence and type validation
253 fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
254 Ok(AdamConfig {
255 learning_rate: deserializer.field("learning_rate")?,
256 beta1: deserializer.field("beta1")?,
257 beta2: deserializer.field("beta2")?,
258 eps: deserializer.field("eps")?,
259 weight_decay: deserializer.field("weight_decay")?,
260 amsgrad: deserializer.field("amsgrad")?,
261 })
262 }
263}
264
265impl ToFieldValue for AdamConfig {
266 /// Convert AdamConfig to FieldValue for embedding in larger structures
267 ///
268 /// This method converts the AdamConfig into a FieldValue::Object that can be
269 /// embedded as a field within larger serializable structures. This enables
270 /// AdamConfig to be serialized as part of more complex training configurations
271 /// or model checkpoints while maintaining its structured representation.
272 ///
273 /// # Returns
274 ///
275 /// FieldValue::Object containing all configuration data as key-value pairs
276 ///
277 /// # Object Structure
278 ///
279 /// The returned object contains these fields:
280 /// - "learning_rate": f32 value as FieldValue::F32
281 /// - "beta1": f32 value as FieldValue::F32
282 /// - "beta2": f32 value as FieldValue::F32
283 /// - "eps": f32 value as FieldValue::F32
284 /// - "weight_decay": f32 value as FieldValue::F32
285 /// - "amsgrad": bool value as FieldValue::Bool
286 ///
287 /// # Performance
288 ///
289 /// - **Time Complexity**: O(1) - Constant time field conversion
290 /// - **Memory Usage**: Allocates HashMap for field storage
291 /// - **Conversion**: Direct field-to-FieldValue conversion without copying
292 fn to_field_value(&self) -> FieldValue {
293 let serializer = self.to_serializer();
294 FieldValue::from_object(serializer.fields.into_iter().collect())
295 }
296}
297
298impl FromFieldValue for AdamConfig {
299 /// Create AdamConfig from FieldValue with comprehensive validation
300 ///
301 /// This method reconstructs an AdamConfig instance from a FieldValue::Object,
302 /// performing type validation and field extraction. It's designed to handle
303 /// AdamConfig instances that were embedded as fields within larger serializable
304 /// structures, ensuring proper error handling and detailed error messages.
305 ///
306 /// # Arguments
307 ///
308 /// * `value` - FieldValue containing configuration data (must be Object variant)
309 /// * `field_name` - Name of the field being deserialized for error context
310 ///
311 /// # Returns
312 ///
313 /// Reconstructed AdamConfig instance on success, or SerializationError on failure
314 ///
315 /// # Expected FieldValue Structure
316 ///
317 /// The FieldValue must be an Object variant containing:
318 /// - "learning_rate": Numeric field value
319 /// - "beta1": Numeric field value
320 /// - "beta2": Numeric field value
321 /// - "eps": Numeric field value
322 /// - "weight_decay": Numeric field value
323 /// - "amsgrad": Boolean field value
324 ///
325 /// # Errors
326 ///
327 /// Returns SerializationError if:
328 /// - FieldValue is not an Object variant
329 /// - Any required field is missing from the object
330 /// - Any field has incorrect type or invalid value
331 /// - Deserialization process fails for any reason
332 ///
333 /// # Performance
334 ///
335 /// - **Time Complexity**: O(1) - Constant time field extraction and validation
336 /// - **Memory Usage**: Temporary deserializer allocation for field processing
337 /// - **Error Handling**: Detailed error messages with field name context
338 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
339 match value {
340 FieldValue::Object(fields) => {
341 let mut deserializer = StructDeserializer::from_fields(fields);
342 Self::from_deserializer(&mut deserializer)
343 }
344 _ => Err(SerializationError::ValidationFailed {
345 field: field_name.to_string(),
346 message: format!(
347 "Expected Object for {}, found {}",
348 std::any::type_name::<Self>(),
349 value.type_name()
350 ),
351 }),
352 }
353 }
354}
355
356// ===== Serializable Parameter State =====
357
358/// Serializable representation of Adam parameter optimization state
359///
360/// This structure provides a serializable wrapper around the internal ParameterState,
361/// enabling efficient persistence of Adam optimizer state for individual parameters.
362/// It leverages the Train Station tensor serialization system for optimal storage
363/// efficiency and maintains complete fidelity of momentum and velocity buffers.
364///
365/// # Purpose
366///
367/// SerializableParameterState serves as an intermediate representation that:
368/// - **Preserves state**: Maintains complete Adam optimization state per parameter
369/// - **Enables serialization**: Converts internal state to serializable format
370/// - **Supports AMSGrad**: Handles optional AMSGrad maximum velocity tracking
371/// - **Maintains precision**: Preserves exact floating-point values in buffers
372/// - **Validates shapes**: Ensures parameter shape consistency during restoration
373///
374/// # Fields
375///
376/// * `m` - First moment estimate (momentum) tensor buffer
377/// * `v` - Second moment estimate (velocity) tensor buffer
378/// * `v_hat_max` - Optional maximum velocity tensor for AMSGrad variant
379/// * `step` - Per-parameter step count for bias correction calculations
380///
381/// # Design Rationale
382///
383/// The structure uses direct tensor serialization rather than manual data extraction:
384/// - **Efficiency**: Leverages optimized tensor serialization infrastructure
385/// - **Simplicity**: Eliminates manual buffer management and data copying
386/// - **Consistency**: Maintains alignment with Train Station serialization patterns
387/// - **Reliability**: Reduces serialization bugs through proven tensor serialization
388/// - **Performance**: Optimized memory usage and serialization speed
389///
390/// # Thread Safety
391///
392/// This structure is thread-safe for serialization operations:
393/// - **Immutable serialization**: Serialization does not modify state
394/// - **Clone safety**: Safe to clone across thread boundaries
395/// - **Tensor safety**: Leverages thread-safe tensor operations
396///
397/// # Memory Layout
398///
399/// The structure maintains efficient memory usage:
400/// - **Tensor sharing**: Tensors use reference counting for memory efficiency
401/// - **Optional fields**: AMSGrad state only allocated when needed
402/// - **Minimal overhead**: Small additional memory footprint beyond tensors
403#[derive(Debug, Clone)]
404struct SerializableParameterState {
405 /// First moment estimate (momentum) tensor buffer
406 ///
407 /// Contains the exponentially decaying average of past gradients, used for
408 /// momentum-based parameter updates. Shape matches the associated parameter.
409 m: Tensor,
410
411 /// Second moment estimate (velocity) tensor buffer
412 ///
413 /// Contains the exponentially decaying average of past squared gradients,
414 /// used for adaptive learning rate scaling. Shape matches the associated parameter.
415 v: Tensor,
416
417 /// Optional maximum velocity tensor for AMSGrad variant
418 ///
419 /// When AMSGrad is enabled, this tensor maintains the element-wise maximum
420 /// of all past velocity estimates, providing improved convergence properties.
421 /// None when AMSGrad is disabled. Shape matches the associated parameter when present.
422 v_hat_max: Option<Tensor>,
423
424 /// Per-parameter step count for bias correction
425 ///
426 /// Tracks the number of optimization steps performed for this specific parameter,
427 /// used in Adam's bias correction calculations. Essential for proper optimizer
428 /// behavior when parameters are added at different training stages.
429 step: usize,
430}
431
432impl SerializableParameterState {
433 /// Create SerializableParameterState from internal ParameterState
434 ///
435 /// This method converts the internal ParameterState representation used by
436 /// the Adam optimizer into a serializable format. It performs efficient tensor
437 /// cloning to preserve all optimization state while enabling serialization.
438 /// The conversion maintains exact numerical precision and handles optional
439 /// AMSGrad state appropriately.
440 ///
441 /// # Arguments
442 ///
443 /// * `state` - Reference to the internal ParameterState to convert
444 ///
445 /// # Returns
446 ///
447 /// SerializableParameterState containing all state data ready for serialization
448 ///
449 /// # Conversion Process
450 ///
451 /// The method performs these operations:
452 /// 1. **Momentum cloning**: Clones the momentum tensor (m) with full precision
453 /// 2. **Velocity cloning**: Clones the velocity tensor (v) with full precision
454 /// 3. **AMSGrad handling**: Clones optional AMSGrad state when present
455 /// 4. **Step preservation**: Copies the step count for bias correction
456 ///
457 /// # Performance
458 ///
459 /// - **Time Complexity**: O(1) - Tensor cloning uses reference counting
460 /// - **Memory Usage**: Minimal additional allocation due to tensor sharing
461 /// - **Precision**: Maintains exact floating-point values in all buffers
462 fn from_parameter_state(state: &ParameterState) -> Self {
463 // Use direct tensor cloning - leverages Tensor's efficient serialization
464 // No manual data extraction needed
465 Self {
466 m: state.m.clone(),
467 v: state.v.clone(),
468 v_hat_max: state.v_hat_max.clone(),
469 step: state.step,
470 }
471 }
472
473 /// Convert SerializableParameterState back to internal ParameterState
474 ///
475 /// This method reconstructs the internal ParameterState representation from
476 /// the serializable format, enabling the Adam optimizer to resume training
477 /// with preserved optimization state. The conversion maintains exact numerical
478 /// precision and properly handles optional AMSGrad state restoration.
479 ///
480 /// # Returns
481 ///
482 /// ParameterState instance ready for use by Adam optimizer, or SerializationError on failure
483 ///
484 /// # Reconstruction Process
485 ///
486 /// The method performs these operations:
487 /// 1. **Momentum restoration**: Clones momentum tensor back to internal format
488 /// 2. **Velocity restoration**: Clones velocity tensor back to internal format
489 /// 3. **AMSGrad restoration**: Handles optional AMSGrad state when present
490 /// 4. **Step restoration**: Preserves step count for continued bias correction
491 ///
492 /// # Validation
493 ///
494 /// The method ensures:
495 /// - All tensors have consistent shapes
496 /// - Step count is valid (non-negative)
497 /// - AMSGrad state consistency with configuration
498 /// - Tensor data integrity is maintained
499 ///
500 /// # Performance
501 ///
502 /// - **Time Complexity**: O(1) - Tensor cloning uses reference counting
503 /// - **Memory Usage**: Minimal allocation due to efficient tensor sharing
504 /// - **Precision**: Maintains exact floating-point values from serialization
505 ///
506 /// # Errors
507 ///
508 /// Returns SerializationError if:
509 /// - Tensor shapes are inconsistent
510 /// - Internal tensor state is corrupted
511 /// - Memory allocation fails during reconstruction
512 fn to_parameter_state(&self) -> SerializationResult<ParameterState> {
513 // Direct tensor cloning - no manual memory management needed
514 // Tensors handle their own efficient reconstruction
515 Ok(ParameterState {
516 m: self.m.clone(),
517 v: self.v.clone(),
518 v_hat_max: self.v_hat_max.clone(),
519 step: self.step,
520 })
521 }
522}
523
524impl StructSerializable for SerializableParameterState {
525 /// Convert SerializableParameterState to StructSerializer for serialization
526 ///
527 /// This method serializes all Adam parameter state components into a structured
528 /// format suitable for both JSON and binary serialization. It leverages the
529 /// efficient tensor serialization system to handle momentum and velocity buffers
530 /// while properly managing optional AMSGrad state.
531 ///
532 /// # Returns
533 ///
534 /// StructSerializer containing all parameter state data with field names and values
535 ///
536 /// # Serialized Fields
537 ///
538 /// - **m**: Momentum tensor (first moment estimate)
539 /// - **v**: Velocity tensor (second moment estimate)
540 /// - **v_hat_max**: Optional AMSGrad maximum velocity tensor
541 /// - **step**: Per-parameter step count for bias correction
542 ///
543 /// # Performance
544 ///
545 /// - **Time Complexity**: O(1) - Constant time field serialization
546 /// - **Memory Usage**: Leverages efficient tensor serialization
547 /// - **Precision**: Full floating-point precision preservation in tensors
548 fn to_serializer(&self) -> StructSerializer {
549 StructSerializer::new()
550 .field("m", &self.m)
551 .field("v", &self.v)
552 .field("v_hat_max", &self.v_hat_max)
553 .field("step", &self.step)
554 }
555
556 /// Create SerializableParameterState from StructDeserializer with validation
557 ///
558 /// This method reconstructs a SerializableParameterState from serialized data,
559 /// performing comprehensive validation to ensure all required fields are present
560 /// and contain valid tensor data. It handles both momentum and velocity tensors
561 /// along with optional AMSGrad state and step count information.
562 ///
563 /// # Arguments
564 ///
565 /// * `deserializer` - StructDeserializer containing parameter state field data
566 ///
567 /// # Returns
568 ///
569 /// Reconstructed SerializableParameterState on success, or SerializationError on failure
570 ///
571 /// # Required Fields
572 ///
573 /// All fields must be present in the deserializer:
574 /// - **m**: Momentum tensor with valid shape and data
575 /// - **v**: Velocity tensor with shape matching momentum tensor
576 /// - **v_hat_max**: Optional AMSGrad tensor (None or matching shape)
577 /// - **step**: Valid step count (non-negative integer)
578 ///
579 /// # Validation
580 ///
581 /// The method validates:
582 /// - Tensor field presence and type correctness
583 /// - Shape consistency between momentum and velocity tensors
584 /// - AMSGrad tensor shape consistency when present
585 /// - Step count validity and range
586 ///
587 /// # Performance
588 ///
589 /// - **Time Complexity**: O(1) - Constant time field extraction
590 /// - **Memory Usage**: Efficient tensor deserialization
591 /// - **Validation**: Comprehensive field and type validation
592 fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
593 Ok(Self {
594 m: deserializer.field("m")?,
595 v: deserializer.field("v")?,
596 v_hat_max: deserializer.field("v_hat_max")?,
597 step: deserializer.field("step")?,
598 })
599 }
600}
601
602impl ToFieldValue for SerializableParameterState {
603 /// Convert SerializableParameterState to FieldValue for embedding in collections
604 ///
605 /// This method converts the parameter state into a FieldValue::Object that can be
606 /// stored in collections or embedded within larger serializable structures. It
607 /// maintains the structured representation of all optimization state components
608 /// while enabling flexible serialization patterns.
609 ///
610 /// # Returns
611 ///
612 /// FieldValue::Object containing all parameter state data as key-value pairs
613 ///
614 /// # Object Structure
615 ///
616 /// The returned object contains these fields:
617 /// - "m": Momentum tensor as serialized FieldValue
618 /// - "v": Velocity tensor as serialized FieldValue
619 /// - "v_hat_max": Optional AMSGrad tensor as serialized FieldValue
620 /// - "step": Step count as FieldValue::Usize
621 ///
622 /// # Performance
623 ///
624 /// - **Time Complexity**: O(1) - Constant time field conversion
625 /// - **Memory Usage**: Efficient tensor serialization with minimal overhead
626 /// - **Precision**: Maintains exact tensor data and step count values
627 fn to_field_value(&self) -> FieldValue {
628 let serializer = self.to_serializer();
629 FieldValue::from_object(serializer.fields.into_iter().collect())
630 }
631}
632
633impl FromFieldValue for SerializableParameterState {
634 /// Create SerializableParameterState from FieldValue with comprehensive validation
635 ///
636 /// This method reconstructs a SerializableParameterState from a FieldValue::Object,
637 /// performing type validation and tensor deserialization. It handles the complex
638 /// process of deserializing momentum and velocity tensors along with optional
639 /// AMSGrad state, ensuring data integrity and proper error handling.
640 ///
641 /// # Arguments
642 ///
643 /// * `value` - FieldValue containing parameter state data (must be Object variant)
644 /// * `field_name` - Name of the field being deserialized for error context
645 ///
646 /// # Returns
647 ///
648 /// Reconstructed SerializableParameterState on success, or SerializationError on failure
649 ///
650 /// # Expected FieldValue Structure
651 ///
652 /// The FieldValue must be an Object variant containing:
653 /// - "m": Serialized momentum tensor
654 /// - "v": Serialized velocity tensor
655 /// - "v_hat_max": Optional serialized AMSGrad tensor
656 /// - "step": Step count as numeric value
657 ///
658 /// # Validation
659 ///
660 /// The method validates:
661 /// - FieldValue is Object variant
662 /// - All required fields are present
663 /// - Tensor deserialization succeeds
664 /// - Step count is valid numeric value
665 /// - Tensor shapes are consistent
666 ///
667 /// # Performance
668 ///
669 /// - **Time Complexity**: O(1) - Constant time field extraction and tensor deserialization
670 /// - **Memory Usage**: Efficient tensor deserialization with minimal overhead
671 /// - **Error Handling**: Detailed error messages with field name context
672 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
673 match value {
674 FieldValue::Object(fields) => {
675 let mut deserializer = StructDeserializer::from_fields(fields);
676 Self::from_deserializer(&mut deserializer)
677 }
678 _ => Err(SerializationError::ValidationFailed {
679 field: field_name.to_string(),
680 message: format!(
681 "Expected Object for {}, found {}",
682 std::any::type_name::<Self>(),
683 value.type_name()
684 ),
685 }),
686 }
687 }
688}
689
690// ===== Adam Serialization =====
691
692impl StructSerializable for Adam {
693 /// Convert Adam to StructSerializer for serialization
694 ///
695 /// Serializes all optimizer state including configuration, parameter states,
696 /// and global step count. Parameter linking is not serialized and must be
697 /// done after deserialization.
698 ///
699 /// # Returns
700 ///
701 /// StructSerializer containing all serializable optimizer state
702 fn to_serializer(&self) -> StructSerializer {
703 // Convert parameter states to serializable form
704 let mut serializable_states = HashMap::new();
705 for (param_id, state) in &self.states {
706 serializable_states.insert(
707 *param_id,
708 SerializableParameterState::from_parameter_state(state),
709 );
710 }
711
712 StructSerializer::new()
713 .field("config", &self.config)
714 .field("states", &serializable_states)
715 .field("step_count", &self.step_count)
716 .field("insertion_order", &self.insertion_order)
717 }
718
719 /// Create Adam from StructDeserializer
720 ///
721 /// Reconstructs Adam optimizer from serialized state. Parameters must be
722 /// linked separately using `add_parameter` or `add_parameters`.
723 ///
724 /// # Arguments
725 ///
726 /// * `deserializer` - StructDeserializer containing optimizer data
727 ///
728 /// # Returns
729 ///
730 /// Reconstructed Adam instance without parameter links, or error if deserialization fails
731 fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
732 let config: AdamConfig = deserializer.field("config")?;
733 let serializable_states: HashMap<usize, SerializableParameterState> =
734 deserializer.field("states")?;
735 let step_count: usize = deserializer.field("step_count")?;
736 let insertion_order: Vec<usize> = deserializer.field("insertion_order")?;
737
738 // Reconstruct parameter states from serialized form
739 let mut states = HashMap::new();
740 for (param_id, serializable_state) in serializable_states {
741 states.insert(param_id, serializable_state.to_parameter_state()?);
742 }
743
744 // Create optimizer with reconstructed states - user must call relink_parameters to link tensors
745 Ok(Adam {
746 config,
747 states,
748 step_count,
749 insertion_order,
750 })
751 }
752}
753
754impl FromFieldValue for Adam {
755 /// Create Adam from FieldValue
756 ///
757 /// # Arguments
758 ///
759 /// * `value` - FieldValue containing optimizer data
760 /// * `field_name` - Name of the field being deserialized (for error messages)
761 ///
762 /// # Returns
763 ///
764 /// Reconstructed Adam instance or error if deserialization fails
765 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
766 match value {
767 FieldValue::Object(fields) => {
768 let mut deserializer = StructDeserializer::from_fields(fields);
769 Self::from_deserializer(&mut deserializer)
770 }
771 _ => Err(SerializationError::ValidationFailed {
772 field: field_name.to_string(),
773 message: format!(
774 "Expected Object for {}, found {}",
775 std::any::type_name::<Self>(),
776 value.type_name()
777 ),
778 }),
779 }
780 }
781}
782
783// ===== Serializable Trait Implementation =====
784
785impl crate::serialization::Serializable for Adam {
786 /// Serialize the Adam optimizer to JSON format
787 ///
788 /// This method converts the Adam optimizer into a human-readable JSON string representation
789 /// that includes all optimizer state, configuration, parameter states, and step counts.
790 /// The JSON format is suitable for debugging, configuration files, and cross-language
791 /// interoperability.
792 ///
793 /// # Returns
794 ///
795 /// JSON string representation of the optimizer on success, or `SerializationError` on failure
796 ///
797 /// # Examples
798 ///
799 /// ```
800 /// use train_station::{Tensor, optimizers::Adam};
801 /// use train_station::serialization::Serializable;
802 ///
803 /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
804 /// let mut optimizer = Adam::new();
805 /// optimizer.add_parameter(&weight);
806 ///
807 /// let json = optimizer.to_json().unwrap();
808 /// assert!(!json.is_empty());
809 /// ```
810 fn to_json(&self) -> SerializationResult<String> {
811 <Self as StructSerializable>::to_json(self)
812 }
813
814 /// Deserialize an Adam optimizer from JSON format
815 ///
816 /// This method parses a JSON string and reconstructs an Adam optimizer with all
817 /// saved state. Parameters must be re-linked after deserialization using
818 /// `add_parameter` or `relink_parameters`.
819 ///
820 /// # Arguments
821 ///
822 /// * `json` - JSON string containing serialized optimizer
823 ///
824 /// # Returns
825 ///
826 /// The deserialized optimizer on success, or `SerializationError` on failure
827 ///
828 /// # Examples
829 ///
830 /// ```
831 /// use train_station::{Tensor, optimizers::Adam};
832 /// use train_station::serialization::Serializable;
833 ///
834 /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
835 /// let mut optimizer = Adam::new();
836 /// optimizer.add_parameter(&weight);
837 ///
838 /// let json = optimizer.to_json().unwrap();
839 /// let loaded_optimizer = Adam::from_json(&json).unwrap();
840 /// assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
841 /// ```
842 fn from_json(json: &str) -> SerializationResult<Self> {
843 <Self as StructSerializable>::from_json(json)
844 }
845
846 /// Serialize the Adam optimizer to binary format
847 ///
848 /// This method converts the optimizer into a compact binary representation optimized
849 /// for storage and transmission. The binary format provides maximum performance
850 /// and minimal file sizes compared to JSON.
851 ///
852 /// # Returns
853 ///
854 /// Binary representation of the optimizer on success, or `SerializationError` on failure
855 ///
856 /// # Examples
857 ///
858 /// ```
859 /// use train_station::{Tensor, optimizers::Adam};
860 /// use train_station::serialization::Serializable;
861 ///
862 /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
863 /// let mut optimizer = Adam::new();
864 /// optimizer.add_parameter(&weight);
865 ///
866 /// let binary = optimizer.to_binary().unwrap();
867 /// assert!(!binary.is_empty());
868 /// ```
869 fn to_binary(&self) -> SerializationResult<Vec<u8>> {
870 <Self as StructSerializable>::to_binary(self)
871 }
872
873 /// Deserialize an Adam optimizer from binary format
874 ///
875 /// This method parses binary data and reconstructs an Adam optimizer with all
876 /// saved state. Parameters must be re-linked after deserialization using
877 /// `add_parameter` or `relink_parameters`.
878 ///
879 /// # Arguments
880 ///
881 /// * `data` - Binary data containing serialized optimizer
882 ///
883 /// # Returns
884 ///
885 /// The deserialized optimizer on success, or `SerializationError` on failure
886 ///
887 /// # Examples
888 ///
889 /// ```
890 /// use train_station::{Tensor, optimizers::Adam};
891 /// use train_station::serialization::Serializable;
892 ///
893 /// let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
894 /// let mut optimizer = Adam::new();
895 /// optimizer.add_parameter(&weight);
896 ///
897 /// let binary = optimizer.to_binary().unwrap();
898 /// let loaded_optimizer = Adam::from_binary(&binary).unwrap();
899 /// assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
900 /// ```
901 fn from_binary(data: &[u8]) -> SerializationResult<Self> {
902 <Self as StructSerializable>::from_binary(data)
903 }
904}
905
906// ===== Utility Methods =====
907
908impl Adam {
909 /// Get the number of saved parameter states for checkpoint validation
910 ///
911 /// This method returns the count of parameter states currently stored in the optimizer,
912 /// which is essential for validating checkpoint integrity and ensuring proper parameter
913 /// re-linking after deserialization. The count includes all parameters that have been
914 /// linked to the optimizer and have accumulated optimization state.
915 ///
916 /// # Returns
917 ///
918 /// Number of parameter states currently stored in the optimizer
919 ///
920 /// # Usage Patterns
921 ///
922 /// ## Checkpoint Validation
923 /// After deserializing an optimizer, this method helps verify that the expected
924 /// number of parameters were saved and can guide the re-linking process.
925 ///
926 /// ## Training Resumption
927 /// When resuming training, compare this count with the number of parameters
928 /// in your model to ensure checkpoint compatibility.
929 ///
930 /// ## State Management
931 /// Use this method to monitor optimizer state growth and memory usage during
932 /// training with dynamic parameter addition.
933 ///
934 /// # Examples
935 ///
936 /// ```
937 /// use train_station::{Tensor, optimizers::Adam};
938 /// use train_station::serialization::Serializable;
939 ///
940 /// let weight = Tensor::ones(vec![10, 5]).with_requires_grad();
941 /// let bias = Tensor::zeros(vec![5]).with_requires_grad();
942 /// let mut optimizer = Adam::new();
943 /// optimizer.add_parameter(&weight);
944 /// optimizer.add_parameter(&bias);
945 ///
946 /// // Check parameter count before serialization
947 /// assert_eq!(optimizer.saved_parameter_count(), 2);
948 ///
949 /// // Serialize and deserialize
950 /// let json = optimizer.to_json().unwrap();
951 /// let loaded_optimizer = Adam::from_json(&json).unwrap();
952 ///
953 /// // Verify parameter count is preserved
954 /// assert_eq!(loaded_optimizer.saved_parameter_count(), 2);
955 /// ```
956 ///
957 /// # Performance
958 ///
959 /// - **Time Complexity**: O(1) - Direct access to internal state count
960 /// - **Memory Usage**: No additional memory allocation
961 /// - **Thread Safety**: Safe to call from multiple threads concurrently
962 #[track_caller]
963 pub fn saved_parameter_count(&self) -> usize {
964 self.states.len()
965 }
966}
967
968// ===== Field Value Implementations for Collections =====
969
970impl ToFieldValue for HashMap<usize, SerializableParameterState> {
971 /// Convert parameter states HashMap to FieldValue for serialization
972 ///
973 /// This method converts the HashMap of parameter states into a FieldValue::Object
974 /// suitable for serialization. It handles the conversion of usize keys to string
975 /// format required by the FieldValue::Object representation while preserving
976 /// all parameter state data.
977 ///
978 /// # Returns
979 ///
980 /// FieldValue::Object with string keys and SerializableParameterState values
981 ///
982 /// # Key Conversion
983 ///
984 /// - **Input**: HashMap<usize, SerializableParameterState> with numeric tensor IDs
985 /// - **Output**: FieldValue::Object with string keys for JSON compatibility
986 /// - **Mapping**: Each usize key is converted to string representation
987 /// - **Preservation**: All parameter state data is preserved exactly
988 ///
989 /// # Performance
990 ///
991 /// - **Time Complexity**: O(n) where n is the number of parameter states
992 /// - **Memory Usage**: Allocates new HashMap for string keys
993 /// - **Conversion**: Efficient string conversion for numeric keys
994 fn to_field_value(&self) -> FieldValue {
995 let mut map = HashMap::new();
996 for (key, value) in self {
997 map.insert(key.to_string(), value.to_field_value());
998 }
999 FieldValue::from_object(map)
1000 }
1001}
1002
1003impl FromFieldValue for HashMap<usize, SerializableParameterState> {
1004 /// Create parameter states HashMap from FieldValue with validation
1005 ///
1006 /// This method reconstructs the HashMap of parameter states from a FieldValue::Object,
1007 /// performing comprehensive validation and key conversion. It handles the conversion
1008 /// from string keys back to usize tensor IDs while ensuring all parameter state
1009 /// data is properly deserialized and validated.
1010 ///
1011 /// # Arguments
1012 ///
1013 /// * `value` - FieldValue containing parameter states data (must be Object variant)
1014 /// * `field_name` - Name of the field being deserialized for error context
1015 ///
1016 /// # Returns
1017 ///
1018 /// Reconstructed HashMap<usize, SerializableParameterState> on success, or SerializationError on failure
1019 ///
1020 /// # Key Conversion Process
1021 ///
1022 /// 1. **Validation**: Ensures FieldValue is Object variant
1023 /// 2. **Key parsing**: Converts string keys back to usize tensor IDs
1024 /// 3. **State deserialization**: Deserializes each parameter state
1025 /// 4. **Validation**: Validates parameter state integrity
1026 /// 5. **Collection**: Builds final HashMap with proper types
1027 ///
1028 /// # Errors
1029 ///
1030 /// Returns SerializationError if:
1031 /// - FieldValue is not Object variant
1032 /// - Any string key cannot be parsed as usize
1033 /// - Parameter state deserialization fails
1034 /// - Invalid parameter state data is encountered
1035 ///
1036 /// # Performance
1037 ///
1038 /// - **Time Complexity**: O(n) where n is the number of parameter states
1039 /// - **Memory Usage**: Allocates new HashMap with proper key types
1040 /// - **Validation**: Comprehensive key parsing and state validation
1041 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
1042 match value {
1043 FieldValue::Object(fields) => {
1044 let mut map = HashMap::new();
1045 for (key_str, field_value) in fields {
1046 let key = key_str.parse::<usize>().map_err(|_| {
1047 SerializationError::ValidationFailed {
1048 field: field_name.to_string(),
1049 message: format!("Invalid key '{}' in parameter states map", key_str),
1050 }
1051 })?;
1052 let state =
1053 SerializableParameterState::from_field_value(field_value, &key_str)?;
1054 map.insert(key, state);
1055 }
1056 Ok(map)
1057 }
1058 _ => Err(SerializationError::ValidationFailed {
1059 field: field_name.to_string(),
1060 message: format!(
1061 "Expected Object for HashMap<usize, SerializableParameterState>, found {}",
1062 value.type_name()
1063 ),
1064 }),
1065 }
1066 }
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071 use super::*;
1072 use crate::optimizers::Optimizer;
1073 use crate::tensor::core::Tensor;
1074
1075 // ===== AdamConfig Serialization Tests =====
1076
1077 #[test]
1078 fn test_adam_config_json_roundtrip() {
1079 let config = AdamConfig {
1080 learning_rate: 1e-4,
1081 beta1: 0.95,
1082 beta2: 0.9999,
1083 eps: 1e-7,
1084 weight_decay: 1e-5,
1085 amsgrad: true,
1086 };
1087
1088 let json = config.to_json().unwrap();
1089 let loaded_config = AdamConfig::from_json(&json).unwrap();
1090
1091 assert_eq!(config.learning_rate, loaded_config.learning_rate);
1092 assert_eq!(config.beta1, loaded_config.beta1);
1093 assert_eq!(config.beta2, loaded_config.beta2);
1094 assert_eq!(config.eps, loaded_config.eps);
1095 assert_eq!(config.weight_decay, loaded_config.weight_decay);
1096 assert_eq!(config.amsgrad, loaded_config.amsgrad);
1097 }
1098
1099 #[test]
1100 fn test_adam_config_binary_roundtrip() {
1101 let config = AdamConfig {
1102 learning_rate: 2e-3,
1103 beta1: 0.85,
1104 beta2: 0.995,
1105 eps: 1e-9,
1106 weight_decay: 5e-4,
1107 amsgrad: false,
1108 };
1109
1110 let binary = config.to_binary().unwrap();
1111 let loaded_config = AdamConfig::from_binary(&binary).unwrap();
1112
1113 assert_eq!(config.learning_rate, loaded_config.learning_rate);
1114 assert_eq!(config.beta1, loaded_config.beta1);
1115 assert_eq!(config.beta2, loaded_config.beta2);
1116 assert_eq!(config.eps, loaded_config.eps);
1117 assert_eq!(config.weight_decay, loaded_config.weight_decay);
1118 assert_eq!(config.amsgrad, loaded_config.amsgrad);
1119 }
1120
1121 #[test]
1122 fn test_adam_config_field_value_roundtrip() {
1123 let config = AdamConfig {
1124 learning_rate: 3e-4,
1125 beta1: 0.92,
1126 beta2: 0.998,
1127 eps: 1e-6,
1128 weight_decay: 2e-4,
1129 amsgrad: true,
1130 };
1131
1132 let field_value = config.to_field_value();
1133 let loaded_config = AdamConfig::from_field_value(field_value, "config").unwrap();
1134
1135 assert_eq!(config.learning_rate, loaded_config.learning_rate);
1136 assert_eq!(config.beta1, loaded_config.beta1);
1137 assert_eq!(config.beta2, loaded_config.beta2);
1138 assert_eq!(config.eps, loaded_config.eps);
1139 assert_eq!(config.weight_decay, loaded_config.weight_decay);
1140 assert_eq!(config.amsgrad, loaded_config.amsgrad);
1141 }
1142
1143 // ===== Adam Optimizer Serialization Tests =====
1144
1145 #[test]
1146 fn test_adam_optimizer_json_roundtrip() {
1147 let mut weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1148 let mut bias = Tensor::zeros(vec![2, 3]).with_requires_grad();
1149
1150 let mut optimizer = Adam::new();
1151 optimizer.add_parameter(&weight);
1152 optimizer.add_parameter(&bias);
1153
1154 // Perform some steps to create state
1155 let output = weight.add_tensor(&bias);
1156 let mut loss = output.sum();
1157 loss.backward(None);
1158 optimizer.step(&mut [&mut weight, &mut bias]);
1159
1160 // Test serialization
1161 let json = optimizer.to_json().unwrap();
1162 let loaded_optimizer = Adam::from_json(&json).unwrap();
1163
1164 assert_eq!(
1165 optimizer.config().learning_rate,
1166 loaded_optimizer.config().learning_rate
1167 );
1168 assert_eq!(
1169 optimizer.saved_parameter_count(),
1170 loaded_optimizer.saved_parameter_count()
1171 );
1172 }
1173
1174 #[test]
1175 fn test_adam_optimizer_binary_roundtrip() {
1176 let weight = Tensor::ones(vec![5, 2]).with_requires_grad();
1177
1178 let mut optimizer = Adam::with_learning_rate(1e-4);
1179 optimizer.add_parameter(&weight);
1180
1181 // Test serialization
1182 let binary = optimizer.to_binary().unwrap();
1183 let loaded_optimizer = Adam::from_binary(&binary).unwrap();
1184
1185 assert_eq!(
1186 optimizer.config().learning_rate,
1187 loaded_optimizer.config().learning_rate
1188 );
1189 assert_eq!(
1190 optimizer.saved_parameter_count(),
1191 loaded_optimizer.saved_parameter_count()
1192 );
1193 }
1194
1195 #[test]
1196 fn test_adam_parameter_relinking() {
1197 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1198 let mut optimizer = Adam::new();
1199 optimizer.add_parameter(&weight);
1200
1201 // Serialize
1202 let json = optimizer.to_json().unwrap();
1203
1204 // Deserialize
1205 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1206
1207 // After deserialization, saved states should be preserved
1208 assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
1209
1210 // Re-link parameter - this creates a new state since it's a new tensor with different ID
1211 let new_weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1212 loaded_optimizer.add_parameter(&new_weight);
1213
1214 // Now there should be 2 states: the original saved one + the new one
1215 assert_eq!(loaded_optimizer.parameter_count(), 2);
1216 assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1217 }
1218
1219 #[test]
1220 fn test_adam_state_preservation() {
1221 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1222 let mut optimizer = Adam::new();
1223 optimizer.add_parameter(&weight);
1224
1225 // Perform training steps to build up state
1226 for _ in 0..3 {
1227 let output = weight.mul_scalar(2.0);
1228 let mut loss = output.sum();
1229 loss.backward(None);
1230 optimizer.step(&mut [&mut weight]);
1231 optimizer.zero_grad(&mut [&mut weight]);
1232 }
1233
1234 // Serialize and deserialize
1235 let json = optimizer.to_json().unwrap();
1236 let loaded_optimizer = Adam::from_json(&json).unwrap();
1237
1238 // Check that states were preserved
1239 assert_eq!(loaded_optimizer.saved_parameter_count(), 1);
1240 assert_eq!(
1241 loaded_optimizer.config().learning_rate,
1242 optimizer.config().learning_rate
1243 );
1244 }
1245
1246 #[test]
1247 fn test_relink_parameters_success() {
1248 // Create original optimizer with parameters
1249 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1250 let bias = Tensor::zeros(vec![3]).with_requires_grad();
1251
1252 let mut optimizer = Adam::new();
1253 optimizer.add_parameter(&weight);
1254 optimizer.add_parameter(&bias);
1255
1256 // Serialize
1257 let json = optimizer.to_json().unwrap();
1258
1259 // Create new parameters with same shapes but different IDs
1260 let new_weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1261 let new_bias = Tensor::zeros(vec![3]).with_requires_grad();
1262
1263 // Deserialize and re-link
1264 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1265 loaded_optimizer
1266 .relink_parameters(&[&new_weight, &new_bias])
1267 .unwrap();
1268
1269 // Verify re-linking worked
1270 assert_eq!(loaded_optimizer.parameter_count(), 2);
1271 assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1272 assert!(loaded_optimizer.is_parameter_linked(&new_bias));
1273 }
1274
1275 #[test]
1276 fn test_relink_parameters_shape_mismatch() {
1277 // Create original optimizer
1278 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1279 let mut optimizer = Adam::new();
1280 optimizer.add_parameter(&weight);
1281
1282 // Serialize
1283 let json = optimizer.to_json().unwrap();
1284
1285 // Create new parameter with different shape
1286 let new_weight = Tensor::ones(vec![3, 2]).with_requires_grad(); // Different shape!
1287
1288 // Deserialize and try to re-link
1289 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1290 let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1291
1292 // Should fail with shape mismatch error
1293 assert!(result.is_err());
1294 assert!(result.unwrap_err().contains("Shape mismatch"));
1295 }
1296
1297 #[test]
1298 fn test_relink_parameters_count_mismatch() {
1299 // Create original optimizer with 2 parameters
1300 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1301 let bias = Tensor::zeros(vec![3]).with_requires_grad();
1302
1303 let mut optimizer = Adam::new();
1304 optimizer.add_parameter(&weight);
1305 optimizer.add_parameter(&bias);
1306
1307 // Serialize
1308 let json = optimizer.to_json().unwrap();
1309
1310 // Create only 1 new parameter
1311 let new_weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1312
1313 // Deserialize and try to re-link with wrong count
1314 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1315 let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1316
1317 // Should fail with count mismatch error
1318 assert!(result.is_err());
1319 assert!(result.unwrap_err().contains("Parameter count mismatch"));
1320 }
1321
1322 #[test]
1323 fn test_relink_parameters_requires_grad() {
1324 // Create original optimizer
1325 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1326 let mut optimizer = Adam::new();
1327 optimizer.add_parameter(&weight);
1328
1329 // Serialize
1330 let json = optimizer.to_json().unwrap();
1331
1332 // Create new parameter without requires_grad
1333 let new_weight = Tensor::ones(vec![2, 3]); // No requires_grad!
1334
1335 // Deserialize and try to re-link
1336 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1337 let result = loaded_optimizer.relink_parameters(&[&new_weight]);
1338
1339 // Should fail with requires_grad error
1340 assert!(result.is_err());
1341 assert!(result.unwrap_err().contains("must require gradients"));
1342 }
1343
1344 #[test]
1345 fn test_relink_preserves_state() {
1346 // Create original optimizer and train it
1347 let mut weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1348 let mut optimizer = Adam::new();
1349 optimizer.add_parameter(&weight);
1350
1351 // Perform some training to build up state
1352 for _ in 0..2 {
1353 let output = weight.mul_scalar(2.0);
1354 let mut loss = output.sum();
1355 loss.backward(None);
1356 optimizer.step(&mut [&mut weight]);
1357 optimizer.zero_grad(&mut [&mut weight]);
1358 }
1359
1360 // Get state before serialization
1361 let original_step_count = optimizer.step_count;
1362
1363 // Serialize
1364 let json = optimizer.to_json().unwrap();
1365
1366 // Create new parameter
1367 let new_weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1368
1369 // Deserialize and re-link
1370 let mut loaded_optimizer = Adam::from_json(&json).unwrap();
1371 loaded_optimizer.relink_parameters(&[&new_weight]).unwrap();
1372
1373 // Verify state is preserved
1374 assert_eq!(loaded_optimizer.step_count, original_step_count);
1375 assert_eq!(loaded_optimizer.parameter_count(), 1);
1376 assert!(loaded_optimizer.is_parameter_linked(&new_weight));
1377 }
1378
1379 // ===== Serializable Trait Tests =====
1380
1381 #[test]
1382 fn test_serializable_json_methods() {
1383 // Create and populate test optimizer
1384 let weight = Tensor::ones(vec![2, 3]).with_requires_grad();
1385 let bias = Tensor::zeros(vec![3]).with_requires_grad();
1386
1387 let mut optimizer = Adam::with_learning_rate(1e-3);
1388 optimizer.add_parameter(&weight);
1389 optimizer.add_parameter(&bias);
1390
1391 // Test to_json method
1392 let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1393 assert!(!json.is_empty());
1394 assert!(json.contains("config"));
1395 assert!(json.contains("states"));
1396 assert!(json.contains("step_count"));
1397 assert!(json.contains("learning_rate"));
1398
1399 // Test from_json method
1400 let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1401 assert_eq!(
1402 optimizer.config().learning_rate,
1403 restored.config().learning_rate
1404 );
1405 assert_eq!(optimizer.config().beta1, restored.config().beta1);
1406 assert_eq!(optimizer.config().beta2, restored.config().beta2);
1407 assert_eq!(optimizer.config().eps, restored.config().eps);
1408 assert_eq!(
1409 optimizer.config().weight_decay,
1410 restored.config().weight_decay
1411 );
1412 assert_eq!(optimizer.config().amsgrad, restored.config().amsgrad);
1413 assert_eq!(
1414 optimizer.saved_parameter_count(),
1415 restored.saved_parameter_count()
1416 );
1417 assert_eq!(optimizer.step_count, restored.step_count);
1418 }
1419
1420 #[test]
1421 fn test_serializable_binary_methods() {
1422 // Create and populate test optimizer
1423 let weight = Tensor::ones(vec![3, 4]).with_requires_grad();
1424 let mut optimizer = Adam::with_config(AdamConfig {
1425 learning_rate: 2e-4,
1426 beta1: 0.95,
1427 beta2: 0.999,
1428 eps: 1e-7,
1429 weight_decay: 1e-4,
1430 amsgrad: true,
1431 });
1432 optimizer.add_parameter(&weight);
1433
1434 // Test to_binary method
1435 let binary = <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1436 assert!(!binary.is_empty());
1437
1438 // Test from_binary method
1439 let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1440 assert_eq!(
1441 optimizer.config().learning_rate,
1442 restored.config().learning_rate
1443 );
1444 assert_eq!(optimizer.config().beta1, restored.config().beta1);
1445 assert_eq!(optimizer.config().beta2, restored.config().beta2);
1446 assert_eq!(optimizer.config().eps, restored.config().eps);
1447 assert_eq!(
1448 optimizer.config().weight_decay,
1449 restored.config().weight_decay
1450 );
1451 assert_eq!(optimizer.config().amsgrad, restored.config().amsgrad);
1452 assert_eq!(
1453 optimizer.saved_parameter_count(),
1454 restored.saved_parameter_count()
1455 );
1456 assert_eq!(optimizer.step_count, restored.step_count);
1457 }
1458
1459 #[test]
1460 fn test_serializable_file_io_json() {
1461 use crate::serialization::{Format, Serializable};
1462 use std::fs;
1463 use std::path::Path;
1464
1465 // Create test optimizer
1466 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1467 let bias = Tensor::zeros(vec![2]).with_requires_grad();
1468
1469 let mut optimizer = Adam::with_learning_rate(5e-4);
1470 optimizer.add_parameter(&weight);
1471 optimizer.add_parameter(&bias);
1472
1473 let json_path = "test_adam_serializable.json";
1474
1475 // Test save method with JSON format
1476 Serializable::save(&optimizer, json_path, Format::Json).unwrap();
1477 assert!(Path::new(json_path).exists());
1478
1479 // Test load method with JSON format
1480 let loaded_optimizer = Adam::load(json_path, Format::Json).unwrap();
1481 assert_eq!(
1482 optimizer.config().learning_rate,
1483 loaded_optimizer.config().learning_rate
1484 );
1485 assert_eq!(
1486 optimizer.saved_parameter_count(),
1487 loaded_optimizer.saved_parameter_count()
1488 );
1489
1490 // Test save_to_writer method
1491 let json_path_2 = "test_adam_serializable_writer.json";
1492 {
1493 let file = std::fs::File::create(json_path_2).unwrap();
1494 let mut writer = std::io::BufWriter::new(file);
1495 Serializable::save_to_writer(&optimizer, &mut writer, Format::Json).unwrap();
1496 }
1497 assert!(Path::new(json_path_2).exists());
1498
1499 // Test load_from_reader method
1500 {
1501 let file = std::fs::File::open(json_path_2).unwrap();
1502 let mut reader = std::io::BufReader::new(file);
1503 let loaded_optimizer = Adam::load_from_reader(&mut reader, Format::Json).unwrap();
1504 assert_eq!(
1505 optimizer.config().learning_rate,
1506 loaded_optimizer.config().learning_rate
1507 );
1508 }
1509
1510 // Cleanup test files
1511 let _ = fs::remove_file(json_path);
1512 let _ = fs::remove_file(json_path_2);
1513 }
1514
1515 #[test]
1516 fn test_serializable_file_io_binary() {
1517 use crate::serialization::{Format, Serializable};
1518 use std::fs;
1519 use std::path::Path;
1520
1521 // Create test optimizer
1522 let weight = Tensor::ones(vec![3, 3]).with_requires_grad();
1523 let mut optimizer = Adam::with_config(AdamConfig {
1524 learning_rate: 1e-3,
1525 beta1: 0.9,
1526 beta2: 0.999,
1527 eps: 1e-8,
1528 weight_decay: 0.0,
1529 amsgrad: false,
1530 });
1531 optimizer.add_parameter(&weight);
1532
1533 let binary_path = "test_adam_serializable.bin";
1534
1535 // Test save method with binary format
1536 Serializable::save(&optimizer, binary_path, Format::Binary).unwrap();
1537 assert!(Path::new(binary_path).exists());
1538
1539 // Test load method with binary format
1540 let loaded_optimizer = Adam::load(binary_path, Format::Binary).unwrap();
1541 assert_eq!(
1542 optimizer.config().learning_rate,
1543 loaded_optimizer.config().learning_rate
1544 );
1545 assert_eq!(
1546 optimizer.saved_parameter_count(),
1547 loaded_optimizer.saved_parameter_count()
1548 );
1549
1550 // Test save_to_writer method
1551 let binary_path_2 = "test_adam_serializable_writer.bin";
1552 {
1553 let file = std::fs::File::create(binary_path_2).unwrap();
1554 let mut writer = std::io::BufWriter::new(file);
1555 Serializable::save_to_writer(&optimizer, &mut writer, Format::Binary).unwrap();
1556 }
1557 assert!(Path::new(binary_path_2).exists());
1558
1559 // Test load_from_reader method
1560 {
1561 let file = std::fs::File::open(binary_path_2).unwrap();
1562 let mut reader = std::io::BufReader::new(file);
1563 let loaded_optimizer = Adam::load_from_reader(&mut reader, Format::Binary).unwrap();
1564 assert_eq!(
1565 optimizer.config().learning_rate,
1566 loaded_optimizer.config().learning_rate
1567 );
1568 }
1569
1570 // Cleanup test files
1571 let _ = fs::remove_file(binary_path);
1572 let _ = fs::remove_file(binary_path_2);
1573 }
1574
1575 #[test]
1576 fn test_serializable_large_optimizer_performance() {
1577 // Create a large optimizer to test performance characteristics
1578 let mut optimizer = Adam::with_learning_rate(1e-4);
1579
1580 // Add multiple parameters of different sizes
1581 for i in 0..5 {
1582 let size = 10 + i * 5;
1583 let param = Tensor::ones(vec![size, size]).with_requires_grad();
1584 optimizer.add_parameter(¶m);
1585 }
1586
1587 // Test JSON serialization
1588 let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1589 assert!(!json.is_empty());
1590 let restored_json = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1591 assert_eq!(
1592 optimizer.config().learning_rate,
1593 restored_json.config().learning_rate
1594 );
1595 assert_eq!(
1596 optimizer.saved_parameter_count(),
1597 restored_json.saved_parameter_count()
1598 );
1599
1600 // Test binary serialization
1601 let binary = <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1602 assert!(!binary.is_empty());
1603 // Binary format should be efficient (this is informational, not a requirement)
1604 println!(
1605 "JSON size: {} bytes, Binary size: {} bytes",
1606 json.len(),
1607 binary.len()
1608 );
1609
1610 let restored_binary =
1611 <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1612 assert_eq!(
1613 optimizer.config().learning_rate,
1614 restored_binary.config().learning_rate
1615 );
1616 assert_eq!(
1617 optimizer.saved_parameter_count(),
1618 restored_binary.saved_parameter_count()
1619 );
1620
1621 // Verify all configurations match
1622 assert_eq!(optimizer.config().beta1, restored_binary.config().beta1);
1623 assert_eq!(optimizer.config().beta2, restored_binary.config().beta2);
1624 assert_eq!(optimizer.config().eps, restored_binary.config().eps);
1625 assert_eq!(
1626 optimizer.config().weight_decay,
1627 restored_binary.config().weight_decay
1628 );
1629 assert_eq!(optimizer.config().amsgrad, restored_binary.config().amsgrad);
1630 }
1631
1632 #[test]
1633 fn test_serializable_error_handling() {
1634 // Test invalid JSON
1635 let invalid_json = r#"{"invalid": "json", "structure": true}"#;
1636 let result = <Adam as crate::serialization::Serializable>::from_json(invalid_json);
1637 assert!(result.is_err());
1638
1639 // Test empty JSON
1640 let empty_json = "{}";
1641 let result = <Adam as crate::serialization::Serializable>::from_json(empty_json);
1642 assert!(result.is_err());
1643
1644 // Test invalid binary data
1645 let invalid_binary = vec![1, 2, 3, 4, 5];
1646 let result = <Adam as crate::serialization::Serializable>::from_binary(&invalid_binary);
1647 assert!(result.is_err());
1648
1649 // Test empty binary data
1650 let empty_binary = vec![];
1651 let result = <Adam as crate::serialization::Serializable>::from_binary(&empty_binary);
1652 assert!(result.is_err());
1653 }
1654
1655 #[test]
1656 fn test_serializable_different_configurations() {
1657 let test_configs = vec![
1658 // Default configuration
1659 AdamConfig::default(),
1660 // High learning rate
1661 AdamConfig {
1662 learning_rate: 1e-2,
1663 beta1: 0.9,
1664 beta2: 0.999,
1665 eps: 1e-8,
1666 weight_decay: 0.0,
1667 amsgrad: false,
1668 },
1669 // AMSGrad enabled
1670 AdamConfig {
1671 learning_rate: 1e-4,
1672 beta1: 0.9,
1673 beta2: 0.999,
1674 eps: 1e-8,
1675 weight_decay: 1e-4,
1676 amsgrad: true,
1677 },
1678 // Custom betas
1679 AdamConfig {
1680 learning_rate: 5e-4,
1681 beta1: 0.95,
1682 beta2: 0.9999,
1683 eps: 1e-7,
1684 weight_decay: 1e-5,
1685 amsgrad: false,
1686 },
1687 ];
1688
1689 for config in test_configs {
1690 // Create optimizer with specific configuration
1691 let weight = Tensor::ones(vec![2, 2]).with_requires_grad();
1692 let mut optimizer = Adam::with_config(config.clone());
1693 optimizer.add_parameter(&weight);
1694
1695 // Test JSON roundtrip
1696 let json = <Adam as crate::serialization::Serializable>::to_json(&optimizer).unwrap();
1697 let restored_json =
1698 <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1699 assert_eq!(config.learning_rate, restored_json.config().learning_rate);
1700 assert_eq!(config.beta1, restored_json.config().beta1);
1701 assert_eq!(config.beta2, restored_json.config().beta2);
1702 assert_eq!(config.eps, restored_json.config().eps);
1703 assert_eq!(config.weight_decay, restored_json.config().weight_decay);
1704 assert_eq!(config.amsgrad, restored_json.config().amsgrad);
1705
1706 // Test binary roundtrip
1707 let binary =
1708 <Adam as crate::serialization::Serializable>::to_binary(&optimizer).unwrap();
1709 let restored_binary =
1710 <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1711 assert_eq!(config.learning_rate, restored_binary.config().learning_rate);
1712 assert_eq!(config.beta1, restored_binary.config().beta1);
1713 assert_eq!(config.beta2, restored_binary.config().beta2);
1714 assert_eq!(config.eps, restored_binary.config().eps);
1715 assert_eq!(config.weight_decay, restored_binary.config().weight_decay);
1716 assert_eq!(config.amsgrad, restored_binary.config().amsgrad);
1717 }
1718 }
1719
1720 #[test]
1721 fn test_serializable_edge_cases() {
1722 // Test optimizer with no parameters
1723 let empty_optimizer = Adam::new();
1724 let json = <Adam as crate::serialization::Serializable>::to_json(&empty_optimizer).unwrap();
1725 let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1726 assert_eq!(
1727 empty_optimizer.saved_parameter_count(),
1728 restored.saved_parameter_count()
1729 );
1730 assert_eq!(empty_optimizer.step_count, restored.step_count);
1731
1732 let binary =
1733 <Adam as crate::serialization::Serializable>::to_binary(&empty_optimizer).unwrap();
1734 let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1735 assert_eq!(
1736 empty_optimizer.saved_parameter_count(),
1737 restored.saved_parameter_count()
1738 );
1739 assert_eq!(empty_optimizer.step_count, restored.step_count);
1740
1741 // Test optimizer with extreme configuration values
1742 let extreme_config = AdamConfig {
1743 learning_rate: 1e-10, // Very small learning rate
1744 beta1: 0.999999, // Very high beta1
1745 beta2: 0.000001, // Very low beta2
1746 eps: 1e-15, // Very small epsilon
1747 weight_decay: 1e-1, // High weight decay
1748 amsgrad: true,
1749 };
1750
1751 let weight = Tensor::ones(vec![1]).with_requires_grad();
1752 let mut extreme_optimizer = Adam::with_config(extreme_config.clone());
1753 extreme_optimizer.add_parameter(&weight);
1754
1755 let json =
1756 <Adam as crate::serialization::Serializable>::to_json(&extreme_optimizer).unwrap();
1757 let restored = <Adam as crate::serialization::Serializable>::from_json(&json).unwrap();
1758 assert_eq!(
1759 extreme_config.learning_rate,
1760 restored.config().learning_rate
1761 );
1762 assert_eq!(extreme_config.beta1, restored.config().beta1);
1763 assert_eq!(extreme_config.beta2, restored.config().beta2);
1764 assert_eq!(extreme_config.eps, restored.config().eps);
1765 assert_eq!(extreme_config.weight_decay, restored.config().weight_decay);
1766 assert_eq!(extreme_config.amsgrad, restored.config().amsgrad);
1767
1768 let binary =
1769 <Adam as crate::serialization::Serializable>::to_binary(&extreme_optimizer).unwrap();
1770 let restored = <Adam as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1771 assert_eq!(
1772 extreme_config.learning_rate,
1773 restored.config().learning_rate
1774 );
1775 assert_eq!(extreme_config.beta1, restored.config().beta1);
1776 assert_eq!(extreme_config.beta2, restored.config().beta2);
1777 assert_eq!(extreme_config.eps, restored.config().eps);
1778 assert_eq!(extreme_config.weight_decay, restored.config().weight_decay);
1779 assert_eq!(extreme_config.amsgrad, restored.config().amsgrad);
1780 }
1781}