vyre_driver/program_walks/
dispatch_params.rs1use vyre_foundation::ir::Program;
4
5use crate::binding::{Binding, BindingRole};
6
7#[must_use]
9pub fn dispatch_element_count(bindings: &[Binding]) -> u32 {
10 dispatch_element_count_inner(bindings, false)
11}
12
13#[must_use]
15pub fn dispatch_element_count_for_program(program: &Program, bindings: &[Binding]) -> u32 {
16 dispatch_element_count_inner(bindings, program_contains_atomic(program))
17}
18
19fn dispatch_element_count_inner(bindings: &[Binding], force_full_span: bool) -> u32 {
20 let mut any_shared = false;
26 let mut max_non_shared: u32 = 0;
27 let mut max_output: u32 = 0;
28 for binding in bindings {
29 if binding.role == BindingRole::Shared {
30 any_shared = true;
31 continue;
32 }
33 if binding.element_count > max_non_shared {
34 max_non_shared = binding.element_count;
35 }
36 if matches!(binding.role, BindingRole::Output | BindingRole::InputOutput)
37 && binding.element_count > max_output
38 {
39 max_output = binding.element_count;
40 }
41 }
42 if any_shared || force_full_span {
43 return max_non_shared.max(1);
44 }
45 if max_output > 0 {
46 return max_output;
47 }
48 max_non_shared.max(1)
49}
50
51fn program_contains_atomic(program: &Program) -> bool {
52 program.stats().atomic_op_count > 0
57}
58
59#[must_use]
61pub fn dispatch_param_words(bindings: &[Binding], element_count: u32) -> Vec<u32> {
62 match try_dispatch_param_words(bindings, element_count) {
63 Ok(words) => words,
64 Err(_error) => Vec::new(),
65 }
66}
67
68pub fn try_dispatch_param_words(
71 bindings: &[Binding],
72 element_count: u32,
73) -> Result<Vec<u32>, String> {
74 let mut words = Vec::new();
75 try_dispatch_param_words_into(bindings, element_count, &mut words)?;
76 Ok(words)
77}
78
79pub fn dispatch_param_words_into(bindings: &[Binding], element_count: u32, words: &mut Vec<u32>) {
81 if try_dispatch_param_words_into(bindings, element_count, words).is_err() {
82 words.clear();
83 }
84}
85
86pub fn try_dispatch_param_words_into(
89 bindings: &[Binding],
90 element_count: u32,
91 words: &mut Vec<u32>,
92) -> Result<(), String> {
93 let word_len = dispatch_param_word_len_for_bindings(bindings)?;
94 reserve_dispatch_param_words(words, word_len)?;
95 words.clear();
96 words.resize(word_len, 0);
97 words[0] = element_count;
98 for binding in bindings {
99 if binding.role == BindingRole::Shared {
100 continue;
101 }
102 let slot = dispatch_param_word_slot(binding)?;
103 words[slot] = if binding.element_count == 0 {
104 element_count
105 } else {
106 binding.element_count
107 };
108 }
109 Ok(())
110}
111
112fn dispatch_param_word_len_for_bindings(bindings: &[Binding]) -> Result<usize, String> {
113 let mut word_len = dispatch_param_word_len_checked(bindings.len())?;
114 for binding in bindings {
115 if binding.role == BindingRole::Shared {
116 continue;
117 }
118 let required = dispatch_param_word_slot(binding)?
119 .checked_add(1)
120 .ok_or_else(|| {
121 format!(
122 "dispatch binding slot {} overflows ABI parameter word count. Fix: split the Program before launch-parameter planning.",
123 binding.binding
124 )
125 })?;
126 if required > word_len {
127 word_len = required;
128 }
129 }
130 Ok(word_len)
131}
132
133fn dispatch_param_word_slot(binding: &Binding) -> Result<usize, String> {
134 let slot = usize::try_from(binding.binding).map_err(|error| {
135 format!(
136 "dispatch binding slot {} does not fit host usize ({error}). Fix: split the Program before launch-parameter planning.",
137 binding.binding
138 )
139 })?;
140 slot.checked_add(1).ok_or_else(|| {
141 format!(
142 "dispatch binding slot {} overflows ABI parameter slot. Fix: split the Program before launch-parameter planning.",
143 binding.binding
144 )
145 })
146}
147
148fn dispatch_param_word_len_checked(binding_count: usize) -> Result<usize, String> {
149 binding_count.checked_add(1).ok_or_else(|| {
150 format!(
151 "dispatch binding count {binding_count} overflows ABI parameter word count. Fix: split the Program before launch-parameter planning."
152 )
153 })
154}
155
156fn reserve_dispatch_param_words(words: &mut Vec<u32>, word_len: usize) -> Result<(), String> {
157 crate::allocation::try_reserve_vec_to_capacity(words, word_len).map_err(|error| {
158 format!(
159 "dispatch parameter staging could not reserve {word_len} u32 word(s): {error}. Fix: split the Program before launch-parameter planning."
160 )
161 })
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use crate::binding::BindingRole;
168 use std::sync::Arc;
169
170 fn binding(buffer_index: usize, element_count: u32) -> Binding {
171 Binding {
172 name: Arc::from("test"),
173 binding: u32::try_from(buffer_index).expect("Fix: test binding index fits u32"),
174 buffer_index,
175 role: BindingRole::Input,
176 element_size: 4,
177 preferred_alignment: 4,
178 element_count,
179 static_byte_len: None,
180 input_index: Some(0),
181 output_index: None,
182 }
183 }
184
185 #[test]
186 fn dispatch_params_support_sparse_binding_indices_without_repeated_growth() {
187 let bindings = [binding(4, 9), binding(1, 0)];
188 let words = try_dispatch_param_words(&bindings, 7)
189 .expect("Fix: sparse binding parameter words should stage");
190
191 assert_eq!(words, vec![7, 0, 7, 0, 0, 9]);
192 }
193
194 #[test]
195 fn dispatch_params_are_indexed_by_binding_slot_not_program_buffer_index() {
196 let mut sparse = binding(0, 4);
197 sparse.binding = 9;
198 let mut dynamic = binding(7, 0);
199 dynamic.binding = 2;
200 let words = try_dispatch_param_words(&[sparse, dynamic], 11)
201 .expect("Fix: sparse binding-slot parameter words should stage");
202
203 assert_eq!(words.len(), 11);
204 assert_eq!(words[0], 11);
205 assert_eq!(words[3], 11);
206 assert_eq!(words[10], 4);
207 assert_eq!(
208 words[8], 0,
209 "Fix: CUDA/PTX parameter words are indexed by binding slot, not buffer_index."
210 );
211 }
212
213 #[test]
214 fn generated_dispatch_params_cover_sparse_binding_slot_matrix() {
215 let mut checked = 0usize;
216 for seed in 0..4096u32 {
217 let binding_count = (seed as usize % 8) + 1;
218 let mut bindings = Vec::with_capacity(binding_count);
219 for index in 0..binding_count {
220 let mut item = binding(
221 index,
222 if index % 3 == 0 {
223 0
224 } else {
225 seed + index as u32
226 },
227 );
228 item.binding = ((seed.wrapping_mul(17) + (index as u32 * 97)) % 1024) + 1;
229 item.buffer_index = binding_count - 1 - index;
230 bindings.push(item);
231 }
232 let element_count = seed.wrapping_mul(31) | 1;
233 let words = try_dispatch_param_words(&bindings, element_count)
234 .expect("Fix: generated sparse binding-slot param words should stage.");
235 assert_eq!(words[0], element_count, "seed {seed}");
236 for item in &bindings {
237 let slot = usize::try_from(item.binding).expect("Fix: test binding fits usize") + 1;
238 let expected = if item.element_count == 0 {
239 element_count
240 } else {
241 item.element_count
242 };
243 assert_eq!(
244 words[slot], expected,
245 "Fix: generated dispatch-param seed {seed} binding slot {} must map to words[slot+1], regardless of buffer_index {}.",
246 item.binding, item.buffer_index
247 );
248 checked += 1;
249 }
250 }
251 assert!(
252 checked >= 18_000,
253 "Fix: generated dispatch-param ABI coverage should exercise thousands of sparse binding layouts, got {checked}."
254 );
255 }
256
257 #[test]
258 fn dispatch_params_source_keeps_fallible_modular_staging() {
259 let source = include_str!("dispatch_params.rs");
260 let production = source
261 .split("#[cfg(test)]")
262 .next()
263 .expect("Fix: dispatch-param source must contain production section before tests");
264
265 assert!(
266 production.contains("pub fn try_dispatch_param_words")
267 && production.contains("pub fn try_dispatch_param_words_into")
268 && production.contains("fn dispatch_param_word_len_for_bindings")
269 && production.contains("fn reserve_dispatch_param_words"),
270 "Fix: dispatch parameter planning must expose modular fallible staging APIs."
271 );
272 assert!(
273 !production.contains("Vec::with_capacity")
274 && !production.contains("words.resize(binding.buffer_index + 2, 0)")
275 && !production.contains("panic!("),
276 "Fix: dispatch parameter planning must not allocate infallibly, grow repeatedly, or panic in release-path helpers."
277 );
278 }
279}