vyre_runtime/megakernel/planner/
geometry.rs1use std::time::Duration;
4
5use vyre_driver::backend::{BackendError, DispatchConfig};
6
7use super::grid::cached_geometry_from_slots;
8use super::sizing::MegakernelSizingPolicy;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct MegakernelLaunchGeometry {
13 pub workgroup_size_x: u32,
15 pub slot_count: u32,
17 pub dispatch_grid: [u32; 3],
19}
20
21impl MegakernelLaunchGeometry {
22 pub fn from_item_count(
29 item_count: usize,
30 worker_count: u32,
31 max_workgroup_size_x: u32,
32 ) -> Result<Self, BackendError> {
33 let item_count = u32::try_from(item_count).map_err(|_| {
34 BackendError::new(
35 "megakernel work queue length exceeds u32::MAX. Fix: shard the queue before dispatch.",
36 )
37 })?;
38 let geometry = Self::from_slots(item_count, worker_count, max_workgroup_size_x);
39 if geometry.slot_count < item_count {
40 return Err(BackendError::new(
41 "megakernel work queue cannot be padded inside the u32 ring protocol. Fix: shard the queue before dispatch.",
42 ));
43 }
44 Ok(geometry)
45 }
46
47 #[must_use]
49 pub fn from_slots(slot_count: u32, worker_count: u32, max_workgroup_size_x: u32) -> Self {
50 cached_geometry_from_slots(slot_count, worker_count, max_workgroup_size_x)
51 }
52
53 #[must_use]
55 pub const fn covering_worker_groups(&self) -> u32 {
56 self.slot_count / self.workgroup_size_x
57 }
58
59 #[must_use]
61 pub fn dispatch_config(&self, timeout: Option<Duration>) -> DispatchConfig {
62 let mut config = DispatchConfig::default();
63 config.timeout = timeout;
64 config.grid_override = Some(self.dispatch_grid);
65 config.workgroup_override = Some([self.workgroup_size_x, 1, 1]);
66 config
67 }
68}
69
70#[must_use]
73pub fn worker_workgroup_size(worker_count: u32, max_workgroup_size_x: u32) -> u32 {
74 MegakernelSizingPolicy::standard().worker_workgroup_size(worker_count, max_workgroup_size_x)
75}
76
77#[must_use]
79pub fn padded_slot_count(slot_count: u32, workgroup_size_x: u32) -> u32 {
80 MegakernelSizingPolicy::standard().padded_slot_count(slot_count, workgroup_size_x)
81}
82
83#[must_use]
85pub fn dispatch_grid_for(worker_count: u32, queue_len: u32, max_workgroup_size_x: u32) -> [u32; 3] {
86 MegakernelSizingPolicy::standard().dispatch_grid_for(
87 worker_count,
88 queue_len,
89 max_workgroup_size_x,
90 )
91}
92
93#[must_use]
99pub fn default_worker_groups_from_limits(
100 max_compute_workgroups_per_dimension: u32,
101 max_compute_invocations_per_workgroup: u32,
102) -> u32 {
103 MegakernelSizingPolicy::standard().default_worker_groups_from_limits(
104 max_compute_workgroups_per_dimension,
105 max_compute_invocations_per_workgroup,
106 )
107}