1use std::collections::HashMap;
2use yscv_tensor::Tensor;
3
4use crate::{ImageAugmentationPipeline, ModelError};
5
6use super::helpers::{
7 class_labels_from_targets, gather_rows, shuffle_indices, validate_split_ratios,
8};
9use super::iter::{
10 CutMixConfig, MixUpConfig, apply_cutmix_batch, apply_mixup_batch, build_sample_order,
11 validate_augmentation_compatibility, validate_cutmix_compatibility, validate_cutmix_config,
12 validate_mixup_config,
13};
14
15#[derive(Debug, Clone, PartialEq)]
17pub struct SupervisedDataset {
18 inputs: Tensor,
19 targets: Tensor,
20}
21
22impl SupervisedDataset {
23 pub fn new(inputs: Tensor, targets: Tensor) -> Result<Self, ModelError> {
24 if inputs.rank() == 0 || targets.rank() == 0 {
25 return Err(ModelError::InvalidDatasetRank {
26 inputs_rank: inputs.rank(),
27 targets_rank: targets.rank(),
28 });
29 }
30 if inputs.shape()[0] != targets.shape()[0] {
31 return Err(ModelError::DatasetShapeMismatch {
32 inputs: inputs.shape().to_vec(),
33 targets: targets.shape().to_vec(),
34 });
35 }
36 Ok(Self { inputs, targets })
37 }
38
39 pub fn len(&self) -> usize {
40 self.inputs.shape()[0]
41 }
42
43 pub fn is_empty(&self) -> bool {
44 self.len() == 0
45 }
46
47 pub fn inputs(&self) -> &Tensor {
48 &self.inputs
49 }
50
51 pub fn targets(&self) -> &Tensor {
52 &self.targets
53 }
54
55 pub fn batches(&self, batch_size: usize) -> Result<MiniBatchIter<'_>, ModelError> {
56 self.batches_with_options(batch_size, BatchIterOptions::default())
57 }
58
59 pub fn batches_with_options(
60 &self,
61 batch_size: usize,
62 options: BatchIterOptions,
63 ) -> Result<MiniBatchIter<'_>, ModelError> {
64 if batch_size == 0 {
65 return Err(ModelError::InvalidBatchSize { batch_size });
66 }
67 if let Some(pipeline) = options.augmentation.as_ref() {
68 validate_augmentation_compatibility(self.inputs(), pipeline)?;
69 }
70 if let Some(mixup) = options.mixup.as_ref() {
71 validate_mixup_config(mixup)?;
72 }
73 if let Some(cutmix) = options.cutmix.as_ref() {
74 validate_cutmix_config(cutmix)?;
75 validate_cutmix_compatibility(self.inputs())?;
76 }
77
78 let order = build_sample_order(self, &options)?;
79
80 Ok(MiniBatchIter {
81 dataset: self,
82 batch_size,
83 cursor: 0,
84 order,
85 drop_last: options.drop_last,
86 augmentation: options.augmentation,
87 augmentation_seed: options.augmentation_seed,
88 mixup: options.mixup,
89 mixup_seed: options.mixup_seed,
90 cutmix: options.cutmix,
91 cutmix_seed: options.cutmix_seed,
92 emitted_batches: 0,
93 })
94 }
95
96 pub fn split_by_counts(
97 &self,
98 train_count: usize,
99 validation_count: usize,
100 shuffle: bool,
101 seed: u64,
102 ) -> Result<DatasetSplit, ModelError> {
103 if train_count
104 .checked_add(validation_count)
105 .is_none_or(|sum| sum > self.len())
106 {
107 return Err(ModelError::InvalidSplitCounts {
108 train_count,
109 validation_count,
110 dataset_len: self.len(),
111 });
112 }
113
114 let mut order = (0..self.len()).collect::<Vec<_>>();
115 if shuffle {
116 shuffle_indices(&mut order, seed);
117 }
118
119 let train = self.subset_by_indices(&order[..train_count])?;
120 let validation_end = train_count + validation_count;
121 let validation = self.subset_by_indices(&order[train_count..validation_end])?;
122 let test = self.subset_by_indices(&order[validation_end..])?;
123 Ok(DatasetSplit {
124 train,
125 validation,
126 test,
127 })
128 }
129
130 pub fn split_by_ratio(
131 &self,
132 train_ratio: f32,
133 validation_ratio: f32,
134 shuffle: bool,
135 seed: u64,
136 ) -> Result<DatasetSplit, ModelError> {
137 validate_split_ratios(train_ratio, validation_ratio)?;
138
139 let len = self.len();
140 let train_count = ((len as f64) * train_ratio as f64).floor() as usize;
141 let remaining = len.saturating_sub(train_count);
142 let validation_count =
143 (((len as f64) * validation_ratio as f64).floor() as usize).min(remaining);
144 self.split_by_counts(train_count, validation_count, shuffle, seed)
145 }
146
147 pub fn split_by_class_ratio(
148 &self,
149 train_ratio: f32,
150 validation_ratio: f32,
151 shuffle: bool,
152 seed: u64,
153 ) -> Result<DatasetSplit, ModelError> {
154 validate_split_ratios(train_ratio, validation_ratio)?;
155
156 let class_labels = class_labels_from_targets(self.targets())?;
157 let mut indices_by_class = HashMap::<usize, Vec<usize>>::new();
158 for (sample_index, class_id) in class_labels.into_iter().enumerate() {
159 indices_by_class
160 .entry(class_id)
161 .or_default()
162 .push(sample_index);
163 }
164
165 let mut class_order = indices_by_class.keys().copied().collect::<Vec<_>>();
166 class_order.sort_unstable();
167
168 let mut train_indices = Vec::new();
169 let mut validation_indices = Vec::new();
170 let mut test_indices = Vec::new();
171
172 for class_id in class_order {
173 let mut class_indices = indices_by_class
174 .remove(&class_id)
175 .ok_or(ModelError::InvalidSamplingDistribution)?;
176 if shuffle {
177 let class_seed = seed ^ (class_id as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15);
178 shuffle_indices(&mut class_indices, class_seed);
179 }
180
181 let class_len = class_indices.len();
182 let class_train_count = ((class_len as f64) * train_ratio as f64).floor() as usize;
183 let class_remaining = class_len.saturating_sub(class_train_count);
184 let class_validation_count = (((class_len as f64) * validation_ratio as f64).floor()
185 as usize)
186 .min(class_remaining);
187
188 let validation_start = class_train_count;
189 let validation_end = validation_start + class_validation_count;
190 train_indices.extend_from_slice(&class_indices[..class_train_count]);
191 validation_indices.extend_from_slice(&class_indices[validation_start..validation_end]);
192 test_indices.extend_from_slice(&class_indices[validation_end..]);
193 }
194
195 if shuffle {
196 shuffle_indices(&mut train_indices, seed ^ 0xA55A_5AA5_1234_5678);
197 shuffle_indices(&mut validation_indices, seed ^ 0xBEE5_F00D_89AB_CDEF);
198 shuffle_indices(&mut test_indices, seed ^ 0xDEAD_BEEF_CAFE_BABE);
199 }
200
201 let train = self.subset_by_indices(&train_indices)?;
202 let validation = self.subset_by_indices(&validation_indices)?;
203 let test = self.subset_by_indices(&test_indices)?;
204 Ok(DatasetSplit {
205 train,
206 validation,
207 test,
208 })
209 }
210
211 fn subset_by_indices(&self, indices: &[usize]) -> Result<SupervisedDataset, ModelError> {
212 let inputs = gather_rows(self.inputs(), indices)?;
213 let targets = gather_rows(self.targets(), indices)?;
214 SupervisedDataset::new(inputs, targets)
215 }
216}
217
218#[derive(Debug, Clone, PartialEq)]
220pub struct Batch {
221 pub inputs: Tensor,
222 pub targets: Tensor,
223}
224
225#[derive(Debug, Clone, PartialEq)]
227pub enum SamplingPolicy {
228 Sequential,
229 Shuffled {
230 seed: u64,
231 },
232 BalancedByClass {
233 seed: u64,
234 with_replacement: bool,
235 },
236 Weighted {
237 weights: Vec<f32>,
238 seed: u64,
239 with_replacement: bool,
240 },
241}
242
243#[derive(Debug, Clone, PartialEq, Default)]
245pub struct BatchIterOptions {
246 pub shuffle: bool,
247 pub shuffle_seed: u64,
248 pub drop_last: bool,
249 pub augmentation: Option<ImageAugmentationPipeline>,
250 pub augmentation_seed: u64,
251 pub mixup: Option<MixUpConfig>,
252 pub mixup_seed: u64,
253 pub cutmix: Option<CutMixConfig>,
254 pub cutmix_seed: u64,
255 pub sampling: Option<SamplingPolicy>,
256}
257
258#[derive(Debug, Clone, PartialEq)]
260pub struct DatasetSplit {
261 pub train: SupervisedDataset,
262 pub validation: SupervisedDataset,
263 pub test: SupervisedDataset,
264}
265
266#[derive(Debug)]
268pub struct MiniBatchIter<'a> {
269 pub(super) dataset: &'a SupervisedDataset,
270 pub(super) batch_size: usize,
271 pub(super) cursor: usize,
272 pub(super) order: Vec<usize>,
273 pub(super) drop_last: bool,
274 pub(super) augmentation: Option<ImageAugmentationPipeline>,
275 pub(super) augmentation_seed: u64,
276 pub(super) mixup: Option<MixUpConfig>,
277 pub(super) mixup_seed: u64,
278 pub(super) cutmix: Option<CutMixConfig>,
279 pub(super) cutmix_seed: u64,
280 pub(super) emitted_batches: usize,
281}
282
283impl Iterator for MiniBatchIter<'_> {
284 type Item = Batch;
285
286 fn next(&mut self) -> Option<Self::Item> {
287 if self.cursor >= self.order.len() {
288 return None;
289 }
290 let start = self.cursor;
291 let end = (self.cursor + self.batch_size).min(self.order.len());
292 if self.drop_last && (end - start) < self.batch_size {
293 self.cursor = self.order.len();
294 return None;
295 }
296 self.cursor = end;
297
298 let batch_indices = &self.order[start..end];
299 let mut inputs = gather_rows(&self.dataset.inputs, batch_indices).ok()?;
300 let mut targets = gather_rows(&self.dataset.targets, batch_indices).ok()?;
301 if let Some(pipeline) = self.augmentation.as_ref() {
302 let seed = self.augmentation_seed
303 ^ (self.emitted_batches as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15);
304 inputs = pipeline.apply_nhwc(&inputs, seed).ok()?;
305 }
306 if let Some(mixup) = self.mixup.as_ref() {
307 let seed =
308 self.mixup_seed ^ (self.emitted_batches as u64).wrapping_mul(0xD134_2543_DE82_EF95);
309 let mixed = apply_mixup_batch(&inputs, &targets, mixup, seed).ok()?;
310 inputs = mixed.inputs;
311 targets = mixed.targets;
312 }
313 if let Some(cutmix) = self.cutmix.as_ref() {
314 let seed = self.cutmix_seed
315 ^ (self.emitted_batches as u64).wrapping_mul(0x94D0_49BB_1331_11EB);
316 let mixed = apply_cutmix_batch(&inputs, &targets, cutmix, seed).ok()?;
317 inputs = mixed.inputs;
318 targets = mixed.targets;
319 }
320 self.emitted_batches = self.emitted_batches.saturating_add(1);
321 Some(Batch { inputs, targets })
322 }
323}