vyre_runtime/megakernel/planner/
sizing.rs1use vyre_driver::backend::BackendError;
2use vyre_foundation::execution_plan::SchedulingPolicy;
3
4use super::{
5 MegakernelGridLimits, MegakernelGridPlan, MegakernelGridRequest, MegakernelLaunchGeometry,
6};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub struct MegakernelSizingPolicy {
14 scheduling: SchedulingPolicy,
15}
16
17impl Default for MegakernelSizingPolicy {
18 fn default() -> Self {
19 Self::standard()
20 }
21}
22
23impl MegakernelSizingPolicy {
24 #[must_use]
26 pub const fn standard() -> Self {
27 Self {
28 scheduling: SchedulingPolicy::standard(),
29 }
30 }
31
32 #[must_use]
34 pub const fn from_scheduling(scheduling: SchedulingPolicy) -> Self {
35 Self { scheduling }
36 }
37
38 #[must_use]
40 pub const fn default_worker_count(&self) -> u32 {
41 self.scheduling.default_worker_count()
42 }
43
44 #[must_use]
46 pub const fn worker_workgroup_size(&self, worker_count: u32, max_workgroup_size_x: u32) -> u32 {
47 self.scheduling
48 .worker_workgroup_size(worker_count, max_workgroup_size_x)
49 }
50
51 #[must_use]
53 pub const fn padded_slot_count(&self, slot_count: u32, workgroup_size_x: u32) -> u32 {
54 self.scheduling
55 .padded_slot_count(slot_count, workgroup_size_x)
56 }
57
58 #[must_use]
60 pub const fn dispatch_grid_for(
61 &self,
62 worker_count: u32,
63 queue_len: u32,
64 max_workgroup_size_x: u32,
65 ) -> [u32; 3] {
66 self.scheduling
67 .dispatch_grid_for(worker_count, queue_len, max_workgroup_size_x)
68 }
69
70 #[must_use]
72 pub const fn default_worker_groups_from_limits(
73 &self,
74 max_compute_workgroups_per_dimension: u32,
75 max_compute_invocations_per_workgroup: u32,
76 ) -> u32 {
77 self.scheduling.default_worker_groups_from_limits(
78 max_compute_workgroups_per_dimension,
79 max_compute_invocations_per_workgroup,
80 )
81 }
82
83 pub fn calculate_optimal_grid(
89 &self,
90 request: MegakernelGridRequest,
91 limits: MegakernelGridLimits,
92 ) -> Result<MegakernelGridPlan, BackendError> {
93 limits.validate()?;
94
95 let occupancy_worker_groups = self
96 .default_worker_groups_from_limits(
97 limits.max_compute_workgroups_per_dimension,
98 limits.max_compute_invocations_per_workgroup,
99 )
100 .min(limits.max_compute_workgroups_per_dimension);
101
102 let worker_groups = if request.requested_worker_groups == 0 {
103 occupancy_worker_groups
104 } else {
105 request
106 .requested_worker_groups
107 .min(limits.max_compute_workgroups_per_dimension)
108 }
109 .max(1);
110
111 let geometry = self.geometry_from_slots(
112 request.queue_len.max(1),
113 worker_groups,
114 limits.max_workgroup_size_x,
115 );
116
117 Ok(MegakernelGridPlan {
118 geometry,
119 worker_groups,
120 })
121 }
122
123 #[must_use]
125 pub fn geometry_from_slots(
126 &self,
127 slot_count: u32,
128 worker_count: u32,
129 max_workgroup_size_x: u32,
130 ) -> MegakernelLaunchGeometry {
131 let workgroup_size_x = self.worker_workgroup_size(worker_count, max_workgroup_size_x);
132 let slot_count = self.padded_slot_count(slot_count, workgroup_size_x);
133 let dispatch_grid = self.dispatch_grid_for(worker_count, slot_count, workgroup_size_x);
134 MegakernelLaunchGeometry {
135 workgroup_size_x,
136 slot_count,
137 dispatch_grid,
138 }
139 }
140}