sparse_bin_mat/matrix/
ser_de.rs

1use super::SparseBinMat;
2use serde;
3use serde::de::{Deserializer, MapAccess, Visitor};
4use serde::ser::{Serialize, SerializeStruct, Serializer};
5use serde::Deserialize;
6use std::fmt;
7
8impl Serialize for SparseBinMat {
9    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
10    where
11        S: Serializer,
12    {
13        let rows: Vec<_> = self
14            .rows()
15            .map(|row| row.to_vec().to_positions_vec())
16            .collect();
17        let mut state = serializer.serialize_struct("SparseBinMat", 2)?;
18        state.serialize_field("number_of_columns", &self.number_of_columns())?;
19        state.serialize_field("rows", &rows)?;
20        state.end()
21    }
22}
23
24#[derive(Deserialize)]
25#[serde(field_identifier, rename_all = "snake_case")]
26enum Field {
27    NumberOfColumns,
28    Rows,
29}
30
31const FIELDS: &'static [&'static str] = &["number_of_columns", "rows"];
32
33struct MatrixVisitor;
34
35impl<'de> Visitor<'de> for MatrixVisitor {
36    type Value = SparseBinMat;
37
38    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
39        formatter.write_str("struct SparseBinMat")
40    }
41
42    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
43    where
44        A: MapAccess<'de>,
45    {
46        let mut number_of_columns = None;
47        let mut rows = None;
48        while let Some(key) = map.next_key()? {
49            match key {
50                Field::NumberOfColumns => {
51                    if number_of_columns.is_some() {
52                        return Err(serde::de::Error::duplicate_field("number_of_columns"));
53                    }
54                    number_of_columns = Some(map.next_value()?);
55                }
56                Field::Rows => {
57                    if rows.is_some() {
58                        return Err(serde::de::Error::duplicate_field("rows"));
59                    }
60                    rows = Some(map.next_value()?);
61                }
62            }
63        }
64        let number_of_columns: usize = number_of_columns
65            .ok_or_else(|| serde::de::Error::missing_field("number_of_columns"))?;
66        let rows: Vec<Vec<usize>> = rows.ok_or_else(|| serde::de::Error::missing_field("rows"))?;
67        if number_of_columns == 0 && rows.len() == 0 {
68            Ok(SparseBinMat::empty())
69        } else {
70            SparseBinMat::try_new(number_of_columns, rows)
71                .map_err(|error| serde::de::Error::custom(&error.to_string()))
72        }
73    }
74}
75
76impl<'de> Deserialize<'de> for SparseBinMat {
77    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
78    where
79        D: Deserializer<'de>,
80    {
81        deserializer.deserialize_struct("SparseBinMat", FIELDS, MatrixVisitor)
82    }
83}
84
85#[cfg(test)]
86mod test {
87    use super::*;
88    use serde_test::{assert_de_tokens_error, assert_tokens, Token};
89
90    #[test]
91    fn ser_de_empty_matrix() {
92        let matrix = SparseBinMat::empty();
93        assert_tokens(
94            &matrix,
95            &[
96                Token::Struct {
97                    name: "SparseBinMat",
98                    len: 2,
99                },
100                Token::String(&"number_of_columns"),
101                Token::U64(0),
102                Token::String(&"rows"),
103                Token::Seq { len: Some(0) },
104                Token::SeqEnd,
105                Token::StructEnd,
106            ],
107        );
108    }
109
110    #[test]
111    fn ser_de_2_by_5_matrix() {
112        let matrix = SparseBinMat::new(5, vec![vec![0, 2, 4], vec![1, 3]]);
113        assert_tokens(
114            &matrix,
115            &[
116                Token::Struct {
117                    name: "SparseBinMat",
118                    len: 2,
119                },
120                Token::String(&"number_of_columns"),
121                Token::U64(5),
122                Token::String(&"rows"),
123                Token::Seq { len: Some(2) },
124                Token::Seq { len: Some(3) },
125                Token::U64(0),
126                Token::U64(2),
127                Token::U64(4),
128                Token::SeqEnd,
129                Token::Seq { len: Some(2) },
130                Token::U64(1),
131                Token::U64(3),
132                Token::SeqEnd,
133                Token::SeqEnd,
134                Token::StructEnd,
135            ],
136        );
137    }
138
139    #[test]
140    fn de_unsorted_rows() {
141        assert_de_tokens_error::<SparseBinMat>(
142            &[
143                Token::Struct {
144                    name: "SparseBinMat",
145                    len: 2,
146                },
147                Token::String(&"number_of_columns"),
148                Token::U64(5),
149                Token::String(&"rows"),
150                Token::Seq { len: Some(1) },
151                Token::Seq { len: Some(3) },
152                Token::U64(0),
153                Token::U64(4),
154                Token::U64(2),
155                Token::SeqEnd,
156                Token::SeqEnd,
157                Token::StructEnd,
158            ],
159            "some positions are not sorted",
160        );
161    }
162}