Skip to main content

tenflowers_dataset/
distributed_sharding.rs

1//! Deterministic shard loader for distributed training
2//!
3//! This module provides deterministic partitioning of datasets across multiple workers
4//! for distributed training, ensuring reproducibility and balanced distribution.
5
6use crate::{error_taxonomy::helpers as error_helpers, Dataset};
7use std::collections::HashMap;
8use std::sync::Arc;
9use tenflowers_core::{Result, Tensor};
10
11/// Configuration for distributed sharding
12#[derive(Debug, Clone)]
13pub struct ShardConfig {
14    /// Total number of workers (world size)
15    pub world_size: usize,
16    /// Current worker rank (0-indexed)
17    pub rank: usize,
18    /// Strategy for distributing samples
19    pub strategy: ShardStrategy,
20    /// Seed for deterministic shuffling
21    pub seed: Option<u64>,
22    /// Whether to drop last incomplete batch
23    pub drop_last: bool,
24    /// Number of replicas per shard (for fault tolerance)
25    pub num_replicas: usize,
26}
27
28/// Strategy for distributing samples across workers
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum ShardStrategy {
31    /// Round-robin distribution (sample i goes to worker i % world_size)
32    RoundRobin,
33    /// Contiguous blocks (each worker gets a contiguous range)
34    Contiguous,
35    /// Deterministic shuffle then round-robin
36    ShuffledRoundRobin,
37    /// Stratified sampling (requires label information)
38    Stratified,
39}
40
41impl ShardConfig {
42    /// Create a new shard configuration
43    pub fn new(world_size: usize, rank: usize) -> Result<Self> {
44        if world_size == 0 {
45            return Err(error_helpers::invalid_configuration(
46                "ShardConfig::new",
47                "world_size",
48                "world_size must be > 0",
49            ));
50        }
51
52        if rank >= world_size {
53            return Err(error_helpers::invalid_configuration(
54                "ShardConfig::new",
55                "rank",
56                format!("rank {} must be < world_size {}", rank, world_size),
57            ));
58        }
59
60        Ok(Self {
61            world_size,
62            rank,
63            strategy: ShardStrategy::RoundRobin,
64            seed: None,
65            drop_last: false,
66            num_replicas: 1,
67        })
68    }
69
70    /// Set the sharding strategy
71    pub fn with_strategy(mut self, strategy: ShardStrategy) -> Self {
72        self.strategy = strategy;
73        self
74    }
75
76    /// Set the seed for deterministic shuffling
77    pub fn with_seed(mut self, seed: u64) -> Self {
78        self.seed = Some(seed);
79        self
80    }
81
82    /// Set whether to drop the last incomplete batch
83    pub fn with_drop_last(mut self, drop_last: bool) -> Self {
84        self.drop_last = drop_last;
85        self
86    }
87
88    /// Set the number of replicas for fault tolerance
89    pub fn with_num_replicas(mut self, num_replicas: usize) -> Self {
90        self.num_replicas = num_replicas;
91        self
92    }
93
94    /// Validate the configuration
95    pub fn validate(&self) -> Result<()> {
96        if self.world_size == 0 {
97            return Err(error_helpers::invalid_configuration(
98                "ShardConfig::validate",
99                "world_size",
100                "world_size must be > 0",
101            ));
102        }
103
104        if self.rank >= self.world_size {
105            return Err(error_helpers::invalid_configuration(
106                "ShardConfig::validate",
107                "rank",
108                format!(
109                    "rank {} must be < world_size {}",
110                    self.rank, self.world_size
111                ),
112            ));
113        }
114
115        if self.num_replicas == 0 {
116            return Err(error_helpers::invalid_configuration(
117                "ShardConfig::validate",
118                "num_replicas",
119                "num_replicas must be > 0",
120            ));
121        }
122
123        Ok(())
124    }
125}
126
127/// Trait for datasets that can be sharded for distributed training
128pub trait ShardableDataset<T>: Dataset<T> {
129    /// Get indices for this worker's shard
130    fn get_shard_indices(&self, config: &ShardConfig) -> Result<Vec<usize>>;
131
132    /// Get the total number of shards
133    fn num_shards(&self, config: &ShardConfig) -> usize {
134        config.world_size
135    }
136
137    /// Get shard size for this worker
138    fn shard_size(&self, config: &ShardConfig) -> usize {
139        let indices = self.get_shard_indices(config).unwrap_or_default();
140        indices.len()
141    }
142}
143
144/// Wrapper that makes any dataset shardable
145pub struct ShardedDataset<T, D: Dataset<T>> {
146    dataset: Arc<D>,
147    config: ShardConfig,
148    indices: Vec<usize>,
149    _phantom: std::marker::PhantomData<T>,
150}
151
152impl<T, D: Dataset<T>> ShardedDataset<T, D> {
153    /// Create a new sharded dataset
154    pub fn new(dataset: D, config: ShardConfig) -> Result<Self> {
155        config.validate()?;
156
157        let dataset = Arc::new(dataset);
158        let indices = Self::compute_indices(&dataset, &config)?;
159
160        Ok(Self {
161            dataset,
162            config,
163            indices,
164            _phantom: std::marker::PhantomData,
165        })
166    }
167
168    /// Create a sharded dataset with stratified sampling
169    /// This requires a label extractor function to determine the class of each sample
170    pub fn new_stratified<F>(dataset: D, config: ShardConfig, label_extractor: F) -> Result<Self>
171    where
172        F: Fn(&Tensor<T>) -> Result<usize>,
173        T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
174    {
175        config.validate()?;
176
177        let dataset = Arc::new(dataset);
178        let indices = Self::compute_stratified_indices(&dataset, &config, label_extractor)?;
179
180        Ok(Self {
181            dataset,
182            config,
183            indices,
184            _phantom: std::marker::PhantomData,
185        })
186    }
187
188    /// Compute the indices for this shard
189    fn compute_indices(dataset: &D, config: &ShardConfig) -> Result<Vec<usize>> {
190        let total_size = dataset.len();
191
192        if total_size == 0 {
193            return Ok(Vec::new());
194        }
195
196        let mut all_indices: Vec<usize> = (0..total_size).collect();
197
198        // Apply strategy-specific ordering
199        match &config.strategy {
200            ShardStrategy::RoundRobin => {
201                // No reordering needed, will filter by rank
202            }
203            ShardStrategy::Contiguous => {
204                // Already in contiguous order
205            }
206            ShardStrategy::ShuffledRoundRobin => {
207                // Deterministically shuffle
208                if let Some(seed) = config.seed {
209                    Self::deterministic_shuffle(&mut all_indices, seed);
210                }
211            }
212            ShardStrategy::Stratified => {
213                // Stratified sampling requires label information
214                // Use `new_stratified` constructor instead of `new` for proper stratified sharding
215                // Falling back to round-robin for backwards compatibility
216            }
217        }
218
219        // Select indices for this rank
220        let shard_indices = match &config.strategy {
221            ShardStrategy::RoundRobin | ShardStrategy::ShuffledRoundRobin => {
222                // Every world_size'th element starting from rank
223                all_indices
224                    .iter()
225                    .enumerate()
226                    .filter(|(i, _)| i % config.world_size == config.rank)
227                    .map(|(_, &idx)| idx)
228                    .collect()
229            }
230            ShardStrategy::Contiguous => {
231                // Divide into contiguous blocks
232                let samples_per_worker = total_size / config.world_size;
233                let extra_samples = total_size % config.world_size;
234
235                let start = if config.rank < extra_samples {
236                    config.rank * (samples_per_worker + 1)
237                } else {
238                    config.rank * samples_per_worker + extra_samples
239                };
240
241                let count = if config.rank < extra_samples {
242                    samples_per_worker + 1
243                } else {
244                    samples_per_worker
245                };
246
247                all_indices[start..start + count].to_vec()
248            }
249            ShardStrategy::Stratified => {
250                // Fallback to round-robin for now
251                all_indices
252                    .iter()
253                    .enumerate()
254                    .filter(|(i, _)| i % config.world_size == config.rank)
255                    .map(|(_, &idx)| idx)
256                    .collect()
257            }
258        };
259
260        Ok(shard_indices)
261    }
262
263    /// Deterministically shuffle indices using Fisher-Yates with seeded RNG
264    fn deterministic_shuffle(indices: &mut [usize], seed: u64) {
265        let mut rng_state = seed;
266
267        for i in (1..indices.len()).rev() {
268            // Simple LCG for deterministic random numbers
269            rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
270            let j = (rng_state as usize) % (i + 1);
271            indices.swap(i, j);
272        }
273    }
274
275    /// Compute stratified indices ensuring balanced class distribution across workers
276    fn compute_stratified_indices<F>(
277        dataset: &D,
278        config: &ShardConfig,
279        label_extractor: F,
280    ) -> Result<Vec<usize>>
281    where
282        F: Fn(&Tensor<T>) -> Result<usize>,
283        T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
284    {
285        let total_size = dataset.len();
286
287        if total_size == 0 {
288            return Ok(Vec::new());
289        }
290
291        // Group indices by class
292        let mut class_to_indices: HashMap<usize, Vec<usize>> = HashMap::new();
293
294        for i in 0..total_size {
295            let (_, label_tensor) = dataset.get(i)?;
296            let class = label_extractor(&label_tensor)?;
297            class_to_indices.entry(class).or_default().push(i);
298        }
299
300        // For each class, shuffle deterministically and distribute across workers
301        let mut worker_indices: Vec<Vec<usize>> = vec![Vec::new(); config.world_size];
302
303        // Sort classes to ensure deterministic ordering
304        let mut classes: Vec<_> = class_to_indices.keys().cloned().collect();
305        classes.sort_unstable();
306
307        for class in classes {
308            let mut indices = class_to_indices
309                .remove(&class)
310                .expect("class should exist in map since we got it from keys()");
311
312            // Deterministically shuffle within class
313            if let Some(seed) = config.seed {
314                // Use class as additional seed component for reproducibility
315                Self::deterministic_shuffle(&mut indices, seed.wrapping_add(class as u64));
316            }
317
318            // Distribute class indices round-robin across workers
319            for (idx_pos, &global_idx) in indices.iter().enumerate() {
320                let worker_id = idx_pos % config.world_size;
321                worker_indices[worker_id].push(global_idx);
322            }
323        }
324
325        // Get indices for this worker's rank
326        let mut shard_indices = worker_indices[config.rank].clone();
327
328        // Optionally shuffle the final shard indices
329        if let Some(seed) = config.seed {
330            Self::deterministic_shuffle(&mut shard_indices, seed.wrapping_add(config.rank as u64));
331        }
332
333        Ok(shard_indices)
334    }
335
336    /// Get the underlying dataset
337    pub fn inner(&self) -> &D {
338        &self.dataset
339    }
340
341    /// Get the shard configuration
342    pub fn config(&self) -> &ShardConfig {
343        &self.config
344    }
345
346    /// Get the shard indices
347    pub fn indices(&self) -> &[usize] {
348        &self.indices
349    }
350
351    /// Get shard statistics
352    pub fn shard_stats(&self) -> ShardStatistics {
353        let total_size = self.dataset.len();
354        let shard_size = self.indices.len();
355
356        let min_shard_size = total_size / self.config.world_size;
357        let max_shard_size = (total_size + self.config.world_size - 1) / self.config.world_size;
358
359        ShardStatistics {
360            total_samples: total_size,
361            shard_size,
362            min_shard_size,
363            max_shard_size,
364            world_size: self.config.world_size,
365            rank: self.config.rank,
366            imbalance_ratio: if min_shard_size > 0 {
367                max_shard_size as f64 / min_shard_size as f64
368            } else {
369                0.0
370            },
371        }
372    }
373}
374
375impl<T, D: Dataset<T>> Dataset<T> for ShardedDataset<T, D> {
376    fn get(
377        &self,
378        index: usize,
379    ) -> Result<(tenflowers_core::Tensor<T>, tenflowers_core::Tensor<T>)> {
380        if index >= self.indices.len() {
381            return Err(error_helpers::index_out_of_bounds(
382                "ShardedDataset::get",
383                index,
384                self.indices.len(),
385            ));
386        }
387
388        let actual_index = self.indices[index];
389        self.dataset.get(actual_index)
390    }
391
392    fn len(&self) -> usize {
393        self.indices.len()
394    }
395}
396
397/// Statistics about shard distribution
398#[derive(Debug, Clone)]
399pub struct ShardStatistics {
400    /// Total number of samples across all shards
401    pub total_samples: usize,
402    /// Number of samples in this shard
403    pub shard_size: usize,
404    /// Minimum shard size across all workers
405    pub min_shard_size: usize,
406    /// Maximum shard size across all workers
407    pub max_shard_size: usize,
408    /// Total number of workers
409    pub world_size: usize,
410    /// Current worker rank
411    pub rank: usize,
412    /// Ratio of max to min shard size (measures imbalance)
413    pub imbalance_ratio: f64,
414}
415
416impl ShardStatistics {
417    /// Check if shards are balanced (imbalance ratio close to 1.0)
418    pub fn is_balanced(&self) -> bool {
419        self.imbalance_ratio <= 1.1 // Allow 10% imbalance
420    }
421
422    /// Generate a human-readable report
423    pub fn report(&self) -> String {
424        format!(
425            "Shard Statistics:\n\
426             - Total samples: {}\n\
427             - World size: {} workers\n\
428             - Rank: {}\n\
429             - This shard size: {}\n\
430             - Min shard size: {}\n\
431             - Max shard size: {}\n\
432             - Imbalance ratio: {:.2}\n\
433             - Balanced: {}",
434            self.total_samples,
435            self.world_size,
436            self.rank,
437            self.shard_size,
438            self.min_shard_size,
439            self.max_shard_size,
440            self.imbalance_ratio,
441            if self.is_balanced() { "Yes" } else { "No" }
442        )
443    }
444}
445
446/// Extension trait to add sharding capabilities to any dataset
447pub trait DatasetShardingExt<T>: Dataset<T> + Sized {
448    /// Shard this dataset for distributed training
449    fn shard(self, config: ShardConfig) -> Result<ShardedDataset<T, Self>> {
450        ShardedDataset::new(self, config)
451    }
452
453    /// Create a round-robin sharded dataset
454    fn shard_round_robin(self, world_size: usize, rank: usize) -> Result<ShardedDataset<T, Self>> {
455        let config = ShardConfig::new(world_size, rank)?;
456        ShardedDataset::new(self, config)
457    }
458
459    /// Create a contiguous sharded dataset
460    fn shard_contiguous(self, world_size: usize, rank: usize) -> Result<ShardedDataset<T, Self>> {
461        let config = ShardConfig::new(world_size, rank)?.with_strategy(ShardStrategy::Contiguous);
462        ShardedDataset::new(self, config)
463    }
464
465    /// Create a shuffled sharded dataset with a seed
466    fn shard_shuffled(
467        self,
468        world_size: usize,
469        rank: usize,
470        seed: u64,
471    ) -> Result<ShardedDataset<T, Self>> {
472        let config = ShardConfig::new(world_size, rank)?
473            .with_strategy(ShardStrategy::ShuffledRoundRobin)
474            .with_seed(seed);
475        ShardedDataset::new(self, config)
476    }
477}
478
479// Blanket implementation for all datasets
480impl<T, D: Dataset<T>> DatasetShardingExt<T> for D {}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use crate::TensorDataset;
486    use tenflowers_core::Tensor;
487
488    #[test]
489    fn test_shard_config_creation() {
490        let config = ShardConfig::new(4, 0).expect("config creation should succeed");
491        assert_eq!(config.world_size, 4);
492        assert_eq!(config.rank, 0);
493        assert_eq!(config.strategy, ShardStrategy::RoundRobin);
494    }
495
496    #[test]
497    fn test_shard_config_validation() {
498        assert!(ShardConfig::new(0, 0).is_err());
499        assert!(ShardConfig::new(4, 4).is_err());
500        assert!(ShardConfig::new(4, 5).is_err());
501        assert!(ShardConfig::new(4, 3).is_ok());
502    }
503
504    #[test]
505    fn test_round_robin_sharding() {
506        let features = Tensor::<f32>::from_vec(
507            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
508            &[10, 1],
509        )
510        .expect("tensor creation should succeed");
511        let labels =
512            Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
513        let dataset = TensorDataset::new(features, labels);
514
515        // Shard into 3 workers
516        let config = ShardConfig::new(3, 0).expect("config creation should succeed");
517        let sharded =
518            ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
519
520        // Rank 0 should get indices [0, 3, 6, 9]
521        assert_eq!(sharded.len(), 4);
522        assert_eq!(sharded.indices(), &[0, 3, 6, 9]);
523    }
524
525    #[test]
526    fn test_contiguous_sharding() {
527        let features = Tensor::<f32>::from_vec(
528            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0],
529            &[10, 1],
530        )
531        .expect("tensor creation should succeed");
532        let labels =
533            Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
534        let dataset = TensorDataset::new(features, labels);
535
536        // Shard into 3 workers with contiguous strategy
537        let config = ShardConfig::new(3, 1)
538            .expect("test: operation should succeed")
539            .with_strategy(ShardStrategy::Contiguous);
540        let sharded =
541            ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
542
543        // Rank 1 should get a contiguous block
544        // 10 samples / 3 workers = 3 base + 1 extra for first worker
545        // Rank 0: [0,1,2,3], Rank 1: [4,5,6], Rank 2: [7,8,9]
546        assert_eq!(sharded.len(), 3);
547        assert_eq!(sharded.indices(), &[4, 5, 6]);
548    }
549
550    #[test]
551    fn test_shuffled_sharding_deterministic() {
552        let features = Tensor::<f32>::from_vec(vec![1.0; 100], &[100, 1])
553            .expect("tensor creation should succeed");
554        let labels = Tensor::<f32>::from_vec(vec![1.0; 100], &[100])
555            .expect("tensor creation should succeed");
556        let dataset1 = TensorDataset::new(features.clone(), labels.clone());
557        let dataset2 = TensorDataset::new(features, labels);
558
559        let config1 = ShardConfig::new(4, 0)
560            .expect("config creation should succeed")
561            .with_strategy(ShardStrategy::ShuffledRoundRobin)
562            .with_seed(42);
563        let config2 = ShardConfig::new(4, 0)
564            .expect("config creation should succeed")
565            .with_strategy(ShardStrategy::ShuffledRoundRobin)
566            .with_seed(42);
567
568        let sharded1 = ShardedDataset::new(dataset1, config1)
569            .expect("sharded dataset creation should succeed");
570        let sharded2 = ShardedDataset::new(dataset2, config2)
571            .expect("sharded dataset creation should succeed");
572
573        // Same seed should produce same indices
574        assert_eq!(sharded1.indices(), sharded2.indices());
575    }
576
577    #[test]
578    fn test_shard_statistics() {
579        let features = Tensor::<f32>::from_vec(vec![1.0; 100], &[100, 1])
580            .expect("tensor creation should succeed");
581        let labels = Tensor::<f32>::from_vec(vec![1.0; 100], &[100])
582            .expect("tensor creation should succeed");
583        let dataset = TensorDataset::new(features, labels);
584
585        let config = ShardConfig::new(3, 0).expect("config creation should succeed");
586        let sharded =
587            ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
588
589        let stats = sharded.shard_stats();
590        assert_eq!(stats.total_samples, 100);
591        assert_eq!(stats.world_size, 3);
592        assert_eq!(stats.rank, 0);
593        assert!(stats.imbalance_ratio >= 1.0);
594    }
595
596    #[test]
597    fn test_extension_trait_round_robin() {
598        let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
599            .expect("tensor creation should succeed");
600        let labels =
601            Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
602        let dataset = TensorDataset::new(features, labels);
603
604        let sharded = dataset
605            .shard_round_robin(2, 0)
606            .expect("shard_round_robin should succeed");
607        assert_eq!(sharded.len(), 5);
608    }
609
610    #[test]
611    fn test_extension_trait_contiguous() {
612        let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
613            .expect("tensor creation should succeed");
614        let labels =
615            Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
616        let dataset = TensorDataset::new(features, labels);
617
618        let sharded = dataset
619            .shard_contiguous(2, 0)
620            .expect("shard_contiguous should succeed");
621        assert_eq!(sharded.len(), 5);
622        assert_eq!(sharded.indices(), &[0, 1, 2, 3, 4]);
623    }
624
625    #[test]
626    fn test_extension_trait_shuffled() {
627        let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
628            .expect("tensor creation should succeed");
629        let labels =
630            Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
631        let dataset = TensorDataset::new(features, labels);
632
633        let sharded = dataset
634            .shard_shuffled(2, 0, 42)
635            .expect("shard_shuffled should succeed");
636        assert_eq!(sharded.len(), 5);
637    }
638
639    #[test]
640    fn test_shard_access() {
641        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6, 1])
642            .expect("tensor creation should succeed");
643        let labels = Tensor::<f32>::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0], &[6])
644            .expect("tensor creation should succeed");
645        let dataset = TensorDataset::new(features, labels);
646
647        let config = ShardConfig::new(2, 0).expect("config creation should succeed");
648        let sharded =
649            ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
650
651        // Rank 0 should get indices [0, 2, 4]
652        let (f0, l0) = sharded.get(0).expect("index should be in bounds");
653        let (f1, l1) = sharded.get(1).expect("index should be in bounds");
654        let (f2, l2) = sharded.get(2).expect("index should be in bounds");
655
656        // Verify we're accessing the correct original indices
657        assert!((f0.to_vec().expect("to_vec should succeed")[0] - 1.0).abs() < 1e-6);
658        assert!((l0.to_vec().expect("to_vec should succeed")[0] - 10.0).abs() < 1e-6);
659
660        assert!((f1.to_vec().expect("to_vec should succeed")[0] - 3.0).abs() < 1e-6);
661        assert!((l1.to_vec().expect("to_vec should succeed")[0] - 30.0).abs() < 1e-6);
662
663        assert!((f2.to_vec().expect("to_vec should succeed")[0] - 5.0).abs() < 1e-6);
664        assert!((l2.to_vec().expect("to_vec should succeed")[0] - 50.0).abs() < 1e-6);
665    }
666
667    #[test]
668    fn test_shard_out_of_bounds() {
669        let features =
670            Tensor::<f32>::from_vec(vec![1.0; 6], &[6, 1]).expect("tensor creation should succeed");
671        let labels =
672            Tensor::<f32>::from_vec(vec![1.0; 6], &[6]).expect("tensor creation should succeed");
673        let dataset = TensorDataset::new(features, labels);
674
675        let sharded = dataset
676            .shard_round_robin(2, 0)
677            .expect("shard_round_robin should succeed");
678        assert_eq!(sharded.len(), 3);
679        assert!(sharded.get(3).is_err());
680    }
681
682    #[test]
683    fn test_empty_dataset_sharding() {
684        let features =
685            Tensor::<f32>::from_vec(vec![], &[0, 1]).expect("empty tensor creation should succeed");
686        let labels =
687            Tensor::<f32>::from_vec(vec![], &[0]).expect("empty tensor creation should succeed");
688        let dataset = TensorDataset::new(features, labels);
689
690        let sharded = dataset
691            .shard_round_robin(2, 0)
692            .expect("shard_round_robin should succeed");
693        assert_eq!(sharded.len(), 0);
694    }
695
696    #[test]
697    fn test_shard_statistics_balanced() {
698        let features = Tensor::<f32>::from_vec(vec![1.0; 12], &[12, 1])
699            .expect("tensor creation should succeed");
700        let labels =
701            Tensor::<f32>::from_vec(vec![1.0; 12], &[12]).expect("tensor creation should succeed");
702        let dataset = TensorDataset::new(features, labels);
703
704        let config = ShardConfig::new(3, 0).expect("config creation should succeed"); // 12/3 = 4 each, perfectly balanced
705        let sharded =
706            ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
707
708        let stats = sharded.shard_stats();
709        assert!(stats.is_balanced());
710        assert_eq!(stats.imbalance_ratio, 1.0);
711    }
712
713    #[test]
714    fn test_shard_statistics_report() {
715        let features = Tensor::<f32>::from_vec(vec![1.0; 10], &[10, 1])
716            .expect("tensor creation should succeed");
717        let labels =
718            Tensor::<f32>::from_vec(vec![1.0; 10], &[10]).expect("tensor creation should succeed");
719        let dataset = TensorDataset::new(features, labels);
720
721        let config = ShardConfig::new(3, 0).expect("config creation should succeed");
722        let sharded =
723            ShardedDataset::new(dataset, config).expect("sharded dataset creation should succeed");
724
725        let report = sharded.shard_stats().report();
726        assert!(report.contains("Total samples: 10"));
727        assert!(report.contains("World size: 3"));
728        assert!(report.contains("Rank: 0"));
729    }
730
731    #[test]
732    fn test_stratified_sharding() {
733        // Create dataset with 3 classes: [0,0,1,1,2,2] (repeated)
734        let features = Tensor::<f32>::from_vec(
735            vec![
736                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
737            ],
738            &[12, 1],
739        )
740        .expect("tensor creation should succeed");
741        let labels = Tensor::<f32>::from_vec(
742            vec![0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 0.0, 0.0, 1.0, 1.0, 2.0, 2.0],
743            &[12],
744        )
745        .expect("tensor creation should succeed");
746        let dataset = TensorDataset::new(features, labels);
747
748        // Label extractor: extract the scalar value from label tensor
749        let label_extractor = |label_tensor: &Tensor<f32>| -> Result<usize> {
750            let data = label_tensor
751                .to_vec()
752                .map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
753            Ok(data[0] as usize)
754        };
755
756        // Shard into 2 workers with stratified strategy
757        let config = ShardConfig::new(2, 0)
758            .expect("config creation should succeed")
759            .with_strategy(ShardStrategy::Stratified)
760            .with_seed(42);
761
762        let sharded = ShardedDataset::new_stratified(dataset, config, label_extractor)
763            .expect("stratified sharding should succeed");
764
765        // Each worker should get balanced class distribution
766        // With 4 samples of each class and 2 workers, each worker should get 2 of each class
767        assert_eq!(sharded.len(), 6); // 12 samples / 2 workers = 6 per worker
768
769        // Verify that we can access samples
770        for i in 0..sharded.len() {
771            let (feature, label) = sharded.get(i).expect("get should succeed");
772            assert!(feature.to_vec().is_ok());
773            assert!(label.to_vec().is_ok());
774        }
775    }
776
777    #[test]
778    fn test_stratified_sharding_balanced_classes() {
779        // Create dataset with balanced classes
780        let features = Tensor::<f32>::from_vec(vec![1.0; 60], &[60, 1])
781            .expect("tensor creation should succeed");
782        // 20 samples each of class 0, 1, 2
783        let mut label_data = vec![0.0; 20];
784        label_data.extend(vec![1.0; 20]);
785        label_data.extend(vec![2.0; 20]);
786        let labels =
787            Tensor::<f32>::from_vec(label_data, &[60]).expect("tensor creation should succeed");
788        let dataset = TensorDataset::new(features, labels);
789
790        let label_extractor = |label_tensor: &Tensor<f32>| -> Result<usize> {
791            let data = label_tensor
792                .to_vec()
793                .map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
794            Ok(data[0] as usize)
795        };
796
797        // Shard into 3 workers
798        let config = ShardConfig::new(3, 0)
799            .expect("config creation should succeed")
800            .with_strategy(ShardStrategy::Stratified)
801            .with_seed(123);
802
803        let sharded = ShardedDataset::new_stratified(dataset, config, label_extractor)
804            .expect("stratified sharding should succeed");
805
806        // Each worker should get approximately 20 samples (60 / 3)
807        // Due to round-robin distribution within classes, the exact count may vary slightly
808        // 20 samples of each class / 3 workers = 6-7 per class per worker
809        // Total per worker = 6*3 + 1*3 = 18-21 samples (allowing for rounding)
810        assert!(sharded.len() >= 18 && sharded.len() <= 21);
811    }
812
813    #[test]
814    fn test_stratified_sharding_deterministic() {
815        let features = Tensor::<f32>::from_vec(vec![1.0; 30], &[30, 1])
816            .expect("test: tensor creation should succeed");
817        let labels = Tensor::<f32>::from_vec(
818            vec![
819                0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0,
820                1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0, 0.0, 1.0, 2.0,
821            ],
822            &[30],
823        )
824        .expect("tensor creation should succeed");
825        let dataset1 = TensorDataset::new(features.clone(), labels.clone());
826        let dataset2 = TensorDataset::new(features, labels);
827
828        let label_extractor1 = |label_tensor: &Tensor<f32>| -> Result<usize> {
829            let data = label_tensor
830                .to_vec()
831                .map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
832            Ok(data[0] as usize)
833        };
834
835        let label_extractor2 = |label_tensor: &Tensor<f32>| -> Result<usize> {
836            let data = label_tensor
837                .to_vec()
838                .map_err(|e| tenflowers_core::TensorError::invalid_argument(e.to_string()))?;
839            Ok(data[0] as usize)
840        };
841
842        // Same seed should produce same results
843        let config1 = ShardConfig::new(2, 0)
844            .expect("config creation should succeed")
845            .with_strategy(ShardStrategy::Stratified)
846            .with_seed(999);
847
848        let config2 = ShardConfig::new(2, 0)
849            .expect("config creation should succeed")
850            .with_strategy(ShardStrategy::Stratified)
851            .with_seed(999);
852
853        let sharded1 = ShardedDataset::new_stratified(dataset1, config1, label_extractor1)
854            .expect("stratified sharding should succeed");
855        let sharded2 = ShardedDataset::new_stratified(dataset2, config2, label_extractor2)
856            .expect("stratified sharding should succeed");
857
858        // Same seed should produce same indices
859        assert_eq!(sharded1.indices(), sharded2.indices());
860    }
861}