Skip to main content

vyre_driver/
launch.rs

1//! Backend-neutral dispatch launch preparation.
2
3use std::collections::BTreeMap;
4use std::path::{Path, PathBuf};
5use std::sync::{Mutex, OnceLock};
6
7use vyre_foundation::ir::{MemoryKind, Node, Program};
8
9use crate::binding::Binding;
10use crate::program_walks::{
11    dispatch_element_count_for_program, dispatch_param_words_into, infer_dispatch_grid_for_count,
12    program_uses_launch_geometry_ids,
13};
14use crate::tuner::{
15    identity_fisher_q16, Mode, NaturalGradientPolicy, Tuner, TunerCache, TuningMeasurement,
16    WORKGROUP_CANDIDATES,
17};
18use crate::validation::{validate_launch_geometry, LaunchGeometryLimits};
19use crate::{BackendError, DispatchConfig};
20
21const COLD_START_GRID_STEP_NS: u64 = 1_024;
22const COLD_START_IDLE_LANE_NS: u64 = 8;
23const COLD_START_TEMPERATURE_NS: u64 = 4_096;
24const MAX_NATURAL_LAUNCH_CACHE_ENTRIES: usize = 4_096;
25
26static NATURAL_LAUNCH_CACHE: OnceLock<Mutex<BTreeMap<NaturalLaunchCacheKey, NaturalLaunchEntry>>> =
27    OnceLock::new();
28
29/// Fully prepared launch metadata shared by concrete drivers.
30#[derive(Clone, Debug, Eq, PartialEq)]
31pub struct LaunchPlan {
32    /// Logical element count passed to the lowered kernel.
33    pub element_count: u32,
34    /// Effective workgroup/block shape after dispatch overrides.
35    pub workgroup: [u32; 3],
36    /// Effective grid shape after dispatch overrides or inference.
37    pub grid: [u32; 3],
38    /// Per-buffer element-count metadata uploaded as the shared params buffer.
39    pub param_words: Vec<u32>,
40    /// Maximum preferred alignment across all launch bindings.
41    ///
42    /// Concrete drivers use this to pick upload staging and device-buffer
43    /// allocation paths without re-inspecting Program buffer declarations.
44    pub max_binding_alignment: usize,
45}
46
47impl LaunchPlan {
48    /// Empty launch plan with reusable parameter-word storage.
49    #[must_use]
50    pub fn new() -> Self {
51        Self {
52            element_count: 1,
53            workgroup: [1, 1, 1],
54            grid: [1, 1, 1],
55            param_words: Vec::new(),
56            max_binding_alignment: 1,
57        }
58    }
59
60    /// Prepare dispatch geometry and parameter words from a validated binding plan.
61    ///
62    /// # Errors
63    ///
64    /// Returns when caller overrides produce zero dimensions, overflow the
65    /// logical launch element count, or exceed backend-reported launch limits.
66    pub fn from_bindings(
67        program: &Program,
68        bindings: &[Binding],
69        config: &DispatchConfig,
70        limits: LaunchGeometryLimits,
71    ) -> Result<Self, BackendError> {
72        let mut plan = Self::new();
73        plan.prepare_into(program, bindings, config, limits)?;
74        Ok(plan)
75    }
76
77    /// Prepare dispatch geometry and parameter words, reusing this plan's buffers.
78    ///
79    /// # Errors
80    ///
81    /// Returns when caller overrides produce zero dimensions, overflow the
82    /// logical launch element count, or exceed backend-reported launch limits.
83    pub fn prepare_into(
84        &mut self,
85        program: &Program,
86        bindings: &[Binding],
87        config: &DispatchConfig,
88        limits: LaunchGeometryLimits,
89    ) -> Result<(), BackendError> {
90        self.prepare_into_for_mode(program, bindings, config, limits, Mode::from_env())
91    }
92
93    fn prepare_into_for_mode(
94        &mut self,
95        program: &Program,
96        bindings: &[Binding],
97        config: &DispatchConfig,
98        limits: LaunchGeometryLimits,
99        mode: Mode,
100    ) -> Result<(), BackendError> {
101        let workgroup =
102            effective_launch_workgroup_for_mode(program, bindings, config, limits, mode);
103        validate_launch_geometry(workgroup, [1, 1, 1], limits)?;
104        let element_count = launch_element_count(program, bindings, workgroup, config, limits)?;
105        let grid = match config.grid_override {
106            Some(grid) => grid,
107            None => {
108                // Non-1D workgroups need an explicit grid_override  -
109                // there's no single right way to map an unknown
110                // element_count across N×M (or N×M×K) thread tiles,
111                // and silently picking one produces silently-wrong
112                // results. Force the caller to make the choice.
113                if workgroup[1] != 1 || workgroup[2] != 1 {
114                    return Err(BackendError::InvalidProgram {
115                        fix: format!(
116                            "Fix: backend `{}` requires DispatchConfig::grid_override for non-1D workgroups. \
117                             workgroup={:?} has no unambiguous default grid; set grid_override to the logical [x, y, z] you want.",
118                            limits.backend, workgroup,
119                        ),
120                    });
121                }
122                infer_dispatch_grid_for_count(element_count, workgroup)?
123            }
124        };
125        validate_launch_geometry(workgroup, grid, limits)?;
126        self.element_count = element_count;
127        self.workgroup = workgroup;
128        self.grid = grid;
129        self.max_binding_alignment = bindings
130            .iter()
131            .map(|binding| binding.preferred_alignment)
132            .max()
133            .unwrap_or(1);
134        dispatch_param_words_into(bindings, element_count, &mut self.param_words);
135        Ok(())
136    }
137}
138
139impl Default for LaunchPlan {
140    fn default() -> Self {
141        Self::new()
142    }
143}
144
145fn launch_element_count(
146    program: &Program,
147    bindings: &[Binding],
148    workgroup: [u32; 3],
149    config: &DispatchConfig,
150    limits: LaunchGeometryLimits,
151) -> Result<u32, BackendError> {
152    let inferred = dispatch_element_count_for_program(program, bindings);
153    let Some(grid) = config.grid_override else {
154        return Ok(inferred);
155    };
156    if workgroup.contains(&0) || grid.contains(&0) {
157        return Err(BackendError::InvalidProgram {
158            fix: format!(
159                "Fix: {} grid_override and workgroup dimensions must all be non-zero.",
160                limits.backend
161            ),
162        });
163    }
164    grid[0]
165        .checked_mul(workgroup[0])
166        .filter(|count| *count != 0)
167        .ok_or_else(|| BackendError::InvalidProgram {
168            fix: format!(
169                "Fix: {} grid_override.x * workgroup_size.x must fit in u32.",
170                limits.backend
171            ),
172        })
173}
174
175fn effective_launch_workgroup_for_mode(
176    program: &Program,
177    bindings: &[Binding],
178    config: &DispatchConfig,
179    limits: LaunchGeometryLimits,
180    mode: Mode,
181) -> [u32; 3] {
182    let element_count = dispatch_element_count_for_program(program, bindings);
183    resolve_launch_workgroup_for_mode(program, config, limits, element_count, mode)
184}
185
186/// Resolve the backend-visible workgroup shape for a dispatch.
187///
188/// Explicit caller overrides remain authoritative. When no override is
189/// supplied and `VYRE_AUTOTUNER` resolves to natural-gradient mode, eligible
190/// 1D storage-only kernels receive a deterministic natural-gradient cold-start
191/// workgroup before grid inference.
192#[must_use]
193pub fn resolve_launch_workgroup(
194    program: &Program,
195    config: &DispatchConfig,
196    limits: LaunchGeometryLimits,
197    element_count: u32,
198) -> [u32; 3] {
199    resolve_launch_workgroup_for_mode(program, config, limits, element_count, Mode::from_env())
200}
201
202/// Resolve the backend-visible workgroup shape with an explicit tuner mode.
203///
204/// This is public so backends whose shader/pipeline compilation must include
205/// the selected workgroup size can derive the same shape before lowering.
206#[must_use]
207pub fn resolve_launch_workgroup_for_mode(
208    program: &Program,
209    config: &DispatchConfig,
210    limits: LaunchGeometryLimits,
211    element_count: u32,
212    mode: Mode,
213) -> [u32; 3] {
214    if let Some(workgroup) = config.workgroup_override {
215        return workgroup;
216    }
217    let declared = program.workgroup_size();
218    if mode != Mode::NaturalGradient || config.grid_override.is_some() {
219        return declared;
220    }
221    natural_gradient_cold_start_workgroup(program, declared, element_count, limits)
222        .unwrap_or(declared)
223}
224
225/// Record a measured launch result for the natural-gradient launch resolver.
226///
227/// Backends should call this only after a real dispatch timing is available.
228/// The function returns `true` when the measurement was accepted into the
229/// bounded feedback cache. Explicit caller overrides, explicit grid launches,
230/// non-natural tuner modes, non-1D kernels, workgroup-local scratch kernels,
231/// zero timings, and out-of-limit candidates are ignored so measured feedback
232/// never changes kernel semantics.
233#[must_use]
234pub fn record_launch_measurement(
235    program: &Program,
236    config: &DispatchConfig,
237    limits: LaunchGeometryLimits,
238    element_count: u32,
239    observed_workgroup: [u32; 3],
240    elapsed_ns: u64,
241) -> bool {
242    record_launch_measurement_for_mode(
243        program,
244        config,
245        limits,
246        element_count,
247        observed_workgroup,
248        elapsed_ns,
249        Mode::from_env(),
250    )
251}
252
253fn record_launch_measurement_for_mode(
254    program: &Program,
255    config: &DispatchConfig,
256    limits: LaunchGeometryLimits,
257    element_count: u32,
258    observed_workgroup: [u32; 3],
259    elapsed_ns: u64,
260    mode: Mode,
261) -> bool {
262    record_launch_measurement_for_mode_with_store(
263        program,
264        config,
265        limits,
266        element_count,
267        observed_workgroup,
268        elapsed_ns,
269        mode,
270        None,
271    )
272}
273
274fn record_launch_measurement_for_mode_with_store(
275    program: &Program,
276    config: &DispatchConfig,
277    limits: LaunchGeometryLimits,
278    element_count: u32,
279    observed_workgroup: [u32; 3],
280    elapsed_ns: u64,
281    mode: Mode,
282    persistent_path: Option<&Path>,
283) -> bool {
284    if mode != Mode::NaturalGradient
285        || elapsed_ns == 0
286        || config.workgroup_override.is_some()
287        || config.grid_override.is_some()
288        || observed_workgroup[1] != 1
289        || observed_workgroup[2] != 1
290        || !candidate_x_fits_limits(observed_workgroup[0], limits)
291    {
292        return false;
293    }
294    let declared = program.workgroup_size();
295    if !is_natural_gradient_launch_tunable(program, declared, element_count) {
296        return false;
297    }
298    let cache_key = NaturalLaunchCacheKey::new(program, declared, element_count, limits);
299    let mut measurements = natural_launch_cache_measurements(cache_key).unwrap_or_default();
300    measurements
301        .entry(observed_workgroup)
302        .and_modify(|best_ns| *best_ns = (*best_ns).min(elapsed_ns))
303        .or_insert(elapsed_ns);
304    let Some(selected) =
305        select_natural_launch_workgroup(declared, element_count, limits, Some(&measurements))
306    else {
307        return false;
308    };
309    natural_launch_cache_set(
310        cache_key,
311        NaturalLaunchEntry {
312            selected,
313            measurements,
314        },
315    );
316    if let Err(error) =
317        persist_natural_launch_selection(cache_key, limits, selected, persistent_path)
318    {
319        tracing::debug!(
320            error,
321            "natural-gradient launch feedback accepted in memory but could not persist"
322        );
323    }
324    true
325}
326
327fn natural_gradient_cold_start_workgroup(
328    program: &Program,
329    declared: [u32; 3],
330    element_count: u32,
331    limits: LaunchGeometryLimits,
332) -> Option<[u32; 3]> {
333    natural_gradient_cold_start_workgroup_with_store(program, declared, element_count, limits, None)
334}
335
336fn natural_gradient_cold_start_workgroup_with_store(
337    program: &Program,
338    declared: [u32; 3],
339    element_count: u32,
340    limits: LaunchGeometryLimits,
341    persistent_path: Option<&Path>,
342) -> Option<[u32; 3]> {
343    if !is_natural_gradient_launch_tunable(program, declared, element_count) {
344        return None;
345    }
346    let cache_key = NaturalLaunchCacheKey::new(program, declared, element_count, limits);
347    if let Some(cached) = natural_launch_cache_get(cache_key) {
348        return Some(cached);
349    }
350    if let Some(persisted) = natural_launch_cache_get_persisted(cache_key, limits, persistent_path)
351    {
352        natural_launch_cache_set(
353            cache_key,
354            NaturalLaunchEntry {
355                selected: persisted,
356                measurements: BTreeMap::new(),
357            },
358        );
359        return Some(persisted);
360    }
361
362    let selected = select_natural_launch_workgroup(declared, element_count, limits, None)?;
363    natural_launch_cache_set(
364        cache_key,
365        NaturalLaunchEntry {
366            selected,
367            measurements: BTreeMap::new(),
368        },
369    );
370    Some(selected)
371}
372
373fn select_natural_launch_workgroup(
374    declared: [u32; 3],
375    element_count: u32,
376    limits: LaunchGeometryLimits,
377    measurements: Option<&BTreeMap<[u32; 3], u64>>,
378) -> Option<[u32; 3]> {
379    let mut samples = Vec::with_capacity(WORKGROUP_CANDIDATES.len() + 1);
380    for candidate_x in WORKGROUP_CANDIDATES
381        .iter()
382        .copied()
383        .chain(std::iter::once(declared[0]))
384    {
385        if !candidate_x_fits_limits(candidate_x, limits)
386            || samples
387                .iter()
388                .any(|sample: &TuningMeasurement| sample.workgroup_size[0] == candidate_x)
389        {
390            continue;
391        }
392        let workgroup_size = [candidate_x, 1, 1];
393        let elapsed_ns = measurements
394            .and_then(|measured| measured.get(&workgroup_size).copied())
395            .unwrap_or_else(|| estimate_cold_start_latency_ns(element_count, candidate_x));
396        samples.push(TuningMeasurement {
397            workgroup_size,
398            elapsed_ns,
399        });
400    }
401    if let Some(measured) = measurements {
402        for (&workgroup_size, &elapsed_ns) in measured {
403            if workgroup_size[1] != 1
404                || workgroup_size[2] != 1
405                || elapsed_ns == 0
406                || !candidate_x_fits_limits(workgroup_size[0], limits)
407                || samples
408                    .iter()
409                    .any(|sample| sample.workgroup_size == workgroup_size)
410            {
411                continue;
412            }
413            samples.push(TuningMeasurement {
414                workgroup_size,
415                elapsed_ns,
416            });
417        }
418    }
419
420    if samples.len() < 2 {
421        return None;
422    }
423    NaturalGradientPolicy {
424        temperature_ns: COLD_START_TEMPERATURE_NS,
425    }
426    .suggest(&samples, &identity_fisher_q16(samples.len()))
427    .ok()
428    .map(|step| step.selected_workgroup_size)
429}
430
431#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
432struct NaturalLaunchCacheKey {
433    fingerprint: [u8; 32],
434    declared: [u32; 3],
435    element_count: u32,
436    max_threads_per_block: u32,
437    max_block_dim: [u32; 3],
438    max_grid_dim: [u32; 3],
439}
440
441impl NaturalLaunchCacheKey {
442    fn new(
443        program: &Program,
444        declared: [u32; 3],
445        element_count: u32,
446        limits: LaunchGeometryLimits,
447    ) -> Self {
448        Self {
449            fingerprint: program.fingerprint(),
450            declared,
451            element_count,
452            max_threads_per_block: limits.max_threads_per_block,
453            max_block_dim: limits.max_block_dim,
454            max_grid_dim: limits.max_grid_dim,
455        }
456    }
457
458    fn persistent_key(self) -> String {
459        let mut hasher = blake3::Hasher::new();
460        hasher.update(b"vyre-natural-launch-feedback-v1\0");
461        hasher.update(&self.fingerprint);
462        for axis in self.declared {
463            hasher.update(&axis.to_le_bytes());
464        }
465        hasher.update(&self.element_count.to_le_bytes());
466        hasher.update(&self.max_threads_per_block.to_le_bytes());
467        for axis in self.max_block_dim {
468            hasher.update(&axis.to_le_bytes());
469        }
470        for axis in self.max_grid_dim {
471            hasher.update(&axis.to_le_bytes());
472        }
473        let digest = hasher.finalize();
474        let mut key = String::with_capacity(74);
475        key.push_str("launch-v1-");
476        push_hex(digest.as_bytes(), &mut key);
477        key
478    }
479}
480
481#[derive(Clone, Debug, Eq, PartialEq)]
482
483struct NaturalLaunchEntry {
484    selected: [u32; 3],
485    measurements: BTreeMap<[u32; 3], u64>,
486}
487
488fn natural_launch_cache_get(key: NaturalLaunchCacheKey) -> Option<[u32; 3]> {
489    let cache = NATURAL_LAUNCH_CACHE.get_or_init(|| Mutex::new(BTreeMap::new()));
490    cache
491        .lock()
492        .ok()
493        .and_then(|guard| guard.get(&key).map(|entry| entry.selected))
494}
495
496fn natural_launch_cache_measurements(
497    key: NaturalLaunchCacheKey,
498) -> Option<BTreeMap<[u32; 3], u64>> {
499    let cache = NATURAL_LAUNCH_CACHE.get_or_init(|| Mutex::new(BTreeMap::new()));
500    cache
501        .lock()
502        .ok()
503        .and_then(|guard| guard.get(&key).map(|entry| entry.measurements.clone()))
504}
505
506fn natural_launch_cache_set(key: NaturalLaunchCacheKey, value: NaturalLaunchEntry) {
507    let cache = NATURAL_LAUNCH_CACHE.get_or_init(|| Mutex::new(BTreeMap::new()));
508    if let Ok(mut guard) = cache.lock() {
509        if guard.len() >= MAX_NATURAL_LAUNCH_CACHE_ENTRIES && !guard.contains_key(&key) {
510            if let Some(oldest) = guard.keys().next().copied() {
511                guard.remove(&oldest);
512            }
513        }
514        guard.insert(key, value);
515    }
516}
517
518#[cfg(test)]
519fn natural_launch_cache_remove(key: NaturalLaunchCacheKey) {
520    if let Some(cache) = NATURAL_LAUNCH_CACHE.get() {
521        if let Ok(mut guard) = cache.lock() {
522            guard.remove(&key);
523        }
524    }
525}
526
527fn natural_launch_cache_get_persisted(
528    key: NaturalLaunchCacheKey,
529    limits: LaunchGeometryLimits,
530    persistent_path: Option<&Path>,
531) -> Option<[u32; 3]> {
532    let path = persistent_path
533        .map(Path::to_path_buf)
534        .unwrap_or_else(|| natural_launch_persistent_cache_path(limits));
535    let selected = TunerCache::load(&path).ok()?.get(&key.persistent_key())?;
536    valid_persisted_launch_selection(selected, limits).then_some(selected)
537}
538
539fn persist_natural_launch_selection(
540    key: NaturalLaunchCacheKey,
541    limits: LaunchGeometryLimits,
542    selected: [u32; 3],
543    persistent_path: Option<&Path>,
544) -> Result<(), String> {
545    let path = persistent_path
546        .map(Path::to_path_buf)
547        .unwrap_or_else(|| natural_launch_persistent_cache_path(limits));
548    persist_natural_launch_selection_to_path(key, selected, &path)
549}
550
551fn persist_natural_launch_selection_to_path(
552    key: NaturalLaunchCacheKey,
553    selected: [u32; 3],
554    path: &Path,
555) -> Result<(), String> {
556    let mut cache = TunerCache::load(path)?;
557    while cache.entries.len() >= MAX_NATURAL_LAUNCH_CACHE_ENTRIES {
558        let Some(oldest) = cache.entries.keys().next().cloned() else {
559            break;
560        };
561        cache.entries.remove(&oldest);
562    }
563    cache.set(key.persistent_key(), selected);
564    cache.save(path)
565}
566
567fn natural_launch_persistent_cache_path(limits: LaunchGeometryLimits) -> PathBuf {
568    Tuner::cache_path_for_adapter(&natural_launch_persistent_adapter_key(limits))
569}
570
571fn natural_launch_persistent_adapter_key(limits: LaunchGeometryLimits) -> String {
572    let mut hasher = blake3::Hasher::new();
573    hasher.update(b"vyre-natural-launch-adapter-v1\0");
574    hasher.update(limits.backend.as_bytes());
575    hasher.update(&limits.max_threads_per_block.to_le_bytes());
576    for axis in limits.max_block_dim {
577        hasher.update(&axis.to_le_bytes());
578    }
579    for axis in limits.max_grid_dim {
580        hasher.update(&axis.to_le_bytes());
581    }
582    let digest = hasher.finalize();
583    let mut key = String::with_capacity(92);
584    key.push_str("natural-launch-feedback-v1-");
585    push_hex(digest.as_bytes(), &mut key);
586    key
587}
588
589fn valid_persisted_launch_selection(selected: [u32; 3], limits: LaunchGeometryLimits) -> bool {
590    selected[1] == 1 && selected[2] == 1 && candidate_x_fits_limits(selected[0], limits)
591}
592
593fn push_hex(bytes: &[u8], out: &mut String) {
594    const HEX: &[u8; 16] = b"0123456789abcdef";
595    for &byte in bytes {
596        out.push(HEX[(byte >> 4) as usize] as char);
597        out.push(HEX[(byte & 0x0f) as usize] as char);
598    }
599}
600
601fn is_natural_gradient_launch_tunable(
602    program: &Program,
603    declared: [u32; 3],
604    element_count: u32,
605) -> bool {
606    declared[0] != 0
607        && declared[1] == 1
608        && declared[2] == 1
609        && element_count != 0
610        && program
611            .entry
612            .iter()
613            .any(|node| !matches!(node, Node::Return))
614        && !program.non_composable_with_self
615        && !program_uses_launch_geometry_ids(program)
616        && program
617            .buffers
618            .iter()
619            .all(|buffer| buffer.kind() != MemoryKind::Shared)
620}
621
622fn candidate_x_fits_limits(candidate_x: u32, limits: LaunchGeometryLimits) -> bool {
623    candidate_x != 0
624        && candidate_x <= limits.max_threads_per_block
625        && candidate_x <= limits.max_block_dim[0]
626}
627
628fn estimate_cold_start_latency_ns(element_count: u32, candidate_x: u32) -> u64 {
629    let groups = u64::from(element_count.div_ceil(candidate_x));
630    let scheduled_lanes = groups.saturating_mul(u64::from(candidate_x));
631    let idle_lanes = scheduled_lanes.saturating_sub(u64::from(element_count));
632    groups
633        .saturating_mul(COLD_START_GRID_STEP_NS)
634        .saturating_add(idle_lanes.saturating_mul(COLD_START_IDLE_LANE_NS))
635}
636
637/// Compute the shared VSA program fingerprint used by backend caches.
638#[must_use]
639pub fn program_vsa_fingerprint(program: &Program) -> Vec<u32> {
640    program_vsa_fingerprint_words(program).to_vec()
641}
642
643/// Compute the shared VSA program fingerprint without heap allocation.
644#[must_use]
645pub fn program_vsa_fingerprint_words(program: &Program) -> [u32; 8] {
646    let fingerprint = program.fingerprint();
647    let mut words = [0u32; 8];
648    for (word, chunk) in words.iter_mut().zip(fingerprint.chunks_exact(4)) {
649        *word = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
650    }
651    words
652}
653
654#[cfg(test)]
655mod tests {
656    use super::*;
657    use crate::binding::BindingRole;
658    use vyre_foundation::ir::{BufferDecl, DataType, Expr, Node, Program};
659
660    #[test]
661    fn program_vsa_fingerprint_words_match_wire_decoder() {
662        let program = Program::wrapped(vec![], [64, 1, 1], vec![]);
663        let words = program_vsa_fingerprint_words(&program);
664        let fingerprint = program.fingerprint();
665
666        for (index, chunk) in fingerprint.chunks_exact(4).enumerate() {
667            assert_eq!(
668                words[index],
669                u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
670            );
671        }
672        assert_eq!(program_vsa_fingerprint(&program), words.to_vec());
673    }
674
675    #[test]
676    fn launch_plan_prepare_into_reuses_param_words() {
677        let program = Program::wrapped(vec![], [64, 1, 1], vec![]);
678        let bindings = vec![Binding {
679            name: std::sync::Arc::from("input"),
680            binding: 0,
681            buffer_index: 0,
682            role: BindingRole::Input,
683            element_size: 4,
684            preferred_alignment: 64,
685            element_count: 7,
686            static_byte_len: Some(28),
687            input_index: Some(0),
688            output_index: None,
689        }];
690        let limits = LaunchGeometryLimits {
691            backend: "test",
692            max_threads_per_block: 1024,
693            max_block_dim: [1024, 1024, 64],
694            max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
695        };
696        let mut plan = LaunchPlan {
697            param_words: Vec::with_capacity(8),
698            ..LaunchPlan::new()
699        };
700        let ptr = plan.param_words.as_ptr();
701        plan.prepare_into(&program, &bindings, &DispatchConfig::default(), limits)
702            .unwrap();
703        assert_eq!(plan.element_count, 7);
704        assert_eq!(plan.grid, [1, 1, 1]);
705        assert_eq!(plan.param_words, vec![7, 7]);
706        assert_eq!(plan.max_binding_alignment, 64);
707        assert_eq!(plan.param_words.as_ptr(), ptr);
708    }
709
710    #[test]
711    fn natural_gradient_launch_tunes_safe_1d_storage_program() {
712        let program = Program::wrapped(
713            vec![BufferDecl::output("out", 0, DataType::U32).with_count(4096)],
714            [32, 1, 1],
715            vec![],
716        );
717        let bindings = vec![Binding {
718            name: std::sync::Arc::from("out"),
719            binding: 0,
720            buffer_index: 0,
721            role: BindingRole::Output,
722            element_size: 4,
723            preferred_alignment: 128,
724            element_count: 4096,
725            static_byte_len: Some(16_384),
726            input_index: None,
727            output_index: Some(0),
728        }];
729        let limits = LaunchGeometryLimits {
730            backend: "test",
731            max_threads_per_block: 1024,
732            max_block_dim: [1024, 1024, 64],
733            max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
734        };
735        let mut plan = LaunchPlan::new();
736
737        plan.prepare_into_for_mode(
738            &program,
739            &bindings,
740            &DispatchConfig::default(),
741            limits,
742            Mode::NaturalGradient,
743        )
744        .expect("Fix: safe 1D storage launch should accept natural-gradient cold start");
745
746        assert_eq!(plan.workgroup, [1024, 1, 1]);
747        assert_eq!(plan.grid, [4, 1, 1]);
748        assert_eq!(plan.element_count, 4096);
749    }
750
751    #[test]
752    fn natural_gradient_launch_preserves_declared_shape_for_local_workgroup_ids() {
753        let program = Program::wrapped(
754            vec![BufferDecl::output("out_local_ids", 0, DataType::U32).with_count(4096)],
755            [1024, 1, 1],
756            vec![
757                Node::let_bind("lane", Expr::LocalId { axis: 0 }),
758                Node::let_bind("block", Expr::WorkgroupId { axis: 0 }),
759                Node::let_bind(
760                    "global",
761                    Expr::add(
762                        Expr::mul(Expr::var("block"), Expr::u32(1024)),
763                        Expr::var("lane"),
764                    ),
765                ),
766                Node::store("out_local_ids", Expr::var("global"), Expr::var("lane")),
767            ],
768        );
769        let bindings = vec![Binding {
770            name: std::sync::Arc::from("out_local_ids"),
771            binding: 0,
772            buffer_index: 0,
773            role: BindingRole::Output,
774            element_size: 4,
775            preferred_alignment: 128,
776            element_count: 4096,
777            static_byte_len: Some(16_384),
778            input_index: None,
779            output_index: Some(0),
780        }];
781        let limits = LaunchGeometryLimits {
782            backend: "test",
783            max_threads_per_block: 1024,
784            max_block_dim: [1024, 1024, 64],
785            max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
786        };
787
788        assert_eq!(
789            effective_launch_workgroup_for_mode(
790                &program,
791                &bindings,
792                &DispatchConfig::default(),
793                limits,
794                Mode::NaturalGradient,
795            ),
796            [1024, 1, 1],
797            "Fix: automatic launch tuning must not change kernels whose LocalId/WorkgroupId arithmetic makes workgroup shape semantic."
798        );
799    }
800
801    #[test]
802    fn measured_launch_feedback_overrides_heuristic_cold_start() {
803        let dir = tempfile::tempdir()
804            .expect("Fix: measured launch feedback test needs an isolated tuner cache");
805        let path = dir.path().join("launch-feedback.toml");
806        let program = Program::wrapped(
807            vec![BufferDecl::output("out_feedback_isolated", 0, DataType::U32).with_count(8192)],
808            [32, 1, 1],
809            vec![],
810        );
811        let config = DispatchConfig::default();
812        let limits = LaunchGeometryLimits {
813            backend: "test",
814            max_threads_per_block: 1024,
815            max_block_dim: [1024, 1024, 64],
816            max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
817        };
818        let key = NaturalLaunchCacheKey::new(&program, [32, 1, 1], 8192, limits);
819        natural_launch_cache_remove(key);
820
821        assert_eq!(
822            natural_gradient_cold_start_workgroup_with_store(
823                &program,
824                [32, 1, 1],
825                8192,
826                limits,
827                Some(&path),
828            ),
829            Some([1024, 1, 1]),
830            "Fix: baseline heuristic should pick the occupancy-efficient cold-start shape."
831        );
832        assert!(
833            record_launch_measurement_for_mode_with_store(
834                &program,
835                &config,
836                limits,
837                8192,
838                [64, 1, 1],
839                1,
840                Mode::NaturalGradient,
841                Some(&path),
842            ),
843            "Fix: natural-gradient resolver must accept measured backend timing for safe 1D launches."
844        );
845        assert_eq!(
846            natural_gradient_cold_start_workgroup_with_store(
847                &program,
848                [32, 1, 1],
849                8192,
850                limits,
851                Some(&path),
852            ),
853            Some([64, 1, 1]),
854            "Fix: measured launch feedback must steer future automatic launch choices."
855        );
856    }
857
858    #[test]
859    fn persisted_launch_feedback_rehydrates_measured_selection() {
860        let dir = tempfile::tempdir()
861            .expect("Fix: launch feedback persistence test needs a temporary cache directory");
862        let path = dir.path().join("launch-feedback.toml");
863        let program = Program::wrapped(
864            vec![BufferDecl::output("out_persisted", 0, DataType::U32).with_count(16_384)],
865            [32, 1, 1],
866            vec![],
867        );
868        let limits = LaunchGeometryLimits {
869            backend: "test",
870            max_threads_per_block: 1024,
871            max_block_dim: [1024, 1024, 64],
872            max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
873        };
874        let key = NaturalLaunchCacheKey::new(&program, [32, 1, 1], 16_384, limits);
875        natural_launch_cache_remove(key);
876
877        persist_natural_launch_selection_to_path(key, [64, 1, 1], &path)
878            .expect("Fix: measured launch feedback should persist through the tuner cache format");
879
880        assert_eq!(
881            natural_gradient_cold_start_workgroup_with_store(
882                &program,
883                [32, 1, 1],
884                16_384,
885                limits,
886                Some(&path),
887            ),
888            Some([64, 1, 1]),
889            "Fix: automatic launch resolution must rehydrate measured feedback from the bounded tuner cache before falling back to heuristics."
890        );
891    }
892
893    #[test]
894    fn natural_gradient_launch_preserves_explicit_and_shared_memory_shapes() {
895        let program = Program::wrapped(
896            vec![
897                BufferDecl::output("out", 0, DataType::U32).with_count(4096),
898                BufferDecl::workgroup("scratch", 64, DataType::U32),
899            ],
900            [64, 1, 1],
901            vec![],
902        );
903        let bindings = vec![Binding {
904            name: std::sync::Arc::from("out"),
905            binding: 0,
906            buffer_index: 0,
907            role: BindingRole::Output,
908            element_size: 4,
909            preferred_alignment: 128,
910            element_count: 4096,
911            static_byte_len: Some(16_384),
912            input_index: None,
913            output_index: Some(0),
914        }];
915        let limits = LaunchGeometryLimits {
916            backend: "test",
917            max_threads_per_block: 1024,
918            max_block_dim: [1024, 1024, 64],
919            max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
920        };
921        let mut config = DispatchConfig::default();
922        config.workgroup_override = Some([256, 1, 1]);
923
924        assert_eq!(
925            effective_launch_workgroup_for_mode(
926                &program,
927                &bindings,
928                &config,
929                limits,
930                Mode::NaturalGradient,
931            ),
932            [256, 1, 1],
933            "Fix: explicit dispatch workgroup overrides must remain authoritative."
934        );
935
936        let default_config = DispatchConfig::default();
937        assert_eq!(
938            effective_launch_workgroup_for_mode(
939                &program,
940                &bindings,
941                &default_config,
942                limits,
943                Mode::NaturalGradient,
944            ),
945            [64, 1, 1],
946            "Fix: workgroup-local scratch kernels must keep their declared shape."
947        );
948    }
949}