vyre_runtime/megakernel/planner/
grid.rs1use std::cell::RefCell;
4
5use rustc_hash::FxHashMap;
6use vyre_driver::backend::BackendError;
7
8use super::geometry::MegakernelLaunchGeometry;
9use super::sizing::MegakernelSizingPolicy;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct MegakernelGridLimits {
14 pub max_workgroup_size_x: u32,
16 pub max_compute_workgroups_per_dimension: u32,
18 pub max_compute_invocations_per_workgroup: u32,
20}
21
22const GRID_PLAN_CACHE_CAP: usize = 128;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25struct GeometryCacheKey {
26 slot_count: u32,
27 worker_count: u32,
28 max_workgroup_size_x: u32,
29}
30
31struct MegakernelPlannerCache {
32 grid_plans:
33 FxHashMap<(MegakernelGridRequest, MegakernelGridLimits), CacheEntry<MegakernelGridPlan>>,
34 geometries: FxHashMap<GeometryCacheKey, CacheEntry<MegakernelLaunchGeometry>>,
35 clock: u64,
36}
37
38struct CacheEntry<T> {
39 value: T,
40 last_seen: u64,
41}
42
43impl MegakernelPlannerCache {
44 fn get_grid_plan(
45 &mut self,
46 key: &(MegakernelGridRequest, MegakernelGridLimits),
47 ) -> Option<MegakernelGridPlan> {
48 self.prepare_cache_hit_tick();
49 let entry = self.grid_plans.get_mut(key)?;
50 self.clock += 1;
51 entry.last_seen = self.clock;
52 Some(entry.value)
53 }
54
55 fn insert_grid_plan(
56 &mut self,
57 key: (MegakernelGridRequest, MegakernelGridLimits),
58 value: MegakernelGridPlan,
59 ) {
60 let tick = self.next_tick();
61 self.grid_plans.insert(
62 key,
63 CacheEntry {
64 value,
65 last_seen: tick,
66 },
67 );
68 self.evict_grid_plans_to_cap();
69 }
70
71 fn get_geometry(&mut self, key: &GeometryCacheKey) -> Option<MegakernelLaunchGeometry> {
72 self.prepare_cache_hit_tick();
73 let entry = self.geometries.get_mut(key)?;
74 self.clock += 1;
75 entry.last_seen = self.clock;
76 Some(entry.value)
77 }
78
79 fn insert_geometry(&mut self, key: GeometryCacheKey, value: MegakernelLaunchGeometry) {
80 let tick = self.next_tick();
81 self.geometries.insert(
82 key,
83 CacheEntry {
84 value,
85 last_seen: tick,
86 },
87 );
88 self.evict_geometries_to_cap();
89 }
90
91 fn evict_grid_plans_to_cap(&mut self) {
92 while self.grid_plans.len() > GRID_PLAN_CACHE_CAP {
93 let Some(evicted) = self
94 .grid_plans
95 .iter()
96 .min_by_key(|(_, entry)| entry.last_seen)
97 .map(|(key, _)| *key)
98 else {
99 break;
100 };
101 self.grid_plans.remove(&evicted);
102 }
103 }
104
105 fn evict_geometries_to_cap(&mut self) {
106 while self.geometries.len() > GRID_PLAN_CACHE_CAP {
107 let Some(evicted) = self
108 .geometries
109 .iter()
110 .min_by_key(|(_, entry)| entry.last_seen)
111 .map(|(key, _)| *key)
112 else {
113 break;
114 };
115 self.geometries.remove(&evicted);
116 }
117 }
118
119 fn next_tick(&mut self) -> u64 {
120 self.prepare_cache_hit_tick();
121 self.clock += 1;
122 self.clock
123 }
124
125 fn prepare_cache_hit_tick(&mut self) {
126 if self.clock == u64::MAX {
127 self.clock = 0;
128 for entry in self.grid_plans.values_mut() {
129 entry.last_seen = 0;
130 }
131 for entry in self.geometries.values_mut() {
132 entry.last_seen = 0;
133 }
134 }
135 }
136}
137
138impl Default for MegakernelPlannerCache {
139 fn default() -> Self {
140 Self {
141 grid_plans: FxHashMap::with_capacity_and_hasher(
142 GRID_PLAN_CACHE_CAP,
143 Default::default(),
144 ),
145 geometries: FxHashMap::with_capacity_and_hasher(
146 GRID_PLAN_CACHE_CAP,
147 Default::default(),
148 ),
149 clock: 0,
150 }
151 }
152}
153
154thread_local! {
155 static PLANNER_CACHE: RefCell<MegakernelPlannerCache> = RefCell::new(MegakernelPlannerCache::default());
156}
157
158fn cached_grid_plan(
159 request: MegakernelGridRequest,
160 limits: MegakernelGridLimits,
161) -> Result<MegakernelGridPlan, BackendError> {
162 if let Some(plan) =
163 PLANNER_CACHE.with(|cache| cache.borrow_mut().get_grid_plan(&(request, limits)))
164 {
165 return Ok(plan);
166 }
167
168 let plan = MegakernelSizingPolicy::standard().calculate_optimal_grid(request, limits)?;
169 PLANNER_CACHE.with(|cache| {
170 cache.borrow_mut().insert_grid_plan((request, limits), plan);
171 });
172 Ok(plan)
173}
174
175pub(super) fn cached_geometry_from_slots(
176 slot_count: u32,
177 worker_count: u32,
178 max_workgroup_size_x: u32,
179) -> MegakernelLaunchGeometry {
180 let key = GeometryCacheKey {
181 slot_count,
182 worker_count,
183 max_workgroup_size_x,
184 };
185 if let Some(geometry) = PLANNER_CACHE.with(|cache| cache.borrow_mut().get_geometry(&key)) {
186 return geometry;
187 }
188
189 let geometry = MegakernelSizingPolicy::standard().geometry_from_slots(
190 slot_count,
191 worker_count,
192 max_workgroup_size_x,
193 );
194 PLANNER_CACHE.with(|cache| {
195 cache.borrow_mut().insert_geometry(key, geometry);
196 });
197 geometry
198}
199
200impl MegakernelGridLimits {
201 #[must_use]
203 pub const fn new(
204 max_workgroup_size_x: u32,
205 max_compute_workgroups_per_dimension: u32,
206 max_compute_invocations_per_workgroup: u32,
207 ) -> Self {
208 Self {
209 max_workgroup_size_x,
210 max_compute_workgroups_per_dimension,
211 max_compute_invocations_per_workgroup,
212 }
213 }
214
215 pub(super) fn validate(self) -> Result<(), BackendError> {
216 if self.max_workgroup_size_x == 0 {
217 return Err(BackendError::new(
218 "megakernel max_workgroup_size_x must be non-zero. Fix: pass live adapter limits instead of a zero limit.",
219 ));
220 }
221 if self.max_compute_workgroups_per_dimension == 0 {
222 return Err(BackendError::new(
223 "megakernel max_compute_workgroups_per_dimension must be non-zero. Fix: pass live adapter limits instead of a zero limit.",
224 ));
225 }
226 if self.max_compute_invocations_per_workgroup == 0 {
227 return Err(BackendError::new(
228 "megakernel max_compute_invocations_per_workgroup must be non-zero. Fix: pass live adapter limits instead of a zero limit.",
229 ));
230 }
231 Ok(())
232 }
233}
234
235#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
237pub struct MegakernelGridRequest {
238 pub queue_len: u32,
240 pub requested_worker_groups: u32,
242}
243
244impl MegakernelGridRequest {
245 #[must_use]
247 pub const fn new(queue_len: u32, requested_worker_groups: u32) -> Self {
248 Self {
249 queue_len,
250 requested_worker_groups,
251 }
252 }
253}
254
255#[derive(Debug, Clone, Copy, PartialEq, Eq)]
257pub struct MegakernelGridPlan {
258 pub geometry: MegakernelLaunchGeometry,
260 pub worker_groups: u32,
262}
263
264impl MegakernelGridPlan {
265 pub fn recommend(
271 request: MegakernelGridRequest,
272 limits: MegakernelGridLimits,
273 ) -> Result<Self, BackendError> {
274 cached_grid_plan(request, limits)
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 fn limits() -> MegakernelGridLimits {
283 MegakernelGridLimits::new(256, 65_535, 256)
284 }
285
286 fn request(queue_len: u32) -> MegakernelGridRequest {
287 MegakernelGridRequest::new(queue_len, 0)
288 }
289
290 fn geometry(slot_count: u32) -> MegakernelLaunchGeometry {
291 MegakernelLaunchGeometry {
292 workgroup_size_x: 1,
293 slot_count,
294 dispatch_grid: [1, 1, 1],
295 }
296 }
297
298 #[test]
299 fn planner_grid_cache_refreshes_hot_plan_on_hit() {
300 let mut cache = MegakernelPlannerCache::default();
301 let limits = limits();
302 let hot_key = (request(1), limits);
303 let hot_plan = MegakernelGridPlan {
304 geometry: geometry(1),
305 worker_groups: 1,
306 };
307 cache.insert_grid_plan(hot_key, hot_plan);
308 for queue_len in 2..=GRID_PLAN_CACHE_CAP as u32 {
309 cache.insert_grid_plan(
310 (request(queue_len), limits),
311 MegakernelGridPlan {
312 geometry: geometry(queue_len),
313 worker_groups: 1,
314 },
315 );
316 }
317 assert_eq!(cache.get_grid_plan(&hot_key), Some(hot_plan));
318 cache.insert_grid_plan(
319 (request((GRID_PLAN_CACHE_CAP + 1) as u32), limits),
320 MegakernelGridPlan {
321 geometry: geometry((GRID_PLAN_CACHE_CAP + 1) as u32),
322 worker_groups: 1,
323 },
324 );
325 assert_eq!(cache.get_grid_plan(&hot_key), Some(hot_plan));
326 }
327
328 #[test]
329 fn planner_geometry_cache_refreshes_hot_geometry_on_hit() {
330 let mut cache = MegakernelPlannerCache::default();
331 let hot_key = GeometryCacheKey {
332 slot_count: 1,
333 worker_count: 1,
334 max_workgroup_size_x: 256,
335 };
336 let hot_geometry = geometry(1);
337 cache.insert_geometry(hot_key, hot_geometry);
338 for slot_count in 2..=GRID_PLAN_CACHE_CAP as u32 {
339 cache.insert_geometry(
340 GeometryCacheKey {
341 slot_count,
342 worker_count: 1,
343 max_workgroup_size_x: 256,
344 },
345 geometry(slot_count),
346 );
347 }
348 assert_eq!(cache.get_geometry(&hot_key), Some(hot_geometry));
349 cache.insert_geometry(
350 GeometryCacheKey {
351 slot_count: (GRID_PLAN_CACHE_CAP + 1) as u32,
352 worker_count: 1,
353 max_workgroup_size_x: 256,
354 },
355 geometry((GRID_PLAN_CACHE_CAP + 1) as u32),
356 );
357 assert_eq!(cache.get_geometry(&hot_key), Some(hot_geometry));
358 }
359}