scirs2_datasets/utils/
serialization.rs1use ndarray::{Array1, Array2};
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9use std::vec::Vec;
10
11pub fn serialize_array2<S>(array: &Array2<f64>, serializer: S) -> Result<S::Ok, S::Error>
25where
26 S: Serializer,
27{
28 let shape = array.shape();
29 let mut vec = Vec::with_capacity(shape[0] * shape[1] + 2);
30
31 vec.push(shape[0] as f64);
33 vec.push(shape[1] as f64);
34
35 vec.extend(array.iter().cloned());
37
38 vec.serialize(serializer)
39}
40
41pub fn deserialize_array2<'de, D>(deserializer: D) -> Result<Array2<f64>, D::Error>
55where
56 D: Deserializer<'de>,
57{
58 let vec = Vec::<f64>::deserialize(deserializer)?;
59 if vec.len() < 2 {
60 return Err(serde::de::Error::custom("Invalid array2 serialization"));
61 }
62
63 let nrows = vec[0] as usize;
64 let ncols = vec[1] as usize;
65
66 if vec.len() != nrows * ncols + 2 {
67 return Err(serde::de::Error::custom("Invalid array2 serialization"));
68 }
69
70 let data = vec[2..].to_vec();
71 match Array2::from_shape_vec((nrows, ncols), data) {
72 Ok(array) => Ok(array),
73 Err(_) => Err(serde::de::Error::custom("Failed to reshape array2")),
74 }
75}
76
77#[allow(dead_code)]
90pub fn serialize_array1<S>(array: &Array1<f64>, serializer: S) -> Result<S::Ok, S::Error>
91where
92 S: Serializer,
93{
94 let vec = array.to_vec();
95 vec.serialize(serializer)
96}
97
98pub fn deserialize_array1<'de, D>(deserializer: D) -> Result<Array1<f64>, D::Error>
110where
111 D: Deserializer<'de>,
112{
113 let vec = Vec::<f64>::deserialize(deserializer)?;
114 Ok(Array1::from(vec))
115}
116
117pub mod optional_array1 {
119 use super::*;
120
121 #[allow(dead_code)]
132 pub fn serialize<S>(array_opt: &Option<Array1<f64>>, serializer: S) -> Result<S::Ok, S::Error>
133 where
134 S: Serializer,
135 {
136 match array_opt {
137 Some(array) => {
138 #[derive(Serialize)]
139 struct Wrapper<'a> {
140 #[serde(
141 serialize_with = "super::serialize_array1",
142 deserialize_with = "super::deserialize_array1"
143 )]
144 value: &'a Array1<f64>,
145 }
146
147 Wrapper { value: array }.serialize(serializer)
148 }
149 None => serializer.serialize_none(),
150 }
151 }
152
153 #[allow(dead_code)]
163 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Array1<f64>>, D::Error>
164 where
165 D: Deserializer<'de>,
166 {
167 #[derive(Deserialize)]
168 struct Wrapper {
169 #[serde(
170 serialize_with = "super::serialize_array1",
171 deserialize_with = "super::deserialize_array1"
172 )]
173 #[allow(dead_code)]
174 value: Array1<f64>,
175 }
176
177 Option::<Wrapper>::deserialize(deserializer).map(|opt_wrapper| opt_wrapper.map(|w| w.value))
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use ndarray::array;
185
186 #[test]
187 fn test_array2_serialization_roundtrip() {
188 let original = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
189
190 let _json = serde_json::to_string(&original.map(|x| *x)).unwrap();
192
193 let vec = [2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
196 let reconstructed = Array2::from_shape_vec((2, 3), vec[2..].to_vec()).unwrap();
197
198 assert_eq!(original, reconstructed);
199 }
200
201 #[test]
202 fn test_array1_serialization_roundtrip() {
203 let original = array![1.0, 2.0, 3.0, 4.0, 5.0];
204 let vec = original.to_vec();
205 let reconstructed = Array1::from(vec);
206
207 assert_eq!(original, reconstructed);
208 }
209
210 #[test]
211 fn test_invalid_array2_deserialization() {
212 let vec = [2.0, 3.0, 1.0]; let result = Array2::from_shape_vec((2, 3), vec[2..].to_vec());
215 assert!(result.is_err());
216 }
217}