Skip to main content

vyre_driver/program_walks/
dispatch_params.rs

1//! Dispatch ABI parameter derivation from binding plans.
2
3use vyre_foundation::ir::Program;
4
5use crate::binding::{Binding, BindingRole};
6
7/// Derive the dispatch element count from a binding plan.
8#[must_use]
9pub fn dispatch_element_count(bindings: &[Binding]) -> u32 {
10    dispatch_element_count_inner(bindings, false)
11}
12
13/// Derive the dispatch element count from a binding plan and Program body.
14#[must_use]
15pub fn dispatch_element_count_for_program(program: &Program, bindings: &[Binding]) -> u32 {
16    dispatch_element_count_inner(bindings, program_contains_atomic(program))
17}
18
19fn dispatch_element_count_inner(bindings: &[Binding], force_full_span: bool) -> u32 {
20    // Single pass over bindings: collect every fact the dispatch
21    // policy needs (any-shared / max non-shared / max output) in one
22    // scan. Previously up to three independent .iter() passes
23    // traversed the same slice  -  for launch shapes that carry 60+
24    // bindings each pass is real work.
25    let mut any_shared = false;
26    let mut max_non_shared: u32 = 0;
27    let mut max_output: u32 = 0;
28    for binding in bindings {
29        if binding.role == BindingRole::Shared {
30            any_shared = true;
31            continue;
32        }
33        if binding.element_count > max_non_shared {
34            max_non_shared = binding.element_count;
35        }
36        if matches!(binding.role, BindingRole::Output | BindingRole::InputOutput)
37            && binding.element_count > max_output
38        {
39            max_output = binding.element_count;
40        }
41    }
42    if any_shared || force_full_span {
43        return max_non_shared.max(1);
44    }
45    if max_output > 0 {
46        return max_output;
47    }
48    max_non_shared.max(1)
49}
50
51fn program_contains_atomic(program: &Program) -> bool {
52    // ProgramStats::atomic_op_count is incremented exactly once per
53    // Expr::Atomic during the cached single-pass stats walk. Reading
54    // the cached count replaces the recursive node + expr scan this
55    // function previously performed.
56    program.stats().atomic_op_count > 0
57}
58
59/// Build per-buffer element-count parameter words for a dispatch.
60#[must_use]
61pub fn dispatch_param_words(bindings: &[Binding], element_count: u32) -> Vec<u32> {
62    match try_dispatch_param_words(bindings, element_count) {
63        Ok(words) => words,
64        Err(_error) => Vec::new(),
65    }
66}
67
68/// Build per-buffer element-count parameter words for a dispatch with fallible
69/// host-staging allocation.
70pub fn try_dispatch_param_words(
71    bindings: &[Binding],
72    element_count: u32,
73) -> Result<Vec<u32>, String> {
74    let mut words = Vec::new();
75    try_dispatch_param_words_into(bindings, element_count, &mut words)?;
76    Ok(words)
77}
78
79/// Build per-buffer element-count parameter words into caller-owned storage.
80pub fn dispatch_param_words_into(bindings: &[Binding], element_count: u32, words: &mut Vec<u32>) {
81    if try_dispatch_param_words_into(bindings, element_count, words).is_err() {
82        words.clear();
83    }
84}
85
86/// Build per-buffer element-count parameter words into caller-owned storage
87/// with explicit allocation and ABI-width errors.
88pub fn try_dispatch_param_words_into(
89    bindings: &[Binding],
90    element_count: u32,
91    words: &mut Vec<u32>,
92) -> Result<(), String> {
93    let word_len = dispatch_param_word_len_for_bindings(bindings)?;
94    reserve_dispatch_param_words(words, word_len)?;
95    words.clear();
96    words.resize(word_len, 0);
97    words[0] = element_count;
98    for binding in bindings {
99        if binding.role == BindingRole::Shared {
100            continue;
101        }
102        let slot = dispatch_param_word_slot(binding)?;
103        words[slot] = if binding.element_count == 0 {
104            element_count
105        } else {
106            binding.element_count
107        };
108    }
109    Ok(())
110}
111
112fn dispatch_param_word_len_for_bindings(bindings: &[Binding]) -> Result<usize, String> {
113    let mut word_len = dispatch_param_word_len_checked(bindings.len())?;
114    for binding in bindings {
115        if binding.role == BindingRole::Shared {
116            continue;
117        }
118        let required = dispatch_param_word_slot(binding)?
119            .checked_add(1)
120            .ok_or_else(|| {
121                format!(
122                    "dispatch binding slot {} overflows ABI parameter word count. Fix: split the Program before launch-parameter planning.",
123                    binding.binding
124                )
125            })?;
126        if required > word_len {
127            word_len = required;
128        }
129    }
130    Ok(word_len)
131}
132
133fn dispatch_param_word_slot(binding: &Binding) -> Result<usize, String> {
134    let slot = usize::try_from(binding.binding).map_err(|error| {
135        format!(
136            "dispatch binding slot {} does not fit host usize ({error}). Fix: split the Program before launch-parameter planning.",
137            binding.binding
138        )
139    })?;
140    slot.checked_add(1).ok_or_else(|| {
141        format!(
142            "dispatch binding slot {} overflows ABI parameter slot. Fix: split the Program before launch-parameter planning.",
143            binding.binding
144        )
145    })
146}
147
148fn dispatch_param_word_len_checked(binding_count: usize) -> Result<usize, String> {
149    binding_count.checked_add(1).ok_or_else(|| {
150        format!(
151            "dispatch binding count {binding_count} overflows ABI parameter word count. Fix: split the Program before launch-parameter planning."
152        )
153    })
154}
155
156fn reserve_dispatch_param_words(words: &mut Vec<u32>, word_len: usize) -> Result<(), String> {
157    crate::allocation::try_reserve_vec_to_capacity(words, word_len).map_err(|error| {
158        format!(
159            "dispatch parameter staging could not reserve {word_len} u32 word(s): {error}. Fix: split the Program before launch-parameter planning."
160        )
161    })
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::binding::BindingRole;
168    use std::sync::Arc;
169
170    fn binding(buffer_index: usize, element_count: u32) -> Binding {
171        Binding {
172            name: Arc::from("test"),
173            binding: u32::try_from(buffer_index).expect("Fix: test binding index fits u32"),
174            buffer_index,
175            role: BindingRole::Input,
176            element_size: 4,
177            preferred_alignment: 4,
178            element_count,
179            static_byte_len: None,
180            input_index: Some(0),
181            output_index: None,
182        }
183    }
184
185    #[test]
186    fn dispatch_params_support_sparse_binding_indices_without_repeated_growth() {
187        let bindings = [binding(4, 9), binding(1, 0)];
188        let words = try_dispatch_param_words(&bindings, 7)
189            .expect("Fix: sparse binding parameter words should stage");
190
191        assert_eq!(words, vec![7, 0, 7, 0, 0, 9]);
192    }
193
194    #[test]
195    fn dispatch_params_are_indexed_by_binding_slot_not_program_buffer_index() {
196        let mut sparse = binding(0, 4);
197        sparse.binding = 9;
198        let mut dynamic = binding(7, 0);
199        dynamic.binding = 2;
200        let words = try_dispatch_param_words(&[sparse, dynamic], 11)
201            .expect("Fix: sparse binding-slot parameter words should stage");
202
203        assert_eq!(words.len(), 11);
204        assert_eq!(words[0], 11);
205        assert_eq!(words[3], 11);
206        assert_eq!(words[10], 4);
207        assert_eq!(
208            words[8], 0,
209            "Fix: CUDA/PTX parameter words are indexed by binding slot, not buffer_index."
210        );
211    }
212
213    #[test]
214    fn generated_dispatch_params_cover_sparse_binding_slot_matrix() {
215        let mut checked = 0usize;
216        for seed in 0..4096u32 {
217            let binding_count = (seed as usize % 8) + 1;
218            let mut bindings = Vec::with_capacity(binding_count);
219            for index in 0..binding_count {
220                let mut item = binding(
221                    index,
222                    if index % 3 == 0 {
223                        0
224                    } else {
225                        seed + index as u32
226                    },
227                );
228                item.binding = ((seed.wrapping_mul(17) + (index as u32 * 97)) % 1024) + 1;
229                item.buffer_index = binding_count - 1 - index;
230                bindings.push(item);
231            }
232            let element_count = seed.wrapping_mul(31) | 1;
233            let words = try_dispatch_param_words(&bindings, element_count)
234                .expect("Fix: generated sparse binding-slot param words should stage.");
235            assert_eq!(words[0], element_count, "seed {seed}");
236            for item in &bindings {
237                let slot = usize::try_from(item.binding).expect("Fix: test binding fits usize") + 1;
238                let expected = if item.element_count == 0 {
239                    element_count
240                } else {
241                    item.element_count
242                };
243                assert_eq!(
244                    words[slot], expected,
245                    "Fix: generated dispatch-param seed {seed} binding slot {} must map to words[slot+1], regardless of buffer_index {}.",
246                    item.binding, item.buffer_index
247                );
248                checked += 1;
249            }
250        }
251        assert!(
252            checked >= 18_000,
253            "Fix: generated dispatch-param ABI coverage should exercise thousands of sparse binding layouts, got {checked}."
254        );
255    }
256
257    #[test]
258    fn dispatch_params_source_keeps_fallible_modular_staging() {
259        let source = include_str!("dispatch_params.rs");
260        let production = source
261            .split("#[cfg(test)]")
262            .next()
263            .expect("Fix: dispatch-param source must contain production section before tests");
264
265        assert!(
266            production.contains("pub fn try_dispatch_param_words")
267                && production.contains("pub fn try_dispatch_param_words_into")
268                && production.contains("fn dispatch_param_word_len_for_bindings")
269                && production.contains("fn reserve_dispatch_param_words"),
270            "Fix: dispatch parameter planning must expose modular fallible staging APIs."
271        );
272        assert!(
273            !production.contains("Vec::with_capacity")
274                && !production.contains("words.resize(binding.buffer_index + 2, 0)")
275                && !production.contains("panic!("),
276            "Fix: dispatch parameter planning must not allocate infallibly, grow repeatedly, or panic in release-path helpers."
277        );
278    }
279}