tensor_rs/serde/
tensor.rs

1#[cfg(feature = "use-serde")]
2use serde::{Serialize, Deserialize, Serializer, Deserializer,
3	    ser::SerializeStruct,
4	    de, de::Visitor, de::SeqAccess, de::MapAccess};
5use crate::tensor::Tensor;
6use std::fmt;
7
8impl Serialize for Tensor {
9    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
10    where S: Serializer, {
11        // 3 is the number of fields in the struct.
12        let mut state = serializer.serialize_struct("Tensor", 1)?;
13        state.serialize_field("v", &self.inner().borrow().clone())?;
14        state.end()
15    }
16}
17
18impl<'de> Deserialize<'de> for Tensor {
19    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
20    where D: Deserializer<'de>, {
21
22	enum Field { V }
23	
24        impl<'de> Deserialize<'de> for Field {
25            fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
26            where D: Deserializer<'de>, {
27                struct FieldVisitor;
28
29                impl<'de> Visitor<'de> for FieldVisitor {
30                    type Value = Field;
31
32                    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
33                        formatter.write_str("v")
34                    }
35
36                    fn visit_str<E>(self, value: &str) -> Result<Field, E>
37                    where E: de::Error, {
38                        match value {
39                            "v" => Ok(Field::V),
40                            _ => Err(de::Error::unknown_field(value, &FIELDS)),
41                        }
42                    }
43                }
44
45                deserializer.deserialize_identifier(FieldVisitor)
46            }
47        }
48	
49        struct TensorVisitor;
50
51        impl<'de> Visitor<'de> for TensorVisitor {
52            type Value = Tensor;
53
54            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
55                formatter.write_str("struct Tensor")
56            }
57
58	    fn visit_map<V>(self, mut map: V) -> Result<Tensor, V::Error>
59            where V: MapAccess<'de>, {
60		let mut v = None;
61                while let Some(key) = map.next_key()? {
62                    match key {
63                        Field::V => {
64                            if v.is_some() {
65                                return Err(de::Error::duplicate_field("v"));
66                            }
67                            v = Some(map.next_value()?);
68                        }
69                    }
70                }
71                let v = v.ok_or_else(|| de::Error::missing_field("ok"))?;
72                Ok(Tensor::set_inner(v))
73            }
74
75            fn visit_seq<V>(self, mut seq: V) -> Result<Tensor, V::Error>
76            where V: SeqAccess<'de>, {
77                let tt = seq.next_element()?
78                    .ok_or_else(|| de::Error::invalid_length(0, &self))?;
79                Ok(Tensor::set_inner(tt))
80            }
81        }
82
83        const FIELDS: [&str; 1] = ["v"];
84        deserializer.deserialize_struct("Duration", &FIELDS, TensorVisitor)
85    }
86}
87
88
89#[cfg(all(test, feature = "use-serde"))]
90mod tests {
91    use crate::tensor::Tensor;
92
93    #[test]
94    fn test_serde() {
95	let m1 = Tensor::eye(3,3);
96
97        let serialized = serde_pickle::to_vec(&m1, true).unwrap();
98        let deserialized = serde_pickle::from_slice(&serialized).unwrap();
99        //println!("{:?}", deserialized);
100        assert_eq!(m1, deserialized);
101    }
102}