1use crate::dataset::{Dataset, Sample};
4use crate::transform::Transform;
5
6pub struct SubsetDataset<D: Dataset> {
12 inner: D,
13 indices: Vec<usize>,
14}
15
16impl<D: Dataset> SubsetDataset<D> {
17 pub fn new(inner: D, indices: Vec<usize>) -> Self {
22 Self { inner, indices }
23 }
24}
25
26impl<D: Dataset> Dataset for SubsetDataset<D> {
27 fn len(&self) -> usize {
28 self.indices.len()
29 }
30
31 fn get(&self, index: usize) -> Sample {
32 self.inner.get(self.indices[index])
33 }
34
35 fn feature_shape(&self) -> &[usize] {
36 self.inner.feature_shape()
37 }
38
39 fn target_shape(&self) -> &[usize] {
40 self.inner.target_shape()
41 }
42
43 fn name(&self) -> &str {
44 self.inner.name()
45 }
46}
47
48pub struct ConcatDataset {
54 datasets: Vec<Box<dyn Dataset>>,
55 cumulative_sizes: Vec<usize>,
56 feature_shape: Vec<usize>,
57 target_shape: Vec<usize>,
58}
59
60impl ConcatDataset {
61 pub fn new(datasets: Vec<Box<dyn Dataset>>) -> Self {
66 assert!(
67 !datasets.is_empty(),
68 "ConcatDataset: need at least one dataset"
69 );
70
71 let feature_shape = datasets[0].feature_shape().to_vec();
72 let target_shape = datasets[0].target_shape().to_vec();
73
74 let mut cumulative_sizes = Vec::with_capacity(datasets.len());
75 let mut total = 0;
76 for ds in &datasets {
77 total += ds.len();
78 cumulative_sizes.push(total);
79 }
80
81 Self {
82 datasets,
83 cumulative_sizes,
84 feature_shape,
85 target_shape,
86 }
87 }
88
89 fn locate(&self, index: usize) -> (usize, usize) {
91 for (ds_idx, &cum) in self.cumulative_sizes.iter().enumerate() {
92 if index < cum {
93 let offset = if ds_idx == 0 {
94 0
95 } else {
96 self.cumulative_sizes[ds_idx - 1]
97 };
98 return (ds_idx, index - offset);
99 }
100 }
101 panic!(
102 "ConcatDataset: index {} out of range (total {})",
103 index,
104 self.cumulative_sizes.last().unwrap_or(&0)
105 );
106 }
107}
108
109impl Dataset for ConcatDataset {
110 fn len(&self) -> usize {
111 *self.cumulative_sizes.last().unwrap_or(&0)
112 }
113
114 fn get(&self, index: usize) -> Sample {
115 let (ds_idx, local_idx) = self.locate(index);
116 self.datasets[ds_idx].get(local_idx)
117 }
118
119 fn feature_shape(&self) -> &[usize] {
120 &self.feature_shape
121 }
122
123 fn target_shape(&self) -> &[usize] {
124 &self.target_shape
125 }
126
127 fn name(&self) -> &str {
128 "concat"
129 }
130}
131
132pub struct MapDataset<D: Dataset> {
136 inner: D,
137 transform: Box<dyn Transform>,
138 feat_shape: Vec<usize>,
140 tgt_shape: Vec<usize>,
142}
143
144impl<D: Dataset> MapDataset<D> {
145 pub fn new(
151 inner: D,
152 transform: Box<dyn Transform>,
153 feat_shape: Vec<usize>,
154 tgt_shape: Vec<usize>,
155 ) -> Self {
156 Self {
157 inner,
158 transform,
159 feat_shape,
160 tgt_shape,
161 }
162 }
163
164 pub fn same_shape(inner: D, transform: Box<dyn Transform>) -> Self {
166 let feat_shape = inner.feature_shape().to_vec();
167 let tgt_shape = inner.target_shape().to_vec();
168 Self {
169 inner,
170 transform,
171 feat_shape,
172 tgt_shape,
173 }
174 }
175}
176
177impl<D: Dataset> Dataset for MapDataset<D> {
178 fn len(&self) -> usize {
179 self.inner.len()
180 }
181
182 fn get(&self, index: usize) -> Sample {
183 let sample = self.inner.get(index);
184 self.transform.apply(sample)
185 }
186
187 fn feature_shape(&self) -> &[usize] {
188 &self.feat_shape
189 }
190
191 fn target_shape(&self) -> &[usize] {
192 &self.tgt_shape
193 }
194
195 fn name(&self) -> &str {
196 self.inner.name()
197 }
198}
199
200pub struct VecDataset {
206 samples: Vec<Sample>,
207 feature_shape: Vec<usize>,
208 target_shape: Vec<usize>,
209 dataset_name: String,
210}
211
212impl VecDataset {
213 pub fn new(samples: Vec<Sample>, name: &str) -> Self {
218 assert!(!samples.is_empty(), "VecDataset: need at least one sample");
219 let feature_shape = samples[0].feature_shape.clone();
220 let target_shape = samples[0].target_shape.clone();
221 Self {
222 samples,
223 feature_shape,
224 target_shape,
225 dataset_name: name.to_string(),
226 }
227 }
228
229 pub fn from_flat(
234 features: &[f64],
235 feature_shape: &[usize],
236 targets: &[f64],
237 target_shape: &[usize],
238 name: &str,
239 ) -> Self {
240 let feat_per_sample: usize = feature_shape.iter().product();
241 let tgt_per_sample: usize = target_shape.iter().product();
242 let n = features.len() / feat_per_sample;
243 assert_eq!(features.len(), n * feat_per_sample);
244 assert_eq!(targets.len(), n * tgt_per_sample);
245
246 let samples: Vec<Sample> = (0..n)
247 .map(|i| Sample {
248 features: features[i * feat_per_sample..(i + 1) * feat_per_sample].to_vec(),
249 feature_shape: feature_shape.to_vec(),
250 target: targets[i * tgt_per_sample..(i + 1) * tgt_per_sample].to_vec(),
251 target_shape: target_shape.to_vec(),
252 })
253 .collect();
254
255 Self {
256 samples,
257 feature_shape: feature_shape.to_vec(),
258 target_shape: target_shape.to_vec(),
259 dataset_name: name.to_string(),
260 }
261 }
262}
263
264impl Dataset for VecDataset {
265 fn len(&self) -> usize {
266 self.samples.len()
267 }
268
269 fn get(&self, index: usize) -> Sample {
270 self.samples[index].clone()
271 }
272
273 fn feature_shape(&self) -> &[usize] {
274 &self.feature_shape
275 }
276
277 fn target_shape(&self) -> &[usize] {
278 &self.target_shape
279 }
280
281 fn name(&self) -> &str {
282 &self.dataset_name
283 }
284}
285
286pub fn train_test_split<D>(dataset: D, ratios: &[f64], seed: u64) -> Vec<SubsetDataset<D>>
298where
299 D: Dataset + Clone,
300{
301 use rand::rngs::StdRng;
302 use rand::seq::SliceRandom;
303 use rand::SeedableRng;
304
305 assert!(
306 ratios.len() >= 2 && ratios.len() <= 3,
307 "train_test_split: ratios must have 2 or 3 elements"
308 );
309 let sum: f64 = ratios.iter().sum();
310 assert!(
311 (sum - 1.0).abs() < 1e-6,
312 "train_test_split: ratios must sum to 1.0, got {}",
313 sum
314 );
315
316 let n = dataset.len();
317 let mut indices: Vec<usize> = (0..n).collect();
318 let mut rng = StdRng::seed_from_u64(seed);
319 indices.shuffle(&mut rng);
320
321 let mut splits = Vec::new();
322 let mut offset = 0;
323 for (i, &ratio) in ratios.iter().enumerate() {
324 let count = if i == ratios.len() - 1 {
325 n - offset } else {
327 (n as f64 * ratio).round() as usize
328 };
329 let end = (offset + count).min(n);
330 splits.push(SubsetDataset::new(
331 dataset.clone(),
332 indices[offset..end].to_vec(),
333 ));
334 offset = end;
335 }
336
337 splits
338}
339
340#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[derive(Clone)]
348 struct TinyDataset {
349 n: usize,
350 }
351
352 impl Dataset for TinyDataset {
353 fn len(&self) -> usize {
354 self.n
355 }
356 fn get(&self, idx: usize) -> Sample {
357 Sample {
358 features: vec![idx as f64],
359 feature_shape: vec![1],
360 target: vec![(idx % 3) as f64],
361 target_shape: vec![1],
362 }
363 }
364 fn feature_shape(&self) -> &[usize] {
365 &[1]
366 }
367 fn target_shape(&self) -> &[usize] {
368 &[1]
369 }
370 }
371
372 #[test]
373 fn subset_dataset() {
374 let ds = TinyDataset { n: 10 };
375 let sub = SubsetDataset::new(ds, vec![2, 5, 7]);
376 assert_eq!(sub.len(), 3);
377 assert_eq!(sub.get(0).features[0], 2.0);
378 assert_eq!(sub.get(1).features[0], 5.0);
379 assert_eq!(sub.get(2).features[0], 7.0);
380 }
381
382 #[test]
383 fn concat_dataset() {
384 let ds1 = TinyDataset { n: 5 };
385 let ds2 = TinyDataset { n: 3 };
386 let concat = ConcatDataset::new(vec![Box::new(ds1), Box::new(ds2)]);
387 assert_eq!(concat.len(), 8);
388 assert_eq!(concat.get(0).features[0], 0.0);
390 assert_eq!(concat.get(4).features[0], 4.0);
391 assert_eq!(concat.get(5).features[0], 0.0); assert_eq!(concat.get(7).features[0], 2.0); }
394
395 #[test]
396 fn map_dataset() {
397 use crate::transform::Normalize;
398 let ds = TinyDataset { n: 4 };
399 let mapped = MapDataset::same_shape(ds, Box::new(Normalize::new(10.0)));
400 assert_eq!(mapped.len(), 4);
401 let s = mapped.get(2);
402 assert!((s.features[0] - 0.2).abs() < 1e-10);
403 }
404
405 #[test]
406 fn vec_dataset() {
407 let features = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
408 let targets = vec![0.0, 1.0, 0.0];
409 let ds = VecDataset::from_flat(&features, &[2], &targets, &[1], "test");
410 assert_eq!(ds.len(), 3);
411 assert_eq!(ds.get(0).features, vec![1.0, 2.0]);
412 assert_eq!(ds.get(1).features, vec![3.0, 4.0]);
413 assert_eq!(ds.get(2).target, vec![0.0]);
414 }
415
416 #[test]
417 fn train_test_split_two_way() {
418 let ds = TinyDataset { n: 100 };
419 let splits = train_test_split(ds, &[0.8, 0.2], 42);
420 assert_eq!(splits.len(), 2);
421 assert_eq!(splits[0].len() + splits[1].len(), 100);
422 assert_eq!(splits[0].len(), 80);
423 assert_eq!(splits[1].len(), 20);
424 }
425
426 #[test]
427 fn train_test_split_three_way() {
428 let ds = TinyDataset { n: 100 };
429 let splits = train_test_split(ds, &[0.7, 0.15, 0.15], 42);
430 assert_eq!(splits.len(), 3);
431 assert_eq!(splits[0].len() + splits[1].len() + splits[2].len(), 100);
432 }
433
434 #[test]
435 fn train_test_split_reproducible() {
436 let ds1 = TinyDataset { n: 50 };
437 let ds2 = TinyDataset { n: 50 };
438 let s1 = train_test_split(ds1, &[0.8, 0.2], 123);
439 let s2 = train_test_split(ds2, &[0.8, 0.2], 123);
440 for i in 0..s1[0].len() {
442 assert_eq!(s1[0].get(i).features, s2[0].get(i).features);
443 }
444 }
445}