1use crate::error::{NeuralError, Result};
7use scirs2_core::ndarray::{s, Array, Array2, ArrayView2, Axis, IxDyn};
8use scirs2_core::numeric::{Float, NumAssign};
9use scirs2_core::random::{Rng, RngExt};
10use std::fmt::Debug;
11use std::marker::PhantomData;
12
13#[derive(Debug, Clone)]
31pub struct Dataset<F: Float + Debug + NumAssign> {
32 features: Array2<F>,
34 labels: Array2<F>,
36 indices: Vec<usize>,
38}
39
40impl<F: Float + Debug + NumAssign> Dataset<F> {
41 pub fn new(features: Array2<F>, labels: Array2<F>) -> Result<Self> {
50 if features.nrows() != labels.nrows() {
51 return Err(NeuralError::InvalidArchitecture(format!(
52 "Features and labels must have same number of samples: {} vs {}",
53 features.nrows(),
54 labels.nrows()
55 )));
56 }
57
58 let num_samples = features.nrows();
59 let indices: Vec<usize> = (0..num_samples).collect();
60
61 Ok(Self {
62 features,
63 labels,
64 indices,
65 })
66 }
67
68 pub fn len(&self) -> usize {
70 self.features.nrows()
71 }
72
73 pub fn is_empty(&self) -> bool {
75 self.features.nrows() == 0
76 }
77
78 pub fn num_features(&self) -> usize {
80 self.features.ncols()
81 }
82
83 pub fn num_labels(&self) -> usize {
85 self.labels.ncols()
86 }
87
88 pub fn features(&self) -> &Array2<F> {
90 &self.features
91 }
92
93 pub fn labels(&self) -> &Array2<F> {
95 &self.labels
96 }
97
98 pub fn shuffle<R: Rng>(&mut self, rng: &mut R) {
103 let n = self.indices.len();
104 for i in (1..n).rev() {
105 let j = (rng.random::<f64>() * (i + 1) as f64) as usize;
106 self.indices.swap(i, j);
107 }
108 }
109
110 pub fn get_batch(&self, start: usize, end: usize) -> Result<(Array2<F>, Array2<F>)> {
119 let end = end.min(self.len());
120 if start >= end {
121 return Err(NeuralError::InvalidArchitecture(format!(
122 "Invalid batch range: {}..{}",
123 start, end
124 )));
125 }
126
127 let batch_indices: Vec<usize> = self.indices[start..end].to_vec();
128 let batch_size = batch_indices.len();
129
130 let mut features_batch = Array2::zeros((batch_size, self.num_features()));
132 let mut labels_batch = Array2::zeros((batch_size, self.num_labels()));
133
134 for (batch_idx, &sample_idx) in batch_indices.iter().enumerate() {
135 for f in 0..self.num_features() {
136 features_batch[[batch_idx, f]] = self.features[[sample_idx, f]];
137 }
138 for l in 0..self.num_labels() {
139 labels_batch[[batch_idx, l]] = self.labels[[sample_idx, l]];
140 }
141 }
142
143 Ok((features_batch, labels_batch))
144 }
145
146 pub fn train_val_split<R: Rng>(
155 mut self,
156 train_ratio: f64,
157 rng: &mut R,
158 ) -> Result<(Self, Self)> {
159 if !(0.0..=1.0).contains(&train_ratio) {
160 return Err(NeuralError::InvalidArchitecture(format!(
161 "train_ratio must be between 0 and 1, got {}",
162 train_ratio
163 )));
164 }
165
166 self.shuffle(rng);
168
169 let n = self.len();
170 let train_size = (n as f64 * train_ratio) as usize;
171
172 let train_indices: Vec<usize> = self.indices[..train_size].to_vec();
174 let val_indices: Vec<usize> = self.indices[train_size..].to_vec();
175
176 let mut train_features = Array2::zeros((train_size, self.num_features()));
178 let mut train_labels = Array2::zeros((train_size, self.num_labels()));
179 for (new_idx, &old_idx) in train_indices.iter().enumerate() {
180 for f in 0..self.num_features() {
181 train_features[[new_idx, f]] = self.features[[old_idx, f]];
182 }
183 for l in 0..self.num_labels() {
184 train_labels[[new_idx, l]] = self.labels[[old_idx, l]];
185 }
186 }
187
188 let val_size = n - train_size;
190 let mut val_features = Array2::zeros((val_size, self.num_features()));
191 let mut val_labels = Array2::zeros((val_size, self.num_labels()));
192 for (new_idx, &old_idx) in val_indices.iter().enumerate() {
193 for f in 0..self.num_features() {
194 val_features[[new_idx, f]] = self.features[[old_idx, f]];
195 }
196 for l in 0..self.num_labels() {
197 val_labels[[new_idx, l]] = self.labels[[old_idx, l]];
198 }
199 }
200
201 Ok((
202 Dataset::new(train_features, train_labels)?,
203 Dataset::new(val_features, val_labels)?,
204 ))
205 }
206}
207
208pub struct BatchIterator<'a, F: Float + Debug + NumAssign> {
212 dataset: &'a Dataset<F>,
213 batch_size: usize,
214 current_idx: usize,
215 drop_last: bool,
216}
217
218impl<'a, F: Float + Debug + NumAssign> BatchIterator<'a, F> {
219 pub fn new(dataset: &'a Dataset<F>, batch_size: usize, drop_last: bool) -> Self {
226 Self {
227 dataset,
228 batch_size,
229 current_idx: 0,
230 drop_last,
231 }
232 }
233
234 pub fn num_batches(&self) -> usize {
236 let n = self.dataset.len();
237 if self.drop_last {
238 n / self.batch_size
239 } else {
240 n.div_ceil(self.batch_size)
241 }
242 }
243}
244
245impl<'a, F: Float + Debug + NumAssign> Iterator for BatchIterator<'a, F> {
246 type Item = Result<(Array2<F>, Array2<F>)>;
247
248 fn next(&mut self) -> Option<Self::Item> {
249 if self.current_idx >= self.dataset.len() {
250 return None;
251 }
252
253 let start = self.current_idx;
254 let end = (start + self.batch_size).min(self.dataset.len());
255
256 if self.drop_last && end - start < self.batch_size {
258 return None;
259 }
260
261 self.current_idx = end;
262 Some(self.dataset.get_batch(start, end))
263 }
264}
265
266pub struct DataLoader<F: Float + Debug + NumAssign> {
291 dataset: Dataset<F>,
292 batch_size: usize,
293 shuffle: bool,
294 drop_last: bool,
295}
296
297impl<F: Float + Debug + NumAssign> DataLoader<F> {
298 pub fn new(dataset: Dataset<F>, batch_size: usize, shuffle: bool, drop_last: bool) -> Self {
306 Self {
307 dataset,
308 batch_size,
309 shuffle,
310 drop_last,
311 }
312 }
313
314 pub fn num_batches(&self) -> usize {
316 let n = self.dataset.len();
317 if self.drop_last {
318 n / self.batch_size
319 } else {
320 n.div_ceil(self.batch_size)
321 }
322 }
323
324 pub fn len(&self) -> usize {
326 self.dataset.len()
327 }
328
329 pub fn is_empty(&self) -> bool {
331 self.dataset.is_empty()
332 }
333
334 pub fn iter(&self) -> BatchIterator<'_, F> {
336 BatchIterator::new(&self.dataset, self.batch_size, self.drop_last)
337 }
338
339 pub fn on_epoch_end(&mut self) {
341 if self.shuffle {
342 let mut rng = scirs2_core::random::rng();
343 self.dataset.shuffle(&mut rng);
344 }
345 }
346
347 pub fn dataset(&self) -> &Dataset<F> {
349 &self.dataset
350 }
351}
352
353#[derive(Debug, Clone, Copy)]
355pub enum Normalization {
356 StandardScaler,
358 MinMaxScaler,
360 None,
362}
363
364pub fn normalize_features<F: Float + Debug + NumAssign>(
373 features: &Array2<F>,
374 strategy: Normalization,
375) -> Array2<F> {
376 match strategy {
377 Normalization::None => features.clone(),
378 Normalization::StandardScaler => {
379 let mut result = features.clone();
380 for j in 0..features.ncols() {
381 let mut sum = F::zero();
383 for i in 0..features.nrows() {
384 sum += features[[i, j]];
385 }
386 let mean = sum / F::from(features.nrows()).unwrap_or(F::one());
387
388 let mut var_sum = F::zero();
390 for i in 0..features.nrows() {
391 let diff = features[[i, j]] - mean;
392 var_sum += diff * diff;
393 }
394 let std = (var_sum / F::from(features.nrows()).unwrap_or(F::one())).sqrt();
395 let std = if std < F::from(1e-8).unwrap_or(F::zero()) {
396 F::one()
397 } else {
398 std
399 };
400
401 for i in 0..features.nrows() {
403 result[[i, j]] = (features[[i, j]] - mean) / std;
404 }
405 }
406 result
407 }
408 Normalization::MinMaxScaler => {
409 let mut result = features.clone();
410 for j in 0..features.ncols() {
411 let mut min_val = features[[0, j]];
413 let mut max_val = features[[0, j]];
414 for i in 1..features.nrows() {
415 if features[[i, j]] < min_val {
416 min_val = features[[i, j]];
417 }
418 if features[[i, j]] > max_val {
419 max_val = features[[i, j]];
420 }
421 }
422
423 let range = max_val - min_val;
424 let range = if range < F::from(1e-8).unwrap_or(F::zero()) {
425 F::one()
426 } else {
427 range
428 };
429
430 for i in 0..features.nrows() {
432 result[[i, j]] = (features[[i, j]] - min_val) / range;
433 }
434 }
435 result
436 }
437 }
438}
439
440pub fn one_hot_encode<F: Float + Debug + NumAssign>(
449 labels: &[usize],
450 num_classes: usize,
451) -> Array2<F> {
452 let n = labels.len();
453 let mut encoded = Array2::zeros((n, num_classes));
454
455 for (i, &label) in labels.iter().enumerate() {
456 if label < num_classes {
457 encoded[[i, label]] = F::one();
458 }
459 }
460
461 encoded
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use scirs2_core::random::rng;
468
469 #[test]
470 fn test_dataset_creation() {
471 let features = Array2::<f64>::zeros((100, 10));
472 let labels = Array2::<f64>::zeros((100, 3));
473
474 let dataset = Dataset::new(features, labels).expect("Operation failed");
475 assert_eq!(dataset.len(), 100);
476 assert_eq!(dataset.num_features(), 10);
477 assert_eq!(dataset.num_labels(), 3);
478 }
479
480 #[test]
481 fn test_dataset_mismatched_sizes() {
482 let features = Array2::<f64>::zeros((100, 10));
483 let labels = Array2::<f64>::zeros((50, 3)); let result = Dataset::new(features, labels);
486 assert!(result.is_err());
487 }
488
489 #[test]
490 fn test_dataset_shuffle() {
491 let mut features = Array2::<f64>::zeros((10, 2));
492 for i in 0..10 {
493 features[[i, 0]] = i as f64;
494 }
495 let labels = Array2::<f64>::zeros((10, 1));
496
497 let mut dataset = Dataset::new(features.clone(), labels).expect("Operation failed");
498 let original_indices = dataset.indices.clone();
499
500 let mut rng = rng();
501 dataset.shuffle(&mut rng);
502
503 assert_ne!(dataset.indices, original_indices);
505 }
506
507 #[test]
508 fn test_get_batch() {
509 let mut features = Array2::<f64>::zeros((10, 2));
510 let mut labels = Array2::<f64>::zeros((10, 1));
511 for i in 0..10 {
512 features[[i, 0]] = i as f64;
513 labels[[i, 0]] = i as f64;
514 }
515
516 let dataset = Dataset::new(features, labels).expect("Operation failed");
517 let (batch_x, batch_y) = dataset.get_batch(0, 5).expect("Operation failed");
518
519 assert_eq!(batch_x.nrows(), 5);
520 assert_eq!(batch_y.nrows(), 5);
521 }
522
523 #[test]
524 fn test_train_val_split() {
525 let features = Array2::<f64>::ones((100, 10));
526 let labels = Array2::<f64>::zeros((100, 3));
527
528 let dataset = Dataset::new(features, labels).expect("Operation failed");
529 let mut rng = rng();
530 let (train, val) = dataset
531 .train_val_split(0.8, &mut rng)
532 .expect("Operation failed");
533
534 assert_eq!(train.len(), 80);
535 assert_eq!(val.len(), 20);
536 }
537
538 #[test]
539 fn test_batch_iterator() {
540 let features = Array2::<f64>::zeros((25, 5));
541 let labels = Array2::<f64>::zeros((25, 2));
542
543 let dataset = Dataset::new(features, labels).expect("Operation failed");
544 let iter = BatchIterator::new(&dataset, 10, false);
545
546 assert_eq!(iter.num_batches(), 3); let batches: Vec<_> = iter.collect();
549 assert_eq!(batches.len(), 3);
550 }
551
552 #[test]
553 fn test_batch_iterator_drop_last() {
554 let features = Array2::<f64>::zeros((25, 5));
555 let labels = Array2::<f64>::zeros((25, 2));
556
557 let dataset = Dataset::new(features, labels).expect("Operation failed");
558 let iter = BatchIterator::new(&dataset, 10, true);
559
560 assert_eq!(iter.num_batches(), 2); let batches: Vec<_> = iter.collect();
563 assert_eq!(batches.len(), 2);
564 }
565
566 #[test]
567 fn test_data_loader() {
568 let features = Array2::<f64>::zeros((50, 10));
569 let labels = Array2::<f64>::zeros((50, 3));
570
571 let dataset = Dataset::new(features, labels).expect("Operation failed");
572 let loader = DataLoader::new(dataset, 16, true, false);
573
574 assert_eq!(loader.len(), 50);
575 assert_eq!(loader.num_batches(), 4); }
577
578 #[test]
579 fn test_standard_scaler() {
580 let mut features = Array2::<f64>::zeros((100, 2));
581 for i in 0..100 {
582 features[[i, 0]] = i as f64;
583 features[[i, 1]] = (i as f64) * 2.0;
584 }
585
586 let normalized = normalize_features(&features, Normalization::StandardScaler);
587
588 let mean_col0: f64 = normalized.column(0).iter().sum::<f64>() / 100.0;
590 assert!(mean_col0.abs() < 1e-10);
591 }
592
593 #[test]
594 fn test_minmax_scaler() {
595 let mut features = Array2::<f64>::zeros((10, 1));
596 for i in 0..10 {
597 features[[i, 0]] = i as f64 * 10.0; }
599
600 let normalized = normalize_features(&features, Normalization::MinMaxScaler);
601
602 let min_val: f64 = normalized.iter().cloned().fold(f64::INFINITY, f64::min);
604 let max_val: f64 = normalized.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
605
606 assert!((min_val - 0.0).abs() < 1e-10);
607 assert!((max_val - 1.0).abs() < 1e-10);
608 }
609
610 #[test]
611 fn test_one_hot_encode() {
612 let labels = vec![0, 1, 2, 0, 1];
613 let encoded: Array2<f64> = one_hot_encode(&labels, 3);
614
615 assert_eq!(encoded.nrows(), 5);
616 assert_eq!(encoded.ncols(), 3);
617
618 assert_eq!(encoded[[0, 0]], 1.0);
620 assert_eq!(encoded[[1, 1]], 1.0);
621 assert_eq!(encoded[[2, 2]], 1.0);
622 }
623}