train_station/tensor/core/
serialization.rs

1//! Tensor serialization implementation
2//!
3//! This module provides comprehensive serialization support for tensors, shapes,
4//! and device information using the Train Station serialization framework.
5//! It supports both JSON and binary formats with full roundtrip fidelity.
6//!
7//! # Key Features
8//!
9//! - **Complete Tensor Serialization**: Tensor data, shape, device, and gradtrack state
10//! - **Shape Serialization**: Dimensions, strides, size, and memory layout information
11//! - **Device Serialization**: Device type and index for CPU/CUDA placement
12//! - **Efficient Binary Format**: Optimized binary serialization for performance
13//! - **Human-Readable JSON**: JSON format for debugging and interoperability
14//! - **Roundtrip Fidelity**: Perfect reconstruction of tensors from serialized data
15//! - **Struct Field Support**: Tensors can be serialized as fields within larger structures
16//!
17//! # Architecture
18//!
19//! The serialization system handles:
20//! - **Tensor Data**: Serialized as `Vec<f32>` for efficiency
21//! - **Shape Information**: Complete shape metadata including strides and layout
22//! - **Device Placement**: Device type and index for proper reconstruction
23//! - **GradTrack State**: requires_grad flag (runtime gradient state not serialized)
24//!
25//! # Non-Serialized Fields
26//!
27//! The following tensor fields are NOT serialized as they are runtime state:
28//! - `data` (raw pointer): Reconstructed from serialized `Vec<f32>`
29//! - `id`: Regenerated during deserialization for uniqueness
30//! - `grad`: Runtime gradient state, not persistent
31//! - `grad_fn`: Runtime gradient function, not persistent
32//! - `allocation_owner`: Internal memory management, reconstructed
33//! - `_phantom`: Zero-sized type, no serialization needed
34//!
35//! # Usage Example
36//!
37//! ## Basic Tensor Serialization
38//!
39//! ```rust
40//! use train_station::Tensor;
41//! use train_station::serialization::StructSerializable;
42//!
43//! // Create and populate a tensor
44//! let mut tensor = Tensor::zeros(vec![2, 3, 4]).with_requires_grad();
45//! tensor.fill(42.0);
46//!
47//! // Serialize to JSON
48//! let json = tensor.to_json().unwrap();
49//! assert!(!json.is_empty());
50//!
51//! // Deserialize from JSON
52//! let loaded_tensor = Tensor::from_json(&json).unwrap();
53//! assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
54//! assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
55//!
56//! // Serialize to binary
57//! let binary = tensor.to_binary().unwrap();
58//! assert!(!binary.is_empty());
59//!
60//! // Deserialize from binary
61//! let loaded_tensor = Tensor::from_binary(&binary).unwrap();
62//! assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
63//! ```
64//!
65//! ## Tensor as Struct Field
66//!
67//! ```rust
68//! use train_station::Tensor;
69//! use train_station::serialization::StructSerializable;
70//!
71//! // Define a struct containing tensors
72//! #[derive(Debug)]
73//! struct ModelWeights {
74//!     weight_matrix: Tensor,
75//!     bias_vector: Tensor,
76//!     learning_rate: f32,
77//!     name: String,
78//! }
79//!
80//! impl StructSerializable for ModelWeights {
81//!     fn to_serializer(&self) -> train_station::serialization::StructSerializer {
82//!         train_station::serialization::StructSerializer::new()
83//!             .field("weight_matrix", &self.weight_matrix)
84//!             .field("bias_vector", &self.bias_vector)
85//!             .field("learning_rate", &self.learning_rate)
86//!             .field("name", &self.name)
87//!     }
88//!
89//!     fn from_deserializer(
90//!         deserializer: &mut train_station::serialization::StructDeserializer,
91//!     ) -> train_station::serialization::SerializationResult<Self> {
92//!         Ok(ModelWeights {
93//!             weight_matrix: deserializer.field("weight_matrix")?,
94//!             bias_vector: deserializer.field("bias_vector")?,
95//!             learning_rate: deserializer.field("learning_rate")?,
96//!             name: deserializer.field("name")?,
97//!         })
98//!     }
99//! }
100//!
101//! // Create test struct with tensors
102//! let mut weights = ModelWeights {
103//!     weight_matrix: Tensor::zeros(vec![10, 5]),
104//!     bias_vector: Tensor::ones(vec![5]).with_requires_grad(),
105//!     learning_rate: 0.001,
106//!     name: "test_model".to_string(),
107//! };
108//!
109//! // Set some values
110//! weights.weight_matrix.set(&[0, 0], 0.5);
111//! weights.bias_vector.set(&[2], 2.0);
112//!
113//! // Test JSON serialization
114//! let json = weights.to_json().unwrap();
115//! let loaded_weights = ModelWeights::from_json(&json).unwrap();
116//!
117//! assert_eq!(weights.learning_rate, loaded_weights.learning_rate);
118//! assert_eq!(weights.name, loaded_weights.name);
119//! assert_eq!(
120//!     weights.weight_matrix.shape().dims,
121//!     loaded_weights.weight_matrix.shape().dims
122//! );
123//! ```
124//!
125//! ## Large Tensor Serialization
126//!
127//! ```rust
128//! use train_station::Tensor;
129//! use train_station::serialization::StructSerializable;
130//!
131//! // Create large tensor
132//! let mut tensor = Tensor::zeros(vec![100, 100]).with_requires_grad();
133//!
134//! // Set some values
135//! for i in 0..10 {
136//!     for j in 0..10 {
137//!         tensor.set(&[i, j], (i * 10 + j) as f32);
138//!     }
139//! }
140//!
141//! // Binary serialization is more efficient for large tensors
142//! let binary = tensor.to_binary().unwrap();
143//! let loaded_tensor = Tensor::from_binary(&binary).unwrap();
144//!
145//! // Verify properties
146//! assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
147//! assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
148//!
149//! // Verify data integrity
150//! for i in 0..10 {
151//!     for j in 0..10 {
152//!         assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
153//!     }
154//! }
155//! ```
156//!
157//! ## Error Handling
158//!
159//! ```rust
160//! use train_station::Tensor;
161//! use train_station::serialization::StructSerializable;
162//!
163//! // Test invalid serialization data
164//! let invalid_json = r#"{"invalid": "data"}"#;
165//! let result = Tensor::from_json(invalid_json);
166//! assert!(result.is_err());
167//!
168//! // Test empty binary data
169//! let empty_binary = vec![];
170//! let result = Tensor::from_binary(&empty_binary);
171//! assert!(result.is_err());
172//! ```
173//!
174//! # Performance Characteristics
175//!
176//! - **Binary Format**: Optimized for size and speed
177//! - **JSON Format**: Human-readable with reasonable performance
178//! - **Memory Efficient**: Minimal overhead during serialization
179//! - **Zero-Copy**: Direct serialization of tensor data arrays
180//! - **Type Safety**: Compile-time guarantees for serialization correctness
181
182use std::collections::HashMap;
183
184use crate::device::{Device, DeviceType};
185use crate::serialization::{
186    FieldValue, FromFieldValue, SerializationError, SerializationResult, StructDeserializer,
187    StructSerializable, StructSerializer, ToFieldValue,
188};
189use crate::tensor::{Shape, Tensor};
190
191// ===== Device Serialization =====
192
193impl ToFieldValue for DeviceType {
194    /// Convert DeviceType to FieldValue for serialization
195    ///
196    /// # Returns
197    ///
198    /// Enum FieldValue with variant name (proper enum serialization)
199    fn to_field_value(&self) -> FieldValue {
200        match self {
201            DeviceType::Cpu => FieldValue::from_enum_unit("Cpu".to_string()),
202            DeviceType::Cuda => FieldValue::from_enum_unit("Cuda".to_string()),
203        }
204    }
205}
206
207impl FromFieldValue for DeviceType {
208    /// Convert FieldValue to DeviceType for deserialization
209    ///
210    /// # Arguments
211    ///
212    /// * `value` - FieldValue containing enum data
213    /// * `field_name` - Name of the field for error reporting
214    ///
215    /// # Returns
216    ///
217    /// DeviceType enum value or error if invalid
218    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
219        let (variant, data) =
220            value
221                .as_enum()
222                .map_err(|_| SerializationError::ValidationFailed {
223                    field: field_name.to_string(),
224                    message: "Expected enum value for device type".to_string(),
225                })?;
226
227        // Ensure no data for unit variants
228        if data.is_some() {
229            return Err(SerializationError::ValidationFailed {
230                field: field_name.to_string(),
231                message: "DeviceType variants should not have associated data".to_string(),
232            });
233        }
234
235        match variant {
236            "Cpu" => Ok(DeviceType::Cpu),
237            "Cuda" => Ok(DeviceType::Cuda),
238            _ => Err(SerializationError::ValidationFailed {
239                field: field_name.to_string(),
240                message: format!("Unknown device type variant: {}", variant),
241            }),
242        }
243    }
244}
245
246impl ToFieldValue for Device {
247    /// Convert Device to FieldValue for serialization
248    ///
249    /// # Returns
250    ///
251    /// Object containing device type and index
252    fn to_field_value(&self) -> FieldValue {
253        let mut object = HashMap::new();
254        object.insert("type".to_string(), self.device_type().to_field_value());
255        object.insert("index".to_string(), self.index().to_field_value());
256        FieldValue::from_object(object)
257    }
258}
259
260impl FromFieldValue for Device {
261    /// Convert FieldValue to Device for deserialization
262    ///
263    /// # Arguments
264    ///
265    /// * `value` - FieldValue containing device object
266    /// * `field_name` - Name of the field for error reporting
267    ///
268    /// # Returns
269    ///
270    /// Device instance or error if invalid
271    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
272        let object = value
273            .as_object()
274            .map_err(|_| SerializationError::ValidationFailed {
275                field: field_name.to_string(),
276                message: "Expected object for device".to_string(),
277            })?;
278
279        let device_type = object
280            .get("type")
281            .ok_or_else(|| SerializationError::ValidationFailed {
282                field: field_name.to_string(),
283                message: "Missing device type field".to_string(),
284            })?
285            .clone();
286
287        let index = object
288            .get("index")
289            .ok_or_else(|| SerializationError::ValidationFailed {
290                field: field_name.to_string(),
291                message: "Missing device index field".to_string(),
292            })?
293            .clone();
294
295        let device_type = DeviceType::from_field_value(device_type, "type")?;
296        let index = usize::from_field_value(index, "index")?;
297
298        match device_type {
299            DeviceType::Cpu => Ok(Device::cpu()),
300            DeviceType::Cuda => Ok(Device::cuda(index)),
301        }
302    }
303}
304
305// ===== Memory Layout Serialization =====
306
307impl ToFieldValue for crate::tensor::MemoryLayout {
308    /// Convert MemoryLayout to FieldValue for serialization
309    ///
310    /// # Returns
311    ///
312    /// Enum FieldValue with variant name (proper enum serialization)
313    fn to_field_value(&self) -> FieldValue {
314        match self {
315            crate::tensor::MemoryLayout::Contiguous => {
316                FieldValue::from_enum_unit("Contiguous".to_string())
317            }
318            crate::tensor::core::MemoryLayout::Strided => {
319                FieldValue::from_enum_unit("Strided".to_string())
320            }
321            crate::tensor::core::MemoryLayout::View => {
322                FieldValue::from_enum_unit("View".to_string())
323            }
324        }
325    }
326}
327
328impl FromFieldValue for crate::tensor::MemoryLayout {
329    /// Convert FieldValue to MemoryLayout for deserialization
330    ///
331    /// # Arguments
332    ///
333    /// * `value` - FieldValue containing enum data
334    /// * `field_name` - Name of the field for error reporting
335    ///
336    /// # Returns
337    ///
338    /// MemoryLayout enum value or error if invalid
339    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
340        let (variant, data) =
341            value
342                .as_enum()
343                .map_err(|_| SerializationError::ValidationFailed {
344                    field: field_name.to_string(),
345                    message: "Expected enum value for memory layout".to_string(),
346                })?;
347
348        // Ensure no data for unit variants
349        if data.is_some() {
350            return Err(SerializationError::ValidationFailed {
351                field: field_name.to_string(),
352                message: "MemoryLayout variants should not have associated data".to_string(),
353            });
354        }
355
356        match variant {
357            "Contiguous" => Ok(crate::tensor::MemoryLayout::Contiguous),
358            "Strided" => Ok(crate::tensor::MemoryLayout::Strided),
359            "View" => Ok(crate::tensor::MemoryLayout::View),
360            _ => Err(SerializationError::ValidationFailed {
361                field: field_name.to_string(),
362                message: format!("Unknown memory layout variant: {}", variant),
363            }),
364        }
365    }
366}
367
368// ===== Shape Serialization =====
369
370impl ToFieldValue for Shape {
371    /// Convert Shape to FieldValue for serialization
372    ///
373    /// # Returns
374    ///
375    /// Object containing all shape metadata
376    fn to_field_value(&self) -> FieldValue {
377        let mut object = HashMap::new();
378        object.insert("dims".to_string(), self.dims.to_field_value());
379        object.insert("size".to_string(), self.size.to_field_value());
380        object.insert("strides".to_string(), self.strides.to_field_value());
381        object.insert("layout".to_string(), self.layout.to_field_value());
382        FieldValue::from_object(object)
383    }
384}
385
386impl FromFieldValue for Shape {
387    /// Convert FieldValue to Shape for deserialization
388    ///
389    /// # Arguments
390    ///
391    /// * `value` - FieldValue containing shape object
392    /// * `field_name` - Name of the field for error reporting
393    ///
394    /// # Returns
395    ///
396    /// Shape instance or error if invalid
397    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
398        let object = value
399            .as_object()
400            .map_err(|_| SerializationError::ValidationFailed {
401                field: field_name.to_string(),
402                message: "Expected object for shape".to_string(),
403            })?;
404
405        let dims = object
406            .get("dims")
407            .ok_or_else(|| SerializationError::ValidationFailed {
408                field: field_name.to_string(),
409                message: "Missing dims field in shape".to_string(),
410            })?
411            .clone();
412
413        let size = object
414            .get("size")
415            .ok_or_else(|| SerializationError::ValidationFailed {
416                field: field_name.to_string(),
417                message: "Missing size field in shape".to_string(),
418            })?
419            .clone();
420
421        let strides = object
422            .get("strides")
423            .ok_or_else(|| SerializationError::ValidationFailed {
424                field: field_name.to_string(),
425                message: "Missing strides field in shape".to_string(),
426            })?
427            .clone();
428
429        let layout = object
430            .get("layout")
431            .ok_or_else(|| SerializationError::ValidationFailed {
432                field: field_name.to_string(),
433                message: "Missing layout field in shape".to_string(),
434            })?
435            .clone();
436
437        let dims = Vec::<usize>::from_field_value(dims, "dims")?;
438        let size = usize::from_field_value(size, "size")?;
439        let strides = Vec::<usize>::from_field_value(strides, "strides")?;
440        let layout = crate::tensor::MemoryLayout::from_field_value(layout, "layout")?;
441
442        // Validate consistency
443        let expected_size: usize = dims.iter().product();
444        if size != expected_size {
445            return Err(SerializationError::ValidationFailed {
446                field: field_name.to_string(),
447                message: format!(
448                    "Shape size {} doesn't match computed size {}",
449                    size, expected_size
450                ),
451            });
452        }
453
454        if dims.len() != strides.len() {
455            return Err(SerializationError::ValidationFailed {
456                field: field_name.to_string(),
457                message: "Dimensions and strides must have same length".to_string(),
458            });
459        }
460
461        Ok(Shape {
462            dims,
463            size,
464            strides,
465            layout,
466        })
467    }
468}
469
470// ===== Tensor Serialization =====
471
472impl StructSerializable for Tensor {
473    /// Convert Tensor to StructSerializer for serialization
474    ///
475    /// Serializes tensor data, shape, device, and gradtrack state.
476    /// Runtime state (id, grad, grad_fn, allocation_owner) is not serialized.
477    ///
478    /// # Returns
479    ///
480    /// StructSerializer containing all persistent tensor state
481    fn to_serializer(&self) -> StructSerializer {
482        // Extract tensor data as Vec<f32> - now uses efficient FieldValue implementation:
483        // - JSON format: Human-readable arrays of numbers
484        // - Binary format: Efficient byte representation with length header
485        let data: Vec<f32> =
486            unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()).to_vec() };
487
488        StructSerializer::new()
489            .field("data", &data)
490            .field("shape", self.shape())
491            .field("device", &self.device())
492            .field("requires_grad", &self.requires_grad())
493    }
494
495    /// Create Tensor from StructDeserializer
496    ///
497    /// Reconstructs tensor from serialized data, shape, device, and gradtrack state.
498    /// Allocates new memory and generates new tensor ID.
499    ///
500    /// # Arguments
501    ///
502    /// * `deserializer` - StructDeserializer containing tensor data
503    ///
504    /// # Returns
505    ///
506    /// Reconstructed Tensor instance or error if deserialization fails
507    fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
508        let data: Vec<f32> = deserializer.field("data")?;
509        let shape: Shape = deserializer.field("shape")?;
510        let device: Device = deserializer.field("device")?;
511        let requires_grad: bool = deserializer.field("requires_grad")?;
512
513        // Validate data size matches shape
514        if data.len() != shape.size {
515            return Err(SerializationError::ValidationFailed {
516                field: "tensor".to_string(),
517                message: format!(
518                    "Data length {} doesn't match shape size {}",
519                    data.len(),
520                    shape.size
521                ),
522            });
523        }
524
525        // Create new tensor with the deserialized shape on the correct device
526        let mut tensor = Tensor::new_on_device(shape.dims.clone(), device);
527
528        // Copy data into tensor
529        if !data.is_empty() {
530            unsafe {
531                let dst = tensor.as_mut_ptr();
532                std::ptr::copy_nonoverlapping(data.as_ptr(), dst, data.len());
533            }
534        }
535
536        // Set gradtrack state
537        tensor.set_requires_grad(requires_grad);
538
539        // Validate that the reconstructed shape matches
540        if tensor.shape().dims != shape.dims
541            || tensor.shape().size != shape.size
542            || tensor.shape().strides != shape.strides
543        {
544            return Err(SerializationError::ValidationFailed {
545                field: "tensor".to_string(),
546                message: "Reconstructed tensor shape doesn't match serialized shape".to_string(),
547            });
548        }
549
550        Ok(tensor)
551    }
552}
553
554impl FromFieldValue for Tensor {
555    /// Convert FieldValue to Tensor for use as struct field
556    ///
557    /// # Arguments
558    ///
559    /// * `value` - FieldValue containing tensor data
560    /// * `field_name` - Name of the field for error reporting
561    ///
562    /// # Returns
563    ///
564    /// Tensor instance or error if deserialization fails
565    fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
566        // Try binary object first (for when serialized as binary)
567        if let Ok(binary_data) = value.as_binary_object() {
568            return Tensor::from_binary(binary_data).map_err(|e| {
569                SerializationError::ValidationFailed {
570                    field: field_name.to_string(),
571                    message: format!("Failed to deserialize tensor from binary: {}", e),
572                }
573            });
574        }
575
576        // Try JSON object (for when serialized as JSON)
577        if let Ok(json_data) = value.as_json_object() {
578            return Tensor::from_json(json_data).map_err(|e| {
579                SerializationError::ValidationFailed {
580                    field: field_name.to_string(),
581                    message: format!("Failed to deserialize tensor from JSON: {}", e),
582                }
583            });
584        }
585
586        // Try object (for when serialized as structured object in JSON)
587        if let Ok(object) = value.as_object() {
588            // Convert object back to deserializer and use StructSerializable
589            let mut deserializer = StructDeserializer::from_fields(object.clone());
590            return Tensor::from_deserializer(&mut deserializer).map_err(|e| {
591                SerializationError::ValidationFailed {
592                    field: field_name.to_string(),
593                    message: format!("Failed to deserialize tensor from object: {}", e),
594                }
595            });
596        }
597
598        Err(SerializationError::ValidationFailed {
599            field: field_name.to_string(),
600            message: "Expected binary object, JSON object, or structured object for tensor field"
601                .to_string(),
602        })
603    }
604}
605
606// ===== Serializable Trait Implementation =====
607
608impl crate::serialization::Serializable for Tensor {
609    /// Serialize the tensor to JSON format
610    ///
611    /// This method converts the tensor into a human-readable JSON string representation
612    /// that includes all tensor data, shape information, device placement, and gradtrack state.
613    /// The JSON format is suitable for debugging, configuration files, and cross-language
614    /// interoperability.
615    ///
616    /// # Returns
617    ///
618    /// JSON string representation of the tensor on success, or `SerializationError` on failure
619    ///
620    /// # Examples
621    ///
622    /// ```
623    /// use train_station::Tensor;
624    /// use train_station::serialization::Serializable;
625    ///
626    /// let mut tensor = Tensor::zeros(vec![2, 3]);
627    /// tensor.set(&[0, 0], 1.0);
628    /// tensor.set(&[1, 2], 5.0);
629    ///
630    /// let json = tensor.to_json().unwrap();
631    /// assert!(!json.is_empty());
632    /// assert!(json.contains("data"));
633    /// assert!(json.contains("shape"));
634    /// ```
635    fn to_json(&self) -> SerializationResult<String> {
636        StructSerializable::to_json(self)
637    }
638
639    /// Deserialize a tensor from JSON format
640    ///
641    /// This method parses a JSON string and reconstructs a tensor with all its data,
642    /// shape information, device placement, and gradtrack state. The JSON must contain
643    /// all necessary fields in the expected format.
644    ///
645    /// # Arguments
646    ///
647    /// * `json` - JSON string containing serialized tensor data
648    ///
649    /// # Returns
650    ///
651    /// The deserialized tensor on success, or `SerializationError` on failure
652    ///
653    /// # Examples
654    ///
655    /// ```
656    /// use train_station::Tensor;
657    /// use train_station::serialization::Serializable;
658    ///
659    /// let mut original = Tensor::ones(vec![2, 2]);
660    /// original.set(&[0, 1], 3.0);
661    /// original.set_requires_grad(true);
662    ///
663    /// let json = original.to_json().unwrap();
664    /// let restored = Tensor::from_json(&json).unwrap();
665    ///
666    /// assert_eq!(original.shape().dims, restored.shape().dims);
667    /// assert_eq!(original.get(&[0, 1]), restored.get(&[0, 1]));
668    /// assert_eq!(original.requires_grad(), restored.requires_grad());
669    /// ```
670    fn from_json(json: &str) -> SerializationResult<Self> {
671        StructSerializable::from_json(json)
672    }
673
674    /// Serialize the tensor to binary format
675    ///
676    /// This method converts the tensor into a compact binary representation optimized
677    /// for storage and transmission. The binary format provides maximum performance
678    /// and minimal file sizes, making it ideal for large tensors and production use.
679    ///
680    /// # Returns
681    ///
682    /// Binary representation of the tensor on success, or `SerializationError` on failure
683    ///
684    /// # Examples
685    ///
686    /// ```
687    /// use train_station::Tensor;
688    /// use train_station::serialization::Serializable;
689    ///
690    /// let mut tensor = Tensor::zeros(vec![100, 100]);
691    /// for i in 0..10 {
692    ///     tensor.set(&[i, i], i as f32);
693    /// }
694    ///
695    /// let binary = tensor.to_binary().unwrap();
696    /// assert!(!binary.is_empty());
697    /// // Binary format is more compact than JSON for large tensors
698    /// ```
699    fn to_binary(&self) -> SerializationResult<Vec<u8>> {
700        StructSerializable::to_binary(self)
701    }
702
703    /// Deserialize a tensor from binary format
704    ///
705    /// This method parses binary data and reconstructs a tensor with all its data,
706    /// shape information, device placement, and gradtrack state. The binary data
707    /// must contain complete serialized information in the expected format.
708    ///
709    /// # Arguments
710    ///
711    /// * `data` - Binary data containing serialized tensor information
712    ///
713    /// # Returns
714    ///
715    /// The deserialized tensor on success, or `SerializationError` on failure
716    ///
717    /// # Examples
718    ///
719    /// ```
720    /// use train_station::Tensor;
721    /// use train_station::serialization::Serializable;
722    ///
723    /// let mut original = Tensor::ones(vec![3, 4]);
724    /// original.set(&[2, 3], 7.5);
725    /// original.set_requires_grad(true);
726    ///
727    /// let binary = original.to_binary().unwrap();
728    /// let restored = Tensor::from_binary(&binary).unwrap();
729    ///
730    /// assert_eq!(original.shape().dims, restored.shape().dims);
731    /// assert_eq!(original.get(&[2, 3]), restored.get(&[2, 3]));
732    /// assert_eq!(original.requires_grad(), restored.requires_grad());
733    /// ```
734    fn from_binary(data: &[u8]) -> SerializationResult<Self> {
735        StructSerializable::from_binary(data)
736    }
737}
738
739#[cfg(test)]
740mod tests {
741    //! Comprehensive tests for tensor serialization functionality
742    //!
743    //! Tests cover all serialization formats and usage patterns including:
744    //! - JSON and binary roundtrip serialization
745    //! - Tensor as field within structs  
746    //! - Edge cases and error conditions
747    //! - Device and shape serialization
748    //! - Large tensor serialization
749
750    use super::*;
751
752    // ===== Device Serialization Tests =====
753
754    #[test]
755    fn test_device_type_serialization() {
756        // Test CPU device type
757        let cpu_type = DeviceType::Cpu;
758        let field_value = cpu_type.to_field_value();
759        let deserialized = DeviceType::from_field_value(field_value, "device_type").unwrap();
760        assert_eq!(cpu_type, deserialized);
761
762        // Test CUDA device type
763        let cuda_type = DeviceType::Cuda;
764        let field_value = cuda_type.to_field_value();
765        let deserialized = DeviceType::from_field_value(field_value, "device_type").unwrap();
766        assert_eq!(cuda_type, deserialized);
767    }
768
769    #[test]
770    fn test_device_serialization() {
771        // Test CPU device
772        let cpu_device = Device::cpu();
773        let field_value = cpu_device.to_field_value();
774        let deserialized = Device::from_field_value(field_value, "device").unwrap();
775        assert_eq!(cpu_device, deserialized);
776        assert!(deserialized.is_cpu());
777        assert_eq!(deserialized.index(), 0);
778    }
779
780    #[test]
781    fn test_device_serialization_errors() {
782        // Test invalid device type
783        let invalid_device_type = FieldValue::from_string("invalid".to_string());
784        let result = DeviceType::from_field_value(invalid_device_type, "device_type");
785        assert!(result.is_err());
786
787        // Test missing device fields
788        let incomplete_device = FieldValue::from_object({
789            let mut obj = HashMap::new();
790            obj.insert(
791                "type".to_string(),
792                FieldValue::from_string("cpu".to_string()),
793            );
794            // Missing index field
795            obj
796        });
797        let result = Device::from_field_value(incomplete_device, "device");
798        assert!(result.is_err());
799    }
800
801    // ===== Shape Serialization Tests =====
802
803    #[test]
804    fn test_memory_layout_serialization() {
805        use crate::tensor::MemoryLayout;
806
807        let layouts = [
808            MemoryLayout::Contiguous,
809            MemoryLayout::Strided,
810            MemoryLayout::View,
811        ];
812
813        for layout in &layouts {
814            let field_value = layout.to_field_value();
815            let deserialized = MemoryLayout::from_field_value(field_value, "layout").unwrap();
816            assert_eq!(*layout, deserialized);
817        }
818    }
819
820    #[test]
821    fn test_shape_serialization() {
822        // Test contiguous shape
823        let shape = Shape::new(vec![2, 3, 4]);
824        let field_value = shape.to_field_value();
825        let deserialized = Shape::from_field_value(field_value, "shape").unwrap();
826        assert_eq!(shape, deserialized);
827        assert_eq!(deserialized.dims, vec![2, 3, 4]);
828        assert_eq!(deserialized.size, 24);
829        assert_eq!(deserialized.strides, vec![12, 4, 1]);
830
831        // Test strided shape
832        let strided_shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
833        let field_value = strided_shape.to_field_value();
834        let deserialized = Shape::from_field_value(field_value, "shape").unwrap();
835        assert_eq!(strided_shape, deserialized);
836    }
837
838    #[test]
839    fn test_shape_validation_errors() {
840        use crate::tensor::MemoryLayout;
841
842        // Test inconsistent size
843        let invalid_shape = FieldValue::from_object({
844            let mut obj = HashMap::new();
845            obj.insert("dims".to_string(), vec![2usize, 3].to_field_value());
846            obj.insert("size".to_string(), 10usize.to_field_value()); // Should be 6
847            obj.insert("strides".to_string(), vec![3usize, 1].to_field_value());
848            obj.insert(
849                "layout".to_string(),
850                MemoryLayout::Contiguous.to_field_value(),
851            );
852            obj
853        });
854        let result = Shape::from_field_value(invalid_shape, "shape");
855        assert!(result.is_err());
856
857        // Test mismatched dimensions and strides
858        let invalid_shape = FieldValue::from_object({
859            let mut obj = HashMap::new();
860            obj.insert("dims".to_string(), vec![2usize, 3].to_field_value());
861            obj.insert("size".to_string(), 6usize.to_field_value());
862            obj.insert("strides".to_string(), vec![3usize].to_field_value()); // Wrong length
863            obj.insert(
864                "layout".to_string(),
865                MemoryLayout::Contiguous.to_field_value(),
866            );
867            obj
868        });
869        let result = Shape::from_field_value(invalid_shape, "shape");
870        assert!(result.is_err());
871    }
872
873    // ===== Tensor Serialization Tests =====
874
875    #[test]
876    fn test_tensor_json_roundtrip() {
877        // Create test tensor with data
878        let mut tensor = Tensor::zeros(vec![2, 3]);
879        tensor.set(&[0, 0], 1.0);
880        tensor.set(&[0, 1], 2.0);
881        tensor.set(&[0, 2], 3.0);
882        tensor.set(&[1, 0], 4.0);
883        tensor.set(&[1, 1], 5.0);
884        tensor.set(&[1, 2], 6.0);
885        tensor.set_requires_grad(true);
886
887        // Serialize to JSON
888        let json = tensor.to_json().unwrap();
889        assert!(!json.is_empty());
890
891        // Deserialize from JSON
892        let loaded_tensor = Tensor::from_json(&json).unwrap();
893
894        // Verify tensor properties
895        assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
896        assert_eq!(tensor.size(), loaded_tensor.size());
897        assert_eq!(tensor.device(), loaded_tensor.device());
898        assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
899
900        // Verify tensor data
901        for i in 0..2 {
902            for j in 0..3 {
903                assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
904            }
905        }
906    }
907
908    #[test]
909    fn test_tensor_binary_roundtrip() {
910        // Create test tensor with gradient tracking
911        let mut tensor = Tensor::ones(vec![3, 4]).with_requires_grad();
912
913        // Modify some values
914        tensor.set(&[0, 0], 10.0);
915        tensor.set(&[1, 2], 20.0);
916        tensor.set(&[2, 3], 30.0);
917
918        // Serialize to binary
919        let binary = tensor.to_binary().unwrap();
920        assert!(!binary.is_empty());
921
922        // Deserialize from binary
923        let loaded_tensor = Tensor::from_binary(&binary).unwrap();
924
925        // Verify tensor properties
926        assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
927        assert_eq!(tensor.size(), loaded_tensor.size());
928        assert_eq!(tensor.device(), loaded_tensor.device());
929        assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
930
931        // Verify tensor data
932        for i in 0..3 {
933            for j in 0..4 {
934                assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
935            }
936        }
937    }
938
939    #[test]
940    fn test_empty_tensor_serialization() {
941        // Test zero-sized tensor
942        let tensor = Tensor::new(vec![0]);
943
944        // JSON roundtrip
945        let json = tensor.to_json().unwrap();
946        let loaded_tensor = Tensor::from_json(&json).unwrap();
947        assert_eq!(tensor.size(), loaded_tensor.size());
948        assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
949
950        // Binary roundtrip
951        let binary = tensor.to_binary().unwrap();
952        let loaded_tensor = Tensor::from_binary(&binary).unwrap();
953        assert_eq!(tensor.size(), loaded_tensor.size());
954        assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
955    }
956
957    #[test]
958    fn test_large_tensor_serialization() {
959        // Test larger tensor
960        let mut tensor = Tensor::zeros(vec![100, 100]).with_requires_grad();
961
962        // Set some values
963        for i in 0..10 {
964            for j in 0..10 {
965                tensor.set(&[i, j], (i * 10 + j) as f32);
966            }
967        }
968
969        // Binary roundtrip (more efficient for large tensors)
970        let binary = tensor.to_binary().unwrap();
971        let loaded_tensor = Tensor::from_binary(&binary).unwrap();
972
973        // Verify properties
974        assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
975        assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
976
977        // Verify a subset of data
978        for i in 0..10 {
979            for j in 0..10 {
980                assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
981            }
982        }
983    }
984
985    #[test]
986    fn test_tensor_as_field_in_struct() {
987        // Define a struct containing tensors
988        #[derive(Debug)]
989        struct ModelWeights {
990            weight_matrix: Tensor,
991            bias_vector: Tensor,
992            learning_rate: f32,
993            name: String,
994        }
995
996        impl StructSerializable for ModelWeights {
997            fn to_serializer(&self) -> StructSerializer {
998                StructSerializer::new()
999                    .field("weight_matrix", &self.weight_matrix)
1000                    .field("bias_vector", &self.bias_vector)
1001                    .field("learning_rate", &self.learning_rate)
1002                    .field("name", &self.name)
1003            }
1004
1005            fn from_deserializer(
1006                deserializer: &mut StructDeserializer,
1007            ) -> SerializationResult<Self> {
1008                Ok(ModelWeights {
1009                    weight_matrix: deserializer.field("weight_matrix")?,
1010                    bias_vector: deserializer.field("bias_vector")?,
1011                    learning_rate: deserializer.field("learning_rate")?,
1012                    name: deserializer.field("name")?,
1013                })
1014            }
1015        }
1016
1017        // Create test struct with tensors
1018        let mut weights = ModelWeights {
1019            weight_matrix: Tensor::zeros(vec![10, 5]),
1020            bias_vector: Tensor::ones(vec![5]).with_requires_grad(),
1021            learning_rate: 0.001,
1022            name: "test_model".to_string(),
1023        };
1024
1025        // Set some values
1026        weights.weight_matrix.set(&[0, 0], 0.5);
1027        weights.weight_matrix.set(&[9, 4], -0.3);
1028        weights.bias_vector.set(&[2], 2.0);
1029
1030        // Test JSON serialization
1031        let json = weights.to_json().unwrap();
1032        let loaded_weights = ModelWeights::from_json(&json).unwrap();
1033
1034        assert_eq!(weights.learning_rate, loaded_weights.learning_rate);
1035        assert_eq!(weights.name, loaded_weights.name);
1036        assert_eq!(
1037            weights.weight_matrix.shape().dims,
1038            loaded_weights.weight_matrix.shape().dims
1039        );
1040        assert_eq!(
1041            weights.bias_vector.shape().dims,
1042            loaded_weights.bias_vector.shape().dims
1043        );
1044        assert_eq!(
1045            weights.bias_vector.requires_grad(),
1046            loaded_weights.bias_vector.requires_grad()
1047        );
1048
1049        // Verify tensor data
1050        assert_eq!(
1051            weights.weight_matrix.get(&[0, 0]),
1052            loaded_weights.weight_matrix.get(&[0, 0])
1053        );
1054        assert_eq!(
1055            weights.weight_matrix.get(&[9, 4]),
1056            loaded_weights.weight_matrix.get(&[9, 4])
1057        );
1058        assert_eq!(
1059            weights.bias_vector.get(&[2]),
1060            loaded_weights.bias_vector.get(&[2])
1061        );
1062
1063        // Test binary serialization
1064        let binary = weights.to_binary().unwrap();
1065        let loaded_weights = ModelWeights::from_binary(&binary).unwrap();
1066
1067        assert_eq!(weights.learning_rate, loaded_weights.learning_rate);
1068        assert_eq!(weights.name, loaded_weights.name);
1069        assert_eq!(
1070            weights.weight_matrix.shape().dims,
1071            loaded_weights.weight_matrix.shape().dims
1072        );
1073        assert_eq!(
1074            weights.bias_vector.requires_grad(),
1075            loaded_weights.bias_vector.requires_grad()
1076        );
1077    }
1078
1079    #[test]
1080    fn test_multiple_tensors_in_struct() {
1081        // Test struct with multiple tensors of different shapes
1082        #[derive(Debug)]
1083        struct MultiTensorStruct {
1084            tensor_1d: Tensor,
1085            tensor_2d: Tensor,
1086            tensor_3d: Tensor,
1087            metadata: HashMap<String, String>,
1088        }
1089
1090        impl StructSerializable for MultiTensorStruct {
1091            fn to_serializer(&self) -> StructSerializer {
1092                StructSerializer::new()
1093                    .field("tensor_1d", &self.tensor_1d)
1094                    .field("tensor_2d", &self.tensor_2d)
1095                    .field("tensor_3d", &self.tensor_3d)
1096                    .field("metadata", &self.metadata)
1097            }
1098
1099            fn from_deserializer(
1100                deserializer: &mut StructDeserializer,
1101            ) -> SerializationResult<Self> {
1102                Ok(MultiTensorStruct {
1103                    tensor_1d: deserializer.field("tensor_1d")?,
1104                    tensor_2d: deserializer.field("tensor_2d")?,
1105                    tensor_3d: deserializer.field("tensor_3d")?,
1106                    metadata: deserializer.field("metadata")?,
1107                })
1108            }
1109        }
1110
1111        // Create test struct
1112        let mut multi_tensor = MultiTensorStruct {
1113            tensor_1d: Tensor::zeros(vec![5]),
1114            tensor_2d: Tensor::ones(vec![3, 4]).with_requires_grad(),
1115            tensor_3d: Tensor::zeros(vec![2, 2, 2]),
1116            metadata: {
1117                let mut map = HashMap::new();
1118                map.insert("version".to_string(), "1.0".to_string());
1119                map.insert("type".to_string(), "test".to_string());
1120                map
1121            },
1122        };
1123
1124        // Set some values
1125        multi_tensor.tensor_1d.set(&[0], 10.0);
1126        multi_tensor.tensor_2d.set(&[0, 0], 5.0);
1127        multi_tensor.tensor_3d.set(&[1, 1, 1], 3.0);
1128
1129        // Test JSON roundtrip
1130        let json = multi_tensor.to_json().unwrap();
1131        let loaded = MultiTensorStruct::from_json(&json).unwrap();
1132
1133        assert_eq!(
1134            multi_tensor.tensor_1d.shape().dims,
1135            loaded.tensor_1d.shape().dims
1136        );
1137        assert_eq!(
1138            multi_tensor.tensor_2d.shape().dims,
1139            loaded.tensor_2d.shape().dims
1140        );
1141        assert_eq!(
1142            multi_tensor.tensor_3d.shape().dims,
1143            loaded.tensor_3d.shape().dims
1144        );
1145        assert_eq!(
1146            multi_tensor.tensor_2d.requires_grad(),
1147            loaded.tensor_2d.requires_grad()
1148        );
1149        assert_eq!(multi_tensor.metadata, loaded.metadata);
1150
1151        // Verify tensor values
1152        assert_eq!(multi_tensor.tensor_1d.get(&[0]), loaded.tensor_1d.get(&[0]));
1153        assert_eq!(
1154            multi_tensor.tensor_2d.get(&[0, 0]),
1155            loaded.tensor_2d.get(&[0, 0])
1156        );
1157        assert_eq!(
1158            multi_tensor.tensor_3d.get(&[1, 1, 1]),
1159            loaded.tensor_3d.get(&[1, 1, 1])
1160        );
1161
1162        // Test binary roundtrip
1163        let binary = multi_tensor.to_binary().unwrap();
1164        let loaded = MultiTensorStruct::from_binary(&binary).unwrap();
1165        assert_eq!(
1166            multi_tensor.tensor_1d.shape().dims,
1167            loaded.tensor_1d.shape().dims
1168        );
1169        assert_eq!(
1170            multi_tensor.tensor_2d.requires_grad(),
1171            loaded.tensor_2d.requires_grad()
1172        );
1173    }
1174
1175    #[test]
1176    fn test_tensor_serialization_errors() {
1177        // Test invalid data size
1178        let mut deserializer = StructDeserializer::from_json(
1179            r#"
1180        {
1181            "data": [1.0, 2.0, 3.0],
1182            "shape": {
1183                "dims": [2, 3],
1184                "size": 6,
1185                "strides": [3, 1],
1186                "layout": "contiguous"
1187            },
1188            "device": {"type": "cpu", "index": 0},
1189            "requires_grad": false
1190        }"#,
1191        )
1192        .unwrap();
1193
1194        let result = Tensor::from_deserializer(&mut deserializer);
1195        assert!(result.is_err()); // Data length (3) doesn't match shape size (6)
1196    }
1197
1198    #[test]
1199    fn test_field_value_tensor_roundtrip() {
1200        // Test tensor as FieldValue
1201        let mut tensor = Tensor::zeros(vec![2, 2]);
1202        tensor.set(&[0, 0], 1.0);
1203        tensor.set(&[1, 1], 2.0);
1204
1205        let field_value = tensor.to_field_value();
1206        let loaded_tensor = Tensor::from_field_value(field_value, "test_tensor").unwrap();
1207
1208        assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
1209        assert_eq!(tensor.get(&[0, 0]), loaded_tensor.get(&[0, 0]));
1210        assert_eq!(tensor.get(&[1, 1]), loaded_tensor.get(&[1, 1]));
1211    }
1212
1213    #[test]
1214    fn test_different_tensor_shapes() {
1215        let test_shapes = vec![
1216            vec![1],          // Scalar
1217            vec![10],         // 1D vector
1218            vec![3, 4],       // 2D matrix
1219            vec![2, 3, 4],    // 3D tensor
1220            vec![2, 2, 2, 2], // 4D tensor
1221        ];
1222
1223        for shape in test_shapes {
1224            let tensor = Tensor::zeros(shape.clone()).with_requires_grad();
1225
1226            // JSON roundtrip
1227            let json = tensor.to_json().unwrap();
1228            let loaded = Tensor::from_json(&json).unwrap();
1229            assert_eq!(tensor.shape().dims, loaded.shape().dims);
1230            assert_eq!(tensor.requires_grad(), loaded.requires_grad());
1231
1232            // Binary roundtrip
1233            let binary = tensor.to_binary().unwrap();
1234            let loaded = Tensor::from_binary(&binary).unwrap();
1235            assert_eq!(tensor.shape().dims, loaded.shape().dims);
1236            assert_eq!(tensor.requires_grad(), loaded.requires_grad());
1237        }
1238    }
1239
1240    // ===== Serializable Trait Tests =====
1241
1242    #[test]
1243    fn test_serializable_json_methods() {
1244        // Create and populate test tensor
1245        let mut tensor = Tensor::zeros(vec![2, 3]);
1246        tensor.set(&[0, 0], 1.0);
1247        tensor.set(&[0, 1], 2.0);
1248        tensor.set(&[1, 2], 5.0);
1249        tensor.set_requires_grad(true);
1250
1251        // Test to_json method
1252        let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1253        assert!(!json.is_empty());
1254        assert!(json.contains("data"));
1255        assert!(json.contains("shape"));
1256        assert!(json.contains("device"));
1257        assert!(json.contains("requires_grad"));
1258
1259        // Test from_json method
1260        let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1261        assert_eq!(tensor.shape().dims, restored.shape().dims);
1262        assert_eq!(tensor.size(), restored.size());
1263        assert_eq!(tensor.device(), restored.device());
1264        assert_eq!(tensor.requires_grad(), restored.requires_grad());
1265
1266        // Verify tensor data
1267        assert_eq!(tensor.get(&[0, 0]), restored.get(&[0, 0]));
1268        assert_eq!(tensor.get(&[0, 1]), restored.get(&[0, 1]));
1269        assert_eq!(tensor.get(&[1, 2]), restored.get(&[1, 2]));
1270    }
1271
1272    #[test]
1273    fn test_serializable_binary_methods() {
1274        // Create and populate test tensor
1275        let mut tensor = Tensor::ones(vec![3, 4]);
1276        tensor.set(&[0, 0], 10.0);
1277        tensor.set(&[1, 2], 20.0);
1278        tensor.set(&[2, 3], 30.0);
1279        tensor.set_requires_grad(true);
1280
1281        // Test to_binary method
1282        let binary = <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1283        assert!(!binary.is_empty());
1284
1285        // Test from_binary method
1286        let restored =
1287            <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1288        assert_eq!(tensor.shape().dims, restored.shape().dims);
1289        assert_eq!(tensor.size(), restored.size());
1290        assert_eq!(tensor.device(), restored.device());
1291        assert_eq!(tensor.requires_grad(), restored.requires_grad());
1292
1293        // Verify tensor data
1294        assert_eq!(tensor.get(&[0, 0]), restored.get(&[0, 0]));
1295        assert_eq!(tensor.get(&[1, 2]), restored.get(&[1, 2]));
1296        assert_eq!(tensor.get(&[2, 3]), restored.get(&[2, 3]));
1297    }
1298
1299    #[test]
1300    fn test_serializable_file_io_json() {
1301        use crate::serialization::{Format, Serializable};
1302        use std::fs;
1303        use std::path::Path;
1304
1305        // Create test tensor
1306        let mut tensor = Tensor::zeros(vec![2, 2]);
1307        tensor.set(&[0, 0], 1.0);
1308        tensor.set(&[0, 1], 2.0);
1309        tensor.set(&[1, 0], 3.0);
1310        tensor.set(&[1, 1], 4.0);
1311        tensor.set_requires_grad(true);
1312
1313        // Test file paths
1314        let json_path = "test_tensor_serializable.json";
1315        let json_path_2 = "test_tensor_serializable_2.json";
1316
1317        // Cleanup any existing files
1318        let _ = fs::remove_file(json_path);
1319        let _ = fs::remove_file(json_path_2);
1320
1321        // Test save method with JSON format
1322        Serializable::save(&tensor, json_path, Format::Json).unwrap();
1323        assert!(Path::new(json_path).exists());
1324
1325        // Test load method with JSON format
1326        let loaded_tensor = Tensor::load(json_path, Format::Json).unwrap();
1327        assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
1328        assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1329        assert_eq!(tensor.get(&[0, 0]), loaded_tensor.get(&[0, 0]));
1330        assert_eq!(tensor.get(&[1, 1]), loaded_tensor.get(&[1, 1]));
1331
1332        // Test save_to_writer and load_from_reader
1333        {
1334            let mut writer = std::fs::File::create(json_path_2).unwrap();
1335            Serializable::save_to_writer(&tensor, &mut writer, Format::Json).unwrap();
1336        }
1337        assert!(Path::new(json_path_2).exists());
1338
1339        {
1340            let mut reader = std::fs::File::open(json_path_2).unwrap();
1341            let loaded_tensor = Tensor::load_from_reader(&mut reader, Format::Json).unwrap();
1342            assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
1343            assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1344            assert_eq!(tensor.get(&[0, 1]), loaded_tensor.get(&[0, 1]));
1345            assert_eq!(tensor.get(&[1, 0]), loaded_tensor.get(&[1, 0]));
1346        }
1347
1348        // Cleanup test files
1349        let _ = fs::remove_file(json_path);
1350        let _ = fs::remove_file(json_path_2);
1351    }
1352
1353    #[test]
1354    fn test_serializable_file_io_binary() {
1355        use crate::serialization::{Format, Serializable};
1356        use std::fs;
1357        use std::path::Path;
1358
1359        // Create test tensor
1360        let mut tensor = Tensor::ones(vec![3, 3]);
1361        for i in 0..3 {
1362            for j in 0..3 {
1363                tensor.set(&[i, j], (i * 3 + j) as f32);
1364            }
1365        }
1366        tensor.set_requires_grad(true);
1367
1368        // Test file paths
1369        let binary_path = "test_tensor_serializable.bin";
1370        let binary_path_2 = "test_tensor_serializable_2.bin";
1371
1372        // Cleanup any existing files
1373        let _ = fs::remove_file(binary_path);
1374        let _ = fs::remove_file(binary_path_2);
1375
1376        // Test save method with binary format
1377        Serializable::save(&tensor, binary_path, Format::Binary).unwrap();
1378        assert!(Path::new(binary_path).exists());
1379
1380        // Test load method with binary format
1381        let loaded_tensor = Tensor::load(binary_path, Format::Binary).unwrap();
1382        assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
1383        assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1384
1385        // Verify all data
1386        for i in 0..3 {
1387            for j in 0..3 {
1388                assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
1389            }
1390        }
1391
1392        // Test save_to_writer and load_from_reader
1393        {
1394            let mut writer = std::fs::File::create(binary_path_2).unwrap();
1395            Serializable::save_to_writer(&tensor, &mut writer, Format::Binary).unwrap();
1396        }
1397        assert!(Path::new(binary_path_2).exists());
1398
1399        {
1400            let mut reader = std::fs::File::open(binary_path_2).unwrap();
1401            let loaded_tensor = Tensor::load_from_reader(&mut reader, Format::Binary).unwrap();
1402            assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
1403            assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1404
1405            // Verify all data
1406            for i in 0..3 {
1407                for j in 0..3 {
1408                    assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
1409                }
1410            }
1411        }
1412
1413        // Cleanup test files
1414        let _ = fs::remove_file(binary_path);
1415        let _ = fs::remove_file(binary_path_2);
1416    }
1417
1418    #[test]
1419    fn test_serializable_large_tensor_performance() {
1420        // Create a large tensor to test performance characteristics
1421        let mut tensor = Tensor::zeros(vec![50, 50]);
1422        for i in 0..25 {
1423            for j in 0..25 {
1424                tensor.set(&[i, j], (i * 25 + j) as f32);
1425            }
1426        }
1427        tensor.set_requires_grad(true);
1428
1429        // Test JSON serialization
1430        let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1431        assert!(!json.is_empty());
1432        let restored_json =
1433            <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1434        assert_eq!(tensor.shape().dims, restored_json.shape().dims);
1435        assert_eq!(tensor.requires_grad(), restored_json.requires_grad());
1436
1437        // Test binary serialization
1438        let binary = <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1439        assert!(!binary.is_empty());
1440        // Binary format should be efficient (this is informational, not a requirement)
1441        println!(
1442            "JSON size: {} bytes, Binary size: {} bytes",
1443            json.len(),
1444            binary.len()
1445        );
1446
1447        let restored_binary =
1448            <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1449        assert_eq!(tensor.shape().dims, restored_binary.shape().dims);
1450        assert_eq!(tensor.requires_grad(), restored_binary.requires_grad());
1451
1452        // Verify a sample of data values
1453        for i in 0..5 {
1454            for j in 0..5 {
1455                assert_eq!(tensor.get(&[i, j]), restored_json.get(&[i, j]));
1456                assert_eq!(tensor.get(&[i, j]), restored_binary.get(&[i, j]));
1457            }
1458        }
1459    }
1460
1461    #[test]
1462    fn test_serializable_error_handling() {
1463        // Test invalid JSON
1464        let invalid_json = r#"{"invalid": "json", "structure": true}"#;
1465        let result = <Tensor as crate::serialization::Serializable>::from_json(invalid_json);
1466        assert!(result.is_err());
1467
1468        // Test empty JSON
1469        let empty_json = "{}";
1470        let result = <Tensor as crate::serialization::Serializable>::from_json(empty_json);
1471        assert!(result.is_err());
1472
1473        // Test invalid binary data
1474        let invalid_binary = vec![1, 2, 3, 4, 5];
1475        let result = <Tensor as crate::serialization::Serializable>::from_binary(&invalid_binary);
1476        assert!(result.is_err());
1477
1478        // Test empty binary data
1479        let empty_binary = vec![];
1480        let result = <Tensor as crate::serialization::Serializable>::from_binary(&empty_binary);
1481        assert!(result.is_err());
1482    }
1483
1484    #[test]
1485    fn test_serializable_different_shapes_and_types() {
1486        let test_cases = vec![
1487            // Scalar (1-element tensor)
1488            (vec![1], vec![42.0]),
1489            // 1D vector
1490            (vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]),
1491            // 2D matrix
1492            (vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
1493            // 3D tensor
1494            (vec![2, 2, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
1495        ];
1496
1497        for (shape, expected_data) in test_cases {
1498            // Create tensor with specific shape and data
1499            let mut tensor = Tensor::zeros(shape.clone());
1500
1501            // Set data based on shape dimensions
1502            match shape.len() {
1503                1 => {
1504                    for (i, &value) in expected_data.iter().enumerate().take(shape[0]) {
1505                        tensor.set(&[i], value);
1506                    }
1507                }
1508                2 => {
1509                    let mut idx = 0;
1510                    for i in 0..shape[0] {
1511                        for j in 0..shape[1] {
1512                            if idx < expected_data.len() {
1513                                tensor.set(&[i, j], expected_data[idx]);
1514                                idx += 1;
1515                            }
1516                        }
1517                    }
1518                }
1519                3 => {
1520                    let mut idx = 0;
1521                    for i in 0..shape[0] {
1522                        for j in 0..shape[1] {
1523                            for k in 0..shape[2] {
1524                                if idx < expected_data.len() {
1525                                    tensor.set(&[i, j, k], expected_data[idx]);
1526                                    idx += 1;
1527                                }
1528                            }
1529                        }
1530                    }
1531                }
1532                _ => {}
1533            }
1534            tensor.set_requires_grad(true);
1535
1536            // Test JSON roundtrip
1537            let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1538            let restored_json =
1539                <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1540            assert_eq!(tensor.shape().dims, restored_json.shape().dims);
1541            assert_eq!(tensor.requires_grad(), restored_json.requires_grad());
1542
1543            // Test binary roundtrip
1544            let binary =
1545                <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1546            let restored_binary =
1547                <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1548            assert_eq!(tensor.shape().dims, restored_binary.shape().dims);
1549            assert_eq!(tensor.requires_grad(), restored_binary.requires_grad());
1550
1551            // Verify data for first few elements
1552            match shape.len() {
1553                1 => {
1554                    for i in 0..shape[0].min(3).min(expected_data.len()) {
1555                        assert_eq!(tensor.get(&[i]), restored_json.get(&[i]));
1556                        assert_eq!(tensor.get(&[i]), restored_binary.get(&[i]));
1557                    }
1558                }
1559                2 => {
1560                    let mut count = 0;
1561                    for i in 0..shape[0] {
1562                        for j in 0..shape[1] {
1563                            if count < 3 && count < expected_data.len() {
1564                                assert_eq!(tensor.get(&[i, j]), restored_json.get(&[i, j]));
1565                                assert_eq!(tensor.get(&[i, j]), restored_binary.get(&[i, j]));
1566                                count += 1;
1567                            }
1568                        }
1569                    }
1570                }
1571                3 => {
1572                    let mut count = 0;
1573                    for i in 0..shape[0] {
1574                        for j in 0..shape[1] {
1575                            for k in 0..shape[2] {
1576                                if count < 3 && count < expected_data.len() {
1577                                    assert_eq!(
1578                                        tensor.get(&[i, j, k]),
1579                                        restored_json.get(&[i, j, k])
1580                                    );
1581                                    assert_eq!(
1582                                        tensor.get(&[i, j, k]),
1583                                        restored_binary.get(&[i, j, k])
1584                                    );
1585                                    count += 1;
1586                                }
1587                            }
1588                        }
1589                    }
1590                }
1591                _ => {}
1592            }
1593        }
1594    }
1595
1596    #[test]
1597    fn test_serializable_edge_cases() {
1598        // Test zero-sized tensor
1599        let zero_tensor = Tensor::new(vec![0]);
1600        let json = <Tensor as crate::serialization::Serializable>::to_json(&zero_tensor).unwrap();
1601        let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1602        assert_eq!(zero_tensor.shape().dims, restored.shape().dims);
1603        assert_eq!(zero_tensor.size(), restored.size());
1604
1605        let binary =
1606            <Tensor as crate::serialization::Serializable>::to_binary(&zero_tensor).unwrap();
1607        let restored =
1608            <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1609        assert_eq!(zero_tensor.shape().dims, restored.shape().dims);
1610        assert_eq!(zero_tensor.size(), restored.size());
1611
1612        // Test tensor with special values (use reasonable large values instead of f32::MAX/MIN)
1613        let mut special_tensor = Tensor::zeros(vec![3]);
1614        special_tensor.set(&[0], 0.0); // Zero
1615        special_tensor.set(&[1], 1000000.0); // Large positive value
1616        special_tensor.set(&[2], -1000000.0); // Large negative value
1617
1618        let json =
1619            <Tensor as crate::serialization::Serializable>::to_json(&special_tensor).unwrap();
1620        let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1621        assert_eq!(special_tensor.get(&[0]), restored.get(&[0]));
1622        assert_eq!(special_tensor.get(&[1]), restored.get(&[1]));
1623        assert_eq!(special_tensor.get(&[2]), restored.get(&[2]));
1624
1625        let binary =
1626            <Tensor as crate::serialization::Serializable>::to_binary(&special_tensor).unwrap();
1627        let restored =
1628            <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1629        assert_eq!(special_tensor.get(&[0]), restored.get(&[0]));
1630        assert_eq!(special_tensor.get(&[1]), restored.get(&[1]));
1631        assert_eq!(special_tensor.get(&[2]), restored.get(&[2]));
1632    }
1633}