Skip to main content

vyre_runtime/megakernel/planner/
geometry.rs

1//! Megakernel launch geometry helpers.
2
3use std::time::Duration;
4
5use vyre_driver::backend::{BackendError, DispatchConfig};
6
7use super::grid::cached_geometry_from_slots;
8use super::sizing::MegakernelSizingPolicy;
9
10/// Host-side launch geometry for a finite megakernel dispatch.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct MegakernelLaunchGeometry {
13    /// Lanes per worker workgroup used to compile the program.
14    pub workgroup_size_x: u32,
15    /// Ring slots allocated for the dispatch, padded to a full workgroup.
16    pub slot_count: u32,
17    /// Grid submitted to the backend.
18    pub dispatch_grid: [u32; 3],
19}
20
21impl MegakernelLaunchGeometry {
22    /// Build geometry for `item_count` host work items.
23    ///
24    /// # Errors
25    ///
26    /// Returns [`BackendError`] when the host queue cannot be represented by
27    /// the u32 ring protocol.
28    pub fn from_item_count(
29        item_count: usize,
30        worker_count: u32,
31        max_workgroup_size_x: u32,
32    ) -> Result<Self, BackendError> {
33        let item_count = u32::try_from(item_count).map_err(|_| {
34            BackendError::new(
35                "megakernel work queue length exceeds u32::MAX. Fix: shard the queue before dispatch.",
36            )
37        })?;
38        let geometry = Self::from_slots(item_count, worker_count, max_workgroup_size_x);
39        if geometry.slot_count < item_count {
40            return Err(BackendError::new(
41                "megakernel work queue cannot be padded inside the u32 ring protocol. Fix: shard the queue before dispatch.",
42            ));
43        }
44        Ok(geometry)
45    }
46
47    /// Build geometry for an already-sized ring.
48    #[must_use]
49    pub fn from_slots(slot_count: u32, worker_count: u32, max_workgroup_size_x: u32) -> Self {
50        cached_geometry_from_slots(slot_count, worker_count, max_workgroup_size_x)
51    }
52
53    /// Number of worker workgroups needed to cover every ring slot exactly once.
54    #[must_use]
55    pub const fn covering_worker_groups(&self) -> u32 {
56        self.slot_count / self.workgroup_size_x
57    }
58
59    /// Build the backend dispatch config that matches this launch geometry.
60    #[must_use]
61    pub fn dispatch_config(&self, timeout: Option<Duration>) -> DispatchConfig {
62        let mut config = DispatchConfig::default();
63        config.timeout = timeout;
64        config.grid_override = Some(self.dispatch_grid);
65        config.workgroup_override = Some([self.workgroup_size_x, 1, 1]);
66        config
67    }
68}
69
70/// Clamp the caller's worker setting into the legal x dimension used by the
71/// current megakernel ABI.
72#[must_use]
73pub fn worker_workgroup_size(worker_count: u32, max_workgroup_size_x: u32) -> u32 {
74    MegakernelSizingPolicy::standard().worker_workgroup_size(worker_count, max_workgroup_size_x)
75}
76
77/// Round a logical slot count up to a whole workgroup.
78#[must_use]
79pub fn padded_slot_count(slot_count: u32, workgroup_size_x: u32) -> u32 {
80    MegakernelSizingPolicy::standard().padded_slot_count(slot_count, workgroup_size_x)
81}
82
83/// Compute the backend dispatch grid for a logical queue length.
84#[must_use]
85pub fn dispatch_grid_for(worker_count: u32, queue_len: u32, max_workgroup_size_x: u32) -> [u32; 3] {
86    MegakernelSizingPolicy::standard().dispatch_grid_for(
87        worker_count,
88        queue_len,
89        max_workgroup_size_x,
90    )
91}
92
93/// Compute a persistent-worker ceiling from adapter limits.
94///
95/// This is the single host-side policy used by runtime batch dispatchers and
96/// direct megakernel dispatch. Callers can still clamp further through
97/// `MegakernelConfig::worker_count`, but occupancy heuristics live here.
98#[must_use]
99pub fn default_worker_groups_from_limits(
100    max_compute_workgroups_per_dimension: u32,
101    max_compute_invocations_per_workgroup: u32,
102) -> u32 {
103    MegakernelSizingPolicy::standard().default_worker_groups_from_limits(
104        max_compute_workgroups_per_dimension,
105        max_compute_invocations_per_workgroup,
106    )
107}