train_station/serialization/
mod.rs

1//! Serialization and deserialization system for Train Station objects
2//!
3//! This module provides a robust, zero-dependency serialization framework that enables
4//! persistent storage and model checkpointing for all Train Station objects. The system
5//! supports both human-readable JSON format for debugging and efficient binary format
6//! for production deployment.
7//!
8//! # Design Philosophy
9//!
10//! The serialization system follows Train Station's core principles:
11//! - **Zero external dependencies**: Uses only the standard library
12//! - **Maximum performance**: Optimized binary format for production use
13//! - **Safety first**: Comprehensive validation and error handling
14//! - **Future-proof**: Generic trait-based design for extensibility
15//!
16//! # Supported Formats
17//!
18//! ## JSON Format
19//! Human-readable format suitable for:
20//! - Model inspection and debugging
21//! - Configuration files and version control
22//! - Cross-language interoperability
23//! - Development and testing workflows
24//!
25//! ## Binary Format
26//! Optimized binary format suitable for:
27//! - Production model deployment
28//! - High-frequency checkpointing
29//! - Network transmission and storage
30//! - Memory and storage-constrained environments
31//!
32//! # Organization
33//!
34//! - `core/` - Core serialization types, traits, and functionality
35//! - `binary/` - Binary format serialization and deserialization
36//! - `json/` - JSON format serialization and deserialization
37//!
38//! # Examples
39//!
40//! Basic serialization usage:
41//!
42//! ```
43//! use train_station::serialization::{StructSerializer, StructDeserializer, Format};
44//! use std::collections::HashMap;
45//!
46//! // Create a simple data structure
47//! let mut data = HashMap::new();
48//! data.insert("name".to_string(), "test".to_string());
49//! data.insert("value".to_string(), "42".to_string());
50//!
51//! // Serialize to JSON
52//! let serializer = StructSerializer::new()
53//!     .field("data", &data)
54//!     .field("version", &1u32);
55//! let json = serializer.to_json().unwrap();
56//! assert!(json.contains("test"));
57//!
58//! // Deserialize from JSON
59//! let mut deserializer = StructDeserializer::from_json(&json).unwrap();
60//! let loaded_data: HashMap<String, String> = deserializer.field("data").unwrap();
61//! let version: u32 = deserializer.field("version").unwrap();
62//! assert_eq!(loaded_data.get("name").unwrap(), "test");
63//! assert_eq!(version, 1);
64//! ```
65//!
66//! # Thread Safety
67//!
68//! All serialization operations are thread-safe and can be performed concurrently
69//! on different objects. The underlying file I/O operations use standard library
70//! primitives that provide appropriate synchronization.
71//!
72//! # Error Handling
73//!
74//! The serialization system provides comprehensive error handling through the
75//! `SerializationError` type, which includes detailed information about what
76//! went wrong during serialization or deserialization. All operations return
77//! `Result` types to ensure errors are handled explicitly.
78
79use std::fs::{File, OpenOptions};
80use std::io::{BufReader, BufWriter, Read, Write};
81use std::path::Path;
82
83pub(crate) mod binary;
84pub(crate) mod core;
85pub(crate) mod json;
86
87// Re-export core functionality for convenience
88pub use core::{
89    FieldValue, FromFieldValue, SerializationError, SerializationResult, StructDeserializer,
90    StructSerializable, StructSerializer, ToFieldValue,
91};
92
93/// Serialization format options for saving and loading objects
94///
95/// This enum defines the available serialization formats supported by the
96/// Train Station serialization system. Each format has specific use cases
97/// and performance characteristics.
98///
99/// # Variants
100///
101/// * `Json` - Human-readable JSON format for debugging and inspection
102/// * `Binary` - Efficient binary format for production deployment
103///
104/// # Examples
105///
106/// ```
107/// use train_station::serialization::Format;
108///
109/// // Check format variants
110/// let json_format = Format::Json;
111/// let binary_format = Format::Binary;
112/// assert_ne!(json_format, binary_format);
113/// ```
114///
115/// # Performance Considerations
116///
117/// - **JSON**: Larger file sizes, slower serialization, human-readable
118/// - **Binary**: Smaller file sizes, faster serialization, machine-optimized
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum Format {
121    /// Human-readable JSON format
122    ///
123    /// Suitable for debugging, configuration files, and cross-language
124    /// interoperability. Produces larger files but allows easy inspection
125    /// and modification of serialized data.
126    Json,
127    /// Efficient binary format
128    ///
129    /// Optimized for production deployment with minimal file sizes and
130    /// maximum serialization speed. Not human-readable but provides
131    /// the best performance characteristics.
132    Binary,
133}
134
135/// Core serialization trait for Train Station objects
136///
137/// This trait provides a unified interface for saving and loading objects in multiple
138/// formats. All serializable objects must implement this trait to enable persistent
139/// storage and model checkpointing. The trait includes both file-based and writer-based
140/// operations for maximum flexibility.
141///
142/// # Required Methods
143///
144/// * `to_json` - Serialize the object to JSON format
145/// * `from_json` - Deserialize an object from JSON format
146/// * `to_binary` - Serialize the object to binary format
147/// * `from_binary` - Deserialize an object from binary format
148///
149/// # Provided Methods
150///
151/// * `save` - Save the object to a file in the specified format
152/// * `save_to_writer` - Save the object to a writer in the specified format
153/// * `load` - Load an object from a file in the specified format
154/// * `load_from_reader` - Load an object from a reader in the specified format
155///
156/// # Safety
157///
158/// Implementations must ensure that:
159/// - Serialized data contains all necessary information for reconstruction
160/// - Deserialization validates all input data thoroughly
161/// - Memory safety is maintained during reconstruction
162/// - No undefined behavior occurs with malformed input
163///
164/// # Examples
165///
166/// ```
167/// use train_station::serialization::{Serializable, Format, SerializationResult};
168///
169/// // Example implementation for a simple struct
170/// #[derive(Debug, PartialEq)]
171/// struct SimpleData {
172///     value: i32,
173/// }
174///
175/// impl Serializable for SimpleData {
176///     fn to_json(&self) -> SerializationResult<String> {
177///         Ok(format!(r#"{{"value":{}}}"#, self.value))
178///     }
179///     
180///     fn from_json(json: &str) -> SerializationResult<Self> {
181///         // Simple parsing for demonstration
182///         if let Some(start) = json.find("value\":") {
183///             let value_str = &json[start + 7..];
184///             if let Some(end) = value_str.find('}') {
185///                 let value: i32 = value_str[..end].parse()
186///                     .map_err(|_| "Invalid number format")?;
187///                 return Ok(SimpleData { value });
188///             }
189///         }
190///         Err("Invalid JSON format".into())
191///     }
192///     
193///     fn to_binary(&self) -> SerializationResult<Vec<u8>> {
194///         Ok(self.value.to_le_bytes().to_vec())
195///     }
196///     
197///     fn from_binary(data: &[u8]) -> SerializationResult<Self> {
198///         if data.len() != 4 {
199///             return Err("Invalid binary data length".into());
200///         }
201///         let value = i32::from_le_bytes([data[0], data[1], data[2], data[3]]);
202///         Ok(SimpleData { value })
203///     }
204/// }
205///
206/// // Usage example
207/// let data = SimpleData { value: 42 };
208/// let json = data.to_json().unwrap();
209/// let loaded = SimpleData::from_json(&json).unwrap();
210/// assert_eq!(data, loaded);
211/// ```
212///
213/// # Implementors
214///
215/// Common types that implement this trait include:
216/// * `Tensor` - For serializing tensor data and metadata
217/// * `AdamConfig` - For serializing optimizer configuration
218/// * `SerializableAdam` - For serializing optimizer state
219pub trait Serializable: Sized {
220    /// Save the object to a file in the specified format
221    ///
222    /// This method creates or overwrites a file at the specified path and writes
223    /// the serialized object data in the requested format. The file is created
224    /// with write permissions and truncated if it already exists.
225    ///
226    /// # Arguments
227    ///
228    /// * `path` - File path where the object should be saved
229    /// * `format` - Serialization format (JSON or Binary)
230    ///
231    /// # Returns
232    ///
233    /// `Ok(())` on success, or `SerializationError` on failure
234    ///
235    /// # Examples
236    ///
237    /// ```
238    /// use train_station::serialization::{Serializable, Format, SerializationResult};
239    /// use std::io::Write;
240    ///
241    /// // Simple example struct
242    /// struct TestData { value: i32 }
243    /// impl Serializable for TestData {
244    ///     fn to_json(&self) -> SerializationResult<String> {
245    ///         Ok(format!(r#"{{"value":{}}}"#, self.value))
246    ///     }
247    ///     fn from_json(json: &str) -> SerializationResult<Self> {
248    ///         Ok(TestData { value: 42 }) // Simplified for example
249    ///     }
250    ///     fn to_binary(&self) -> SerializationResult<Vec<u8>> {
251    ///         Ok(self.value.to_le_bytes().to_vec())
252    ///     }
253    ///     fn from_binary(_data: &[u8]) -> SerializationResult<Self> {
254    ///         Ok(TestData { value: 42 }) // Simplified for example
255    ///     }
256    /// }
257    ///
258    /// let data = TestData { value: 42 };
259    ///
260    /// // Save to temporary file (cleanup handled by temp directory)
261    /// let temp_dir = std::env::temp_dir();
262    /// let json_path = temp_dir.join("test_data.json");
263    /// data.save(&json_path, Format::Json).unwrap();
264    ///
265    /// // Verify file was created
266    /// assert!(json_path.exists());
267    ///
268    /// // Clean up
269    /// std::fs::remove_file(&json_path).ok();
270    /// ```
271    fn save<P: AsRef<Path>>(&self, path: P, format: Format) -> SerializationResult<()> {
272        let file = OpenOptions::new()
273            .write(true)
274            .create(true)
275            .truncate(true)
276            .open(path)?;
277
278        let mut writer = BufWriter::new(file);
279        self.save_to_writer(&mut writer, format)
280    }
281
282    /// Save the object to a writer in the specified format
283    ///
284    /// This method serializes the object and writes the data to the provided writer.
285    /// The writer is flushed after writing to ensure all data is written. This method
286    /// is useful for streaming serialization or writing to non-file destinations.
287    ///
288    /// # Arguments
289    ///
290    /// * `writer` - Writer to output serialized data
291    /// * `format` - Serialization format (JSON or Binary)
292    ///
293    /// # Returns
294    ///
295    /// `Ok(())` on success, or `SerializationError` on failure
296    fn save_to_writer<W: Write>(&self, writer: &mut W, format: Format) -> SerializationResult<()> {
297        match format {
298            Format::Json => {
299                let json_data = self.to_json()?;
300                writer.write_all(json_data.as_bytes())?;
301            }
302            Format::Binary => {
303                let binary_data = self.to_binary()?;
304                writer.write_all(&binary_data)?;
305            }
306        }
307        writer.flush()?;
308        Ok(())
309    }
310
311    /// Load an object from a file in the specified format
312    ///
313    /// This method reads the entire file content and deserializes it into an object
314    /// of the implementing type. The file must exist and contain valid serialized
315    /// data in the specified format.
316    ///
317    /// # Arguments
318    ///
319    /// * `path` - File path to read from
320    /// * `format` - Expected serialization format
321    ///
322    /// # Returns
323    ///
324    /// The deserialized object on success, or `SerializationError` on failure
325    ///
326    /// # Examples
327    ///
328    /// ```
329    /// use train_station::serialization::{Serializable, Format, SerializationResult};
330    /// use std::io::Write;
331    ///
332    /// // Simple example struct
333    /// #[derive(Debug, PartialEq)]
334    /// struct TestData { value: i32 }
335    /// impl Serializable for TestData {
336    ///     fn to_json(&self) -> SerializationResult<String> {
337    ///         Ok(format!(r#"{{"value":{}}}"#, self.value))
338    ///     }
339    ///     fn from_json(json: &str) -> SerializationResult<Self> {
340    ///         // Simple parsing for demonstration
341    ///         if json.contains("42") {
342    ///             Ok(TestData { value: 42 })
343    ///         } else {
344    ///             Ok(TestData { value: 0 })
345    ///         }
346    ///     }
347    ///     fn to_binary(&self) -> SerializationResult<Vec<u8>> {
348    ///         Ok(self.value.to_le_bytes().to_vec())
349    ///     }
350    ///     fn from_binary(data: &[u8]) -> SerializationResult<Self> {
351    ///         if data.len() >= 4 {
352    ///             let value = i32::from_le_bytes([data[0], data[1], data[2], data[3]]);
353    ///             Ok(TestData { value })
354    ///         } else {
355    ///             Ok(TestData { value: 0 })
356    ///         }
357    ///     }
358    /// }
359    ///
360    /// let original = TestData { value: 42 };
361    ///
362    /// // Save and load from temporary file
363    /// let temp_dir = std::env::temp_dir();
364    /// let json_path = temp_dir.join("test_load.json");
365    /// original.save(&json_path, Format::Json).unwrap();
366    ///
367    /// let loaded = TestData::load(&json_path, Format::Json).unwrap();
368    /// assert_eq!(original, loaded);
369    ///
370    /// // Clean up
371    /// std::fs::remove_file(&json_path).ok();
372    /// ```
373    fn load<P: AsRef<Path>>(path: P, format: Format) -> SerializationResult<Self> {
374        let file = File::open(path)?;
375        let mut reader = BufReader::new(file);
376        Self::load_from_reader(&mut reader, format)
377    }
378
379    /// Load an object from a reader in the specified format
380    ///
381    /// This method reads all available data from the provided reader and deserializes
382    /// it into an object of the implementing type. The reader must contain complete
383    /// serialized data in the specified format.
384    ///
385    /// # Arguments
386    ///
387    /// * `reader` - Reader containing serialized data
388    /// * `format` - Expected serialization format
389    ///
390    /// # Returns
391    ///
392    /// The deserialized object on success, or `SerializationError` on failure
393    fn load_from_reader<R: Read>(reader: &mut R, format: Format) -> SerializationResult<Self> {
394        match format {
395            Format::Json => {
396                let mut json_data = String::new();
397                reader.read_to_string(&mut json_data)?;
398                Self::from_json(&json_data)
399            }
400            Format::Binary => {
401                let mut binary_data = Vec::new();
402                reader.read_to_end(&mut binary_data)?;
403                Self::from_binary(&binary_data)
404            }
405        }
406    }
407
408    /// Serialize the object to JSON format
409    ///
410    /// This method converts the object into a human-readable JSON string representation.
411    /// The JSON format is suitable for debugging, configuration files, and cross-language
412    /// interoperability.
413    ///
414    /// # Returns
415    ///
416    /// JSON string representation of the object on success, or `SerializationError` on failure
417    fn to_json(&self) -> SerializationResult<String>;
418
419    /// Deserialize an object from JSON format
420    ///
421    /// This method parses a JSON string and reconstructs an object of the implementing
422    /// type. The JSON must contain all necessary fields and data in the expected format.
423    ///
424    /// # Arguments
425    ///
426    /// * `json` - JSON string containing serialized object
427    ///
428    /// # Returns
429    ///
430    /// The deserialized object on success, or `SerializationError` on failure
431    fn from_json(json: &str) -> SerializationResult<Self>;
432
433    /// Serialize the object to binary format
434    ///
435    /// This method converts the object into a compact binary representation optimized
436    /// for storage and transmission. The binary format provides maximum performance
437    /// and minimal file sizes.
438    ///
439    /// # Returns
440    ///
441    /// Binary representation of the object on success, or `SerializationError` on failure
442    fn to_binary(&self) -> SerializationResult<Vec<u8>>;
443
444    /// Deserialize an object from binary format
445    ///
446    /// This method parses binary data and reconstructs an object of the implementing
447    /// type. The binary data must contain complete serialized information in the
448    /// expected format.
449    ///
450    /// # Arguments
451    ///
452    /// * `data` - Binary data containing serialized object
453    ///
454    /// # Returns
455    ///
456    /// The deserialized object on success, or `SerializationError` on failure
457    fn from_binary(data: &[u8]) -> SerializationResult<Self>;
458}
459
460/// Utility functions for common serialization tasks
461///
462/// This module provides helper functions for format detection, file extension
463/// management, and size estimation for serialization operations. These functions
464/// are used internally by the serialization system to support file operations
465/// and provide estimates for memory allocation.
466///
467/// # Purpose
468///
469/// The utilities in this module handle:
470/// - File extension mapping for different serialization formats
471/// - Automatic format detection based on file paths
472/// - Size estimation for binary serialization planning
473/// - Common helper functions used across the serialization system
474pub(crate) mod utils {
475    #[cfg(test)]
476    use super::Format;
477    #[cfg(test)]
478    use std::path::Path;
479
480    /// Get the appropriate file extension for a format
481    ///
482    /// Returns the standard file extension associated with each serialization format.
483    /// This is useful for automatically determining file extensions when saving
484    /// or for format detection based on file paths.
485    ///
486    /// # Arguments
487    ///
488    /// * `format` - The serialization format
489    ///
490    /// # Returns
491    ///
492    /// The file extension as a string slice
493    ///
494    /// # Examples
495    ///
496    /// ```
497    /// use train_station::serialization::Format;
498    ///
499    /// // This function is internal, but demonstrates the concept
500    /// fn format_extension(format: Format) -> &'static str {
501    ///     match format {
502    ///         Format::Json => "json",
503    ///         Format::Binary => "bin",
504    ///     }
505    /// }
506    ///
507    /// assert_eq!(format_extension(Format::Json), "json");
508    /// assert_eq!(format_extension(Format::Binary), "bin");
509    /// ```
510    #[cfg(test)]
511    pub(crate) fn format_extension(format: Format) -> &'static str {
512        match format {
513            Format::Json => "json",
514            Format::Binary => "bin",
515        }
516    }
517
518    /// Detect format from file extension
519    ///
520    /// Attempts to determine the serialization format based on the file extension.
521    /// Supports case-insensitive extension matching for common format extensions.
522    ///
523    /// # Arguments
524    ///
525    /// * `path` - File path to analyze
526    ///
527    /// # Returns
528    ///
529    /// `Some(Format)` if the extension is recognized, `None` otherwise
530    ///
531    /// # Examples
532    ///
533    /// ```
534    /// use train_station::serialization::Format;
535    /// use std::path::Path;
536    ///
537    /// // This function is internal, but demonstrates the concept
538    /// fn detect_format<P: AsRef<Path>>(path: P) -> Option<Format> {
539    ///     path.as_ref()
540    ///         .extension()
541    ///         .and_then(|ext| ext.to_str())
542    ///         .and_then(|ext| match ext.to_lowercase().as_str() {
543    ///             "json" => Some(Format::Json),
544    ///             "bin" => Some(Format::Binary),
545    ///             _ => None,
546    ///         })
547    /// }
548    ///
549    /// assert_eq!(detect_format("model.json"), Some(Format::Json));
550    /// assert_eq!(detect_format("model.JSON"), Some(Format::Json));
551    /// assert_eq!(detect_format("model.bin"), Some(Format::Binary));
552    /// assert_eq!(detect_format("model.txt"), None);
553    /// ```
554    #[cfg(test)]
555    pub(crate) fn detect_format<P: AsRef<Path>>(path: P) -> Option<Format> {
556        path.as_ref()
557            .extension()
558            .and_then(|ext| ext.to_str())
559            .and_then(|ext| match ext.to_lowercase().as_str() {
560                "json" => Some(Format::Json),
561                "bin" => Some(Format::Binary),
562                _ => None,
563            })
564    }
565
566    /// Estimate serialized size for binary format
567    ///
568    /// Provides a rough estimate of the binary serialized size based on the number
569    /// of tensors, total elements, and metadata fields. This is useful for memory
570    /// allocation and storage planning.
571    ///
572    /// # Arguments
573    ///
574    /// * `tensor_count` - Number of tensors to be serialized
575    /// * `total_elements` - Total number of elements across all tensors
576    /// * `metadata_fields` - Number of metadata fields per tensor
577    ///
578    /// # Returns
579    ///
580    /// Estimated size in bytes for the binary serialization
581    ///
582    /// # Examples
583    ///
584    /// ```
585    /// // This function is internal, but demonstrates the concept
586    /// fn estimate_binary_size(
587    ///     tensor_count: usize,
588    ///     total_elements: usize,
589    ///     metadata_fields: usize,
590    /// ) -> usize {
591    ///     // Header + magic number + version
592    ///     let header_size = 16;
593    ///     // Tensor data (f32 per element)
594    ///     let data_size = total_elements * 4;
595    ///     // Shape information (dimensions, strides, metadata)
596    ///     let shape_size = tensor_count * (metadata_fields * 8 + 64);
597    ///     header_size + data_size + shape_size
598    /// }
599    ///
600    /// let estimated_size = estimate_binary_size(3, 1000, 5);
601    /// assert!(estimated_size > 4000); // At least data size
602    /// ```
603    #[cfg(test)]
604    pub(crate) fn estimate_binary_size(
605        tensor_count: usize,
606        total_elements: usize,
607        metadata_fields: usize,
608    ) -> usize {
609        // Header + magic number + version
610        let header_size = 16;
611
612        // Tensor data (f32 per element)
613        let data_size = total_elements * 4;
614
615        // Shape information (dimensions, strides, metadata)
616        let shape_size = tensor_count * (metadata_fields * 8 + 64);
617
618        header_size + data_size + shape_size
619    }
620}
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625
626    #[test]
627    fn test_format_detection() {
628        assert_eq!(utils::detect_format("model.json"), Some(Format::Json));
629        assert_eq!(utils::detect_format("model.JSON"), Some(Format::Json)); // Case insensitive
630        assert_eq!(utils::detect_format("model.bin"), Some(Format::Binary));
631        assert_eq!(utils::detect_format("model.BIN"), Some(Format::Binary)); // Case insensitive
632        assert_eq!(utils::detect_format("model.txt"), None);
633        assert_eq!(utils::detect_format("model"), None);
634        assert_eq!(utils::detect_format(""), None);
635    }
636
637    #[test]
638    fn test_format_extensions() {
639        assert_eq!(utils::format_extension(Format::Json), "json");
640        assert_eq!(utils::format_extension(Format::Binary), "bin");
641    }
642
643    #[test]
644    fn test_binary_size_estimation() {
645        // Single tensor with 1000 elements
646        let estimated = utils::estimate_binary_size(1, 1000, 5);
647        assert!(estimated >= 4000); // At least data size (1000 * 4 bytes)
648        assert!(estimated <= 5000); // Reasonable metadata overhead
649
650        // Multiple tensors
651        let estimated_multi = utils::estimate_binary_size(3, 3000, 5);
652        assert!(estimated_multi >= 12000); // At least data size (3000 * 4 bytes)
653        assert!(estimated_multi > estimated * 2); // Should be larger than single tensor
654    }
655}