train_station/serialization/mod.rs
1//! Serialization and deserialization system for Train Station objects
2//!
3//! This module provides a robust, zero-dependency serialization framework that enables
4//! persistent storage and model checkpointing for all Train Station objects. The system
5//! supports both human-readable JSON format for debugging and efficient binary format
6//! for production deployment.
7//!
8//! # Design Philosophy
9//!
10//! The serialization system follows Train Station's core principles:
11//! - **Zero external dependencies**: Uses only the standard library
12//! - **Maximum performance**: Optimized binary format for production use
13//! - **Safety first**: Comprehensive validation and error handling
14//! - **Future-proof**: Generic trait-based design for extensibility
15//!
16//! # Supported Formats
17//!
18//! ## JSON Format
19//! Human-readable format suitable for:
20//! - Model inspection and debugging
21//! - Configuration files and version control
22//! - Cross-language interoperability
23//! - Development and testing workflows
24//!
25//! ## Binary Format
26//! Optimized binary format suitable for:
27//! - Production model deployment
28//! - High-frequency checkpointing
29//! - Network transmission and storage
30//! - Memory and storage-constrained environments
31//!
32//! # Organization
33//!
34//! - `core/` - Core serialization types, traits, and functionality
35//! - `binary/` - Binary format serialization and deserialization
36//! - `json/` - JSON format serialization and deserialization
37//!
38//! # Examples
39//!
40//! Basic serialization usage:
41//!
42//! ```
43//! use train_station::serialization::{StructSerializer, StructDeserializer, Format};
44//! use std::collections::HashMap;
45//!
46//! // Create a simple data structure
47//! let mut data = HashMap::new();
48//! data.insert("name".to_string(), "test".to_string());
49//! data.insert("value".to_string(), "42".to_string());
50//!
51//! // Serialize to JSON
52//! let serializer = StructSerializer::new()
53//! .field("data", &data)
54//! .field("version", &1u32);
55//! let json = serializer.to_json().unwrap();
56//! assert!(json.contains("test"));
57//!
58//! // Deserialize from JSON
59//! let mut deserializer = StructDeserializer::from_json(&json).unwrap();
60//! let loaded_data: HashMap<String, String> = deserializer.field("data").unwrap();
61//! let version: u32 = deserializer.field("version").unwrap();
62//! assert_eq!(loaded_data.get("name").unwrap(), "test");
63//! assert_eq!(version, 1);
64//! ```
65//!
66//! # Thread Safety
67//!
68//! All serialization operations are thread-safe and can be performed concurrently
69//! on different objects. The underlying file I/O operations use standard library
70//! primitives that provide appropriate synchronization.
71//!
72//! # Error Handling
73//!
74//! The serialization system provides comprehensive error handling through the
75//! `SerializationError` type, which includes detailed information about what
76//! went wrong during serialization or deserialization. All operations return
77//! `Result` types to ensure errors are handled explicitly.
78
79use std::fs::{File, OpenOptions};
80use std::io::{BufReader, BufWriter, Read, Write};
81use std::path::Path;
82
83pub(crate) mod binary;
84pub(crate) mod core;
85pub(crate) mod json;
86
87// Re-export core functionality for convenience
88pub use core::{
89 FieldValue, FromFieldValue, SerializationError, SerializationResult, StructDeserializer,
90 StructSerializable, StructSerializer, ToFieldValue,
91};
92
93/// Serialization format options for saving and loading objects
94///
95/// This enum defines the available serialization formats supported by the
96/// Train Station serialization system. Each format has specific use cases
97/// and performance characteristics.
98///
99/// # Variants
100///
101/// * `Json` - Human-readable JSON format for debugging and inspection
102/// * `Binary` - Efficient binary format for production deployment
103///
104/// # Examples
105///
106/// ```
107/// use train_station::serialization::Format;
108///
109/// // Check format variants
110/// let json_format = Format::Json;
111/// let binary_format = Format::Binary;
112/// assert_ne!(json_format, binary_format);
113/// ```
114///
115/// # Performance Considerations
116///
117/// - **JSON**: Larger file sizes, slower serialization, human-readable
118/// - **Binary**: Smaller file sizes, faster serialization, machine-optimized
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum Format {
121 /// Human-readable JSON format
122 ///
123 /// Suitable for debugging, configuration files, and cross-language
124 /// interoperability. Produces larger files but allows easy inspection
125 /// and modification of serialized data.
126 Json,
127 /// Efficient binary format
128 ///
129 /// Optimized for production deployment with minimal file sizes and
130 /// maximum serialization speed. Not human-readable but provides
131 /// the best performance characteristics.
132 Binary,
133}
134
135/// Core serialization trait for Train Station objects
136///
137/// This trait provides a unified interface for saving and loading objects in multiple
138/// formats. All serializable objects must implement this trait to enable persistent
139/// storage and model checkpointing. The trait includes both file-based and writer-based
140/// operations for maximum flexibility.
141///
142/// # Required Methods
143///
144/// * `to_json` - Serialize the object to JSON format
145/// * `from_json` - Deserialize an object from JSON format
146/// * `to_binary` - Serialize the object to binary format
147/// * `from_binary` - Deserialize an object from binary format
148///
149/// # Provided Methods
150///
151/// * `save` - Save the object to a file in the specified format
152/// * `save_to_writer` - Save the object to a writer in the specified format
153/// * `load` - Load an object from a file in the specified format
154/// * `load_from_reader` - Load an object from a reader in the specified format
155///
156/// # Safety
157///
158/// Implementations must ensure that:
159/// - Serialized data contains all necessary information for reconstruction
160/// - Deserialization validates all input data thoroughly
161/// - Memory safety is maintained during reconstruction
162/// - No undefined behavior occurs with malformed input
163///
164/// # Examples
165///
166/// ```
167/// use train_station::serialization::{Serializable, Format, SerializationResult};
168///
169/// // Example implementation for a simple struct
170/// #[derive(Debug, PartialEq)]
171/// struct SimpleData {
172/// value: i32,
173/// }
174///
175/// impl Serializable for SimpleData {
176/// fn to_json(&self) -> SerializationResult<String> {
177/// Ok(format!(r#"{{"value":{}}}"#, self.value))
178/// }
179///
180/// fn from_json(json: &str) -> SerializationResult<Self> {
181/// // Simple parsing for demonstration
182/// if let Some(start) = json.find("value\":") {
183/// let value_str = &json[start + 7..];
184/// if let Some(end) = value_str.find('}') {
185/// let value: i32 = value_str[..end].parse()
186/// .map_err(|_| "Invalid number format")?;
187/// return Ok(SimpleData { value });
188/// }
189/// }
190/// Err("Invalid JSON format".into())
191/// }
192///
193/// fn to_binary(&self) -> SerializationResult<Vec<u8>> {
194/// Ok(self.value.to_le_bytes().to_vec())
195/// }
196///
197/// fn from_binary(data: &[u8]) -> SerializationResult<Self> {
198/// if data.len() != 4 {
199/// return Err("Invalid binary data length".into());
200/// }
201/// let value = i32::from_le_bytes([data[0], data[1], data[2], data[3]]);
202/// Ok(SimpleData { value })
203/// }
204/// }
205///
206/// // Usage example
207/// let data = SimpleData { value: 42 };
208/// let json = data.to_json().unwrap();
209/// let loaded = SimpleData::from_json(&json).unwrap();
210/// assert_eq!(data, loaded);
211/// ```
212///
213/// # Implementors
214///
215/// Common types that implement this trait include:
216/// * `Tensor` - For serializing tensor data and metadata
217/// * `AdamConfig` - For serializing optimizer configuration
218/// * `SerializableAdam` - For serializing optimizer state
219pub trait Serializable: Sized {
220 /// Save the object to a file in the specified format
221 ///
222 /// This method creates or overwrites a file at the specified path and writes
223 /// the serialized object data in the requested format. The file is created
224 /// with write permissions and truncated if it already exists.
225 ///
226 /// # Arguments
227 ///
228 /// * `path` - File path where the object should be saved
229 /// * `format` - Serialization format (JSON or Binary)
230 ///
231 /// # Returns
232 ///
233 /// `Ok(())` on success, or `SerializationError` on failure
234 ///
235 /// # Examples
236 ///
237 /// ```
238 /// use train_station::serialization::{Serializable, Format, SerializationResult};
239 /// use std::io::Write;
240 ///
241 /// // Simple example struct
242 /// struct TestData { value: i32 }
243 /// impl Serializable for TestData {
244 /// fn to_json(&self) -> SerializationResult<String> {
245 /// Ok(format!(r#"{{"value":{}}}"#, self.value))
246 /// }
247 /// fn from_json(json: &str) -> SerializationResult<Self> {
248 /// Ok(TestData { value: 42 }) // Simplified for example
249 /// }
250 /// fn to_binary(&self) -> SerializationResult<Vec<u8>> {
251 /// Ok(self.value.to_le_bytes().to_vec())
252 /// }
253 /// fn from_binary(_data: &[u8]) -> SerializationResult<Self> {
254 /// Ok(TestData { value: 42 }) // Simplified for example
255 /// }
256 /// }
257 ///
258 /// let data = TestData { value: 42 };
259 ///
260 /// // Save to temporary file (cleanup handled by temp directory)
261 /// let temp_dir = std::env::temp_dir();
262 /// let json_path = temp_dir.join("test_data.json");
263 /// data.save(&json_path, Format::Json).unwrap();
264 ///
265 /// // Verify file was created
266 /// assert!(json_path.exists());
267 ///
268 /// // Clean up
269 /// std::fs::remove_file(&json_path).ok();
270 /// ```
271 #[track_caller]
272 fn save<P: AsRef<Path>>(&self, path: P, format: Format) -> SerializationResult<()> {
273 let file = OpenOptions::new()
274 .write(true)
275 .create(true)
276 .truncate(true)
277 .open(path)?;
278
279 let mut writer = BufWriter::new(file);
280 self.save_to_writer(&mut writer, format)
281 }
282
283 /// Save the object to a writer in the specified format
284 ///
285 /// This method serializes the object and writes the data to the provided writer.
286 /// The writer is flushed after writing to ensure all data is written. This method
287 /// is useful for streaming serialization or writing to non-file destinations.
288 ///
289 /// # Arguments
290 ///
291 /// * `writer` - Writer to output serialized data
292 /// * `format` - Serialization format (JSON or Binary)
293 ///
294 /// # Returns
295 ///
296 /// `Ok(())` on success, or `SerializationError` on failure
297 #[track_caller]
298 fn save_to_writer<W: Write>(&self, writer: &mut W, format: Format) -> SerializationResult<()> {
299 match format {
300 Format::Json => {
301 let json_data = self.to_json()?;
302 writer.write_all(json_data.as_bytes())?;
303 }
304 Format::Binary => {
305 let binary_data = self.to_binary()?;
306 writer.write_all(&binary_data)?;
307 }
308 }
309 writer.flush()?;
310 Ok(())
311 }
312
313 /// Load an object from a file in the specified format
314 ///
315 /// This method reads the entire file content and deserializes it into an object
316 /// of the implementing type. The file must exist and contain valid serialized
317 /// data in the specified format.
318 ///
319 /// # Arguments
320 ///
321 /// * `path` - File path to read from
322 /// * `format` - Expected serialization format
323 ///
324 /// # Returns
325 ///
326 /// The deserialized object on success, or `SerializationError` on failure
327 ///
328 /// # Examples
329 ///
330 /// ```
331 /// use train_station::serialization::{Serializable, Format, SerializationResult};
332 /// use std::io::Write;
333 ///
334 /// // Simple example struct
335 /// #[derive(Debug, PartialEq)]
336 /// struct TestData { value: i32 }
337 /// impl Serializable for TestData {
338 /// fn to_json(&self) -> SerializationResult<String> {
339 /// Ok(format!(r#"{{"value":{}}}"#, self.value))
340 /// }
341 /// fn from_json(json: &str) -> SerializationResult<Self> {
342 /// // Simple parsing for demonstration
343 /// if json.contains("42") {
344 /// Ok(TestData { value: 42 })
345 /// } else {
346 /// Ok(TestData { value: 0 })
347 /// }
348 /// }
349 /// fn to_binary(&self) -> SerializationResult<Vec<u8>> {
350 /// Ok(self.value.to_le_bytes().to_vec())
351 /// }
352 /// fn from_binary(data: &[u8]) -> SerializationResult<Self> {
353 /// if data.len() >= 4 {
354 /// let value = i32::from_le_bytes([data[0], data[1], data[2], data[3]]);
355 /// Ok(TestData { value })
356 /// } else {
357 /// Ok(TestData { value: 0 })
358 /// }
359 /// }
360 /// }
361 ///
362 /// let original = TestData { value: 42 };
363 ///
364 /// // Save and load from temporary file
365 /// let temp_dir = std::env::temp_dir();
366 /// let json_path = temp_dir.join("test_load.json");
367 /// original.save(&json_path, Format::Json).unwrap();
368 ///
369 /// let loaded = TestData::load(&json_path, Format::Json).unwrap();
370 /// assert_eq!(original, loaded);
371 ///
372 /// // Clean up
373 /// std::fs::remove_file(&json_path).ok();
374 /// ```
375 #[track_caller]
376 fn load<P: AsRef<Path>>(path: P, format: Format) -> SerializationResult<Self> {
377 let file = File::open(path)?;
378 let mut reader = BufReader::new(file);
379 Self::load_from_reader(&mut reader, format)
380 }
381
382 /// Load an object from a reader in the specified format
383 ///
384 /// This method reads all available data from the provided reader and deserializes
385 /// it into an object of the implementing type. The reader must contain complete
386 /// serialized data in the specified format.
387 ///
388 /// # Arguments
389 ///
390 /// * `reader` - Reader containing serialized data
391 /// * `format` - Expected serialization format
392 ///
393 /// # Returns
394 ///
395 /// The deserialized object on success, or `SerializationError` on failure
396 #[track_caller]
397 fn load_from_reader<R: Read>(reader: &mut R, format: Format) -> SerializationResult<Self> {
398 match format {
399 Format::Json => {
400 let mut json_data = String::new();
401 reader.read_to_string(&mut json_data)?;
402 Self::from_json(&json_data)
403 }
404 Format::Binary => {
405 let mut binary_data = Vec::new();
406 reader.read_to_end(&mut binary_data)?;
407 Self::from_binary(&binary_data)
408 }
409 }
410 }
411
412 /// Serialize the object to JSON format
413 ///
414 /// This method converts the object into a human-readable JSON string representation.
415 /// The JSON format is suitable for debugging, configuration files, and cross-language
416 /// interoperability.
417 ///
418 /// # Returns
419 ///
420 /// JSON string representation of the object on success, or `SerializationError` on failure
421 #[track_caller]
422 fn to_json(&self) -> SerializationResult<String>;
423
424 /// Deserialize an object from JSON format
425 ///
426 /// This method parses a JSON string and reconstructs an object of the implementing
427 /// type. The JSON must contain all necessary fields and data in the expected format.
428 ///
429 /// # Arguments
430 ///
431 /// * `json` - JSON string containing serialized object
432 ///
433 /// # Returns
434 ///
435 /// The deserialized object on success, or `SerializationError` on failure
436 #[track_caller]
437 fn from_json(json: &str) -> SerializationResult<Self>;
438
439 /// Serialize the object to binary format
440 ///
441 /// This method converts the object into a compact binary representation optimized
442 /// for storage and transmission. The binary format provides maximum performance
443 /// and minimal file sizes.
444 ///
445 /// # Returns
446 ///
447 /// Binary representation of the object on success, or `SerializationError` on failure
448 #[track_caller]
449 fn to_binary(&self) -> SerializationResult<Vec<u8>>;
450
451 /// Deserialize an object from binary format
452 ///
453 /// This method parses binary data and reconstructs an object of the implementing
454 /// type. The binary data must contain complete serialized information in the
455 /// expected format.
456 ///
457 /// # Arguments
458 ///
459 /// * `data` - Binary data containing serialized object
460 ///
461 /// # Returns
462 ///
463 /// The deserialized object on success, or `SerializationError` on failure
464 #[track_caller]
465 fn from_binary(data: &[u8]) -> SerializationResult<Self>;
466}
467
468/// Utility functions for common serialization tasks
469///
470/// This module provides helper functions for format detection, file extension
471/// management, and size estimation for serialization operations. These functions
472/// are used internally by the serialization system to support file operations
473/// and provide estimates for memory allocation.
474///
475/// # Purpose
476///
477/// The utilities in this module handle:
478/// - File extension mapping for different serialization formats
479/// - Automatic format detection based on file paths
480/// - Size estimation for binary serialization planning
481/// - Common helper functions used across the serialization system
482pub(crate) mod utils {
483 #[cfg(test)]
484 use super::Format;
485 #[cfg(test)]
486 use std::path::Path;
487
488 /// Get the appropriate file extension for a format
489 ///
490 /// Returns the standard file extension associated with each serialization format.
491 /// This is useful for automatically determining file extensions when saving
492 /// or for format detection based on file paths.
493 ///
494 /// # Arguments
495 ///
496 /// * `format` - The serialization format
497 ///
498 /// # Returns
499 ///
500 /// The file extension as a string slice
501 ///
502 /// # Examples
503 ///
504 /// ```
505 /// use train_station::serialization::Format;
506 ///
507 /// // This function is internal, but demonstrates the concept
508 /// fn format_extension(format: Format) -> &'static str {
509 /// match format {
510 /// Format::Json => "json",
511 /// Format::Binary => "bin",
512 /// }
513 /// }
514 ///
515 /// assert_eq!(format_extension(Format::Json), "json");
516 /// assert_eq!(format_extension(Format::Binary), "bin");
517 /// ```
518 #[cfg(test)]
519 pub(crate) fn format_extension(format: Format) -> &'static str {
520 match format {
521 Format::Json => "json",
522 Format::Binary => "bin",
523 }
524 }
525
526 /// Detect format from file extension
527 ///
528 /// Attempts to determine the serialization format based on the file extension.
529 /// Supports case-insensitive extension matching for common format extensions.
530 ///
531 /// # Arguments
532 ///
533 /// * `path` - File path to analyze
534 ///
535 /// # Returns
536 ///
537 /// `Some(Format)` if the extension is recognized, `None` otherwise
538 ///
539 /// # Examples
540 ///
541 /// ```
542 /// use train_station::serialization::Format;
543 /// use std::path::Path;
544 ///
545 /// // This function is internal, but demonstrates the concept
546 /// fn detect_format<P: AsRef<Path>>(path: P) -> Option<Format> {
547 /// path.as_ref()
548 /// .extension()
549 /// .and_then(|ext| ext.to_str())
550 /// .and_then(|ext| match ext.to_lowercase().as_str() {
551 /// "json" => Some(Format::Json),
552 /// "bin" => Some(Format::Binary),
553 /// _ => None,
554 /// })
555 /// }
556 ///
557 /// assert_eq!(detect_format("model.json"), Some(Format::Json));
558 /// assert_eq!(detect_format("model.JSON"), Some(Format::Json));
559 /// assert_eq!(detect_format("model.bin"), Some(Format::Binary));
560 /// assert_eq!(detect_format("model.txt"), None);
561 /// ```
562 #[cfg(test)]
563 pub(crate) fn detect_format<P: AsRef<Path>>(path: P) -> Option<Format> {
564 path.as_ref()
565 .extension()
566 .and_then(|ext| ext.to_str())
567 .and_then(|ext| match ext.to_lowercase().as_str() {
568 "json" => Some(Format::Json),
569 "bin" => Some(Format::Binary),
570 _ => None,
571 })
572 }
573
574 /// Estimate serialized size for binary format
575 ///
576 /// Provides a rough estimate of the binary serialized size based on the number
577 /// of tensors, total elements, and metadata fields. This is useful for memory
578 /// allocation and storage planning.
579 ///
580 /// # Arguments
581 ///
582 /// * `tensor_count` - Number of tensors to be serialized
583 /// * `total_elements` - Total number of elements across all tensors
584 /// * `metadata_fields` - Number of metadata fields per tensor
585 ///
586 /// # Returns
587 ///
588 /// Estimated size in bytes for the binary serialization
589 ///
590 /// # Examples
591 ///
592 /// ```
593 /// // This function is internal, but demonstrates the concept
594 /// fn estimate_binary_size(
595 /// tensor_count: usize,
596 /// total_elements: usize,
597 /// metadata_fields: usize,
598 /// ) -> usize {
599 /// // Header + magic number + version
600 /// let header_size = 16;
601 /// // Tensor data (f32 per element)
602 /// let data_size = total_elements * 4;
603 /// // Shape information (dimensions, strides, metadata)
604 /// let shape_size = tensor_count * (metadata_fields * 8 + 64);
605 /// header_size + data_size + shape_size
606 /// }
607 ///
608 /// let estimated_size = estimate_binary_size(3, 1000, 5);
609 /// assert!(estimated_size > 4000); // At least data size
610 /// ```
611 #[cfg(test)]
612 pub(crate) fn estimate_binary_size(
613 tensor_count: usize,
614 total_elements: usize,
615 metadata_fields: usize,
616 ) -> usize {
617 // Header + magic number + version
618 let header_size = 16;
619
620 // Tensor data (f32 per element)
621 let data_size = total_elements * 4;
622
623 // Shape information (dimensions, strides, metadata)
624 let shape_size = tensor_count * (metadata_fields * 8 + 64);
625
626 header_size + data_size + shape_size
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633
634 #[test]
635 fn test_format_detection() {
636 assert_eq!(utils::detect_format("model.json"), Some(Format::Json));
637 assert_eq!(utils::detect_format("model.JSON"), Some(Format::Json)); // Case insensitive
638 assert_eq!(utils::detect_format("model.bin"), Some(Format::Binary));
639 assert_eq!(utils::detect_format("model.BIN"), Some(Format::Binary)); // Case insensitive
640 assert_eq!(utils::detect_format("model.txt"), None);
641 assert_eq!(utils::detect_format("model"), None);
642 assert_eq!(utils::detect_format(""), None);
643 }
644
645 #[test]
646 fn test_format_extensions() {
647 assert_eq!(utils::format_extension(Format::Json), "json");
648 assert_eq!(utils::format_extension(Format::Binary), "bin");
649 }
650
651 #[test]
652 fn test_binary_size_estimation() {
653 // Single tensor with 1000 elements
654 let estimated = utils::estimate_binary_size(1, 1000, 5);
655 assert!(estimated >= 4000); // At least data size (1000 * 4 bytes)
656 assert!(estimated <= 5000); // Reasonable metadata overhead
657
658 // Multiple tensors
659 let estimated_multi = utils::estimate_binary_size(3, 3000, 5);
660 assert!(estimated_multi >= 12000); // At least data size (3000 * 4 bytes)
661 assert!(estimated_multi > estimated * 2); // Should be larger than single tensor
662 }
663}