Skip to main content

yscv_model/dataset/
types.rs

1use std::collections::HashMap;
2use yscv_tensor::Tensor;
3
4use crate::{ImageAugmentationPipeline, ModelError};
5
6use super::helpers::{
7    class_labels_from_targets, gather_rows, shuffle_indices, validate_split_ratios,
8};
9use super::iter::{
10    CutMixConfig, MixUpConfig, apply_cutmix_batch, apply_mixup_batch, build_sample_order,
11    validate_augmentation_compatibility, validate_cutmix_compatibility, validate_cutmix_config,
12    validate_mixup_config,
13};
14
15/// Supervised dataset with aligned input/target sample axis at position 0.
16#[derive(Debug, Clone, PartialEq)]
17pub struct SupervisedDataset {
18    inputs: Tensor,
19    targets: Tensor,
20}
21
22impl SupervisedDataset {
23    pub fn new(inputs: Tensor, targets: Tensor) -> Result<Self, ModelError> {
24        if inputs.rank() == 0 || targets.rank() == 0 {
25            return Err(ModelError::InvalidDatasetRank {
26                inputs_rank: inputs.rank(),
27                targets_rank: targets.rank(),
28            });
29        }
30        if inputs.shape()[0] != targets.shape()[0] {
31            return Err(ModelError::DatasetShapeMismatch {
32                inputs: inputs.shape().to_vec(),
33                targets: targets.shape().to_vec(),
34            });
35        }
36        Ok(Self { inputs, targets })
37    }
38
39    pub fn len(&self) -> usize {
40        self.inputs.shape()[0]
41    }
42
43    pub fn is_empty(&self) -> bool {
44        self.len() == 0
45    }
46
47    pub fn inputs(&self) -> &Tensor {
48        &self.inputs
49    }
50
51    pub fn targets(&self) -> &Tensor {
52        &self.targets
53    }
54
55    pub fn batches(&self, batch_size: usize) -> Result<MiniBatchIter<'_>, ModelError> {
56        self.batches_with_options(batch_size, BatchIterOptions::default())
57    }
58
59    pub fn batches_with_options(
60        &self,
61        batch_size: usize,
62        options: BatchIterOptions,
63    ) -> Result<MiniBatchIter<'_>, ModelError> {
64        if batch_size == 0 {
65            return Err(ModelError::InvalidBatchSize { batch_size });
66        }
67        if let Some(pipeline) = options.augmentation.as_ref() {
68            validate_augmentation_compatibility(self.inputs(), pipeline)?;
69        }
70        if let Some(mixup) = options.mixup.as_ref() {
71            validate_mixup_config(mixup)?;
72        }
73        if let Some(cutmix) = options.cutmix.as_ref() {
74            validate_cutmix_config(cutmix)?;
75            validate_cutmix_compatibility(self.inputs())?;
76        }
77
78        let order = build_sample_order(self, &options)?;
79
80        Ok(MiniBatchIter {
81            dataset: self,
82            batch_size,
83            cursor: 0,
84            order,
85            drop_last: options.drop_last,
86            augmentation: options.augmentation,
87            augmentation_seed: options.augmentation_seed,
88            mixup: options.mixup,
89            mixup_seed: options.mixup_seed,
90            cutmix: options.cutmix,
91            cutmix_seed: options.cutmix_seed,
92            emitted_batches: 0,
93        })
94    }
95
96    pub fn split_by_counts(
97        &self,
98        train_count: usize,
99        validation_count: usize,
100        shuffle: bool,
101        seed: u64,
102    ) -> Result<DatasetSplit, ModelError> {
103        if train_count
104            .checked_add(validation_count)
105            .is_none_or(|sum| sum > self.len())
106        {
107            return Err(ModelError::InvalidSplitCounts {
108                train_count,
109                validation_count,
110                dataset_len: self.len(),
111            });
112        }
113
114        let mut order = (0..self.len()).collect::<Vec<_>>();
115        if shuffle {
116            shuffle_indices(&mut order, seed);
117        }
118
119        let train = self.subset_by_indices(&order[..train_count])?;
120        let validation_end = train_count + validation_count;
121        let validation = self.subset_by_indices(&order[train_count..validation_end])?;
122        let test = self.subset_by_indices(&order[validation_end..])?;
123        Ok(DatasetSplit {
124            train,
125            validation,
126            test,
127        })
128    }
129
130    pub fn split_by_ratio(
131        &self,
132        train_ratio: f32,
133        validation_ratio: f32,
134        shuffle: bool,
135        seed: u64,
136    ) -> Result<DatasetSplit, ModelError> {
137        validate_split_ratios(train_ratio, validation_ratio)?;
138
139        let len = self.len();
140        let train_count = ((len as f64) * train_ratio as f64).floor() as usize;
141        let remaining = len.saturating_sub(train_count);
142        let validation_count =
143            (((len as f64) * validation_ratio as f64).floor() as usize).min(remaining);
144        self.split_by_counts(train_count, validation_count, shuffle, seed)
145    }
146
147    pub fn split_by_class_ratio(
148        &self,
149        train_ratio: f32,
150        validation_ratio: f32,
151        shuffle: bool,
152        seed: u64,
153    ) -> Result<DatasetSplit, ModelError> {
154        validate_split_ratios(train_ratio, validation_ratio)?;
155
156        let class_labels = class_labels_from_targets(self.targets())?;
157        let mut indices_by_class = HashMap::<usize, Vec<usize>>::new();
158        for (sample_index, class_id) in class_labels.into_iter().enumerate() {
159            indices_by_class
160                .entry(class_id)
161                .or_default()
162                .push(sample_index);
163        }
164
165        let mut class_order = indices_by_class.keys().copied().collect::<Vec<_>>();
166        class_order.sort_unstable();
167
168        let mut train_indices = Vec::new();
169        let mut validation_indices = Vec::new();
170        let mut test_indices = Vec::new();
171
172        for class_id in class_order {
173            let mut class_indices = indices_by_class
174                .remove(&class_id)
175                .ok_or(ModelError::InvalidSamplingDistribution)?;
176            if shuffle {
177                let class_seed = seed ^ (class_id as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15);
178                shuffle_indices(&mut class_indices, class_seed);
179            }
180
181            let class_len = class_indices.len();
182            let class_train_count = ((class_len as f64) * train_ratio as f64).floor() as usize;
183            let class_remaining = class_len.saturating_sub(class_train_count);
184            let class_validation_count = (((class_len as f64) * validation_ratio as f64).floor()
185                as usize)
186                .min(class_remaining);
187
188            let validation_start = class_train_count;
189            let validation_end = validation_start + class_validation_count;
190            train_indices.extend_from_slice(&class_indices[..class_train_count]);
191            validation_indices.extend_from_slice(&class_indices[validation_start..validation_end]);
192            test_indices.extend_from_slice(&class_indices[validation_end..]);
193        }
194
195        if shuffle {
196            shuffle_indices(&mut train_indices, seed ^ 0xA55A_5AA5_1234_5678);
197            shuffle_indices(&mut validation_indices, seed ^ 0xBEE5_F00D_89AB_CDEF);
198            shuffle_indices(&mut test_indices, seed ^ 0xDEAD_BEEF_CAFE_BABE);
199        }
200
201        let train = self.subset_by_indices(&train_indices)?;
202        let validation = self.subset_by_indices(&validation_indices)?;
203        let test = self.subset_by_indices(&test_indices)?;
204        Ok(DatasetSplit {
205            train,
206            validation,
207            test,
208        })
209    }
210
211    fn subset_by_indices(&self, indices: &[usize]) -> Result<SupervisedDataset, ModelError> {
212        let inputs = gather_rows(self.inputs(), indices)?;
213        let targets = gather_rows(self.targets(), indices)?;
214        SupervisedDataset::new(inputs, targets)
215    }
216}
217
218/// One deterministic batch from dataset iterator.
219#[derive(Debug, Clone, PartialEq)]
220pub struct Batch {
221    pub inputs: Tensor,
222    pub targets: Tensor,
223}
224
225/// Sample-order policy used by `BatchIterOptions`.
226#[derive(Debug, Clone, PartialEq)]
227pub enum SamplingPolicy {
228    Sequential,
229    Shuffled {
230        seed: u64,
231    },
232    BalancedByClass {
233        seed: u64,
234        with_replacement: bool,
235    },
236    Weighted {
237        weights: Vec<f32>,
238        seed: u64,
239        with_replacement: bool,
240    },
241}
242
243/// Controls mini-batch order, truncation behavior, and optional per-batch regularization.
244#[derive(Debug, Clone, PartialEq, Default)]
245pub struct BatchIterOptions {
246    pub shuffle: bool,
247    pub shuffle_seed: u64,
248    pub drop_last: bool,
249    pub augmentation: Option<ImageAugmentationPipeline>,
250    pub augmentation_seed: u64,
251    pub mixup: Option<MixUpConfig>,
252    pub mixup_seed: u64,
253    pub cutmix: Option<CutMixConfig>,
254    pub cutmix_seed: u64,
255    pub sampling: Option<SamplingPolicy>,
256}
257
258/// Deterministic dataset split produced by `split_by_counts` / `split_by_ratio`.
259#[derive(Debug, Clone, PartialEq)]
260pub struct DatasetSplit {
261    pub train: SupervisedDataset,
262    pub validation: SupervisedDataset,
263    pub test: SupervisedDataset,
264}
265
266/// Deterministic sequential mini-batch iterator.
267#[derive(Debug)]
268pub struct MiniBatchIter<'a> {
269    pub(super) dataset: &'a SupervisedDataset,
270    pub(super) batch_size: usize,
271    pub(super) cursor: usize,
272    pub(super) order: Vec<usize>,
273    pub(super) drop_last: bool,
274    pub(super) augmentation: Option<ImageAugmentationPipeline>,
275    pub(super) augmentation_seed: u64,
276    pub(super) mixup: Option<MixUpConfig>,
277    pub(super) mixup_seed: u64,
278    pub(super) cutmix: Option<CutMixConfig>,
279    pub(super) cutmix_seed: u64,
280    pub(super) emitted_batches: usize,
281}
282
283impl Iterator for MiniBatchIter<'_> {
284    type Item = Batch;
285
286    fn next(&mut self) -> Option<Self::Item> {
287        if self.cursor >= self.order.len() {
288            return None;
289        }
290        let start = self.cursor;
291        let end = (self.cursor + self.batch_size).min(self.order.len());
292        if self.drop_last && (end - start) < self.batch_size {
293            self.cursor = self.order.len();
294            return None;
295        }
296        self.cursor = end;
297
298        let batch_indices = &self.order[start..end];
299        let mut inputs = gather_rows(&self.dataset.inputs, batch_indices).ok()?;
300        let mut targets = gather_rows(&self.dataset.targets, batch_indices).ok()?;
301        if let Some(pipeline) = self.augmentation.as_ref() {
302            let seed = self.augmentation_seed
303                ^ (self.emitted_batches as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15);
304            inputs = pipeline.apply_nhwc(&inputs, seed).ok()?;
305        }
306        if let Some(mixup) = self.mixup.as_ref() {
307            let seed =
308                self.mixup_seed ^ (self.emitted_batches as u64).wrapping_mul(0xD134_2543_DE82_EF95);
309            let mixed = apply_mixup_batch(&inputs, &targets, mixup, seed).ok()?;
310            inputs = mixed.inputs;
311            targets = mixed.targets;
312        }
313        if let Some(cutmix) = self.cutmix.as_ref() {
314            let seed = self.cutmix_seed
315                ^ (self.emitted_batches as u64).wrapping_mul(0x94D0_49BB_1331_11EB);
316            let mixed = apply_cutmix_batch(&inputs, &targets, cutmix, seed).ok()?;
317            inputs = mixed.inputs;
318            targets = mixed.targets;
319        }
320        self.emitted_batches = self.emitted_batches.saturating_add(1);
321        Some(Batch { inputs, targets })
322    }
323}