sparse_bin_mat/matrix/
ser_de.rs1use super::SparseBinMat;
2use serde;
3use serde::de::{Deserializer, MapAccess, Visitor};
4use serde::ser::{Serialize, SerializeStruct, Serializer};
5use serde::Deserialize;
6use std::fmt;
7
8impl Serialize for SparseBinMat {
9 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
10 where
11 S: Serializer,
12 {
13 let rows: Vec<_> = self
14 .rows()
15 .map(|row| row.to_vec().to_positions_vec())
16 .collect();
17 let mut state = serializer.serialize_struct("SparseBinMat", 2)?;
18 state.serialize_field("number_of_columns", &self.number_of_columns())?;
19 state.serialize_field("rows", &rows)?;
20 state.end()
21 }
22}
23
24#[derive(Deserialize)]
25#[serde(field_identifier, rename_all = "snake_case")]
26enum Field {
27 NumberOfColumns,
28 Rows,
29}
30
31const FIELDS: &'static [&'static str] = &["number_of_columns", "rows"];
32
33struct MatrixVisitor;
34
35impl<'de> Visitor<'de> for MatrixVisitor {
36 type Value = SparseBinMat;
37
38 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
39 formatter.write_str("struct SparseBinMat")
40 }
41
42 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
43 where
44 A: MapAccess<'de>,
45 {
46 let mut number_of_columns = None;
47 let mut rows = None;
48 while let Some(key) = map.next_key()? {
49 match key {
50 Field::NumberOfColumns => {
51 if number_of_columns.is_some() {
52 return Err(serde::de::Error::duplicate_field("number_of_columns"));
53 }
54 number_of_columns = Some(map.next_value()?);
55 }
56 Field::Rows => {
57 if rows.is_some() {
58 return Err(serde::de::Error::duplicate_field("rows"));
59 }
60 rows = Some(map.next_value()?);
61 }
62 }
63 }
64 let number_of_columns: usize = number_of_columns
65 .ok_or_else(|| serde::de::Error::missing_field("number_of_columns"))?;
66 let rows: Vec<Vec<usize>> = rows.ok_or_else(|| serde::de::Error::missing_field("rows"))?;
67 if number_of_columns == 0 && rows.len() == 0 {
68 Ok(SparseBinMat::empty())
69 } else {
70 SparseBinMat::try_new(number_of_columns, rows)
71 .map_err(|error| serde::de::Error::custom(&error.to_string()))
72 }
73 }
74}
75
76impl<'de> Deserialize<'de> for SparseBinMat {
77 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
78 where
79 D: Deserializer<'de>,
80 {
81 deserializer.deserialize_struct("SparseBinMat", FIELDS, MatrixVisitor)
82 }
83}
84
85#[cfg(test)]
86mod test {
87 use super::*;
88 use serde_test::{assert_de_tokens_error, assert_tokens, Token};
89
90 #[test]
91 fn ser_de_empty_matrix() {
92 let matrix = SparseBinMat::empty();
93 assert_tokens(
94 &matrix,
95 &[
96 Token::Struct {
97 name: "SparseBinMat",
98 len: 2,
99 },
100 Token::String(&"number_of_columns"),
101 Token::U64(0),
102 Token::String(&"rows"),
103 Token::Seq { len: Some(0) },
104 Token::SeqEnd,
105 Token::StructEnd,
106 ],
107 );
108 }
109
110 #[test]
111 fn ser_de_2_by_5_matrix() {
112 let matrix = SparseBinMat::new(5, vec![vec![0, 2, 4], vec![1, 3]]);
113 assert_tokens(
114 &matrix,
115 &[
116 Token::Struct {
117 name: "SparseBinMat",
118 len: 2,
119 },
120 Token::String(&"number_of_columns"),
121 Token::U64(5),
122 Token::String(&"rows"),
123 Token::Seq { len: Some(2) },
124 Token::Seq { len: Some(3) },
125 Token::U64(0),
126 Token::U64(2),
127 Token::U64(4),
128 Token::SeqEnd,
129 Token::Seq { len: Some(2) },
130 Token::U64(1),
131 Token::U64(3),
132 Token::SeqEnd,
133 Token::SeqEnd,
134 Token::StructEnd,
135 ],
136 );
137 }
138
139 #[test]
140 fn de_unsorted_rows() {
141 assert_de_tokens_error::<SparseBinMat>(
142 &[
143 Token::Struct {
144 name: "SparseBinMat",
145 len: 2,
146 },
147 Token::String(&"number_of_columns"),
148 Token::U64(5),
149 Token::String(&"rows"),
150 Token::Seq { len: Some(1) },
151 Token::Seq { len: Some(3) },
152 Token::U64(0),
153 Token::U64(4),
154 Token::U64(2),
155 Token::SeqEnd,
156 Token::SeqEnd,
157 Token::StructEnd,
158 ],
159 "some positions are not sorted",
160 );
161 }
162}