Skip to main content

scirs2_neural/utils/
datasets.rs

1//! Dataset utilities for neural network training
2//!
3//! This module provides utilities for loading, batching, and preprocessing
4//! data for neural network training.
5
6use crate::error::{NeuralError, Result};
7use scirs2_core::ndarray::{s, Array, Array2, ArrayView2, Axis, IxDyn};
8use scirs2_core::numeric::{Float, NumAssign};
9use scirs2_core::random::{Rng, RngExt};
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13/// A dataset for neural network training
14///
15/// Provides efficient batching and shuffling of training data.
16///
17/// # Examples
18///
19/// ```rust
20/// use scirs2_neural::utils::datasets::Dataset;
21/// use scirs2_core::ndarray::Array2;
22///
23/// // Create a dataset with 100 samples, 10 features
24/// let features = Array2::<f64>::zeros((100, 10));
25/// let labels = Array2::<f64>::zeros((100, 3));
26///
27/// let dataset = Dataset::new(features, labels).expect("failed to create dataset");
28/// assert_eq!(dataset.len(), 100);
29/// ```
30#[derive(Debug, Clone)]
31pub struct Dataset<F: Float + Debug + NumAssign> {
32    /// Feature matrix [num_samples, num_features]
33    features: Array2<F>,
34    /// Label matrix [num_samples, num_labels]
35    labels: Array2<F>,
36    /// Indices for shuffling
37    indices: Vec<usize>,
38}
39
40impl<F: Float + Debug + NumAssign> Dataset<F> {
41    /// Create a new dataset from features and labels
42    ///
43    /// # Arguments
44    /// * `features` - Feature matrix [num_samples, num_features]
45    /// * `labels` - Label matrix [num_samples, num_labels]
46    ///
47    /// # Returns
48    /// A new Dataset instance
49    pub fn new(features: Array2<F>, labels: Array2<F>) -> Result<Self> {
50        if features.nrows() != labels.nrows() {
51            return Err(NeuralError::InvalidArchitecture(format!(
52                "Features and labels must have same number of samples: {} vs {}",
53                features.nrows(),
54                labels.nrows()
55            )));
56        }
57
58        let num_samples = features.nrows();
59        let indices: Vec<usize> = (0..num_samples).collect();
60
61        Ok(Self {
62            features,
63            labels,
64            indices,
65        })
66    }
67
68    /// Get the number of samples in the dataset
69    pub fn len(&self) -> usize {
70        self.features.nrows()
71    }
72
73    /// Check if the dataset is empty
74    pub fn is_empty(&self) -> bool {
75        self.features.nrows() == 0
76    }
77
78    /// Get the number of features
79    pub fn num_features(&self) -> usize {
80        self.features.ncols()
81    }
82
83    /// Get the number of labels/outputs
84    pub fn num_labels(&self) -> usize {
85        self.labels.ncols()
86    }
87
88    /// Get a reference to the features
89    pub fn features(&self) -> &Array2<F> {
90        &self.features
91    }
92
93    /// Get a reference to the labels
94    pub fn labels(&self) -> &Array2<F> {
95        &self.labels
96    }
97
98    /// Shuffle the dataset in place
99    ///
100    /// # Arguments
101    /// * `rng` - Random number generator
102    pub fn shuffle<R: Rng>(&mut self, rng: &mut R) {
103        let n = self.indices.len();
104        for i in (1..n).rev() {
105            let j = (rng.random::<f64>() * (i + 1) as f64) as usize;
106            self.indices.swap(i, j);
107        }
108    }
109
110    /// Get a batch of data at the specified indices
111    ///
112    /// # Arguments
113    /// * `start` - Starting index
114    /// * `end` - Ending index (exclusive)
115    ///
116    /// # Returns
117    /// A tuple of (features_batch, labels_batch)
118    pub fn get_batch(&self, start: usize, end: usize) -> Result<(Array2<F>, Array2<F>)> {
119        let end = end.min(self.len());
120        if start >= end {
121            return Err(NeuralError::InvalidArchitecture(format!(
122                "Invalid batch range: {}..{}",
123                start, end
124            )));
125        }
126
127        let batch_indices: Vec<usize> = self.indices[start..end].to_vec();
128        let batch_size = batch_indices.len();
129
130        // Extract features for batch
131        let mut features_batch = Array2::zeros((batch_size, self.num_features()));
132        let mut labels_batch = Array2::zeros((batch_size, self.num_labels()));
133
134        for (batch_idx, &sample_idx) in batch_indices.iter().enumerate() {
135            for f in 0..self.num_features() {
136                features_batch[[batch_idx, f]] = self.features[[sample_idx, f]];
137            }
138            for l in 0..self.num_labels() {
139                labels_batch[[batch_idx, l]] = self.labels[[sample_idx, l]];
140            }
141        }
142
143        Ok((features_batch, labels_batch))
144    }
145
146    /// Split the dataset into training and validation sets
147    ///
148    /// # Arguments
149    /// * `train_ratio` - Fraction of data to use for training (0.0 to 1.0)
150    /// * `rng` - Random number generator for shuffling before split
151    ///
152    /// # Returns
153    /// A tuple of (train_dataset, val_dataset)
154    pub fn train_val_split<R: Rng>(
155        mut self,
156        train_ratio: f64,
157        rng: &mut R,
158    ) -> Result<(Self, Self)> {
159        if !(0.0..=1.0).contains(&train_ratio) {
160            return Err(NeuralError::InvalidArchitecture(format!(
161                "train_ratio must be between 0 and 1, got {}",
162                train_ratio
163            )));
164        }
165
166        // Shuffle first
167        self.shuffle(rng);
168
169        let n = self.len();
170        let train_size = (n as f64 * train_ratio) as usize;
171
172        // Get indices for train and val
173        let train_indices: Vec<usize> = self.indices[..train_size].to_vec();
174        let val_indices: Vec<usize> = self.indices[train_size..].to_vec();
175
176        // Build train dataset
177        let mut train_features = Array2::zeros((train_size, self.num_features()));
178        let mut train_labels = Array2::zeros((train_size, self.num_labels()));
179        for (new_idx, &old_idx) in train_indices.iter().enumerate() {
180            for f in 0..self.num_features() {
181                train_features[[new_idx, f]] = self.features[[old_idx, f]];
182            }
183            for l in 0..self.num_labels() {
184                train_labels[[new_idx, l]] = self.labels[[old_idx, l]];
185            }
186        }
187
188        // Build val dataset
189        let val_size = n - train_size;
190        let mut val_features = Array2::zeros((val_size, self.num_features()));
191        let mut val_labels = Array2::zeros((val_size, self.num_labels()));
192        for (new_idx, &old_idx) in val_indices.iter().enumerate() {
193            for f in 0..self.num_features() {
194                val_features[[new_idx, f]] = self.features[[old_idx, f]];
195            }
196            for l in 0..self.num_labels() {
197                val_labels[[new_idx, l]] = self.labels[[old_idx, l]];
198            }
199        }
200
201        Ok((
202            Dataset::new(train_features, train_labels)?,
203            Dataset::new(val_features, val_labels)?,
204        ))
205    }
206}
207
208/// Iterator for batching a dataset
209///
210/// Provides efficient iteration over batches of a dataset.
211pub struct BatchIterator<'a, F: Float + Debug + NumAssign> {
212    dataset: &'a Dataset<F>,
213    batch_size: usize,
214    current_idx: usize,
215    drop_last: bool,
216}
217
218impl<'a, F: Float + Debug + NumAssign> BatchIterator<'a, F> {
219    /// Create a new batch iterator
220    ///
221    /// # Arguments
222    /// * `dataset` - The dataset to iterate over
223    /// * `batch_size` - Size of each batch
224    /// * `drop_last` - Whether to drop the last batch if it's smaller than batch_size
225    pub fn new(dataset: &'a Dataset<F>, batch_size: usize, drop_last: bool) -> Self {
226        Self {
227            dataset,
228            batch_size,
229            current_idx: 0,
230            drop_last,
231        }
232    }
233
234    /// Get the number of batches
235    pub fn num_batches(&self) -> usize {
236        let n = self.dataset.len();
237        if self.drop_last {
238            n / self.batch_size
239        } else {
240            n.div_ceil(self.batch_size)
241        }
242    }
243}
244
245impl<'a, F: Float + Debug + NumAssign> Iterator for BatchIterator<'a, F> {
246    type Item = Result<(Array2<F>, Array2<F>)>;
247
248    fn next(&mut self) -> Option<Self::Item> {
249        if self.current_idx >= self.dataset.len() {
250            return None;
251        }
252
253        let start = self.current_idx;
254        let end = (start + self.batch_size).min(self.dataset.len());
255
256        // Check if we should drop this batch
257        if self.drop_last && end - start < self.batch_size {
258            return None;
259        }
260
261        self.current_idx = end;
262        Some(self.dataset.get_batch(start, end))
263    }
264}
265
266/// Data loader for training neural networks
267///
268/// Provides shuffling, batching, and iteration over datasets.
269///
270/// # Examples
271///
272/// ```rust
273/// use scirs2_neural::utils::datasets::{Dataset, DataLoader};
274/// use scirs2_core::ndarray::Array2;
275///
276/// let features = Array2::<f64>::zeros((100, 10));
277/// let labels = Array2::<f64>::zeros((100, 3));
278/// let dataset = Dataset::new(features, labels).expect("failed to create dataset");
279///
280/// let mut loader = DataLoader::new(dataset, 16, true, true);
281///
282/// for epoch in 0..2 {
283///     for batch_result in loader.iter() {
284///         let (x, y) = batch_result.expect("batch failed");
285///         // Process batch
286///     }
287///     loader.on_epoch_end(); // Shuffle for next epoch
288/// }
289/// ```
290pub struct DataLoader<F: Float + Debug + NumAssign> {
291    dataset: Dataset<F>,
292    batch_size: usize,
293    shuffle: bool,
294    drop_last: bool,
295}
296
297impl<F: Float + Debug + NumAssign> DataLoader<F> {
298    /// Create a new data loader
299    ///
300    /// # Arguments
301    /// * `dataset` - The dataset to load
302    /// * `batch_size` - Size of each batch
303    /// * `shuffle` - Whether to shuffle the data each epoch
304    /// * `drop_last` - Whether to drop the last batch if it's incomplete
305    pub fn new(dataset: Dataset<F>, batch_size: usize, shuffle: bool, drop_last: bool) -> Self {
306        Self {
307            dataset,
308            batch_size,
309            shuffle,
310            drop_last,
311        }
312    }
313
314    /// Get the number of batches per epoch
315    pub fn num_batches(&self) -> usize {
316        let n = self.dataset.len();
317        if self.drop_last {
318            n / self.batch_size
319        } else {
320            n.div_ceil(self.batch_size)
321        }
322    }
323
324    /// Get the dataset size
325    pub fn len(&self) -> usize {
326        self.dataset.len()
327    }
328
329    /// Check if the data loader is empty
330    pub fn is_empty(&self) -> bool {
331        self.dataset.is_empty()
332    }
333
334    /// Get an iterator over batches
335    pub fn iter(&self) -> BatchIterator<'_, F> {
336        BatchIterator::new(&self.dataset, self.batch_size, self.drop_last)
337    }
338
339    /// Call this at the end of each epoch to shuffle data
340    pub fn on_epoch_end(&mut self) {
341        if self.shuffle {
342            let mut rng = scirs2_core::random::rng();
343            self.dataset.shuffle(&mut rng);
344        }
345    }
346
347    /// Get a reference to the underlying dataset
348    pub fn dataset(&self) -> &Dataset<F> {
349        &self.dataset
350    }
351}
352
353/// Normalization strategy for features
354#[derive(Debug, Clone, Copy)]
355pub enum Normalization {
356    /// Standard normalization: (x - mean) / std
357    StandardScaler,
358    /// Min-max normalization: (x - min) / (max - min)
359    MinMaxScaler,
360    /// No normalization
361    None,
362}
363
364/// Normalize features according to the specified strategy
365///
366/// # Arguments
367/// * `features` - Feature matrix [num_samples, num_features]
368/// * `strategy` - Normalization strategy to apply
369///
370/// # Returns
371/// Normalized feature matrix
372pub fn normalize_features<F: Float + Debug + NumAssign>(
373    features: &Array2<F>,
374    strategy: Normalization,
375) -> Array2<F> {
376    match strategy {
377        Normalization::None => features.clone(),
378        Normalization::StandardScaler => {
379            let mut result = features.clone();
380            for j in 0..features.ncols() {
381                // Compute mean
382                let mut sum = F::zero();
383                for i in 0..features.nrows() {
384                    sum += features[[i, j]];
385                }
386                let mean = sum / F::from(features.nrows()).unwrap_or(F::one());
387
388                // Compute std
389                let mut var_sum = F::zero();
390                for i in 0..features.nrows() {
391                    let diff = features[[i, j]] - mean;
392                    var_sum += diff * diff;
393                }
394                let std = (var_sum / F::from(features.nrows()).unwrap_or(F::one())).sqrt();
395                let std = if std < F::from(1e-8).unwrap_or(F::zero()) {
396                    F::one()
397                } else {
398                    std
399                };
400
401                // Normalize
402                for i in 0..features.nrows() {
403                    result[[i, j]] = (features[[i, j]] - mean) / std;
404                }
405            }
406            result
407        }
408        Normalization::MinMaxScaler => {
409            let mut result = features.clone();
410            for j in 0..features.ncols() {
411                // Find min and max
412                let mut min_val = features[[0, j]];
413                let mut max_val = features[[0, j]];
414                for i in 1..features.nrows() {
415                    if features[[i, j]] < min_val {
416                        min_val = features[[i, j]];
417                    }
418                    if features[[i, j]] > max_val {
419                        max_val = features[[i, j]];
420                    }
421                }
422
423                let range = max_val - min_val;
424                let range = if range < F::from(1e-8).unwrap_or(F::zero()) {
425                    F::one()
426                } else {
427                    range
428                };
429
430                // Normalize
431                for i in 0..features.nrows() {
432                    result[[i, j]] = (features[[i, j]] - min_val) / range;
433                }
434            }
435            result
436        }
437    }
438}
439
440/// One-hot encode labels
441///
442/// # Arguments
443/// * `labels` - Integer label array
444/// * `num_classes` - Number of classes
445///
446/// # Returns
447/// One-hot encoded label matrix
448pub fn one_hot_encode<F: Float + Debug + NumAssign>(
449    labels: &[usize],
450    num_classes: usize,
451) -> Array2<F> {
452    let n = labels.len();
453    let mut encoded = Array2::zeros((n, num_classes));
454
455    for (i, &label) in labels.iter().enumerate() {
456        if label < num_classes {
457            encoded[[i, label]] = F::one();
458        }
459    }
460
461    encoded
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use scirs2_core::random::rng;
468
469    #[test]
470    fn test_dataset_creation() {
471        let features = Array2::<f64>::zeros((100, 10));
472        let labels = Array2::<f64>::zeros((100, 3));
473
474        let dataset = Dataset::new(features, labels).expect("Operation failed");
475        assert_eq!(dataset.len(), 100);
476        assert_eq!(dataset.num_features(), 10);
477        assert_eq!(dataset.num_labels(), 3);
478    }
479
480    #[test]
481    fn test_dataset_mismatched_sizes() {
482        let features = Array2::<f64>::zeros((100, 10));
483        let labels = Array2::<f64>::zeros((50, 3)); // Wrong size
484
485        let result = Dataset::new(features, labels);
486        assert!(result.is_err());
487    }
488
489    #[test]
490    fn test_dataset_shuffle() {
491        let mut features = Array2::<f64>::zeros((10, 2));
492        for i in 0..10 {
493            features[[i, 0]] = i as f64;
494        }
495        let labels = Array2::<f64>::zeros((10, 1));
496
497        let mut dataset = Dataset::new(features.clone(), labels).expect("Operation failed");
498        let original_indices = dataset.indices.clone();
499
500        let mut rng = rng();
501        dataset.shuffle(&mut rng);
502
503        // Indices should be different after shuffle (very unlikely to be same)
504        assert_ne!(dataset.indices, original_indices);
505    }
506
507    #[test]
508    fn test_get_batch() {
509        let mut features = Array2::<f64>::zeros((10, 2));
510        let mut labels = Array2::<f64>::zeros((10, 1));
511        for i in 0..10 {
512            features[[i, 0]] = i as f64;
513            labels[[i, 0]] = i as f64;
514        }
515
516        let dataset = Dataset::new(features, labels).expect("Operation failed");
517        let (batch_x, batch_y) = dataset.get_batch(0, 5).expect("Operation failed");
518
519        assert_eq!(batch_x.nrows(), 5);
520        assert_eq!(batch_y.nrows(), 5);
521    }
522
523    #[test]
524    fn test_train_val_split() {
525        let features = Array2::<f64>::ones((100, 10));
526        let labels = Array2::<f64>::zeros((100, 3));
527
528        let dataset = Dataset::new(features, labels).expect("Operation failed");
529        let mut rng = rng();
530        let (train, val) = dataset
531            .train_val_split(0.8, &mut rng)
532            .expect("Operation failed");
533
534        assert_eq!(train.len(), 80);
535        assert_eq!(val.len(), 20);
536    }
537
538    #[test]
539    fn test_batch_iterator() {
540        let features = Array2::<f64>::zeros((25, 5));
541        let labels = Array2::<f64>::zeros((25, 2));
542
543        let dataset = Dataset::new(features, labels).expect("Operation failed");
544        let iter = BatchIterator::new(&dataset, 10, false);
545
546        assert_eq!(iter.num_batches(), 3); // 25 / 10 = 2.5, rounded up to 3
547
548        let batches: Vec<_> = iter.collect();
549        assert_eq!(batches.len(), 3);
550    }
551
552    #[test]
553    fn test_batch_iterator_drop_last() {
554        let features = Array2::<f64>::zeros((25, 5));
555        let labels = Array2::<f64>::zeros((25, 2));
556
557        let dataset = Dataset::new(features, labels).expect("Operation failed");
558        let iter = BatchIterator::new(&dataset, 10, true);
559
560        assert_eq!(iter.num_batches(), 2); // 25 / 10 = 2 (drop remainder)
561
562        let batches: Vec<_> = iter.collect();
563        assert_eq!(batches.len(), 2);
564    }
565
566    #[test]
567    fn test_data_loader() {
568        let features = Array2::<f64>::zeros((50, 10));
569        let labels = Array2::<f64>::zeros((50, 3));
570
571        let dataset = Dataset::new(features, labels).expect("Operation failed");
572        let loader = DataLoader::new(dataset, 16, true, false);
573
574        assert_eq!(loader.len(), 50);
575        assert_eq!(loader.num_batches(), 4); // ceil(50/16) = 4
576    }
577
578    #[test]
579    fn test_standard_scaler() {
580        let mut features = Array2::<f64>::zeros((100, 2));
581        for i in 0..100 {
582            features[[i, 0]] = i as f64;
583            features[[i, 1]] = (i as f64) * 2.0;
584        }
585
586        let normalized = normalize_features(&features, Normalization::StandardScaler);
587
588        // Check that mean is approximately 0
589        let mean_col0: f64 = normalized.column(0).iter().sum::<f64>() / 100.0;
590        assert!(mean_col0.abs() < 1e-10);
591    }
592
593    #[test]
594    fn test_minmax_scaler() {
595        let mut features = Array2::<f64>::zeros((10, 1));
596        for i in 0..10 {
597            features[[i, 0]] = i as f64 * 10.0; // 0, 10, 20, ..., 90
598        }
599
600        let normalized = normalize_features(&features, Normalization::MinMaxScaler);
601
602        // Check range is [0, 1]
603        let min_val: f64 = normalized.iter().cloned().fold(f64::INFINITY, f64::min);
604        let max_val: f64 = normalized.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
605
606        assert!((min_val - 0.0).abs() < 1e-10);
607        assert!((max_val - 1.0).abs() < 1e-10);
608    }
609
610    #[test]
611    fn test_one_hot_encode() {
612        let labels = vec![0, 1, 2, 0, 1];
613        let encoded: Array2<f64> = one_hot_encode(&labels, 3);
614
615        assert_eq!(encoded.nrows(), 5);
616        assert_eq!(encoded.ncols(), 3);
617
618        // Check encoding
619        assert_eq!(encoded[[0, 0]], 1.0);
620        assert_eq!(encoded[[1, 1]], 1.0);
621        assert_eq!(encoded[[2, 2]], 1.0);
622    }
623}