torsh_data/sampler/
distributed.rs

1//! Distributed sampling for multi-process training
2//!
3//! This module provides samplers for distributed training scenarios where
4//! multiple processes need to sample different subsets of the data to avoid
5//! overlap and ensure balanced workloads.
6
7#[cfg(not(feature = "std"))]
8use alloc::vec::Vec;
9
10// ✅ SciRS2 Policy Compliant - Using scirs2_core for all random operations
11use scirs2_core::random::{Random, Rng};
12
13use super::core::{Sampler, SamplerIterator};
14
15/// Wrapper that makes any sampler work in a distributed setting
16///
17/// This wrapper takes an underlying sampler and distributes its indices across
18/// multiple replicas (processes). Each replica gets a disjoint subset of the
19/// indices, ensuring no overlap between processes.
20///
21/// # Examples
22///
23/// ```rust
24/// use torsh_data::sampler::{SequentialSampler, DistributedWrapper, Sampler};
25///
26/// let base_sampler = SequentialSampler::new(10);
27/// // Process 0 of 2 processes
28/// let distributed = DistributedWrapper::new(base_sampler, 2, 0);
29///
30/// let indices: Vec<usize> = distributed.iter().collect();
31/// // Process 0 gets: [0, 2, 4, 6, 8]
32/// // Process 1 would get: [1, 3, 5, 7, 9]
33/// ```
34#[derive(Debug, Clone)]
35pub struct DistributedWrapper<S: Sampler> {
36    sampler: S,
37    num_replicas: usize,
38    rank: usize,
39    shuffle: bool,
40    generator: Option<u64>,
41}
42
43impl<S: Sampler> DistributedWrapper<S> {
44    /// Create a new distributed wrapper
45    ///
46    /// # Arguments
47    ///
48    /// * `sampler` - The underlying sampler to distribute
49    /// * `num_replicas` - Total number of processes
50    /// * `rank` - Current process rank (0-based)
51    ///
52    /// # Panics
53    ///
54    /// Panics if `num_replicas` is 0 or `rank` >= `num_replicas`
55    pub fn new(sampler: S, num_replicas: usize, rank: usize) -> Self {
56        assert!(num_replicas > 0, "Number of replicas must be positive");
57        assert!(rank < num_replicas, "Rank must be less than num_replicas");
58
59        Self {
60            sampler,
61            num_replicas,
62            rank,
63            shuffle: true,
64            generator: None,
65        }
66    }
67
68    /// Enable or disable shuffling
69    ///
70    /// When shuffling is enabled, the indices are shuffled before distribution.
71    /// This ensures different ordering across epochs.
72    pub fn with_shuffle(mut self, shuffle: bool) -> Self {
73        self.shuffle = shuffle;
74        self
75    }
76
77    /// Set the random number generator seed
78    ///
79    /// # Arguments
80    ///
81    /// * `seed` - Random seed for reproducible shuffling
82    pub fn with_generator(mut self, seed: u64) -> Self {
83        self.generator = Some(seed);
84        self
85    }
86
87    /// Get the number of replicas
88    pub fn num_replicas(&self) -> usize {
89        self.num_replicas
90    }
91
92    /// Get the current rank
93    pub fn rank(&self) -> usize {
94        self.rank
95    }
96
97    /// Check if shuffling is enabled
98    pub fn shuffle(&self) -> bool {
99        self.shuffle
100    }
101
102    /// Get the generator seed if set
103    pub fn generator(&self) -> Option<u64> {
104        self.generator
105    }
106
107    /// Get a reference to the underlying sampler
108    pub fn sampler(&self) -> &S {
109        &self.sampler
110    }
111
112    /// Get the underlying sampler by value
113    pub fn into_sampler(self) -> S {
114        self.sampler
115    }
116
117    /// Calculate the number of samples this replica will receive
118    fn calculate_num_samples(&self) -> usize {
119        let total_samples = self.sampler.len();
120        // Each replica gets roughly equal number of samples
121        // If total doesn't divide evenly, some replicas get one extra sample
122        let base_samples = total_samples / self.num_replicas;
123        let extra_samples = total_samples % self.num_replicas;
124
125        if self.rank < extra_samples {
126            base_samples + 1
127        } else {
128            base_samples
129        }
130    }
131}
132
133impl<S: Sampler> Sampler for DistributedWrapper<S> {
134    type Iter = SamplerIterator;
135
136    fn iter(&self) -> Self::Iter {
137        // Get all indices from the underlying sampler
138        let mut all_indices: Vec<usize> = self.sampler.iter().collect();
139
140        // Shuffle if enabled
141        if self.shuffle {
142            // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
143            let mut rng = match self.generator {
144                Some(seed) => Random::seed(seed),
145                None => Random::seed(42),
146            };
147
148            // Fisher-Yates shuffle
149            for i in (1..all_indices.len()).rev() {
150                let j = rng.gen_range(0..=i);
151                all_indices.swap(i, j);
152            }
153        }
154
155        // Distribute indices across replicas
156        let replica_indices: Vec<usize> = all_indices
157            .into_iter()
158            .enumerate()
159            .filter_map(|(i, idx)| {
160                if i % self.num_replicas == self.rank {
161                    Some(idx)
162                } else {
163                    None
164                }
165            })
166            .collect();
167
168        SamplerIterator::new(replica_indices)
169    }
170
171    fn len(&self) -> usize {
172        self.calculate_num_samples()
173    }
174}
175
176/// Distributed sampler for balanced data distribution
177///
178/// Unlike DistributedWrapper which wraps an existing sampler, DistributedSampler
179/// is a standalone sampler designed specifically for distributed training.
180/// It provides more control over the distribution strategy.
181///
182/// # Examples
183///
184/// ```rust
185/// use torsh_data::sampler::{DistributedSampler, Sampler};
186///
187/// // Process 1 of 4 processes, working with dataset of size 100
188/// let sampler = DistributedSampler::new(100, 4, 1, true).with_generator(42);
189///
190/// let indices: Vec<usize> = sampler.iter().collect();
191/// assert_eq!(indices.len(), 25); // 100 / 4 = 25 samples per process
192/// ```
193#[derive(Debug, Clone)]
194pub struct DistributedSampler {
195    dataset_size: usize,
196    num_replicas: usize,
197    rank: usize,
198    shuffle: bool,
199    generator: Option<u64>,
200    drop_last: bool,
201}
202
203impl DistributedSampler {
204    /// Create a new distributed sampler
205    ///
206    /// # Arguments
207    ///
208    /// * `dataset_size` - Total size of the dataset
209    /// * `num_replicas` - Total number of processes
210    /// * `rank` - Current process rank (0-based)
211    /// * `shuffle` - Whether to shuffle indices
212    ///
213    /// # Panics
214    ///
215    /// Panics if `dataset_size` is 0, `num_replicas` is 0, or `rank` >= `num_replicas`
216    pub fn new(dataset_size: usize, num_replicas: usize, rank: usize, shuffle: bool) -> Self {
217        assert!(dataset_size > 0, "Dataset size must be positive");
218        assert!(num_replicas > 0, "Number of replicas must be positive");
219        assert!(rank < num_replicas, "Rank must be less than num_replicas");
220
221        Self {
222            dataset_size,
223            num_replicas,
224            rank,
225            shuffle,
226            generator: None,
227            drop_last: false,
228        }
229    }
230
231    /// Set the random number generator seed
232    pub fn with_generator(mut self, seed: u64) -> Self {
233        self.generator = Some(seed);
234        self
235    }
236
237    /// Set whether to drop the last samples to make dataset evenly divisible
238    pub fn with_drop_last(mut self, drop_last: bool) -> Self {
239        self.drop_last = drop_last;
240        self
241    }
242
243    /// Get the dataset size
244    pub fn dataset_size(&self) -> usize {
245        self.dataset_size
246    }
247
248    /// Get the number of replicas
249    pub fn num_replicas(&self) -> usize {
250        self.num_replicas
251    }
252
253    /// Get the current rank
254    pub fn rank(&self) -> usize {
255        self.rank
256    }
257
258    /// Check if shuffling is enabled
259    pub fn shuffle(&self) -> bool {
260        self.shuffle
261    }
262
263    /// Check if dropping last samples
264    pub fn drop_last(&self) -> bool {
265        self.drop_last
266    }
267
268    /// Get the generator seed if set
269    pub fn generator(&self) -> Option<u64> {
270        self.generator
271    }
272
273    /// Calculate the effective dataset size after potential padding
274    fn effective_dataset_size(&self) -> usize {
275        if self.drop_last {
276            // Drop samples to make evenly divisible
277            (self.dataset_size / self.num_replicas) * self.num_replicas
278        } else {
279            // Pad with duplicates to make evenly divisible
280            let samples_per_replica =
281                (self.dataset_size + self.num_replicas - 1) / self.num_replicas;
282            samples_per_replica * self.num_replicas
283        }
284    }
285
286    /// Calculate the number of samples this replica will receive
287    fn calculate_num_samples(&self) -> usize {
288        if self.drop_last {
289            self.dataset_size / self.num_replicas
290        } else {
291            (self.dataset_size + self.num_replicas - 1) / self.num_replicas
292        }
293    }
294}
295
296impl Sampler for DistributedSampler {
297    type Iter = SamplerIterator;
298
299    fn iter(&self) -> Self::Iter {
300        let effective_size = self.effective_dataset_size();
301        let samples_per_replica = self.calculate_num_samples();
302
303        // Create base indices
304        let mut indices: Vec<usize> = if self.drop_last {
305            (0..effective_size).collect()
306        } else {
307            // Pad with duplicates if needed
308            (0..effective_size).map(|i| i % self.dataset_size).collect()
309        };
310
311        // Shuffle if enabled
312        if self.shuffle {
313            // ✅ SciRS2 Policy Compliant - Using scirs2_core for random operations
314            let mut rng = match self.generator {
315                Some(seed) => Random::seed(seed),
316                None => Random::seed(42),
317            };
318
319            // Fisher-Yates shuffle
320            for i in (1..indices.len()).rev() {
321                let j = rng.gen_range(0..=i);
322                indices.swap(i, j);
323            }
324        }
325
326        // Extract this replica's portion
327        let start_idx = self.rank * samples_per_replica;
328        let end_idx = start_idx + samples_per_replica;
329        let replica_indices = indices[start_idx..end_idx.min(indices.len())].to_vec();
330
331        SamplerIterator::new(replica_indices)
332    }
333
334    fn len(&self) -> usize {
335        self.calculate_num_samples()
336    }
337}
338
339/// Create a distributed wrapper for any sampler
340///
341/// Convenience function for creating a distributed wrapper.
342///
343/// # Arguments
344///
345/// * `sampler` - The underlying sampler
346/// * `num_replicas` - Total number of processes
347/// * `rank` - Current process rank
348pub fn distributed<S: Sampler>(
349    sampler: S,
350    num_replicas: usize,
351    rank: usize,
352) -> DistributedWrapper<S> {
353    DistributedWrapper::new(sampler, num_replicas, rank)
354}
355
356/// Create a distributed sampler
357///
358/// Convenience function for creating a distributed sampler.
359///
360/// # Arguments
361///
362/// * `dataset_size` - Total size of the dataset
363/// * `num_replicas` - Total number of processes
364/// * `rank` - Current process rank
365/// * `shuffle` - Whether to shuffle indices
366pub fn distributed_sampler(
367    dataset_size: usize,
368    num_replicas: usize,
369    rank: usize,
370    shuffle: bool,
371) -> DistributedSampler {
372    DistributedSampler::new(dataset_size, num_replicas, rank, shuffle)
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use crate::sampler::basic::SequentialSampler;
379
380    #[test]
381    fn test_distributed_wrapper_basic() {
382        let base_sampler = SequentialSampler::new(10);
383        let distributed = DistributedWrapper::new(base_sampler, 2, 0).with_shuffle(false);
384
385        assert_eq!(distributed.num_replicas(), 2);
386        assert_eq!(distributed.rank(), 0);
387        assert!(!distributed.shuffle());
388        assert_eq!(distributed.len(), 5); // 10 / 2 = 5 samples per process
389
390        let indices: Vec<usize> = distributed.iter().collect();
391        assert_eq!(indices, vec![0, 2, 4, 6, 8]); // Even indices for rank 0
392    }
393
394    #[test]
395    fn test_distributed_wrapper_rank_1() {
396        let base_sampler = SequentialSampler::new(10);
397        let distributed = DistributedWrapper::new(base_sampler, 2, 1).with_shuffle(false);
398
399        assert_eq!(distributed.rank(), 1);
400        assert_eq!(distributed.len(), 5);
401
402        let indices: Vec<usize> = distributed.iter().collect();
403        assert_eq!(indices, vec![1, 3, 5, 7, 9]); // Odd indices for rank 1
404    }
405
406    #[test]
407    fn test_distributed_wrapper_uneven_split() {
408        let base_sampler = SequentialSampler::new(7); // 7 doesn't divide evenly by 3
409
410        let dist0 = DistributedWrapper::new(base_sampler.clone(), 3, 0).with_shuffle(false);
411        let dist1 = DistributedWrapper::new(base_sampler.clone(), 3, 1).with_shuffle(false);
412        let dist2 = DistributedWrapper::new(base_sampler, 3, 2).with_shuffle(false);
413
414        // First rank gets extra sample: 7 / 3 = 2 remainder 1
415        assert_eq!(dist0.len(), 3); // 2 + 1 = 3
416        assert_eq!(dist1.len(), 2); // 2
417        assert_eq!(dist2.len(), 2); // 2
418
419        let indices0: Vec<usize> = dist0.iter().collect();
420        let indices1: Vec<usize> = dist1.iter().collect();
421        let indices2: Vec<usize> = dist2.iter().collect();
422
423        assert_eq!(indices0, vec![0, 3, 6]);
424        assert_eq!(indices1, vec![1, 4]);
425        assert_eq!(indices2, vec![2, 5]);
426
427        // Verify all indices are covered
428        let mut all_indices = indices0;
429        all_indices.extend(indices1);
430        all_indices.extend(indices2);
431        all_indices.sort();
432        assert_eq!(all_indices, vec![0, 1, 2, 3, 4, 5, 6]);
433    }
434
435    #[test]
436    fn test_distributed_wrapper_with_shuffle() {
437        let base_sampler = SequentialSampler::new(10);
438        let distributed = DistributedWrapper::new(base_sampler, 2, 0)
439            .with_shuffle(true)
440            .with_generator(42);
441
442        let indices: Vec<usize> = distributed.iter().collect();
443        assert_eq!(indices.len(), 5);
444
445        // Verify indices are from original dataset
446        for &idx in &indices {
447            assert!(idx < 10);
448        }
449
450        // Should be deterministic with same seed
451        let distributed2 = DistributedWrapper::new(SequentialSampler::new(10), 2, 0)
452            .with_shuffle(true)
453            .with_generator(42);
454        let indices2: Vec<usize> = distributed2.iter().collect();
455        assert_eq!(indices, indices2);
456    }
457
458    #[test]
459    #[should_panic(expected = "Number of replicas must be positive")]
460    fn test_distributed_wrapper_zero_replicas() {
461        let base_sampler = SequentialSampler::new(10);
462        DistributedWrapper::new(base_sampler, 0, 0);
463    }
464
465    #[test]
466    #[should_panic(expected = "Rank must be less than num_replicas")]
467    fn test_distributed_wrapper_invalid_rank() {
468        let base_sampler = SequentialSampler::new(10);
469        DistributedWrapper::new(base_sampler, 2, 2);
470    }
471
472    #[test]
473    fn test_distributed_sampler_basic() {
474        let sampler = DistributedSampler::new(12, 3, 1, false);
475
476        assert_eq!(sampler.dataset_size(), 12);
477        assert_eq!(sampler.num_replicas(), 3);
478        assert_eq!(sampler.rank(), 1);
479        assert!(!sampler.shuffle());
480        assert!(!sampler.drop_last());
481        assert_eq!(sampler.len(), 4); // 12 / 3 = 4
482
483        let indices: Vec<usize> = sampler.iter().collect();
484        assert_eq!(indices, vec![4, 5, 6, 7]); // Rank 1 gets indices 4-7
485    }
486
487    #[test]
488    fn test_distributed_sampler_with_padding() {
489        let sampler = DistributedSampler::new(10, 3, 0, false); // 10 doesn't divide by 3
490
491        // With padding, each replica gets 4 samples (total 12, padded from 10)
492        assert_eq!(sampler.len(), 4);
493
494        let indices: Vec<usize> = sampler.iter().collect();
495        assert_eq!(indices.len(), 4);
496
497        // All indices should be valid (0-9, with possible duplicates due to padding)
498        for &idx in &indices {
499            assert!(idx < 10);
500        }
501    }
502
503    #[test]
504    fn test_distributed_sampler_drop_last() {
505        let sampler = DistributedSampler::new(10, 3, 0, false).with_drop_last(true);
506
507        // With drop_last, 10 / 3 = 3 samples per replica (drops 1 sample)
508        assert_eq!(sampler.len(), 3);
509
510        let indices: Vec<usize> = sampler.iter().collect();
511        assert_eq!(indices, vec![0, 1, 2]);
512    }
513
514    #[test]
515    fn test_distributed_sampler_shuffle() {
516        let sampler = DistributedSampler::new(12, 3, 0, true).with_generator(42);
517
518        let indices: Vec<usize> = sampler.iter().collect();
519        assert_eq!(indices.len(), 4);
520
521        // Should be deterministic with same seed
522        let sampler2 = DistributedSampler::new(12, 3, 0, true).with_generator(42);
523        let indices2: Vec<usize> = sampler2.iter().collect();
524        assert_eq!(indices, indices2);
525    }
526
527    #[test]
528    fn test_convenience_functions() {
529        let base_sampler = SequentialSampler::new(8);
530        let dist_wrapper = distributed(base_sampler, 2, 0);
531        assert_eq!(dist_wrapper.len(), 4);
532
533        let dist_sampler = distributed_sampler(8, 2, 1, false);
534        assert_eq!(dist_sampler.len(), 4);
535    }
536
537    #[test]
538    fn test_distributed_sampler_edge_cases() {
539        // Single replica (should get all data)
540        let sampler = DistributedSampler::new(10, 1, 0, false);
541        assert_eq!(sampler.len(), 10);
542
543        let indices: Vec<usize> = sampler.iter().collect();
544        assert_eq!(indices, (0..10).collect::<Vec<_>>());
545
546        // More replicas than data points
547        let sampler = DistributedSampler::new(2, 5, 3, false);
548        assert_eq!(sampler.len(), 1);
549
550        let indices: Vec<usize> = sampler.iter().collect();
551        assert_eq!(indices.len(), 1);
552        assert!(indices[0] < 2);
553    }
554
555    #[test]
556    fn test_into_sampler() {
557        let base_sampler = SequentialSampler::new(5);
558        let distributed = DistributedWrapper::new(base_sampler, 2, 0);
559
560        let recovered = distributed.into_sampler();
561        assert_eq!(recovered.len(), 5);
562    }
563}