train_station/tensor/core/
serialization.rs

1//! Tensor serialization implementation
2//!
3//! This module provides comprehensive serialization support for tensors, shapes,
4//! and device information using the Train Station serialization framework.
5//! It supports both JSON and binary formats with full roundtrip fidelity.
6//!
7//! # Key Features
8//!
9//! - **Complete Tensor Serialization**: Tensor data, shape, device, and gradtrack state
10//! - **Shape Serialization**: Dimensions, strides, size, and memory layout information
11//! - **Device Serialization**: Device type and index for CPU/CUDA placement
12//! - **Efficient Binary Format**: Optimized binary serialization for performance
13//! - **Human-Readable JSON**: JSON format for debugging and interoperability
14//! - **Roundtrip Fidelity**: Perfect reconstruction of tensors from serialized data
15//! - **Struct Field Support**: Tensors can be serialized as fields within larger structures
16//!
17//! # Architecture
18//!
19//! The serialization system handles:
20//! - **Tensor Data**: Serialized as `Vec<f32>` for efficiency
21//! - **Shape Information**: Complete shape metadata including strides and layout
22//! - **Device Placement**: Device type and index for proper reconstruction
23//! - **GradTrack State**: requires_grad flag (runtime gradient state not serialized)
24//!
25//! # Non-Serialized Fields
26//!
27//! The following tensor fields are NOT serialized as they are runtime state:
28//! - `data` (raw pointer): Reconstructed from serialized `Vec<f32>`
29//! - `id`: Regenerated during deserialization for uniqueness
30//! - `grad`: Runtime gradient state, not persistent
31//! - `grad_fn`: Runtime gradient function, not persistent
32//! - `allocation_owner`: Internal memory management, reconstructed
33//! - `_phantom`: Zero-sized type, no serialization needed
34//!
35//! # Usage Example
36//!
37//! ## Basic Tensor Serialization
38//!
39//! ```rust
40//! use train_station::Tensor;
41//! use train_station::serialization::StructSerializable;
42//!
43//! // Create and populate a tensor
44//! let mut tensor = Tensor::zeros(vec![2, 3, 4]).with_requires_grad();
45//! tensor.fill(42.0);
46//!
47//! // Serialize to JSON
48//! let json = tensor.to_json().unwrap();
49//! assert!(!json.is_empty());
50//!
51//! // Deserialize from JSON
52//! let loaded_tensor = Tensor::from_json(&json).unwrap();
53//! assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
54//! assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
55//!
56//! // Serialize to binary
57//! let binary = tensor.to_binary().unwrap();
58//! assert!(!binary.is_empty());
59//!
60//! // Deserialize from binary
61//! let loaded_tensor = Tensor::from_binary(&binary).unwrap();
62//! assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
63//! ```
64//!
65//! ## Tensor as Struct Field
66//!
67//! ```rust
68//! use train_station::Tensor;
69//! use train_station::serialization::StructSerializable;
70//!
71//! // Define a struct containing tensors
72//! #[derive(Debug)]
73//! struct ModelWeights {
74//!     weight_matrix: Tensor,
75//!     bias_vector: Tensor,
76//!     learning_rate: f32,
77//!     name: String,
78//! }
79//!
80//! impl StructSerializable for ModelWeights {
81//!     fn to_serializer(&self) -> train_station::serialization::StructSerializer {
82//!         train_station::serialization::StructSerializer::new()
83//!             .field("weight_matrix", &self.weight_matrix)
84//!             .field("bias_vector", &self.bias_vector)
85//!             .field("learning_rate", &self.learning_rate)
86//!             .field("name", &self.name)
87//!     }
88//!
89//!     fn from_deserializer(
90//!         deserializer: &mut train_station::serialization::StructDeserializer,
91//!     ) -> train_station::serialization::SerializationResult<Self> {
92//!         Ok(ModelWeights {
93//!             weight_matrix: deserializer.field("weight_matrix")?,
94//!             bias_vector: deserializer.field("bias_vector")?,
95//!             learning_rate: deserializer.field("learning_rate")?,
96//!             name: deserializer.field("name")?,
97//!         })
98//!     }
99//! }
100//!
101//! // Create test struct with tensors
102//! let mut weights = ModelWeights {
103//!     weight_matrix: Tensor::zeros(vec![10, 5]),
104//!     bias_vector: Tensor::ones(vec![5]).with_requires_grad(),
105//!     learning_rate: 0.001,
106//!     name: "test_model".to_string(),
107//! };
108//!
109//! // Set some values
110//! weights.weight_matrix.set(&[0, 0], 0.5);
111//! weights.bias_vector.set(&[2], 2.0);
112//!
113//! // Test JSON serialization
114//! let json = weights.to_json().unwrap();
115//! let loaded_weights = ModelWeights::from_json(&json).unwrap();
116//!
117//! assert_eq!(weights.learning_rate, loaded_weights.learning_rate);
118//! assert_eq!(weights.name, loaded_weights.name);
119//! assert_eq!(
120//!     weights.weight_matrix.shape().dims(),
121//!     loaded_weights.weight_matrix.shape().dims()
122//! );
123//! ```
124//!
125//! ## Large Tensor Serialization
126//!
127//! ```rust
128//! use train_station::Tensor;
129//! use train_station::serialization::StructSerializable;
130//!
131//! // Create large tensor
132//! let mut tensor = Tensor::zeros(vec![100, 100]).with_requires_grad();
133//!
134//! // Set some values
135//! for i in 0..10 {
136//!     for j in 0..10 {
137//!         tensor.set(&[i, j], (i * 10 + j) as f32);
138//!     }
139//! }
140//!
141//! // Binary serialization is more efficient for large tensors
142//! let binary = tensor.to_binary().unwrap();
143//! let loaded_tensor = Tensor::from_binary(&binary).unwrap();
144//!
145//! // Verify properties
146//! assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
147//! assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
148//!
149//! // Verify data integrity
150//! for i in 0..10 {
151//!     for j in 0..10 {
152//!         assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
153//!     }
154//! }
155//! ```
156//!
157//! ## Error Handling
158//!
159//! ```rust
160//! use train_station::Tensor;
161//! use train_station::serialization::StructSerializable;
162//!
163//! // Test invalid serialization data
164//! let invalid_json = r#"{"invalid": "data"}"#;
165//! let result = Tensor::from_json(invalid_json);
166//! assert!(result.is_err());
167//!
168//! // Test empty binary data
169//! let empty_binary = vec![];
170//! let result = Tensor::from_binary(&empty_binary);
171//! assert!(result.is_err());
172//! ```
173//!
174//! # Performance Characteristics
175//!
176//! - **Binary Format**: Optimized for size and speed
177//! - **JSON Format**: Human-readable with reasonable performance
178//! - **Memory Efficient**: Minimal overhead during serialization
179//! - **Zero-Copy**: Direct serialization of tensor data arrays
180//! - **Type Safety**: Compile-time guarantees for serialization correctness
181
182use std::collections::HashMap;
183
184use crate::device::{Device, DeviceType};
185use crate::serialization::{
186    FieldValue, FromFieldValue, SerializationError, SerializationResult, StructDeserializer,
187    StructSerializable, StructSerializer, ToFieldValue,
188};
189use crate::tensor::{Shape, Tensor};
190
191// ===== Device Serialization =====
192
193impl ToFieldValue for DeviceType {
194    /// Convert DeviceType to FieldValue for serialization
195    ///
196    /// # Returns
197    ///
198    /// Enum FieldValue with variant name (proper enum serialization)
199    fn to_field_value(&self) -> FieldValue {
200        match self {
201            DeviceType::Cpu => FieldValue::from_enum_unit("Cpu".to_string()),
202            DeviceType::Cuda => FieldValue::from_enum_unit("Cuda".to_string()),
203        }
204    }
205}
206
207impl FromFieldValue for DeviceType {
208    /// Convert FieldValue to DeviceType for deserialization
209    ///
210    /// # Arguments
211    ///
212    /// * `value` - FieldValue containing enum data
213    /// * `field_name` - Name of the field for error reporting
214    ///
215    /// # Returns
216    ///
217    /// DeviceType enum value or error if invalid
218    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
219        let (variant, data) =
220            value
221                .as_enum()
222                .map_err(|_| SerializationError::ValidationFailed {
223                    field: field_name.to_string(),
224                    message: "Expected enum value for device type".to_string(),
225                })?;
226
227        // Ensure no data for unit variants
228        if data.is_some() {
229            return Err(SerializationError::ValidationFailed {
230                field: field_name.to_string(),
231                message: "DeviceType variants should not have associated data".to_string(),
232            });
233        }
234
235        match variant {
236            "Cpu" => Ok(DeviceType::Cpu),
237            "Cuda" => Ok(DeviceType::Cuda),
238            _ => Err(SerializationError::ValidationFailed {
239                field: field_name.to_string(),
240                message: format!("Unknown device type variant: {}", variant),
241            }),
242        }
243    }
244}
245
246impl ToFieldValue for Device {
247    /// Convert Device to FieldValue for serialization
248    ///
249    /// # Returns
250    ///
251    /// Object containing device type and index
252    fn to_field_value(&self) -> FieldValue {
253        let mut object = HashMap::new();
254        object.insert("type".to_string(), self.device_type().to_field_value());
255        object.insert("index".to_string(), self.index().to_field_value());
256        FieldValue::from_object(object)
257    }
258}
259
260impl FromFieldValue for Device {
261    /// Convert FieldValue to Device for deserialization
262    ///
263    /// # Arguments
264    ///
265    /// * `value` - FieldValue containing device object
266    /// * `field_name` - Name of the field for error reporting
267    ///
268    /// # Returns
269    ///
270    /// Device instance or error if invalid
271    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
272        let object = value
273            .as_object()
274            .map_err(|_| SerializationError::ValidationFailed {
275                field: field_name.to_string(),
276                message: "Expected object for device".to_string(),
277            })?;
278
279        let device_type = object
280            .get("type")
281            .ok_or_else(|| SerializationError::ValidationFailed {
282                field: field_name.to_string(),
283                message: "Missing device type field".to_string(),
284            })?
285            .clone();
286
287        let index = object
288            .get("index")
289            .ok_or_else(|| SerializationError::ValidationFailed {
290                field: field_name.to_string(),
291                message: "Missing device index field".to_string(),
292            })?
293            .clone();
294
295        let device_type = DeviceType::from_field_value(device_type, "type")?;
296        let index = usize::from_field_value(index, "index")?;
297
298        match device_type {
299            DeviceType::Cpu => Ok(Device::cpu()),
300            DeviceType::Cuda => Ok(Device::cuda(index)),
301        }
302    }
303}
304
305// ===== Memory Layout Serialization =====
306
307impl ToFieldValue for crate::tensor::MemoryLayout {
308    /// Convert MemoryLayout to FieldValue for serialization
309    ///
310    /// # Returns
311    ///
312    /// Enum FieldValue with variant name (proper enum serialization)
313    fn to_field_value(&self) -> FieldValue {
314        match self {
315            crate::tensor::MemoryLayout::Contiguous => {
316                FieldValue::from_enum_unit("Contiguous".to_string())
317            }
318            crate::tensor::core::MemoryLayout::Strided => {
319                FieldValue::from_enum_unit("Strided".to_string())
320            }
321            crate::tensor::core::MemoryLayout::View => {
322                FieldValue::from_enum_unit("View".to_string())
323            }
324        }
325    }
326}
327
328impl FromFieldValue for crate::tensor::MemoryLayout {
329    /// Convert FieldValue to MemoryLayout for deserialization
330    ///
331    /// # Arguments
332    ///
333    /// * `value` - FieldValue containing enum data
334    /// * `field_name` - Name of the field for error reporting
335    ///
336    /// # Returns
337    ///
338    /// MemoryLayout enum value or error if invalid
339    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
340        let (variant, data) =
341            value
342                .as_enum()
343                .map_err(|_| SerializationError::ValidationFailed {
344                    field: field_name.to_string(),
345                    message: "Expected enum value for memory layout".to_string(),
346                })?;
347
348        // Ensure no data for unit variants
349        if data.is_some() {
350            return Err(SerializationError::ValidationFailed {
351                field: field_name.to_string(),
352                message: "MemoryLayout variants should not have associated data".to_string(),
353            });
354        }
355
356        match variant {
357            "Contiguous" => Ok(crate::tensor::MemoryLayout::Contiguous),
358            "Strided" => Ok(crate::tensor::MemoryLayout::Strided),
359            "View" => Ok(crate::tensor::MemoryLayout::View),
360            _ => Err(SerializationError::ValidationFailed {
361                field: field_name.to_string(),
362                message: format!("Unknown memory layout variant: {}", variant),
363            }),
364        }
365    }
366}
367
368// ===== Shape Serialization =====
369
370impl ToFieldValue for Shape {
371    /// Convert Shape to FieldValue for serialization
372    ///
373    /// # Returns
374    ///
375    /// Object containing all shape metadata
376    fn to_field_value(&self) -> FieldValue {
377        let mut object = HashMap::new();
378        object.insert("dims".to_string(), self.dims().to_vec().to_field_value());
379        object.insert("size".to_string(), self.size().to_field_value());
380        object.insert(
381            "strides".to_string(),
382            self.strides().to_vec().to_field_value(),
383        );
384        object.insert("layout".to_string(), self.layout().to_field_value());
385        FieldValue::from_object(object)
386    }
387}
388
389impl FromFieldValue for Shape {
390    /// Convert FieldValue to Shape for deserialization
391    ///
392    /// # Arguments
393    ///
394    /// * `value` - FieldValue containing shape object
395    /// * `field_name` - Name of the field for error reporting
396    ///
397    /// # Returns
398    ///
399    /// Shape instance or error if invalid
400    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
401        let object = value
402            .as_object()
403            .map_err(|_| SerializationError::ValidationFailed {
404                field: field_name.to_string(),
405                message: "Expected object for shape".to_string(),
406            })?;
407
408        let dims = object
409            .get("dims")
410            .ok_or_else(|| SerializationError::ValidationFailed {
411                field: field_name.to_string(),
412                message: "Missing dims field in shape".to_string(),
413            })?
414            .clone();
415
416        let size = object
417            .get("size")
418            .ok_or_else(|| SerializationError::ValidationFailed {
419                field: field_name.to_string(),
420                message: "Missing size field in shape".to_string(),
421            })?
422            .clone();
423
424        let strides = object
425            .get("strides")
426            .ok_or_else(|| SerializationError::ValidationFailed {
427                field: field_name.to_string(),
428                message: "Missing strides field in shape".to_string(),
429            })?
430            .clone();
431
432        let layout = object
433            .get("layout")
434            .ok_or_else(|| SerializationError::ValidationFailed {
435                field: field_name.to_string(),
436                message: "Missing layout field in shape".to_string(),
437            })?
438            .clone();
439
440        let dims = Vec::<usize>::from_field_value(dims, "dims")?;
441        let size = usize::from_field_value(size, "size")?;
442        let strides = Vec::<usize>::from_field_value(strides, "strides")?;
443        let layout = crate::tensor::MemoryLayout::from_field_value(layout, "layout")?;
444
445        // Validate consistency
446        let expected_size: usize = dims.iter().product();
447        if size != expected_size {
448            return Err(SerializationError::ValidationFailed {
449                field: field_name.to_string(),
450                message: format!(
451                    "Shape size {} doesn't match computed size {}",
452                    size, expected_size
453                ),
454            });
455        }
456
457        if dims.len() != strides.len() {
458            return Err(SerializationError::ValidationFailed {
459                field: field_name.to_string(),
460                message: "Dimensions and strides must have same length".to_string(),
461            });
462        }
463
464        // Use the appropriate constructor based on layout
465        let mut shape = match layout {
466            crate::tensor::MemoryLayout::Contiguous => Shape::new(dims),
467            _ => Shape::with_strides(dims, strides),
468        };
469
470        // For view layouts, we need to set it as a view
471        if matches!(layout, crate::tensor::MemoryLayout::View) {
472            shape = Shape::as_view(shape.dims().to_vec(), shape.strides().to_vec());
473        }
474
475        Ok(shape)
476    }
477}
478
479// ===== Tensor Serialization =====
480
481impl StructSerializable for Tensor {
482    /// Convert Tensor to StructSerializer for serialization
483    ///
484    /// Serializes tensor data, shape, device, and gradtrack state.
485    /// Runtime state (id, grad, grad_fn, allocation_owner) is not serialized.
486    ///
487    /// # Returns
488    ///
489    /// StructSerializer containing all persistent tensor state
490    fn to_serializer(&self) -> StructSerializer {
491        // Extract tensor data as Vec<f32> - now uses efficient FieldValue implementation:
492        // - JSON format: Human-readable arrays of numbers
493        // - Binary format: Efficient byte representation with length header
494        let data: Vec<f32> =
495            unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()).to_vec() };
496
497        StructSerializer::new()
498            .field("data", &data)
499            .field("shape", self.shape())
500            .field("device", &self.device())
501            .field("requires_grad", &self.requires_grad())
502    }
503
504    /// Create Tensor from StructDeserializer
505    ///
506    /// Reconstructs tensor from serialized data, shape, device, and gradtrack state.
507    /// Allocates new memory and generates new tensor ID.
508    ///
509    /// # Arguments
510    ///
511    /// * `deserializer` - StructDeserializer containing tensor data
512    ///
513    /// # Returns
514    ///
515    /// Reconstructed Tensor instance or error if deserialization fails
516    fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
517        let data: Vec<f32> = deserializer.field("data")?;
518        let shape: Shape = deserializer.field("shape")?;
519        let device: Device = deserializer.field("device")?;
520        let requires_grad: bool = deserializer.field("requires_grad")?;
521
522        // Validate data size matches shape
523        if data.len() != shape.size() {
524            return Err(SerializationError::ValidationFailed {
525                field: "tensor".to_string(),
526                message: format!(
527                    "Data length {} doesn't match shape size {}",
528                    data.len(),
529                    shape.size()
530                ),
531            });
532        }
533
534        // Create new tensor with the deserialized shape on the correct device
535        let mut tensor = Tensor::new_on_device(shape.dims().to_vec(), device);
536
537        // Copy data into tensor
538        if !data.is_empty() {
539            unsafe {
540                let dst = tensor.as_mut_ptr();
541                std::ptr::copy_nonoverlapping(data.as_ptr(), dst, data.len());
542            }
543        }
544
545        // Set gradtrack state
546        tensor.set_requires_grad(requires_grad);
547
548        // Validate that the reconstructed shape matches
549        if tensor.shape().dims() != shape.dims()
550            || tensor.shape().size() != shape.size()
551            || tensor.shape().strides() != shape.strides()
552        {
553            return Err(SerializationError::ValidationFailed {
554                field: "tensor".to_string(),
555                message: "Reconstructed tensor shape doesn't match serialized shape".to_string(),
556            });
557        }
558
559        Ok(tensor)
560    }
561}
562
563impl FromFieldValue for Tensor {
564    /// Convert FieldValue to Tensor for use as struct field
565    ///
566    /// # Arguments
567    ///
568    /// * `value` - FieldValue containing tensor data
569    /// * `field_name` - Name of the field for error reporting
570    ///
571    /// # Returns
572    ///
573    /// Tensor instance or error if deserialization fails
574    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
575        // Try binary object first (for when serialized as binary)
576        if let Ok(binary_data) = value.as_binary_object() {
577            return Tensor::from_binary(binary_data).map_err(|e| {
578                SerializationError::ValidationFailed {
579                    field: field_name.to_string(),
580                    message: format!("Failed to deserialize tensor from binary: {}", e),
581                }
582            });
583        }
584
585        // Try JSON object (for when serialized as JSON)
586        if let Ok(json_data) = value.as_json_object() {
587            return Tensor::from_json(json_data).map_err(|e| {
588                SerializationError::ValidationFailed {
589                    field: field_name.to_string(),
590                    message: format!("Failed to deserialize tensor from JSON: {}", e),
591                }
592            });
593        }
594
595        // Try object (for when serialized as structured object in JSON)
596        if let Ok(object) = value.as_object() {
597            // Convert object back to deserializer and use StructSerializable
598            let mut deserializer = StructDeserializer::from_fields(object.clone());
599            return Tensor::from_deserializer(&mut deserializer).map_err(|e| {
600                SerializationError::ValidationFailed {
601                    field: field_name.to_string(),
602                    message: format!("Failed to deserialize tensor from object: {}", e),
603                }
604            });
605        }
606
607        Err(SerializationError::ValidationFailed {
608            field: field_name.to_string(),
609            message: "Expected binary object, JSON object, or structured object for tensor field"
610                .to_string(),
611        })
612    }
613}
614
615// ===== Serializable Trait Implementation =====
616
617impl crate::serialization::Serializable for Tensor {
618    /// Serialize the tensor to JSON format
619    ///
620    /// This method converts the tensor into a human-readable JSON string representation
621    /// that includes all tensor data, shape information, device placement, and gradtrack state.
622    /// The JSON format is suitable for debugging, configuration files, and cross-language
623    /// interoperability.
624    ///
625    /// # Returns
626    ///
627    /// JSON string representation of the tensor on success, or `SerializationError` on failure
628    ///
629    /// # Examples
630    ///
631    /// ```
632    /// use train_station::Tensor;
633    /// use train_station::serialization::Serializable;
634    ///
635    /// let mut tensor = Tensor::zeros(vec![2, 3]);
636    /// tensor.set(&[0, 0], 1.0);
637    /// tensor.set(&[1, 2], 5.0);
638    ///
639    /// let json = tensor.to_json().unwrap();
640    /// assert!(!json.is_empty());
641    /// assert!(json.contains("data"));
642    /// assert!(json.contains("shape"));
643    /// ```
644    fn to_json(&self) -> SerializationResult<String> {
645        StructSerializable::to_json(self)
646    }
647
648    /// Deserialize a tensor from JSON format
649    ///
650    /// This method parses a JSON string and reconstructs a tensor with all its data,
651    /// shape information, device placement, and gradtrack state. The JSON must contain
652    /// all necessary fields in the expected format.
653    ///
654    /// # Arguments
655    ///
656    /// * `json` - JSON string containing serialized tensor data
657    ///
658    /// # Returns
659    ///
660    /// The deserialized tensor on success, or `SerializationError` on failure
661    ///
662    /// # Examples
663    ///
664    /// ```
665    /// use train_station::Tensor;
666    /// use train_station::serialization::Serializable;
667    ///
668    /// let mut original = Tensor::ones(vec![2, 2]);
669    /// original.set(&[0, 1], 3.0);
670    /// original.set_requires_grad(true);
671    ///
672    /// let json = original.to_json().unwrap();
673    /// let restored = Tensor::from_json(&json).unwrap();
674    ///
675    /// assert_eq!(original.shape().dims(), restored.shape().dims());
676    /// assert_eq!(original.get(&[0, 1]), restored.get(&[0, 1]));
677    /// assert_eq!(original.requires_grad(), restored.requires_grad());
678    /// ```
679    fn from_json(json: &str) -> SerializationResult<Self> {
680        StructSerializable::from_json(json)
681    }
682
683    /// Serialize the tensor to binary format
684    ///
685    /// This method converts the tensor into a compact binary representation optimized
686    /// for storage and transmission. The binary format provides maximum performance
687    /// and minimal file sizes, making it ideal for large tensors and production use.
688    ///
689    /// # Returns
690    ///
691    /// Binary representation of the tensor on success, or `SerializationError` on failure
692    ///
693    /// # Examples
694    ///
695    /// ```
696    /// use train_station::Tensor;
697    /// use train_station::serialization::Serializable;
698    ///
699    /// let mut tensor = Tensor::zeros(vec![100, 100]);
700    /// for i in 0..10 {
701    ///     tensor.set(&[i, i], i as f32);
702    /// }
703    ///
704    /// let binary = tensor.to_binary().unwrap();
705    /// assert!(!binary.is_empty());
706    /// // Binary format is more compact than JSON for large tensors
707    /// ```
708    fn to_binary(&self) -> SerializationResult<Vec<u8>> {
709        StructSerializable::to_binary(self)
710    }
711
712    /// Deserialize a tensor from binary format
713    ///
714    /// This method parses binary data and reconstructs a tensor with all its data,
715    /// shape information, device placement, and gradtrack state. The binary data
716    /// must contain complete serialized information in the expected format.
717    ///
718    /// # Arguments
719    ///
720    /// * `data` - Binary data containing serialized tensor information
721    ///
722    /// # Returns
723    ///
724    /// The deserialized tensor on success, or `SerializationError` on failure
725    ///
726    /// # Examples
727    ///
728    /// ```
729    /// use train_station::Tensor;
730    /// use train_station::serialization::Serializable;
731    ///
732    /// let mut original = Tensor::ones(vec![3, 4]);
733    /// original.set(&[2, 3], 7.5);
734    /// original.set_requires_grad(true);
735    ///
736    /// let binary = original.to_binary().unwrap();
737    /// let restored = Tensor::from_binary(&binary).unwrap();
738    ///
739    /// assert_eq!(original.shape().dims(), restored.shape().dims());
740    /// assert_eq!(original.get(&[2, 3]), restored.get(&[2, 3]));
741    /// assert_eq!(original.requires_grad(), restored.requires_grad());
742    /// ```
743    fn from_binary(data: &[u8]) -> SerializationResult<Self> {
744        StructSerializable::from_binary(data)
745    }
746}
747
748#[cfg(test)]
749mod tests {
750    //! Comprehensive tests for tensor serialization functionality
751    //!
752    //! Tests cover all serialization formats and usage patterns including:
753    //! - JSON and binary roundtrip serialization
754    //! - Tensor as field within structs  
755    //! - Edge cases and error conditions
756    //! - Device and shape serialization
757    //! - Large tensor serialization
758
759    use super::*;
760
761    // ===== Device Serialization Tests =====
762
763    #[test]
764    fn test_device_type_serialization() {
765        // Test CPU device type
766        let cpu_type = DeviceType::Cpu;
767        let field_value = cpu_type.to_field_value();
768        let deserialized = DeviceType::from_field_value(field_value, "device_type").unwrap();
769        assert_eq!(cpu_type, deserialized);
770
771        // Test CUDA device type
772        let cuda_type = DeviceType::Cuda;
773        let field_value = cuda_type.to_field_value();
774        let deserialized = DeviceType::from_field_value(field_value, "device_type").unwrap();
775        assert_eq!(cuda_type, deserialized);
776    }
777
778    #[test]
779    fn test_device_serialization() {
780        // Test CPU device
781        let cpu_device = Device::cpu();
782        let field_value = cpu_device.to_field_value();
783        let deserialized = Device::from_field_value(field_value, "device").unwrap();
784        assert_eq!(cpu_device, deserialized);
785        assert!(deserialized.is_cpu());
786        assert_eq!(deserialized.index(), 0);
787    }
788
789    #[test]
790    fn test_device_serialization_errors() {
791        // Test invalid device type
792        let invalid_device_type = FieldValue::from_string("invalid".to_string());
793        let result = DeviceType::from_field_value(invalid_device_type, "device_type");
794        assert!(result.is_err());
795
796        // Test missing device fields
797        let incomplete_device = FieldValue::from_object({
798            let mut obj = HashMap::new();
799            obj.insert(
800                "type".to_string(),
801                FieldValue::from_string("cpu".to_string()),
802            );
803            // Missing index field
804            obj
805        });
806        let result = Device::from_field_value(incomplete_device, "device");
807        assert!(result.is_err());
808    }
809
810    // ===== Shape Serialization Tests =====
811
812    #[test]
813    fn test_memory_layout_serialization() {
814        use crate::tensor::MemoryLayout;
815
816        let layouts = [
817            MemoryLayout::Contiguous,
818            MemoryLayout::Strided,
819            MemoryLayout::View,
820        ];
821
822        for layout in &layouts {
823            let field_value = layout.to_field_value();
824            let deserialized = MemoryLayout::from_field_value(field_value, "layout").unwrap();
825            assert_eq!(*layout, deserialized);
826        }
827    }
828
829    #[test]
830    fn test_shape_serialization() {
831        // Test contiguous shape
832        let shape = Shape::new(vec![2, 3, 4]);
833        let field_value = shape.to_field_value();
834        let deserialized = Shape::from_field_value(field_value, "shape").unwrap();
835        assert_eq!(shape, deserialized);
836        assert_eq!(deserialized.dims(), vec![2, 3, 4]);
837        assert_eq!(deserialized.size(), 24);
838        assert_eq!(deserialized.strides(), vec![12, 4, 1]);
839
840        // Test strided shape
841        let strided_shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
842        let field_value = strided_shape.to_field_value();
843        let deserialized = Shape::from_field_value(field_value, "shape").unwrap();
844        assert_eq!(strided_shape, deserialized);
845    }
846
847    #[test]
848    fn test_shape_validation_errors() {
849        use crate::tensor::MemoryLayout;
850
851        // Test inconsistent size
852        let invalid_shape = FieldValue::from_object({
853            let mut obj = HashMap::new();
854            obj.insert("dims".to_string(), vec![2usize, 3].to_field_value());
855            obj.insert("size".to_string(), 10usize.to_field_value()); // Should be 6
856            obj.insert("strides".to_string(), vec![3usize, 1].to_field_value());
857            obj.insert(
858                "layout".to_string(),
859                MemoryLayout::Contiguous.to_field_value(),
860            );
861            obj
862        });
863        let result = Shape::from_field_value(invalid_shape, "shape");
864        assert!(result.is_err());
865
866        // Test mismatched dimensions and strides
867        let invalid_shape = FieldValue::from_object({
868            let mut obj = HashMap::new();
869            obj.insert("dims".to_string(), vec![2usize, 3].to_field_value());
870            obj.insert("size".to_string(), 6usize.to_field_value());
871            obj.insert("strides".to_string(), vec![3usize].to_field_value()); // Wrong length
872            obj.insert(
873                "layout".to_string(),
874                MemoryLayout::Contiguous.to_field_value(),
875            );
876            obj
877        });
878        let result = Shape::from_field_value(invalid_shape, "shape");
879        assert!(result.is_err());
880    }
881
882    // ===== Tensor Serialization Tests =====
883
884    #[test]
885    fn test_tensor_json_roundtrip() {
886        // Create test tensor with data
887        let mut tensor = Tensor::zeros(vec![2, 3]);
888        tensor.set(&[0, 0], 1.0);
889        tensor.set(&[0, 1], 2.0);
890        tensor.set(&[0, 2], 3.0);
891        tensor.set(&[1, 0], 4.0);
892        tensor.set(&[1, 1], 5.0);
893        tensor.set(&[1, 2], 6.0);
894        tensor.set_requires_grad(true);
895
896        // Serialize to JSON
897        let json = tensor.to_json().unwrap();
898        assert!(!json.is_empty());
899
900        // Deserialize from JSON
901        let loaded_tensor = Tensor::from_json(&json).unwrap();
902
903        // Verify tensor properties
904        assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
905        assert_eq!(tensor.size(), loaded_tensor.size());
906        assert_eq!(tensor.device(), loaded_tensor.device());
907        assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
908
909        // Verify tensor data
910        for i in 0..2 {
911            for j in 0..3 {
912                assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
913            }
914        }
915    }
916
917    #[test]
918    fn test_tensor_binary_roundtrip() {
919        // Create test tensor with gradient tracking
920        let mut tensor = Tensor::ones(vec![3, 4]).with_requires_grad();
921
922        // Modify some values
923        tensor.set(&[0, 0], 10.0);
924        tensor.set(&[1, 2], 20.0);
925        tensor.set(&[2, 3], 30.0);
926
927        // Serialize to binary
928        let binary = tensor.to_binary().unwrap();
929        assert!(!binary.is_empty());
930
931        // Deserialize from binary
932        let loaded_tensor = Tensor::from_binary(&binary).unwrap();
933
934        // Verify tensor properties
935        assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
936        assert_eq!(tensor.size(), loaded_tensor.size());
937        assert_eq!(tensor.device(), loaded_tensor.device());
938        assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
939
940        // Verify tensor data
941        for i in 0..3 {
942            for j in 0..4 {
943                assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
944            }
945        }
946    }
947
948    #[test]
949    fn test_empty_tensor_serialization() {
950        // Test zero-sized tensor
951        let tensor = Tensor::new(vec![0]);
952
953        // JSON roundtrip
954        let json = tensor.to_json().unwrap();
955        let loaded_tensor = Tensor::from_json(&json).unwrap();
956        assert_eq!(tensor.size(), loaded_tensor.size());
957        assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
958
959        // Binary roundtrip
960        let binary = tensor.to_binary().unwrap();
961        let loaded_tensor = Tensor::from_binary(&binary).unwrap();
962        assert_eq!(tensor.size(), loaded_tensor.size());
963        assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
964    }
965
966    #[test]
967    fn test_large_tensor_serialization() {
968        // Test larger tensor
969        let mut tensor = Tensor::zeros(vec![100, 100]).with_requires_grad();
970
971        // Set some values
972        for i in 0..10 {
973            for j in 0..10 {
974                tensor.set(&[i, j], (i * 10 + j) as f32);
975            }
976        }
977
978        // Binary roundtrip (more efficient for large tensors)
979        let binary = tensor.to_binary().unwrap();
980        let loaded_tensor = Tensor::from_binary(&binary).unwrap();
981
982        // Verify properties
983        assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
984        assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
985
986        // Verify a subset of data
987        for i in 0..10 {
988            for j in 0..10 {
989                assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
990            }
991        }
992    }
993
994    #[test]
995    fn test_tensor_as_field_in_struct() {
996        // Define a struct containing tensors
997        #[derive(Debug)]
998        struct ModelWeights {
999            weight_matrix: Tensor,
1000            bias_vector: Tensor,
1001            learning_rate: f32,
1002            name: String,
1003        }
1004
1005        impl StructSerializable for ModelWeights {
1006            fn to_serializer(&self) -> StructSerializer {
1007                StructSerializer::new()
1008                    .field("weight_matrix", &self.weight_matrix)
1009                    .field("bias_vector", &self.bias_vector)
1010                    .field("learning_rate", &self.learning_rate)
1011                    .field("name", &self.name)
1012            }
1013
1014            fn from_deserializer(
1015                deserializer: &mut StructDeserializer,
1016            ) -> SerializationResult<Self> {
1017                Ok(ModelWeights {
1018                    weight_matrix: deserializer.field("weight_matrix")?,
1019                    bias_vector: deserializer.field("bias_vector")?,
1020                    learning_rate: deserializer.field("learning_rate")?,
1021                    name: deserializer.field("name")?,
1022                })
1023            }
1024        }
1025
1026        // Create test struct with tensors
1027        let mut weights = ModelWeights {
1028            weight_matrix: Tensor::zeros(vec![10, 5]),
1029            bias_vector: Tensor::ones(vec![5]).with_requires_grad(),
1030            learning_rate: 0.001,
1031            name: "test_model".to_string(),
1032        };
1033
1034        // Set some values
1035        weights.weight_matrix.set(&[0, 0], 0.5);
1036        weights.weight_matrix.set(&[9, 4], -0.3);
1037        weights.bias_vector.set(&[2], 2.0);
1038
1039        // Test JSON serialization
1040        let json = weights.to_json().unwrap();
1041        let loaded_weights = ModelWeights::from_json(&json).unwrap();
1042
1043        assert_eq!(weights.learning_rate, loaded_weights.learning_rate);
1044        assert_eq!(weights.name, loaded_weights.name);
1045        assert_eq!(
1046            weights.weight_matrix.shape().dims(),
1047            loaded_weights.weight_matrix.shape().dims()
1048        );
1049        assert_eq!(
1050            weights.bias_vector.shape().dims(),
1051            loaded_weights.bias_vector.shape().dims()
1052        );
1053        assert_eq!(
1054            weights.bias_vector.requires_grad(),
1055            loaded_weights.bias_vector.requires_grad()
1056        );
1057
1058        // Verify tensor data
1059        assert_eq!(
1060            weights.weight_matrix.get(&[0, 0]),
1061            loaded_weights.weight_matrix.get(&[0, 0])
1062        );
1063        assert_eq!(
1064            weights.weight_matrix.get(&[9, 4]),
1065            loaded_weights.weight_matrix.get(&[9, 4])
1066        );
1067        assert_eq!(
1068            weights.bias_vector.get(&[2]),
1069            loaded_weights.bias_vector.get(&[2])
1070        );
1071
1072        // Test binary serialization
1073        let binary = weights.to_binary().unwrap();
1074        let loaded_weights = ModelWeights::from_binary(&binary).unwrap();
1075
1076        assert_eq!(weights.learning_rate, loaded_weights.learning_rate);
1077        assert_eq!(weights.name, loaded_weights.name);
1078        assert_eq!(
1079            weights.weight_matrix.shape().dims(),
1080            loaded_weights.weight_matrix.shape().dims()
1081        );
1082        assert_eq!(
1083            weights.bias_vector.requires_grad(),
1084            loaded_weights.bias_vector.requires_grad()
1085        );
1086    }
1087
1088    #[test]
1089    fn test_multiple_tensors_in_struct() {
1090        // Test struct with multiple tensors of different shapes
1091        #[derive(Debug)]
1092        struct MultiTensorStruct {
1093            tensor_1d: Tensor,
1094            tensor_2d: Tensor,
1095            tensor_3d: Tensor,
1096            metadata: HashMap<String, String>,
1097        }
1098
1099        impl StructSerializable for MultiTensorStruct {
1100            fn to_serializer(&self) -> StructSerializer {
1101                StructSerializer::new()
1102                    .field("tensor_1d", &self.tensor_1d)
1103                    .field("tensor_2d", &self.tensor_2d)
1104                    .field("tensor_3d", &self.tensor_3d)
1105                    .field("metadata", &self.metadata)
1106            }
1107
1108            fn from_deserializer(
1109                deserializer: &mut StructDeserializer,
1110            ) -> SerializationResult<Self> {
1111                Ok(MultiTensorStruct {
1112                    tensor_1d: deserializer.field("tensor_1d")?,
1113                    tensor_2d: deserializer.field("tensor_2d")?,
1114                    tensor_3d: deserializer.field("tensor_3d")?,
1115                    metadata: deserializer.field("metadata")?,
1116                })
1117            }
1118        }
1119
1120        // Create test struct
1121        let mut multi_tensor = MultiTensorStruct {
1122            tensor_1d: Tensor::zeros(vec![5]),
1123            tensor_2d: Tensor::ones(vec![3, 4]).with_requires_grad(),
1124            tensor_3d: Tensor::zeros(vec![2, 2, 2]),
1125            metadata: {
1126                let mut map = HashMap::new();
1127                map.insert("version".to_string(), "1.0".to_string());
1128                map.insert("type".to_string(), "test".to_string());
1129                map
1130            },
1131        };
1132
1133        // Set some values
1134        multi_tensor.tensor_1d.set(&[0], 10.0);
1135        multi_tensor.tensor_2d.set(&[0, 0], 5.0);
1136        multi_tensor.tensor_3d.set(&[1, 1, 1], 3.0);
1137
1138        // Test JSON roundtrip
1139        let json = multi_tensor.to_json().unwrap();
1140        let loaded = MultiTensorStruct::from_json(&json).unwrap();
1141
1142        assert_eq!(
1143            multi_tensor.tensor_1d.shape().dims(),
1144            loaded.tensor_1d.shape().dims()
1145        );
1146        assert_eq!(
1147            multi_tensor.tensor_2d.shape().dims(),
1148            loaded.tensor_2d.shape().dims()
1149        );
1150        assert_eq!(
1151            multi_tensor.tensor_3d.shape().dims(),
1152            loaded.tensor_3d.shape().dims()
1153        );
1154        assert_eq!(
1155            multi_tensor.tensor_2d.requires_grad(),
1156            loaded.tensor_2d.requires_grad()
1157        );
1158        assert_eq!(multi_tensor.metadata, loaded.metadata);
1159
1160        // Verify tensor values
1161        assert_eq!(multi_tensor.tensor_1d.get(&[0]), loaded.tensor_1d.get(&[0]));
1162        assert_eq!(
1163            multi_tensor.tensor_2d.get(&[0, 0]),
1164            loaded.tensor_2d.get(&[0, 0])
1165        );
1166        assert_eq!(
1167            multi_tensor.tensor_3d.get(&[1, 1, 1]),
1168            loaded.tensor_3d.get(&[1, 1, 1])
1169        );
1170
1171        // Test binary roundtrip
1172        let binary = multi_tensor.to_binary().unwrap();
1173        let loaded = MultiTensorStruct::from_binary(&binary).unwrap();
1174        assert_eq!(
1175            multi_tensor.tensor_1d.shape().dims(),
1176            loaded.tensor_1d.shape().dims()
1177        );
1178        assert_eq!(
1179            multi_tensor.tensor_2d.requires_grad(),
1180            loaded.tensor_2d.requires_grad()
1181        );
1182    }
1183
1184    #[test]
1185    fn test_tensor_serialization_errors() {
1186        // Test invalid data size
1187        let mut deserializer = StructDeserializer::from_json(
1188            r#"
1189        {
1190            "data": [1.0, 2.0, 3.0],
1191            "shape": {
1192                "dims": [2, 3],
1193                "size": 6,
1194                "strides": [3, 1],
1195                "layout": "contiguous"
1196            },
1197            "device": {"type": "cpu", "index": 0},
1198            "requires_grad": false
1199        }"#,
1200        )
1201        .unwrap();
1202
1203        let result = Tensor::from_deserializer(&mut deserializer);
1204        assert!(result.is_err()); // Data length (3) doesn't match shape size (6)
1205    }
1206
1207    #[test]
1208    fn test_field_value_tensor_roundtrip() {
1209        // Test tensor as FieldValue
1210        let mut tensor = Tensor::zeros(vec![2, 2]);
1211        tensor.set(&[0, 0], 1.0);
1212        tensor.set(&[1, 1], 2.0);
1213
1214        let field_value = tensor.to_field_value();
1215        let loaded_tensor = Tensor::from_field_value(field_value, "test_tensor").unwrap();
1216
1217        assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
1218        assert_eq!(tensor.get(&[0, 0]), loaded_tensor.get(&[0, 0]));
1219        assert_eq!(tensor.get(&[1, 1]), loaded_tensor.get(&[1, 1]));
1220    }
1221
1222    #[test]
1223    fn test_different_tensor_shapes() {
1224        let test_shapes = vec![
1225            vec![1],          // Scalar
1226            vec![10],         // 1D vector
1227            vec![3, 4],       // 2D matrix
1228            vec![2, 3, 4],    // 3D tensor
1229            vec![2, 2, 2, 2], // 4D tensor
1230        ];
1231
1232        for shape in test_shapes {
1233            let tensor = Tensor::zeros(shape.clone()).with_requires_grad();
1234
1235            // JSON roundtrip
1236            let json = tensor.to_json().unwrap();
1237            let loaded = Tensor::from_json(&json).unwrap();
1238            assert_eq!(tensor.shape().dims(), loaded.shape().dims());
1239            assert_eq!(tensor.requires_grad(), loaded.requires_grad());
1240
1241            // Binary roundtrip
1242            let binary = tensor.to_binary().unwrap();
1243            let loaded = Tensor::from_binary(&binary).unwrap();
1244            assert_eq!(tensor.shape().dims(), loaded.shape().dims());
1245            assert_eq!(tensor.requires_grad(), loaded.requires_grad());
1246        }
1247    }
1248
1249    // ===== Serializable Trait Tests =====
1250
1251    #[test]
1252    fn test_serializable_json_methods() {
1253        // Create and populate test tensor
1254        let mut tensor = Tensor::zeros(vec![2, 3]);
1255        tensor.set(&[0, 0], 1.0);
1256        tensor.set(&[0, 1], 2.0);
1257        tensor.set(&[1, 2], 5.0);
1258        tensor.set_requires_grad(true);
1259
1260        // Test to_json method
1261        let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1262        assert!(!json.is_empty());
1263        assert!(json.contains("data"));
1264        assert!(json.contains("shape"));
1265        assert!(json.contains("device"));
1266        assert!(json.contains("requires_grad"));
1267
1268        // Test from_json method
1269        let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1270        assert_eq!(tensor.shape().dims(), restored.shape().dims());
1271        assert_eq!(tensor.size(), restored.size());
1272        assert_eq!(tensor.device(), restored.device());
1273        assert_eq!(tensor.requires_grad(), restored.requires_grad());
1274
1275        // Verify tensor data
1276        assert_eq!(tensor.get(&[0, 0]), restored.get(&[0, 0]));
1277        assert_eq!(tensor.get(&[0, 1]), restored.get(&[0, 1]));
1278        assert_eq!(tensor.get(&[1, 2]), restored.get(&[1, 2]));
1279    }
1280
1281    #[test]
1282    fn test_serializable_binary_methods() {
1283        // Create and populate test tensor
1284        let mut tensor = Tensor::ones(vec![3, 4]);
1285        tensor.set(&[0, 0], 10.0);
1286        tensor.set(&[1, 2], 20.0);
1287        tensor.set(&[2, 3], 30.0);
1288        tensor.set_requires_grad(true);
1289
1290        // Test to_binary method
1291        let binary = <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1292        assert!(!binary.is_empty());
1293
1294        // Test from_binary method
1295        let restored =
1296            <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1297        assert_eq!(tensor.shape().dims(), restored.shape().dims());
1298        assert_eq!(tensor.size(), restored.size());
1299        assert_eq!(tensor.device(), restored.device());
1300        assert_eq!(tensor.requires_grad(), restored.requires_grad());
1301
1302        // Verify tensor data
1303        assert_eq!(tensor.get(&[0, 0]), restored.get(&[0, 0]));
1304        assert_eq!(tensor.get(&[1, 2]), restored.get(&[1, 2]));
1305        assert_eq!(tensor.get(&[2, 3]), restored.get(&[2, 3]));
1306    }
1307
1308    #[test]
1309    fn test_serializable_file_io_json() {
1310        use crate::serialization::{Format, Serializable};
1311        use std::fs;
1312        use std::path::Path;
1313
1314        // Create test tensor
1315        let mut tensor = Tensor::zeros(vec![2, 2]);
1316        tensor.set(&[0, 0], 1.0);
1317        tensor.set(&[0, 1], 2.0);
1318        tensor.set(&[1, 0], 3.0);
1319        tensor.set(&[1, 1], 4.0);
1320        tensor.set_requires_grad(true);
1321
1322        // Test file paths
1323        let json_path = "test_tensor_serializable.json";
1324        let json_path_2 = "test_tensor_serializable_2.json";
1325
1326        // Cleanup any existing files
1327        let _ = fs::remove_file(json_path);
1328        let _ = fs::remove_file(json_path_2);
1329
1330        // Test save method with JSON format
1331        Serializable::save(&tensor, json_path, Format::Json).unwrap();
1332        assert!(Path::new(json_path).exists());
1333
1334        // Test load method with JSON format
1335        let loaded_tensor = Tensor::load(json_path, Format::Json).unwrap();
1336        assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
1337        assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1338        assert_eq!(tensor.get(&[0, 0]), loaded_tensor.get(&[0, 0]));
1339        assert_eq!(tensor.get(&[1, 1]), loaded_tensor.get(&[1, 1]));
1340
1341        // Test save_to_writer and load_from_reader
1342        {
1343            let mut writer = std::fs::File::create(json_path_2).unwrap();
1344            Serializable::save_to_writer(&tensor, &mut writer, Format::Json).unwrap();
1345        }
1346        assert!(Path::new(json_path_2).exists());
1347
1348        {
1349            let mut reader = std::fs::File::open(json_path_2).unwrap();
1350            let loaded_tensor = Tensor::load_from_reader(&mut reader, Format::Json).unwrap();
1351            assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
1352            assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1353            assert_eq!(tensor.get(&[0, 1]), loaded_tensor.get(&[0, 1]));
1354            assert_eq!(tensor.get(&[1, 0]), loaded_tensor.get(&[1, 0]));
1355        }
1356
1357        // Cleanup test files
1358        let _ = fs::remove_file(json_path);
1359        let _ = fs::remove_file(json_path_2);
1360    }
1361
1362    #[test]
1363    fn test_serializable_file_io_binary() {
1364        use crate::serialization::{Format, Serializable};
1365        use std::fs;
1366        use std::path::Path;
1367
1368        // Create test tensor
1369        let mut tensor = Tensor::ones(vec![3, 3]);
1370        for i in 0..3 {
1371            for j in 0..3 {
1372                tensor.set(&[i, j], (i * 3 + j) as f32);
1373            }
1374        }
1375        tensor.set_requires_grad(true);
1376
1377        // Test file paths
1378        let binary_path = "test_tensor_serializable.bin";
1379        let binary_path_2 = "test_tensor_serializable_2.bin";
1380
1381        // Cleanup any existing files
1382        let _ = fs::remove_file(binary_path);
1383        let _ = fs::remove_file(binary_path_2);
1384
1385        // Test save method with binary format
1386        Serializable::save(&tensor, binary_path, Format::Binary).unwrap();
1387        assert!(Path::new(binary_path).exists());
1388
1389        // Test load method with binary format
1390        let loaded_tensor = Tensor::load(binary_path, Format::Binary).unwrap();
1391        assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
1392        assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1393
1394        // Verify all data
1395        for i in 0..3 {
1396            for j in 0..3 {
1397                assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
1398            }
1399        }
1400
1401        // Test save_to_writer and load_from_reader
1402        {
1403            let mut writer = std::fs::File::create(binary_path_2).unwrap();
1404            Serializable::save_to_writer(&tensor, &mut writer, Format::Binary).unwrap();
1405        }
1406        assert!(Path::new(binary_path_2).exists());
1407
1408        {
1409            let mut reader = std::fs::File::open(binary_path_2).unwrap();
1410            let loaded_tensor = Tensor::load_from_reader(&mut reader, Format::Binary).unwrap();
1411            assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
1412            assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1413
1414            // Verify all data
1415            for i in 0..3 {
1416                for j in 0..3 {
1417                    assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
1418                }
1419            }
1420        }
1421
1422        // Cleanup test files
1423        let _ = fs::remove_file(binary_path);
1424        let _ = fs::remove_file(binary_path_2);
1425    }
1426
1427    #[test]
1428    fn test_serializable_large_tensor_performance() {
1429        // Create a large tensor to test performance characteristics
1430        let mut tensor = Tensor::zeros(vec![50, 50]);
1431        for i in 0..25 {
1432            for j in 0..25 {
1433                tensor.set(&[i, j], (i * 25 + j) as f32);
1434            }
1435        }
1436        tensor.set_requires_grad(true);
1437
1438        // Test JSON serialization
1439        let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1440        assert!(!json.is_empty());
1441        let restored_json =
1442            <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1443        assert_eq!(tensor.shape().dims(), restored_json.shape().dims());
1444        assert_eq!(tensor.requires_grad(), restored_json.requires_grad());
1445
1446        // Test binary serialization
1447        let binary = <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1448        assert!(!binary.is_empty());
1449        // Binary format should be efficient (this is informational, not a requirement)
1450        println!(
1451            "JSON size: {} bytes, Binary size: {} bytes",
1452            json.len(),
1453            binary.len()
1454        );
1455
1456        let restored_binary =
1457            <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1458        assert_eq!(tensor.shape().dims(), restored_binary.shape().dims());
1459        assert_eq!(tensor.requires_grad(), restored_binary.requires_grad());
1460
1461        // Verify a sample of data values
1462        for i in 0..5 {
1463            for j in 0..5 {
1464                assert_eq!(tensor.get(&[i, j]), restored_json.get(&[i, j]));
1465                assert_eq!(tensor.get(&[i, j]), restored_binary.get(&[i, j]));
1466            }
1467        }
1468    }
1469
1470    #[test]
1471    fn test_serializable_error_handling() {
1472        // Test invalid JSON
1473        let invalid_json = r#"{"invalid": "json", "structure": true}"#;
1474        let result = <Tensor as crate::serialization::Serializable>::from_json(invalid_json);
1475        assert!(result.is_err());
1476
1477        // Test empty JSON
1478        let empty_json = "{}";
1479        let result = <Tensor as crate::serialization::Serializable>::from_json(empty_json);
1480        assert!(result.is_err());
1481
1482        // Test invalid binary data
1483        let invalid_binary = vec![1, 2, 3, 4, 5];
1484        let result = <Tensor as crate::serialization::Serializable>::from_binary(&invalid_binary);
1485        assert!(result.is_err());
1486
1487        // Test empty binary data
1488        let empty_binary = vec![];
1489        let result = <Tensor as crate::serialization::Serializable>::from_binary(&empty_binary);
1490        assert!(result.is_err());
1491    }
1492
1493    #[test]
1494    fn test_serializable_different_shapes_and_types() {
1495        let test_cases = vec![
1496            // Scalar (1-element tensor)
1497            (vec![1], vec![42.0]),
1498            // 1D vector
1499            (vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]),
1500            // 2D matrix
1501            (vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
1502            // 3D tensor
1503            (vec![2, 2, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
1504        ];
1505
1506        for (shape, expected_data) in test_cases {
1507            // Create tensor with specific shape and data
1508            let mut tensor = Tensor::zeros(shape.clone());
1509
1510            // Set data based on shape dimensions
1511            match shape.len() {
1512                1 => {
1513                    for (i, &value) in expected_data.iter().enumerate().take(shape[0]) {
1514                        tensor.set(&[i], value);
1515                    }
1516                }
1517                2 => {
1518                    let mut idx = 0;
1519                    for i in 0..shape[0] {
1520                        for j in 0..shape[1] {
1521                            if idx < expected_data.len() {
1522                                tensor.set(&[i, j], expected_data[idx]);
1523                                idx += 1;
1524                            }
1525                        }
1526                    }
1527                }
1528                3 => {
1529                    let mut idx = 0;
1530                    for i in 0..shape[0] {
1531                        for j in 0..shape[1] {
1532                            for k in 0..shape[2] {
1533                                if idx < expected_data.len() {
1534                                    tensor.set(&[i, j, k], expected_data[idx]);
1535                                    idx += 1;
1536                                }
1537                            }
1538                        }
1539                    }
1540                }
1541                _ => {}
1542            }
1543            tensor.set_requires_grad(true);
1544
1545            // Test JSON roundtrip
1546            let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1547            let restored_json =
1548                <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1549            assert_eq!(tensor.shape().dims(), restored_json.shape().dims());
1550            assert_eq!(tensor.requires_grad(), restored_json.requires_grad());
1551
1552            // Test binary roundtrip
1553            let binary =
1554                <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1555            let restored_binary =
1556                <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1557            assert_eq!(tensor.shape().dims(), restored_binary.shape().dims());
1558            assert_eq!(tensor.requires_grad(), restored_binary.requires_grad());
1559
1560            // Verify data for first few elements
1561            match shape.len() {
1562                1 => {
1563                    for i in 0..shape[0].min(3).min(expected_data.len()) {
1564                        assert_eq!(tensor.get(&[i]), restored_json.get(&[i]));
1565                        assert_eq!(tensor.get(&[i]), restored_binary.get(&[i]));
1566                    }
1567                }
1568                2 => {
1569                    let mut count = 0;
1570                    for i in 0..shape[0] {
1571                        for j in 0..shape[1] {
1572                            if count < 3 && count < expected_data.len() {
1573                                assert_eq!(tensor.get(&[i, j]), restored_json.get(&[i, j]));
1574                                assert_eq!(tensor.get(&[i, j]), restored_binary.get(&[i, j]));
1575                                count += 1;
1576                            }
1577                        }
1578                    }
1579                }
1580                3 => {
1581                    let mut count = 0;
1582                    for i in 0..shape[0] {
1583                        for j in 0..shape[1] {
1584                            for k in 0..shape[2] {
1585                                if count < 3 && count < expected_data.len() {
1586                                    assert_eq!(
1587                                        tensor.get(&[i, j, k]),
1588                                        restored_json.get(&[i, j, k])
1589                                    );
1590                                    assert_eq!(
1591                                        tensor.get(&[i, j, k]),
1592                                        restored_binary.get(&[i, j, k])
1593                                    );
1594                                    count += 1;
1595                                }
1596                            }
1597                        }
1598                    }
1599                }
1600                _ => {}
1601            }
1602        }
1603    }
1604
1605    #[test]
1606    fn test_serializable_edge_cases() {
1607        // Test zero-sized tensor
1608        let zero_tensor = Tensor::new(vec![0]);
1609        let json = <Tensor as crate::serialization::Serializable>::to_json(&zero_tensor).unwrap();
1610        let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1611        assert_eq!(zero_tensor.shape().dims(), restored.shape().dims());
1612        assert_eq!(zero_tensor.size(), restored.size());
1613
1614        let binary =
1615            <Tensor as crate::serialization::Serializable>::to_binary(&zero_tensor).unwrap();
1616        let restored =
1617            <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1618        assert_eq!(zero_tensor.shape().dims(), restored.shape().dims());
1619        assert_eq!(zero_tensor.size(), restored.size());
1620
1621        // Test tensor with special values (use reasonable large values instead of f32::MAX/MIN)
1622        let mut special_tensor = Tensor::zeros(vec![3]);
1623        special_tensor.set(&[0], 0.0); // Zero
1624        special_tensor.set(&[1], 1000000.0); // Large positive value
1625        special_tensor.set(&[2], -1000000.0); // Large negative value
1626
1627        let json =
1628            <Tensor as crate::serialization::Serializable>::to_json(&special_tensor).unwrap();
1629        let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1630        assert_eq!(special_tensor.get(&[0]), restored.get(&[0]));
1631        assert_eq!(special_tensor.get(&[1]), restored.get(&[1]));
1632        assert_eq!(special_tensor.get(&[2]), restored.get(&[2]));
1633
1634        let binary =
1635            <Tensor as crate::serialization::Serializable>::to_binary(&special_tensor).unwrap();
1636        let restored =
1637            <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1638        assert_eq!(special_tensor.get(&[0]), restored.get(&[0]));
1639        assert_eq!(special_tensor.get(&[1]), restored.get(&[1]));
1640        assert_eq!(special_tensor.get(&[2]), restored.get(&[2]));
1641    }
1642}