Skip to main content

vyre_runtime/megakernel/planner/
config.rs

1//! Megakernel launch configuration and policy request construction.
2
3use std::time::Duration;
4
5use vyre_driver::backend::BackendError;
6
7use super::super::policy::{
8    MegakernelLaunchPolicy, MegakernelLaunchRecommendation, MegakernelLaunchRequest,
9};
10use super::super::task::{TaskQueueSnapshot, TaskWorkItem};
11use super::geometry::dispatch_grid_for;
12use super::sizing::MegakernelSizingPolicy;
13
14/// Optional scale signals that let the megakernel launch policy choose sparse,
15/// dense, hybrid, fused, or memory-constrained execution from real workload
16/// shape instead of queue length alone.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18pub struct MegakernelWorkloadHints {
19    /// Count of opcodes observed hot enough for promotion consideration.
20    pub hot_opcode_count: u32,
21    /// Count of ticketed route windows observed hot enough for promotion.
22    pub hot_window_count: u32,
23    /// Resident dependency-graph node count.
24    pub graph_node_count: u32,
25    /// Resident dependency-graph edge count.
26    pub graph_edge_count: u32,
27    /// Active frontier density in basis points. Zero means infer when possible.
28    pub frontier_density_bps: u16,
29    /// Device memory pressure in basis points. Zero means infer when possible.
30    pub memory_pressure_bps: u16,
31    /// Device-resident bytes already required by this dispatch family.
32    pub resident_device_bytes: u64,
33    /// Hard device-memory budget. Zero means unbounded.
34    pub device_memory_budget_bytes: u64,
35}
36
37/// Configuration for one megakernel dispatch invocation.
38#[derive(Debug, Clone)]
39pub struct MegakernelConfig {
40    /// Number of persistent worker workgroups.
41    pub worker_count: u32,
42    /// Maximum wall-clock time the megakernel runs before draining
43    /// queued work and exiting.
44    pub max_wall_time: Duration,
45    /// Hint to the scheduler about expected items per worker.
46    pub expected_items_per_worker: u32,
47    /// Optional workload-shape hints consumed by the launch policy.
48    pub workload: MegakernelWorkloadHints,
49}
50
51impl Default for MegakernelConfig {
52    fn default() -> Self {
53        Self {
54            worker_count: MegakernelSizingPolicy::standard().default_worker_count(),
55            max_wall_time: Duration::from_secs(60),
56            expected_items_per_worker: 0,
57            workload: MegakernelWorkloadHints::default(),
58        }
59    }
60}
61
62impl MegakernelConfig {
63    /// Validate the config and surface actionable errors.
64    ///
65    /// # Errors
66    ///
67    /// Returns an error when the worker count is zero or the wall-clock budget
68    /// is empty, because either condition would make persistent dispatch
69    /// unschedulable.
70    pub fn validate(&self) -> Result<(), BackendError> {
71        if self.worker_count == 0 {
72            return Err(BackendError::new(
73                "megakernel worker_count must be non-zero. Fix: provide at least one worker workgroup.",
74            ));
75        }
76        if self.max_wall_time.is_zero() {
77            return Err(BackendError::new(
78                "megakernel max_wall_time must be non-zero. Fix: supply a positive Duration budget.",
79            ));
80        }
81        Ok(())
82    }
83
84    /// Compute the direct-dispatch grid for `queue_len` logical work slots.
85    ///
86    /// `worker_count` is the caller's persistent worker-workgroup ceiling; the
87    /// returned grid never launches more workgroups than that ceiling or the
88    /// backend occupancy cap.
89    #[must_use]
90    pub fn dispatch_grid(&self, queue_len: u32, max_workgroup_size_x: u32) -> [u32; 3] {
91        dispatch_grid_for(self.worker_count, queue_len, max_workgroup_size_x)
92    }
93
94    /// Build a policy request from this config and adapter limits.
95    #[must_use]
96    pub const fn launch_request(
97        &self,
98        queue_len: u32,
99        max_workgroup_size_x: u32,
100        max_compute_workgroups_per_dimension: u32,
101        max_compute_invocations_per_workgroup: u32,
102    ) -> MegakernelLaunchRequest {
103        MegakernelLaunchRequest {
104            queue_len,
105            requested_worker_groups: self.worker_count,
106            max_workgroup_size_x,
107            max_compute_workgroups_per_dimension,
108            max_compute_invocations_per_workgroup,
109            requested_hit_capacity: 0,
110            expected_hits_per_item: if self.expected_items_per_worker > 1 {
111                self.expected_items_per_worker
112            } else {
113                1
114            },
115            hot_opcode_count: self.workload.hot_opcode_count,
116            hot_window_count: self.workload.hot_window_count,
117            requeue_count: 0,
118            max_priority_age: 0,
119            graph_node_count: self.workload.graph_node_count,
120            graph_edge_count: self.workload.graph_edge_count,
121            frontier_density_bps: self.workload.frontier_density_bps,
122            memory_pressure_bps: self.workload.memory_pressure_bps,
123            resident_device_bytes: self.workload.resident_device_bytes,
124            device_memory_budget_bytes: self.workload.device_memory_budget_bytes,
125        }
126    }
127
128    /// Build a policy request from device-visible continuation task slots.
129    ///
130    /// Paused, completed, empty, running, and faulted tasks do not add launch
131    /// lanes. Yielded and requeued tasks stay schedulable so the GPU can resume
132    /// them without a CPU-side republish loop.
133    ///
134    /// # Errors
135    ///
136    /// Returns [`BackendError`] when a task slot contains an invalid state word.
137    pub fn launch_request_for_tasks(
138        &self,
139        tasks: &[TaskWorkItem],
140        max_workgroup_size_x: u32,
141        max_compute_workgroups_per_dimension: u32,
142        max_compute_invocations_per_workgroup: u32,
143    ) -> Result<MegakernelLaunchRequest, BackendError> {
144        let snapshot = TaskQueueSnapshot::from_tasks(tasks)?;
145        let schedulable_count = snapshot.try_schedulable_count()?;
146        let request = self.launch_request(
147            schedulable_count,
148            max_workgroup_size_x,
149            max_compute_workgroups_per_dimension,
150            max_compute_invocations_per_workgroup,
151        );
152        snapshot.try_apply_to_launch_request(request)
153    }
154
155    /// Recommend one launch shape through the shared megakernel policy.
156    ///
157    /// # Errors
158    ///
159    /// Returns [`BackendError`] when adapter limits are malformed.
160    pub fn launch_recommendation(
161        &self,
162        queue_len: u32,
163        max_workgroup_size_x: u32,
164        max_compute_workgroups_per_dimension: u32,
165        max_compute_invocations_per_workgroup: u32,
166    ) -> Result<MegakernelLaunchRecommendation, BackendError> {
167        MegakernelLaunchPolicy::standard().recommend(self.launch_request(
168            queue_len,
169            max_workgroup_size_x,
170            max_compute_workgroups_per_dimension,
171            max_compute_invocations_per_workgroup,
172        ))
173    }
174
175    /// Recommend one launch shape for a continuation task queue.
176    ///
177    /// # Errors
178    ///
179    /// Returns [`BackendError`] when adapter limits are malformed or any task
180    /// slot contains an invalid state word.
181    pub fn launch_recommendation_for_tasks(
182        &self,
183        tasks: &[TaskWorkItem],
184        max_workgroup_size_x: u32,
185        max_compute_workgroups_per_dimension: u32,
186        max_compute_invocations_per_workgroup: u32,
187    ) -> Result<MegakernelLaunchRecommendation, BackendError> {
188        MegakernelLaunchPolicy::standard().recommend(self.launch_request_for_tasks(
189            tasks,
190            max_workgroup_size_x,
191            max_compute_workgroups_per_dimension,
192            max_compute_invocations_per_workgroup,
193        )?)
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn launch_request_preserves_workload_hints() {
203        let config = MegakernelConfig {
204            workload: MegakernelWorkloadHints {
205                hot_opcode_count: 7,
206                hot_window_count: 11,
207                graph_node_count: 1_000,
208                graph_edge_count: 4_000,
209                frontier_density_bps: 7_500,
210                memory_pressure_bps: 8_000,
211                resident_device_bytes: 1 << 20,
212                device_memory_budget_bytes: 1 << 24,
213            },
214            ..MegakernelConfig::default()
215        };
216
217        let request = config.launch_request(128, 256, 65_535, 1_024);
218
219        assert_eq!(request.hot_opcode_count, 7);
220        assert_eq!(request.hot_window_count, 11);
221        assert_eq!(request.graph_node_count, 1_000);
222        assert_eq!(request.graph_edge_count, 4_000);
223        assert_eq!(request.frontier_density_bps, 7_500);
224        assert_eq!(request.memory_pressure_bps, 8_000);
225        assert_eq!(request.resident_device_bytes, 1 << 20);
226        assert_eq!(request.device_memory_budget_bytes, 1 << 24);
227    }
228}