1use std::collections::HashMap;
183
184use crate::device::{Device, DeviceType};
185use crate::serialization::{
186 FieldValue, FromFieldValue, SerializationError, SerializationResult, StructDeserializer,
187 StructSerializable, StructSerializer, ToFieldValue,
188};
189use crate::tensor::{Shape, Tensor};
190
191impl ToFieldValue for DeviceType {
194 fn to_field_value(&self) -> FieldValue {
200 match self {
201 DeviceType::Cpu => FieldValue::from_enum_unit("Cpu".to_string()),
202 DeviceType::Cuda => FieldValue::from_enum_unit("Cuda".to_string()),
203 }
204 }
205}
206
207impl FromFieldValue for DeviceType {
208 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
219 let (variant, data) =
220 value
221 .as_enum()
222 .map_err(|_| SerializationError::ValidationFailed {
223 field: field_name.to_string(),
224 message: "Expected enum value for device type".to_string(),
225 })?;
226
227 if data.is_some() {
229 return Err(SerializationError::ValidationFailed {
230 field: field_name.to_string(),
231 message: "DeviceType variants should not have associated data".to_string(),
232 });
233 }
234
235 match variant {
236 "Cpu" => Ok(DeviceType::Cpu),
237 "Cuda" => Ok(DeviceType::Cuda),
238 _ => Err(SerializationError::ValidationFailed {
239 field: field_name.to_string(),
240 message: format!("Unknown device type variant: {}", variant),
241 }),
242 }
243 }
244}
245
246impl ToFieldValue for Device {
247 fn to_field_value(&self) -> FieldValue {
253 let mut object = HashMap::new();
254 object.insert("type".to_string(), self.device_type().to_field_value());
255 object.insert("index".to_string(), self.index().to_field_value());
256 FieldValue::from_object(object)
257 }
258}
259
260impl FromFieldValue for Device {
261 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
272 let object = value
273 .as_object()
274 .map_err(|_| SerializationError::ValidationFailed {
275 field: field_name.to_string(),
276 message: "Expected object for device".to_string(),
277 })?;
278
279 let device_type = object
280 .get("type")
281 .ok_or_else(|| SerializationError::ValidationFailed {
282 field: field_name.to_string(),
283 message: "Missing device type field".to_string(),
284 })?
285 .clone();
286
287 let index = object
288 .get("index")
289 .ok_or_else(|| SerializationError::ValidationFailed {
290 field: field_name.to_string(),
291 message: "Missing device index field".to_string(),
292 })?
293 .clone();
294
295 let device_type = DeviceType::from_field_value(device_type, "type")?;
296 let index = usize::from_field_value(index, "index")?;
297
298 match device_type {
299 DeviceType::Cpu => Ok(Device::cpu()),
300 DeviceType::Cuda => Ok(Device::cuda(index)),
301 }
302 }
303}
304
305impl ToFieldValue for crate::tensor::MemoryLayout {
308 fn to_field_value(&self) -> FieldValue {
314 match self {
315 crate::tensor::MemoryLayout::Contiguous => {
316 FieldValue::from_enum_unit("Contiguous".to_string())
317 }
318 crate::tensor::core::MemoryLayout::Strided => {
319 FieldValue::from_enum_unit("Strided".to_string())
320 }
321 crate::tensor::core::MemoryLayout::View => {
322 FieldValue::from_enum_unit("View".to_string())
323 }
324 }
325 }
326}
327
328impl FromFieldValue for crate::tensor::MemoryLayout {
329 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
340 let (variant, data) =
341 value
342 .as_enum()
343 .map_err(|_| SerializationError::ValidationFailed {
344 field: field_name.to_string(),
345 message: "Expected enum value for memory layout".to_string(),
346 })?;
347
348 if data.is_some() {
350 return Err(SerializationError::ValidationFailed {
351 field: field_name.to_string(),
352 message: "MemoryLayout variants should not have associated data".to_string(),
353 });
354 }
355
356 match variant {
357 "Contiguous" => Ok(crate::tensor::MemoryLayout::Contiguous),
358 "Strided" => Ok(crate::tensor::MemoryLayout::Strided),
359 "View" => Ok(crate::tensor::MemoryLayout::View),
360 _ => Err(SerializationError::ValidationFailed {
361 field: field_name.to_string(),
362 message: format!("Unknown memory layout variant: {}", variant),
363 }),
364 }
365 }
366}
367
368impl ToFieldValue for Shape {
371 fn to_field_value(&self) -> FieldValue {
377 let mut object = HashMap::new();
378 object.insert("dims".to_string(), self.dims().to_vec().to_field_value());
379 object.insert("size".to_string(), self.size().to_field_value());
380 object.insert(
381 "strides".to_string(),
382 self.strides().to_vec().to_field_value(),
383 );
384 object.insert("layout".to_string(), self.layout().to_field_value());
385 FieldValue::from_object(object)
386 }
387}
388
389impl FromFieldValue for Shape {
390 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
401 let object = value
402 .as_object()
403 .map_err(|_| SerializationError::ValidationFailed {
404 field: field_name.to_string(),
405 message: "Expected object for shape".to_string(),
406 })?;
407
408 let dims = object
409 .get("dims")
410 .ok_or_else(|| SerializationError::ValidationFailed {
411 field: field_name.to_string(),
412 message: "Missing dims field in shape".to_string(),
413 })?
414 .clone();
415
416 let size = object
417 .get("size")
418 .ok_or_else(|| SerializationError::ValidationFailed {
419 field: field_name.to_string(),
420 message: "Missing size field in shape".to_string(),
421 })?
422 .clone();
423
424 let strides = object
425 .get("strides")
426 .ok_or_else(|| SerializationError::ValidationFailed {
427 field: field_name.to_string(),
428 message: "Missing strides field in shape".to_string(),
429 })?
430 .clone();
431
432 let layout = object
433 .get("layout")
434 .ok_or_else(|| SerializationError::ValidationFailed {
435 field: field_name.to_string(),
436 message: "Missing layout field in shape".to_string(),
437 })?
438 .clone();
439
440 let dims = Vec::<usize>::from_field_value(dims, "dims")?;
441 let size = usize::from_field_value(size, "size")?;
442 let strides = Vec::<usize>::from_field_value(strides, "strides")?;
443 let layout = crate::tensor::MemoryLayout::from_field_value(layout, "layout")?;
444
445 let expected_size: usize = dims.iter().product();
447 if size != expected_size {
448 return Err(SerializationError::ValidationFailed {
449 field: field_name.to_string(),
450 message: format!(
451 "Shape size {} doesn't match computed size {}",
452 size, expected_size
453 ),
454 });
455 }
456
457 if dims.len() != strides.len() {
458 return Err(SerializationError::ValidationFailed {
459 field: field_name.to_string(),
460 message: "Dimensions and strides must have same length".to_string(),
461 });
462 }
463
464 let mut shape = match layout {
466 crate::tensor::MemoryLayout::Contiguous => Shape::new(dims),
467 _ => Shape::with_strides(dims, strides),
468 };
469
470 if matches!(layout, crate::tensor::MemoryLayout::View) {
472 shape = Shape::as_view(shape.dims().to_vec(), shape.strides().to_vec());
473 }
474
475 Ok(shape)
476 }
477}
478
479impl StructSerializable for Tensor {
482 fn to_serializer(&self) -> StructSerializer {
491 let data: Vec<f32> =
495 unsafe { std::slice::from_raw_parts(self.as_ptr(), self.size()).to_vec() };
496
497 StructSerializer::new()
498 .field("data", &data)
499 .field("shape", self.shape())
500 .field("device", &self.device())
501 .field("requires_grad", &self.requires_grad())
502 }
503
504 fn from_deserializer(deserializer: &mut StructDeserializer) -> SerializationResult<Self> {
517 let data: Vec<f32> = deserializer.field("data")?;
518 let shape: Shape = deserializer.field("shape")?;
519 let device: Device = deserializer.field("device")?;
520 let requires_grad: bool = deserializer.field("requires_grad")?;
521
522 if data.len() != shape.size() {
524 return Err(SerializationError::ValidationFailed {
525 field: "tensor".to_string(),
526 message: format!(
527 "Data length {} doesn't match shape size {}",
528 data.len(),
529 shape.size()
530 ),
531 });
532 }
533
534 let mut tensor = Tensor::new_on_device(shape.dims().to_vec(), device);
536
537 if !data.is_empty() {
539 unsafe {
540 let dst = tensor.as_mut_ptr();
541 std::ptr::copy_nonoverlapping(data.as_ptr(), dst, data.len());
542 }
543 }
544
545 tensor.set_requires_grad(requires_grad);
547
548 if tensor.shape().dims() != shape.dims()
550 || tensor.shape().size() != shape.size()
551 || tensor.shape().strides() != shape.strides()
552 {
553 return Err(SerializationError::ValidationFailed {
554 field: "tensor".to_string(),
555 message: "Reconstructed tensor shape doesn't match serialized shape".to_string(),
556 });
557 }
558
559 Ok(tensor)
560 }
561}
562
563impl FromFieldValue for Tensor {
564 fn from_field_value(value: FieldValue, field_name: &str) -> SerializationResult<Self> {
575 if let Ok(binary_data) = value.as_binary_object() {
577 return Tensor::from_binary(binary_data).map_err(|e| {
578 SerializationError::ValidationFailed {
579 field: field_name.to_string(),
580 message: format!("Failed to deserialize tensor from binary: {}", e),
581 }
582 });
583 }
584
585 if let Ok(json_data) = value.as_json_object() {
587 return Tensor::from_json(json_data).map_err(|e| {
588 SerializationError::ValidationFailed {
589 field: field_name.to_string(),
590 message: format!("Failed to deserialize tensor from JSON: {}", e),
591 }
592 });
593 }
594
595 if let Ok(object) = value.as_object() {
597 let mut deserializer = StructDeserializer::from_fields(object.clone());
599 return Tensor::from_deserializer(&mut deserializer).map_err(|e| {
600 SerializationError::ValidationFailed {
601 field: field_name.to_string(),
602 message: format!("Failed to deserialize tensor from object: {}", e),
603 }
604 });
605 }
606
607 Err(SerializationError::ValidationFailed {
608 field: field_name.to_string(),
609 message: "Expected binary object, JSON object, or structured object for tensor field"
610 .to_string(),
611 })
612 }
613}
614
615impl crate::serialization::Serializable for Tensor {
618 fn to_json(&self) -> SerializationResult<String> {
645 StructSerializable::to_json(self)
646 }
647
648 fn from_json(json: &str) -> SerializationResult<Self> {
680 StructSerializable::from_json(json)
681 }
682
683 fn to_binary(&self) -> SerializationResult<Vec<u8>> {
709 StructSerializable::to_binary(self)
710 }
711
712 fn from_binary(data: &[u8]) -> SerializationResult<Self> {
744 StructSerializable::from_binary(data)
745 }
746}
747
748#[cfg(test)]
749mod tests {
750 use super::*;
760
761 #[test]
764 fn test_device_type_serialization() {
765 let cpu_type = DeviceType::Cpu;
767 let field_value = cpu_type.to_field_value();
768 let deserialized = DeviceType::from_field_value(field_value, "device_type").unwrap();
769 assert_eq!(cpu_type, deserialized);
770
771 let cuda_type = DeviceType::Cuda;
773 let field_value = cuda_type.to_field_value();
774 let deserialized = DeviceType::from_field_value(field_value, "device_type").unwrap();
775 assert_eq!(cuda_type, deserialized);
776 }
777
778 #[test]
779 fn test_device_serialization() {
780 let cpu_device = Device::cpu();
782 let field_value = cpu_device.to_field_value();
783 let deserialized = Device::from_field_value(field_value, "device").unwrap();
784 assert_eq!(cpu_device, deserialized);
785 assert!(deserialized.is_cpu());
786 assert_eq!(deserialized.index(), 0);
787 }
788
789 #[test]
790 fn test_device_serialization_errors() {
791 let invalid_device_type = FieldValue::from_string("invalid".to_string());
793 let result = DeviceType::from_field_value(invalid_device_type, "device_type");
794 assert!(result.is_err());
795
796 let incomplete_device = FieldValue::from_object({
798 let mut obj = HashMap::new();
799 obj.insert(
800 "type".to_string(),
801 FieldValue::from_string("cpu".to_string()),
802 );
803 obj
805 });
806 let result = Device::from_field_value(incomplete_device, "device");
807 assert!(result.is_err());
808 }
809
810 #[test]
813 fn test_memory_layout_serialization() {
814 use crate::tensor::MemoryLayout;
815
816 let layouts = [
817 MemoryLayout::Contiguous,
818 MemoryLayout::Strided,
819 MemoryLayout::View,
820 ];
821
822 for layout in &layouts {
823 let field_value = layout.to_field_value();
824 let deserialized = MemoryLayout::from_field_value(field_value, "layout").unwrap();
825 assert_eq!(*layout, deserialized);
826 }
827 }
828
829 #[test]
830 fn test_shape_serialization() {
831 let shape = Shape::new(vec![2, 3, 4]);
833 let field_value = shape.to_field_value();
834 let deserialized = Shape::from_field_value(field_value, "shape").unwrap();
835 assert_eq!(shape, deserialized);
836 assert_eq!(deserialized.dims(), vec![2, 3, 4]);
837 assert_eq!(deserialized.size(), 24);
838 assert_eq!(deserialized.strides(), vec![12, 4, 1]);
839
840 let strided_shape = Shape::with_strides(vec![2, 3], vec![6, 2]);
842 let field_value = strided_shape.to_field_value();
843 let deserialized = Shape::from_field_value(field_value, "shape").unwrap();
844 assert_eq!(strided_shape, deserialized);
845 }
846
847 #[test]
848 fn test_shape_validation_errors() {
849 use crate::tensor::MemoryLayout;
850
851 let invalid_shape = FieldValue::from_object({
853 let mut obj = HashMap::new();
854 obj.insert("dims".to_string(), vec![2usize, 3].to_field_value());
855 obj.insert("size".to_string(), 10usize.to_field_value()); obj.insert("strides".to_string(), vec![3usize, 1].to_field_value());
857 obj.insert(
858 "layout".to_string(),
859 MemoryLayout::Contiguous.to_field_value(),
860 );
861 obj
862 });
863 let result = Shape::from_field_value(invalid_shape, "shape");
864 assert!(result.is_err());
865
866 let invalid_shape = FieldValue::from_object({
868 let mut obj = HashMap::new();
869 obj.insert("dims".to_string(), vec![2usize, 3].to_field_value());
870 obj.insert("size".to_string(), 6usize.to_field_value());
871 obj.insert("strides".to_string(), vec![3usize].to_field_value()); obj.insert(
873 "layout".to_string(),
874 MemoryLayout::Contiguous.to_field_value(),
875 );
876 obj
877 });
878 let result = Shape::from_field_value(invalid_shape, "shape");
879 assert!(result.is_err());
880 }
881
882 #[test]
885 fn test_tensor_json_roundtrip() {
886 let mut tensor = Tensor::zeros(vec![2, 3]);
888 tensor.set(&[0, 0], 1.0);
889 tensor.set(&[0, 1], 2.0);
890 tensor.set(&[0, 2], 3.0);
891 tensor.set(&[1, 0], 4.0);
892 tensor.set(&[1, 1], 5.0);
893 tensor.set(&[1, 2], 6.0);
894 tensor.set_requires_grad(true);
895
896 let json = tensor.to_json().unwrap();
898 assert!(!json.is_empty());
899
900 let loaded_tensor = Tensor::from_json(&json).unwrap();
902
903 assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
905 assert_eq!(tensor.size(), loaded_tensor.size());
906 assert_eq!(tensor.device(), loaded_tensor.device());
907 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
908
909 for i in 0..2 {
911 for j in 0..3 {
912 assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
913 }
914 }
915 }
916
917 #[test]
918 fn test_tensor_binary_roundtrip() {
919 let mut tensor = Tensor::ones(vec![3, 4]).with_requires_grad();
921
922 tensor.set(&[0, 0], 10.0);
924 tensor.set(&[1, 2], 20.0);
925 tensor.set(&[2, 3], 30.0);
926
927 let binary = tensor.to_binary().unwrap();
929 assert!(!binary.is_empty());
930
931 let loaded_tensor = Tensor::from_binary(&binary).unwrap();
933
934 assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
936 assert_eq!(tensor.size(), loaded_tensor.size());
937 assert_eq!(tensor.device(), loaded_tensor.device());
938 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
939
940 for i in 0..3 {
942 for j in 0..4 {
943 assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
944 }
945 }
946 }
947
948 #[test]
949 fn test_empty_tensor_serialization() {
950 let tensor = Tensor::new(vec![0]);
952
953 let json = tensor.to_json().unwrap();
955 let loaded_tensor = Tensor::from_json(&json).unwrap();
956 assert_eq!(tensor.size(), loaded_tensor.size());
957 assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
958
959 let binary = tensor.to_binary().unwrap();
961 let loaded_tensor = Tensor::from_binary(&binary).unwrap();
962 assert_eq!(tensor.size(), loaded_tensor.size());
963 assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
964 }
965
966 #[test]
967 fn test_large_tensor_serialization() {
968 let mut tensor = Tensor::zeros(vec![100, 100]).with_requires_grad();
970
971 for i in 0..10 {
973 for j in 0..10 {
974 tensor.set(&[i, j], (i * 10 + j) as f32);
975 }
976 }
977
978 let binary = tensor.to_binary().unwrap();
980 let loaded_tensor = Tensor::from_binary(&binary).unwrap();
981
982 assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
984 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
985
986 for i in 0..10 {
988 for j in 0..10 {
989 assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
990 }
991 }
992 }
993
994 #[test]
995 fn test_tensor_as_field_in_struct() {
996 #[derive(Debug)]
998 struct ModelWeights {
999 weight_matrix: Tensor,
1000 bias_vector: Tensor,
1001 learning_rate: f32,
1002 name: String,
1003 }
1004
1005 impl StructSerializable for ModelWeights {
1006 fn to_serializer(&self) -> StructSerializer {
1007 StructSerializer::new()
1008 .field("weight_matrix", &self.weight_matrix)
1009 .field("bias_vector", &self.bias_vector)
1010 .field("learning_rate", &self.learning_rate)
1011 .field("name", &self.name)
1012 }
1013
1014 fn from_deserializer(
1015 deserializer: &mut StructDeserializer,
1016 ) -> SerializationResult<Self> {
1017 Ok(ModelWeights {
1018 weight_matrix: deserializer.field("weight_matrix")?,
1019 bias_vector: deserializer.field("bias_vector")?,
1020 learning_rate: deserializer.field("learning_rate")?,
1021 name: deserializer.field("name")?,
1022 })
1023 }
1024 }
1025
1026 let mut weights = ModelWeights {
1028 weight_matrix: Tensor::zeros(vec![10, 5]),
1029 bias_vector: Tensor::ones(vec![5]).with_requires_grad(),
1030 learning_rate: 0.001,
1031 name: "test_model".to_string(),
1032 };
1033
1034 weights.weight_matrix.set(&[0, 0], 0.5);
1036 weights.weight_matrix.set(&[9, 4], -0.3);
1037 weights.bias_vector.set(&[2], 2.0);
1038
1039 let json = weights.to_json().unwrap();
1041 let loaded_weights = ModelWeights::from_json(&json).unwrap();
1042
1043 assert_eq!(weights.learning_rate, loaded_weights.learning_rate);
1044 assert_eq!(weights.name, loaded_weights.name);
1045 assert_eq!(
1046 weights.weight_matrix.shape().dims(),
1047 loaded_weights.weight_matrix.shape().dims()
1048 );
1049 assert_eq!(
1050 weights.bias_vector.shape().dims(),
1051 loaded_weights.bias_vector.shape().dims()
1052 );
1053 assert_eq!(
1054 weights.bias_vector.requires_grad(),
1055 loaded_weights.bias_vector.requires_grad()
1056 );
1057
1058 assert_eq!(
1060 weights.weight_matrix.get(&[0, 0]),
1061 loaded_weights.weight_matrix.get(&[0, 0])
1062 );
1063 assert_eq!(
1064 weights.weight_matrix.get(&[9, 4]),
1065 loaded_weights.weight_matrix.get(&[9, 4])
1066 );
1067 assert_eq!(
1068 weights.bias_vector.get(&[2]),
1069 loaded_weights.bias_vector.get(&[2])
1070 );
1071
1072 let binary = weights.to_binary().unwrap();
1074 let loaded_weights = ModelWeights::from_binary(&binary).unwrap();
1075
1076 assert_eq!(weights.learning_rate, loaded_weights.learning_rate);
1077 assert_eq!(weights.name, loaded_weights.name);
1078 assert_eq!(
1079 weights.weight_matrix.shape().dims(),
1080 loaded_weights.weight_matrix.shape().dims()
1081 );
1082 assert_eq!(
1083 weights.bias_vector.requires_grad(),
1084 loaded_weights.bias_vector.requires_grad()
1085 );
1086 }
1087
1088 #[test]
1089 fn test_multiple_tensors_in_struct() {
1090 #[derive(Debug)]
1092 struct MultiTensorStruct {
1093 tensor_1d: Tensor,
1094 tensor_2d: Tensor,
1095 tensor_3d: Tensor,
1096 metadata: HashMap<String, String>,
1097 }
1098
1099 impl StructSerializable for MultiTensorStruct {
1100 fn to_serializer(&self) -> StructSerializer {
1101 StructSerializer::new()
1102 .field("tensor_1d", &self.tensor_1d)
1103 .field("tensor_2d", &self.tensor_2d)
1104 .field("tensor_3d", &self.tensor_3d)
1105 .field("metadata", &self.metadata)
1106 }
1107
1108 fn from_deserializer(
1109 deserializer: &mut StructDeserializer,
1110 ) -> SerializationResult<Self> {
1111 Ok(MultiTensorStruct {
1112 tensor_1d: deserializer.field("tensor_1d")?,
1113 tensor_2d: deserializer.field("tensor_2d")?,
1114 tensor_3d: deserializer.field("tensor_3d")?,
1115 metadata: deserializer.field("metadata")?,
1116 })
1117 }
1118 }
1119
1120 let mut multi_tensor = MultiTensorStruct {
1122 tensor_1d: Tensor::zeros(vec![5]),
1123 tensor_2d: Tensor::ones(vec![3, 4]).with_requires_grad(),
1124 tensor_3d: Tensor::zeros(vec![2, 2, 2]),
1125 metadata: {
1126 let mut map = HashMap::new();
1127 map.insert("version".to_string(), "1.0".to_string());
1128 map.insert("type".to_string(), "test".to_string());
1129 map
1130 },
1131 };
1132
1133 multi_tensor.tensor_1d.set(&[0], 10.0);
1135 multi_tensor.tensor_2d.set(&[0, 0], 5.0);
1136 multi_tensor.tensor_3d.set(&[1, 1, 1], 3.0);
1137
1138 let json = multi_tensor.to_json().unwrap();
1140 let loaded = MultiTensorStruct::from_json(&json).unwrap();
1141
1142 assert_eq!(
1143 multi_tensor.tensor_1d.shape().dims(),
1144 loaded.tensor_1d.shape().dims()
1145 );
1146 assert_eq!(
1147 multi_tensor.tensor_2d.shape().dims(),
1148 loaded.tensor_2d.shape().dims()
1149 );
1150 assert_eq!(
1151 multi_tensor.tensor_3d.shape().dims(),
1152 loaded.tensor_3d.shape().dims()
1153 );
1154 assert_eq!(
1155 multi_tensor.tensor_2d.requires_grad(),
1156 loaded.tensor_2d.requires_grad()
1157 );
1158 assert_eq!(multi_tensor.metadata, loaded.metadata);
1159
1160 assert_eq!(multi_tensor.tensor_1d.get(&[0]), loaded.tensor_1d.get(&[0]));
1162 assert_eq!(
1163 multi_tensor.tensor_2d.get(&[0, 0]),
1164 loaded.tensor_2d.get(&[0, 0])
1165 );
1166 assert_eq!(
1167 multi_tensor.tensor_3d.get(&[1, 1, 1]),
1168 loaded.tensor_3d.get(&[1, 1, 1])
1169 );
1170
1171 let binary = multi_tensor.to_binary().unwrap();
1173 let loaded = MultiTensorStruct::from_binary(&binary).unwrap();
1174 assert_eq!(
1175 multi_tensor.tensor_1d.shape().dims(),
1176 loaded.tensor_1d.shape().dims()
1177 );
1178 assert_eq!(
1179 multi_tensor.tensor_2d.requires_grad(),
1180 loaded.tensor_2d.requires_grad()
1181 );
1182 }
1183
1184 #[test]
1185 fn test_tensor_serialization_errors() {
1186 let mut deserializer = StructDeserializer::from_json(
1188 r#"
1189 {
1190 "data": [1.0, 2.0, 3.0],
1191 "shape": {
1192 "dims": [2, 3],
1193 "size": 6,
1194 "strides": [3, 1],
1195 "layout": "contiguous"
1196 },
1197 "device": {"type": "cpu", "index": 0},
1198 "requires_grad": false
1199 }"#,
1200 )
1201 .unwrap();
1202
1203 let result = Tensor::from_deserializer(&mut deserializer);
1204 assert!(result.is_err()); }
1206
1207 #[test]
1208 fn test_field_value_tensor_roundtrip() {
1209 let mut tensor = Tensor::zeros(vec![2, 2]);
1211 tensor.set(&[0, 0], 1.0);
1212 tensor.set(&[1, 1], 2.0);
1213
1214 let field_value = tensor.to_field_value();
1215 let loaded_tensor = Tensor::from_field_value(field_value, "test_tensor").unwrap();
1216
1217 assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
1218 assert_eq!(tensor.get(&[0, 0]), loaded_tensor.get(&[0, 0]));
1219 assert_eq!(tensor.get(&[1, 1]), loaded_tensor.get(&[1, 1]));
1220 }
1221
1222 #[test]
1223 fn test_different_tensor_shapes() {
1224 let test_shapes = vec![
1225 vec![1], vec![10], vec![3, 4], vec![2, 3, 4], vec![2, 2, 2, 2], ];
1231
1232 for shape in test_shapes {
1233 let tensor = Tensor::zeros(shape.clone()).with_requires_grad();
1234
1235 let json = tensor.to_json().unwrap();
1237 let loaded = Tensor::from_json(&json).unwrap();
1238 assert_eq!(tensor.shape().dims(), loaded.shape().dims());
1239 assert_eq!(tensor.requires_grad(), loaded.requires_grad());
1240
1241 let binary = tensor.to_binary().unwrap();
1243 let loaded = Tensor::from_binary(&binary).unwrap();
1244 assert_eq!(tensor.shape().dims(), loaded.shape().dims());
1245 assert_eq!(tensor.requires_grad(), loaded.requires_grad());
1246 }
1247 }
1248
1249 #[test]
1252 fn test_serializable_json_methods() {
1253 let mut tensor = Tensor::zeros(vec![2, 3]);
1255 tensor.set(&[0, 0], 1.0);
1256 tensor.set(&[0, 1], 2.0);
1257 tensor.set(&[1, 2], 5.0);
1258 tensor.set_requires_grad(true);
1259
1260 let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1262 assert!(!json.is_empty());
1263 assert!(json.contains("data"));
1264 assert!(json.contains("shape"));
1265 assert!(json.contains("device"));
1266 assert!(json.contains("requires_grad"));
1267
1268 let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1270 assert_eq!(tensor.shape().dims(), restored.shape().dims());
1271 assert_eq!(tensor.size(), restored.size());
1272 assert_eq!(tensor.device(), restored.device());
1273 assert_eq!(tensor.requires_grad(), restored.requires_grad());
1274
1275 assert_eq!(tensor.get(&[0, 0]), restored.get(&[0, 0]));
1277 assert_eq!(tensor.get(&[0, 1]), restored.get(&[0, 1]));
1278 assert_eq!(tensor.get(&[1, 2]), restored.get(&[1, 2]));
1279 }
1280
1281 #[test]
1282 fn test_serializable_binary_methods() {
1283 let mut tensor = Tensor::ones(vec![3, 4]);
1285 tensor.set(&[0, 0], 10.0);
1286 tensor.set(&[1, 2], 20.0);
1287 tensor.set(&[2, 3], 30.0);
1288 tensor.set_requires_grad(true);
1289
1290 let binary = <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1292 assert!(!binary.is_empty());
1293
1294 let restored =
1296 <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1297 assert_eq!(tensor.shape().dims(), restored.shape().dims());
1298 assert_eq!(tensor.size(), restored.size());
1299 assert_eq!(tensor.device(), restored.device());
1300 assert_eq!(tensor.requires_grad(), restored.requires_grad());
1301
1302 assert_eq!(tensor.get(&[0, 0]), restored.get(&[0, 0]));
1304 assert_eq!(tensor.get(&[1, 2]), restored.get(&[1, 2]));
1305 assert_eq!(tensor.get(&[2, 3]), restored.get(&[2, 3]));
1306 }
1307
1308 #[test]
1309 fn test_serializable_file_io_json() {
1310 use crate::serialization::{Format, Serializable};
1311 use std::fs;
1312 use std::path::Path;
1313
1314 let mut tensor = Tensor::zeros(vec![2, 2]);
1316 tensor.set(&[0, 0], 1.0);
1317 tensor.set(&[0, 1], 2.0);
1318 tensor.set(&[1, 0], 3.0);
1319 tensor.set(&[1, 1], 4.0);
1320 tensor.set_requires_grad(true);
1321
1322 let json_path = "test_tensor_serializable.json";
1324 let json_path_2 = "test_tensor_serializable_2.json";
1325
1326 let _ = fs::remove_file(json_path);
1328 let _ = fs::remove_file(json_path_2);
1329
1330 Serializable::save(&tensor, json_path, Format::Json).unwrap();
1332 assert!(Path::new(json_path).exists());
1333
1334 let loaded_tensor = Tensor::load(json_path, Format::Json).unwrap();
1336 assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
1337 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1338 assert_eq!(tensor.get(&[0, 0]), loaded_tensor.get(&[0, 0]));
1339 assert_eq!(tensor.get(&[1, 1]), loaded_tensor.get(&[1, 1]));
1340
1341 {
1343 let mut writer = std::fs::File::create(json_path_2).unwrap();
1344 Serializable::save_to_writer(&tensor, &mut writer, Format::Json).unwrap();
1345 }
1346 assert!(Path::new(json_path_2).exists());
1347
1348 {
1349 let mut reader = std::fs::File::open(json_path_2).unwrap();
1350 let loaded_tensor = Tensor::load_from_reader(&mut reader, Format::Json).unwrap();
1351 assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
1352 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1353 assert_eq!(tensor.get(&[0, 1]), loaded_tensor.get(&[0, 1]));
1354 assert_eq!(tensor.get(&[1, 0]), loaded_tensor.get(&[1, 0]));
1355 }
1356
1357 let _ = fs::remove_file(json_path);
1359 let _ = fs::remove_file(json_path_2);
1360 }
1361
1362 #[test]
1363 fn test_serializable_file_io_binary() {
1364 use crate::serialization::{Format, Serializable};
1365 use std::fs;
1366 use std::path::Path;
1367
1368 let mut tensor = Tensor::ones(vec![3, 3]);
1370 for i in 0..3 {
1371 for j in 0..3 {
1372 tensor.set(&[i, j], (i * 3 + j) as f32);
1373 }
1374 }
1375 tensor.set_requires_grad(true);
1376
1377 let binary_path = "test_tensor_serializable.bin";
1379 let binary_path_2 = "test_tensor_serializable_2.bin";
1380
1381 let _ = fs::remove_file(binary_path);
1383 let _ = fs::remove_file(binary_path_2);
1384
1385 Serializable::save(&tensor, binary_path, Format::Binary).unwrap();
1387 assert!(Path::new(binary_path).exists());
1388
1389 let loaded_tensor = Tensor::load(binary_path, Format::Binary).unwrap();
1391 assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
1392 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1393
1394 for i in 0..3 {
1396 for j in 0..3 {
1397 assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
1398 }
1399 }
1400
1401 {
1403 let mut writer = std::fs::File::create(binary_path_2).unwrap();
1404 Serializable::save_to_writer(&tensor, &mut writer, Format::Binary).unwrap();
1405 }
1406 assert!(Path::new(binary_path_2).exists());
1407
1408 {
1409 let mut reader = std::fs::File::open(binary_path_2).unwrap();
1410 let loaded_tensor = Tensor::load_from_reader(&mut reader, Format::Binary).unwrap();
1411 assert_eq!(tensor.shape().dims(), loaded_tensor.shape().dims());
1412 assert_eq!(tensor.requires_grad(), loaded_tensor.requires_grad());
1413
1414 for i in 0..3 {
1416 for j in 0..3 {
1417 assert_eq!(tensor.get(&[i, j]), loaded_tensor.get(&[i, j]));
1418 }
1419 }
1420 }
1421
1422 let _ = fs::remove_file(binary_path);
1424 let _ = fs::remove_file(binary_path_2);
1425 }
1426
1427 #[test]
1428 fn test_serializable_large_tensor_performance() {
1429 let mut tensor = Tensor::zeros(vec![50, 50]);
1431 for i in 0..25 {
1432 for j in 0..25 {
1433 tensor.set(&[i, j], (i * 25 + j) as f32);
1434 }
1435 }
1436 tensor.set_requires_grad(true);
1437
1438 let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1440 assert!(!json.is_empty());
1441 let restored_json =
1442 <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1443 assert_eq!(tensor.shape().dims(), restored_json.shape().dims());
1444 assert_eq!(tensor.requires_grad(), restored_json.requires_grad());
1445
1446 let binary = <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1448 assert!(!binary.is_empty());
1449 println!(
1451 "JSON size: {} bytes, Binary size: {} bytes",
1452 json.len(),
1453 binary.len()
1454 );
1455
1456 let restored_binary =
1457 <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1458 assert_eq!(tensor.shape().dims(), restored_binary.shape().dims());
1459 assert_eq!(tensor.requires_grad(), restored_binary.requires_grad());
1460
1461 for i in 0..5 {
1463 for j in 0..5 {
1464 assert_eq!(tensor.get(&[i, j]), restored_json.get(&[i, j]));
1465 assert_eq!(tensor.get(&[i, j]), restored_binary.get(&[i, j]));
1466 }
1467 }
1468 }
1469
1470 #[test]
1471 fn test_serializable_error_handling() {
1472 let invalid_json = r#"{"invalid": "json", "structure": true}"#;
1474 let result = <Tensor as crate::serialization::Serializable>::from_json(invalid_json);
1475 assert!(result.is_err());
1476
1477 let empty_json = "{}";
1479 let result = <Tensor as crate::serialization::Serializable>::from_json(empty_json);
1480 assert!(result.is_err());
1481
1482 let invalid_binary = vec![1, 2, 3, 4, 5];
1484 let result = <Tensor as crate::serialization::Serializable>::from_binary(&invalid_binary);
1485 assert!(result.is_err());
1486
1487 let empty_binary = vec![];
1489 let result = <Tensor as crate::serialization::Serializable>::from_binary(&empty_binary);
1490 assert!(result.is_err());
1491 }
1492
1493 #[test]
1494 fn test_serializable_different_shapes_and_types() {
1495 let test_cases = vec![
1496 (vec![1], vec![42.0]),
1498 (vec![5], vec![1.0, 2.0, 3.0, 4.0, 5.0]),
1500 (vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]),
1502 (vec![2, 2, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
1504 ];
1505
1506 for (shape, expected_data) in test_cases {
1507 let mut tensor = Tensor::zeros(shape.clone());
1509
1510 match shape.len() {
1512 1 => {
1513 for (i, &value) in expected_data.iter().enumerate().take(shape[0]) {
1514 tensor.set(&[i], value);
1515 }
1516 }
1517 2 => {
1518 let mut idx = 0;
1519 for i in 0..shape[0] {
1520 for j in 0..shape[1] {
1521 if idx < expected_data.len() {
1522 tensor.set(&[i, j], expected_data[idx]);
1523 idx += 1;
1524 }
1525 }
1526 }
1527 }
1528 3 => {
1529 let mut idx = 0;
1530 for i in 0..shape[0] {
1531 for j in 0..shape[1] {
1532 for k in 0..shape[2] {
1533 if idx < expected_data.len() {
1534 tensor.set(&[i, j, k], expected_data[idx]);
1535 idx += 1;
1536 }
1537 }
1538 }
1539 }
1540 }
1541 _ => {}
1542 }
1543 tensor.set_requires_grad(true);
1544
1545 let json = <Tensor as crate::serialization::Serializable>::to_json(&tensor).unwrap();
1547 let restored_json =
1548 <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1549 assert_eq!(tensor.shape().dims(), restored_json.shape().dims());
1550 assert_eq!(tensor.requires_grad(), restored_json.requires_grad());
1551
1552 let binary =
1554 <Tensor as crate::serialization::Serializable>::to_binary(&tensor).unwrap();
1555 let restored_binary =
1556 <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1557 assert_eq!(tensor.shape().dims(), restored_binary.shape().dims());
1558 assert_eq!(tensor.requires_grad(), restored_binary.requires_grad());
1559
1560 match shape.len() {
1562 1 => {
1563 for i in 0..shape[0].min(3).min(expected_data.len()) {
1564 assert_eq!(tensor.get(&[i]), restored_json.get(&[i]));
1565 assert_eq!(tensor.get(&[i]), restored_binary.get(&[i]));
1566 }
1567 }
1568 2 => {
1569 let mut count = 0;
1570 for i in 0..shape[0] {
1571 for j in 0..shape[1] {
1572 if count < 3 && count < expected_data.len() {
1573 assert_eq!(tensor.get(&[i, j]), restored_json.get(&[i, j]));
1574 assert_eq!(tensor.get(&[i, j]), restored_binary.get(&[i, j]));
1575 count += 1;
1576 }
1577 }
1578 }
1579 }
1580 3 => {
1581 let mut count = 0;
1582 for i in 0..shape[0] {
1583 for j in 0..shape[1] {
1584 for k in 0..shape[2] {
1585 if count < 3 && count < expected_data.len() {
1586 assert_eq!(
1587 tensor.get(&[i, j, k]),
1588 restored_json.get(&[i, j, k])
1589 );
1590 assert_eq!(
1591 tensor.get(&[i, j, k]),
1592 restored_binary.get(&[i, j, k])
1593 );
1594 count += 1;
1595 }
1596 }
1597 }
1598 }
1599 }
1600 _ => {}
1601 }
1602 }
1603 }
1604
1605 #[test]
1606 fn test_serializable_edge_cases() {
1607 let zero_tensor = Tensor::new(vec![0]);
1609 let json = <Tensor as crate::serialization::Serializable>::to_json(&zero_tensor).unwrap();
1610 let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1611 assert_eq!(zero_tensor.shape().dims(), restored.shape().dims());
1612 assert_eq!(zero_tensor.size(), restored.size());
1613
1614 let binary =
1615 <Tensor as crate::serialization::Serializable>::to_binary(&zero_tensor).unwrap();
1616 let restored =
1617 <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1618 assert_eq!(zero_tensor.shape().dims(), restored.shape().dims());
1619 assert_eq!(zero_tensor.size(), restored.size());
1620
1621 let mut special_tensor = Tensor::zeros(vec![3]);
1623 special_tensor.set(&[0], 0.0); special_tensor.set(&[1], 1000000.0); special_tensor.set(&[2], -1000000.0); let json =
1628 <Tensor as crate::serialization::Serializable>::to_json(&special_tensor).unwrap();
1629 let restored = <Tensor as crate::serialization::Serializable>::from_json(&json).unwrap();
1630 assert_eq!(special_tensor.get(&[0]), restored.get(&[0]));
1631 assert_eq!(special_tensor.get(&[1]), restored.get(&[1]));
1632 assert_eq!(special_tensor.get(&[2]), restored.get(&[2]));
1633
1634 let binary =
1635 <Tensor as crate::serialization::Serializable>::to_binary(&special_tensor).unwrap();
1636 let restored =
1637 <Tensor as crate::serialization::Serializable>::from_binary(&binary).unwrap();
1638 assert_eq!(special_tensor.get(&[0]), restored.get(&[0]));
1639 assert_eq!(special_tensor.get(&[1]), restored.get(&[1]));
1640 assert_eq!(special_tensor.get(&[2]), restored.get(&[2]));
1641 }
1642}