Skip to main content

vyre_runtime/megakernel/planner/
grid.rs

1//! Megakernel grid request, limits, plan cache, and recommendation surface.
2
3use std::cell::RefCell;
4
5use rustc_hash::FxHashMap;
6use vyre_driver::backend::BackendError;
7
8use super::geometry::MegakernelLaunchGeometry;
9use super::sizing::MegakernelSizingPolicy;
10
11/// Adapter limits that bound a megakernel worker-grid recommendation.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct MegakernelGridLimits {
14    /// Adapter maximum workgroup size in the x dimension.
15    pub max_workgroup_size_x: u32,
16    /// Adapter maximum compute workgroups per dimension.
17    pub max_compute_workgroups_per_dimension: u32,
18    /// Adapter maximum invocations per compute workgroup.
19    pub max_compute_invocations_per_workgroup: u32,
20}
21
22const GRID_PLAN_CACHE_CAP: usize = 128;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25struct GeometryCacheKey {
26    slot_count: u32,
27    worker_count: u32,
28    max_workgroup_size_x: u32,
29}
30
31struct MegakernelPlannerCache {
32    grid_plans:
33        FxHashMap<(MegakernelGridRequest, MegakernelGridLimits), CacheEntry<MegakernelGridPlan>>,
34    geometries: FxHashMap<GeometryCacheKey, CacheEntry<MegakernelLaunchGeometry>>,
35    clock: u64,
36}
37
38struct CacheEntry<T> {
39    value: T,
40    last_seen: u64,
41}
42
43impl MegakernelPlannerCache {
44    fn get_grid_plan(
45        &mut self,
46        key: &(MegakernelGridRequest, MegakernelGridLimits),
47    ) -> Option<MegakernelGridPlan> {
48        self.prepare_cache_hit_tick();
49        let entry = self.grid_plans.get_mut(key)?;
50        self.clock += 1;
51        entry.last_seen = self.clock;
52        Some(entry.value)
53    }
54
55    fn insert_grid_plan(
56        &mut self,
57        key: (MegakernelGridRequest, MegakernelGridLimits),
58        value: MegakernelGridPlan,
59    ) {
60        let tick = self.next_tick();
61        self.grid_plans.insert(
62            key,
63            CacheEntry {
64                value,
65                last_seen: tick,
66            },
67        );
68        self.evict_grid_plans_to_cap();
69    }
70
71    fn get_geometry(&mut self, key: &GeometryCacheKey) -> Option<MegakernelLaunchGeometry> {
72        self.prepare_cache_hit_tick();
73        let entry = self.geometries.get_mut(key)?;
74        self.clock += 1;
75        entry.last_seen = self.clock;
76        Some(entry.value)
77    }
78
79    fn insert_geometry(&mut self, key: GeometryCacheKey, value: MegakernelLaunchGeometry) {
80        let tick = self.next_tick();
81        self.geometries.insert(
82            key,
83            CacheEntry {
84                value,
85                last_seen: tick,
86            },
87        );
88        self.evict_geometries_to_cap();
89    }
90
91    fn evict_grid_plans_to_cap(&mut self) {
92        while self.grid_plans.len() > GRID_PLAN_CACHE_CAP {
93            let Some(evicted) = self
94                .grid_plans
95                .iter()
96                .min_by_key(|(_, entry)| entry.last_seen)
97                .map(|(key, _)| *key)
98            else {
99                break;
100            };
101            self.grid_plans.remove(&evicted);
102        }
103    }
104
105    fn evict_geometries_to_cap(&mut self) {
106        while self.geometries.len() > GRID_PLAN_CACHE_CAP {
107            let Some(evicted) = self
108                .geometries
109                .iter()
110                .min_by_key(|(_, entry)| entry.last_seen)
111                .map(|(key, _)| *key)
112            else {
113                break;
114            };
115            self.geometries.remove(&evicted);
116        }
117    }
118
119    fn next_tick(&mut self) -> u64 {
120        self.prepare_cache_hit_tick();
121        self.clock += 1;
122        self.clock
123    }
124
125    fn prepare_cache_hit_tick(&mut self) {
126        if self.clock == u64::MAX {
127            self.clock = 0;
128            for entry in self.grid_plans.values_mut() {
129                entry.last_seen = 0;
130            }
131            for entry in self.geometries.values_mut() {
132                entry.last_seen = 0;
133            }
134        }
135    }
136}
137
138impl Default for MegakernelPlannerCache {
139    fn default() -> Self {
140        Self {
141            grid_plans: FxHashMap::with_capacity_and_hasher(
142                GRID_PLAN_CACHE_CAP,
143                Default::default(),
144            ),
145            geometries: FxHashMap::with_capacity_and_hasher(
146                GRID_PLAN_CACHE_CAP,
147                Default::default(),
148            ),
149            clock: 0,
150        }
151    }
152}
153
154thread_local! {
155    static PLANNER_CACHE: RefCell<MegakernelPlannerCache> = RefCell::new(MegakernelPlannerCache::default());
156}
157
158fn cached_grid_plan(
159    request: MegakernelGridRequest,
160    limits: MegakernelGridLimits,
161) -> Result<MegakernelGridPlan, BackendError> {
162    if let Some(plan) =
163        PLANNER_CACHE.with(|cache| cache.borrow_mut().get_grid_plan(&(request, limits)))
164    {
165        return Ok(plan);
166    }
167
168    let plan = MegakernelSizingPolicy::standard().calculate_optimal_grid(request, limits)?;
169    PLANNER_CACHE.with(|cache| {
170        cache.borrow_mut().insert_grid_plan((request, limits), plan);
171    });
172    Ok(plan)
173}
174
175pub(super) fn cached_geometry_from_slots(
176    slot_count: u32,
177    worker_count: u32,
178    max_workgroup_size_x: u32,
179) -> MegakernelLaunchGeometry {
180    let key = GeometryCacheKey {
181        slot_count,
182        worker_count,
183        max_workgroup_size_x,
184    };
185    if let Some(geometry) = PLANNER_CACHE.with(|cache| cache.borrow_mut().get_geometry(&key)) {
186        return geometry;
187    }
188
189    let geometry = MegakernelSizingPolicy::standard().geometry_from_slots(
190        slot_count,
191        worker_count,
192        max_workgroup_size_x,
193    );
194    PLANNER_CACHE.with(|cache| {
195        cache.borrow_mut().insert_geometry(key, geometry);
196    });
197    geometry
198}
199
200impl MegakernelGridLimits {
201    /// Construct megakernel grid limits from backend adapter limits.
202    #[must_use]
203    pub const fn new(
204        max_workgroup_size_x: u32,
205        max_compute_workgroups_per_dimension: u32,
206        max_compute_invocations_per_workgroup: u32,
207    ) -> Self {
208        Self {
209            max_workgroup_size_x,
210            max_compute_workgroups_per_dimension,
211            max_compute_invocations_per_workgroup,
212        }
213    }
214
215    pub(super) fn validate(self) -> Result<(), BackendError> {
216        if self.max_workgroup_size_x == 0 {
217            return Err(BackendError::new(
218                "megakernel max_workgroup_size_x must be non-zero. Fix: pass live adapter limits instead of a zero limit.",
219            ));
220        }
221        if self.max_compute_workgroups_per_dimension == 0 {
222            return Err(BackendError::new(
223                "megakernel max_compute_workgroups_per_dimension must be non-zero. Fix: pass live adapter limits instead of a zero limit.",
224            ));
225        }
226        if self.max_compute_invocations_per_workgroup == 0 {
227            return Err(BackendError::new(
228                "megakernel max_compute_invocations_per_workgroup must be non-zero. Fix: pass live adapter limits instead of a zero limit.",
229            ));
230        }
231        Ok(())
232    }
233}
234
235/// Logical work shape requested for a megakernel worker-grid recommendation.
236#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
237pub struct MegakernelGridRequest {
238    /// Logical ring slots or work items queued for this launch.
239    pub queue_len: u32,
240    /// Caller-requested worker workgroup ceiling. Zero means derive from occupancy.
241    pub requested_worker_groups: u32,
242}
243
244impl MegakernelGridRequest {
245    /// Construct a worker-grid request.
246    #[must_use]
247    pub const fn new(queue_len: u32, requested_worker_groups: u32) -> Self {
248        Self {
249            queue_len,
250            requested_worker_groups,
251        }
252    }
253}
254
255/// Resolved worker-grid plan shared by direct and policy-driven megakernel paths.
256#[derive(Debug, Clone, Copy, PartialEq, Eq)]
257pub struct MegakernelGridPlan {
258    /// Padded launch geometry for the ring protocol.
259    pub geometry: MegakernelLaunchGeometry,
260    /// Worker workgroups selected for the dispatch.
261    pub worker_groups: u32,
262}
263
264impl MegakernelGridPlan {
265    /// Resolve worker groups, workgroup width, slot padding, and dispatch grid.
266    ///
267    /// # Errors
268    ///
269    /// Returns [`BackendError`] when adapter limits are malformed.
270    pub fn recommend(
271        request: MegakernelGridRequest,
272        limits: MegakernelGridLimits,
273    ) -> Result<Self, BackendError> {
274        cached_grid_plan(request, limits)
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    fn limits() -> MegakernelGridLimits {
283        MegakernelGridLimits::new(256, 65_535, 256)
284    }
285
286    fn request(queue_len: u32) -> MegakernelGridRequest {
287        MegakernelGridRequest::new(queue_len, 0)
288    }
289
290    fn geometry(slot_count: u32) -> MegakernelLaunchGeometry {
291        MegakernelLaunchGeometry {
292            workgroup_size_x: 1,
293            slot_count,
294            dispatch_grid: [1, 1, 1],
295        }
296    }
297
298    #[test]
299    fn planner_grid_cache_refreshes_hot_plan_on_hit() {
300        let mut cache = MegakernelPlannerCache::default();
301        let limits = limits();
302        let hot_key = (request(1), limits);
303        let hot_plan = MegakernelGridPlan {
304            geometry: geometry(1),
305            worker_groups: 1,
306        };
307        cache.insert_grid_plan(hot_key, hot_plan);
308        for queue_len in 2..=GRID_PLAN_CACHE_CAP as u32 {
309            cache.insert_grid_plan(
310                (request(queue_len), limits),
311                MegakernelGridPlan {
312                    geometry: geometry(queue_len),
313                    worker_groups: 1,
314                },
315            );
316        }
317        assert_eq!(cache.get_grid_plan(&hot_key), Some(hot_plan));
318        cache.insert_grid_plan(
319            (request((GRID_PLAN_CACHE_CAP + 1) as u32), limits),
320            MegakernelGridPlan {
321                geometry: geometry((GRID_PLAN_CACHE_CAP + 1) as u32),
322                worker_groups: 1,
323            },
324        );
325        assert_eq!(cache.get_grid_plan(&hot_key), Some(hot_plan));
326    }
327
328    #[test]
329    fn planner_geometry_cache_refreshes_hot_geometry_on_hit() {
330        let mut cache = MegakernelPlannerCache::default();
331        let hot_key = GeometryCacheKey {
332            slot_count: 1,
333            worker_count: 1,
334            max_workgroup_size_x: 256,
335        };
336        let hot_geometry = geometry(1);
337        cache.insert_geometry(hot_key, hot_geometry);
338        for slot_count in 2..=GRID_PLAN_CACHE_CAP as u32 {
339            cache.insert_geometry(
340                GeometryCacheKey {
341                    slot_count,
342                    worker_count: 1,
343                    max_workgroup_size_x: 256,
344                },
345                geometry(slot_count),
346            );
347        }
348        assert_eq!(cache.get_geometry(&hot_key), Some(hot_geometry));
349        cache.insert_geometry(
350            GeometryCacheKey {
351                slot_count: (GRID_PLAN_CACHE_CAP + 1) as u32,
352                worker_count: 1,
353                max_workgroup_size_x: 256,
354            },
355            geometry((GRID_PLAN_CACHE_CAP + 1) as u32),
356        );
357        assert_eq!(cache.get_geometry(&hot_key), Some(hot_geometry));
358    }
359}