scirs2_datasets/utils/
serialization.rs

1//! Serialization utilities for ndarray types with serde
2//!
3//! This module provides helper functions for serializing and deserializing
4//! ndarray Array1 and Array2 types with serde, enabling JSON and other format
5//! compatibility for dataset structures.
6
7use ndarray::{Array1, Array2};
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9use std::vec::Vec;
10
11/// Serialize a 2D array to a format compatible with serde
12///
13/// The serialization format stores the shape information (rows, cols) at the
14/// beginning of a flat vector, followed by the array data in row-major order.
15///
16/// # Arguments
17///
18/// * `array` - The 2D array to serialize
19/// * `serializer` - The serde serializer to use
20///
21/// # Returns
22///
23/// * `Result<S::Ok, S::Error>` - Serialization result
24pub fn serialize_array2<S>(array: &Array2<f64>, serializer: S) -> Result<S::Ok, S::Error>
25where
26    S: Serializer,
27{
28    let shape = array.shape();
29    let mut vec = Vec::with_capacity(shape[0] * shape[1] + 2);
30
31    // Store shape at the beginning
32    vec.push(shape[0] as f64);
33    vec.push(shape[1] as f64);
34
35    // Store data
36    vec.extend(array.iter().cloned());
37
38    vec.serialize(serializer)
39}
40
41/// Deserialize a 2D array from a serde-compatible format
42///
43/// Reconstructs an Array2 from the flattened format created by serialize_array2.
44/// The first two elements are interpreted as the shape (rows, cols), and the
45/// remaining elements are reshaped into the 2D array.
46///
47/// # Arguments
48///
49/// * `deserializer` - The serde deserializer to use
50///
51/// # Returns
52///
53/// * `Result<Array2<f64>, D::Error>` - Deserialized 2D array
54pub fn deserialize_array2<'de, D>(deserializer: D) -> Result<Array2<f64>, D::Error>
55where
56    D: Deserializer<'de>,
57{
58    let vec = Vec::<f64>::deserialize(deserializer)?;
59    if vec.len() < 2 {
60        return Err(serde::de::Error::custom("Invalid array2 serialization"));
61    }
62
63    let nrows = vec[0] as usize;
64    let ncols = vec[1] as usize;
65
66    if vec.len() != nrows * ncols + 2 {
67        return Err(serde::de::Error::custom("Invalid array2 serialization"));
68    }
69
70    let data = vec[2..].to_vec();
71    match Array2::from_shape_vec((nrows, ncols), data) {
72        Ok(array) => Ok(array),
73        Err(_) => Err(serde::de::Error::custom("Failed to reshape array2")),
74    }
75}
76
77/// Serialize a 1D array to a format compatible with serde
78///
79/// Simply converts the Array1 to a Vec for JSON serialization.
80///
81/// # Arguments
82///
83/// * `array` - The 1D array to serialize
84/// * `serializer` - The serde serializer to use
85///
86/// # Returns
87///
88/// * `Result<S::Ok, S::Error>` - Serialization result
89#[allow(dead_code)]
90pub fn serialize_array1<S>(array: &Array1<f64>, serializer: S) -> Result<S::Ok, S::Error>
91where
92    S: Serializer,
93{
94    let vec = array.to_vec();
95    vec.serialize(serializer)
96}
97
98/// Deserialize a 1D array from a serde-compatible format
99///
100/// Reconstructs an Array1 from a Vec<f64>.
101///
102/// # Arguments
103///
104/// * `deserializer` - The serde deserializer to use
105///
106/// # Returns
107///
108/// * `Result<Array1<f64>, D::Error>` - Deserialized 1D array
109pub fn deserialize_array1<'de, D>(deserializer: D) -> Result<Array1<f64>, D::Error>
110where
111    D: Deserializer<'de>,
112{
113    let vec = Vec::<f64>::deserialize(deserializer)?;
114    Ok(Array1::from(vec))
115}
116
117/// Helper functions for serializing Option<Array1<f64>> types
118pub mod optional_array1 {
119    use super::*;
120
121    /// Serialize an optional 1D array
122    ///
123    /// # Arguments
124    ///
125    /// * `array_opt` - The optional array to serialize
126    /// * `serializer` - The serde serializer to use
127    ///
128    /// # Returns
129    ///
130    /// * `Result<S::Ok, S::Error>` - Serialization result
131    #[allow(dead_code)]
132    pub fn serialize<S>(array_opt: &Option<Array1<f64>>, serializer: S) -> Result<S::Ok, S::Error>
133    where
134        S: Serializer,
135    {
136        match array_opt {
137            Some(array) => {
138                #[derive(Serialize)]
139                struct Wrapper<'a> {
140                    #[serde(
141                        serialize_with = "super::serialize_array1",
142                        deserialize_with = "super::deserialize_array1"
143                    )]
144                    value: &'a Array1<f64>,
145                }
146
147                Wrapper { value: array }.serialize(serializer)
148            }
149            None => serializer.serialize_none(),
150        }
151    }
152
153    /// Deserialize an optional 1D array
154    ///
155    /// # Arguments
156    ///
157    /// * `deserializer` - The serde deserializer to use
158    ///
159    /// # Returns
160    ///
161    /// * `Result<Option<Array1<f64>>, D::Error>` - Deserialized optional array
162    #[allow(dead_code)]
163    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Array1<f64>>, D::Error>
164    where
165        D: Deserializer<'de>,
166    {
167        #[derive(Deserialize)]
168        struct Wrapper {
169            #[serde(
170                serialize_with = "super::serialize_array1",
171                deserialize_with = "super::deserialize_array1"
172            )]
173            #[allow(dead_code)]
174            value: Array1<f64>,
175        }
176
177        Option::<Wrapper>::deserialize(deserializer).map(|opt_wrapper| opt_wrapper.map(|w| w.value))
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use ndarray::array;
185
186    #[test]
187    fn test_array2_serialization_roundtrip() {
188        let original = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
189
190        // Serialize to JSON
191        let _json = serde_json::to_string(&original.map(|x| *x)).unwrap();
192
193        // For testing, we need to manually test the serialization functions
194        // since they're designed to work with serde attributes
195        let vec = [2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
196        let reconstructed = Array2::from_shape_vec((2, 3), vec[2..].to_vec()).unwrap();
197
198        assert_eq!(original, reconstructed);
199    }
200
201    #[test]
202    fn test_array1_serialization_roundtrip() {
203        let original = array![1.0, 2.0, 3.0, 4.0, 5.0];
204        let vec = original.to_vec();
205        let reconstructed = Array1::from(vec);
206
207        assert_eq!(original, reconstructed);
208    }
209
210    #[test]
211    fn test_invalid_array2_deserialization() {
212        // Test with insufficient data
213        let vec = [2.0, 3.0, 1.0]; // Claims 2x3 but only has 1 element
214        let result = Array2::from_shape_vec((2, 3), vec[2..].to_vec());
215        assert!(result.is_err());
216    }
217}