1#[cfg(not(feature = "std"))]
8use alloc::vec::Vec;
9
10use scirs2_core::random::{Random, Rng};
12
13use super::core::{Sampler, SamplerIterator};
14
15#[derive(Debug, Clone)]
35pub struct DistributedWrapper<S: Sampler> {
36 sampler: S,
37 num_replicas: usize,
38 rank: usize,
39 shuffle: bool,
40 generator: Option<u64>,
41}
42
43impl<S: Sampler> DistributedWrapper<S> {
44 pub fn new(sampler: S, num_replicas: usize, rank: usize) -> Self {
56 assert!(num_replicas > 0, "Number of replicas must be positive");
57 assert!(rank < num_replicas, "Rank must be less than num_replicas");
58
59 Self {
60 sampler,
61 num_replicas,
62 rank,
63 shuffle: true,
64 generator: None,
65 }
66 }
67
68 pub fn with_shuffle(mut self, shuffle: bool) -> Self {
73 self.shuffle = shuffle;
74 self
75 }
76
77 pub fn with_generator(mut self, seed: u64) -> Self {
83 self.generator = Some(seed);
84 self
85 }
86
87 pub fn num_replicas(&self) -> usize {
89 self.num_replicas
90 }
91
92 pub fn rank(&self) -> usize {
94 self.rank
95 }
96
97 pub fn shuffle(&self) -> bool {
99 self.shuffle
100 }
101
102 pub fn generator(&self) -> Option<u64> {
104 self.generator
105 }
106
107 pub fn sampler(&self) -> &S {
109 &self.sampler
110 }
111
112 pub fn into_sampler(self) -> S {
114 self.sampler
115 }
116
117 fn calculate_num_samples(&self) -> usize {
119 let total_samples = self.sampler.len();
120 let base_samples = total_samples / self.num_replicas;
123 let extra_samples = total_samples % self.num_replicas;
124
125 if self.rank < extra_samples {
126 base_samples + 1
127 } else {
128 base_samples
129 }
130 }
131}
132
133impl<S: Sampler> Sampler for DistributedWrapper<S> {
134 type Iter = SamplerIterator;
135
136 fn iter(&self) -> Self::Iter {
137 let mut all_indices: Vec<usize> = self.sampler.iter().collect();
139
140 if self.shuffle {
142 let mut rng = match self.generator {
144 Some(seed) => Random::seed(seed),
145 None => Random::seed(42),
146 };
147
148 for i in (1..all_indices.len()).rev() {
150 let j = rng.gen_range(0..=i);
151 all_indices.swap(i, j);
152 }
153 }
154
155 let replica_indices: Vec<usize> = all_indices
157 .into_iter()
158 .enumerate()
159 .filter_map(|(i, idx)| {
160 if i % self.num_replicas == self.rank {
161 Some(idx)
162 } else {
163 None
164 }
165 })
166 .collect();
167
168 SamplerIterator::new(replica_indices)
169 }
170
171 fn len(&self) -> usize {
172 self.calculate_num_samples()
173 }
174}
175
176#[derive(Debug, Clone)]
194pub struct DistributedSampler {
195 dataset_size: usize,
196 num_replicas: usize,
197 rank: usize,
198 shuffle: bool,
199 generator: Option<u64>,
200 drop_last: bool,
201}
202
203impl DistributedSampler {
204 pub fn new(dataset_size: usize, num_replicas: usize, rank: usize, shuffle: bool) -> Self {
217 assert!(dataset_size > 0, "Dataset size must be positive");
218 assert!(num_replicas > 0, "Number of replicas must be positive");
219 assert!(rank < num_replicas, "Rank must be less than num_replicas");
220
221 Self {
222 dataset_size,
223 num_replicas,
224 rank,
225 shuffle,
226 generator: None,
227 drop_last: false,
228 }
229 }
230
231 pub fn with_generator(mut self, seed: u64) -> Self {
233 self.generator = Some(seed);
234 self
235 }
236
237 pub fn with_drop_last(mut self, drop_last: bool) -> Self {
239 self.drop_last = drop_last;
240 self
241 }
242
243 pub fn dataset_size(&self) -> usize {
245 self.dataset_size
246 }
247
248 pub fn num_replicas(&self) -> usize {
250 self.num_replicas
251 }
252
253 pub fn rank(&self) -> usize {
255 self.rank
256 }
257
258 pub fn shuffle(&self) -> bool {
260 self.shuffle
261 }
262
263 pub fn drop_last(&self) -> bool {
265 self.drop_last
266 }
267
268 pub fn generator(&self) -> Option<u64> {
270 self.generator
271 }
272
273 fn effective_dataset_size(&self) -> usize {
275 if self.drop_last {
276 (self.dataset_size / self.num_replicas) * self.num_replicas
278 } else {
279 let samples_per_replica =
281 (self.dataset_size + self.num_replicas - 1) / self.num_replicas;
282 samples_per_replica * self.num_replicas
283 }
284 }
285
286 fn calculate_num_samples(&self) -> usize {
288 if self.drop_last {
289 self.dataset_size / self.num_replicas
290 } else {
291 (self.dataset_size + self.num_replicas - 1) / self.num_replicas
292 }
293 }
294}
295
296impl Sampler for DistributedSampler {
297 type Iter = SamplerIterator;
298
299 fn iter(&self) -> Self::Iter {
300 let effective_size = self.effective_dataset_size();
301 let samples_per_replica = self.calculate_num_samples();
302
303 let mut indices: Vec<usize> = if self.drop_last {
305 (0..effective_size).collect()
306 } else {
307 (0..effective_size).map(|i| i % self.dataset_size).collect()
309 };
310
311 if self.shuffle {
313 let mut rng = match self.generator {
315 Some(seed) => Random::seed(seed),
316 None => Random::seed(42),
317 };
318
319 for i in (1..indices.len()).rev() {
321 let j = rng.gen_range(0..=i);
322 indices.swap(i, j);
323 }
324 }
325
326 let start_idx = self.rank * samples_per_replica;
328 let end_idx = start_idx + samples_per_replica;
329 let replica_indices = indices[start_idx..end_idx.min(indices.len())].to_vec();
330
331 SamplerIterator::new(replica_indices)
332 }
333
334 fn len(&self) -> usize {
335 self.calculate_num_samples()
336 }
337}
338
339pub fn distributed<S: Sampler>(
349 sampler: S,
350 num_replicas: usize,
351 rank: usize,
352) -> DistributedWrapper<S> {
353 DistributedWrapper::new(sampler, num_replicas, rank)
354}
355
356pub fn distributed_sampler(
367 dataset_size: usize,
368 num_replicas: usize,
369 rank: usize,
370 shuffle: bool,
371) -> DistributedSampler {
372 DistributedSampler::new(dataset_size, num_replicas, rank, shuffle)
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 use crate::sampler::basic::SequentialSampler;
379
380 #[test]
381 fn test_distributed_wrapper_basic() {
382 let base_sampler = SequentialSampler::new(10);
383 let distributed = DistributedWrapper::new(base_sampler, 2, 0).with_shuffle(false);
384
385 assert_eq!(distributed.num_replicas(), 2);
386 assert_eq!(distributed.rank(), 0);
387 assert!(!distributed.shuffle());
388 assert_eq!(distributed.len(), 5); let indices: Vec<usize> = distributed.iter().collect();
391 assert_eq!(indices, vec![0, 2, 4, 6, 8]); }
393
394 #[test]
395 fn test_distributed_wrapper_rank_1() {
396 let base_sampler = SequentialSampler::new(10);
397 let distributed = DistributedWrapper::new(base_sampler, 2, 1).with_shuffle(false);
398
399 assert_eq!(distributed.rank(), 1);
400 assert_eq!(distributed.len(), 5);
401
402 let indices: Vec<usize> = distributed.iter().collect();
403 assert_eq!(indices, vec![1, 3, 5, 7, 9]); }
405
406 #[test]
407 fn test_distributed_wrapper_uneven_split() {
408 let base_sampler = SequentialSampler::new(7); let dist0 = DistributedWrapper::new(base_sampler.clone(), 3, 0).with_shuffle(false);
411 let dist1 = DistributedWrapper::new(base_sampler.clone(), 3, 1).with_shuffle(false);
412 let dist2 = DistributedWrapper::new(base_sampler, 3, 2).with_shuffle(false);
413
414 assert_eq!(dist0.len(), 3); assert_eq!(dist1.len(), 2); assert_eq!(dist2.len(), 2); let indices0: Vec<usize> = dist0.iter().collect();
420 let indices1: Vec<usize> = dist1.iter().collect();
421 let indices2: Vec<usize> = dist2.iter().collect();
422
423 assert_eq!(indices0, vec![0, 3, 6]);
424 assert_eq!(indices1, vec![1, 4]);
425 assert_eq!(indices2, vec![2, 5]);
426
427 let mut all_indices = indices0;
429 all_indices.extend(indices1);
430 all_indices.extend(indices2);
431 all_indices.sort();
432 assert_eq!(all_indices, vec![0, 1, 2, 3, 4, 5, 6]);
433 }
434
435 #[test]
436 fn test_distributed_wrapper_with_shuffle() {
437 let base_sampler = SequentialSampler::new(10);
438 let distributed = DistributedWrapper::new(base_sampler, 2, 0)
439 .with_shuffle(true)
440 .with_generator(42);
441
442 let indices: Vec<usize> = distributed.iter().collect();
443 assert_eq!(indices.len(), 5);
444
445 for &idx in &indices {
447 assert!(idx < 10);
448 }
449
450 let distributed2 = DistributedWrapper::new(SequentialSampler::new(10), 2, 0)
452 .with_shuffle(true)
453 .with_generator(42);
454 let indices2: Vec<usize> = distributed2.iter().collect();
455 assert_eq!(indices, indices2);
456 }
457
458 #[test]
459 #[should_panic(expected = "Number of replicas must be positive")]
460 fn test_distributed_wrapper_zero_replicas() {
461 let base_sampler = SequentialSampler::new(10);
462 DistributedWrapper::new(base_sampler, 0, 0);
463 }
464
465 #[test]
466 #[should_panic(expected = "Rank must be less than num_replicas")]
467 fn test_distributed_wrapper_invalid_rank() {
468 let base_sampler = SequentialSampler::new(10);
469 DistributedWrapper::new(base_sampler, 2, 2);
470 }
471
472 #[test]
473 fn test_distributed_sampler_basic() {
474 let sampler = DistributedSampler::new(12, 3, 1, false);
475
476 assert_eq!(sampler.dataset_size(), 12);
477 assert_eq!(sampler.num_replicas(), 3);
478 assert_eq!(sampler.rank(), 1);
479 assert!(!sampler.shuffle());
480 assert!(!sampler.drop_last());
481 assert_eq!(sampler.len(), 4); let indices: Vec<usize> = sampler.iter().collect();
484 assert_eq!(indices, vec![4, 5, 6, 7]); }
486
487 #[test]
488 fn test_distributed_sampler_with_padding() {
489 let sampler = DistributedSampler::new(10, 3, 0, false); assert_eq!(sampler.len(), 4);
493
494 let indices: Vec<usize> = sampler.iter().collect();
495 assert_eq!(indices.len(), 4);
496
497 for &idx in &indices {
499 assert!(idx < 10);
500 }
501 }
502
503 #[test]
504 fn test_distributed_sampler_drop_last() {
505 let sampler = DistributedSampler::new(10, 3, 0, false).with_drop_last(true);
506
507 assert_eq!(sampler.len(), 3);
509
510 let indices: Vec<usize> = sampler.iter().collect();
511 assert_eq!(indices, vec![0, 1, 2]);
512 }
513
514 #[test]
515 fn test_distributed_sampler_shuffle() {
516 let sampler = DistributedSampler::new(12, 3, 0, true).with_generator(42);
517
518 let indices: Vec<usize> = sampler.iter().collect();
519 assert_eq!(indices.len(), 4);
520
521 let sampler2 = DistributedSampler::new(12, 3, 0, true).with_generator(42);
523 let indices2: Vec<usize> = sampler2.iter().collect();
524 assert_eq!(indices, indices2);
525 }
526
527 #[test]
528 fn test_convenience_functions() {
529 let base_sampler = SequentialSampler::new(8);
530 let dist_wrapper = distributed(base_sampler, 2, 0);
531 assert_eq!(dist_wrapper.len(), 4);
532
533 let dist_sampler = distributed_sampler(8, 2, 1, false);
534 assert_eq!(dist_sampler.len(), 4);
535 }
536
537 #[test]
538 fn test_distributed_sampler_edge_cases() {
539 let sampler = DistributedSampler::new(10, 1, 0, false);
541 assert_eq!(sampler.len(), 10);
542
543 let indices: Vec<usize> = sampler.iter().collect();
544 assert_eq!(indices, (0..10).collect::<Vec<_>>());
545
546 let sampler = DistributedSampler::new(2, 5, 3, false);
548 assert_eq!(sampler.len(), 1);
549
550 let indices: Vec<usize> = sampler.iter().collect();
551 assert_eq!(indices.len(), 1);
552 assert!(indices[0] < 2);
553 }
554
555 #[test]
556 fn test_into_sampler() {
557 let base_sampler = SequentialSampler::new(5);
558 let distributed = DistributedWrapper::new(base_sampler, 2, 0);
559
560 let recovered = distributed.into_sampler();
561 assert_eq!(recovered.len(), 5);
562 }
563}