Skip to main content

vyre_runtime/megakernel/planner/
sizing.rs

1use vyre_driver::backend::BackendError;
2use vyre_foundation::execution_plan::SchedulingPolicy;
3
4use super::{
5    MegakernelGridLimits, MegakernelGridPlan, MegakernelGridRequest, MegakernelLaunchGeometry,
6};
7
8/// Shared worker-grid sizing policy for megakernel dispatch.
9///
10/// This is the host-side policy surface for persistent worker counts,
11/// workgroup width, slot padding, and backend grid geometry.
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct MegakernelSizingPolicy {
14    scheduling: SchedulingPolicy,
15}
16
17impl Default for MegakernelSizingPolicy {
18    fn default() -> Self {
19        Self::standard()
20    }
21}
22
23impl MegakernelSizingPolicy {
24    /// Standard megakernel sizing policy used by built-in dispatch paths.
25    #[must_use]
26    pub const fn standard() -> Self {
27        Self {
28            scheduling: SchedulingPolicy::standard(),
29        }
30    }
31
32    /// Build from a shared backend-neutral scheduling policy.
33    #[must_use]
34    pub const fn from_scheduling(scheduling: SchedulingPolicy) -> Self {
35        Self { scheduling }
36    }
37
38    /// Default persistent worker workgroup count.
39    #[must_use]
40    pub const fn default_worker_count(&self) -> u32 {
41        self.scheduling.default_worker_count()
42    }
43
44    /// Clamp a requested worker count into the legal workgroup x dimension.
45    #[must_use]
46    pub const fn worker_workgroup_size(&self, worker_count: u32, max_workgroup_size_x: u32) -> u32 {
47        self.scheduling
48            .worker_workgroup_size(worker_count, max_workgroup_size_x)
49    }
50
51    /// Round a logical slot count up to a whole worker workgroup.
52    #[must_use]
53    pub const fn padded_slot_count(&self, slot_count: u32, workgroup_size_x: u32) -> u32 {
54        self.scheduling
55            .padded_slot_count(slot_count, workgroup_size_x)
56    }
57
58    /// Compute the backend dispatch grid for a logical queue length.
59    #[must_use]
60    pub const fn dispatch_grid_for(
61        &self,
62        worker_count: u32,
63        queue_len: u32,
64        max_workgroup_size_x: u32,
65    ) -> [u32; 3] {
66        self.scheduling
67            .dispatch_grid_for(worker_count, queue_len, max_workgroup_size_x)
68    }
69
70    /// Compute a persistent-worker ceiling from adapter limits.
71    #[must_use]
72    pub const fn default_worker_groups_from_limits(
73        &self,
74        max_compute_workgroups_per_dimension: u32,
75        max_compute_invocations_per_workgroup: u32,
76    ) -> u32 {
77        self.scheduling.default_worker_groups_from_limits(
78            max_compute_workgroups_per_dimension,
79            max_compute_invocations_per_workgroup,
80        )
81    }
82
83    /// Resolve worker groups, workgroup width, slot padding, and dispatch grid.
84    ///
85    /// # Errors
86    ///
87    /// Returns [`BackendError`] when adapter limits are malformed.
88    pub fn calculate_optimal_grid(
89        &self,
90        request: MegakernelGridRequest,
91        limits: MegakernelGridLimits,
92    ) -> Result<MegakernelGridPlan, BackendError> {
93        limits.validate()?;
94
95        let occupancy_worker_groups = self
96            .default_worker_groups_from_limits(
97                limits.max_compute_workgroups_per_dimension,
98                limits.max_compute_invocations_per_workgroup,
99            )
100            .min(limits.max_compute_workgroups_per_dimension);
101
102        let worker_groups = if request.requested_worker_groups == 0 {
103            occupancy_worker_groups
104        } else {
105            request
106                .requested_worker_groups
107                .min(limits.max_compute_workgroups_per_dimension)
108        }
109        .max(1);
110
111        let geometry = self.geometry_from_slots(
112            request.queue_len.max(1),
113            worker_groups,
114            limits.max_workgroup_size_x,
115        );
116
117        Ok(MegakernelGridPlan {
118            geometry,
119            worker_groups,
120        })
121    }
122
123    /// Build geometry for an already-sized ring.
124    #[must_use]
125    pub fn geometry_from_slots(
126        &self,
127        slot_count: u32,
128        worker_count: u32,
129        max_workgroup_size_x: u32,
130    ) -> MegakernelLaunchGeometry {
131        let workgroup_size_x = self.worker_workgroup_size(worker_count, max_workgroup_size_x);
132        let slot_count = self.padded_slot_count(slot_count, workgroup_size_x);
133        let dispatch_grid = self.dispatch_grid_for(worker_count, slot_count, workgroup_size_x);
134        MegakernelLaunchGeometry {
135            workgroup_size_x,
136            slot_count,
137            dispatch_grid,
138        }
139    }
140}