scirs2_datasets/
utils.rs

1//\! Utility functions and data structures for datasets
2
3use crate::error::{DatasetsError, Result};
4use ndarray::{Array1, Array2};
5use rand::prelude::*;
6use rand::rng;
7use rand::rngs::StdRng;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11// Helper module for serializing ndarray types with serde
12mod serde_array {
13    use ndarray::{Array1, Array2};
14    use serde::{Deserialize, Deserializer, Serialize, Serializer};
15    use std::vec::Vec;
16
17    pub fn serialize_array2<S>(array: &Array2<f64>, serializer: S) -> Result<S::Ok, S::Error>
18    where
19        S: Serializer,
20    {
21        let shape = array.shape();
22        let mut vec = Vec::with_capacity(shape[0] * shape[1] + 2);
23
24        // Store shape at the beginning
25        vec.push(shape[0] as f64);
26        vec.push(shape[1] as f64);
27
28        // Store data
29        vec.extend(array.iter().cloned());
30
31        vec.serialize(serializer)
32    }
33
34    pub fn deserialize_array2<'de, D>(deserializer: D) -> Result<Array2<f64>, D::Error>
35    where
36        D: Deserializer<'de>,
37    {
38        let vec = Vec::<f64>::deserialize(deserializer)?;
39        if vec.len() < 2 {
40            return Err(serde::de::Error::custom("Invalid array2 serialization"));
41        }
42
43        let nrows = vec[0] as usize;
44        let ncols = vec[1] as usize;
45
46        if vec.len() != nrows * ncols + 2 {
47            return Err(serde::de::Error::custom("Invalid array2 serialization"));
48        }
49
50        let data = vec[2..].to_vec();
51        match Array2::from_shape_vec((nrows, ncols), data) {
52            Ok(array) => Ok(array),
53            Err(_) => Err(serde::de::Error::custom("Failed to reshape array2")),
54        }
55    }
56
57    #[allow(dead_code)]
58    pub fn serialize_array1<S>(array: &Array1<f64>, serializer: S) -> Result<S::Ok, S::Error>
59    where
60        S: Serializer,
61    {
62        let vec = array.to_vec();
63        vec.serialize(serializer)
64    }
65
66    pub fn deserialize_array1<'de, D>(deserializer: D) -> Result<Array1<f64>, D::Error>
67    where
68        D: Deserializer<'de>,
69    {
70        let vec = Vec::<f64>::deserialize(deserializer)?;
71        Ok(Array1::from(vec))
72    }
73}
74
75/// Represents a dataset with features, optional targets, and metadata
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct Dataset {
78    /// Features/data matrix (n_samples, n_features)
79    #[serde(
80        serialize_with = "serde_array::serialize_array2",
81        deserialize_with = "serde_array::deserialize_array2"
82    )]
83    pub data: Array2<f64>,
84
85    /// Optional target values
86    #[serde(skip_serializing_if = "Option::is_none")]
87    pub target: Option<Array1<f64>>,
88
89    /// Optional target names for classification problems
90    #[serde(skip_serializing_if = "Option::is_none")]
91    pub target_names: Option<Vec<String>>,
92
93    /// Optional feature names
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub feature_names: Option<Vec<String>>,
96
97    /// Optional descriptions for each feature
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub feature_descriptions: Option<Vec<String>>,
100
101    /// Optional dataset description
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub description: Option<String>,
104
105    /// Optional dataset metadata
106    pub metadata: HashMap<String, String>,
107}
108
109// Helper module for serializing Option<Array1<f64>>
110mod optional_array1 {
111    use super::serde_array;
112    use ndarray::Array1;
113    use serde::{self, Deserialize, Deserializer, Serialize, Serializer};
114
115    #[allow(dead_code)]
116    pub fn serialize<S>(array_opt: &Option<Array1<f64>>, serializer: S) -> Result<S::Ok, S::Error>
117    where
118        S: Serializer,
119    {
120        match array_opt {
121            Some(array) => {
122                #[derive(Serialize)]
123                struct Helper<'a>(&'a Array1<f64>);
124
125                #[derive(Serialize)]
126                struct Wrapper<'a> {
127                    #[serde(
128                        serialize_with = "serde_array::serialize_array1",
129                        deserialize_with = "serde_array::deserialize_array1"
130                    )]
131                    value: &'a Array1<f64>,
132                }
133
134                Wrapper { value: array }.serialize(serializer)
135            }
136            None => serializer.serialize_none(),
137        }
138    }
139
140    #[allow(dead_code)]
141    pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Array1<f64>>, D::Error>
142    where
143        D: Deserializer<'de>,
144    {
145        #[derive(Deserialize)]
146        struct Wrapper {
147            #[serde(
148                serialize_with = "serde_array::serialize_array1",
149                deserialize_with = "serde_array::deserialize_array1"
150            )]
151            #[allow(dead_code)]
152            value: Array1<f64>,
153        }
154
155        Option::<Wrapper>::deserialize(deserializer).map(|opt_wrapper| opt_wrapper.map(|w| w.value))
156    }
157}
158
159impl Dataset {
160    /// Create a new dataset with the given data and target
161    pub fn new(data: Array2<f64>, target: Option<Array1<f64>>) -> Self {
162        Dataset {
163            data,
164            target,
165            target_names: None,
166            feature_names: None,
167            feature_descriptions: None,
168            description: None,
169            metadata: HashMap::new(),
170        }
171    }
172
173    /// Add target names to the dataset
174    pub fn with_target_names(mut self, target_names: Vec<String>) -> Self {
175        self.target_names = Some(target_names);
176        self
177    }
178
179    /// Add feature names to the dataset
180    pub fn with_feature_names(mut self, feature_names: Vec<String>) -> Self {
181        self.feature_names = Some(feature_names);
182        self
183    }
184
185    /// Add feature descriptions to the dataset
186    pub fn with_feature_descriptions(mut self, feature_descriptions: Vec<String>) -> Self {
187        self.feature_descriptions = Some(feature_descriptions);
188        self
189    }
190
191    /// Add a description to the dataset
192    pub fn with_description(mut self, description: String) -> Self {
193        self.description = Some(description);
194        self
195    }
196
197    /// Add metadata to the dataset
198    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
199        self.metadata.insert(key.to_string(), value.to_string());
200        self
201    }
202
203    /// Get the number of samples in the dataset
204    pub fn n_samples(&self) -> usize {
205        self.data.nrows()
206    }
207
208    /// Get the number of features in the dataset
209    pub fn n_features(&self) -> usize {
210        self.data.ncols()
211    }
212
213    /// Split the dataset into training and test sets
214    pub fn train_test_split(
215        &self,
216        test_size: f64,
217        random_seed: Option<u64>,
218    ) -> Result<(Dataset, Dataset)> {
219        if test_size <= 0.0 || test_size >= 1.0 {
220            return Err(DatasetsError::InvalidFormat(
221                "test_size must be between 0 and 1".to_string(),
222            ));
223        }
224
225        let n_samples = self.n_samples();
226        let n_test = (n_samples as f64 * test_size).round() as usize;
227        let n_train = n_samples - n_test;
228
229        if n_train == 0 || n_test == 0 {
230            return Err(DatasetsError::InvalidFormat(
231                "Both train and test sets must have at least one sample".to_string(),
232            ));
233        }
234
235        // Create shuffled indices
236        let mut indices: Vec<usize> = (0..n_samples).collect();
237        let mut rng = match random_seed {
238            Some(seed) => StdRng::seed_from_u64(seed),
239            None => {
240                let mut r = rng();
241                StdRng::seed_from_u64(r.next_u64())
242            }
243        };
244        indices.shuffle(&mut rng);
245
246        let train_indices = &indices[0..n_train];
247        let test_indices = &indices[n_train..];
248
249        // Create training dataset
250        let train_data = self.data.select(ndarray::Axis(0), train_indices);
251        let train_target = self
252            .target
253            .as_ref()
254            .map(|t| t.select(ndarray::Axis(0), train_indices));
255
256        let mut train_dataset = Dataset::new(train_data, train_target);
257        if let Some(feature_names) = &self.feature_names {
258            train_dataset = train_dataset.with_feature_names(feature_names.clone());
259        }
260        if let Some(description) = &self.description {
261            train_dataset = train_dataset.with_description(description.clone());
262        }
263
264        // Create test dataset
265        let test_data = self.data.select(ndarray::Axis(0), test_indices);
266        let test_target = self
267            .target
268            .as_ref()
269            .map(|t| t.select(ndarray::Axis(0), test_indices));
270
271        let mut test_dataset = Dataset::new(test_data, test_target);
272        if let Some(feature_names) = &self.feature_names {
273            test_dataset = test_dataset.with_feature_names(feature_names.clone());
274        }
275        if let Some(description) = &self.description {
276            test_dataset = test_dataset.with_description(description.clone());
277        }
278
279        Ok((train_dataset, test_dataset))
280    }
281}
282
283/// Helper function to normalize data (zero mean, unit variance)
284pub fn normalize(data: &mut Array2<f64>) {
285    let n_features = data.ncols();
286
287    for j in 0..n_features {
288        let mut column = data.column_mut(j);
289
290        // Calculate mean and std
291        let mean = column.mean().unwrap_or(0.0);
292        let std = column.std(0.0);
293
294        // Avoid division by zero
295        if std > 1e-10 {
296            column.mapv_inplace(|x| (x - mean) / std);
297        }
298    }
299}
300
301/// Trait extension for Array2 to calculate mean and standard deviation
302#[allow(dead_code)]
303trait StatsExt {
304    fn mean(&self) -> Option<f64>;
305    fn std(&self, ddof: f64) -> f64;
306}
307
308impl StatsExt for ndarray::ArrayView1<'_, f64> {
309    fn mean(&self) -> Option<f64> {
310        if self.is_empty() {
311            return None;
312        }
313
314        let sum: f64 = self.sum();
315        Some(sum / self.len() as f64)
316    }
317
318    fn std(&self, ddof: f64) -> f64 {
319        if self.is_empty() {
320            return 0.0;
321        }
322
323        let n = self.len() as f64;
324        let mean = self.mean().unwrap_or(0.0);
325
326        let mut sum_sq = 0.0;
327        for &x in self.iter() {
328            let diff = x - mean;
329            sum_sq += diff * diff;
330        }
331
332        let divisor = n - ddof;
333        if divisor <= 0.0 {
334            return 0.0;
335        }
336
337        (sum_sq / divisor).sqrt()
338    }
339}