1use crate::{error_taxonomy::helpers as error_helpers, Dataset};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tenflowers_core::{Result, Tensor};
10
11#[derive(Debug, Clone)]
13pub struct ShardConfig {
14 pub world_size: usize,
16 pub rank: usize,
18 pub strategy: ShardStrategy,
20 pub seed: Option<u64>,
22 pub drop_last: bool,
24 pub num_replicas: usize,
26}
27
28#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum ShardStrategy {
31 RoundRobin,
33 Contiguous,
35 ShuffledRoundRobin,
37 Stratified,
39}
40
41impl ShardConfig {
42 pub fn new(world_size: usize, rank: usize) -> Result<Self> {
44 if world_size == 0 {
45 return Err(error_helpers::invalid_configuration(
46 "ShardConfig::new",
47 "world_size",
48 "world_size must be > 0",
49 ));
50 }
51
52 if rank >= world_size {
53 return Err(error_helpers::invalid_configuration(
54 "ShardConfig::new",
55 "rank",
56 format!("rank {} must be < world_size {}", rank, world_size),
57 ));
58 }
59
60 Ok(Self {
61 world_size,
62 rank,
63 strategy: ShardStrategy::RoundRobin,
64 seed: None,
65 drop_last: false,
66 num_replicas: 1,
67 })
68 }
69
70 pub fn with_strategy(mut self, strategy: ShardStrategy) -> Self {
72 self.strategy = strategy;
73 self
74 }
75
76 pub fn with_seed(mut self, seed: u64) -> Self {
78 self.seed = Some(seed);
79 self
80 }
81
82 pub fn with_drop_last(mut self, drop_last: bool) -> Self {
84 self.drop_last = drop_last;
85 self
86 }
87
88 pub fn with_num_replicas(mut self, num_replicas: usize) -> Self {
90 self.num_replicas = num_replicas;
91 self
92 }
93
94 pub fn validate(&self) -> Result<()> {
96 if self.world_size == 0 {
97 return Err(error_helpers::invalid_configuration(
98 "ShardConfig::validate",
99 "world_size",
100 "world_size must be > 0",
101 ));
102 }
103
104 if self.rank >= self.world_size {
105 return Err(error_helpers::invalid_configuration(
106 "ShardConfig::validate",
107 "rank",
108 format!(
109 "rank {} must be < world_size {}",
110 self.rank, self.world_size
111 ),
112 ));
113 }
114
115 if self.num_replicas == 0 {
116 return Err(error_helpers::invalid_configuration(
117 "ShardConfig::validate",
118 "num_replicas",
119 "num_replicas must be > 0",
120 ));
121 }
122
123 Ok(())
124 }
125}
126
127pub trait ShardableDataset<T>: Dataset<T> {
129 fn get_shard_indices(&self, config: &ShardConfig) -> Result<Vec<usize>>;
131
132 fn num_shards(&self, config: &ShardConfig) -> usize {
134 config.world_size
135 }
136
137 fn shard_size(&self, config: &ShardConfig) -> usize {
139 let indices = self.get_shard_indices(config).unwrap_or_default();
140 indices.len()
141 }
142}
143
144pub struct ShardedDataset<T, D: Dataset<T>> {
146 dataset: Arc<D>,
147 config: ShardConfig,
148 indices: Vec<usize>,
149 _phantom: std::marker::PhantomData<T>,
150}
151
152impl<T, D: Dataset<T>> ShardedDataset<T, D> {
153 pub fn new(dataset: D, config: ShardConfig) -> Result<Self> {
155 config.validate()?;
156
157 let dataset = Arc::new(dataset);
158 let indices = Self::compute_indices(&dataset, &config)?;
159
160 Ok(Self {
161 dataset,
162 config,
163 indices,
164 _phantom: std::marker::PhantomData,
165 })
166 }
167
168 pub fn new_stratified<F>(dataset: D, config: ShardConfig, label_extractor: F) -> Result<Self>
171 where
172 F: Fn(&Tensor<T>) -> Result<usize>,
173 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
174 {
175 config.validate()?;
176
177 let dataset = Arc::new(dataset);
178 let indices = Self::compute_stratified_indices(&dataset, &config, label_extractor)?;
179
180 Ok(Self {
181 dataset,
182 config,
183 indices,
184 _phantom: std::marker::PhantomData,
185 })
186 }
187
188 fn compute_indices(dataset: &D, config: &ShardConfig) -> Result<Vec<usize>> {
190 let total_size = dataset.len();
191
192 if total_size == 0 {
193 return Ok(Vec::new());
194 }
195
196 let mut all_indices: Vec<usize> = (0..total_size).collect();
197
198 match &config.strategy {
200 ShardStrategy::RoundRobin => {
201 }
203 ShardStrategy::Contiguous => {
204 }
206 ShardStrategy::ShuffledRoundRobin => {
207 if let Some(seed) = config.seed {
209 Self::deterministic_shuffle(&mut all_indices, seed);
210 }
211 }
212 ShardStrategy::Stratified => {
213 }
217 }
218
219 let shard_indices = match &config.strategy {
221 ShardStrategy::RoundRobin | ShardStrategy::ShuffledRoundRobin => {
222 all_indices
224 .iter()
225 .enumerate()
226 .filter(|(i, _)| i % config.world_size == config.rank)
227 .map(|(_, &idx)| idx)
228 .collect()
229 }
230 ShardStrategy::Contiguous => {
231 let samples_per_worker = total_size / config.world_size;
233 let extra_samples = total_size % config.world_size;
234
235 let start = if config.rank < extra_samples {
236 config.rank * (samples_per_worker + 1)
237 } else {
238 config.rank * samples_per_worker + extra_samples
239 };
240
241 let count = if config.rank < extra_samples {
242 samples_per_worker + 1
243 } else {
244 samples_per_worker
245 };
246
247 all_indices[start..start + count].to_vec()
248 }
249 ShardStrategy::Stratified => {
250 all_indices
252 .iter()
253 .enumerate()
254 .filter(|(i, _)| i % config.world_size == config.rank)
255 .map(|(_, &idx)| idx)
256 .collect()
257 }
258 };
259
260 Ok(shard_indices)
261 }
262
263 fn deterministic_shuffle(indices: &mut [usize], seed: u64) {
265 let mut rng_state = seed;
266
267 for i in (1..indices.len()).rev() {
268 rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
270 let j = (rng_state as usize) % (i + 1);
271 indices.swap(i, j);
272 }
273 }
274
275 fn compute_stratified_indices<F>(
277 dataset: &D,
278 config: &ShardConfig,
279 label_extractor: F,
280 ) -> Result<Vec<usize>>
281 where
282 F: Fn(&Tensor<T>) -> Result<usize>,
283 T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
284 {
285 let total_size = dataset.len();
286
287 if total_size == 0 {
288 return Ok(Vec::new());
289 }
290
291 let mut class_to_indices: HashMap<usize, Vec<usize>> = HashMap::new();
293
294 for i in 0..total_size {
295 let (_, label_tensor) = dataset.get(i)?;
296 let class = label_extractor(&label_tensor)?;
297 class_to_indices.entry(class).or_default().push(i);
298 }
299
300 let mut worker_indices: Vec<Vec<usize>> = vec![Vec::new(); config.world_size];
302
303 let mut classes: Vec<_> = class_to_indices.keys().cloned().collect();
305 classes.sort_unstable();
306
307 for class in classes {
308 let mut indices = class_to_indices
309 .remove(&class)
310 .expect("class should exist in map since we got it from keys()");
311
312 if let Some(seed) = config.seed {
314 Self::deterministic_shuffle(&mut indices, seed.wrapping_add(class as u64));
316 }
317
318 for (idx_pos, &global_idx) in indices.iter().enumerate() {
320 let worker_id = idx_pos % config.world_size;
321 worker_indices[worker_id].push(global_idx);
322 }
323 }
324
325 let mut shard_indices = worker_indices[config.rank].clone();
327
328 if let Some(seed) = config.seed {
330 Self::deterministic_shuffle(&mut shard_indices, seed.wrapping_add(config.rank as u64));
331 }
332
333 Ok(shard_indices)
334 }
335
336 pub fn inner(&self) -> &D {
338 &self.dataset
339 }
340
341 pub fn config(&self) -> &ShardConfig {
343 &self.config
344 }
345
346 pub fn indices(&self) -> &[usize] {
348 &self.indices
349 }
350
351 pub fn shard_stats(&self) -> ShardStatistics {
353 let total_size = self.dataset.len();
354 let shard_size = self.indices.len();
355
356 let min_shard_size = total_size / self.config.world_size;
357 let max_shard_size = (total_size + self.config.world_size - 1) / self.config.world_size;
358
359 ShardStatistics {
360 total_samples: total_size,
361 shard_size,
362 min_shard_size,
363 max_shard_size,
364 world_size: self.config.world_size,
365 rank: self.config.rank,
366 imbalance_ratio: if min_shard_size > 0 {
367 max_shard_size as f64 / min_shard_size as f64
368 } else {
369 0.0
370 },
371 }
372 }
373}
374
375impl<T, D: Dataset<T>> Dataset<T> for ShardedDataset<T, D> {
376 fn get(
377 &self,
378 index: usize,
379 ) -> Result<(tenflowers_core::Tensor<T>, tenflowers_core::Tensor<T>)> {
380 if index >= self.indices.len() {
381 return Err(error_helpers::index_out_of_bounds(
382 "ShardedDataset::get",
383 index,
384 self.indices.len(),
385 ));
386 }
387
388 let actual_index = self.indices[index];
389 self.dataset.get(actual_index)
390 }
391
392 fn len(&self) -> usize {
393 self.indices.len()
394 }
395}
396
397#[derive(Debug, Clone)]
399pub struct ShardStatistics {
400 pub total_samples: usize,
402 pub shard_size: usize,
404 pub min_shard_size: usize,
406 pub max_shard_size: usize,
408 pub world_size: usize,
410 pub rank: usize,
412 pub imbalance_ratio: f64,
414}
415
416impl ShardStatistics {
417 pub fn is_balanced(&self) -> bool {
419 self.imbalance_ratio <= 1.1 }
421
422 pub fn report(&self) -> String {
424 format!(
425 "Shard Statistics:\n\
426 - Total samples: {}\n\
427 - World size: {} workers\n\
428 - Rank: {}\n\
429 - This shard size: {}\n\
430 - Min shard size: {}\n\
431 - Max shard size: {}\n\
432 - Imbalance ratio: {:.2}\n\
433 - Balanced: {}",
434 self.total_samples,
435 self.world_size,
436 self.rank,
437 self.shard_size,
438 self.min_shard_size,
439 self.max_shard_size,
440 self.imbalance_ratio,
441 if self.is_balanced() { "Yes" } else { "No" }
442 )
443 }
444}
445
446pub trait DatasetShardingExt<T>: Dataset<T> + Sized {
448 fn shard(self, config: ShardConfig) -> Result<ShardedDataset<T, Self>> {
450 ShardedDataset::new(self, config)
451 }
452
453 fn shard_round_robin(self, world_size: usize, rank: usize) -> Result<ShardedDataset<T, Self>> {
455 let config = ShardConfig::new(world_size, rank)?;
456 ShardedDataset::new(self, config)
457 }
458
459 fn shard_contiguous(self, world_size: usize, rank: usize) -> Result<ShardedDataset<T, Self>> {
461 let config = ShardConfig::new(world_size, rank)?.with_strategy(ShardStrategy::Contiguous);
462 ShardedDataset::new(self, config)
463 }
464
465 fn shard_shuffled(
467 self,
468 world_size: usize,
469 rank: usize,
470 seed: u64,
471 ) -> Result<ShardedDataset<T, Self>> {
472 let config = ShardConfig::new(world_size, rank)?
473 .with_strategy(ShardStrategy::ShuffledRoundRobin)
474 .with_seed(seed);
475 ShardedDataset::new(self, config)
476 }
477}
478
479impl<T, D: Dataset<T>> DatasetShardingExt<T> for D {}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use crate::TensorDataset;
486 use tenflowers_core::Tensor;
487
488 #[test]
489 fn test_shard_config_creation() {
490 let config = ShardConfig::new(4, 0).expect("config creation should succeed");
491 assert_eq!(config.world_size, 4);
492 assert_eq!(config.rank, 0);
493 assert_eq!(config.strategy, ShardStrategy::RoundRobin);
494 }
495
496 #[test]
497 fn test_shard_config_validation() {
498 assert!(ShardConfig::new(0, 0).is_err());
499 assert!(ShardConfig::new(4, 4).is_err());
500 assert!(ShardConfig::new(4, 5).is_err());
501 assert!(ShardConfig::new(4, 3).is_ok());
502 }
503
504 #[test]
505 fn test_round_robin_sharding() {
506 let features = Tensor::<f32>::from_vec(
507 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
508 &[10, 1],
509 )
510 .expect("tensor creation should succeed");
511 let labels =
512 Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
513 let dataset = TensorDataset::new(features, labels);
514
515 let config = ShardConfig::new(3, 0).expect("config creation should succeed");
517 let sharded =
518 ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
519
520 assert_eq!(sharded.len(), 4);
522 assert_eq!(sharded.indices(), &[0, 3, 6, 9]);
523 }
524
525 #[test]
526 fn test_contiguous_sharding() {
527 let features = Tensor::<f32>::from_vec(
528 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
529 &[10, 1],
530 )
531 .expect("tensor creation should succeed");
532 let labels =
533 Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
534 let dataset = TensorDataset::new(features, labels);
535
536 let config = ShardConfig::new(3, 1)
538 .expect("test: operation should succeed")
539 .with_strategy(ShardStrategy::Contiguous);
540 let sharded =
541 ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
542
543 assert_eq!(sharded.len(), 3);
547 assert_eq!(sharded.indices(), &[4, 5, 6]);
548 }
549
550 #[test]
551 fn test_shuffled_sharding_deterministic() {
552 let features = Tensor::<f32>::from_vec(vec![1.0; 100], &[100, 1])
553 .expect("tensor creation should succeed");
554 let labels = Tensor::<f32>::from_vec(vec![1.0; 100], &[100])
555 .expect("tensor creation should succeed");
556 let dataset1 = TensorDataset::new(features.clone(), labels.clone());
557 let dataset2 = TensorDataset::new(features, labels);
558
559 let config1 = ShardConfig::new(4, 0)
560 .expect("config creation should succeed")
561 .with_strategy(ShardStrategy::ShuffledRoundRobin)
562 .with_seed(42);
563 let config2 = ShardConfig::new(4, 0)
564 .expect("config creation should succeed")
565 .with_strategy(ShardStrategy::ShuffledRoundRobin)
566 .with_seed(42);
567
568 let sharded1 = ShardedDataset::new(dataset1, config1)
569 .expect("sharded dataset creation should succeed");
570 let sharded2 = ShardedDataset::new(dataset2, config2)
571 .expect("sharded dataset creation should succeed");
572
573 assert_eq!(sharded1.indices(), sharded2.indices());
575 }
576
577 #[test]
578 fn test_shard_statistics() {
579 let features = Tensor::<f32>::from_vec(vec![1.0; 100], &[100, 1])
580 .expect("tensor creation should succeed");
581 let labels = Tensor::<f32>::from_vec(vec![1.0; 100], &[100])
582 .expect("tensor creation should succeed");
583 let dataset = TensorDataset::new(features, labels);
584
585 let config = ShardConfig::new(3, 0).expect("config creation should succeed");
586 let sharded =
587 ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
588
589 let stats = sharded.shard_stats();
590 assert_eq!(stats.total_samples, 100);
591 assert_eq!(stats.world_size, 3);
592 assert_eq!(stats.rank, 0);
593 assert!(stats.imbalance_ratio >= 1.0);
594 }
595
596 #[test]
597 fn test_extension_trait_round_robin() {
598 let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
599 .expect("tensor creation should succeed");
600 let labels =
601 Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
602 let dataset = TensorDataset::new(features, labels);
603
604 let sharded = dataset
605 .shard_round_robin(2, 0)
606 .expect("shard_round_robin should succeed");
607 assert_eq!(sharded.len(), 5);
608 }
609
610 #[test]
611 fn test_extension_trait_contiguous() {
612 let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
613 .expect("tensor creation should succeed");
614 let labels =
615 Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
616 let dataset = TensorDataset::new(features, labels);
617
618 let sharded = dataset
619 .shard_contiguous(2, 0)
620 .expect("shard_contiguous should succeed");
621 assert_eq!(sharded.len(), 5);
622 assert_eq!(sharded.indices(), &[0, 1, 2, 3, 4]);
623 }
624
625 #[test]
626 fn test_extension_trait_shuffled() {
627 let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
628 .expect("tensor creation should succeed");
629 let labels =
630 Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
631 let dataset = TensorDataset::new(features, labels);
632
633 let sharded = dataset
634 .shard_shuffled(2, 0, 42)
635 .expect("shard_shuffled should succeed");
636 assert_eq!(sharded.len(), 5);
637 }
638
639 #[test]
640 fn test_shard_access() {
641 let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6, 1])
642 .expect("tensor creation should succeed");
643 let labels = Tensor::<f32>::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[6])
644 .expect("tensor creation should succeed");
645 let dataset = TensorDataset::new(features, labels);
646
647 let config = ShardConfig::new(2, 0).expect("config creation should succeed");
648 let sharded =
649 ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
650
651 let (f0, l0) = sharded.get(0).expect("index should be in bounds");
653 let (f1, l1) = sharded.get(1).expect("index should be in bounds");
654 let (f2, l2) = sharded.get(2).expect("index should be in bounds");
655
656 assert!((f0.to_vec().expect("to_vec should succeed")[0] - 1.0).abs() < 1e-6);
658 assert!((l0.to_vec().expect("to_vec should succeed")[0] - 10.0).abs() < 1e-6);
659
660 assert!((f1.to_vec().expect("to_vec should succeed")[0] - 3.0).abs() < 1e-6);
661 assert!((l1.to_vec().expect("to_vec should succeed")[0] - 30.0).abs() < 1e-6);
662
663 assert!((f2.to_vec().expect("to_vec should succeed")[0] - 5.0).abs() < 1e-6);
664 assert!((l2.to_vec().expect("to_vec should succeed")[0] - 50.0).abs() < 1e-6);
665 }
666
667 #[test]
668 fn test_shard_out_of_bounds() {
669 let features =
670 Tensor::<f32>::from_vec(vec![1.0; 6], &[6, 1]).expect("tensor creation should succeed");
671 let labels =
672 Tensor::<f32>::from_vec(vec![1.0; 6], &[6]).expect("tensor creation should succeed");
673 let dataset = TensorDataset::new(features, labels);
674
675 let sharded = dataset
676 .shard_round_robin(2, 0)
677 .expect("shard_round_robin should succeed");
678 assert_eq!(sharded.len(), 3);
679 assert!(sharded.get(3).is_err());
680 }
681
682 #[test]
683 fn test_empty_dataset_sharding() {
684 let features =
685 Tensor::<f32>::from_vec(vec![], &[0, 1]).expect("empty tensor creation should succeed");
686 let labels =
687 Tensor::<f32>::from_vec(vec![], &[0]).expect("empty tensor creation should succeed");
688 let dataset = TensorDataset::new(features, labels);
689
690 let sharded = dataset
691 .shard_round_robin(2, 0)
692 .expect("shard_round_robin should succeed");
693 assert_eq!(sharded.len(), 0);
694 }
695
696 #[test]
697 fn test_shard_statistics_balanced() {
698 let features = Tensor::<f32>::from_vec(vec![1.0; 12], &[12, 1])
699 .expect("tensor creation should succeed");
700 let labels =
701 Tensor::<f32>::from_vec(vec![1.0; 12], &[12]).expect("tensor creation should succeed");
702 let dataset = TensorDataset::new(features, labels);
703
704 let config = ShardConfig::new(3, 0).expect("config creation should succeed"); let sharded =
706 ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
707
708 let stats = sharded.shard_stats();
709 assert!(stats.is_balanced());
710 assert_eq!(stats.imbalance_ratio, 1.0);
711 }
712
713 #[test]
714 fn test_shard_statistics_report() {
715 let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
716 .expect("tensor creation should succeed");
717 let labels =
718 Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
719 let dataset = TensorDataset::new(features, labels);
720
721 let config = ShardConfig::new(3, 0).expect("config creation should succeed");
722 let sharded =
723 ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
724
725 let report = sharded.shard_stats().report();
726 assert!(report.contains("Total samples: 10"));
727 assert!(report.contains("World size: 3"));
728 assert!(report.contains("Rank: 0"));
729 }
730
731 #[test]
732 fn test_stratified_sharding() {
733 let features = Tensor::<f32>::from_vec(
735 vec![
736 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
737 ],
738 &[12, 1],
739 )
740 .expect("tensor creation should succeed");
741 let labels = Tensor::<f32>::from_vec(
742 vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
743 &[12],
744 )
745 .expect("tensor creation should succeed");
746 let dataset = TensorDataset::new(features, labels);
747
748 let label_extractor = |label_tensor: &Tensor<f32>| -> Result<usize> {
750 let data = label_tensor
751 .to_vec()
752 .map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
753 Ok(data[0] as usize)
754 };
755
756 let config = ShardConfig::new(2, 0)
758 .expect("config creation should succeed")
759 .with_strategy(ShardStrategy::Stratified)
760 .with_seed(42);
761
762 let sharded = ShardedDataset::new_stratified(dataset, config, label_extractor)
763 .expect("stratified sharding should succeed");
764
765 assert_eq!(sharded.len(), 6); for i in 0..sharded.len() {
771 let (feature, label) = sharded.get(i).expect("get should succeed");
772 assert!(feature.to_vec().is_ok());
773 assert!(label.to_vec().is_ok());
774 }
775 }
776
777 #[test]
778 fn test_stratified_sharding_balanced_classes() {
779 let features = Tensor::<f32>::from_vec(vec![1.0; 60], &[60, 1])
781 .expect("tensor creation should succeed");
782 let mut label_data = vec![0.0; 20];
784 label_data.extend(vec![1.0; 20]);
785 label_data.extend(vec![2.0; 20]);
786 let labels =
787 Tensor::<f32>::from_vec(label_data, &[60]).expect("tensor creation should succeed");
788 let dataset = TensorDataset::new(features, labels);
789
790 let label_extractor = |label_tensor: &Tensor<f32>| -> Result<usize> {
791 let data = label_tensor
792 .to_vec()
793 .map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
794 Ok(data[0] as usize)
795 };
796
797 let config = ShardConfig::new(3, 0)
799 .expect("config creation should succeed")
800 .with_strategy(ShardStrategy::Stratified)
801 .with_seed(123);
802
803 let sharded = ShardedDataset::new_stratified(dataset, config, label_extractor)
804 .expect("stratified sharding should succeed");
805
806 assert!(sharded.len() >= 18 && sharded.len() <= 21);
811 }
812
813 #[test]
814 fn test_stratified_sharding_deterministic() {
815 let features = Tensor::<f32>::from_vec(vec![1.0; 30], &[30, 1])
816 .expect("test: tensor creation should succeed");
817 let labels = Tensor::<f32>::from_vec(
818 vec![
819 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0,
820 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0,
821 ],
822 &[30],
823 )
824 .expect("tensor creation should succeed");
825 let dataset1 = TensorDataset::new(features.clone(), labels.clone());
826 let dataset2 = TensorDataset::new(features, labels);
827
828 let label_extractor1 = |label_tensor: &Tensor<f32>| -> Result<usize> {
829 let data = label_tensor
830 .to_vec()
831 .map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
832 Ok(data[0] as usize)
833 };
834
835 let label_extractor2 = |label_tensor: &Tensor<f32>| -> Result<usize> {
836 let data = label_tensor
837 .to_vec()
838 .map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
839 Ok(data[0] as usize)
840 };
841
842 let config1 = ShardConfig::new(2, 0)
844 .expect("config creation should succeed")
845 .with_strategy(ShardStrategy::Stratified)
846 .with_seed(999);
847
848 let config2 = ShardConfig::new(2, 0)
849 .expect("config creation should succeed")
850 .with_strategy(ShardStrategy::Stratified)
851 .with_seed(999);
852
853 let sharded1 = ShardedDataset::new_stratified(dataset1, config1, label_extractor1)
854 .expect("stratified sharding should succeed");
855 let sharded2 = ShardedDataset::new_stratified(dataset2, config2, label_extractor2)
856 .expect("stratified sharding should succeed");
857
858 assert_eq!(sharded1.indices(), sharded2.indices());
860 }
861}