Skip to main content

vyre_runtime/
scheduler.rs

1//! Multi-GPU work stealing scheduler (Innovation I.7).
2//!
3//! Partitions a large Program or batch of Programs across all
4//! registered physical devices.
5
6use std::ops::Range;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9use vyre_driver::{BackendError, VyreBackend};
10
11/// A unit of work assigned to one GPU.
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct Shard {
14    /// Stable backend identifier for the GPU backend receiving this shard.
15    pub backend_id: &'static str,
16    /// Half-open byte/item range assigned to the backend.
17    pub work_range: Range<usize>,
18}
19
20/// Dynamic work-stealing scheduler.
21pub struct WorkStealingScheduler {
22    backends: Vec<Arc<dyn VyreBackend>>,
23    /// Atomic work index used by dispatch loops to let fast backends
24    /// steal more fine-grained work units. Worker threads call
25    /// [`Self::claim_next_unit`] which atomically increments the index;
26    /// the returned value is the unit index they own. This is the
27    /// work-stealing primitive  -  fast backends pull more units, slow
28    /// backends pull fewer.
29    work_index: AtomicUsize,
30}
31
32impl WorkStealingScheduler {
33    /// Create a scheduler over the live runtime backends available to the process.
34    pub fn new(backends: Vec<Arc<dyn VyreBackend>>) -> Self {
35        Self {
36            backends,
37            work_index: AtomicUsize::new(0),
38        }
39    }
40
41    /// Partition a large haystack across available GPUs.
42    pub fn partition(&self, total_len: usize) -> Vec<Shard> {
43        match self.try_partition(total_len) {
44            Ok(shards) => shards,
45            Err(_error) => Vec::new(),
46        }
47    }
48
49    /// Partition a large haystack across available GPUs with explicit staging
50    /// allocation failure reporting.
51    pub fn try_partition(&self, total_len: usize) -> Result<Vec<Shard>, BackendError> {
52        let mut shards = Vec::new();
53        self.try_partition_into(total_len, &mut shards)?;
54        Ok(shards)
55    }
56
57    /// Atomically claim the next fine-grained work unit. Worker threads
58    /// call this in a loop; the returned value is the unit index they
59    /// own. When the returned index is `>= num_units`, the worker is
60    /// done. This is the work-stealing primitive: fast backends call
61    /// `claim_next_unit` more times in the same wall-clock window.
62    ///
63    /// # Examples
64    ///
65    /// ```
66    /// use vyre_runtime::scheduler::WorkStealingScheduler;
67    /// let scheduler = WorkStealingScheduler::new(Vec::new());
68    /// assert_eq!(scheduler.claim_next_unit(), 0);
69    /// assert_eq!(scheduler.claim_next_unit(), 1);
70    /// scheduler.reset_unit_cursor();
71    /// assert_eq!(scheduler.claim_next_unit(), 0);
72    /// ```
73    #[must_use]
74    pub fn claim_next_unit(&self) -> usize {
75        self.work_index.fetch_add(1, Ordering::AcqRel)
76    }
77
78    /// Reset the work-unit cursor to zero. Call between dispatches that
79    /// reuse the same scheduler.
80    pub fn reset_unit_cursor(&self) {
81        self.work_index.store(0, Ordering::Release);
82    }
83
84    /// Partition a large haystack into many fine-grained work units
85    /// assigned round-robin to backends. A caller-side dispatch loop
86    /// uses [`Self::claim_next_unit`] to let worker threads atomically
87    /// claim units so fast backends steal more work.
88    pub fn partition_into(&self, total_len: usize, out: &mut Vec<Shard>) {
89        if self.try_partition_into(total_len, out).is_err() {
90            out.clear();
91        }
92    }
93
94    /// Partition into caller-owned storage with explicit staging allocation
95    /// failure reporting.
96    pub fn try_partition_into(
97        &self,
98        total_len: usize,
99        out: &mut Vec<Shard>,
100    ) -> Result<(), BackendError> {
101        let n = self.backends.len();
102        out.clear();
103        if n == 0 || total_len == 0 {
104            return Ok(());
105        }
106        let work_unit_size = partition_work_unit_size(total_len, n);
107        let num_units = total_len.div_ceil(work_unit_size);
108        vyre_foundation::allocation::try_reserve_vec_to_capacity(out, num_units).map_err(
109            |error| BackendError::InvalidProgram {
110                fix: format!(
111                    "Fix: scheduler could not reserve {num_units} GPU work shard(s): {error}. Shard the workload before work-stealing partitioning."
112                ),
113            },
114        )?;
115        let mut start = 0;
116        for i in 0..num_units {
117            let end = (start + work_unit_size).min(total_len);
118            out.push(Shard {
119                backend_id: self.backends[i % n].id(),
120                work_range: start..end,
121            });
122            start = end;
123        }
124        Ok(())
125    }
126}
127
128fn partition_work_unit_size(total_len: usize, backend_count: usize) -> usize {
129    if total_len == 0 || backend_count == 0 {
130        return 1;
131    }
132    let denominator = backend_count.checked_mul(4).unwrap_or(usize::MAX);
133    (total_len / denominator.max(1)).max(1)
134}
135
136#[cfg(test)]
137fn partition_ranges(total_len: usize, backend_count: usize) -> Vec<Range<usize>> {
138    if backend_count == 0 || total_len == 0 {
139        return Vec::new();
140    }
141    let work_unit_size = partition_work_unit_size(total_len, backend_count);
142    let num_units = total_len.div_ceil(work_unit_size);
143    let mut ranges = Vec::with_capacity(num_units);
144    let mut start = 0;
145    for _ in 0..num_units {
146        let end = (start + work_unit_size).min(total_len);
147        ranges.push(start..end);
148        start = end;
149    }
150    ranges
151}
152
153#[cfg(test)]
154mod tests {
155    use super::{partition_ranges, WorkStealingScheduler};
156    use std::sync::Arc;
157    use vyre_driver::backend::{DispatchConfig, VyreBackend};
158    use vyre_foundation::ir::Program;
159
160    struct TestBackend(&'static str);
161
162    impl vyre_driver::backend::private::Sealed for TestBackend {}
163
164    impl VyreBackend for TestBackend {
165        fn id(&self) -> &'static str {
166            self.0
167        }
168
169        fn dispatch(
170            &self,
171            _program: &Program,
172            _inputs: &[Vec<u8>],
173            _config: &DispatchConfig,
174        ) -> Result<Vec<Vec<u8>>, vyre_driver::BackendError> {
175            Ok(Vec::new())
176        }
177    }
178
179    #[test]
180    fn partition_ranges_produces_fine_grained_units() {
181        let ranges = partition_ranges(10, 3);
182        assert_eq!(ranges.len(), 10);
183        assert_eq!(
184            ranges,
185            vec![0..1, 1..2, 2..3, 3..4, 4..5, 5..6, 6..7, 7..8, 8..9, 9..10]
186        );
187    }
188
189    #[test]
190    fn partition_ranges_never_emits_empty_shards() {
191        let ranges = partition_ranges(2, 8);
192        assert_eq!(ranges, vec![0..1, 1..2]);
193    }
194
195    #[test]
196    fn partition_ranges_uses_overflow_safe_work_unit_math() {
197        let ranges = partition_ranges(2, usize::MAX);
198        assert_eq!(ranges[0], 0..1);
199        assert_eq!(ranges[1], 1..2);
200        assert_eq!(
201            super::partition_work_unit_size(2, usize::MAX),
202            1,
203            "backend_count * 4 overflow must not panic or enlarge the work unit"
204        );
205    }
206
207    #[test]
208    fn scheduler_partition_into_reuses_output_storage() {
209        let scheduler = WorkStealingScheduler::new(vec![
210            Arc::new(TestBackend("a")),
211            Arc::new(TestBackend("b")),
212            Arc::new(TestBackend("c")),
213        ]);
214        let mut shards = Vec::with_capacity(10);
215
216        scheduler.partition_into(10, &mut shards);
217        let ptr = shards.as_ptr();
218        scheduler.partition_into(10, &mut shards);
219
220        assert_eq!(shards.as_ptr(), ptr);
221        assert_eq!(shards.len(), 10);
222        assert_eq!(shards[0].backend_id, "a");
223        assert_eq!(shards[0].work_range, 0..1);
224        assert_eq!(shards[1].backend_id, "b");
225        assert_eq!(shards[1].work_range, 1..2);
226        assert_eq!(shards[9].backend_id, "a");
227        assert_eq!(shards[9].work_range, 9..10);
228        assert_eq!(
229            scheduler
230                .work_index
231                .load(std::sync::atomic::Ordering::Relaxed),
232            0
233        );
234    }
235}