1use std::collections::HashMap;
183
184use crate::device::{Device, DeviceType};
185use crate::serialization::{
186 FieldValue, FromFieldValue, SerializationError, SerializationResult, StructDeserializer,
187 StructSerializable, StructSerializer, ToFieldValue,
188};
189use crate::tensor::{Shape, Tensor};
190
191impl ToFieldValue for DeviceType {
194 fn to_field_value(&self) -> FieldValue {
200 match self {
201 DeviceType::Cpu => FieldValue::from_enum_unit("Cpu".to_string()),
202 DeviceType::Cuda => FieldValue::from_enum_unit("Cuda".to_string()),
203 }
204 }
205}
206
207impl FromFieldValue for DeviceType {
208 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
219 let (variant, data) =
220 value
221 .as_enum()
222 .map_err(|_| SerializationError::ValidationFailed {
223 field: field_name.to_string(),
224 message: "Expected enum value for device type".to_string(),
225 })?;
226
227 if data.is_some() {
229 return Err(SerializationError::ValidationFailed {
230 field: field_name.to_string(),
231 message: "DeviceType variants should not have associated data".to_string(),
232 });
233 }
234
235 match variant {
236 "Cpu" => Ok(DeviceType::Cpu),
237 "Cuda" => Ok(DeviceType::Cuda),
238 _ => Err(SerializationError::ValidationFailed {
239 field: field_name.to_string(),
240 message: format!("Unknown device type variant: {}", variant),
241 }),
242 }
243 }
244}
245
246impl ToFieldValue for Device {
247 fn to_field_value(&self) -> FieldValue {
253 let mut object = HashMap::new();
254 object.insert("type".to_string(), self.device_type().to_field_value());
255 object.insert("index".to_string(), self.index().to_field_value());
256 FieldValue::from_object(object)
257 }
258}
259
260impl FromFieldValue for Device {
261 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
272 let object = value
273 .as_object()
274 .map_err(|_| SerializationError::ValidationFailed {
275 field: field_name.to_string(),
276 message: "Expected object for device".to_string(),
277 })?;
278
279 let device_type = object
280 .get("type")
281 .ok_or_else(|| SerializationError::ValidationFailed {
282 field: field_name.to_string(),
283 message: "Missing device type field".to_string(),
284 })?
285 .clone();
286
287 let index = object
288 .get("index")
289 .ok_or_else(|| SerializationError::ValidationFailed {
290 field: field_name.to_string(),
291 message: "Missing device index field".to_string(),
292 })?
293 .clone();
294
295 let device_type = DeviceType::from_field_value(device_type, "type")?;
296 let index = usize::from_field_value(index, "index")?;
297
298 match device_type {
299 DeviceType::Cpu => Ok(Device::cpu()),
300 DeviceType::Cuda => Ok(Device::cuda(index)),
301 }
302 }
303}
304
305impl ToFieldValue for crate::tensor::MemoryLayout {
308 fn to_field_value(&self) -> FieldValue {
314 match self {
315 crate::tensor::MemoryLayout::Contiguous => {
316 FieldValue::from_enum_unit("Contiguous".to_string())
317 }
318 crate::tensor::core::MemoryLayout::Strided => {
319 FieldValue::from_enum_unit("Strided".to_string())
320 }
321 crate::tensor::core::MemoryLayout::View => {
322 FieldValue::from_enum_unit("View".to_string())
323 }
324 }
325 }
326}
327
328impl FromFieldValue for crate::tensor::MemoryLayout {
329 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
340 let (variant, data) =
341 value
342 .as_enum()
343 .map_err(|_| SerializationError::ValidationFailed {
344 field: field_name.to_string(),
345 message: "Expected enum value for memory layout".to_string(),
346 })?;
347
348 if data.is_some() {
350 return Err(SerializationError::ValidationFailed {
351 field: field_name.to_string(),
352 message: "MemoryLayout variants should not have associated data".to_string(),
353 });
354 }
355
356 match variant {
357 "Contiguous" => Ok(crate::tensor::MemoryLayout::Contiguous),
358 "Strided" => Ok(crate::tensor::MemoryLayout::Strided),
359 "View" => Ok(crate::tensor::MemoryLayout::View),
360 _ => Err(SerializationError::ValidationFailed {
361 field: field_name.to_string(),
362 message: format!("Unknown memory layout variant: {}", variant),
363 }),
364 }
365 }
366}
367
368impl ToFieldValue for Shape {
371 fn to_field_value(&self) -> FieldValue {
377 let mut object = HashMap::new();
378 object.insert("dims".to_string(), self.dims.to_field_value());
379 object.insert("size".to_string(), self.size.to_field_value());
380 object.insert("strides".to_string(), self.strides.to_field_value());
381 object.insert("layout".to_string(), self.layout.to_field_value());
382 FieldValue::from_object(object)
383 }
384}
385
386impl FromFieldValue for Shape {
387 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
398 let object = value
399 .as_object()
400 .map_err(|_| SerializationError::ValidationFailed {
401 field: field_name.to_string(),
402 message: "Expected object for shape".to_string(),
403 })?;
404
405 let dims = object
406 .get("dims")
407 .ok_or_else(|| SerializationError::ValidationFailed {
408 field: field_name.to_string(),
409 message: "Missing dims field in shape".to_string(),
410 })?
411 .clone();
412
413 let size = object
414 .get("size")
415 .ok_or_else(|| SerializationError::ValidationFailed {
416 field: field_name.to_string(),
417 message: "Missing size field in shape".to_string(),
418 })?
419 .clone();
420
421 let strides = object
422 .get("strides")
423 .ok_or_else(|| SerializationError::ValidationFailed {
424 field: field_name.to_string(),
425 message: "Missing strides field in shape".to_string(),
426 })?
427 .clone();
428
429 let layout = object
430 .get("layout")
431 .ok_or_else(|| SerializationError::ValidationFailed {
432 field: field_name.to_string(),
433 message: "Missing layout field in shape".to_string(),
434 })?
435 .clone();
436
437 let dims = Vec::<usize>::from_field_value(dims, "dims")?;
438 let size = usize::from_field_value(size, "size")?;
439 let strides = Vec::<usize>::from_field_value(strides, "strides")?;
440 let layout = crate::tensor::MemoryLayout::from_field_value(layout, "layout")?;
441
442 let expected_size: usize = dims.iter().product();
444 if size != expected_size {
445 return Err(SerializationError::ValidationFailed {
446 field: field_name.to_string(),
447 message: format!(
448 "Shape size {} doesn't match computed size {}",
449 size, expected_size
450 ),
451 });
452 }
453
454 if dims.len() != strides.len() {
455 return Err(SerializationError::ValidationFailed {
456 field: field_name.to_string(),
457 message: "Dimensions and strides must have same length".to_string(),
458 });
459 }
460
461 Ok(Shape {
462 dims,
463 size,
464 strides,
465 layout,
466 })
467 }
468}
469
470impl StructSerializable for Tensor {
473 fn to_serializer(&self) -> StructSerializer {
482 let data: Vec<f32> =
486 unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()).to_vec() };
487
488 StructSerializer::new()
489 .field("data", &data)
490 .field("shape", self.shape())
491 .field("device", &self.device())
492 .field("requires_grad", &self.requires_grad())
493 }
494
495 fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
508 let data: Vec<f32> = deserializer.field("data")?;
509 let shape: Shape = deserializer.field("shape")?;
510 let device: Device = deserializer.field("device")?;
511 let requires_grad: bool = deserializer.field("requires_grad")?;
512
513 if data.len() != shape.size {
515 return Err(SerializationError::ValidationFailed {
516 field: "tensor".to_string(),
517 message: format!(
518 "Data length {} doesn't match shape size {}",
519 data.len(),
520 shape.size
521 ),
522 });
523 }
524
525 let mut tensor = Tensor::new_on_device(shape.dims.clone(), device);
527
528 if !data.is_empty() {
530 unsafe {
531 let dst = tensor.as_mut_ptr();
532 std::ptr::copy_nonoverlapping(data.as_ptr(), dst, data.len());
533 }
534 }
535
536 tensor.set_requires_grad(requires_grad);
538
539 if tensor.shape().dims != shape.dims
541 || tensor.shape().size != shape.size
542 || tensor.shape().strides != shape.strides
543 {
544 return Err(SerializationError::ValidationFailed {
545 field: "tensor".to_string(),
546 message: "Reconstructed tensor shape doesn't match serialized shape".to_string(),
547 });
548 }
549
550 Ok(tensor)
551 }
552}
553
554impl FromFieldValue for Tensor {
555 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
566 if let Ok(binary_data) = value.as_binary_object() {
568 return Tensor::from_binary(binary_data).map_err(|e| {
569 SerializationError::ValidationFailed {
570 field: field_name.to_string(),
571 message: format!("Failed to deserialize tensor from binary: {}", e),
572 }
573 });
574 }
575
576 if let Ok(json_data) = value.as_json_object() {
578 return Tensor::from_json(json_data).map_err(|e| {
579 SerializationError::ValidationFailed {
580 field: field_name.to_string(),
581 message: format!("Failed to deserialize tensor from JSON: {}", e),
582 }
583 });
584 }
585
586 if let Ok(object) = value.as_object() {
588 let mut deserializer = StructDeserializer::from_fields(object.clone());
590 return Tensor::from_deserializer(&mut deserializer).map_err(|e| {
591 SerializationError::ValidationFailed {
592 field: field_name.to_string(),
593 message: format!("Failed to deserialize tensor from object: {}", e),
594 }
595 });
596 }
597
598 Err(SerializationError::ValidationFailed {
599 field: field_name.to_string(),
600 message: "Expected binary object, JSON object, or structured object for tensor field"
601 .to_string(),
602 })
603 }
604}
605
606impl crate::serialization::Serializable for Tensor {
609 fn to_json(&self) -> SerializationResult<String> {
636 StructSerializable::to_json(self)
637 }
638
639 fn from_json(json: &str) -> SerializationResult<Self> {
671 StructSerializable::from_json(json)
672 }
673
674 fn to_binary(&self) -> SerializationResult<Vec<u8>> {
700 StructSerializable::to_binary(self)
701 }
702
703 fn from_binary(data: &[u8]) -> SerializationResult<Self> {
735 StructSerializable::from_binary(data)
736 }
737}
738
739#[cfg(test)]
740mod tests {
741 use super::*;
751
752 #[test]
755 fn test_device_type_serialization() {
756 let cpu_type = DeviceType::Cpu;
758 let field_value = cpu_type.to_field_value();
759 let deserialized = DeviceType::from_field_value(field_value, "device_type").unwrap();
760 assert_eq!(cpu_type, deserialized);
761
762 let cuda_type = DeviceType::Cuda;
764 let field_value = cuda_type.to_field_value();
765 let deserialized = DeviceType::from_field_value(field_value, "device_type").unwrap();
766 assert_eq!(cuda_type, deserialized);
767 }
768
769 #[test]
770 fn test_device_serialization() {
771 let cpu_device = Device::cpu();
773 let field_value = cpu_device.to_field_value();
774 let deserialized = Device::from_field_value(field_value, "device").unwrap();
775 assert_eq!(cpu_device, deserialized);
776 assert!(deserialized.is_cpu());
777 assert_eq!(deserialized.index(), 0);
778 }
779
780 #[test]
781 fn test_device_serialization_errors() {
782 let invalid_device_type = FieldValue::from_string("invalid".to_string());
784 let result = DeviceType::from_field_value(invalid_device_type, "device_type");
785 assert!(result.is_err());
786
787 let incomplete_device = FieldValue::from_object({
789 let mut obj = HashMap::new();
790 obj.insert(
791 "type".to_string(),
792 FieldValue::from_string("cpu".to_string()),
793 );
794 obj
796 });
797 let result = Device::from_field_value(incomplete_device, "device");
798 assert!(result.is_err());
799 }
800
801 #[test]
804 fn test_memory_layout_serialization() {
805 use crate::tensor::MemoryLayout;
806
807 let layouts = [
808 MemoryLayout::Contiguous,
809 MemoryLayout::Strided,
810 MemoryLayout::View,
811 ];
812
813 for layout in &layouts {
814 let field_value = layout.to_field_value();
815 let deserialized = MemoryLayout::from_field_value(field_value, "layout").unwrap();
816 assert_eq!(*layout, deserialized);
817 }
818 }
819
820 #[test]
821 fn test_shape_serialization() {
822 let shape = Shape::new(vec![2, 3, 4]);
824 let field_value = shape.to_field_value();
825 let deserialized = Shape::from_field_value(field_value, "shape").unwrap();
826 assert_eq!(shape, deserialized);
827 assert_eq!(deserialized.dims, vec![2, 3, 4]);
828 assert_eq!(deserialized.size, 24);
829 assert_eq!(deserialized.strides, vec![12, 4, 1]);
830
831 let strided_shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
833 let field_value = strided_shape.to_field_value();
834 let deserialized = Shape::from_field_value(field_value, "shape").unwrap();
835 assert_eq!(strided_shape, deserialized);
836 }
837
838 #[test]
839 fn test_shape_validation_errors() {
840 use crate::tensor::MemoryLayout;
841
842 let invalid_shape = FieldValue::from_object({
844 let mut obj = HashMap::new();
845 obj.insert("dims".to_string(), vec![2usize, 3].to_field_value());
846 obj.insert("size".to_string(), 10usize.to_field_value()); obj.insert("strides".to_string(), vec![3usize, 1].to_field_value());
848 obj.insert(
849 "layout".to_string(),
850 MemoryLayout::Contiguous.to_field_value(),
851 );
852 obj
853 });
854 let result = Shape::from_field_value(invalid_shape, "shape");
855 assert!(result.is_err());
856
857 let invalid_shape = FieldValue::from_object({
859 let mut obj = HashMap::new();
860 obj.insert("dims".to_string(), vec![2usize, 3].to_field_value());
861 obj.insert("size".to_string(), 6usize.to_field_value());
862 obj.insert("strides".to_string(), vec![3usize].to_field_value()); obj.insert(
864 "layout".to_string(),
865 MemoryLayout::Contiguous.to_field_value(),
866 );
867 obj
868 });
869 let result = Shape::from_field_value(invalid_shape, "shape");
870 assert!(result.is_err());
871 }
872
873 #[test]
876 fn test_tensor_json_roundtrip() {
877 let mut tensor = Tensor::zeros(vec![2, 3]);
879 tensor.set(&[0, 0], 1.0);
880 tensor.set(&[0, 1], 2.0);
881 tensor.set(&[0, 2], 3.0);
882 tensor.set(&[1, 0], 4.0);
883 tensor.set(&[1, 1], 5.0);
884 tensor.set(&[1, 2], 6.0);
885 tensor.set_requires_grad(true);
886
887 let json = tensor.to_json().unwrap();
889 assert!(!json.is_empty());
890
891 let loaded_tensor = Tensor::from_json(&json).unwrap();
893
894 assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
896 assert_eq!(tensor.size(), loaded_tensor.size());
897 assert_eq!(tensor.device(), loaded_tensor.device());
898 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
899
900 for i in 0..2 {
902 for j in 0..3 {
903 assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
904 }
905 }
906 }
907
908 #[test]
909 fn test_tensor_binary_roundtrip() {
910 let mut tensor = Tensor::ones(vec![3, 4]).with_requires_grad();
912
913 tensor.set(&[0, 0], 10.0);
915 tensor.set(&[1, 2], 20.0);
916 tensor.set(&[2, 3], 30.0);
917
918 let binary = tensor.to_binary().unwrap();
920 assert!(!binary.is_empty());
921
922 let loaded_tensor = Tensor::from_binary(&binary).unwrap();
924
925 assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
927 assert_eq!(tensor.size(), loaded_tensor.size());
928 assert_eq!(tensor.device(), loaded_tensor.device());
929 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
930
931 for i in 0..3 {
933 for j in 0..4 {
934 assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
935 }
936 }
937 }
938
939 #[test]
940 fn test_empty_tensor_serialization() {
941 let tensor = Tensor::new(vec![0]);
943
944 let json = tensor.to_json().unwrap();
946 let loaded_tensor = Tensor::from_json(&json).unwrap();
947 assert_eq!(tensor.size(), loaded_tensor.size());
948 assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
949
950 let binary = tensor.to_binary().unwrap();
952 let loaded_tensor = Tensor::from_binary(&binary).unwrap();
953 assert_eq!(tensor.size(), loaded_tensor.size());
954 assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
955 }
956
957 #[test]
958 fn test_large_tensor_serialization() {
959 let mut tensor = Tensor::zeros(vec![100, 100]).with_requires_grad();
961
962 for i in 0..10 {
964 for j in 0..10 {
965 tensor.set(&[i, j], (i * 10 + j) as f32);
966 }
967 }
968
969 let binary = tensor.to_binary().unwrap();
971 let loaded_tensor = Tensor::from_binary(&binary).unwrap();
972
973 assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
975 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
976
977 for i in 0..10 {
979 for j in 0..10 {
980 assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
981 }
982 }
983 }
984
985 #[test]
986 fn test_tensor_as_field_in_struct() {
987 #[derive(Debug)]
989 struct ModelWeights {
990 weight_matrix: Tensor,
991 bias_vector: Tensor,
992 learning_rate: f32,
993 name: String,
994 }
995
996 impl StructSerializable for ModelWeights {
997 fn to_serializer(&self) -> StructSerializer {
998 StructSerializer::new()
999 .field("weight_matrix", &self.weight_matrix)
1000 .field("bias_vector", &self.bias_vector)
1001 .field("learning_rate", &self.learning_rate)
1002 .field("name", &self.name)
1003 }
1004
1005 fn from_deserializer(
1006 deserializer: &mut StructDeserializer,
1007 ) -> SerializationResult<Self> {
1008 Ok(ModelWeights {
1009 weight_matrix: deserializer.field("weight_matrix")?,
1010 bias_vector: deserializer.field("bias_vector")?,
1011 learning_rate: deserializer.field("learning_rate")?,
1012 name: deserializer.field("name")?,
1013 })
1014 }
1015 }
1016
1017 let mut weights = ModelWeights {
1019 weight_matrix: Tensor::zeros(vec![10, 5]),
1020 bias_vector: Tensor::ones(vec![5]).with_requires_grad(),
1021 learning_rate: 0.001,
1022 name: "test_model".to_string(),
1023 };
1024
1025 weights.weight_matrix.set(&[0, 0], 0.5);
1027 weights.weight_matrix.set(&[9, 4], -0.3);
1028 weights.bias_vector.set(&[2], 2.0);
1029
1030 let json = weights.to_json().unwrap();
1032 let loaded_weights = ModelWeights::from_json(&json).unwrap();
1033
1034 assert_eq!(weights.learning_rate, loaded_weights.learning_rate);
1035 assert_eq!(weights.name, loaded_weights.name);
1036 assert_eq!(
1037 weights.weight_matrix.shape().dims,
1038 loaded_weights.weight_matrix.shape().dims
1039 );
1040 assert_eq!(
1041 weights.bias_vector.shape().dims,
1042 loaded_weights.bias_vector.shape().dims
1043 );
1044 assert_eq!(
1045 weights.bias_vector.requires_grad(),
1046 loaded_weights.bias_vector.requires_grad()
1047 );
1048
1049 assert_eq!(
1051 weights.weight_matrix.get(&[0, 0]),
1052 loaded_weights.weight_matrix.get(&[0, 0])
1053 );
1054 assert_eq!(
1055 weights.weight_matrix.get(&[9, 4]),
1056 loaded_weights.weight_matrix.get(&[9, 4])
1057 );
1058 assert_eq!(
1059 weights.bias_vector.get(&[2]),
1060 loaded_weights.bias_vector.get(&[2])
1061 );
1062
1063 let binary = weights.to_binary().unwrap();
1065 let loaded_weights = ModelWeights::from_binary(&binary).unwrap();
1066
1067 assert_eq!(weights.learning_rate, loaded_weights.learning_rate);
1068 assert_eq!(weights.name, loaded_weights.name);
1069 assert_eq!(
1070 weights.weight_matrix.shape().dims,
1071 loaded_weights.weight_matrix.shape().dims
1072 );
1073 assert_eq!(
1074 weights.bias_vector.requires_grad(),
1075 loaded_weights.bias_vector.requires_grad()
1076 );
1077 }
1078
1079 #[test]
1080 fn test_multiple_tensors_in_struct() {
1081 #[derive(Debug)]
1083 struct MultiTensorStruct {
1084 tensor_1d: Tensor,
1085 tensor_2d: Tensor,
1086 tensor_3d: Tensor,
1087 metadata: HashMap<String, String>,
1088 }
1089
1090 impl StructSerializable for MultiTensorStruct {
1091 fn to_serializer(&self) -> StructSerializer {
1092 StructSerializer::new()
1093 .field("tensor_1d", &self.tensor_1d)
1094 .field("tensor_2d", &self.tensor_2d)
1095 .field("tensor_3d", &self.tensor_3d)
1096 .field("metadata", &self.metadata)
1097 }
1098
1099 fn from_deserializer(
1100 deserializer: &mut StructDeserializer,
1101 ) -> SerializationResult<Self> {
1102 Ok(MultiTensorStruct {
1103 tensor_1d: deserializer.field("tensor_1d")?,
1104 tensor_2d: deserializer.field("tensor_2d")?,
1105 tensor_3d: deserializer.field("tensor_3d")?,
1106 metadata: deserializer.field("metadata")?,
1107 })
1108 }
1109 }
1110
1111 let mut multi_tensor = MultiTensorStruct {
1113 tensor_1d: Tensor::zeros(vec![5]),
1114 tensor_2d: Tensor::ones(vec![3, 4]).with_requires_grad(),
1115 tensor_3d: Tensor::zeros(vec![2, 2, 2]),
1116 metadata: {
1117 let mut map = HashMap::new();
1118 map.insert("version".to_string(), "1.0".to_string());
1119 map.insert("type".to_string(), "test".to_string());
1120 map
1121 },
1122 };
1123
1124 multi_tensor.tensor_1d.set(&[0], 10.0);
1126 multi_tensor.tensor_2d.set(&[0, 0], 5.0);
1127 multi_tensor.tensor_3d.set(&[1, 1, 1], 3.0);
1128
1129 let json = multi_tensor.to_json().unwrap();
1131 let loaded = MultiTensorStruct::from_json(&json).unwrap();
1132
1133 assert_eq!(
1134 multi_tensor.tensor_1d.shape().dims,
1135 loaded.tensor_1d.shape().dims
1136 );
1137 assert_eq!(
1138 multi_tensor.tensor_2d.shape().dims,
1139 loaded.tensor_2d.shape().dims
1140 );
1141 assert_eq!(
1142 multi_tensor.tensor_3d.shape().dims,
1143 loaded.tensor_3d.shape().dims
1144 );
1145 assert_eq!(
1146 multi_tensor.tensor_2d.requires_grad(),
1147 loaded.tensor_2d.requires_grad()
1148 );
1149 assert_eq!(multi_tensor.metadata, loaded.metadata);
1150
1151 assert_eq!(multi_tensor.tensor_1d.get(&[0]), loaded.tensor_1d.get(&[0]));
1153 assert_eq!(
1154 multi_tensor.tensor_2d.get(&[0, 0]),
1155 loaded.tensor_2d.get(&[0, 0])
1156 );
1157 assert_eq!(
1158 multi_tensor.tensor_3d.get(&[1, 1, 1]),
1159 loaded.tensor_3d.get(&[1, 1, 1])
1160 );
1161
1162 let binary = multi_tensor.to_binary().unwrap();
1164 let loaded = MultiTensorStruct::from_binary(&binary).unwrap();
1165 assert_eq!(
1166 multi_tensor.tensor_1d.shape().dims,
1167 loaded.tensor_1d.shape().dims
1168 );
1169 assert_eq!(
1170 multi_tensor.tensor_2d.requires_grad(),
1171 loaded.tensor_2d.requires_grad()
1172 );
1173 }
1174
1175 #[test]
1176 fn test_tensor_serialization_errors() {
1177 let mut deserializer = StructDeserializer::from_json(
1179 r#"
1180 {
1181 "data": [1.0, 2.0, 3.0],
1182 "shape": {
1183 "dims": [2, 3],
1184 "size": 6,
1185 "strides": [3, 1],
1186 "layout": "contiguous"
1187 },
1188 "device": {"type": "cpu", "index": 0},
1189 "requires_grad": false
1190 }"#,
1191 )
1192 .unwrap();
1193
1194 let result = Tensor::from_deserializer(&mut deserializer);
1195 assert!(result.is_err()); }
1197
1198 #[test]
1199 fn test_field_value_tensor_roundtrip() {
1200 let mut tensor = Tensor::zeros(vec![2, 2]);
1202 tensor.set(&[0, 0], 1.0);
1203 tensor.set(&[1, 1], 2.0);
1204
1205 let field_value = tensor.to_field_value();
1206 let loaded_tensor = Tensor::from_field_value(field_value, "test_tensor").unwrap();
1207
1208 assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
1209 assert_eq!(tensor.get(&[0, 0]), loaded_tensor.get(&[0, 0]));
1210 assert_eq!(tensor.get(&[1, 1]), loaded_tensor.get(&[1, 1]));
1211 }
1212
1213 #[test]
1214 fn test_different_tensor_shapes() {
1215 let test_shapes = vec![
1216 vec![1], vec![10], vec![3, 4], vec![2, 3, 4], vec![2, 2, 2, 2], ];
1222
1223 for shape in test_shapes {
1224 let tensor = Tensor::zeros(shape.clone()).with_requires_grad();
1225
1226 let json = tensor.to_json().unwrap();
1228 let loaded = Tensor::from_json(&json).unwrap();
1229 assert_eq!(tensor.shape().dims, loaded.shape().dims);
1230 assert_eq!(tensor.requires_grad(), loaded.requires_grad());
1231
1232 let binary = tensor.to_binary().unwrap();
1234 let loaded = Tensor::from_binary(&binary).unwrap();
1235 assert_eq!(tensor.shape().dims, loaded.shape().dims);
1236 assert_eq!(tensor.requires_grad(), loaded.requires_grad());
1237 }
1238 }
1239
1240 #[test]
1243 fn test_serializable_json_methods() {
1244 let mut tensor = Tensor::zeros(vec![2, 3]);
1246 tensor.set(&[0, 0], 1.0);
1247 tensor.set(&[0, 1], 2.0);
1248 tensor.set(&[1, 2], 5.0);
1249 tensor.set_requires_grad(true);
1250
1251 let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1253 assert!(!json.is_empty());
1254 assert!(json.contains("data"));
1255 assert!(json.contains("shape"));
1256 assert!(json.contains("device"));
1257 assert!(json.contains("requires_grad"));
1258
1259 let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1261 assert_eq!(tensor.shape().dims, restored.shape().dims);
1262 assert_eq!(tensor.size(), restored.size());
1263 assert_eq!(tensor.device(), restored.device());
1264 assert_eq!(tensor.requires_grad(), restored.requires_grad());
1265
1266 assert_eq!(tensor.get(&[0, 0]), restored.get(&[0, 0]));
1268 assert_eq!(tensor.get(&[0, 1]), restored.get(&[0, 1]));
1269 assert_eq!(tensor.get(&[1, 2]), restored.get(&[1, 2]));
1270 }
1271
1272 #[test]
1273 fn test_serializable_binary_methods() {
1274 let mut tensor = Tensor::ones(vec![3, 4]);
1276 tensor.set(&[0, 0], 10.0);
1277 tensor.set(&[1, 2], 20.0);
1278 tensor.set(&[2, 3], 30.0);
1279 tensor.set_requires_grad(true);
1280
1281 let binary = <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1283 assert!(!binary.is_empty());
1284
1285 let restored =
1287 <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1288 assert_eq!(tensor.shape().dims, restored.shape().dims);
1289 assert_eq!(tensor.size(), restored.size());
1290 assert_eq!(tensor.device(), restored.device());
1291 assert_eq!(tensor.requires_grad(), restored.requires_grad());
1292
1293 assert_eq!(tensor.get(&[0, 0]), restored.get(&[0, 0]));
1295 assert_eq!(tensor.get(&[1, 2]), restored.get(&[1, 2]));
1296 assert_eq!(tensor.get(&[2, 3]), restored.get(&[2, 3]));
1297 }
1298
1299 #[test]
1300 fn test_serializable_file_io_json() {
1301 use crate::serialization::{Format, Serializable};
1302 use std::fs;
1303 use std::path::Path;
1304
1305 let mut tensor = Tensor::zeros(vec![2, 2]);
1307 tensor.set(&[0, 0], 1.0);
1308 tensor.set(&[0, 1], 2.0);
1309 tensor.set(&[1, 0], 3.0);
1310 tensor.set(&[1, 1], 4.0);
1311 tensor.set_requires_grad(true);
1312
1313 let json_path = "test_tensor_serializable.json";
1315 let json_path_2 = "test_tensor_serializable_2.json";
1316
1317 let _ = fs::remove_file(json_path);
1319 let _ = fs::remove_file(json_path_2);
1320
1321 Serializable::save(&tensor, json_path, Format::Json).unwrap();
1323 assert!(Path::new(json_path).exists());
1324
1325 let loaded_tensor = Tensor::load(json_path, Format::Json).unwrap();
1327 assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
1328 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1329 assert_eq!(tensor.get(&[0, 0]), loaded_tensor.get(&[0, 0]));
1330 assert_eq!(tensor.get(&[1, 1]), loaded_tensor.get(&[1, 1]));
1331
1332 {
1334 let mut writer = std::fs::File::create(json_path_2).unwrap();
1335 Serializable::save_to_writer(&tensor, &mut writer, Format::Json).unwrap();
1336 }
1337 assert!(Path::new(json_path_2).exists());
1338
1339 {
1340 let mut reader = std::fs::File::open(json_path_2).unwrap();
1341 let loaded_tensor = Tensor::load_from_reader(&mut reader, Format::Json).unwrap();
1342 assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
1343 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1344 assert_eq!(tensor.get(&[0, 1]), loaded_tensor.get(&[0, 1]));
1345 assert_eq!(tensor.get(&[1, 0]), loaded_tensor.get(&[1, 0]));
1346 }
1347
1348 let _ = fs::remove_file(json_path);
1350 let _ = fs::remove_file(json_path_2);
1351 }
1352
1353 #[test]
1354 fn test_serializable_file_io_binary() {
1355 use crate::serialization::{Format, Serializable};
1356 use std::fs;
1357 use std::path::Path;
1358
1359 let mut tensor = Tensor::ones(vec![3, 3]);
1361 for i in 0..3 {
1362 for j in 0..3 {
1363 tensor.set(&[i, j], (i * 3 + j) as f32);
1364 }
1365 }
1366 tensor.set_requires_grad(true);
1367
1368 let binary_path = "test_tensor_serializable.bin";
1370 let binary_path_2 = "test_tensor_serializable_2.bin";
1371
1372 let _ = fs::remove_file(binary_path);
1374 let _ = fs::remove_file(binary_path_2);
1375
1376 Serializable::save(&tensor, binary_path, Format::Binary).unwrap();
1378 assert!(Path::new(binary_path).exists());
1379
1380 let loaded_tensor = Tensor::load(binary_path, Format::Binary).unwrap();
1382 assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
1383 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1384
1385 for i in 0..3 {
1387 for j in 0..3 {
1388 assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
1389 }
1390 }
1391
1392 {
1394 let mut writer = std::fs::File::create(binary_path_2).unwrap();
1395 Serializable::save_to_writer(&tensor, &mut writer, Format::Binary).unwrap();
1396 }
1397 assert!(Path::new(binary_path_2).exists());
1398
1399 {
1400 let mut reader = std::fs::File::open(binary_path_2).unwrap();
1401 let loaded_tensor = Tensor::load_from_reader(&mut reader, Format::Binary).unwrap();
1402 assert_eq!(tensor.shape().dims, loaded_tensor.shape().dims);
1403 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1404
1405 for i in 0..3 {
1407 for j in 0..3 {
1408 assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
1409 }
1410 }
1411 }
1412
1413 let _ = fs::remove_file(binary_path);
1415 let _ = fs::remove_file(binary_path_2);
1416 }
1417
1418 #[test]
1419 fn test_serializable_large_tensor_performance() {
1420 let mut tensor = Tensor::zeros(vec![50, 50]);
1422 for i in 0..25 {
1423 for j in 0..25 {
1424 tensor.set(&[i, j], (i * 25 + j) as f32);
1425 }
1426 }
1427 tensor.set_requires_grad(true);
1428
1429 let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1431 assert!(!json.is_empty());
1432 let restored_json =
1433 <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1434 assert_eq!(tensor.shape().dims, restored_json.shape().dims);
1435 assert_eq!(tensor.requires_grad(), restored_json.requires_grad());
1436
1437 let binary = <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1439 assert!(!binary.is_empty());
1440 println!(
1442 "JSON size: {} bytes, Binary size: {} bytes",
1443 json.len(),
1444 binary.len()
1445 );
1446
1447 let restored_binary =
1448 <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1449 assert_eq!(tensor.shape().dims, restored_binary.shape().dims);
1450 assert_eq!(tensor.requires_grad(), restored_binary.requires_grad());
1451
1452 for i in 0..5 {
1454 for j in 0..5 {
1455 assert_eq!(tensor.get(&[i, j]), restored_json.get(&[i, j]));
1456 assert_eq!(tensor.get(&[i, j]), restored_binary.get(&[i, j]));
1457 }
1458 }
1459 }
1460
1461 #[test]
1462 fn test_serializable_error_handling() {
1463 let invalid_json = r#"{"invalid": "json", "structure": true}"#;
1465 let result = <Tensor as crate::serialization::Serializable>::from_json(invalid_json);
1466 assert!(result.is_err());
1467
1468 let empty_json = "{}";
1470 let result = <Tensor as crate::serialization::Serializable>::from_json(empty_json);
1471 assert!(result.is_err());
1472
1473 let invalid_binary = vec![1, 2, 3, 4, 5];
1475 let result = <Tensor as crate::serialization::Serializable>::from_binary(&invalid_binary);
1476 assert!(result.is_err());
1477
1478 let empty_binary = vec![];
1480 let result = <Tensor as crate::serialization::Serializable>::from_binary(&empty_binary);
1481 assert!(result.is_err());
1482 }
1483
1484 #[test]
1485 fn test_serializable_different_shapes_and_types() {
1486 let test_cases = vec![
1487 (vec![1], vec![42.0]),
1489 (vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]),
1491 (vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
1493 (vec![2, 2, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
1495 ];
1496
1497 for (shape, expected_data) in test_cases {
1498 let mut tensor = Tensor::zeros(shape.clone());
1500
1501 match shape.len() {
1503 1 => {
1504 for (i, &value) in expected_data.iter().enumerate().take(shape[0]) {
1505 tensor.set(&[i], value);
1506 }
1507 }
1508 2 => {
1509 let mut idx = 0;
1510 for i in 0..shape[0] {
1511 for j in 0..shape[1] {
1512 if idx < expected_data.len() {
1513 tensor.set(&[i, j], expected_data[idx]);
1514 idx += 1;
1515 }
1516 }
1517 }
1518 }
1519 3 => {
1520 let mut idx = 0;
1521 for i in 0..shape[0] {
1522 for j in 0..shape[1] {
1523 for k in 0..shape[2] {
1524 if idx < expected_data.len() {
1525 tensor.set(&[i, j, k], expected_data[idx]);
1526 idx += 1;
1527 }
1528 }
1529 }
1530 }
1531 }
1532 _ => {}
1533 }
1534 tensor.set_requires_grad(true);
1535
1536 let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1538 let restored_json =
1539 <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1540 assert_eq!(tensor.shape().dims, restored_json.shape().dims);
1541 assert_eq!(tensor.requires_grad(), restored_json.requires_grad());
1542
1543 let binary =
1545 <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1546 let restored_binary =
1547 <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1548 assert_eq!(tensor.shape().dims, restored_binary.shape().dims);
1549 assert_eq!(tensor.requires_grad(), restored_binary.requires_grad());
1550
1551 match shape.len() {
1553 1 => {
1554 for i in 0..shape[0].min(3).min(expected_data.len()) {
1555 assert_eq!(tensor.get(&[i]), restored_json.get(&[i]));
1556 assert_eq!(tensor.get(&[i]), restored_binary.get(&[i]));
1557 }
1558 }
1559 2 => {
1560 let mut count = 0;
1561 for i in 0..shape[0] {
1562 for j in 0..shape[1] {
1563 if count < 3 && count < expected_data.len() {
1564 assert_eq!(tensor.get(&[i, j]), restored_json.get(&[i, j]));
1565 assert_eq!(tensor.get(&[i, j]), restored_binary.get(&[i, j]));
1566 count += 1;
1567 }
1568 }
1569 }
1570 }
1571 3 => {
1572 let mut count = 0;
1573 for i in 0..shape[0] {
1574 for j in 0..shape[1] {
1575 for k in 0..shape[2] {
1576 if count < 3 && count < expected_data.len() {
1577 assert_eq!(
1578 tensor.get(&[i, j, k]),
1579 restored_json.get(&[i, j, k])
1580 );
1581 assert_eq!(
1582 tensor.get(&[i, j, k]),
1583 restored_binary.get(&[i, j, k])
1584 );
1585 count += 1;
1586 }
1587 }
1588 }
1589 }
1590 }
1591 _ => {}
1592 }
1593 }
1594 }
1595
1596 #[test]
1597 fn test_serializable_edge_cases() {
1598 let zero_tensor = Tensor::new(vec![0]);
1600 let json = <Tensor as crate::serialization::Serializable>::to_json(&zero_tensor).unwrap();
1601 let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1602 assert_eq!(zero_tensor.shape().dims, restored.shape().dims);
1603 assert_eq!(zero_tensor.size(), restored.size());
1604
1605 let binary =
1606 <Tensor as crate::serialization::Serializable>::to_binary(&zero_tensor).unwrap();
1607 let restored =
1608 <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1609 assert_eq!(zero_tensor.shape().dims, restored.shape().dims);
1610 assert_eq!(zero_tensor.size(), restored.size());
1611
1612 let mut special_tensor = Tensor::zeros(vec![3]);
1614 special_tensor.set(&[0], 0.0); special_tensor.set(&[1], 1000000.0); special_tensor.set(&[2], -1000000.0); let json =
1619 <Tensor as crate::serialization::Serializable>::to_json(&special_tensor).unwrap();
1620 let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1621 assert_eq!(special_tensor.get(&[0]), restored.get(&[0]));
1622 assert_eq!(special_tensor.get(&[1]), restored.get(&[1]));
1623 assert_eq!(special_tensor.get(&[2]), restored.get(&[2]));
1624
1625 let binary =
1626 <Tensor as crate::serialization::Serializable>::to_binary(&special_tensor).unwrap();
1627 let restored =
1628 <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1629 assert_eq!(special_tensor.get(&[0]), restored.get(&[0]));
1630 assert_eq!(special_tensor.get(&[1]), restored.get(&[1]));
1631 assert_eq!(special_tensor.get(&[2]), restored.get(&[2]));
1632 }
1633}