Skip to main content

tenflowers_dataset/
dataset_core.rs

1//! Core dataset traits, types, and utility implementations.
2//!
3//! This module contains the fundamental `Dataset` trait, `DatasetUtilsExt` extension trait,
4//! and all the basic dataset wrapper types (`TensorDataset`, `BatchedDataset`, `ConcatDataset`,
5//! `FilteredDataset`, `SubsetDataset`, `MergedDataset`, `DatasetSplitter`, etc.).
6
7use scirs2_core::random::rng;
8use std::marker::PhantomData;
9use std::sync::Arc;
10use tenflowers_core::ops::slice;
11use tenflowers_core::{Result, Tensor, TensorError};
12
13/// Core trait for datasets.
14///
15/// All dataset types must implement this trait to be usable with data loaders,
16/// transforms, and other dataset utilities.
17pub trait Dataset<T> {
18    fn len(&self) -> usize;
19    fn is_empty(&self) -> bool {
20        self.len() == 0
21    }
22    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)>;
23    fn batch(self, batch_size: usize) -> BatchedDataset<T, Self>
24    where
25        Self: Sized,
26    {
27        BatchedDataset {
28            dataset: self,
29            batch_size,
30            current_index: 0,
31            _phantom: PhantomData,
32        }
33    }
34}
35
36/// Implement Dataset for `Arc<D>` to allow shared ownership
37impl<T, D: Dataset<T>> Dataset<T> for Arc<D> {
38    fn len(&self) -> usize {
39        (**self).len()
40    }
41
42    fn is_empty(&self) -> bool {
43        (**self).is_empty()
44    }
45
46    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
47        (**self).get(index)
48    }
49}
50
51/// Extension trait providing utility methods for datasets
52pub trait DatasetUtilsExt<T>: Dataset<T> {
53    /// Get multiple samples by their indices
54    fn get_multiple(&self, indices: &[usize]) -> Result<Vec<(Tensor<T>, Tensor<T>)>> {
55        let mut samples = Vec::with_capacity(indices.len());
56        for &index in indices {
57            samples.push(self.get(index)?);
58        }
59        Ok(samples)
60    }
61
62    /// Get a range of samples from start (inclusive) to end (exclusive)
63    fn get_range(&self, start: usize, end: usize) -> Result<Vec<(Tensor<T>, Tensor<T>)>> {
64        if start >= end {
65            return Ok(Vec::new());
66        }
67        if end > self.len() {
68            return Err(TensorError::invalid_argument(format!(
69                "End index {} out of bounds for dataset of length {}",
70                end,
71                self.len()
72            )));
73        }
74
75        let mut samples = Vec::with_capacity(end - start);
76        for i in start..end {
77            samples.push(self.get(i)?);
78        }
79        Ok(samples)
80    }
81
82    /// Get a random sample from the dataset
83    fn get_random(&self) -> Result<(Tensor<T>, Tensor<T>)> {
84        use scirs2_core::random::rand_prelude::*;
85        if self.is_empty() {
86            return Err(TensorError::invalid_argument(
87                "Cannot get random sample from empty dataset".to_string(),
88            ));
89        }
90        let mut rng = rng();
91        let random_val: f64 = rng.random();
92        let index = (random_val * self.len() as f64) as usize;
93        let index = index.min(self.len() - 1); // Clamp to valid range
94        self.get(index)
95    }
96
97    /// Get multiple random samples from the dataset (with replacement)
98    fn get_random_samples(&self, count: usize) -> Result<Vec<(Tensor<T>, Tensor<T>)>> {
99        use scirs2_core::random::rand_prelude::*;
100        if self.is_empty() {
101            return Err(TensorError::invalid_argument(
102                "Cannot get random samples from empty dataset".to_string(),
103            ));
104        }
105
106        let mut rng = rng();
107        let mut samples = Vec::with_capacity(count);
108        for _ in 0..count {
109            let random_val: f64 = rng.random();
110            let index = (random_val * self.len() as f64) as usize;
111            let index = index.min(self.len() - 1); // Clamp to valid range
112            samples.push(self.get(index)?);
113        }
114        Ok(samples)
115    }
116}
117
118/// Implement DatasetUtilsExt for all types that implement Dataset
119impl<T, D: Dataset<T>> DatasetUtilsExt<T> for D {}
120
121#[derive(Clone)]
122pub struct TensorDataset<T> {
123    features: Tensor<T>,
124    #[allow(dead_code)]
125    labels: Tensor<T>,
126}
127
128impl<T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static> TensorDataset<T> {
129    pub fn new(features: Tensor<T>, labels: Tensor<T>) -> Self {
130        Self { features, labels }
131    }
132}
133
134impl<T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static> Dataset<T>
135    for TensorDataset<T>
136{
137    fn len(&self) -> usize {
138        self.features.shape().dims()[0]
139    }
140
141    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
142        if index >= self.len() {
143            return Err(TensorError::invalid_argument(format!(
144                "Index {} out of bounds for dataset of length {}",
145                index,
146                self.len()
147            )));
148        }
149
150        // Create slice ranges for the specific index
151        let mut feature_ranges = Vec::new();
152        let mut label_ranges = Vec::new();
153
154        // For the first dimension (batch dimension), slice to get single index
155        feature_ranges.push(index..index + 1);
156        label_ranges.push(index..index + 1);
157
158        // For remaining dimensions, take all elements
159        for i in 1..self.features.shape().rank() {
160            feature_ranges.push(0..self.features.shape().dims()[i]);
161        }
162        for i in 1..self.labels.shape().rank() {
163            label_ranges.push(0..self.labels.shape().dims()[i]);
164        }
165
166        // Slice the tensors
167        let feature_slice = slice(&self.features, &feature_ranges)?;
168        let label_slice = slice(&self.labels, &label_ranges)?;
169
170        // Squeeze the first dimension (remove batch dimension of size 1)
171        let feature_squeezed = squeeze_first_dim(&feature_slice)?;
172        let label_squeezed = squeeze_first_dim(&label_slice)?;
173
174        Ok((feature_squeezed, label_squeezed))
175    }
176}
177
178/// Helper function to squeeze the first dimension of a tensor
179fn squeeze_first_dim<T>(tensor: &Tensor<T>) -> Result<Tensor<T>>
180where
181    T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
182{
183    let shape = tensor.shape();
184    if shape.rank() == 0 {
185        return Ok(tensor.clone());
186    }
187
188    if shape.dims()[0] != 1 {
189        return Err(TensorError::invalid_argument(format!(
190            "Cannot squeeze dimension of size {}",
191            shape.dims()[0]
192        )));
193    }
194
195    let new_shape: Vec<usize> = shape.dims()[1..].to_vec();
196    tenflowers_core::ops::reshape(tensor, &new_shape)
197}
198
199pub struct BatchedDataset<T, D: Dataset<T>> {
200    dataset: D,
201    batch_size: usize,
202    current_index: usize,
203    _phantom: PhantomData<T>,
204}
205
206impl<T, D: Dataset<T>> Iterator for BatchedDataset<T, D> {
207    type Item = Vec<(Tensor<T>, Tensor<T>)>;
208
209    fn next(&mut self) -> Option<Self::Item> {
210        if self.current_index >= self.dataset.len() {
211            return None;
212        }
213
214        let mut batch = Vec::new();
215        let end_index = (self.current_index + self.batch_size).min(self.dataset.len());
216
217        for i in self.current_index..end_index {
218            match self.dataset.get(i) {
219                Ok(sample) => batch.push(sample),
220                Err(_) => break, // Stop on error, return partial batch if any
221            }
222        }
223
224        self.current_index = end_index;
225
226        if batch.is_empty() {
227            None
228        } else {
229            Some(batch)
230        }
231    }
232}
233
234/// Dataset concatenation - combines multiple datasets into one
235pub struct ConcatDataset<T, D: Dataset<T>> {
236    datasets: Vec<D>,
237    cumulative_lengths: Vec<usize>,
238    total_length: usize,
239    _phantom: PhantomData<T>,
240}
241
242impl<T, D: Dataset<T>> ConcatDataset<T, D> {
243    pub fn new(datasets: Vec<D>) -> Self {
244        let mut cumulative_lengths = Vec::with_capacity(datasets.len());
245        let mut total_length = 0;
246
247        for dataset in &datasets {
248            total_length += dataset.len();
249            cumulative_lengths.push(total_length);
250        }
251
252        Self {
253            datasets,
254            cumulative_lengths,
255            total_length,
256            _phantom: PhantomData,
257        }
258    }
259
260    /// Find which dataset and local index for a global index
261    fn find_dataset_and_index(&self, global_index: usize) -> Result<(usize, usize)> {
262        for (dataset_idx, &cumulative_len) in self.cumulative_lengths.iter().enumerate() {
263            if global_index < cumulative_len {
264                let local_index = if dataset_idx == 0 {
265                    global_index
266                } else {
267                    global_index - self.cumulative_lengths[dataset_idx - 1]
268                };
269                return Ok((dataset_idx, local_index));
270            }
271        }
272        Err(TensorError::invalid_argument(format!(
273            "Index {} out of bounds for dataset of total length {}",
274            global_index, self.total_length
275        )))
276    }
277}
278
279impl<T, D: Dataset<T>> Dataset<T> for ConcatDataset<T, D> {
280    fn len(&self) -> usize {
281        self.total_length
282    }
283
284    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
285        if index >= self.total_length {
286            return Err(TensorError::invalid_argument(format!(
287                "Index {} out of bounds for dataset of length {}",
288                index, self.total_length
289            )));
290        }
291
292        let (dataset_idx, local_index) = self.find_dataset_and_index(index)?;
293        self.datasets[dataset_idx].get(local_index)
294    }
295}
296
297/// Dataset filtering - creates a view of a dataset with only indices that match a predicate
298pub struct FilteredDataset<T, D: Dataset<T>, F: Fn(&(Tensor<T>, Tensor<T>)) -> bool> {
299    dataset: D,
300    valid_indices: Vec<usize>,
301    _phantom: PhantomData<(T, F)>,
302}
303
304impl<T, D: Dataset<T>, F: Fn(&(Tensor<T>, Tensor<T>)) -> bool> FilteredDataset<T, D, F> {
305    pub fn new(dataset: D, predicate: F) -> Result<Self> {
306        let mut valid_indices = Vec::new();
307
308        for i in 0..dataset.len() {
309            match dataset.get(i) {
310                Ok(sample) => {
311                    if predicate(&sample) {
312                        valid_indices.push(i);
313                    }
314                }
315                Err(_) => continue, // Skip invalid samples
316            }
317        }
318
319        Ok(Self {
320            dataset,
321            valid_indices,
322            _phantom: PhantomData,
323        })
324    }
325}
326
327impl<T, D: Dataset<T>, F: Fn(&(Tensor<T>, Tensor<T>)) -> bool> Dataset<T>
328    for FilteredDataset<T, D, F>
329{
330    fn len(&self) -> usize {
331        self.valid_indices.len()
332    }
333
334    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
335        if index >= self.valid_indices.len() {
336            return Err(TensorError::invalid_argument(format!(
337                "Index {} out of bounds for filtered dataset of length {}",
338                index,
339                self.valid_indices.len()
340            )));
341        }
342
343        let actual_index = self.valid_indices[index];
344        self.dataset.get(actual_index)
345    }
346}
347
348/// Dataset splitting - splits a dataset into train/validation/test sets
349pub struct DatasetSplit<T, D: Dataset<T>> {
350    pub train: SubsetDataset<T, Arc<D>>,
351    pub validation: Option<SubsetDataset<T, Arc<D>>>,
352    pub test: Option<SubsetDataset<T, Arc<D>>>,
353}
354
355/// Subset dataset - creates a view of a dataset with only specified indices
356pub struct SubsetDataset<T, D: Dataset<T>> {
357    dataset: D,
358    indices: Vec<usize>,
359    _phantom: PhantomData<T>,
360}
361
362impl<T, D: Dataset<T>> SubsetDataset<T, D> {
363    pub fn new(dataset: D, indices: Vec<usize>) -> Self {
364        Self {
365            dataset,
366            indices,
367            _phantom: PhantomData,
368        }
369    }
370}
371
372impl<T, D: Dataset<T>> Dataset<T> for SubsetDataset<T, D> {
373    fn len(&self) -> usize {
374        self.indices.len()
375    }
376
377    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
378        if index >= self.indices.len() {
379            return Err(TensorError::invalid_argument(format!(
380                "Index {} out of bounds for subset dataset of length {}",
381                index,
382                self.indices.len()
383            )));
384        }
385
386        let actual_index = self.indices[index];
387        self.dataset.get(actual_index)
388    }
389}
390
391/// Dataset merging - combines multiple datasets with different modalities
392pub struct MergedDataset<T, D1: Dataset<T>, D2: Dataset<T>> {
393    dataset1: D1,
394    dataset2: D2,
395    merge_strategy: MergeStrategy,
396    _phantom: PhantomData<T>,
397}
398
399/// Strategy for merging datasets
400#[derive(Debug, Clone)]
401pub enum MergeStrategy {
402    /// Concatenate features horizontally
403    FeatureConcatenation,
404    /// Average features element-wise
405    FeatureAverage,
406    /// Use features from first dataset, labels from second
407    FeatureFromFirst,
408    /// Use features from second dataset, labels from first
409    FeatureFromSecond,
410    /// Custom merge function
411    Custom,
412}
413
414impl<T, D1: Dataset<T>, D2: Dataset<T>> MergedDataset<T, D1, D2> {
415    /// Create a new merged dataset with feature concatenation
416    pub fn new_concatenated(dataset1: D1, dataset2: D2) -> Result<Self> {
417        if dataset1.len() != dataset2.len() {
418            return Err(TensorError::invalid_argument(format!(
419                "Dataset lengths must match: {} vs {}",
420                dataset1.len(),
421                dataset2.len()
422            )));
423        }
424
425        Ok(Self {
426            dataset1,
427            dataset2,
428            merge_strategy: MergeStrategy::FeatureConcatenation,
429            _phantom: PhantomData,
430        })
431    }
432
433    /// Create a new merged dataset with feature averaging
434    pub fn new_averaged(dataset1: D1, dataset2: D2) -> Result<Self> {
435        if dataset1.len() != dataset2.len() {
436            return Err(TensorError::invalid_argument(format!(
437                "Dataset lengths must match: {} vs {}",
438                dataset1.len(),
439                dataset2.len()
440            )));
441        }
442
443        Ok(Self {
444            dataset1,
445            dataset2,
446            merge_strategy: MergeStrategy::FeatureAverage,
447            _phantom: PhantomData,
448        })
449    }
450
451    /// Create a new merged dataset using features from first dataset and labels from second
452    pub fn new_features_from_first(dataset1: D1, dataset2: D2) -> Result<Self> {
453        if dataset1.len() != dataset2.len() {
454            return Err(TensorError::invalid_argument(format!(
455                "Dataset lengths must match: {} vs {}",
456                dataset1.len(),
457                dataset2.len()
458            )));
459        }
460
461        Ok(Self {
462            dataset1,
463            dataset2,
464            merge_strategy: MergeStrategy::FeatureFromFirst,
465            _phantom: PhantomData,
466        })
467    }
468
469    /// Create a new merged dataset using features from second dataset and labels from first
470    pub fn new_features_from_second(dataset1: D1, dataset2: D2) -> Result<Self> {
471        if dataset1.len() != dataset2.len() {
472            return Err(TensorError::invalid_argument(format!(
473                "Dataset lengths must match: {} vs {}",
474                dataset1.len(),
475                dataset2.len()
476            )));
477        }
478
479        Ok(Self {
480            dataset1,
481            dataset2,
482            merge_strategy: MergeStrategy::FeatureFromSecond,
483            _phantom: PhantomData,
484        })
485    }
486
487    /// Merge two tensors based on the merge strategy
488    fn merge_tensors(&self, tensor1: &Tensor<T>, tensor2: &Tensor<T>) -> Result<Tensor<T>>
489    where
490        T: Clone
491            + Default
492            + scirs2_core::numeric::Zero
493            + scirs2_core::numeric::Float
494            + Send
495            + Sync
496            + 'static,
497    {
498        match self.merge_strategy {
499            MergeStrategy::FeatureConcatenation => {
500                // Concatenate tensors along the feature dimension
501                let data1 = tensor1.as_slice().ok_or_else(|| {
502                    TensorError::invalid_argument(
503                        "Cannot access tensor data (GPU tensor not supported)".to_string(),
504                    )
505                })?;
506                let data2 = tensor2.as_slice().ok_or_else(|| {
507                    TensorError::invalid_argument(
508                        "Cannot access tensor data (GPU tensor not supported)".to_string(),
509                    )
510                })?;
511                let mut merged_data = Vec::new();
512                merged_data.extend_from_slice(data1);
513                merged_data.extend_from_slice(data2);
514
515                let new_shape = vec![data1.len() + data2.len()];
516                Tensor::from_vec(merged_data, &new_shape)
517            }
518            MergeStrategy::FeatureAverage => {
519                // Average tensors element-wise
520                let data1 = tensor1.as_slice().ok_or_else(|| {
521                    TensorError::invalid_argument(
522                        "Cannot access tensor data (GPU tensor not supported)".to_string(),
523                    )
524                })?;
525                let data2 = tensor2.as_slice().ok_or_else(|| {
526                    TensorError::invalid_argument(
527                        "Cannot access tensor data (GPU tensor not supported)".to_string(),
528                    )
529                })?;
530
531                if data1.len() != data2.len() {
532                    return Err(TensorError::invalid_argument(
533                        "Cannot average tensors of different sizes".to_string(),
534                    ));
535                }
536
537                let mut averaged_data = Vec::new();
538                let two = T::from(2.0).expect("conversion of 2.0 to float type should succeed");
539                for (v1, v2) in data1.iter().zip(data2.iter()) {
540                    let avg = (*v1 + *v2) / two;
541                    averaged_data.push(avg);
542                }
543
544                Tensor::from_vec(averaged_data, tensor1.shape().dims())
545            }
546            MergeStrategy::FeatureFromFirst => Ok(tensor1.clone()),
547            MergeStrategy::FeatureFromSecond => Ok(tensor2.clone()),
548            MergeStrategy::Custom => {
549                // For custom merge, just return first tensor for now
550                // This could be extended to accept custom merge functions
551                Ok(tensor1.clone())
552            }
553        }
554    }
555}
556
557impl<T, D1: Dataset<T>, D2: Dataset<T>> Dataset<T> for MergedDataset<T, D1, D2>
558where
559    T: Clone
560        + Default
561        + scirs2_core::numeric::Zero
562        + scirs2_core::numeric::Float
563        + Send
564        + Sync
565        + 'static,
566{
567    fn len(&self) -> usize {
568        self.dataset1.len()
569    }
570
571    fn get(&self, index: usize) -> Result<(Tensor<T>, Tensor<T>)> {
572        if index >= self.dataset1.len() {
573            return Err(TensorError::invalid_argument(format!(
574                "Index {} out of bounds for merged dataset of length {}",
575                index,
576                self.dataset1.len()
577            )));
578        }
579
580        let (features1, labels1) = self.dataset1.get(index)?;
581        let (features2, labels2) = self.dataset2.get(index)?;
582
583        let merged_features = self.merge_tensors(&features1, &features2)?;
584
585        // For labels, use the strategy to determine which label to use
586        let merged_labels = match self.merge_strategy {
587            MergeStrategy::FeatureFromFirst => labels1,
588            MergeStrategy::FeatureFromSecond => labels2,
589            _ => labels1, // Default to first dataset's labels
590        };
591
592        Ok((merged_features, merged_labels))
593    }
594}
595
596/// Dataset splitting utilities
597pub struct DatasetSplitter;
598
599impl DatasetSplitter {
600    /// Split dataset into train/validation/test sets with given ratios
601    pub fn split<T, D: Dataset<T>>(
602        dataset: D,
603        train_ratio: f64,
604        val_ratio: Option<f64>,
605        test_ratio: Option<f64>,
606        shuffle: bool,
607    ) -> Result<DatasetSplit<T, D>> {
608        let total_len = dataset.len();
609        if total_len == 0 {
610            return Err(TensorError::invalid_argument(
611                "Cannot split empty dataset".to_string(),
612            ));
613        }
614
615        // Validate ratios
616        let val_ratio = val_ratio.unwrap_or(0.0);
617        let test_ratio = test_ratio.unwrap_or(0.0);
618
619        if train_ratio + val_ratio + test_ratio > 1.0 {
620            return Err(TensorError::invalid_argument(
621                "Sum of ratios cannot exceed 1.0".to_string(),
622            ));
623        }
624
625        // Create indices
626        let mut indices: Vec<usize> = (0..total_len).collect();
627
628        // Shuffle if requested
629        if shuffle {
630            use scirs2_core::random::rand_prelude::*;
631            let mut rng = rng();
632            indices.shuffle(&mut rng);
633        }
634
635        // Calculate split points
636        let train_end = (total_len as f64 * train_ratio) as usize;
637        let val_end = train_end + (total_len as f64 * val_ratio) as usize;
638        let test_end = val_end + (total_len as f64 * test_ratio) as usize;
639
640        // Create subset datasets using Arc for sharing
641        let dataset_arc = Arc::new(dataset);
642        let train_indices = indices[0..train_end].to_vec();
643        let train = SubsetDataset::new(dataset_arc.clone(), train_indices);
644
645        let validation = if val_ratio > 0.0 {
646            let val_indices = indices[train_end..val_end].to_vec();
647            Some(SubsetDataset::new(dataset_arc.clone(), val_indices))
648        } else {
649            None
650        };
651
652        let test = if test_ratio > 0.0 {
653            let test_indices = indices[val_end..test_end].to_vec();
654            Some(SubsetDataset::new(dataset_arc.clone(), test_indices))
655        } else {
656            None
657        };
658
659        Ok(DatasetSplit {
660            train,
661            validation,
662            test,
663        })
664    }
665
666    /// Split dataset into k-folds for cross-validation
667    #[allow(clippy::type_complexity)]
668    pub fn k_fold<T, D: Dataset<T>>(
669        dataset: D,
670        k: usize,
671        shuffle: bool,
672    ) -> Result<Vec<(SubsetDataset<T, Arc<D>>, SubsetDataset<T, Arc<D>>)>> {
673        if k <= 1 {
674            return Err(TensorError::invalid_argument(
675                "K must be greater than 1".to_string(),
676            ));
677        }
678
679        let total_len = dataset.len();
680        if total_len == 0 {
681            return Err(TensorError::invalid_argument(
682                "Cannot split empty dataset".to_string(),
683            ));
684        }
685
686        let mut indices: Vec<usize> = (0..total_len).collect();
687
688        if shuffle {
689            use scirs2_core::random::rand_prelude::*;
690            let mut rng = rng();
691            indices.shuffle(&mut rng);
692        }
693
694        let fold_size = total_len / k;
695        let mut folds = Vec::new();
696        let dataset_arc = Arc::new(dataset);
697
698        for i in 0..k {
699            let start = i * fold_size;
700            let end = if i == k - 1 {
701                total_len
702            } else {
703                (i + 1) * fold_size
704            };
705
706            let val_indices = indices[start..end].to_vec();
707            let train_indices: Vec<usize> = indices[0..start]
708                .iter()
709                .chain(indices[end..].iter())
710                .cloned()
711                .collect();
712
713            let train_dataset = SubsetDataset::new(dataset_arc.clone(), train_indices);
714            let val_dataset = SubsetDataset::new(dataset_arc.clone(), val_indices);
715
716            folds.push((train_dataset, val_dataset));
717        }
718
719        Ok(folds)
720    }
721
722    /// Stratified split - maintains class distribution across splits
723    pub fn stratified_split<T, D: Dataset<T>>(
724        dataset: D,
725        train_ratio: f64,
726        val_ratio: Option<f64>,
727        extract_class: fn(&(Tensor<T>, Tensor<T>)) -> usize,
728    ) -> Result<(Vec<usize>, Vec<usize>)> {
729        let total_len = dataset.len();
730        if total_len == 0 {
731            return Err(TensorError::invalid_argument(
732                "Cannot split empty dataset".to_string(),
733            ));
734        }
735
736        // Group indices by class
737        let mut class_indices: std::collections::HashMap<usize, Vec<usize>> =
738            std::collections::HashMap::new();
739
740        for i in 0..total_len {
741            if let Ok(sample) = dataset.get(i) {
742                let class = extract_class(&sample);
743                class_indices.entry(class).or_default().push(i);
744            }
745        }
746
747        let mut train_indices = Vec::new();
748        let mut val_indices = Vec::new();
749
750        // Split each class proportionally
751        for (_, mut indices) in class_indices {
752            // Shuffle class indices
753            use scirs2_core::random::rand_prelude::*;
754            let mut rng = rng();
755            indices.shuffle(&mut rng);
756
757            let class_len = indices.len();
758            let train_end = (class_len as f64 * train_ratio) as usize;
759
760            train_indices.extend(indices[0..train_end].iter());
761
762            if let Some(val_ratio) = val_ratio {
763                let val_end = train_end + (class_len as f64 * val_ratio) as usize;
764                val_indices.extend(indices[train_end..val_end].iter());
765            }
766        }
767
768        Ok((train_indices, val_indices))
769    }
770}
771
772#[cfg(test)]
773mod tests {
774    use super::*;
775    use tenflowers_core::Tensor;
776
777    #[test]
778    fn test_tensor_dataset_creation() {
779        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2])
780            .expect("test: tensor creation should succeed");
781        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0, 2.0], &[3])
782            .expect("test: tensor creation should succeed");
783
784        let dataset = TensorDataset::new(features, labels);
785        assert_eq!(dataset.len(), 3);
786        assert!(!dataset.is_empty());
787    }
788
789    #[test]
790    fn test_tensor_dataset_get() {
791        let features = Tensor::<f32>::from_vec(
792            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
793            &[3, 2], // 3 samples, 2 features each
794        )
795        .expect("test: operation should succeed");
796        let labels = Tensor::<f32>::from_vec(
797            vec![10.0, 20.0, 30.0],
798            &[3], // 3 labels
799        )
800        .expect("test: operation should succeed");
801
802        let dataset = TensorDataset::new(features, labels);
803
804        // Test getting first sample
805        let (feat, label) = dataset.get(0).expect("index should be in bounds");
806        assert_eq!(feat.shape().dims(), &[2]); // Should be squeezed from [1, 2] to [2]
807        assert_eq!(label.shape().dims(), &[] as &[usize]); // Should be squeezed from [1] to scalar
808
809        // Test getting second sample
810        let (feat2, label2) = dataset.get(1).expect("index should be in bounds");
811        assert_eq!(feat2.shape().dims(), &[2]);
812        assert_eq!(label2.shape().dims(), &[] as &[usize]);
813
814        // Test out of bounds
815        assert!(dataset.get(3).is_err());
816    }
817
818    #[test]
819    fn test_batched_dataset() {
820        let features = Tensor::<f32>::from_vec(
821            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
822            &[4, 2], // 4 samples, 2 features each
823        )
824        .expect("test: operation should succeed");
825        let labels = Tensor::<f32>::from_vec(
826            vec![10.0, 20.0, 30.0, 40.0],
827            &[4], // 4 labels
828        )
829        .expect("test: operation should succeed");
830
831        let dataset = TensorDataset::new(features, labels);
832        let mut batched = dataset.batch(2);
833
834        // First batch should have 2 samples
835        let batch1 = batched.next().expect("test: iterator should have next");
836        assert_eq!(batch1.len(), 2);
837
838        // Second batch should have 2 samples
839        let batch2 = batched.next().expect("test: iterator should have next");
840        assert_eq!(batch2.len(), 2);
841
842        // No more batches
843        assert!(batched.next().is_none());
844    }
845
846    #[test]
847    fn test_batched_dataset_partial_batch() {
848        let features = Tensor::<f32>::from_vec(
849            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
850            &[3, 2], // 3 samples, 2 features each
851        )
852        .expect("test: operation should succeed");
853        let labels = Tensor::<f32>::from_vec(
854            vec![10.0, 20.0, 30.0],
855            &[3], // 3 labels
856        )
857        .expect("test: operation should succeed");
858
859        let dataset = TensorDataset::new(features, labels);
860        let mut batched = dataset.batch(2);
861
862        // First batch should have 2 samples
863        let batch1 = batched.next().expect("test: iterator should have next");
864        assert_eq!(batch1.len(), 2);
865
866        // Second batch should have 1 sample (partial)
867        let batch2 = batched.next().expect("test: iterator should have next");
868        assert_eq!(batch2.len(), 1);
869
870        // No more batches
871        assert!(batched.next().is_none());
872    }
873
874    #[test]
875    fn test_merged_dataset_concatenation() {
876        // Create two datasets
877        let features1 = Tensor::<f32>::from_vec(
878            vec![1.0, 2.0, 3.0, 4.0],
879            &[2, 2], // 2 samples, 2 features each
880        )
881        .expect("test: operation should succeed");
882        let labels1 = Tensor::<f32>::from_vec(vec![10.0, 20.0], &[2])
883            .expect("test: tensor creation should succeed");
884        let dataset1 = TensorDataset::new(features1, labels1);
885
886        let features2 = Tensor::<f32>::from_vec(
887            vec![5.0, 6.0, 7.0, 8.0],
888            &[2, 2], // 2 samples, 2 features each
889        )
890        .expect("test: operation should succeed");
891        let labels2 = Tensor::<f32>::from_vec(vec![30.0, 40.0], &[2])
892            .expect("test: tensor creation should succeed");
893        let dataset2 = TensorDataset::new(features2, labels2);
894
895        // Create merged dataset
896        let merged = MergedDataset::new_concatenated(dataset1, dataset2)
897            .expect("test: operation should succeed");
898
899        assert_eq!(merged.len(), 2);
900
901        // Test getting first sample
902        let (features, labels) = merged.get(0).expect("index should be in bounds");
903        assert_eq!(features.shape().dims(), &[4]); // 2 + 2 features concatenated
904        assert_eq!(labels.shape().dims(), &[] as &[usize]);
905    }
906
907    #[test]
908    fn test_merged_dataset_averaging() {
909        // Create two datasets with same feature size
910        let features1 = Tensor::<f32>::from_vec(
911            vec![1.0, 2.0, 3.0, 4.0],
912            &[2, 2], // 2 samples, 2 features each
913        )
914        .expect("test: operation should succeed");
915        let labels1 = Tensor::<f32>::from_vec(vec![10.0, 20.0], &[2])
916            .expect("test: tensor creation should succeed");
917        let dataset1 = TensorDataset::new(features1, labels1);
918
919        let features2 = Tensor::<f32>::from_vec(
920            vec![5.0, 6.0, 7.0, 8.0],
921            &[2, 2], // 2 samples, 2 features each
922        )
923        .expect("test: operation should succeed");
924        let labels2 = Tensor::<f32>::from_vec(vec![30.0, 40.0], &[2])
925            .expect("test: tensor creation should succeed");
926        let dataset2 = TensorDataset::new(features2, labels2);
927
928        // Create merged dataset with averaging
929        let merged = MergedDataset::new_averaged(dataset1, dataset2)
930            .expect("test: operation should succeed");
931
932        assert_eq!(merged.len(), 2);
933
934        // Test getting first sample - should be averaged
935        let (features, _) = merged.get(0).expect("index should be in bounds");
936        assert_eq!(features.shape().dims(), &[2]); // Same feature size
937                                                   // First sample should be (1+5)/2=3, (2+6)/2=4
938        let data = features.as_slice().expect("tensor should be contiguous");
939        assert!((data[0] - 3.0).abs() < 1e-6);
940        assert!((data[1] - 4.0).abs() < 1e-6);
941    }
942
943    #[test]
944    fn test_merged_dataset_mismatched_lengths() {
945        // Create two datasets with different lengths
946        let features1 = Tensor::<f32>::from_vec(
947            vec![1.0, 2.0, 3.0, 4.0],
948            &[2, 2], // 2 samples
949        )
950        .expect("test: operation should succeed");
951        let labels1 = Tensor::<f32>::from_vec(vec![10.0, 20.0], &[2])
952            .expect("test: tensor creation should succeed");
953        let dataset1 = TensorDataset::new(features1, labels1);
954
955        let features2 = Tensor::<f32>::from_vec(
956            vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
957            &[3, 2], // 3 samples
958        )
959        .expect("test: operation should succeed");
960        let labels2 = Tensor::<f32>::from_vec(vec![30.0, 40.0, 50.0], &[3])
961            .expect("test: tensor creation should succeed");
962        let dataset2 = TensorDataset::new(features2, labels2);
963
964        // Should fail with mismatched lengths
965        assert!(MergedDataset::new_concatenated(dataset1, dataset2).is_err());
966    }
967}