1use std::collections::BTreeMap;
4use std::fmt::Write as _;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use vyre_foundation::ir::Program;
9
10pub const WORKGROUP_CANDIDATES: &[u32] = &[32, 64, 128, 256, 512, 1024];
13const AUTOTUNER_ENV: &str = "VYRE_AUTOTUNER";
14const MAX_TUNER_CACHE_BYTES: u64 = 4 * 1024 * 1024;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[non_exhaustive]
19pub enum Mode {
20 On,
22 NaturalGradient,
24 OffUseDefault,
26}
27
28impl Mode {
29 #[must_use]
35 pub const fn production_default() -> Self {
36 Mode::NaturalGradient
37 }
38
39 #[must_use]
41 pub fn from_env() -> Self {
42 match std::env::var(AUTOTUNER_ENV).ok() {
43 Some(value) => Self::from_env_value(Some(value.as_str())),
44 None => Self::production_default(),
45 }
46 }
47
48 fn from_env_value(value: Option<&str>) -> Self {
49 match value {
50 Some("on") => Mode::On,
51 Some("natural" | "ng") => Mode::NaturalGradient,
52 Some("off" | "default") => Mode::OffUseDefault,
53 Some(_) => Self::production_default(),
54 None => Self::production_default(),
55 }
56 }
57}
58
59pub trait BackendTimer {
61 type Error;
63
64 fn measure_candidate_ns(
71 &mut self,
72 program: &Program,
73 workgroup_size: [u32; 3],
74 ) -> Result<u64, Self::Error>;
75}
76
77#[derive(Debug, Default, Clone, PartialEq, Eq)]
79pub struct TunerCache {
80 pub entries: BTreeMap<String, [u32; 3]>,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub struct StaticProgramShape {
87 pub workgroup_size: [u32; 3],
89 pub workgroup_count: Option<[u32; 3]>,
91 pub output_bytes: u64,
93}
94
95impl StaticProgramShape {
96 #[must_use]
98 pub fn new(program: &Program, workgroup_count: Option<[u32; 3]>, output_bytes: u64) -> Self {
99 Self {
100 workgroup_size: program.workgroup_size(),
101 workgroup_count,
102 output_bytes,
103 }
104 }
105}
106
107#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
109pub struct TunerProgramKey(String);
110
111impl TunerProgramKey {
112 #[must_use]
114 pub fn from_program(program: &Program, shape: StaticProgramShape) -> Self {
115 let mut hasher = blake3::Hasher::new();
116 hasher.update(b"vyre-driver-workgroup-tuner-v1\0program\0");
117 hasher.update(&program.fingerprint());
118 hasher.update(b"\0workgroup-size\0");
119 for axis in shape.workgroup_size {
120 hasher.update(&axis.to_le_bytes());
121 }
122 hasher.update(b"\0workgroup-count\0");
123 match shape.workgroup_count {
124 Some(count) => {
125 hasher.update(&[1]);
126 for axis in count {
127 hasher.update(&axis.to_le_bytes());
128 }
129 }
130 None => {
131 hasher.update(&[0]);
132 }
133 }
134 hasher.update(b"\0output-bytes\0");
135 hasher.update(&shape.output_bytes.to_le_bytes());
136 let digest = hasher.finalize();
137 let mut key = String::with_capacity(67);
138 key.push_str("v1-");
139 push_hex(digest.as_bytes(), &mut key);
140 Self(key)
141 }
142
143 #[must_use]
145 pub fn as_str(&self) -> &str {
146 &self.0
147 }
148}
149
150fn push_hex(bytes: &[u8], out: &mut String) {
151 const HEX: &[u8; 16] = b"0123456789abcdef";
152 for &byte in bytes {
153 out.push(HEX[(byte >> 4) as usize] as char);
154 out.push(HEX[(byte & 0x0f) as usize] as char);
155 }
156}
157
158impl AsRef<str> for TunerProgramKey {
159 fn as_ref(&self) -> &str {
160 self.as_str()
161 }
162}
163
164impl TunerCache {
165 #[must_use]
167 pub fn get(&self, program_fp: &str) -> Option<[u32; 3]> {
168 self.entries.get(program_fp).copied()
169 }
170
171 #[must_use]
173 pub fn get_key(&self, key: &TunerProgramKey) -> Option<[u32; 3]> {
174 self.get(key.as_str())
175 }
176
177 pub fn set(&mut self, program_fp: impl Into<String>, size: [u32; 3]) {
179 self.entries.insert(program_fp.into(), size);
180 }
181
182 pub fn set_key(&mut self, key: TunerProgramKey, size: [u32; 3]) {
187 self.entries.insert(key.0, size);
188 }
189
190 pub fn load(path: &Path) -> Result<Self, String> {
196 let Ok(contents) = read_tuner_cache_bounded(path) else {
197 return Ok(Self::default());
198 };
199 let parsed: toml::Value = toml::from_str(&contents).map_err(|error| {
200 format!(
201 "Fix: tuner cache `{}` is not valid TOML: {error}",
202 path.display()
203 )
204 })?;
205 let mut entries = BTreeMap::new();
206 if let Some(table) = parsed.as_table() {
207 for (key, value) in table {
208 if let Some(array) = value.as_array() {
209 if array.len() == 3 {
210 let mut triple = [0u32; 3];
211 for (index, value) in array.iter().enumerate() {
212 if let Some(number) = value.as_integer() {
213 if let Ok(converted) = u32::try_from(number) {
214 triple[index] = converted;
215 }
216 }
217 }
218 entries.insert(key.clone(), triple);
219 }
220 }
221 }
222 }
223 Ok(Self { entries })
224 }
225
226 pub fn save(&self, path: &Path) -> Result<(), String> {
233 if let Some(parent) = path.parent() {
234 fs::create_dir_all(parent).map_err(|error| {
235 format!(
236 "Fix: could not create tuner cache directory {}: {error}",
237 parent.display()
238 )
239 })?;
240 }
241 let mut out = String::with_capacity(tuner_cache_string_capacity(self.entries.len()));
242 for (key, size) in &self.entries {
243 let _ = writeln!(out, "\"{}\" = [{}, {}, {}]", key, size[0], size[1], size[2]);
244 }
245 fs::write(path, &out).map_err(|error| {
246 format!(
247 "Fix: could not write tuner cache {}: {error}",
248 path.display()
249 )
250 })
251 }
252}
253
254fn read_tuner_cache_bounded(path: &Path) -> std::io::Result<String> {
255 use std::io::Read as _;
256
257 let mut file = fs::File::open(path)?;
258 let metadata = file.metadata()?;
259 if metadata.len() > MAX_TUNER_CACHE_BYTES {
260 return Err(std::io::Error::new(
261 std::io::ErrorKind::InvalidData,
262 format!("tuner cache exceeds {MAX_TUNER_CACHE_BYTES} byte limit"),
263 ));
264 }
265 let mut text = String::with_capacity(metadata.len() as usize);
266 file.by_ref()
267 .take(MAX_TUNER_CACHE_BYTES + 1)
268 .read_to_string(&mut text)?;
269 if text.len() as u64 > MAX_TUNER_CACHE_BYTES {
270 return Err(std::io::Error::new(
271 std::io::ErrorKind::InvalidData,
272 "tuner cache exceeded bounded read limit",
273 ));
274 }
275 Ok(text)
276}
277
278#[derive(Debug, Clone, Copy, PartialEq, Eq)]
280pub struct TuningMeasurement {
281 pub workgroup_size: [u32; 3],
283 pub elapsed_ns: u64,
285}
286
287pub const Q16_ONE: u32 = 1 << 16;
289
290#[derive(Debug, Clone, Copy, PartialEq, Eq)]
300pub struct NaturalGradientPolicy {
301 pub temperature_ns: u64,
303}
304
305impl Default for NaturalGradientPolicy {
306 fn default() -> Self {
307 Self {
308 temperature_ns: 10_000,
309 }
310 }
311}
312
313#[derive(Debug, Clone, PartialEq, Eq)]
315pub struct NaturalGradientTuningStep {
316 pub selected_workgroup_size: [u32; 3],
318 pub best_measured_workgroup_size: [u32; 3],
320 pub best_measured_elapsed_ns: u64,
322 pub policy_weights_q16: Vec<u32>,
324 pub natural_gradient_q16: Vec<u32>,
326}
327
328#[derive(Debug, Clone, PartialEq, Eq)]
330#[non_exhaustive]
331pub enum NaturalGradientTuningError {
332 EmptyMeasurements,
334 FisherMatrixShape {
336 measurements: usize,
338 cells: usize,
340 },
341 ZeroTemperature,
343}
344
345impl std::fmt::Display for NaturalGradientTuningError {
346 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 match self {
348 Self::EmptyMeasurements => {
349 write!(
350 f,
351 "natural-gradient tuner received no measurements. Fix: measure at least one candidate before policy update."
352 )
353 }
354 Self::FisherMatrixShape {
355 measurements,
356 cells,
357 } => write!(
358 f,
359 "natural-gradient tuner expected an inverse-Fisher matrix with {} cells for {measurements} measurement(s), got {cells}. Fix: pass an n*n 16.16 matrix.",
360 measurements.saturating_mul(*measurements)
361 ),
362 Self::ZeroTemperature => {
363 write!(
364 f,
365 "natural-gradient tuner temperature is zero. Fix: use a positive temperature_ns."
366 )
367 }
368 }
369 }
370}
371
372impl std::error::Error for NaturalGradientTuningError {}
373
374impl NaturalGradientPolicy {
375 pub fn suggest(
389 &self,
390 measurements: &[TuningMeasurement],
391 fisher_inv_sqrt_q16: &[u32],
392 ) -> Result<NaturalGradientTuningStep, NaturalGradientTuningError> {
393 if measurements.is_empty() {
394 return Err(NaturalGradientTuningError::EmptyMeasurements);
395 }
396 if self.temperature_ns == 0 {
397 return Err(NaturalGradientTuningError::ZeroTemperature);
398 }
399 let expected_cells = measurements.len().checked_mul(measurements.len()).ok_or(
400 NaturalGradientTuningError::FisherMatrixShape {
401 measurements: measurements.len(),
402 cells: fisher_inv_sqrt_q16.len(),
403 },
404 )?;
405 if fisher_inv_sqrt_q16.len() != expected_cells {
406 return Err(NaturalGradientTuningError::FisherMatrixShape {
407 measurements: measurements.len(),
408 cells: fisher_inv_sqrt_q16.len(),
409 });
410 }
411
412 let mut best_index = 0usize;
413 let mut best_elapsed = measurements[0].elapsed_ns;
414 for (index, measurement) in measurements.iter().enumerate().skip(1) {
415 if measurement.elapsed_ns < best_elapsed {
416 best_index = index;
417 best_elapsed = measurement.elapsed_ns;
418 }
419 }
420
421 let policy_weights_q16 =
422 latency_softmax_weights_q16(measurements, best_elapsed, self.temperature_ns);
423 let natural_gradient_q16 =
424 precondition_q16(fisher_inv_sqrt_q16, &policy_weights_q16, measurements.len());
425 let selected_index = natural_gradient_q16
426 .iter()
427 .enumerate()
428 .max_by_key(|(_, value)| *value)
429 .map(|(index, _)| index)
430 .unwrap_or(best_index);
431
432 Ok(NaturalGradientTuningStep {
433 selected_workgroup_size: measurements[selected_index].workgroup_size,
434 best_measured_workgroup_size: measurements[best_index].workgroup_size,
435 best_measured_elapsed_ns: best_elapsed,
436 policy_weights_q16,
437 natural_gradient_q16,
438 })
439 }
440}
441
442#[must_use]
444pub fn identity_fisher_q16(candidate_count: usize) -> Vec<u32> {
445 let mut out = Vec::new();
446 identity_fisher_q16_into(candidate_count, &mut out);
447 out
448}
449
450pub fn identity_fisher_q16_into(candidate_count: usize, out: &mut Vec<u32>) {
454 let Some(cells) = candidate_count.checked_mul(candidate_count) else {
455 out.clear();
456 return;
457 };
458 out.clear();
459 out.resize(cells, 0);
460 for index in 0..candidate_count {
461 out[index * candidate_count + index] = Q16_ONE;
462 }
463}
464
465fn latency_softmax_weights_q16(
466 measurements: &[TuningMeasurement],
467 best_elapsed: u64,
468 temperature_ns: u64,
469) -> Vec<u32> {
470 let temperature = temperature_ns as f64;
471 let mut weights = Vec::with_capacity(measurements.len());
472 let mut sum = 0.0f64;
473 for measurement in measurements {
474 let penalty = measurement.elapsed_ns.saturating_sub(best_elapsed) as f64;
475 let weight = (-penalty / temperature).exp();
476 weights.push(weight);
477 sum += weight;
478 }
479 let mut out = Vec::with_capacity(measurements.len());
480 let mut assigned = 0u32;
481 for (index, weight) in weights.iter().enumerate() {
482 if index + 1 == weights.len() {
483 out.push(Q16_ONE.saturating_sub(assigned));
484 break;
485 }
486 let q16 = ((*weight / sum) * f64::from(Q16_ONE)).round() as u32;
487 let remaining = Q16_ONE.saturating_sub(assigned);
488 let q16 = q16.min(remaining);
489 assigned = assigned.saturating_add(q16);
490 out.push(q16);
491 }
492 out
493}
494
495fn precondition_q16(matrix_q16: &[u32], gradient_q16: &[u32], n: usize) -> Vec<u32> {
496 let mut out = vec![0u32; n];
497 for row in 0..n {
498 let mut acc = 0u64;
499 for col in 0..n {
500 let matrix = u64::from(matrix_q16[row * n + col]);
501 let gradient = u64::from(gradient_q16[col]);
502 acc = acc.saturating_add((matrix.saturating_mul(gradient)) >> 16);
503 }
504 out[row] = acc.min(u64::from(u32::MAX)) as u32;
505 }
506 out
507}
508
509pub struct Tuner {
511 mode: Mode,
512 cache: TunerCache,
513 cache_path: PathBuf,
514}
515
516impl Tuner {
517 #[must_use]
519 pub fn new(adapter_fp: &str, mode: Mode) -> Self {
520 let cache_path = Self::cache_path_for_adapter(adapter_fp);
521 let cache = TunerCache::load(&cache_path).unwrap_or_default();
522 Self {
523 mode,
524 cache,
525 cache_path,
526 }
527 }
528
529 #[must_use]
531 pub fn cache_path_for_adapter(adapter_fp: &str) -> PathBuf {
532 let mut home = dirs_cache_root();
533 home.push("vyre");
534 home.push("tuner");
535 home.push(format!("{adapter_fp}.toml"));
536 home
537 }
538
539 #[must_use]
541 pub fn candidates_for(&self, max_invocations: u32) -> Vec<u32> {
542 let mut candidates = Vec::new();
543 let _ = candidates.try_reserve_exact(WORKGROUP_CANDIDATES.len());
544 candidates.extend(
545 WORKGROUP_CANDIDATES
546 .iter()
547 .copied()
548 .filter(|candidate| *candidate <= max_invocations),
549 );
550 candidates
551 }
552
553 #[must_use]
555 pub const fn default_workgroup_size() -> [u32; 3] {
556 crate::pipeline::DEFAULT_1D_WORKGROUP_SIZE
557 }
558
559 #[must_use]
561 pub const fn mode(&self) -> Mode {
562 self.mode
563 }
564
565 #[must_use]
567 pub fn resolve(&self, program_fp: &str) -> [u32; 3] {
568 self.cache
569 .get(program_fp)
570 .unwrap_or_else(Self::default_workgroup_size)
571 }
572
573 #[must_use]
575 pub fn resolve_key(&self, key: &TunerProgramKey) -> [u32; 3] {
576 self.resolve(key.as_str())
577 }
578
579 pub fn record_decision(&mut self, program_fp: impl Into<String>, size: [u32; 3]) {
581 self.cache.set(program_fp, size);
582 }
583
584 pub fn record_key_decision(&mut self, key: TunerProgramKey, size: [u32; 3]) {
586 self.cache.set_key(key, size);
587 }
588
589 pub fn best_of<T: BackendTimer>(
595 &self,
596 program: &Program,
597 candidates: impl IntoIterator<Item = [u32; 3]>,
598 timer: &mut T,
599 ) -> Result<Option<TuningMeasurement>, T::Error> {
600 let mut best = None;
601 for workgroup_size in candidates {
602 let elapsed_ns = timer.measure_candidate_ns(program, workgroup_size)?;
603 let measurement = TuningMeasurement {
604 workgroup_size,
605 elapsed_ns,
606 };
607 if best
608 .map(|current: TuningMeasurement| elapsed_ns < current.elapsed_ns)
609 .unwrap_or(true)
610 {
611 best = Some(measurement);
612 }
613 }
614 Ok(best)
615 }
616
617 pub fn best_of_natural_gradient<T: BackendTimer>(
631 &self,
632 program: &Program,
633 candidates: impl IntoIterator<Item = [u32; 3]>,
634 timer: &mut T,
635 fisher_inv_sqrt_q16: &[u32],
636 policy: NaturalGradientPolicy,
637 ) -> Result<Result<NaturalGradientTuningStep, NaturalGradientTuningError>, T::Error> {
638 let mut measurements = Vec::new();
639 for workgroup_size in candidates {
640 let elapsed_ns = timer.measure_candidate_ns(program, workgroup_size)?;
641 measurements.push(TuningMeasurement {
642 workgroup_size,
643 elapsed_ns,
644 });
645 }
646 Ok(policy.suggest(&measurements, fisher_inv_sqrt_q16))
647 }
648
649 pub fn natural_gradient_step(
661 &self,
662 measurements: &[TuningMeasurement],
663 fisher_inv_sqrt_q16: &[u32],
664 policy: NaturalGradientPolicy,
665 ) -> Result<NaturalGradientTuningStep, NaturalGradientTuningError> {
666 policy.suggest(measurements, fisher_inv_sqrt_q16)
667 }
668
669 pub fn persist(&self) -> Result<(), String> {
675 self.cache.save(&self.cache_path)
676 }
677}
678
679#[derive(Debug, Clone)]
681pub struct TunerFeedback {
682 pub per_opcode_counts: Vec<(u32, u32)>,
684 pub wall_time_us: u64,
686 pub idle_us: u64,
688 pub observed_workgroup_size_x: u32,
690 pub observed_throughput_per_us: f64,
692}
693
694#[derive(Debug, Clone)]
696pub struct DefaultPolicy {
697 pub adapter_max_workgroup_size_x: u32,
699 pub minimum_workgroup_size_x: u32,
701 pub saturation_threshold_per_us: f64,
703 pub idle_shrink_us: u64,
705}
706
707impl Default for DefaultPolicy {
708 fn default() -> Self {
709 Self {
710 adapter_max_workgroup_size_x: 1024,
711 minimum_workgroup_size_x: 32,
712 saturation_threshold_per_us: 1.0,
713 idle_shrink_us: 100_000,
714 }
715 }
716}
717
718impl DefaultPolicy {
719 #[must_use]
721 pub fn suggest_resize(&self, feedback: &TunerFeedback) -> Option<u32> {
722 let current = feedback.observed_workgroup_size_x.max(1);
723 if feedback.idle_us > self.idle_shrink_us {
724 let shrunk = current / 2;
725 if shrunk >= self.minimum_workgroup_size_x && shrunk != current {
726 return Some(shrunk);
727 }
728 return None;
729 }
730 if feedback.observed_throughput_per_us < self.saturation_threshold_per_us {
731 let grown = current.checked_mul(2)?;
732 if grown <= self.adapter_max_workgroup_size_x && grown != current {
733 return Some(grown);
734 }
735 }
736 None
737 }
738}
739
740fn tuner_cache_string_capacity(entries: usize) -> usize {
741 entries.saturating_mul(96)
742}
743
744fn dirs_cache_root() -> PathBuf {
745 if let Some(xdg) = std::env::var_os("XDG_CACHE_HOME") {
746 PathBuf::from(xdg)
747 } else if let Some(home) = std::env::var_os("HOME") {
748 let mut path = PathBuf::from(home);
749 path.push(".cache");
750 path
751 } else {
752 PathBuf::from(".")
753 }
754}
755
756#[cfg(test)]
757mod tests {
758 use super::*;
759
760 fn measurements() -> Vec<TuningMeasurement> {
761 vec![
762 TuningMeasurement {
763 workgroup_size: [64, 1, 1],
764 elapsed_ns: 12_000,
765 },
766 TuningMeasurement {
767 workgroup_size: [128, 1, 1],
768 elapsed_ns: 8_000,
769 },
770 TuningMeasurement {
771 workgroup_size: [256, 1, 1],
772 elapsed_ns: 10_000,
773 },
774 ]
775 }
776
777 struct StaticTimer {
778 fail_on: Option<u32>,
779 measured: Vec<[u32; 3]>,
780 }
781
782 impl StaticTimer {
783 fn new() -> Self {
784 Self {
785 fail_on: None,
786 measured: Vec::new(),
787 }
788 }
789
790 fn failing(fail_on: u32) -> Self {
791 Self {
792 fail_on: Some(fail_on),
793 measured: Vec::new(),
794 }
795 }
796 }
797
798 impl BackendTimer for StaticTimer {
799 type Error = &'static str;
800
801 fn measure_candidate_ns(
802 &mut self,
803 _program: &Program,
804 workgroup_size: [u32; 3],
805 ) -> Result<u64, Self::Error> {
806 self.measured.push(workgroup_size);
807 if self.fail_on == Some(workgroup_size[0]) {
808 return Err("timer failed");
809 }
810 Ok(match workgroup_size[0] {
811 64 => 12_000,
812 128 => 8_000,
813 256 => 10_000,
814 _ => 50_000,
815 })
816 }
817 }
818
819 fn empty_program() -> Program {
820 Program::wrapped(Vec::new(), [64, 1, 1], Vec::new())
821 }
822
823 #[test]
824 fn unset_autotuner_mode_defaults_to_natural_gradient_release_path() {
825 assert_eq!(Mode::production_default(), Mode::NaturalGradient);
826 assert_eq!(Mode::from_env_value(None), Mode::NaturalGradient);
827 }
828
829 #[test]
830 fn explicit_env_modes_preserve_escape_hatches() {
831 assert_eq!(Mode::from_env_value(Some("natural")), Mode::NaturalGradient);
832 assert_eq!(Mode::from_env_value(Some("ng")), Mode::NaturalGradient);
833 assert_eq!(Mode::from_env_value(Some("on")), Mode::On);
834 assert_eq!(Mode::from_env_value(Some("off")), Mode::OffUseDefault);
835 assert_eq!(Mode::from_env_value(Some("default")), Mode::OffUseDefault);
836 }
837
838 #[test]
839 fn identity_fisher_preserves_fastest_candidate_policy_gradient() {
840 let policy = NaturalGradientPolicy {
841 temperature_ns: 4_000,
842 };
843 let samples = measurements();
844 let step = policy
845 .suggest(&samples, &identity_fisher_q16(samples.len()))
846 .expect("Fix: identity Fisher natural-gradient update should be valid");
847
848 assert_eq!(step.best_measured_workgroup_size, [128, 1, 1]);
849 assert_eq!(step.selected_workgroup_size, [128, 1, 1]);
850 assert_eq!(step.best_measured_elapsed_ns, 8_000);
851 }
852
853 #[test]
854 fn anisotropic_fisher_can_redirect_next_probe_without_changing_measurement_winner() {
855 let policy = NaturalGradientPolicy {
856 temperature_ns: 4_000,
857 };
858 let samples = measurements();
859 let mut fisher = identity_fisher_q16(samples.len());
860 fisher[0] = Q16_ONE * 8;
861
862 let step = policy
863 .suggest(&samples, &fisher)
864 .expect("Fix: diagonal Fisher natural-gradient update should be valid");
865
866 assert_eq!(step.best_measured_workgroup_size, [128, 1, 1]);
867 assert_eq!(
868 step.selected_workgroup_size,
869 [64, 1, 1],
870 "Fix: Fisher geometry must be able to steer exploration away from the raw fastest sample."
871 );
872 assert!(
873 step.natural_gradient_q16[0] > step.natural_gradient_q16[1],
874 "Fix: preconditioned gradient should reflect the anisotropic Fisher block."
875 );
876 }
877
878 #[test]
879 fn softmax_weights_conserve_q16_probability_mass_across_hostile_latencies() {
880 let policy = NaturalGradientPolicy { temperature_ns: 1 };
881 for base in [0_u64, 1, 10, 1_000, u64::MAX - 2] {
882 let samples = vec![
883 TuningMeasurement {
884 workgroup_size: [32, 1, 1],
885 elapsed_ns: base,
886 },
887 TuningMeasurement {
888 workgroup_size: [64, 1, 1],
889 elapsed_ns: base.saturating_add(1),
890 },
891 TuningMeasurement {
892 workgroup_size: [128, 1, 1],
893 elapsed_ns: base.saturating_add(2),
894 },
895 ];
896 let step = policy
897 .suggest(&samples, &identity_fisher_q16(samples.len()))
898 .expect("Fix: hostile latency range should still produce a normalized policy");
899 let total: u32 = step.policy_weights_q16.iter().sum();
900 assert_eq!(
901 total, Q16_ONE,
902 "Fix: fixed-point policy weights must conserve probability mass for base={base}."
903 );
904 }
905 }
906
907 #[test]
908 fn rejects_empty_measurements_zero_temperature_and_bad_fisher_shape() {
909 let policy = NaturalGradientPolicy::default();
910 assert_eq!(
911 policy.suggest(&[], &[]),
912 Err(NaturalGradientTuningError::EmptyMeasurements)
913 );
914
915 let samples = measurements();
916 let zero_temp = NaturalGradientPolicy { temperature_ns: 0 };
917 assert_eq!(
918 zero_temp.suggest(&samples, &identity_fisher_q16(samples.len())),
919 Err(NaturalGradientTuningError::ZeroTemperature)
920 );
921 assert_eq!(
922 policy.suggest(&samples, &[Q16_ONE]),
923 Err(NaturalGradientTuningError::FisherMatrixShape {
924 measurements: samples.len(),
925 cells: 1,
926 })
927 );
928 }
929
930 #[test]
931 fn tuner_exposes_natural_gradient_step_surface() {
932 let tuner = Tuner::new("natural-gradient-test-adapter", Mode::OffUseDefault);
933 let samples = measurements();
934 let step = tuner
935 .natural_gradient_step(
936 &samples,
937 &identity_fisher_q16(samples.len()),
938 NaturalGradientPolicy::default(),
939 )
940 .expect("Fix: tuner natural-gradient policy surface should accept identity Fisher");
941
942 assert_eq!(step.selected_workgroup_size, [128, 1, 1]);
943 }
944
945 #[test]
946 fn measured_natural_gradient_sweep_uses_backend_timer_and_fisher_policy() {
947 let tuner = Tuner::new(
948 "measured-natural-gradient-test-adapter",
949 Mode::NaturalGradient,
950 );
951 let mut timer = StaticTimer::new();
952 let mut fisher = identity_fisher_q16(3);
953 fisher[0] = Q16_ONE * 8;
954
955 let step = tuner
956 .best_of_natural_gradient(
957 &empty_program(),
958 [[64, 1, 1], [128, 1, 1], [256, 1, 1]],
959 &mut timer,
960 &fisher,
961 NaturalGradientPolicy {
962 temperature_ns: 4_000,
963 },
964 )
965 .expect("Fix: backend timer should succeed")
966 .expect("Fix: natural-gradient policy should accept measured candidates");
967
968 assert_eq!(
969 timer.measured,
970 vec![[64, 1, 1], [128, 1, 1], [256, 1, 1]],
971 "Fix: natural-gradient sweep must measure every supplied candidate."
972 );
973 assert_eq!(step.best_measured_workgroup_size, [128, 1, 1]);
974 assert_eq!(
975 step.selected_workgroup_size,
976 [64, 1, 1],
977 "Fix: measured natural-gradient sweep must use Fisher policy, not raw fastest-only selection."
978 );
979 }
980
981 #[test]
982 fn measured_natural_gradient_sweep_propagates_timer_failures() {
983 let tuner = Tuner::new(
984 "measured-natural-gradient-error-test-adapter",
985 Mode::NaturalGradient,
986 );
987 let mut timer = StaticTimer::failing(128);
988 let err = tuner
989 .best_of_natural_gradient(
990 &empty_program(),
991 [[64, 1, 1], [128, 1, 1], [256, 1, 1]],
992 &mut timer,
993 &identity_fisher_q16(3),
994 NaturalGradientPolicy::default(),
995 )
996 .expect_err("Fix: backend timer failures must propagate before policy update");
997
998 assert_eq!(err, "timer failed");
999 assert_eq!(
1000 timer.measured,
1001 vec![[64, 1, 1], [128, 1, 1]],
1002 "Fix: failed measurements must stop the sweep instead of producing a fake policy result."
1003 );
1004 }
1005}