1use crate::error::{DatasetsError, Result};
4use ndarray::{Array1, Array2};
5use rand::prelude::*;
6use rand::rng;
7use rand::rngs::StdRng;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11mod serde_array {
13 use ndarray::{Array1, Array2};
14 use serde::{Deserialize, Deserializer, Serialize, Serializer};
15 use std::vec::Vec;
16
17 pub fn serialize_array2<S>(array: &Array2<f64>, serializer: S) -> Result<S::Ok, S::Error>
18 where
19 S: Serializer,
20 {
21 let shape = array.shape();
22 let mut vec = Vec::with_capacity(shape[0] * shape[1] + 2);
23
24 vec.push(shape[0] as f64);
26 vec.push(shape[1] as f64);
27
28 vec.extend(array.iter().cloned());
30
31 vec.serialize(serializer)
32 }
33
34 pub fn deserialize_array2<'de, D>(deserializer: D) -> Result<Array2<f64>, D::Error>
35 where
36 D: Deserializer<'de>,
37 {
38 let vec = Vec::<f64>::deserialize(deserializer)?;
39 if vec.len() < 2 {
40 return Err(serde::de::Error::custom("Invalid array2 serialization"));
41 }
42
43 let nrows = vec[0] as usize;
44 let ncols = vec[1] as usize;
45
46 if vec.len() != nrows * ncols + 2 {
47 return Err(serde::de::Error::custom("Invalid array2 serialization"));
48 }
49
50 let data = vec[2..].to_vec();
51 match Array2::from_shape_vec((nrows, ncols), data) {
52 Ok(array) => Ok(array),
53 Err(_) => Err(serde::de::Error::custom("Failed to reshape array2")),
54 }
55 }
56
57 #[allow(dead_code)]
58 pub fn serialize_array1<S>(array: &Array1<f64>, serializer: S) -> Result<S::Ok, S::Error>
59 where
60 S: Serializer,
61 {
62 let vec = array.to_vec();
63 vec.serialize(serializer)
64 }
65
66 pub fn deserialize_array1<'de, D>(deserializer: D) -> Result<Array1<f64>, D::Error>
67 where
68 D: Deserializer<'de>,
69 {
70 let vec = Vec::<f64>::deserialize(deserializer)?;
71 Ok(Array1::from(vec))
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct Dataset {
78 #[serde(
80 serialize_with = "serde_array::serialize_array2",
81 deserialize_with = "serde_array::deserialize_array2"
82 )]
83 pub data: Array2<f64>,
84
85 #[serde(skip_serializing_if = "Option::is_none")]
87 pub target: Option<Array1<f64>>,
88
89 #[serde(skip_serializing_if = "Option::is_none")]
91 pub target_names: Option<Vec<String>>,
92
93 #[serde(skip_serializing_if = "Option::is_none")]
95 pub feature_names: Option<Vec<String>>,
96
97 #[serde(skip_serializing_if = "Option::is_none")]
99 pub feature_descriptions: Option<Vec<String>>,
100
101 #[serde(skip_serializing_if = "Option::is_none")]
103 pub description: Option<String>,
104
105 pub metadata: HashMap<String, String>,
107}
108
109mod optional_array1 {
111 use super::serde_array;
112 use ndarray::Array1;
113 use serde::{self, Deserialize, Deserializer, Serialize, Serializer};
114
115 #[allow(dead_code)]
116 pub fn serialize<S>(array_opt: &Option<Array1<f64>>, serializer: S) -> Result<S::Ok, S::Error>
117 where
118 S: Serializer,
119 {
120 match array_opt {
121 Some(array) => {
122 #[derive(Serialize)]
123 struct Helper<'a>(&'a Array1<f64>);
124
125 #[derive(Serialize)]
126 struct Wrapper<'a> {
127 #[serde(
128 serialize_with = "serde_array::serialize_array1",
129 deserialize_with = "serde_array::deserialize_array1"
130 )]
131 value: &'a Array1<f64>,
132 }
133
134 Wrapper { value: array }.serialize(serializer)
135 }
136 None => serializer.serialize_none(),
137 }
138 }
139
140 #[allow(dead_code)]
141 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Array1<f64>>, D::Error>
142 where
143 D: Deserializer<'de>,
144 {
145 #[derive(Deserialize)]
146 struct Wrapper {
147 #[serde(
148 serialize_with = "serde_array::serialize_array1",
149 deserialize_with = "serde_array::deserialize_array1"
150 )]
151 #[allow(dead_code)]
152 value: Array1<f64>,
153 }
154
155 Option::<Wrapper>::deserialize(deserializer).map(|opt_wrapper| opt_wrapper.map(|w| w.value))
156 }
157}
158
159impl Dataset {
160 pub fn new(data: Array2<f64>, target: Option<Array1<f64>>) -> Self {
162 Dataset {
163 data,
164 target,
165 target_names: None,
166 feature_names: None,
167 feature_descriptions: None,
168 description: None,
169 metadata: HashMap::new(),
170 }
171 }
172
173 pub fn with_target_names(mut self, target_names: Vec<String>) -> Self {
175 self.target_names = Some(target_names);
176 self
177 }
178
179 pub fn with_feature_names(mut self, feature_names: Vec<String>) -> Self {
181 self.feature_names = Some(feature_names);
182 self
183 }
184
185 pub fn with_feature_descriptions(mut self, feature_descriptions: Vec<String>) -> Self {
187 self.feature_descriptions = Some(feature_descriptions);
188 self
189 }
190
191 pub fn with_description(mut self, description: String) -> Self {
193 self.description = Some(description);
194 self
195 }
196
197 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
199 self.metadata.insert(key.to_string(), value.to_string());
200 self
201 }
202
203 pub fn n_samples(&self) -> usize {
205 self.data.nrows()
206 }
207
208 pub fn n_features(&self) -> usize {
210 self.data.ncols()
211 }
212
213 pub fn train_test_split(
215 &self,
216 test_size: f64,
217 random_seed: Option<u64>,
218 ) -> Result<(Dataset, Dataset)> {
219 if test_size <= 0.0 || test_size >= 1.0 {
220 return Err(DatasetsError::InvalidFormat(
221 "test_size must be between 0 and 1".to_string(),
222 ));
223 }
224
225 let n_samples = self.n_samples();
226 let n_test = (n_samples as f64 * test_size).round() as usize;
227 let n_train = n_samples - n_test;
228
229 if n_train == 0 || n_test == 0 {
230 return Err(DatasetsError::InvalidFormat(
231 "Both train and test sets must have at least one sample".to_string(),
232 ));
233 }
234
235 let mut indices: Vec<usize> = (0..n_samples).collect();
237 let mut rng = match random_seed {
238 Some(seed) => StdRng::seed_from_u64(seed),
239 None => {
240 let mut r = rng();
241 StdRng::seed_from_u64(r.next_u64())
242 }
243 };
244 indices.shuffle(&mut rng);
245
246 let train_indices = &indices[0..n_train];
247 let test_indices = &indices[n_train..];
248
249 let train_data = self.data.select(ndarray::Axis(0), train_indices);
251 let train_target = self
252 .target
253 .as_ref()
254 .map(|t| t.select(ndarray::Axis(0), train_indices));
255
256 let mut train_dataset = Dataset::new(train_data, train_target);
257 if let Some(feature_names) = &self.feature_names {
258 train_dataset = train_dataset.with_feature_names(feature_names.clone());
259 }
260 if let Some(description) = &self.description {
261 train_dataset = train_dataset.with_description(description.clone());
262 }
263
264 let test_data = self.data.select(ndarray::Axis(0), test_indices);
266 let test_target = self
267 .target
268 .as_ref()
269 .map(|t| t.select(ndarray::Axis(0), test_indices));
270
271 let mut test_dataset = Dataset::new(test_data, test_target);
272 if let Some(feature_names) = &self.feature_names {
273 test_dataset = test_dataset.with_feature_names(feature_names.clone());
274 }
275 if let Some(description) = &self.description {
276 test_dataset = test_dataset.with_description(description.clone());
277 }
278
279 Ok((train_dataset, test_dataset))
280 }
281}
282
283pub fn normalize(data: &mut Array2<f64>) {
285 let n_features = data.ncols();
286
287 for j in 0..n_features {
288 let mut column = data.column_mut(j);
289
290 let mean = column.mean().unwrap_or(0.0);
292 let std = column.std(0.0);
293
294 if std > 1e-10 {
296 column.mapv_inplace(|x| (x - mean) / std);
297 }
298 }
299}
300
301#[allow(dead_code)]
303trait StatsExt {
304 fn mean(&self) -> Option<f64>;
305 fn std(&self, ddof: f64) -> f64;
306}
307
308impl StatsExt for ndarray::ArrayView1<'_, f64> {
309 fn mean(&self) -> Option<f64> {
310 if self.is_empty() {
311 return None;
312 }
313
314 let sum: f64 = self.sum();
315 Some(sum / self.len() as f64)
316 }
317
318 fn std(&self, ddof: f64) -> f64 {
319 if self.is_empty() {
320 return 0.0;
321 }
322
323 let n = self.len() as f64;
324 let mean = self.mean().unwrap_or(0.0);
325
326 let mut sum_sq = 0.0;
327 for &x in self.iter() {
328 let diff = x - mean;
329 sum_sq += diff * diff;
330 }
331
332 let divisor = n - ddof;
333 if divisor <= 0.0 {
334 return 0.0;
335 }
336
337 (sum_sq / divisor).sqrt()
338 }
339}