1use std::collections::BTreeMap;
4use std::path::{Path, PathBuf};
5use std::sync::{Mutex, OnceLock};
6
7use vyre_foundation::ir::{MemoryKind, Node, Program};
8
9use crate::binding::Binding;
10use crate::program_walks::{
11 dispatch_element_count_for_program, dispatch_param_words_into, infer_dispatch_grid_for_count,
12 program_uses_launch_geometry_ids,
13};
14use crate::tuner::{
15 identity_fisher_q16, Mode, NaturalGradientPolicy, Tuner, TunerCache, TuningMeasurement,
16 WORKGROUP_CANDIDATES,
17};
18use crate::validation::{validate_launch_geometry, LaunchGeometryLimits};
19use crate::{BackendError, DispatchConfig};
20
21const COLD_START_GRID_STEP_NS: u64 = 1_024;
22const COLD_START_IDLE_LANE_NS: u64 = 8;
23const COLD_START_TEMPERATURE_NS: u64 = 4_096;
24const MAX_NATURAL_LAUNCH_CACHE_ENTRIES: usize = 4_096;
25
26static NATURAL_LAUNCH_CACHE: OnceLock<Mutex<BTreeMap<NaturalLaunchCacheKey, NaturalLaunchEntry>>> =
27 OnceLock::new();
28
29#[derive(Clone, Debug, Eq, PartialEq)]
31pub struct LaunchPlan {
32 pub element_count: u32,
34 pub workgroup: [u32; 3],
36 pub grid: [u32; 3],
38 pub param_words: Vec<u32>,
40 pub max_binding_alignment: usize,
45}
46
47impl LaunchPlan {
48 #[must_use]
50 pub fn new() -> Self {
51 Self {
52 element_count: 1,
53 workgroup: [1, 1, 1],
54 grid: [1, 1, 1],
55 param_words: Vec::new(),
56 max_binding_alignment: 1,
57 }
58 }
59
60 pub fn from_bindings(
67 program: &Program,
68 bindings: &[Binding],
69 config: &DispatchConfig,
70 limits: LaunchGeometryLimits,
71 ) -> Result<Self, BackendError> {
72 let mut plan = Self::new();
73 plan.prepare_into(program, bindings, config, limits)?;
74 Ok(plan)
75 }
76
77 pub fn prepare_into(
84 &mut self,
85 program: &Program,
86 bindings: &[Binding],
87 config: &DispatchConfig,
88 limits: LaunchGeometryLimits,
89 ) -> Result<(), BackendError> {
90 self.prepare_into_for_mode(program, bindings, config, limits, Mode::from_env())
91 }
92
93 fn prepare_into_for_mode(
94 &mut self,
95 program: &Program,
96 bindings: &[Binding],
97 config: &DispatchConfig,
98 limits: LaunchGeometryLimits,
99 mode: Mode,
100 ) -> Result<(), BackendError> {
101 let workgroup =
102 effective_launch_workgroup_for_mode(program, bindings, config, limits, mode);
103 validate_launch_geometry(workgroup, [1, 1, 1], limits)?;
104 let element_count = launch_element_count(program, bindings, workgroup, config, limits)?;
105 let grid = match config.grid_override {
106 Some(grid) => grid,
107 None => {
108 if workgroup[1] != 1 || workgroup[2] != 1 {
114 return Err(BackendError::InvalidProgram {
115 fix: format!(
116 "Fix: backend `{}` requires DispatchConfig::grid_override for non-1D workgroups. \
117 workgroup={:?} has no unambiguous default grid; set grid_override to the logical [x, y, z] you want.",
118 limits.backend, workgroup,
119 ),
120 });
121 }
122 infer_dispatch_grid_for_count(element_count, workgroup)?
123 }
124 };
125 validate_launch_geometry(workgroup, grid, limits)?;
126 self.element_count = element_count;
127 self.workgroup = workgroup;
128 self.grid = grid;
129 self.max_binding_alignment = bindings
130 .iter()
131 .map(|binding| binding.preferred_alignment)
132 .max()
133 .unwrap_or(1);
134 dispatch_param_words_into(bindings, element_count, &mut self.param_words);
135 Ok(())
136 }
137}
138
139impl Default for LaunchPlan {
140 fn default() -> Self {
141 Self::new()
142 }
143}
144
145fn launch_element_count(
146 program: &Program,
147 bindings: &[Binding],
148 workgroup: [u32; 3],
149 config: &DispatchConfig,
150 limits: LaunchGeometryLimits,
151) -> Result<u32, BackendError> {
152 let inferred = dispatch_element_count_for_program(program, bindings);
153 let Some(grid) = config.grid_override else {
154 return Ok(inferred);
155 };
156 if workgroup.contains(&0) || grid.contains(&0) {
157 return Err(BackendError::InvalidProgram {
158 fix: format!(
159 "Fix: {} grid_override and workgroup dimensions must all be non-zero.",
160 limits.backend
161 ),
162 });
163 }
164 grid[0]
165 .checked_mul(workgroup[0])
166 .filter(|count| *count != 0)
167 .ok_or_else(|| BackendError::InvalidProgram {
168 fix: format!(
169 "Fix: {} grid_override.x * workgroup_size.x must fit in u32.",
170 limits.backend
171 ),
172 })
173}
174
175fn effective_launch_workgroup_for_mode(
176 program: &Program,
177 bindings: &[Binding],
178 config: &DispatchConfig,
179 limits: LaunchGeometryLimits,
180 mode: Mode,
181) -> [u32; 3] {
182 let element_count = dispatch_element_count_for_program(program, bindings);
183 resolve_launch_workgroup_for_mode(program, config, limits, element_count, mode)
184}
185
186#[must_use]
193pub fn resolve_launch_workgroup(
194 program: &Program,
195 config: &DispatchConfig,
196 limits: LaunchGeometryLimits,
197 element_count: u32,
198) -> [u32; 3] {
199 resolve_launch_workgroup_for_mode(program, config, limits, element_count, Mode::from_env())
200}
201
202#[must_use]
207pub fn resolve_launch_workgroup_for_mode(
208 program: &Program,
209 config: &DispatchConfig,
210 limits: LaunchGeometryLimits,
211 element_count: u32,
212 mode: Mode,
213) -> [u32; 3] {
214 if let Some(workgroup) = config.workgroup_override {
215 return workgroup;
216 }
217 let declared = program.workgroup_size();
218 if mode != Mode::NaturalGradient || config.grid_override.is_some() {
219 return declared;
220 }
221 natural_gradient_cold_start_workgroup(program, declared, element_count, limits)
222 .unwrap_or(declared)
223}
224
225#[must_use]
234pub fn record_launch_measurement(
235 program: &Program,
236 config: &DispatchConfig,
237 limits: LaunchGeometryLimits,
238 element_count: u32,
239 observed_workgroup: [u32; 3],
240 elapsed_ns: u64,
241) -> bool {
242 record_launch_measurement_for_mode(
243 program,
244 config,
245 limits,
246 element_count,
247 observed_workgroup,
248 elapsed_ns,
249 Mode::from_env(),
250 )
251}
252
253fn record_launch_measurement_for_mode(
254 program: &Program,
255 config: &DispatchConfig,
256 limits: LaunchGeometryLimits,
257 element_count: u32,
258 observed_workgroup: [u32; 3],
259 elapsed_ns: u64,
260 mode: Mode,
261) -> bool {
262 record_launch_measurement_for_mode_with_store(
263 program,
264 config,
265 limits,
266 element_count,
267 observed_workgroup,
268 elapsed_ns,
269 mode,
270 None,
271 )
272}
273
274fn record_launch_measurement_for_mode_with_store(
275 program: &Program,
276 config: &DispatchConfig,
277 limits: LaunchGeometryLimits,
278 element_count: u32,
279 observed_workgroup: [u32; 3],
280 elapsed_ns: u64,
281 mode: Mode,
282 persistent_path: Option<&Path>,
283) -> bool {
284 if mode != Mode::NaturalGradient
285 || elapsed_ns == 0
286 || config.workgroup_override.is_some()
287 || config.grid_override.is_some()
288 || observed_workgroup[1] != 1
289 || observed_workgroup[2] != 1
290 || !candidate_x_fits_limits(observed_workgroup[0], limits)
291 {
292 return false;
293 }
294 let declared = program.workgroup_size();
295 if !is_natural_gradient_launch_tunable(program, declared, element_count) {
296 return false;
297 }
298 let cache_key = NaturalLaunchCacheKey::new(program, declared, element_count, limits);
299 let mut measurements = natural_launch_cache_measurements(cache_key).unwrap_or_default();
300 measurements
301 .entry(observed_workgroup)
302 .and_modify(|best_ns| *best_ns = (*best_ns).min(elapsed_ns))
303 .or_insert(elapsed_ns);
304 let Some(selected) =
305 select_natural_launch_workgroup(declared, element_count, limits, Some(&measurements))
306 else {
307 return false;
308 };
309 natural_launch_cache_set(
310 cache_key,
311 NaturalLaunchEntry {
312 selected,
313 measurements,
314 },
315 );
316 if let Err(error) =
317 persist_natural_launch_selection(cache_key, limits, selected, persistent_path)
318 {
319 tracing::debug!(
320 error,
321 "natural-gradient launch feedback accepted in memory but could not persist"
322 );
323 }
324 true
325}
326
327fn natural_gradient_cold_start_workgroup(
328 program: &Program,
329 declared: [u32; 3],
330 element_count: u32,
331 limits: LaunchGeometryLimits,
332) -> Option<[u32; 3]> {
333 natural_gradient_cold_start_workgroup_with_store(program, declared, element_count, limits, None)
334}
335
336fn natural_gradient_cold_start_workgroup_with_store(
337 program: &Program,
338 declared: [u32; 3],
339 element_count: u32,
340 limits: LaunchGeometryLimits,
341 persistent_path: Option<&Path>,
342) -> Option<[u32; 3]> {
343 if !is_natural_gradient_launch_tunable(program, declared, element_count) {
344 return None;
345 }
346 let cache_key = NaturalLaunchCacheKey::new(program, declared, element_count, limits);
347 if let Some(cached) = natural_launch_cache_get(cache_key) {
348 return Some(cached);
349 }
350 if let Some(persisted) = natural_launch_cache_get_persisted(cache_key, limits, persistent_path)
351 {
352 natural_launch_cache_set(
353 cache_key,
354 NaturalLaunchEntry {
355 selected: persisted,
356 measurements: BTreeMap::new(),
357 },
358 );
359 return Some(persisted);
360 }
361
362 let selected = select_natural_launch_workgroup(declared, element_count, limits, None)?;
363 natural_launch_cache_set(
364 cache_key,
365 NaturalLaunchEntry {
366 selected,
367 measurements: BTreeMap::new(),
368 },
369 );
370 Some(selected)
371}
372
373fn select_natural_launch_workgroup(
374 declared: [u32; 3],
375 element_count: u32,
376 limits: LaunchGeometryLimits,
377 measurements: Option<&BTreeMap<[u32; 3], u64>>,
378) -> Option<[u32; 3]> {
379 let mut samples = Vec::with_capacity(WORKGROUP_CANDIDATES.len() + 1);
380 for candidate_x in WORKGROUP_CANDIDATES
381 .iter()
382 .copied()
383 .chain(std::iter::once(declared[0]))
384 {
385 if !candidate_x_fits_limits(candidate_x, limits)
386 || samples
387 .iter()
388 .any(|sample: &TuningMeasurement| sample.workgroup_size[0] == candidate_x)
389 {
390 continue;
391 }
392 let workgroup_size = [candidate_x, 1, 1];
393 let elapsed_ns = measurements
394 .and_then(|measured| measured.get(&workgroup_size).copied())
395 .unwrap_or_else(|| estimate_cold_start_latency_ns(element_count, candidate_x));
396 samples.push(TuningMeasurement {
397 workgroup_size,
398 elapsed_ns,
399 });
400 }
401 if let Some(measured) = measurements {
402 for (&workgroup_size, &elapsed_ns) in measured {
403 if workgroup_size[1] != 1
404 || workgroup_size[2] != 1
405 || elapsed_ns == 0
406 || !candidate_x_fits_limits(workgroup_size[0], limits)
407 || samples
408 .iter()
409 .any(|sample| sample.workgroup_size == workgroup_size)
410 {
411 continue;
412 }
413 samples.push(TuningMeasurement {
414 workgroup_size,
415 elapsed_ns,
416 });
417 }
418 }
419
420 if samples.len() < 2 {
421 return None;
422 }
423 NaturalGradientPolicy {
424 temperature_ns: COLD_START_TEMPERATURE_NS,
425 }
426 .suggest(&samples, &identity_fisher_q16(samples.len()))
427 .ok()
428 .map(|step| step.selected_workgroup_size)
429}
430
431#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
432struct NaturalLaunchCacheKey {
433 fingerprint: [u8; 32],
434 declared: [u32; 3],
435 element_count: u32,
436 max_threads_per_block: u32,
437 max_block_dim: [u32; 3],
438 max_grid_dim: [u32; 3],
439}
440
441impl NaturalLaunchCacheKey {
442 fn new(
443 program: &Program,
444 declared: [u32; 3],
445 element_count: u32,
446 limits: LaunchGeometryLimits,
447 ) -> Self {
448 Self {
449 fingerprint: program.fingerprint(),
450 declared,
451 element_count,
452 max_threads_per_block: limits.max_threads_per_block,
453 max_block_dim: limits.max_block_dim,
454 max_grid_dim: limits.max_grid_dim,
455 }
456 }
457
458 fn persistent_key(self) -> String {
459 let mut hasher = blake3::Hasher::new();
460 hasher.update(b"vyre-natural-launch-feedback-v1\0");
461 hasher.update(&self.fingerprint);
462 for axis in self.declared {
463 hasher.update(&axis.to_le_bytes());
464 }
465 hasher.update(&self.element_count.to_le_bytes());
466 hasher.update(&self.max_threads_per_block.to_le_bytes());
467 for axis in self.max_block_dim {
468 hasher.update(&axis.to_le_bytes());
469 }
470 for axis in self.max_grid_dim {
471 hasher.update(&axis.to_le_bytes());
472 }
473 let digest = hasher.finalize();
474 let mut key = String::with_capacity(74);
475 key.push_str("launch-v1-");
476 push_hex(digest.as_bytes(), &mut key);
477 key
478 }
479}
480
481#[derive(Clone, Debug, Eq, PartialEq)]
482
483struct NaturalLaunchEntry {
484 selected: [u32; 3],
485 measurements: BTreeMap<[u32; 3], u64>,
486}
487
488fn natural_launch_cache_get(key: NaturalLaunchCacheKey) -> Option<[u32; 3]> {
489 let cache = NATURAL_LAUNCH_CACHE.get_or_init(|| Mutex::new(BTreeMap::new()));
490 cache
491 .lock()
492 .ok()
493 .and_then(|guard| guard.get(&key).map(|entry| entry.selected))
494}
495
496fn natural_launch_cache_measurements(
497 key: NaturalLaunchCacheKey,
498) -> Option<BTreeMap<[u32; 3], u64>> {
499 let cache = NATURAL_LAUNCH_CACHE.get_or_init(|| Mutex::new(BTreeMap::new()));
500 cache
501 .lock()
502 .ok()
503 .and_then(|guard| guard.get(&key).map(|entry| entry.measurements.clone()))
504}
505
506fn natural_launch_cache_set(key: NaturalLaunchCacheKey, value: NaturalLaunchEntry) {
507 let cache = NATURAL_LAUNCH_CACHE.get_or_init(|| Mutex::new(BTreeMap::new()));
508 if let Ok(mut guard) = cache.lock() {
509 if guard.len() >= MAX_NATURAL_LAUNCH_CACHE_ENTRIES && !guard.contains_key(&key) {
510 if let Some(oldest) = guard.keys().next().copied() {
511 guard.remove(&oldest);
512 }
513 }
514 guard.insert(key, value);
515 }
516}
517
518#[cfg(test)]
519fn natural_launch_cache_remove(key: NaturalLaunchCacheKey) {
520 if let Some(cache) = NATURAL_LAUNCH_CACHE.get() {
521 if let Ok(mut guard) = cache.lock() {
522 guard.remove(&key);
523 }
524 }
525}
526
527fn natural_launch_cache_get_persisted(
528 key: NaturalLaunchCacheKey,
529 limits: LaunchGeometryLimits,
530 persistent_path: Option<&Path>,
531) -> Option<[u32; 3]> {
532 let path = persistent_path
533 .map(Path::to_path_buf)
534 .unwrap_or_else(|| natural_launch_persistent_cache_path(limits));
535 let selected = TunerCache::load(&path).ok()?.get(&key.persistent_key())?;
536 valid_persisted_launch_selection(selected, limits).then_some(selected)
537}
538
539fn persist_natural_launch_selection(
540 key: NaturalLaunchCacheKey,
541 limits: LaunchGeometryLimits,
542 selected: [u32; 3],
543 persistent_path: Option<&Path>,
544) -> Result<(), String> {
545 let path = persistent_path
546 .map(Path::to_path_buf)
547 .unwrap_or_else(|| natural_launch_persistent_cache_path(limits));
548 persist_natural_launch_selection_to_path(key, selected, &path)
549}
550
551fn persist_natural_launch_selection_to_path(
552 key: NaturalLaunchCacheKey,
553 selected: [u32; 3],
554 path: &Path,
555) -> Result<(), String> {
556 let mut cache = TunerCache::load(path)?;
557 while cache.entries.len() >= MAX_NATURAL_LAUNCH_CACHE_ENTRIES {
558 let Some(oldest) = cache.entries.keys().next().cloned() else {
559 break;
560 };
561 cache.entries.remove(&oldest);
562 }
563 cache.set(key.persistent_key(), selected);
564 cache.save(path)
565}
566
567fn natural_launch_persistent_cache_path(limits: LaunchGeometryLimits) -> PathBuf {
568 Tuner::cache_path_for_adapter(&natural_launch_persistent_adapter_key(limits))
569}
570
571fn natural_launch_persistent_adapter_key(limits: LaunchGeometryLimits) -> String {
572 let mut hasher = blake3::Hasher::new();
573 hasher.update(b"vyre-natural-launch-adapter-v1\0");
574 hasher.update(limits.backend.as_bytes());
575 hasher.update(&limits.max_threads_per_block.to_le_bytes());
576 for axis in limits.max_block_dim {
577 hasher.update(&axis.to_le_bytes());
578 }
579 for axis in limits.max_grid_dim {
580 hasher.update(&axis.to_le_bytes());
581 }
582 let digest = hasher.finalize();
583 let mut key = String::with_capacity(92);
584 key.push_str("natural-launch-feedback-v1-");
585 push_hex(digest.as_bytes(), &mut key);
586 key
587}
588
589fn valid_persisted_launch_selection(selected: [u32; 3], limits: LaunchGeometryLimits) -> bool {
590 selected[1] == 1 && selected[2] == 1 && candidate_x_fits_limits(selected[0], limits)
591}
592
593fn push_hex(bytes: &[u8], out: &mut String) {
594 const HEX: &[u8; 16] = b"0123456789abcdef";
595 for &byte in bytes {
596 out.push(HEX[(byte >> 4) as usize] as char);
597 out.push(HEX[(byte & 0x0f) as usize] as char);
598 }
599}
600
601fn is_natural_gradient_launch_tunable(
602 program: &Program,
603 declared: [u32; 3],
604 element_count: u32,
605) -> bool {
606 declared[0] != 0
607 && declared[1] == 1
608 && declared[2] == 1
609 && element_count != 0
610 && program
611 .entry
612 .iter()
613 .any(|node| !matches!(node, Node::Return))
614 && !program.non_composable_with_self
615 && !program_uses_launch_geometry_ids(program)
616 && program
617 .buffers
618 .iter()
619 .all(|buffer| buffer.kind() != MemoryKind::Shared)
620}
621
622fn candidate_x_fits_limits(candidate_x: u32, limits: LaunchGeometryLimits) -> bool {
623 candidate_x != 0
624 && candidate_x <= limits.max_threads_per_block
625 && candidate_x <= limits.max_block_dim[0]
626}
627
628fn estimate_cold_start_latency_ns(element_count: u32, candidate_x: u32) -> u64 {
629 let groups = u64::from(element_count.div_ceil(candidate_x));
630 let scheduled_lanes = groups.saturating_mul(u64::from(candidate_x));
631 let idle_lanes = scheduled_lanes.saturating_sub(u64::from(element_count));
632 groups
633 .saturating_mul(COLD_START_GRID_STEP_NS)
634 .saturating_add(idle_lanes.saturating_mul(COLD_START_IDLE_LANE_NS))
635}
636
637#[must_use]
639pub fn program_vsa_fingerprint(program: &Program) -> Vec<u32> {
640 program_vsa_fingerprint_words(program).to_vec()
641}
642
643#[must_use]
645pub fn program_vsa_fingerprint_words(program: &Program) -> [u32; 8] {
646 let fingerprint = program.fingerprint();
647 let mut words = [0u32; 8];
648 for (word, chunk) in words.iter_mut().zip(fingerprint.chunks_exact(4)) {
649 *word = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
650 }
651 words
652}
653
654#[cfg(test)]
655mod tests {
656 use super::*;
657 use crate::binding::BindingRole;
658 use vyre_foundation::ir::{BufferDecl, DataType, Expr, Node, Program};
659
660 #[test]
661 fn program_vsa_fingerprint_words_match_wire_decoder() {
662 let program = Program::wrapped(vec![], [64, 1, 1], vec![]);
663 let words = program_vsa_fingerprint_words(&program);
664 let fingerprint = program.fingerprint();
665
666 for (index, chunk) in fingerprint.chunks_exact(4).enumerate() {
667 assert_eq!(
668 words[index],
669 u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
670 );
671 }
672 assert_eq!(program_vsa_fingerprint(&program), words.to_vec());
673 }
674
675 #[test]
676 fn launch_plan_prepare_into_reuses_param_words() {
677 let program = Program::wrapped(vec![], [64, 1, 1], vec![]);
678 let bindings = vec![Binding {
679 name: std::sync::Arc::from("input"),
680 binding: 0,
681 buffer_index: 0,
682 role: BindingRole::Input,
683 element_size: 4,
684 preferred_alignment: 64,
685 element_count: 7,
686 static_byte_len: Some(28),
687 input_index: Some(0),
688 output_index: None,
689 }];
690 let limits = LaunchGeometryLimits {
691 backend: "test",
692 max_threads_per_block: 1024,
693 max_block_dim: [1024, 1024, 64],
694 max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
695 };
696 let mut plan = LaunchPlan {
697 param_words: Vec::with_capacity(8),
698 ..LaunchPlan::new()
699 };
700 let ptr = plan.param_words.as_ptr();
701 plan.prepare_into(&program, &bindings, &DispatchConfig::default(), limits)
702 .unwrap();
703 assert_eq!(plan.element_count, 7);
704 assert_eq!(plan.grid, [1, 1, 1]);
705 assert_eq!(plan.param_words, vec![7, 7]);
706 assert_eq!(plan.max_binding_alignment, 64);
707 assert_eq!(plan.param_words.as_ptr(), ptr);
708 }
709
710 #[test]
711 fn natural_gradient_launch_tunes_safe_1d_storage_program() {
712 let program = Program::wrapped(
713 vec![BufferDecl::output("out", 0, DataType::U32).with_count(4096)],
714 [32, 1, 1],
715 vec![],
716 );
717 let bindings = vec![Binding {
718 name: std::sync::Arc::from("out"),
719 binding: 0,
720 buffer_index: 0,
721 role: BindingRole::Output,
722 element_size: 4,
723 preferred_alignment: 128,
724 element_count: 4096,
725 static_byte_len: Some(16_384),
726 input_index: None,
727 output_index: Some(0),
728 }];
729 let limits = LaunchGeometryLimits {
730 backend: "test",
731 max_threads_per_block: 1024,
732 max_block_dim: [1024, 1024, 64],
733 max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
734 };
735 let mut plan = LaunchPlan::new();
736
737 plan.prepare_into_for_mode(
738 &program,
739 &bindings,
740 &DispatchConfig::default(),
741 limits,
742 Mode::NaturalGradient,
743 )
744 .expect("Fix: safe 1D storage launch should accept natural-gradient cold start");
745
746 assert_eq!(plan.workgroup, [1024, 1, 1]);
747 assert_eq!(plan.grid, [4, 1, 1]);
748 assert_eq!(plan.element_count, 4096);
749 }
750
751 #[test]
752 fn natural_gradient_launch_preserves_declared_shape_for_local_workgroup_ids() {
753 let program = Program::wrapped(
754 vec![BufferDecl::output("out_local_ids", 0, DataType::U32).with_count(4096)],
755 [1024, 1, 1],
756 vec![
757 Node::let_bind("lane", Expr::LocalId { axis: 0 }),
758 Node::let_bind("block", Expr::WorkgroupId { axis: 0 }),
759 Node::let_bind(
760 "global",
761 Expr::add(
762 Expr::mul(Expr::var("block"), Expr::u32(1024)),
763 Expr::var("lane"),
764 ),
765 ),
766 Node::store("out_local_ids", Expr::var("global"), Expr::var("lane")),
767 ],
768 );
769 let bindings = vec![Binding {
770 name: std::sync::Arc::from("out_local_ids"),
771 binding: 0,
772 buffer_index: 0,
773 role: BindingRole::Output,
774 element_size: 4,
775 preferred_alignment: 128,
776 element_count: 4096,
777 static_byte_len: Some(16_384),
778 input_index: None,
779 output_index: Some(0),
780 }];
781 let limits = LaunchGeometryLimits {
782 backend: "test",
783 max_threads_per_block: 1024,
784 max_block_dim: [1024, 1024, 64],
785 max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
786 };
787
788 assert_eq!(
789 effective_launch_workgroup_for_mode(
790 &program,
791 &bindings,
792 &DispatchConfig::default(),
793 limits,
794 Mode::NaturalGradient,
795 ),
796 [1024, 1, 1],
797 "Fix: automatic launch tuning must not change kernels whose LocalId/WorkgroupId arithmetic makes workgroup shape semantic."
798 );
799 }
800
801 #[test]
802 fn measured_launch_feedback_overrides_heuristic_cold_start() {
803 let dir = tempfile::tempdir()
804 .expect("Fix: measured launch feedback test needs an isolated tuner cache");
805 let path = dir.path().join("launch-feedback.toml");
806 let program = Program::wrapped(
807 vec![BufferDecl::output("out_feedback_isolated", 0, DataType::U32).with_count(8192)],
808 [32, 1, 1],
809 vec![],
810 );
811 let config = DispatchConfig::default();
812 let limits = LaunchGeometryLimits {
813 backend: "test",
814 max_threads_per_block: 1024,
815 max_block_dim: [1024, 1024, 64],
816 max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
817 };
818 let key = NaturalLaunchCacheKey::new(&program, [32, 1, 1], 8192, limits);
819 natural_launch_cache_remove(key);
820
821 assert_eq!(
822 natural_gradient_cold_start_workgroup_with_store(
823 &program,
824 [32, 1, 1],
825 8192,
826 limits,
827 Some(&path),
828 ),
829 Some([1024, 1, 1]),
830 "Fix: baseline heuristic should pick the occupancy-efficient cold-start shape."
831 );
832 assert!(
833 record_launch_measurement_for_mode_with_store(
834 &program,
835 &config,
836 limits,
837 8192,
838 [64, 1, 1],
839 1,
840 Mode::NaturalGradient,
841 Some(&path),
842 ),
843 "Fix: natural-gradient resolver must accept measured backend timing for safe 1D launches."
844 );
845 assert_eq!(
846 natural_gradient_cold_start_workgroup_with_store(
847 &program,
848 [32, 1, 1],
849 8192,
850 limits,
851 Some(&path),
852 ),
853 Some([64, 1, 1]),
854 "Fix: measured launch feedback must steer future automatic launch choices."
855 );
856 }
857
858 #[test]
859 fn persisted_launch_feedback_rehydrates_measured_selection() {
860 let dir = tempfile::tempdir()
861 .expect("Fix: launch feedback persistence test needs a temporary cache directory");
862 let path = dir.path().join("launch-feedback.toml");
863 let program = Program::wrapped(
864 vec![BufferDecl::output("out_persisted", 0, DataType::U32).with_count(16_384)],
865 [32, 1, 1],
866 vec![],
867 );
868 let limits = LaunchGeometryLimits {
869 backend: "test",
870 max_threads_per_block: 1024,
871 max_block_dim: [1024, 1024, 64],
872 max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
873 };
874 let key = NaturalLaunchCacheKey::new(&program, [32, 1, 1], 16_384, limits);
875 natural_launch_cache_remove(key);
876
877 persist_natural_launch_selection_to_path(key, [64, 1, 1], &path)
878 .expect("Fix: measured launch feedback should persist through the tuner cache format");
879
880 assert_eq!(
881 natural_gradient_cold_start_workgroup_with_store(
882 &program,
883 [32, 1, 1],
884 16_384,
885 limits,
886 Some(&path),
887 ),
888 Some([64, 1, 1]),
889 "Fix: automatic launch resolution must rehydrate measured feedback from the bounded tuner cache before falling back to heuristics."
890 );
891 }
892
893 #[test]
894 fn natural_gradient_launch_preserves_explicit_and_shared_memory_shapes() {
895 let program = Program::wrapped(
896 vec![
897 BufferDecl::output("out", 0, DataType::U32).with_count(4096),
898 BufferDecl::workgroup("scratch", 64, DataType::U32),
899 ],
900 [64, 1, 1],
901 vec![],
902 );
903 let bindings = vec![Binding {
904 name: std::sync::Arc::from("out"),
905 binding: 0,
906 buffer_index: 0,
907 role: BindingRole::Output,
908 element_size: 4,
909 preferred_alignment: 128,
910 element_count: 4096,
911 static_byte_len: Some(16_384),
912 input_index: None,
913 output_index: Some(0),
914 }];
915 let limits = LaunchGeometryLimits {
916 backend: "test",
917 max_threads_per_block: 1024,
918 max_block_dim: [1024, 1024, 64],
919 max_grid_dim: [u32::MAX, u32::MAX, u32::MAX],
920 };
921 let mut config = DispatchConfig::default();
922 config.workgroup_override = Some([256, 1, 1]);
923
924 assert_eq!(
925 effective_launch_workgroup_for_mode(
926 &program,
927 &bindings,
928 &config,
929 limits,
930 Mode::NaturalGradient,
931 ),
932 [256, 1, 1],
933 "Fix: explicit dispatch workgroup overrides must remain authoritative."
934 );
935
936 let default_config = DispatchConfig::default();
937 assert_eq!(
938 effective_launch_workgroup_for_mode(
939 &program,
940 &bindings,
941 &default_config,
942 limits,
943 Mode::NaturalGradient,
944 ),
945 [64, 1, 1],
946 "Fix: workgroup-local scratch kernels must keep their declared shape."
947 );
948 }
949}