stv_rs/parallelism/
range.rs

1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use log::debug;
16#[cfg(feature = "log_parallelism")]
17use log::{info, trace};
18#[cfg(feature = "log_parallelism")]
19use std::ops::AddAssign;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::sync::Arc;
22#[cfg(feature = "log_parallelism")]
23use std::sync::Mutex;
24
25/// A factory for handing out ranges of items to various threads.
26pub trait RangeFactory {
27    type Rn: Range;
28    type Orchestrator: RangeOrchestrator;
29
30    /// Creates a new factory for a range with the given number of elements
31    /// split across the given number of threads.
32    fn new(num_elements: usize, num_threads: usize) -> Self;
33
34    /// Returns the orchestrator object for all the ranges created by this
35    /// factory.
36    fn orchestrator(self) -> Self::Orchestrator;
37
38    /// Returns the range for the given thread.
39    fn range(&self, thread_id: usize) -> Self::Rn;
40}
41
42/// An orchestrator for the ranges given to all the threads.
43pub trait RangeOrchestrator {
44    /// Resets all the ranges to prepare a new computation round.
45    fn reset_ranges(&self);
46
47    /// Hook to display various debugging statistics.
48    #[cfg(feature = "log_parallelism")]
49    fn print_statistics(&self) {}
50}
51
52/// A range of items similar to [`std::ops::Range`], but that can steal from or
53/// be stolen by other threads.
54pub trait Range {
55    type Iter: Iterator<Item = usize>;
56
57    /// Returns an iterator over the items in this range. The item can be
58    /// dynamically stolen from/by other threads, but the iterator provides
59    /// a safe abstraction over that.
60    fn iter(&self) -> Self::Iter;
61}
62
63/// A factory that hands out a fixed range to each thread, without any stealing.
64pub struct FixedRangeFactory {
65    /// Total number of elements to iterate over.
66    num_elements: usize,
67    /// Number of threads that iterate.
68    num_threads: usize,
69}
70
71impl RangeFactory for FixedRangeFactory {
72    type Rn = FixedRange;
73    type Orchestrator = FixedRangeOrchestrator;
74
75    fn new(num_elements: usize, num_threads: usize) -> Self {
76        Self {
77            num_elements,
78            num_threads,
79        }
80    }
81
82    fn orchestrator(self) -> FixedRangeOrchestrator {
83        FixedRangeOrchestrator {}
84    }
85
86    fn range(&self, thread_id: usize) -> FixedRange {
87        let start = (thread_id * self.num_elements) / self.num_threads;
88        let end = ((thread_id + 1) * self.num_elements) / self.num_threads;
89        FixedRange(start..end)
90    }
91}
92
93/// An orchestrator for the [`FixedRangeFactory`].
94pub struct FixedRangeOrchestrator {}
95
96impl RangeOrchestrator for FixedRangeOrchestrator {
97    fn reset_ranges(&self) {
98        // Nothing to do.
99    }
100}
101
102/// A fixed range.
103#[derive(Debug, PartialEq, Eq)]
104pub struct FixedRange(std::ops::Range<usize>);
105
106impl Range for FixedRange {
107    type Iter = std::ops::Range<usize>;
108
109    fn iter(&self) -> Self::Iter {
110        self.0.clone()
111    }
112}
113
114/// A factory for ranges that implement work stealing among threads.
115///
116/// Whenever a thread finishes processing its range, it looks for another range
117/// to steal from. It then divides that range into two and steals a half, to
118/// continue processing items.
119pub struct WorkStealingRangeFactory {
120    /// Total number of elements to iterate over.
121    num_elements: usize,
122    /// Handle to the ranges of all the threads.
123    ranges: Arc<Vec<AtomicRange>>,
124    /// Handle to the work-stealing statistics.
125    #[cfg(feature = "log_parallelism")]
126    stats: Arc<Mutex<WorkStealingStats>>,
127}
128
129impl RangeFactory for WorkStealingRangeFactory {
130    type Rn = WorkStealingRange;
131    type Orchestrator = WorkStealingRangeOrchestrator;
132
133    fn new(num_elements: usize, num_threads: usize) -> Self {
134        Self {
135            num_elements,
136            ranges: Arc::new((0..num_threads).map(|_| AtomicRange::default()).collect()),
137            #[cfg(feature = "log_parallelism")]
138            stats: Arc::new(Mutex::new(WorkStealingStats::default())),
139        }
140    }
141
142    fn orchestrator(self) -> WorkStealingRangeOrchestrator {
143        WorkStealingRangeOrchestrator {
144            num_elements: self.num_elements,
145            ranges: self.ranges,
146            #[cfg(feature = "log_parallelism")]
147            stats: self.stats,
148        }
149    }
150
151    fn range(&self, thread_id: usize) -> WorkStealingRange {
152        WorkStealingRange {
153            id: thread_id,
154            ranges: self.ranges.clone(),
155            #[cfg(feature = "log_parallelism")]
156            stats: self.stats.clone(),
157        }
158    }
159}
160
161/// An orchestrator for the [`WorkStealingRangeFactory`].
162pub struct WorkStealingRangeOrchestrator {
163    /// Total number of elements to iterate over.
164    num_elements: usize,
165    /// Handle to the ranges of all the threads.
166    ranges: Arc<Vec<AtomicRange>>,
167    /// Handle to the work-stealing statistics.
168    #[cfg(feature = "log_parallelism")]
169    stats: Arc<Mutex<WorkStealingStats>>,
170}
171
172impl RangeOrchestrator for WorkStealingRangeOrchestrator {
173    fn reset_ranges(&self) {
174        debug!("Resetting ranges.");
175        let num_threads = self.ranges.len();
176        for (i, range) in self.ranges.iter().enumerate() {
177            let start = (i * self.num_elements) / num_threads;
178            let end = ((i + 1) * self.num_elements) / num_threads;
179            range.store(PackedRange::new(start as u32, end as u32));
180        }
181    }
182
183    #[cfg(feature = "log_parallelism")]
184    fn print_statistics(&self) {
185        let stats = self.stats.lock().unwrap();
186        info!("Work-stealing statistics:");
187        info!("- increments: {}", stats.increments);
188        info!("- failed_increments: {}", stats.failed_increments);
189        info!("- other_loads: {}", stats.other_loads);
190        info!("- thefts: {}", stats.thefts);
191        info!("- failed_thefts: {}", stats.failed_thefts);
192        info!("- increments + thefts: {}", stats.increments + stats.thefts);
193    }
194}
195
196/// A range that implements work stealing.
197pub struct WorkStealingRange {
198    /// Index of the thread that owns this range.
199    id: usize,
200    /// Handle to the ranges of all the threads.
201    ranges: Arc<Vec<AtomicRange>>,
202    /// Handle to the work-stealing statistics.
203    #[cfg(feature = "log_parallelism")]
204    stats: Arc<Mutex<WorkStealingStats>>,
205}
206
207impl Range for WorkStealingRange {
208    type Iter = WorkStealingRangeIterator;
209
210    fn iter(&self) -> Self::Iter {
211        WorkStealingRangeIterator {
212            id: self.id,
213            ranges: self.ranges.clone(),
214            #[cfg(feature = "log_parallelism")]
215            stats: WorkStealingStats::default(),
216            #[cfg(feature = "log_parallelism")]
217            global_stats: self.stats.clone(),
218        }
219    }
220}
221
222/// A [start, end) pair that can atomically be modified.
223#[repr(align(64))]
224struct AtomicRange(AtomicU64);
225
226impl Default for AtomicRange {
227    #[inline(always)]
228    fn default() -> Self {
229        AtomicRange::new(PackedRange::default())
230    }
231}
232
233impl AtomicRange {
234    /// Creates a new atomic range.
235    #[inline(always)]
236    fn new(range: PackedRange) -> Self {
237        AtomicRange(AtomicU64::new(range.0))
238    }
239
240    /// Atomically loads the range.
241    #[inline(always)]
242    fn load(&self) -> PackedRange {
243        PackedRange(self.0.load(Ordering::SeqCst))
244    }
245
246    /// Atomically stores the range.
247    #[inline(always)]
248    fn store(&self, range: PackedRange) {
249        self.0.store(range.0, Ordering::SeqCst)
250    }
251
252    /// Atomically compares and exchanges the range. In case of failure, the
253    /// range contained in the atomic variable is returned.
254    #[inline(always)]
255    fn compare_exchange(&self, before: PackedRange, after: PackedRange) -> Result<(), PackedRange> {
256        match self
257            .0
258            .compare_exchange(before.0, after.0, Ordering::SeqCst, Ordering::SeqCst)
259        {
260            Ok(_) => Ok(()),
261            Err(e) => Err(PackedRange(e)),
262        }
263    }
264}
265
266/// A [start, end) range that fits into a `u64`, and can therefore be
267/// loaded/stored atomically.
268#[derive(Clone, Copy, Default)]
269struct PackedRange(u64);
270
271impl PackedRange {
272    /// Creates a range with the given [start, end) pair.
273    #[inline(always)]
274    fn new(start: u32, end: u32) -> Self {
275        Self((start as u64) | ((end as u64) << 32))
276    }
277
278    /// Reads the start of the range (inclusive).
279    #[inline(always)]
280    fn start(self) -> u32 {
281        self.0 as u32
282    }
283
284    /// Reads the end of the range (exclusive).
285    #[inline(always)]
286    fn end(self) -> u32 {
287        (self.0 >> 32) as u32
288    }
289
290    /// Reads the length of the range.
291    #[inline(always)]
292    fn len(self) -> u32 {
293        self.end() - self.start()
294    }
295
296    /// Increments the start of the range.
297    #[inline(always)]
298    fn increment_start(self) -> (u32, Self) {
299        assert!(self.start() < self.end());
300        // TODO: check for overflow.
301        (self.start(), PackedRange::new(self.start() + 1, self.end()))
302    }
303
304    /// Splits the range into two halves. If the input range is non-empty, the
305    /// second half is guaranteed to be non-empty.
306    #[inline(always)]
307    fn split(self) -> (Self, Self) {
308        let start = self.start();
309        let end = self.end();
310        // TODO: check for overflow.
311        let middle = (start + end) / 2;
312        (
313            PackedRange::new(start, middle),
314            PackedRange::new(middle, end),
315        )
316    }
317
318    /// Checks if the range is empty.
319    #[inline(always)]
320    fn is_empty(self) -> bool {
321        self.start() == self.end()
322    }
323}
324
325#[cfg(feature = "log_parallelism")]
326#[derive(Default)]
327pub struct WorkStealingStats {
328    /// Number of times this thread successfully incremented its range.
329    increments: u64,
330    /// Number of times this thread failed to increment its range, because
331    /// another thread stole it.
332    failed_increments: u64,
333    /// Number of times this thread loaded the range of another thread
334    /// (excluding compare-exchanges).
335    other_loads: u64,
336    /// Number of times this thread has stolen a range from another thread.
337    thefts: u64,
338    /// Number of times this thread failed to steal a range because another
339    /// thread modified it in the meantime.
340    failed_thefts: u64,
341}
342
343#[cfg(feature = "log_parallelism")]
344impl AddAssign<&WorkStealingStats> for WorkStealingStats {
345    fn add_assign(&mut self, other: &WorkStealingStats) {
346        self.increments += other.increments;
347        self.failed_increments += other.failed_increments;
348        self.other_loads += other.other_loads;
349        self.thefts += other.thefts;
350        self.failed_thefts += other.failed_thefts;
351    }
352}
353
354/// An iterator for the [`WorkStealingRange`].
355pub struct WorkStealingRangeIterator {
356    /// Index of the thread that owns this range.
357    id: usize,
358    /// Handle to the ranges of all the threads.
359    ranges: Arc<Vec<AtomicRange>>,
360    /// Local work-stealing statistics.
361    #[cfg(feature = "log_parallelism")]
362    stats: WorkStealingStats,
363    /// Handle to the global work-stealing statistics.
364    #[cfg(feature = "log_parallelism")]
365    global_stats: Arc<Mutex<WorkStealingStats>>,
366}
367
368impl Iterator for WorkStealingRangeIterator {
369    type Item = usize;
370
371    fn next(&mut self) -> Option<usize> {
372        let my_atomic_range: &AtomicRange = &self.ranges[self.id];
373        let mut my_range: PackedRange = my_atomic_range.load();
374        loop {
375            if !my_range.is_empty() {
376                let (taken, my_new_range) = my_range.increment_start();
377                match my_atomic_range.compare_exchange(my_range, my_new_range) {
378                    Ok(()) => {
379                        #[cfg(feature = "log_parallelism")]
380                        {
381                            self.stats.increments += 1;
382                            trace!(
383                                "[thread {}] Incremented range to {}..{}.",
384                                self.id,
385                                my_new_range.start(),
386                                my_new_range.end()
387                            );
388                        }
389                        return Some(taken as usize);
390                    }
391                    Err(range) => {
392                        my_range = range;
393                        #[cfg(feature = "log_parallelism")]
394                        {
395                            self.stats.failed_increments += 1;
396                            debug!(
397                                "[thread {}] Failed to increment range, new range is {}..{}.",
398                                self.id,
399                                range.start(),
400                                range.end()
401                            );
402                        }
403                        continue;
404                    }
405                }
406            } else {
407                #[cfg(feature = "log_parallelism")]
408                debug!(
409                    "[thread {}] Range {}..{} is empty, scanning other threads.",
410                    self.id,
411                    my_range.start(),
412                    my_range.end()
413                );
414                let range_count = self.ranges.len();
415
416                #[cfg(feature = "log_parallelism")]
417                {
418                    self.stats.other_loads += range_count as u64 - 1;
419                }
420                let mut other_ranges = vec![PackedRange::default(); range_count];
421                for (i, range) in other_ranges.iter_mut().enumerate() {
422                    if i == self.id {
423                        continue;
424                    }
425                    *range = self.ranges[i].load();
426                }
427
428                let mut max_index = 0;
429                let mut max_range = PackedRange::default();
430                for (i, range) in other_ranges.iter().enumerate() {
431                    if i == self.id {
432                        continue;
433                    }
434                    if range.len() > max_range.len() {
435                        max_index = i;
436                        max_range = *range;
437                    }
438                }
439
440                while !max_range.is_empty() {
441                    // Steal some work.
442                    let (remaining, stolen) = max_range.split();
443                    match self.ranges[max_index].compare_exchange(max_range, remaining) {
444                        Ok(()) => {
445                            let (taken, my_new_range) = stolen.increment_start();
446                            my_atomic_range.store(my_new_range);
447                            #[cfg(feature = "log_parallelism")]
448                            {
449                                self.stats.thefts += 1;
450                            }
451                            return Some(taken as usize);
452                        }
453                        Err(range) => {
454                            other_ranges[max_index] = range;
455                            #[cfg(feature = "log_parallelism")]
456                            {
457                                self.stats.failed_thefts += 1;
458                            }
459                            // Re-compute max_index.
460                            max_range = range;
461                            for (i, range) in other_ranges.iter().enumerate() {
462                                if i == self.id {
463                                    continue;
464                                }
465                                if range.len() > max_range.len() {
466                                    max_index = i;
467                                    max_range = *range;
468                                }
469                            }
470                        }
471                    }
472                }
473
474                #[cfg(feature = "log_parallelism")]
475                {
476                    debug!("[thread {}] Didn't find anything to steal", self.id);
477                    *self.global_stats.lock().unwrap() += &self.stats;
478                }
479                // Didn't manage to steal anything.
480                return None;
481            }
482        }
483    }
484}
485
486#[cfg(test)]
487mod test {
488    use super::*;
489
490    #[test]
491    fn test_fixed_range_factory_splits_evenly() {
492        let factory = FixedRangeFactory::new(100, 4);
493        assert_eq!(factory.range(0), FixedRange(0..25));
494        assert_eq!(factory.range(1), FixedRange(25..50));
495        assert_eq!(factory.range(2), FixedRange(50..75));
496        assert_eq!(factory.range(3), FixedRange(75..100));
497
498        let factory = FixedRangeFactory::new(100, 7);
499        assert_eq!(factory.range(0), FixedRange(0..14));
500        assert_eq!(factory.range(1), FixedRange(14..28));
501        assert_eq!(factory.range(2), FixedRange(28..42));
502        assert_eq!(factory.range(3), FixedRange(42..57));
503        assert_eq!(factory.range(4), FixedRange(57..71));
504        assert_eq!(factory.range(5), FixedRange(71..85));
505        assert_eq!(factory.range(6), FixedRange(85..100));
506    }
507
508    #[test]
509    fn test_fixed_range() {
510        let factory = FixedRangeFactory::new(100, 4);
511        let ranges: [_; 4] = std::array::from_fn(|i| factory.range(i));
512        let orchestrator = factory.orchestrator();
513
514        std::thread::scope(|s| {
515            for _ in 0..10 {
516                orchestrator.reset_ranges();
517                let handles = ranges
518                    .each_ref()
519                    .map(|range| s.spawn(move || range.iter().collect::<Vec<_>>()));
520                let values: [Vec<usize>; 4] = handles.map(|handle| handle.join().unwrap());
521
522                // The fixed range implementation always yields the same items in order.
523                for (i, set) in values.iter().enumerate() {
524                    assert_eq!(*set, (i * 25..(i + 1) * 25).collect::<Vec<_>>());
525                }
526            }
527        });
528    }
529
530    #[test]
531    fn test_work_stealing_range() {
532        const NUM_THREADS: usize = 4;
533        const NUM_ELEMENTS: usize = 10000;
534
535        let factory = WorkStealingRangeFactory::new(NUM_ELEMENTS, NUM_THREADS);
536        let ranges: [_; NUM_THREADS] = std::array::from_fn(|i| factory.range(i));
537        let orchestrator = factory.orchestrator();
538
539        std::thread::scope(|s| {
540            for _ in 0..10 {
541                orchestrator.reset_ranges();
542                let handles = ranges
543                    .each_ref()
544                    .map(|range| s.spawn(move || range.iter().collect::<Vec<_>>()));
545                let values: [Vec<usize>; NUM_THREADS] =
546                    handles.map(|handle| handle.join().unwrap());
547
548                // This checks that:
549                // - all ranges yield disjoint elements,
550                // - each range never yields the same element twice.
551                let mut all_values = vec![false; NUM_ELEMENTS];
552                for set in values {
553                    println!("Values: {set:?}");
554                    for x in set {
555                        assert!(!all_values[x]);
556                        all_values[x] = true;
557                    }
558                }
559                // Check that the whole range is covered.
560                assert!(all_values.iter().all(|x| *x));
561            }
562        });
563    }
564
565    #[test]
566    fn test_default_packed_range_is_empty() {
567        let range = PackedRange::default();
568        assert!(range.is_empty());
569        assert_eq!(range.start(), 0);
570        assert_eq!(range.end(), 0);
571    }
572
573    #[test]
574    fn test_packed_range_is_consistent() {
575        for i in 0..30 {
576            for j in i..30 {
577                let range = PackedRange::new(i, j);
578                assert_eq!(range.start(), i);
579                assert_eq!(range.end(), j);
580            }
581        }
582    }
583
584    #[test]
585    fn test_packed_range_increment_start() {
586        let mut range = PackedRange::new(0, 10);
587
588        for i in 1..=10 {
589            let (j, new_range) = range.increment_start();
590            range = new_range;
591            assert_eq!(j, i - 1);
592            assert_eq!((range.start(), range.end()), (i, 10));
593        }
594    }
595
596    #[test]
597    fn test_packed_range_split() {
598        let (left, right) = PackedRange::new(0, 0).split();
599        assert!(left.is_empty());
600        assert_eq!((left.start(), left.end()), (0, 0));
601        assert!(right.is_empty());
602        assert_eq!((right.start(), right.end()), (0, 0));
603
604        let (left, right) = PackedRange::new(0, 1).split();
605        assert!(left.is_empty());
606        assert_eq!((left.start(), left.end()), (0, 0));
607        assert!(!right.is_empty());
608        assert_eq!((right.start(), right.end()), (0, 1));
609    }
610
611    #[test]
612    fn test_packed_range_split_is_exhaustive() {
613        for i in 0..100 {
614            for j in i..100 {
615                let (left, right) = PackedRange::new(i, j).split();
616                assert!(left.start() <= left.end());
617                assert!(right.start() <= right.end());
618                assert_eq!(left.start(), i);
619                assert_eq!(left.end(), right.start());
620                assert_eq!(right.end(), j);
621            }
622        }
623    }
624
625    #[test]
626    fn test_packed_range_split_is_fair() {
627        for i in 0..100 {
628            for j in i..100 {
629                let (left, right) = PackedRange::new(i, j).split();
630                assert!(left.end() - left.start() <= right.end() - right.start());
631                assert!(right.end() - right.start() <= left.end() - left.start() + 1);
632                if i != j {
633                    assert!(!right.is_empty());
634                }
635            }
636        }
637    }
638}