1use scirs2_core::random::rng;
8use std::marker::PhantomData;
9use std::sync::Arc;
10use tenflowers_core::ops::slice;
11use tenflowers_core::{Result, Tensor, TensorError};
12
13pub trait Dataset<T> {
18 fn len(&self) -> usize;
19 fn is_empty(&self) -> bool {
20 self.len() == 0
21 }
22 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)>;
23 fn batch(self, batch_size: usize) -> BatchedDataset<T, Self>
24 where
25 Self: Sized,
26 {
27 BatchedDataset {
28 dataset: self,
29 batch_size,
30 current_index: 0,
31 _phantom: PhantomData,
32 }
33 }
34}
35
36impl<T, D: Dataset<T>> Dataset<T> for Arc<D> {
38 fn len(&self) -> usize {
39 (**self).len()
40 }
41
42 fn is_empty(&self) -> bool {
43 (**self).is_empty()
44 }
45
46 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
47 (**self).get(index)
48 }
49}
50
51pub trait DatasetUtilsExt<T>: Dataset<T> {
53 fn get_multiple(&self, indices: &[usize]) -> Result<Vec<(Tensor<T>, Tensor<T>)>> {
55 let mut samples = Vec::with_capacity(indices.len());
56 for &index in indices {
57 samples.push(self.get(index)?);
58 }
59 Ok(samples)
60 }
61
62 fn get_range(&self, start: usize, end: usize) -> Result<Vec<(Tensor<T>, Tensor<T>)>> {
64 if start >= end {
65 return Ok(Vec::new());
66 }
67 if end > self.len() {
68 return Err(TensorError::invalid_argument(format!(
69 "End index {} out of bounds for dataset of length {}",
70 end,
71 self.len()
72 )));
73 }
74
75 let mut samples = Vec::with_capacity(end - start);
76 for i in start..end {
77 samples.push(self.get(i)?);
78 }
79 Ok(samples)
80 }
81
82 fn get_random(&self) -> Result<(Tensor<T>, Tensor<T>)> {
84 use scirs2_core::random::rand_prelude::*;
85 if self.is_empty() {
86 return Err(TensorError::invalid_argument(
87 "Cannot get random sample from empty dataset".to_string(),
88 ));
89 }
90 let mut rng = rng();
91 let random_val: f64 = rng.random();
92 let index = (random_val * self.len() as f64) as usize;
93 let index = index.min(self.len() - 1); self.get(index)
95 }
96
97 fn get_random_samples(&self, count: usize) -> Result<Vec<(Tensor<T>, Tensor<T>)>> {
99 use scirs2_core::random::rand_prelude::*;
100 if self.is_empty() {
101 return Err(TensorError::invalid_argument(
102 "Cannot get random samples from empty dataset".to_string(),
103 ));
104 }
105
106 let mut rng = rng();
107 let mut samples = Vec::with_capacity(count);
108 for _ in 0..count {
109 let random_val: f64 = rng.random();
110 let index = (random_val * self.len() as f64) as usize;
111 let index = index.min(self.len() - 1); samples.push(self.get(index)?);
113 }
114 Ok(samples)
115 }
116}
117
118impl<T, D: Dataset<T>> DatasetUtilsExt<T> for D {}
120
121#[derive(Clone)]
122pub struct TensorDataset<T> {
123 features: Tensor<T>,
124 #[allow(dead_code)]
125 labels: Tensor<T>,
126}
127
128impl<T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static> TensorDataset<T> {
129 pub fn new(features: Tensor<T>, labels: Tensor<T>) -> Self {
130 Self { features, labels }
131 }
132}
133
134impl<T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static> Dataset<T>
135 for TensorDataset<T>
136{
137 fn len(&self) -> usize {
138 self.features.shape().dims()[0]
139 }
140
141 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
142 if index >= self.len() {
143 return Err(TensorError::invalid_argument(format!(
144 "Index {} out of bounds for dataset of length {}",
145 index,
146 self.len()
147 )));
148 }
149
150 let mut feature_ranges = Vec::new();
152 let mut label_ranges = Vec::new();
153
154 feature_ranges.push(index..index + 1);
156 label_ranges.push(index..index + 1);
157
158 for i in 1..self.features.shape().rank() {
160 feature_ranges.push(0..self.features.shape().dims()[i]);
161 }
162 for i in 1..self.labels.shape().rank() {
163 label_ranges.push(0..self.labels.shape().dims()[i]);
164 }
165
166 let feature_slice = slice(&self.features, &feature_ranges)?;
168 let label_slice = slice(&self.labels, &label_ranges)?;
169
170 let feature_squeezed = squeeze_first_dim(&feature_slice)?;
172 let label_squeezed = squeeze_first_dim(&label_slice)?;
173
174 Ok((feature_squeezed, label_squeezed))
175 }
176}
177
178fn squeeze_first_dim<T>(tensor: &Tensor<T>) -> Result<Tensor<T>>
180where
181 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
182{
183 let shape = tensor.shape();
184 if shape.rank() == 0 {
185 return Ok(tensor.clone());
186 }
187
188 if shape.dims()[0] != 1 {
189 return Err(TensorError::invalid_argument(format!(
190 "Cannot squeeze dimension of size {}",
191 shape.dims()[0]
192 )));
193 }
194
195 let new_shape: Vec<usize> = shape.dims()[1..].to_vec();
196 tenflowers_core::ops::reshape(tensor, &new_shape)
197}
198
199pub struct BatchedDataset<T, D: Dataset<T>> {
200 dataset: D,
201 batch_size: usize,
202 current_index: usize,
203 _phantom: PhantomData<T>,
204}
205
206impl<T, D: Dataset<T>> Iterator for BatchedDataset<T, D> {
207 type Item = Vec<(Tensor<T>, Tensor<T>)>;
208
209 fn next(&mut self) -> Option<Self::Item> {
210 if self.current_index >= self.dataset.len() {
211 return None;
212 }
213
214 let mut batch = Vec::new();
215 let end_index = (self.current_index + self.batch_size).min(self.dataset.len());
216
217 for i in self.current_index..end_index {
218 match self.dataset.get(i) {
219 Ok(sample) => batch.push(sample),
220 Err(_) => break, }
222 }
223
224 self.current_index = end_index;
225
226 if batch.is_empty() {
227 None
228 } else {
229 Some(batch)
230 }
231 }
232}
233
234pub struct ConcatDataset<T, D: Dataset<T>> {
236 datasets: Vec<D>,
237 cumulative_lengths: Vec<usize>,
238 total_length: usize,
239 _phantom: PhantomData<T>,
240}
241
242impl<T, D: Dataset<T>> ConcatDataset<T, D> {
243 pub fn new(datasets: Vec<D>) -> Self {
244 let mut cumulative_lengths = Vec::with_capacity(datasets.len());
245 let mut total_length = 0;
246
247 for dataset in &datasets {
248 total_length += dataset.len();
249 cumulative_lengths.push(total_length);
250 }
251
252 Self {
253 datasets,
254 cumulative_lengths,
255 total_length,
256 _phantom: PhantomData,
257 }
258 }
259
260 fn find_dataset_and_index(&self, global_index: usize) -> Result<(usize, usize)> {
262 for (dataset_idx, &cumulative_len) in self.cumulative_lengths.iter().enumerate() {
263 if global_index < cumulative_len {
264 let local_index = if dataset_idx == 0 {
265 global_index
266 } else {
267 global_index - self.cumulative_lengths[dataset_idx - 1]
268 };
269 return Ok((dataset_idx, local_index));
270 }
271 }
272 Err(TensorError::invalid_argument(format!(
273 "Index {} out of bounds for dataset of total length {}",
274 global_index, self.total_length
275 )))
276 }
277}
278
279impl<T, D: Dataset<T>> Dataset<T> for ConcatDataset<T, D> {
280 fn len(&self) -> usize {
281 self.total_length
282 }
283
284 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
285 if index >= self.total_length {
286 return Err(TensorError::invalid_argument(format!(
287 "Index {} out of bounds for dataset of length {}",
288 index, self.total_length
289 )));
290 }
291
292 let (dataset_idx, local_index) = self.find_dataset_and_index(index)?;
293 self.datasets[dataset_idx].get(local_index)
294 }
295}
296
297pub struct FilteredDataset<T, D: Dataset<T>, F: Fn(&(Tensor<T>, Tensor<T>)) -> bool> {
299 dataset: D,
300 valid_indices: Vec<usize>,
301 _phantom: PhantomData<(T, F)>,
302}
303
304impl<T, D: Dataset<T>, F: Fn(&(Tensor<T>, Tensor<T>)) -> bool> FilteredDataset<T, D, F> {
305 pub fn new(dataset: D, predicate: F) -> Result<Self> {
306 let mut valid_indices = Vec::new();
307
308 for i in 0..dataset.len() {
309 match dataset.get(i) {
310 Ok(sample) => {
311 if predicate(&sample) {
312 valid_indices.push(i);
313 }
314 }
315 Err(_) => continue, }
317 }
318
319 Ok(Self {
320 dataset,
321 valid_indices,
322 _phantom: PhantomData,
323 })
324 }
325}
326
327impl<T, D: Dataset<T>, F: Fn(&(Tensor<T>, Tensor<T>)) -> bool> Dataset<T>
328 for FilteredDataset<T, D, F>
329{
330 fn len(&self) -> usize {
331 self.valid_indices.len()
332 }
333
334 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
335 if index >= self.valid_indices.len() {
336 return Err(TensorError::invalid_argument(format!(
337 "Index {} out of bounds for filtered dataset of length {}",
338 index,
339 self.valid_indices.len()
340 )));
341 }
342
343 let actual_index = self.valid_indices[index];
344 self.dataset.get(actual_index)
345 }
346}
347
348pub struct DatasetSplit<T, D: Dataset<T>> {
350 pub train: SubsetDataset<T, Arc<D>>,
351 pub validation: Option<SubsetDataset<T, Arc<D>>>,
352 pub test: Option<SubsetDataset<T, Arc<D>>>,
353}
354
355pub struct SubsetDataset<T, D: Dataset<T>> {
357 dataset: D,
358 indices: Vec<usize>,
359 _phantom: PhantomData<T>,
360}
361
362impl<T, D: Dataset<T>> SubsetDataset<T, D> {
363 pub fn new(dataset: D, indices: Vec<usize>) -> Self {
364 Self {
365 dataset,
366 indices,
367 _phantom: PhantomData,
368 }
369 }
370}
371
372impl<T, D: Dataset<T>> Dataset<T> for SubsetDataset<T, D> {
373 fn len(&self) -> usize {
374 self.indices.len()
375 }
376
377 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
378 if index >= self.indices.len() {
379 return Err(TensorError::invalid_argument(format!(
380 "Index {} out of bounds for subset dataset of length {}",
381 index,
382 self.indices.len()
383 )));
384 }
385
386 let actual_index = self.indices[index];
387 self.dataset.get(actual_index)
388 }
389}
390
391pub struct MergedDataset<T, D1: Dataset<T>, D2: Dataset<T>> {
393 dataset1: D1,
394 dataset2: D2,
395 merge_strategy: MergeStrategy,
396 _phantom: PhantomData<T>,
397}
398
399#[derive(Debug, Clone)]
401pub enum MergeStrategy {
402 FeatureConcatenation,
404 FeatureAverage,
406 FeatureFromFirst,
408 FeatureFromSecond,
410 Custom,
412}
413
414impl<T, D1: Dataset<T>, D2: Dataset<T>> MergedDataset<T, D1, D2> {
415 pub fn new_concatenated(dataset1: D1, dataset2: D2) -> Result<Self> {
417 if dataset1.len() != dataset2.len() {
418 return Err(TensorError::invalid_argument(format!(
419 "Dataset lengths must match: {} vs {}",
420 dataset1.len(),
421 dataset2.len()
422 )));
423 }
424
425 Ok(Self {
426 dataset1,
427 dataset2,
428 merge_strategy: MergeStrategy::FeatureConcatenation,
429 _phantom: PhantomData,
430 })
431 }
432
433 pub fn new_averaged(dataset1: D1, dataset2: D2) -> Result<Self> {
435 if dataset1.len() != dataset2.len() {
436 return Err(TensorError::invalid_argument(format!(
437 "Dataset lengths must match: {} vs {}",
438 dataset1.len(),
439 dataset2.len()
440 )));
441 }
442
443 Ok(Self {
444 dataset1,
445 dataset2,
446 merge_strategy: MergeStrategy::FeatureAverage,
447 _phantom: PhantomData,
448 })
449 }
450
451 pub fn new_features_from_first(dataset1: D1, dataset2: D2) -> Result<Self> {
453 if dataset1.len() != dataset2.len() {
454 return Err(TensorError::invalid_argument(format!(
455 "Dataset lengths must match: {} vs {}",
456 dataset1.len(),
457 dataset2.len()
458 )));
459 }
460
461 Ok(Self {
462 dataset1,
463 dataset2,
464 merge_strategy: MergeStrategy::FeatureFromFirst,
465 _phantom: PhantomData,
466 })
467 }
468
469 pub fn new_features_from_second(dataset1: D1, dataset2: D2) -> Result<Self> {
471 if dataset1.len() != dataset2.len() {
472 return Err(TensorError::invalid_argument(format!(
473 "Dataset lengths must match: {} vs {}",
474 dataset1.len(),
475 dataset2.len()
476 )));
477 }
478
479 Ok(Self {
480 dataset1,
481 dataset2,
482 merge_strategy: MergeStrategy::FeatureFromSecond,
483 _phantom: PhantomData,
484 })
485 }
486
487 fn merge_tensors(&self, tensor1: &Tensor<T>, tensor2: &Tensor<T>) -> Result<Tensor<T>>
489 where
490 T: Clone
491 + Default
492 + scirs2_core::numeric::Zero
493 + scirs2_core::numeric::Float
494 + Send
495 + Sync
496 + 'static,
497 {
498 match self.merge_strategy {
499 MergeStrategy::FeatureConcatenation => {
500 let data1 = tensor1.as_slice().ok_or_else(|| {
502 TensorError::invalid_argument(
503 "Cannot access tensor data (GPU tensor not supported)".to_string(),
504 )
505 })?;
506 let data2 = tensor2.as_slice().ok_or_else(|| {
507 TensorError::invalid_argument(
508 "Cannot access tensor data (GPU tensor not supported)".to_string(),
509 )
510 })?;
511 let mut merged_data = Vec::new();
512 merged_data.extend_from_slice(data1);
513 merged_data.extend_from_slice(data2);
514
515 let new_shape = vec![data1.len() + data2.len()];
516 Tensor::from_vec(merged_data, &new_shape)
517 }
518 MergeStrategy::FeatureAverage => {
519 let data1 = tensor1.as_slice().ok_or_else(|| {
521 TensorError::invalid_argument(
522 "Cannot access tensor data (GPU tensor not supported)".to_string(),
523 )
524 })?;
525 let data2 = tensor2.as_slice().ok_or_else(|| {
526 TensorError::invalid_argument(
527 "Cannot access tensor data (GPU tensor not supported)".to_string(),
528 )
529 })?;
530
531 if data1.len() != data2.len() {
532 return Err(TensorError::invalid_argument(
533 "Cannot average tensors of different sizes".to_string(),
534 ));
535 }
536
537 let mut averaged_data = Vec::new();
538 let two = T::from(2.0).expect("conversion of 2.0 to float type should succeed");
539 for (v1, v2) in data1.iter().zip(data2.iter()) {
540 let avg = (*v1 + *v2) / two;
541 averaged_data.push(avg);
542 }
543
544 Tensor::from_vec(averaged_data, tensor1.shape().dims())
545 }
546 MergeStrategy::FeatureFromFirst => Ok(tensor1.clone()),
547 MergeStrategy::FeatureFromSecond => Ok(tensor2.clone()),
548 MergeStrategy::Custom => {
549 Ok(tensor1.clone())
552 }
553 }
554 }
555}
556
557impl<T, D1: Dataset<T>, D2: Dataset<T>> Dataset<T> for MergedDataset<T, D1, D2>
558where
559 T: Clone
560 + Default
561 + scirs2_core::numeric::Zero
562 + scirs2_core::numeric::Float
563 + Send
564 + Sync
565 + 'static,
566{
567 fn len(&self) -> usize {
568 self.dataset1.len()
569 }
570
571 fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
572 if index >= self.dataset1.len() {
573 return Err(TensorError::invalid_argument(format!(
574 "Index {} out of bounds for merged dataset of length {}",
575 index,
576 self.dataset1.len()
577 )));
578 }
579
580 let (features1, labels1) = self.dataset1.get(index)?;
581 let (features2, labels2) = self.dataset2.get(index)?;
582
583 let merged_features = self.merge_tensors(&features1, &features2)?;
584
585 let merged_labels = match self.merge_strategy {
587 MergeStrategy::FeatureFromFirst => labels1,
588 MergeStrategy::FeatureFromSecond => labels2,
589 _ => labels1, };
591
592 Ok((merged_features, merged_labels))
593 }
594}
595
596pub struct DatasetSplitter;
598
599impl DatasetSplitter {
600 pub fn split<T, D: Dataset<T>>(
602 dataset: D,
603 train_ratio: f64,
604 val_ratio: Option<f64>,
605 test_ratio: Option<f64>,
606 shuffle: bool,
607 ) -> Result<DatasetSplit<T, D>> {
608 let total_len = dataset.len();
609 if total_len == 0 {
610 return Err(TensorError::invalid_argument(
611 "Cannot split empty dataset".to_string(),
612 ));
613 }
614
615 let val_ratio = val_ratio.unwrap_or(0.0);
617 let test_ratio = test_ratio.unwrap_or(0.0);
618
619 if train_ratio + val_ratio + test_ratio > 1.0 {
620 return Err(TensorError::invalid_argument(
621 "Sum of ratios cannot exceed 1.0".to_string(),
622 ));
623 }
624
625 let mut indices: Vec<usize> = (0..total_len).collect();
627
628 if shuffle {
630 use scirs2_core::random::rand_prelude::*;
631 let mut rng = rng();
632 indices.shuffle(&mut rng);
633 }
634
635 let train_end = (total_len as f64 * train_ratio) as usize;
637 let val_end = train_end + (total_len as f64 * val_ratio) as usize;
638 let test_end = val_end + (total_len as f64 * test_ratio) as usize;
639
640 let dataset_arc = Arc::new(dataset);
642 let train_indices = indices[0..train_end].to_vec();
643 let train = SubsetDataset::new(dataset_arc.clone(), train_indices);
644
645 let validation = if val_ratio > 0.0 {
646 let val_indices = indices[train_end..val_end].to_vec();
647 Some(SubsetDataset::new(dataset_arc.clone(), val_indices))
648 } else {
649 None
650 };
651
652 let test = if test_ratio > 0.0 {
653 let test_indices = indices[val_end..test_end].to_vec();
654 Some(SubsetDataset::new(dataset_arc.clone(), test_indices))
655 } else {
656 None
657 };
658
659 Ok(DatasetSplit {
660 train,
661 validation,
662 test,
663 })
664 }
665
666 #[allow(clippy::type_complexity)]
668 pub fn k_fold<T, D: Dataset<T>>(
669 dataset: D,
670 k: usize,
671 shuffle: bool,
672 ) -> Result<Vec<(SubsetDataset<T, Arc<D>>, SubsetDataset<T, Arc<D>>)>> {
673 if k <= 1 {
674 return Err(TensorError::invalid_argument(
675 "K must be greater than 1".to_string(),
676 ));
677 }
678
679 let total_len = dataset.len();
680 if total_len == 0 {
681 return Err(TensorError::invalid_argument(
682 "Cannot split empty dataset".to_string(),
683 ));
684 }
685
686 let mut indices: Vec<usize> = (0..total_len).collect();
687
688 if shuffle {
689 use scirs2_core::random::rand_prelude::*;
690 let mut rng = rng();
691 indices.shuffle(&mut rng);
692 }
693
694 let fold_size = total_len / k;
695 let mut folds = Vec::new();
696 let dataset_arc = Arc::new(dataset);
697
698 for i in 0..k {
699 let start = i * fold_size;
700 let end = if i == k - 1 {
701 total_len
702 } else {
703 (i + 1) * fold_size
704 };
705
706 let val_indices = indices[start..end].to_vec();
707 let train_indices: Vec<usize> = indices[0..start]
708 .iter()
709 .chain(indices[end..].iter())
710 .cloned()
711 .collect();
712
713 let train_dataset = SubsetDataset::new(dataset_arc.clone(), train_indices);
714 let val_dataset = SubsetDataset::new(dataset_arc.clone(), val_indices);
715
716 folds.push((train_dataset, val_dataset));
717 }
718
719 Ok(folds)
720 }
721
722 pub fn stratified_split<T, D: Dataset<T>>(
724 dataset: D,
725 train_ratio: f64,
726 val_ratio: Option<f64>,
727 extract_class: fn(&(Tensor<T>, Tensor<T>)) -> usize,
728 ) -> Result<(Vec<usize>, Vec<usize>)> {
729 let total_len = dataset.len();
730 if total_len == 0 {
731 return Err(TensorError::invalid_argument(
732 "Cannot split empty dataset".to_string(),
733 ));
734 }
735
736 let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
738 std::collections::HashMap::new();
739
740 for i in 0..total_len {
741 if let Ok(sample) = dataset.get(i) {
742 let class = extract_class(&sample);
743 class_indices.entry(class).or_default().push(i);
744 }
745 }
746
747 let mut train_indices = Vec::new();
748 let mut val_indices = Vec::new();
749
750 for (_, mut indices) in class_indices {
752 use scirs2_core::random::rand_prelude::*;
754 let mut rng = rng();
755 indices.shuffle(&mut rng);
756
757 let class_len = indices.len();
758 let train_end = (class_len as f64 * train_ratio) as usize;
759
760 train_indices.extend(indices[0..train_end].iter());
761
762 if let Some(val_ratio) = val_ratio {
763 let val_end = train_end + (class_len as f64 * val_ratio) as usize;
764 val_indices.extend(indices[train_end..val_end].iter());
765 }
766 }
767
768 Ok((train_indices, val_indices))
769 }
770}
771
772#[cfg(test)]
773mod tests {
774 use super::*;
775 use tenflowers_core::Tensor;
776
777 #[test]
778 fn test_tensor_dataset_creation() {
779 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2])
780 .expect("test: tensor creation should succeed");
781 let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 2.0], &[3])
782 .expect("test: tensor creation should succeed");
783
784 let dataset = TensorDataset::new(features, labels);
785 assert_eq!(dataset.len(), 3);
786 assert!(!dataset.is_empty());
787 }
788
789 #[test]
790 fn test_tensor_dataset_get() {
791 let features = Tensor::<f32>::from_vec(
792 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
793 &[3, 2], )
795 .expect("test: operation should succeed");
796 let labels = Tensor::<f32>::from_vec(
797 vec![10.0, 20.0, 30.0],
798 &[3], )
800 .expect("test: operation should succeed");
801
802 let dataset = TensorDataset::new(features, labels);
803
804 let (feat, label) = dataset.get(0).expect("index should be in bounds");
806 assert_eq!(feat.shape().dims(), &[2]); assert_eq!(label.shape().dims(), &[] as &[usize]); let (feat2, label2) = dataset.get(1).expect("index should be in bounds");
811 assert_eq!(feat2.shape().dims(), &[2]);
812 assert_eq!(label2.shape().dims(), &[] as &[usize]);
813
814 assert!(dataset.get(3).is_err());
816 }
817
818 #[test]
819 fn test_batched_dataset() {
820 let features = Tensor::<f32>::from_vec(
821 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
822 &[4, 2], )
824 .expect("test: operation should succeed");
825 let labels = Tensor::<f32>::from_vec(
826 vec![10.0, 20.0, 30.0, 40.0],
827 &[4], )
829 .expect("test: operation should succeed");
830
831 let dataset = TensorDataset::new(features, labels);
832 let mut batched = dataset.batch(2);
833
834 let batch1 = batched.next().expect("test: iterator should have next");
836 assert_eq!(batch1.len(), 2);
837
838 let batch2 = batched.next().expect("test: iterator should have next");
840 assert_eq!(batch2.len(), 2);
841
842 assert!(batched.next().is_none());
844 }
845
846 #[test]
847 fn test_batched_dataset_partial_batch() {
848 let features = Tensor::<f32>::from_vec(
849 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
850 &[3, 2], )
852 .expect("test: operation should succeed");
853 let labels = Tensor::<f32>::from_vec(
854 vec![10.0, 20.0, 30.0],
855 &[3], )
857 .expect("test: operation should succeed");
858
859 let dataset = TensorDataset::new(features, labels);
860 let mut batched = dataset.batch(2);
861
862 let batch1 = batched.next().expect("test: iterator should have next");
864 assert_eq!(batch1.len(), 2);
865
866 let batch2 = batched.next().expect("test: iterator should have next");
868 assert_eq!(batch2.len(), 1);
869
870 assert!(batched.next().is_none());
872 }
873
874 #[test]
875 fn test_merged_dataset_concatenation() {
876 let features1 = Tensor::<f32>::from_vec(
878 vec![1.0, 2.0, 3.0, 4.0],
879 &[2, 2], )
881 .expect("test: operation should succeed");
882 let labels1 = Tensor::<f32>::from_vec(vec![10.0, 20.0], &[2])
883 .expect("test: tensor creation should succeed");
884 let dataset1 = TensorDataset::new(features1, labels1);
885
886 let features2 = Tensor::<f32>::from_vec(
887 vec![5.0, 6.0, 7.0, 8.0],
888 &[2, 2], )
890 .expect("test: operation should succeed");
891 let labels2 = Tensor::<f32>::from_vec(vec![30.0, 40.0], &[2])
892 .expect("test: tensor creation should succeed");
893 let dataset2 = TensorDataset::new(features2, labels2);
894
895 let merged = MergedDataset::new_concatenated(dataset1, dataset2)
897 .expect("test: operation should succeed");
898
899 assert_eq!(merged.len(), 2);
900
901 let (features, labels) = merged.get(0).expect("index should be in bounds");
903 assert_eq!(features.shape().dims(), &[4]); assert_eq!(labels.shape().dims(), &[] as &[usize]);
905 }
906
907 #[test]
908 fn test_merged_dataset_averaging() {
909 let features1 = Tensor::<f32>::from_vec(
911 vec![1.0, 2.0, 3.0, 4.0],
912 &[2, 2], )
914 .expect("test: operation should succeed");
915 let labels1 = Tensor::<f32>::from_vec(vec![10.0, 20.0], &[2])
916 .expect("test: tensor creation should succeed");
917 let dataset1 = TensorDataset::new(features1, labels1);
918
919 let features2 = Tensor::<f32>::from_vec(
920 vec![5.0, 6.0, 7.0, 8.0],
921 &[2, 2], )
923 .expect("test: operation should succeed");
924 let labels2 = Tensor::<f32>::from_vec(vec![30.0, 40.0], &[2])
925 .expect("test: tensor creation should succeed");
926 let dataset2 = TensorDataset::new(features2, labels2);
927
928 let merged = MergedDataset::new_averaged(dataset1, dataset2)
930 .expect("test: operation should succeed");
931
932 assert_eq!(merged.len(), 2);
933
934 let (features, _) = merged.get(0).expect("index should be in bounds");
936 assert_eq!(features.shape().dims(), &[2]); let data = features.as_slice().expect("tensor should be contiguous");
939 assert!((data[0] - 3.0).abs() < 1e-6);
940 assert!((data[1] - 4.0).abs() < 1e-6);
941 }
942
943 #[test]
944 fn test_merged_dataset_mismatched_lengths() {
945 let features1 = Tensor::<f32>::from_vec(
947 vec![1.0, 2.0, 3.0, 4.0],
948 &[2, 2], )
950 .expect("test: operation should succeed");
951 let labels1 = Tensor::<f32>::from_vec(vec![10.0, 20.0], &[2])
952 .expect("test: tensor creation should succeed");
953 let dataset1 = TensorDataset::new(features1, labels1);
954
955 let features2 = Tensor::<f32>::from_vec(
956 vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
957 &[3, 2], )
959 .expect("test: operation should succeed");
960 let labels2 = Tensor::<f32>::from_vec(vec![30.0, 40.0, 50.0], &[3])
961 .expect("test: tensor creation should succeed");
962 let dataset2 = TensorDataset::new(features2, labels2);
963
964 assert!(MergedDataset::new_concatenated(dataset1, dataset2).is_err());
966 }
967}