scirs2_datasets/utils/
serialization.rs

1//! Serialization utilities for ndarray types with serde
2//!
3//! This module provides helper functions for serializing and deserializing
4//! ndarray Array1 and Array2 types with serde, enabling JSON and other format
5//! compatibility for dataset structures.
6
7use scirs2_core::ndarray::{Array1, Array2};
8use serde::{Deserialize, Deserializer, Serialize, Serializer};
9use std::vec::Vec;
10
11/// Serialize a 2D array to a format compatible with serde
12///
13/// The serialization format stores the shape information (rows, cols) at the
14/// beginning of a flat vector, followed by the array data in row-major order.
15///
16/// # Arguments
17///
18/// * `array` - The 2D array to serialize
19/// * `serializer` - The serde serializer to use
20///
21/// # Returns
22///
23/// * `Result<S::Ok, S::Error>` - Serialization result
24#[allow(dead_code)]
25pub fn serialize_array2<S>(array: &Array2<f64>, serializer: S) -> Result<S::Ok, S::Error>
26where
27    S: Serializer,
28{
29    let shape = array.shape();
30    let mut vec = Vec::with_capacity(shape[0] * shape[1] + 2);
31
32    // Store shape at the beginning
33    vec.push(shape[0] as f64);
34    vec.push(shape[1] as f64);
35
36    // Store data
37    vec.extend(array.iter().cloned());
38
39    vec.serialize(serializer)
40}
41
42/// Deserialize a 2D array from a serde-compatible format
43///
44/// Reconstructs an Array2 from the flattened format created by serialize_array2.
45/// The first two elements are interpreted as the shape (rows, cols), and the
46/// remaining elements are reshaped into the 2D array.
47///
48/// # Arguments
49///
50/// * `deserializer` - The serde deserializer to use
51///
52/// # Returns
53///
54/// * `Result<Array2<f64>, D::Error>` - Deserialized 2D array
55#[allow(dead_code)]
56pub fn deserialize_array2<'de, D>(deserializer: D) -> Result<Array2<f64>, D::Error>
57where
58    D: Deserializer<'de>,
59{
60    let vec = Vec::<f64>::deserialize(deserializer)?;
61    if vec.len() < 2 {
62        return Err(serde::de::Error::custom("Invalid array2 serialization"));
63    }
64
65    let nrows = vec[0] as usize;
66    let ncols = vec[1] as usize;
67
68    if vec.len() != nrows * ncols + 2 {
69        return Err(serde::de::Error::custom("Invalid array2 serialization"));
70    }
71
72    let data = vec[2..].to_vec();
73    match Array2::from_shape_vec((nrows, ncols), data) {
74        Ok(array) => Ok(array),
75        Err(_) => Err(serde::de::Error::custom("Failed to reshape array2")),
76    }
77}
78
79/// Serialize a 1D array to a format compatible with serde
80///
81/// Simply converts the Array1 to a Vec for JSON serialization.
82///
83/// # Arguments
84///
85/// * `array` - The 1D array to serialize
86/// * `serializer` - The serde serializer to use
87///
88/// # Returns
89///
90/// * `Result<S::Ok, S::Error>` - Serialization result
91///   Serialize a 1D array to a serde-compatible format
92///
93/// This function converts an Array1<f64> to a Vec<f64> for serialization.
94/// Useful for saving datasets or individual arrays to JSON, YAML, etc.
95///
96/// # Arguments
97///
98/// * `array` - The 1D array to serialize
99/// * `serializer` - The serde serializer to use
100///
101/// # Returns
102///
103/// * `Result<S::Ok, S::Error>` - Serialization result
104#[allow(dead_code)]
105pub fn serialize_array1<S>(array: &Array1<f64>, serializer: S) -> Result<S::Ok, S::Error>
106where
107    S: Serializer,
108{
109    let vec = array.to_vec();
110    vec.serialize(serializer)
111}
112
113/// Deserialize a 1D array from a serde-compatible format
114///
115/// Reconstructs an Array1 from a Vec<f64>.
116///
117/// # Arguments
118///
119/// * `deserializer` - The serde deserializer to use
120///
121/// # Returns
122///
123/// * `Result<Array1<f64>, D::Error>` - Deserialized 1D array
124#[allow(dead_code)]
125pub fn deserialize_array1<'de, D>(deserializer: D) -> Result<Array1<f64>, D::Error>
126where
127    D: Deserializer<'de>,
128{
129    let vec = Vec::<f64>::deserialize(deserializer)?;
130    Ok(Array1::from(vec))
131}
132
133/// Helper functions for serializing Option<Array1<f64>> types
134pub mod optional_array1 {
135    use super::*;
136
137    /// Serialize an optional 1D array
138    ///
139    /// # Arguments
140    ///
141    /// * `array_opt` - The optional array to serialize
142    /// * `serializer` - The serde serializer to use
143    ///
144    /// # Returns
145    ///
146    /// * `Result<S::Ok, S::Error>` - Serialization result
147    ///   Serialize an optional 1D array to a serde-compatible format
148    ///
149    /// This function handles serialization of optional arrays, serializing None as null
150    /// and Some(array) using the array1 serializer.
151    ///
152    /// # Arguments
153    ///
154    /// * `array_opt` - The optional array to serialize
155    /// * `serializer` - The serde serializer to use
156    ///
157    /// # Returns
158    ///
159    /// * `Result<S::Ok, S::Error>` - Serialization result
160    pub fn serialize<S>(_arrayopt: &Option<Array1<f64>>, serializer: S) -> Result<S::Ok, S::Error>
161    where
162        S: Serializer,
163    {
164        match _arrayopt {
165            Some(array) => {
166                #[derive(Serialize)]
167                struct Wrapper<'a> {
168                    #[serde(
169                        serialize_with = "super::serialize_array1",
170                        deserialize_with = "super::deserialize_array1"
171                    )]
172                    value: &'a Array1<f64>,
173                }
174
175                Wrapper { value: array }.serialize(serializer)
176            }
177            None => serializer.serialize_none(),
178        }
179    }
180
181    /// Deserialize an optional 1D array
182    ///
183    /// # Arguments
184    ///
185    /// * `deserializer` - The serde deserializer to use
186    ///
187    /// # Returns
188    ///
189    /// * `Result<Option<Array1<f64>>, D::Error>` - Deserialized optional array
190    ///   Deserialize an optional 1D array from a serde-compatible format
191    ///
192    /// This function handles deserialization of optional arrays, converting null to None
193    /// and valid data to Some(array).
194    ///
195    /// # Arguments
196    ///
197    /// * `deserializer` - The serde deserializer to use
198    ///
199    /// # Returns
200    ///
201    /// * `Result<Option<Array1<f64>>, D::Error>` - Deserialized optional array
202    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Array1<f64>>, D::Error>
203    where
204        D: Deserializer<'de>,
205    {
206        #[derive(Deserialize)]
207        struct Wrapper {
208            #[serde(
209                serialize_with = "super::serialize_array1",
210                deserialize_with = "super::deserialize_array1"
211            )]
212            #[allow(dead_code)]
213            value: Array1<f64>,
214        }
215
216        Option::<Wrapper>::deserialize(deserializer).map(|opt_wrapper| opt_wrapper.map(|w| w.value))
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use scirs2_core::ndarray::array;
224
225    #[test]
226    fn test_array2_serialization_roundtrip() {
227        let original = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
228
229        // Serialize to JSON
230        let _json = serde_json::to_string(&original.map(|x| *x)).unwrap();
231
232        // For testing, we need to manually test the serialization functions
233        // since they're designed to work with serde attributes
234        let vec = [2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
235        let reconstructed = Array2::from_shape_vec((2, 3), vec[2..].to_vec()).unwrap();
236
237        assert_eq!(original, reconstructed);
238    }
239
240    #[test]
241    fn test_array1_serialization_roundtrip() {
242        let original = array![1.0, 2.0, 3.0, 4.0, 5.0];
243        let vec = original.to_vec();
244        let reconstructed = Array1::from(vec);
245
246        assert_eq!(original, reconstructed);
247    }
248
249    #[test]
250    fn test_invalid_array2_deserialization() {
251        // Test with insufficient data
252        let vec = [2.0, 3.0, 1.0]; // Claims 2x3 but only has 1 element
253        let result = Array2::from_shape_vec((2, 3), vec[2..].to_vec());
254        assert!(result.is_err());
255    }
256}