tensor_rs/serde/
tensor.rs1#[cfg(feature = "use-serde")]
2use serde::{Serialize, Deserialize, Serializer, Deserializer,
3 ser::SerializeStruct,
4 de, de::Visitor, de::SeqAccess, de::MapAccess};
5use crate::tensor::Tensor;
6use std::fmt;
7
8impl Serialize for Tensor {
9 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
10 where S: Serializer, {
11 let mut state = serializer.serialize_struct("Tensor", 1)?;
13 state.serialize_field("v", &self.inner().borrow().clone())?;
14 state.end()
15 }
16}
17
18impl<'de> Deserialize<'de> for Tensor {
19 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
20 where D: Deserializer<'de>, {
21
22 enum Field { V }
23
24 impl<'de> Deserialize<'de> for Field {
25 fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
26 where D: Deserializer<'de>, {
27 struct FieldVisitor;
28
29 impl<'de> Visitor<'de> for FieldVisitor {
30 type Value = Field;
31
32 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
33 formatter.write_str("v")
34 }
35
36 fn visit_str<E>(self, value: &str) -> Result<Field, E>
37 where E: de::Error, {
38 match value {
39 "v" => Ok(Field::V),
40 _ => Err(de::Error::unknown_field(value, &FIELDS)),
41 }
42 }
43 }
44
45 deserializer.deserialize_identifier(FieldVisitor)
46 }
47 }
48
49 struct TensorVisitor;
50
51 impl<'de> Visitor<'de> for TensorVisitor {
52 type Value = Tensor;
53
54 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
55 formatter.write_str("struct Tensor")
56 }
57
58 fn visit_map<V>(self, mut map: V) -> Result<Tensor, V::Error>
59 where V: MapAccess<'de>, {
60 let mut v = None;
61 while let Some(key) = map.next_key()? {
62 match key {
63 Field::V => {
64 if v.is_some() {
65 return Err(de::Error::duplicate_field("v"));
66 }
67 v = Some(map.next_value()?);
68 }
69 }
70 }
71 let v = v.ok_or_else(|| de::Error::missing_field("ok"))?;
72 Ok(Tensor::set_inner(v))
73 }
74
75 fn visit_seq<V>(self, mut seq: V) -> Result<Tensor, V::Error>
76 where V: SeqAccess<'de>, {
77 let tt = seq.next_element()?
78 .ok_or_else(|| de::Error::invalid_length(0, &self))?;
79 Ok(Tensor::set_inner(tt))
80 }
81 }
82
83 const FIELDS: [&str; 1] = ["v"];
84 deserializer.deserialize_struct("Duration", &FIELDS, TensorVisitor)
85 }
86}
87
88
89#[cfg(all(test, feature = "use-serde"))]
90mod tests {
91 use crate::tensor::Tensor;
92
93 #[test]
94 fn test_serde() {
95 let m1 = Tensor::eye(3,3);
96
97 let serialized = serde_pickle::to_vec(&m1, true).unwrap();
98 let deserialized = serde_pickle::from_slice(&serialized).unwrap();
99 assert_eq!(m1, deserialized);
101 }
102}