vyre_runtime/megakernel/planner/
config.rs1use std::time::Duration;
4
5use vyre_driver::backend::BackendError;
6
7use super::super::policy::{
8 MegakernelLaunchPolicy, MegakernelLaunchRecommendation, MegakernelLaunchRequest,
9};
10use super::super::task::{TaskQueueSnapshot, TaskWorkItem};
11use super::geometry::dispatch_grid_for;
12use super::sizing::MegakernelSizingPolicy;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18pub struct MegakernelWorkloadHints {
19 pub hot_opcode_count: u32,
21 pub hot_window_count: u32,
23 pub graph_node_count: u32,
25 pub graph_edge_count: u32,
27 pub frontier_density_bps: u16,
29 pub memory_pressure_bps: u16,
31 pub resident_device_bytes: u64,
33 pub device_memory_budget_bytes: u64,
35}
36
37#[derive(Debug, Clone)]
39pub struct MegakernelConfig {
40 pub worker_count: u32,
42 pub max_wall_time: Duration,
45 pub expected_items_per_worker: u32,
47 pub workload: MegakernelWorkloadHints,
49}
50
51impl Default for MegakernelConfig {
52 fn default() -> Self {
53 Self {
54 worker_count: MegakernelSizingPolicy::standard().default_worker_count(),
55 max_wall_time: Duration::from_secs(60),
56 expected_items_per_worker: 0,
57 workload: MegakernelWorkloadHints::default(),
58 }
59 }
60}
61
62impl MegakernelConfig {
63 pub fn validate(&self) -> Result<(), BackendError> {
71 if self.worker_count == 0 {
72 return Err(BackendError::new(
73 "megakernel worker_count must be non-zero. Fix: provide at least one worker workgroup.",
74 ));
75 }
76 if self.max_wall_time.is_zero() {
77 return Err(BackendError::new(
78 "megakernel max_wall_time must be non-zero. Fix: supply a positive Duration budget.",
79 ));
80 }
81 Ok(())
82 }
83
84 #[must_use]
90 pub fn dispatch_grid(&self, queue_len: u32, max_workgroup_size_x: u32) -> [u32; 3] {
91 dispatch_grid_for(self.worker_count, queue_len, max_workgroup_size_x)
92 }
93
94 #[must_use]
96 pub const fn launch_request(
97 &self,
98 queue_len: u32,
99 max_workgroup_size_x: u32,
100 max_compute_workgroups_per_dimension: u32,
101 max_compute_invocations_per_workgroup: u32,
102 ) -> MegakernelLaunchRequest {
103 MegakernelLaunchRequest {
104 queue_len,
105 requested_worker_groups: self.worker_count,
106 max_workgroup_size_x,
107 max_compute_workgroups_per_dimension,
108 max_compute_invocations_per_workgroup,
109 requested_hit_capacity: 0,
110 expected_hits_per_item: if self.expected_items_per_worker > 1 {
111 self.expected_items_per_worker
112 } else {
113 1
114 },
115 hot_opcode_count: self.workload.hot_opcode_count,
116 hot_window_count: self.workload.hot_window_count,
117 requeue_count: 0,
118 max_priority_age: 0,
119 graph_node_count: self.workload.graph_node_count,
120 graph_edge_count: self.workload.graph_edge_count,
121 frontier_density_bps: self.workload.frontier_density_bps,
122 memory_pressure_bps: self.workload.memory_pressure_bps,
123 resident_device_bytes: self.workload.resident_device_bytes,
124 device_memory_budget_bytes: self.workload.device_memory_budget_bytes,
125 }
126 }
127
128 pub fn launch_request_for_tasks(
138 &self,
139 tasks: &[TaskWorkItem],
140 max_workgroup_size_x: u32,
141 max_compute_workgroups_per_dimension: u32,
142 max_compute_invocations_per_workgroup: u32,
143 ) -> Result<MegakernelLaunchRequest, BackendError> {
144 let snapshot = TaskQueueSnapshot::from_tasks(tasks)?;
145 let schedulable_count = snapshot.try_schedulable_count()?;
146 let request = self.launch_request(
147 schedulable_count,
148 max_workgroup_size_x,
149 max_compute_workgroups_per_dimension,
150 max_compute_invocations_per_workgroup,
151 );
152 snapshot.try_apply_to_launch_request(request)
153 }
154
155 pub fn launch_recommendation(
161 &self,
162 queue_len: u32,
163 max_workgroup_size_x: u32,
164 max_compute_workgroups_per_dimension: u32,
165 max_compute_invocations_per_workgroup: u32,
166 ) -> Result<MegakernelLaunchRecommendation, BackendError> {
167 MegakernelLaunchPolicy::standard().recommend(self.launch_request(
168 queue_len,
169 max_workgroup_size_x,
170 max_compute_workgroups_per_dimension,
171 max_compute_invocations_per_workgroup,
172 ))
173 }
174
175 pub fn launch_recommendation_for_tasks(
182 &self,
183 tasks: &[TaskWorkItem],
184 max_workgroup_size_x: u32,
185 max_compute_workgroups_per_dimension: u32,
186 max_compute_invocations_per_workgroup: u32,
187 ) -> Result<MegakernelLaunchRecommendation, BackendError> {
188 MegakernelLaunchPolicy::standard().recommend(self.launch_request_for_tasks(
189 tasks,
190 max_workgroup_size_x,
191 max_compute_workgroups_per_dimension,
192 max_compute_invocations_per_workgroup,
193 )?)
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn launch_request_preserves_workload_hints() {
203 let config = MegakernelConfig {
204 workload: MegakernelWorkloadHints {
205 hot_opcode_count: 7,
206 hot_window_count: 11,
207 graph_node_count: 1_000,
208 graph_edge_count: 4_000,
209 frontier_density_bps: 7_500,
210 memory_pressure_bps: 8_000,
211 resident_device_bytes: 1 << 20,
212 device_memory_budget_bytes: 1 << 24,
213 },
214 ..MegakernelConfig::default()
215 };
216
217 let request = config.launch_request(128, 256, 65_535, 1_024);
218
219 assert_eq!(request.hot_opcode_count, 7);
220 assert_eq!(request.hot_window_count, 11);
221 assert_eq!(request.graph_node_count, 1_000);
222 assert_eq!(request.graph_edge_count, 4_000);
223 assert_eq!(request.frontier_density_bps, 7_500);
224 assert_eq!(request.memory_pressure_bps, 8_000);
225 assert_eq!(request.resident_device_bytes, 1 << 20);
226 assert_eq!(request.device_memory_budget_bytes, 1 << 24);
227 }
228}