train_station/serialization/
mod.rs

1//! Serialization and deserialization system for Train Station objects
2//!
3//! This module provides a robust, zero-dependency serialization framework that enables
4//! persistent storage and model checkpointing for all Train Station objects. The system
5//! supports both human-readable JSON format for debugging and efficient binary format
6//! for production deployment.
7//!
8//! # Design Philosophy
9//!
10//! The serialization system follows Train Station's core principles:
11//! - **Zero external dependencies**: Uses only the standard library
12//! - **Maximum performance**: Optimized binary format for production use
13//! - **Safety first**: Comprehensive validation and error handling
14//! - **Future-proof**: Generic trait-based design for extensibility
15//!
16//! # Supported Formats
17//!
18//! ## JSON Format
19//! Human-readable format suitable for:
20//! - Model inspection and debugging
21//! - Configuration files and version control
22//! - Cross-language interoperability
23//! - Development and testing workflows
24//!
25//! ## Binary Format
26//! Optimized binary format suitable for:
27//! - Production model deployment
28//! - High-frequency checkpointing
29//! - Network transmission and storage
30//! - Memory and storage-constrained environments
31//!
32//! # Organization
33//!
34//! - `core/` - Core serialization types, traits, and functionality
35//! - `binary/` - Binary format serialization and deserialization
36//! - `json/` - JSON format serialization and deserialization
37//!
38//! # Examples
39//!
40//! Basic serialization usage:
41//!
42//! ```
43//! use train_station::serialization::{StructSerializer, StructDeserializer, Format};
44//! use std::collections::HashMap;
45//!
46//! // Create a simple data structure
47//! let mut data = HashMap::new();
48//! data.insert("name".to_string(), "test".to_string());
49//! data.insert("value".to_string(), "42".to_string());
50//!
51//! // Serialize to JSON
52//! let serializer = StructSerializer::new()
53//!     .field("data", &data)
54//!     .field("version", &1u32);
55//! let json = serializer.to_json().unwrap();
56//! assert!(json.contains("test"));
57//!
58//! // Deserialize from JSON
59//! let mut deserializer = StructDeserializer::from_json(&json).unwrap();
60//! let loaded_data: HashMap<String, String> = deserializer.field("data").unwrap();
61//! let version: u32 = deserializer.field("version").unwrap();
62//! assert_eq!(loaded_data.get("name").unwrap(), "test");
63//! assert_eq!(version, 1);
64//! ```
65//!
66//! # Thread Safety
67//!
68//! All serialization operations are thread-safe and can be performed concurrently
69//! on different objects. The underlying file I/O operations use standard library
70//! primitives that provide appropriate synchronization.
71//!
72//! # Error Handling
73//!
74//! The serialization system provides comprehensive error handling through the
75//! `SerializationError` type, which includes detailed information about what
76//! went wrong during serialization or deserialization. All operations return
77//! `Result` types to ensure errors are handled explicitly.
78
79use std::fs::{File, OpenOptions};
80use std::io::{BufReader, BufWriter, Read, Write};
81use std::path::Path;
82
83pub(crate) mod binary;
84pub(crate) mod core;
85pub(crate) mod json;
86
87// Re-export core functionality for convenience
88pub use core::{
89    FieldValue, FromFieldValue, SerializationError, SerializationResult, StructDeserializer,
90    StructSerializable, StructSerializer, ToFieldValue,
91};
92
93/// Serialization format options for saving and loading objects
94///
95/// This enum defines the available serialization formats supported by the
96/// Train Station serialization system. Each format has specific use cases
97/// and performance characteristics.
98///
99/// # Variants
100///
101/// * `Json` - Human-readable JSON format for debugging and inspection
102/// * `Binary` - Efficient binary format for production deployment
103///
104/// # Examples
105///
106/// ```
107/// use train_station::serialization::Format;
108///
109/// // Check format variants
110/// let json_format = Format::Json;
111/// let binary_format = Format::Binary;
112/// assert_ne!(json_format, binary_format);
113/// ```
114///
115/// # Performance Considerations
116///
117/// - **JSON**: Larger file sizes, slower serialization, human-readable
118/// - **Binary**: Smaller file sizes, faster serialization, machine-optimized
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum Format {
121    /// Human-readable JSON format
122    ///
123    /// Suitable for debugging, configuration files, and cross-language
124    /// interoperability. Produces larger files but allows easy inspection
125    /// and modification of serialized data.
126    Json,
127    /// Efficient binary format
128    ///
129    /// Optimized for production deployment with minimal file sizes and
130    /// maximum serialization speed. Not human-readable but provides
131    /// the best performance characteristics.
132    Binary,
133}
134
135/// Core serialization trait for Train Station objects
136///
137/// This trait provides a unified interface for saving and loading objects in multiple
138/// formats. All serializable objects must implement this trait to enable persistent
139/// storage and model checkpointing. The trait includes both file-based and writer-based
140/// operations for maximum flexibility.
141///
142/// # Required Methods
143///
144/// * `to_json` - Serialize the object to JSON format
145/// * `from_json` - Deserialize an object from JSON format
146/// * `to_binary` - Serialize the object to binary format
147/// * `from_binary` - Deserialize an object from binary format
148///
149/// # Provided Methods
150///
151/// * `save` - Save the object to a file in the specified format
152/// * `save_to_writer` - Save the object to a writer in the specified format
153/// * `load` - Load an object from a file in the specified format
154/// * `load_from_reader` - Load an object from a reader in the specified format
155///
156/// # Safety
157///
158/// Implementations must ensure that:
159/// - Serialized data contains all necessary information for reconstruction
160/// - Deserialization validates all input data thoroughly
161/// - Memory safety is maintained during reconstruction
162/// - No undefined behavior occurs with malformed input
163///
164/// # Examples
165///
166/// ```
167/// use train_station::serialization::{Serializable, Format, SerializationResult};
168///
169/// // Example implementation for a simple struct
170/// #[derive(Debug, PartialEq)]
171/// struct SimpleData {
172///     value: i32,
173/// }
174///
175/// impl Serializable for SimpleData {
176///     fn to_json(&self) -> SerializationResult<String> {
177///         Ok(format!(r#"{{"value":{}}}"#, self.value))
178///     }
179///     
180///     fn from_json(json: &str) -> SerializationResult<Self> {
181///         // Simple parsing for demonstration
182///         if let Some(start) = json.find("value\":") {
183///             let value_str = &json[start + 7..];
184///             if let Some(end) = value_str.find('}') {
185///                 let value: i32 = value_str[..end].parse()
186///                     .map_err(|_| "Invalid number format")?;
187///                 return Ok(SimpleData { value });
188///             }
189///         }
190///         Err("Invalid JSON format".into())
191///     }
192///     
193///     fn to_binary(&self) -> SerializationResult<Vec<u8>> {
194///         Ok(self.value.to_le_bytes().to_vec())
195///     }
196///     
197///     fn from_binary(data: &[u8]) -> SerializationResult<Self> {
198///         if data.len() != 4 {
199///             return Err("Invalid binary data length".into());
200///         }
201///         let value = i32::from_le_bytes([data[0], data[1], data[2], data[3]]);
202///         Ok(SimpleData { value })
203///     }
204/// }
205///
206/// // Usage example
207/// let data = SimpleData { value: 42 };
208/// let json = data.to_json().unwrap();
209/// let loaded = SimpleData::from_json(&json).unwrap();
210/// assert_eq!(data, loaded);
211/// ```
212///
213/// # Implementors
214///
215/// Common types that implement this trait include:
216/// * `Tensor` - For serializing tensor data and metadata
217/// * `AdamConfig` - For serializing optimizer configuration
218/// * `SerializableAdam` - For serializing optimizer state
219pub trait Serializable: Sized {
220    /// Save the object to a file in the specified format
221    ///
222    /// This method creates or overwrites a file at the specified path and writes
223    /// the serialized object data in the requested format. The file is created
224    /// with write permissions and truncated if it already exists.
225    ///
226    /// # Arguments
227    ///
228    /// * `path` - File path where the object should be saved
229    /// * `format` - Serialization format (JSON or Binary)
230    ///
231    /// # Returns
232    ///
233    /// `Ok(())` on success, or `SerializationError` on failure
234    ///
235    /// # Examples
236    ///
237    /// ```
238    /// use train_station::serialization::{Serializable, Format, SerializationResult};
239    /// use std::io::Write;
240    ///
241    /// // Simple example struct
242    /// struct TestData { value: i32 }
243    /// impl Serializable for TestData {
244    ///     fn to_json(&self) -> SerializationResult<String> {
245    ///         Ok(format!(r#"{{"value":{}}}"#, self.value))
246    ///     }
247    ///     fn from_json(json: &str) -> SerializationResult<Self> {
248    ///         Ok(TestData { value: 42 }) // Simplified for example
249    ///     }
250    ///     fn to_binary(&self) -> SerializationResult<Vec<u8>> {
251    ///         Ok(self.value.to_le_bytes().to_vec())
252    ///     }
253    ///     fn from_binary(_data: &[u8]) -> SerializationResult<Self> {
254    ///         Ok(TestData { value: 42 }) // Simplified for example
255    ///     }
256    /// }
257    ///
258    /// let data = TestData { value: 42 };
259    ///
260    /// // Save to temporary file (cleanup handled by temp directory)
261    /// let temp_dir = std::env::temp_dir();
262    /// let json_path = temp_dir.join("test_data.json");
263    /// data.save(&json_path, Format::Json).unwrap();
264    ///
265    /// // Verify file was created
266    /// assert!(json_path.exists());
267    ///
268    /// // Clean up
269    /// std::fs::remove_file(&json_path).ok();
270    /// ```
271    #[track_caller]
272    fn save<P: AsRef<Path>>(&self, path: P, format: Format) -> SerializationResult<()> {
273        let file = OpenOptions::new()
274            .write(true)
275            .create(true)
276            .truncate(true)
277            .open(path)?;
278
279        let mut writer = BufWriter::new(file);
280        self.save_to_writer(&mut writer, format)
281    }
282
283    /// Save the object to a writer in the specified format
284    ///
285    /// This method serializes the object and writes the data to the provided writer.
286    /// The writer is flushed after writing to ensure all data is written. This method
287    /// is useful for streaming serialization or writing to non-file destinations.
288    ///
289    /// # Arguments
290    ///
291    /// * `writer` - Writer to output serialized data
292    /// * `format` - Serialization format (JSON or Binary)
293    ///
294    /// # Returns
295    ///
296    /// `Ok(())` on success, or `SerializationError` on failure
297    #[track_caller]
298    fn save_to_writer<W: Write>(&self, writer: &mut W, format: Format) -> SerializationResult<()> {
299        match format {
300            Format::Json => {
301                let json_data = self.to_json()?;
302                writer.write_all(json_data.as_bytes())?;
303            }
304            Format::Binary => {
305                let binary_data = self.to_binary()?;
306                writer.write_all(&binary_data)?;
307            }
308        }
309        writer.flush()?;
310        Ok(())
311    }
312
313    /// Load an object from a file in the specified format
314    ///
315    /// This method reads the entire file content and deserializes it into an object
316    /// of the implementing type. The file must exist and contain valid serialized
317    /// data in the specified format.
318    ///
319    /// # Arguments
320    ///
321    /// * `path` - File path to read from
322    /// * `format` - Expected serialization format
323    ///
324    /// # Returns
325    ///
326    /// The deserialized object on success, or `SerializationError` on failure
327    ///
328    /// # Examples
329    ///
330    /// ```
331    /// use train_station::serialization::{Serializable, Format, SerializationResult};
332    /// use std::io::Write;
333    ///
334    /// // Simple example struct
335    /// #[derive(Debug, PartialEq)]
336    /// struct TestData { value: i32 }
337    /// impl Serializable for TestData {
338    ///     fn to_json(&self) -> SerializationResult<String> {
339    ///         Ok(format!(r#"{{"value":{}}}"#, self.value))
340    ///     }
341    ///     fn from_json(json: &str) -> SerializationResult<Self> {
342    ///         // Simple parsing for demonstration
343    ///         if json.contains("42") {
344    ///             Ok(TestData { value: 42 })
345    ///         } else {
346    ///             Ok(TestData { value: 0 })
347    ///         }
348    ///     }
349    ///     fn to_binary(&self) -> SerializationResult<Vec<u8>> {
350    ///         Ok(self.value.to_le_bytes().to_vec())
351    ///     }
352    ///     fn from_binary(data: &[u8]) -> SerializationResult<Self> {
353    ///         if data.len() >= 4 {
354    ///             let value = i32::from_le_bytes([data[0], data[1], data[2], data[3]]);
355    ///             Ok(TestData { value })
356    ///         } else {
357    ///             Ok(TestData { value: 0 })
358    ///         }
359    ///     }
360    /// }
361    ///
362    /// let original = TestData { value: 42 };
363    ///
364    /// // Save and load from temporary file
365    /// let temp_dir = std::env::temp_dir();
366    /// let json_path = temp_dir.join("test_load.json");
367    /// original.save(&json_path, Format::Json).unwrap();
368    ///
369    /// let loaded = TestData::load(&json_path, Format::Json).unwrap();
370    /// assert_eq!(original, loaded);
371    ///
372    /// // Clean up
373    /// std::fs::remove_file(&json_path).ok();
374    /// ```
375    #[track_caller]
376    fn load<P: AsRef<Path>>(path: P, format: Format) -> SerializationResult<Self> {
377        let file = File::open(path)?;
378        let mut reader = BufReader::new(file);
379        Self::load_from_reader(&mut reader, format)
380    }
381
382    /// Load an object from a reader in the specified format
383    ///
384    /// This method reads all available data from the provided reader and deserializes
385    /// it into an object of the implementing type. The reader must contain complete
386    /// serialized data in the specified format.
387    ///
388    /// # Arguments
389    ///
390    /// * `reader` - Reader containing serialized data
391    /// * `format` - Expected serialization format
392    ///
393    /// # Returns
394    ///
395    /// The deserialized object on success, or `SerializationError` on failure
396    #[track_caller]
397    fn load_from_reader<R: Read>(reader: &mut R, format: Format) -> SerializationResult<Self> {
398        match format {
399            Format::Json => {
400                let mut json_data = String::new();
401                reader.read_to_string(&mut json_data)?;
402                Self::from_json(&json_data)
403            }
404            Format::Binary => {
405                let mut binary_data = Vec::new();
406                reader.read_to_end(&mut binary_data)?;
407                Self::from_binary(&binary_data)
408            }
409        }
410    }
411
412    /// Serialize the object to JSON format
413    ///
414    /// This method converts the object into a human-readable JSON string representation.
415    /// The JSON format is suitable for debugging, configuration files, and cross-language
416    /// interoperability.
417    ///
418    /// # Returns
419    ///
420    /// JSON string representation of the object on success, or `SerializationError` on failure
421    #[track_caller]
422    fn to_json(&self) -> SerializationResult<String>;
423
424    /// Deserialize an object from JSON format
425    ///
426    /// This method parses a JSON string and reconstructs an object of the implementing
427    /// type. The JSON must contain all necessary fields and data in the expected format.
428    ///
429    /// # Arguments
430    ///
431    /// * `json` - JSON string containing serialized object
432    ///
433    /// # Returns
434    ///
435    /// The deserialized object on success, or `SerializationError` on failure
436    #[track_caller]
437    fn from_json(json: &str) -> SerializationResult<Self>;
438
439    /// Serialize the object to binary format
440    ///
441    /// This method converts the object into a compact binary representation optimized
442    /// for storage and transmission. The binary format provides maximum performance
443    /// and minimal file sizes.
444    ///
445    /// # Returns
446    ///
447    /// Binary representation of the object on success, or `SerializationError` on failure
448    #[track_caller]
449    fn to_binary(&self) -> SerializationResult<Vec<u8>>;
450
451    /// Deserialize an object from binary format
452    ///
453    /// This method parses binary data and reconstructs an object of the implementing
454    /// type. The binary data must contain complete serialized information in the
455    /// expected format.
456    ///
457    /// # Arguments
458    ///
459    /// * `data` - Binary data containing serialized object
460    ///
461    /// # Returns
462    ///
463    /// The deserialized object on success, or `SerializationError` on failure
464    #[track_caller]
465    fn from_binary(data: &[u8]) -> SerializationResult<Self>;
466}
467
468/// Utility functions for common serialization tasks
469///
470/// This module provides helper functions for format detection, file extension
471/// management, and size estimation for serialization operations. These functions
472/// are used internally by the serialization system to support file operations
473/// and provide estimates for memory allocation.
474///
475/// # Purpose
476///
477/// The utilities in this module handle:
478/// - File extension mapping for different serialization formats
479/// - Automatic format detection based on file paths
480/// - Size estimation for binary serialization planning
481/// - Common helper functions used across the serialization system
482pub(crate) mod utils {
483    #[cfg(test)]
484    use super::Format;
485    #[cfg(test)]
486    use std::path::Path;
487
488    /// Get the appropriate file extension for a format
489    ///
490    /// Returns the standard file extension associated with each serialization format.
491    /// This is useful for automatically determining file extensions when saving
492    /// or for format detection based on file paths.
493    ///
494    /// # Arguments
495    ///
496    /// * `format` - The serialization format
497    ///
498    /// # Returns
499    ///
500    /// The file extension as a string slice
501    ///
502    /// # Examples
503    ///
504    /// ```
505    /// use train_station::serialization::Format;
506    ///
507    /// // This function is internal, but demonstrates the concept
508    /// fn format_extension(format: Format) -> &'static str {
509    ///     match format {
510    ///         Format::Json => "json",
511    ///         Format::Binary => "bin",
512    ///     }
513    /// }
514    ///
515    /// assert_eq!(format_extension(Format::Json), "json");
516    /// assert_eq!(format_extension(Format::Binary), "bin");
517    /// ```
518    #[cfg(test)]
519    pub(crate) fn format_extension(format: Format) -> &'static str {
520        match format {
521            Format::Json => "json",
522            Format::Binary => "bin",
523        }
524    }
525
526    /// Detect format from file extension
527    ///
528    /// Attempts to determine the serialization format based on the file extension.
529    /// Supports case-insensitive extension matching for common format extensions.
530    ///
531    /// # Arguments
532    ///
533    /// * `path` - File path to analyze
534    ///
535    /// # Returns
536    ///
537    /// `Some(Format)` if the extension is recognized, `None` otherwise
538    ///
539    /// # Examples
540    ///
541    /// ```
542    /// use train_station::serialization::Format;
543    /// use std::path::Path;
544    ///
545    /// // This function is internal, but demonstrates the concept
546    /// fn detect_format<P: AsRef<Path>>(path: P) -> Option<Format> {
547    ///     path.as_ref()
548    ///         .extension()
549    ///         .and_then(|ext| ext.to_str())
550    ///         .and_then(|ext| match ext.to_lowercase().as_str() {
551    ///             "json" => Some(Format::Json),
552    ///             "bin" => Some(Format::Binary),
553    ///             _ => None,
554    ///         })
555    /// }
556    ///
557    /// assert_eq!(detect_format("model.json"), Some(Format::Json));
558    /// assert_eq!(detect_format("model.JSON"), Some(Format::Json));
559    /// assert_eq!(detect_format("model.bin"), Some(Format::Binary));
560    /// assert_eq!(detect_format("model.txt"), None);
561    /// ```
562    #[cfg(test)]
563    pub(crate) fn detect_format<P: AsRef<Path>>(path: P) -> Option<Format> {
564        path.as_ref()
565            .extension()
566            .and_then(|ext| ext.to_str())
567            .and_then(|ext| match ext.to_lowercase().as_str() {
568                "json" => Some(Format::Json),
569                "bin" => Some(Format::Binary),
570                _ => None,
571            })
572    }
573
574    /// Estimate serialized size for binary format
575    ///
576    /// Provides a rough estimate of the binary serialized size based on the number
577    /// of tensors, total elements, and metadata fields. This is useful for memory
578    /// allocation and storage planning.
579    ///
580    /// # Arguments
581    ///
582    /// * `tensor_count` - Number of tensors to be serialized
583    /// * `total_elements` - Total number of elements across all tensors
584    /// * `metadata_fields` - Number of metadata fields per tensor
585    ///
586    /// # Returns
587    ///
588    /// Estimated size in bytes for the binary serialization
589    ///
590    /// # Examples
591    ///
592    /// ```
593    /// // This function is internal, but demonstrates the concept
594    /// fn estimate_binary_size(
595    ///     tensor_count: usize,
596    ///     total_elements: usize,
597    ///     metadata_fields: usize,
598    /// ) -> usize {
599    ///     // Header + magic number + version
600    ///     let header_size = 16;
601    ///     // Tensor data (f32 per element)
602    ///     let data_size = total_elements * 4;
603    ///     // Shape information (dimensions, strides, metadata)
604    ///     let shape_size = tensor_count * (metadata_fields * 8 + 64);
605    ///     header_size + data_size + shape_size
606    /// }
607    ///
608    /// let estimated_size = estimate_binary_size(3, 1000, 5);
609    /// assert!(estimated_size > 4000); // At least data size
610    /// ```
611    #[cfg(test)]
612    pub(crate) fn estimate_binary_size(
613        tensor_count: usize,
614        total_elements: usize,
615        metadata_fields: usize,
616    ) -> usize {
617        // Header + magic number + version
618        let header_size = 16;
619
620        // Tensor data (f32 per element)
621        let data_size = total_elements * 4;
622
623        // Shape information (dimensions, strides, metadata)
624        let shape_size = tensor_count * (metadata_fields * 8 + 64);
625
626        header_size + data_size + shape_size
627    }
628}
629
630#[cfg(test)]
631mod tests {
632    use super::*;
633
634    #[test]
635    fn test_format_detection() {
636        assert_eq!(utils::detect_format("model.json"), Some(Format::Json));
637        assert_eq!(utils::detect_format("model.JSON"), Some(Format::Json)); // Case insensitive
638        assert_eq!(utils::detect_format("model.bin"), Some(Format::Binary));
639        assert_eq!(utils::detect_format("model.BIN"), Some(Format::Binary)); // Case insensitive
640        assert_eq!(utils::detect_format("model.txt"), None);
641        assert_eq!(utils::detect_format("model"), None);
642        assert_eq!(utils::detect_format(""), None);
643    }
644
645    #[test]
646    fn test_format_extensions() {
647        assert_eq!(utils::format_extension(Format::Json), "json");
648        assert_eq!(utils::format_extension(Format::Binary), "bin");
649    }
650
651    #[test]
652    fn test_binary_size_estimation() {
653        // Single tensor with 1000 elements
654        let estimated = utils::estimate_binary_size(1, 1000, 5);
655        assert!(estimated >= 4000); // At least data size (1000 * 4 bytes)
656        assert!(estimated <= 5000); // Reasonable metadata overhead
657
658        // Multiple tensors
659        let estimated_multi = utils::estimate_binary_size(3, 3000, 5);
660        assert!(estimated_multi >= 12000); // At least data size (3000 * 4 bytes)
661        assert!(estimated_multi > estimated * 2); // Should be larger than single tensor
662    }
663}