vyre_runtime/
scheduler.rs1use std::ops::Range;
7use std::sync::atomic::{AtomicUsize, Ordering};
8use std::sync::Arc;
9use vyre_driver::{BackendError, VyreBackend};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct Shard {
14 pub backend_id: &'static str,
16 pub work_range: Range<usize>,
18}
19
20pub struct WorkStealingScheduler {
22 backends: Vec<Arc<dyn VyreBackend>>,
23 work_index: AtomicUsize,
30}
31
32impl WorkStealingScheduler {
33 pub fn new(backends: Vec<Arc<dyn VyreBackend>>) -> Self {
35 Self {
36 backends,
37 work_index: AtomicUsize::new(0),
38 }
39 }
40
41 pub fn partition(&self, total_len: usize) -> Vec<Shard> {
43 match self.try_partition(total_len) {
44 Ok(shards) => shards,
45 Err(_error) => Vec::new(),
46 }
47 }
48
49 pub fn try_partition(&self, total_len: usize) -> Result<Vec<Shard>, BackendError> {
52 let mut shards = Vec::new();
53 self.try_partition_into(total_len, &mut shards)?;
54 Ok(shards)
55 }
56
57 #[must_use]
74 pub fn claim_next_unit(&self) -> usize {
75 self.work_index.fetch_add(1, Ordering::AcqRel)
76 }
77
78 pub fn reset_unit_cursor(&self) {
81 self.work_index.store(0, Ordering::Release);
82 }
83
84 pub fn partition_into(&self, total_len: usize, out: &mut Vec<Shard>) {
89 if self.try_partition_into(total_len, out).is_err() {
90 out.clear();
91 }
92 }
93
94 pub fn try_partition_into(
97 &self,
98 total_len: usize,
99 out: &mut Vec<Shard>,
100 ) -> Result<(), BackendError> {
101 let n = self.backends.len();
102 out.clear();
103 if n == 0 || total_len == 0 {
104 return Ok(());
105 }
106 let work_unit_size = partition_work_unit_size(total_len, n);
107 let num_units = total_len.div_ceil(work_unit_size);
108 vyre_foundation::allocation::try_reserve_vec_to_capacity(out, num_units).map_err(
109 |error| BackendError::InvalidProgram {
110 fix: format!(
111 "Fix: scheduler could not reserve {num_units} GPU work shard(s): {error}. Shard the workload before work-stealing partitioning."
112 ),
113 },
114 )?;
115 let mut start = 0;
116 for i in 0..num_units {
117 let end = (start + work_unit_size).min(total_len);
118 out.push(Shard {
119 backend_id: self.backends[i % n].id(),
120 work_range: start..end,
121 });
122 start = end;
123 }
124 Ok(())
125 }
126}
127
128fn partition_work_unit_size(total_len: usize, backend_count: usize) -> usize {
129 if total_len == 0 || backend_count == 0 {
130 return 1;
131 }
132 let denominator = backend_count.checked_mul(4).unwrap_or(usize::MAX);
133 (total_len / denominator.max(1)).max(1)
134}
135
136#[cfg(test)]
137fn partition_ranges(total_len: usize, backend_count: usize) -> Vec<Range<usize>> {
138 if backend_count == 0 || total_len == 0 {
139 return Vec::new();
140 }
141 let work_unit_size = partition_work_unit_size(total_len, backend_count);
142 let num_units = total_len.div_ceil(work_unit_size);
143 let mut ranges = Vec::with_capacity(num_units);
144 let mut start = 0;
145 for _ in 0..num_units {
146 let end = (start + work_unit_size).min(total_len);
147 ranges.push(start..end);
148 start = end;
149 }
150 ranges
151}
152
153#[cfg(test)]
154mod tests {
155 use super::{partition_ranges, WorkStealingScheduler};
156 use std::sync::Arc;
157 use vyre_driver::backend::{DispatchConfig, VyreBackend};
158 use vyre_foundation::ir::Program;
159
160 struct TestBackend(&'static str);
161
162 impl vyre_driver::backend::private::Sealed for TestBackend {}
163
164 impl VyreBackend for TestBackend {
165 fn id(&self) -> &'static str {
166 self.0
167 }
168
169 fn dispatch(
170 &self,
171 _program: &Program,
172 _inputs: &[Vec<u8>],
173 _config: &DispatchConfig,
174 ) -> Result<Vec<Vec<u8>>, vyre_driver::BackendError> {
175 Ok(Vec::new())
176 }
177 }
178
179 #[test]
180 fn partition_ranges_produces_fine_grained_units() {
181 let ranges = partition_ranges(10, 3);
182 assert_eq!(ranges.len(), 10);
183 assert_eq!(
184 ranges,
185 vec![0..1, 1..2, 2..3, 3..4, 4..5, 5..6, 6..7, 7..8, 8..9, 9..10]
186 );
187 }
188
189 #[test]
190 fn partition_ranges_never_emits_empty_shards() {
191 let ranges = partition_ranges(2, 8);
192 assert_eq!(ranges, vec![0..1, 1..2]);
193 }
194
195 #[test]
196 fn partition_ranges_uses_overflow_safe_work_unit_math() {
197 let ranges = partition_ranges(2, usize::MAX);
198 assert_eq!(ranges[0], 0..1);
199 assert_eq!(ranges[1], 1..2);
200 assert_eq!(
201 super::partition_work_unit_size(2, usize::MAX),
202 1,
203 "backend_count * 4 overflow must not panic or enlarge the work unit"
204 );
205 }
206
207 #[test]
208 fn scheduler_partition_into_reuses_output_storage() {
209 let scheduler = WorkStealingScheduler::new(vec![
210 Arc::new(TestBackend("a")),
211 Arc::new(TestBackend("b")),
212 Arc::new(TestBackend("c")),
213 ]);
214 let mut shards = Vec::with_capacity(10);
215
216 scheduler.partition_into(10, &mut shards);
217 let ptr = shards.as_ptr();
218 scheduler.partition_into(10, &mut shards);
219
220 assert_eq!(shards.as_ptr(), ptr);
221 assert_eq!(shards.len(), 10);
222 assert_eq!(shards[0].backend_id, "a");
223 assert_eq!(shards[0].work_range, 0..1);
224 assert_eq!(shards[1].backend_id, "b");
225 assert_eq!(shards[1].work_range, 1..2);
226 assert_eq!(shards[9].backend_id, "a");
227 assert_eq!(shards[9].work_range, 9..10);
228 assert_eq!(
229 scheduler
230 .work_index
231 .load(std::sync::atomic::Ordering::Relaxed),
232 0
233 );
234 }
235}