Skip to main content

vyre_driver/
tuner.rs

1//! Backend-neutral autotuner framework and cache metadata.
2
3use std::collections::BTreeMap;
4use std::fmt::Write as _;
5use std::fs;
6use std::path::{Path, PathBuf};
7
8use vyre_foundation::ir::Program;
9
10/// Canonical 1D workgroup-size probes shared by live dispatch tuning and
11/// backend timer sweeps.
12pub const WORKGROUP_CANDIDATES: &[u32] = &[32, 64, 128, 256, 512, 1024];
13const AUTOTUNER_ENV: &str = "VYRE_AUTOTUNER";
14const MAX_TUNER_CACHE_BYTES: u64 = 4 * 1024 * 1024;
15
16/// Tuner runtime mode.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18#[non_exhaustive]
19pub enum Mode {
20    /// Sweep candidate sizes on first dispatch.
21    On,
22    /// Sweep candidate sizes and use Fisher-preconditioned policy updates.
23    NaturalGradient,
24    /// Use cached decisions when present, otherwise the default workgroup.
25    OffUseDefault,
26}
27
28impl Mode {
29    /// Production default when `VYRE_AUTOTUNER` is unset.
30    ///
31    /// Explicit `VYRE_AUTOTUNER=off` or `default` still gives the stable
32    /// cached/default path for deterministic bisects, but the release path
33    /// exercises the Fisher-preconditioned autotuner by default.
34    #[must_use]
35    pub const fn production_default() -> Self {
36        Mode::NaturalGradient
37    }
38
39    /// Resolve mode from `VYRE_AUTOTUNER`.
40    #[must_use]
41    pub fn from_env() -> Self {
42        match std::env::var(AUTOTUNER_ENV).ok() {
43            Some(value) => Self::from_env_value(Some(value.as_str())),
44            None => Self::production_default(),
45        }
46    }
47
48    fn from_env_value(value: Option<&str>) -> Self {
49        match value {
50            Some("on") => Mode::On,
51            Some("natural" | "ng") => Mode::NaturalGradient,
52            Some("off" | "default") => Mode::OffUseDefault,
53            Some(_) => Self::production_default(),
54            None => Self::production_default(),
55        }
56    }
57}
58
59/// Backend timing hook used by the generic best-of-N framework.
60pub trait BackendTimer {
61    /// Error type returned by a concrete timing implementation.
62    type Error;
63
64    /// Measure one workgroup-size candidate and return elapsed nanoseconds.
65    ///
66    /// # Errors
67    ///
68    /// Returns the concrete backend timing error when the dispatch or timer
69    /// instrumentation fails.
70    fn measure_candidate_ns(
71        &mut self,
72        program: &Program,
73        workgroup_size: [u32; 3],
74    ) -> Result<u64, Self::Error>;
75}
76
77/// Per-adapter tuner decisions keyed by program fingerprint.
78#[derive(Debug, Default, Clone, PartialEq, Eq)]
79pub struct TunerCache {
80    /// `program_fingerprint -> best_workgroup_size`.
81    pub entries: BTreeMap<String, [u32; 3]>,
82}
83
84/// Static program shape used to disambiguate autotuner decisions.
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub struct StaticProgramShape {
87    /// Declared or overridden workgroup shape.
88    pub workgroup_size: [u32; 3],
89    /// Static workgroup-count override when known.
90    pub workgroup_count: Option<[u32; 3]>,
91    /// Static visible output byte count used by the dispatch.
92    pub output_bytes: u64,
93}
94
95impl StaticProgramShape {
96    /// Build a shape record from a program and caller-known launch facts.
97    #[must_use]
98    pub fn new(program: &Program, workgroup_count: Option<[u32; 3]>, output_bytes: u64) -> Self {
99        Self {
100            workgroup_size: program.workgroup_size(),
101            workgroup_count,
102            output_bytes,
103        }
104    }
105}
106
107/// Stable key for per-adapter workgroup autotuning decisions.
108#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
109pub struct TunerProgramKey(String);
110
111impl TunerProgramKey {
112    /// Build a key from the canonical program fingerprint plus static shape.
113    #[must_use]
114    pub fn from_program(program: &Program, shape: StaticProgramShape) -> Self {
115        let mut hasher = blake3::Hasher::new();
116        hasher.update(b"vyre-driver-workgroup-tuner-v1\0program\0");
117        hasher.update(&program.fingerprint());
118        hasher.update(b"\0workgroup-size\0");
119        for axis in shape.workgroup_size {
120            hasher.update(&axis.to_le_bytes());
121        }
122        hasher.update(b"\0workgroup-count\0");
123        match shape.workgroup_count {
124            Some(count) => {
125                hasher.update(&[1]);
126                for axis in count {
127                    hasher.update(&axis.to_le_bytes());
128                }
129            }
130            None => {
131                hasher.update(&[0]);
132            }
133        }
134        hasher.update(b"\0output-bytes\0");
135        hasher.update(&shape.output_bytes.to_le_bytes());
136        let digest = hasher.finalize();
137        let mut key = String::with_capacity(67);
138        key.push_str("v1-");
139        push_hex(digest.as_bytes(), &mut key);
140        Self(key)
141    }
142
143    /// String form used in the TOML cache.
144    #[must_use]
145    pub fn as_str(&self) -> &str {
146        &self.0
147    }
148}
149
150fn push_hex(bytes: &[u8], out: &mut String) {
151    const HEX: &[u8; 16] = b"0123456789abcdef";
152    for &byte in bytes {
153        out.push(HEX[(byte >> 4) as usize] as char);
154        out.push(HEX[(byte & 0x0f) as usize] as char);
155    }
156}
157
158impl AsRef<str> for TunerProgramKey {
159    fn as_ref(&self) -> &str {
160        self.as_str()
161    }
162}
163
164impl TunerCache {
165    /// Return the best workgroup size for the given key, if cached.
166    #[must_use]
167    pub fn get(&self, program_fp: &str) -> Option<[u32; 3]> {
168        self.entries.get(program_fp).copied()
169    }
170
171    /// Return the cached decision for a typed tuner key.
172    #[must_use]
173    pub fn get_key(&self, key: &TunerProgramKey) -> Option<[u32; 3]> {
174        self.get(key.as_str())
175    }
176
177    /// Record a decision.
178    pub fn set(&mut self, program_fp: impl Into<String>, size: [u32; 3]) {
179        self.entries.insert(program_fp.into(), size);
180    }
181
182    /// Record a decision under a typed key.
183    ///
184    /// HOT PATH (autotuner cache write): takes ownership of `key` so the fingerprint `String`
185    /// moves into the map  -  `set(key.as_str(), …)` would allocate a second copy of the same bytes.
186    pub fn set_key(&mut self, key: TunerProgramKey, size: [u32; 3]) {
187        self.entries.insert(key.0, size);
188    }
189
190    /// Load from a TOML file. Missing file returns an empty cache.
191    ///
192    /// # Errors
193    ///
194    /// Returns when the file exists but contains invalid TOML.
195    pub fn load(path: &Path) -> Result<Self, String> {
196        let Ok(contents) = read_tuner_cache_bounded(path) else {
197            return Ok(Self::default());
198        };
199        let parsed: toml::Value = toml::from_str(&contents).map_err(|error| {
200            format!(
201                "Fix: tuner cache `{}` is not valid TOML: {error}",
202                path.display()
203            )
204        })?;
205        let mut entries = BTreeMap::new();
206        if let Some(table) = parsed.as_table() {
207            for (key, value) in table {
208                if let Some(array) = value.as_array() {
209                    if array.len() == 3 {
210                        let mut triple = [0u32; 3];
211                        for (index, value) in array.iter().enumerate() {
212                            if let Some(number) = value.as_integer() {
213                                if let Ok(converted) = u32::try_from(number) {
214                                    triple[index] = converted;
215                                }
216                            }
217                        }
218                        entries.insert(key.clone(), triple);
219                    }
220                }
221            }
222        }
223        Ok(Self { entries })
224    }
225
226    /// Persist to disk. Creates parent directories as needed.
227    ///
228    /// # Errors
229    ///
230    /// Returns when the parent directory cannot be created or the file cannot
231    /// be written.
232    pub fn save(&self, path: &Path) -> Result<(), String> {
233        if let Some(parent) = path.parent() {
234            fs::create_dir_all(parent).map_err(|error| {
235                format!(
236                    "Fix: could not create tuner cache directory {}: {error}",
237                    parent.display()
238                )
239            })?;
240        }
241        let mut out = String::with_capacity(tuner_cache_string_capacity(self.entries.len()));
242        for (key, size) in &self.entries {
243            let _ = writeln!(out, "\"{}\" = [{}, {}, {}]", key, size[0], size[1], size[2]);
244        }
245        fs::write(path, &out).map_err(|error| {
246            format!(
247                "Fix: could not write tuner cache {}: {error}",
248                path.display()
249            )
250        })
251    }
252}
253
254fn read_tuner_cache_bounded(path: &Path) -> std::io::Result<String> {
255    use std::io::Read as _;
256
257    let mut file = fs::File::open(path)?;
258    let metadata = file.metadata()?;
259    if metadata.len() > MAX_TUNER_CACHE_BYTES {
260        return Err(std::io::Error::new(
261            std::io::ErrorKind::InvalidData,
262            format!("tuner cache exceeds {MAX_TUNER_CACHE_BYTES} byte limit"),
263        ));
264    }
265    let mut text = String::with_capacity(metadata.len() as usize);
266    file.by_ref()
267        .take(MAX_TUNER_CACHE_BYTES + 1)
268        .read_to_string(&mut text)?;
269    if text.len() as u64 > MAX_TUNER_CACHE_BYTES {
270        return Err(std::io::Error::new(
271            std::io::ErrorKind::InvalidData,
272            "tuner cache exceeded bounded read limit",
273        ));
274    }
275    Ok(text)
276}
277
278/// Best-of-N measurement result.
279#[derive(Debug, Clone, Copy, PartialEq, Eq)]
280pub struct TuningMeasurement {
281    /// Winning workgroup size.
282    pub workgroup_size: [u32; 3],
283    /// Measured elapsed nanoseconds for the winner.
284    pub elapsed_ns: u64,
285}
286
287/// 16.16 fixed-point value representing 1.0.
288pub const Q16_ONE: u32 = 1 << 16;
289
290/// Natural-gradient policy for choosing the next autotune probe from
291/// measured latency samples.
292///
293/// The policy treats the candidate set as a discrete distribution over
294/// launch configurations. Latency samples become a softmax over
295/// `-elapsed_ns / temperature_ns`; the supplied inverse-Fisher square-root
296/// matrix preconditions that probability/gradient vector before the driver
297/// picks the next candidate. CUDA/self-substrate can produce the same
298/// fixed-point matrix through the primitive-backed natural-gradient path.
299#[derive(Debug, Clone, Copy, PartialEq, Eq)]
300pub struct NaturalGradientPolicy {
301    /// Softmax temperature in nanoseconds. Larger values explore more.
302    pub temperature_ns: u64,
303}
304
305impl Default for NaturalGradientPolicy {
306    fn default() -> Self {
307        Self {
308            temperature_ns: 10_000,
309        }
310    }
311}
312
313/// Result of a natural-gradient autotune policy update.
314#[derive(Debug, Clone, PartialEq, Eq)]
315pub struct NaturalGradientTuningStep {
316    /// Candidate selected after Fisher preconditioning.
317    pub selected_workgroup_size: [u32; 3],
318    /// Fastest candidate observed in the raw measurement window.
319    pub best_measured_workgroup_size: [u32; 3],
320    /// Fastest elapsed time observed in the raw measurement window.
321    pub best_measured_elapsed_ns: u64,
322    /// Softmax policy weights in 16.16 fixed-point form.
323    pub policy_weights_q16: Vec<u32>,
324    /// Fisher-preconditioned gradient magnitudes in 16.16 fixed-point form.
325    pub natural_gradient_q16: Vec<u32>,
326}
327
328/// Errors returned by natural-gradient autotune policy construction.
329#[derive(Debug, Clone, PartialEq, Eq)]
330#[non_exhaustive]
331pub enum NaturalGradientTuningError {
332    /// No latency samples were provided.
333    EmptyMeasurements,
334    /// The inverse-Fisher square-root matrix was not `n * n`.
335    FisherMatrixShape {
336        /// Number of latency samples.
337        measurements: usize,
338        /// Number of fixed-point cells in the supplied matrix.
339        cells: usize,
340    },
341    /// The softmax temperature was zero.
342    ZeroTemperature,
343}
344
345impl std::fmt::Display for NaturalGradientTuningError {
346    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347        match self {
348            Self::EmptyMeasurements => {
349                write!(
350                    f,
351                    "natural-gradient tuner received no measurements. Fix: measure at least one candidate before policy update."
352                )
353            }
354            Self::FisherMatrixShape {
355                measurements,
356                cells,
357            } => write!(
358                f,
359                "natural-gradient tuner expected an inverse-Fisher matrix with {} cells for {measurements} measurement(s), got {cells}. Fix: pass an n*n 16.16 matrix.",
360                measurements.saturating_mul(*measurements)
361            ),
362            Self::ZeroTemperature => {
363                write!(
364                    f,
365                    "natural-gradient tuner temperature is zero. Fix: use a positive temperature_ns."
366                )
367            }
368        }
369    }
370}
371
372impl std::error::Error for NaturalGradientTuningError {}
373
374impl NaturalGradientPolicy {
375    /// Suggest the next workgroup-size candidate from latency samples and an
376    /// inverse-Fisher square-root matrix.
377    ///
378    /// `fisher_inv_sqrt_q16` is row-major `n x n`, 16.16 fixed-point. Passing
379    /// an identity matrix makes the policy reduce to the softmax-gradient
380    /// candidate. Non-identity blocks let the runtime bias exploration by the
381    /// local latency manifold instead of blindly reusing the single fastest
382    /// point.
383    ///
384    /// # Errors
385    ///
386    /// Returns [`NaturalGradientTuningError`] when the measurement set is
387    /// empty, temperature is zero, or the Fisher matrix shape does not match.
388    pub fn suggest(
389        &self,
390        measurements: &[TuningMeasurement],
391        fisher_inv_sqrt_q16: &[u32],
392    ) -> Result<NaturalGradientTuningStep, NaturalGradientTuningError> {
393        if measurements.is_empty() {
394            return Err(NaturalGradientTuningError::EmptyMeasurements);
395        }
396        if self.temperature_ns == 0 {
397            return Err(NaturalGradientTuningError::ZeroTemperature);
398        }
399        let expected_cells = measurements.len().checked_mul(measurements.len()).ok_or(
400            NaturalGradientTuningError::FisherMatrixShape {
401                measurements: measurements.len(),
402                cells: fisher_inv_sqrt_q16.len(),
403            },
404        )?;
405        if fisher_inv_sqrt_q16.len() != expected_cells {
406            return Err(NaturalGradientTuningError::FisherMatrixShape {
407                measurements: measurements.len(),
408                cells: fisher_inv_sqrt_q16.len(),
409            });
410        }
411
412        let mut best_index = 0usize;
413        let mut best_elapsed = measurements[0].elapsed_ns;
414        for (index, measurement) in measurements.iter().enumerate().skip(1) {
415            if measurement.elapsed_ns < best_elapsed {
416                best_index = index;
417                best_elapsed = measurement.elapsed_ns;
418            }
419        }
420
421        let policy_weights_q16 =
422            latency_softmax_weights_q16(measurements, best_elapsed, self.temperature_ns);
423        let natural_gradient_q16 =
424            precondition_q16(fisher_inv_sqrt_q16, &policy_weights_q16, measurements.len());
425        let selected_index = natural_gradient_q16
426            .iter()
427            .enumerate()
428            .max_by_key(|(_, value)| *value)
429            .map(|(index, _)| index)
430            .unwrap_or(best_index);
431
432        Ok(NaturalGradientTuningStep {
433            selected_workgroup_size: measurements[selected_index].workgroup_size,
434            best_measured_workgroup_size: measurements[best_index].workgroup_size,
435            best_measured_elapsed_ns: best_elapsed,
436            policy_weights_q16,
437            natural_gradient_q16,
438        })
439    }
440}
441
442/// Build an identity inverse-Fisher square-root matrix in 16.16 fixed point.
443#[must_use]
444pub fn identity_fisher_q16(candidate_count: usize) -> Vec<u32> {
445    let mut out = Vec::new();
446    identity_fisher_q16_into(candidate_count, &mut out);
447    out
448}
449
450/// Write an identity inverse-Fisher square-root matrix into caller-owned
451/// storage.
452
453pub fn identity_fisher_q16_into(candidate_count: usize, out: &mut Vec<u32>) {
454    let Some(cells) = candidate_count.checked_mul(candidate_count) else {
455        out.clear();
456        return;
457    };
458    out.clear();
459    out.resize(cells, 0);
460    for index in 0..candidate_count {
461        out[index * candidate_count + index] = Q16_ONE;
462    }
463}
464
465fn latency_softmax_weights_q16(
466    measurements: &[TuningMeasurement],
467    best_elapsed: u64,
468    temperature_ns: u64,
469) -> Vec<u32> {
470    let temperature = temperature_ns as f64;
471    let mut weights = Vec::with_capacity(measurements.len());
472    let mut sum = 0.0f64;
473    for measurement in measurements {
474        let penalty = measurement.elapsed_ns.saturating_sub(best_elapsed) as f64;
475        let weight = (-penalty / temperature).exp();
476        weights.push(weight);
477        sum += weight;
478    }
479    let mut out = Vec::with_capacity(measurements.len());
480    let mut assigned = 0u32;
481    for (index, weight) in weights.iter().enumerate() {
482        if index + 1 == weights.len() {
483            out.push(Q16_ONE.saturating_sub(assigned));
484            break;
485        }
486        let q16 = ((*weight / sum) * f64::from(Q16_ONE)).round() as u32;
487        let remaining = Q16_ONE.saturating_sub(assigned);
488        let q16 = q16.min(remaining);
489        assigned = assigned.saturating_add(q16);
490        out.push(q16);
491    }
492    out
493}
494
495fn precondition_q16(matrix_q16: &[u32], gradient_q16: &[u32], n: usize) -> Vec<u32> {
496    let mut out = vec![0u32; n];
497    for row in 0..n {
498        let mut acc = 0u64;
499        for col in 0..n {
500            let matrix = u64::from(matrix_q16[row * n + col]);
501            let gradient = u64::from(gradient_q16[col]);
502            acc = acc.saturating_add((matrix.saturating_mul(gradient)) >> 16);
503        }
504        out[row] = acc.min(u64::from(u32::MAX)) as u32;
505    }
506    out
507}
508
509/// Workgroup-size autotuner.
510pub struct Tuner {
511    mode: Mode,
512    cache: TunerCache,
513    cache_path: PathBuf,
514}
515
516impl Tuner {
517    /// Build a new tuner for the adapter fingerprinted as `adapter_fp`.
518    #[must_use]
519    pub fn new(adapter_fp: &str, mode: Mode) -> Self {
520        let cache_path = Self::cache_path_for_adapter(adapter_fp);
521        let cache = TunerCache::load(&cache_path).unwrap_or_default();
522        Self {
523            mode,
524            cache,
525            cache_path,
526        }
527    }
528
529    /// Cache file path for a given adapter fingerprint.
530    #[must_use]
531    pub fn cache_path_for_adapter(adapter_fp: &str) -> PathBuf {
532        let mut home = dirs_cache_root();
533        home.push("vyre");
534        home.push("tuner");
535        home.push(format!("{adapter_fp}.toml"));
536        home
537    }
538
539    /// Candidate workgroup sizes bounded by `max_invocations`.
540    #[must_use]
541    pub fn candidates_for(&self, max_invocations: u32) -> Vec<u32> {
542        let mut candidates = Vec::new();
543        let _ = candidates.try_reserve_exact(WORKGROUP_CANDIDATES.len());
544        candidates.extend(
545            WORKGROUP_CANDIDATES
546                .iter()
547                .copied()
548                .filter(|candidate| *candidate <= max_invocations),
549        );
550        candidates
551    }
552
553    /// Default workgroup size used without cache data.
554    #[must_use]
555    pub const fn default_workgroup_size() -> [u32; 3] {
556        crate::pipeline::DEFAULT_1D_WORKGROUP_SIZE
557    }
558
559    /// Mode this tuner is running in.
560    #[must_use]
561    pub const fn mode(&self) -> Mode {
562        self.mode
563    }
564
565    /// Resolve the workgroup size for a program key.
566    #[must_use]
567    pub fn resolve(&self, program_fp: &str) -> [u32; 3] {
568        self.cache
569            .get(program_fp)
570            .unwrap_or_else(Self::default_workgroup_size)
571    }
572
573    /// Resolve the workgroup size for a typed program/static-shape key.
574    #[must_use]
575    pub fn resolve_key(&self, key: &TunerProgramKey) -> [u32; 3] {
576        self.resolve(key.as_str())
577    }
578
579    /// Record a sweep outcome in memory.
580    pub fn record_decision(&mut self, program_fp: impl Into<String>, size: [u32; 3]) {
581        self.cache.set(program_fp, size);
582    }
583
584    /// Record a sweep outcome for a typed key.
585    pub fn record_key_decision(&mut self, key: TunerProgramKey, size: [u32; 3]) {
586        self.cache.set_key(key, size);
587    }
588
589    /// Measure candidate sizes and choose the fastest one.
590    ///
591    /// # Errors
592    ///
593    /// Returns a backend timing error from [`BackendTimer`].
594    pub fn best_of<T: BackendTimer>(
595        &self,
596        program: &Program,
597        candidates: impl IntoIterator<Item = [u32; 3]>,
598        timer: &mut T,
599    ) -> Result<Option<TuningMeasurement>, T::Error> {
600        let mut best = None;
601        for workgroup_size in candidates {
602            let elapsed_ns = timer.measure_candidate_ns(program, workgroup_size)?;
603            let measurement = TuningMeasurement {
604                workgroup_size,
605                elapsed_ns,
606            };
607            if best
608                .map(|current: TuningMeasurement| elapsed_ns < current.elapsed_ns)
609                .unwrap_or(true)
610            {
611                best = Some(measurement);
612            }
613        }
614        Ok(best)
615    }
616
617    /// Measure candidates, then choose the next probe with a
618    /// Fisher-preconditioned natural-gradient policy.
619    ///
620    /// This is the concrete runtime handoff for `VYRE_AUTOTUNER=natural`.
621    /// It reuses the same backend timer as [`Self::best_of`], records every
622    /// measured candidate, and feeds those measurements into
623    /// [`NaturalGradientPolicy`]. The returned step includes both the raw
624    /// fastest measurement and the Fisher-directed next candidate.
625    ///
626    /// # Errors
627    ///
628    /// Returns backend timing errors from [`BackendTimer`] or policy errors
629    /// from [`NaturalGradientPolicy`].
630    pub fn best_of_natural_gradient<T: BackendTimer>(
631        &self,
632        program: &Program,
633        candidates: impl IntoIterator<Item = [u32; 3]>,
634        timer: &mut T,
635        fisher_inv_sqrt_q16: &[u32],
636        policy: NaturalGradientPolicy,
637    ) -> Result<Result<NaturalGradientTuningStep, NaturalGradientTuningError>, T::Error> {
638        let mut measurements = Vec::new();
639        for workgroup_size in candidates {
640            let elapsed_ns = timer.measure_candidate_ns(program, workgroup_size)?;
641            measurements.push(TuningMeasurement {
642                workgroup_size,
643                elapsed_ns,
644            });
645        }
646        Ok(policy.suggest(&measurements, fisher_inv_sqrt_q16))
647    }
648
649    /// Convert measured candidates into a Fisher-preconditioned next probe.
650    ///
651    /// This keeps the best-of-N timing hook compatible while giving CUDA and
652    /// other GPU backends a richer update rule than "pick the current fastest
653    /// sample forever." Backends can feed `fisher_inv_sqrt_q16` from the
654    /// primitive-backed natural-gradient self-substrate path.
655    ///
656    /// # Errors
657    ///
658    /// Returns [`NaturalGradientTuningError`] when the policy input is
659    /// malformed.
660    pub fn natural_gradient_step(
661        &self,
662        measurements: &[TuningMeasurement],
663        fisher_inv_sqrt_q16: &[u32],
664        policy: NaturalGradientPolicy,
665    ) -> Result<NaturalGradientTuningStep, NaturalGradientTuningError> {
666        policy.suggest(measurements, fisher_inv_sqrt_q16)
667    }
668
669    /// Write the cache to disk.
670    ///
671    /// # Errors
672    ///
673    /// Returns the structured error from [`TunerCache::save`].
674    pub fn persist(&self) -> Result<(), String> {
675        self.cache.save(&self.cache_path)
676    }
677}
678
679/// Snapshot of live behavior the tuner consumes for adaptive resizing.
680#[derive(Debug, Clone)]
681pub struct TunerFeedback {
682    /// `(opcode_id, execution_count)` pairs from backend metrics.
683    pub per_opcode_counts: Vec<(u32, u32)>,
684    /// Total wall-time in microseconds.
685    pub wall_time_us: u64,
686    /// Idle microseconds inside the window.
687    pub idle_us: u64,
688    /// Workgroup size x this feedback was gathered on.
689    pub observed_workgroup_size_x: u32,
690    /// Observed throughput per microsecond.
691    pub observed_throughput_per_us: f64,
692}
693
694/// Hysteresis-based default resize policy.
695#[derive(Debug, Clone)]
696pub struct DefaultPolicy {
697    /// Upper bound from the adapter capability probe.
698    pub adapter_max_workgroup_size_x: u32,
699    /// Floor below which we never shrink.
700    pub minimum_workgroup_size_x: u32,
701    /// Throughput below which we grow.
702    pub saturation_threshold_per_us: f64,
703    /// Idle time above which we shrink.
704    pub idle_shrink_us: u64,
705}
706
707impl Default for DefaultPolicy {
708    fn default() -> Self {
709        Self {
710            adapter_max_workgroup_size_x: 1024,
711            minimum_workgroup_size_x: 32,
712            saturation_threshold_per_us: 1.0,
713            idle_shrink_us: 100_000,
714        }
715    }
716}
717
718impl DefaultPolicy {
719    /// Suggest a new workgroup size for the next feedback window.
720    #[must_use]
721    pub fn suggest_resize(&self, feedback: &TunerFeedback) -> Option<u32> {
722        let current = feedback.observed_workgroup_size_x.max(1);
723        if feedback.idle_us > self.idle_shrink_us {
724            let shrunk = current / 2;
725            if shrunk >= self.minimum_workgroup_size_x && shrunk != current {
726                return Some(shrunk);
727            }
728            return None;
729        }
730        if feedback.observed_throughput_per_us < self.saturation_threshold_per_us {
731            let grown = current.checked_mul(2)?;
732            if grown <= self.adapter_max_workgroup_size_x && grown != current {
733                return Some(grown);
734            }
735        }
736        None
737    }
738}
739
740fn tuner_cache_string_capacity(entries: usize) -> usize {
741    entries.saturating_mul(96)
742}
743
744fn dirs_cache_root() -> PathBuf {
745    if let Some(xdg) = std::env::var_os("XDG_CACHE_HOME") {
746        PathBuf::from(xdg)
747    } else if let Some(home) = std::env::var_os("HOME") {
748        let mut path = PathBuf::from(home);
749        path.push(".cache");
750        path
751    } else {
752        PathBuf::from(".")
753    }
754}
755
756#[cfg(test)]
757mod tests {
758    use super::*;
759
760    fn measurements() -> Vec<TuningMeasurement> {
761        vec![
762            TuningMeasurement {
763                workgroup_size: [64, 1, 1],
764                elapsed_ns: 12_000,
765            },
766            TuningMeasurement {
767                workgroup_size: [128, 1, 1],
768                elapsed_ns: 8_000,
769            },
770            TuningMeasurement {
771                workgroup_size: [256, 1, 1],
772                elapsed_ns: 10_000,
773            },
774        ]
775    }
776
777    struct StaticTimer {
778        fail_on: Option<u32>,
779        measured: Vec<[u32; 3]>,
780    }
781
782    impl StaticTimer {
783        fn new() -> Self {
784            Self {
785                fail_on: None,
786                measured: Vec::new(),
787            }
788        }
789
790        fn failing(fail_on: u32) -> Self {
791            Self {
792                fail_on: Some(fail_on),
793                measured: Vec::new(),
794            }
795        }
796    }
797
798    impl BackendTimer for StaticTimer {
799        type Error = &'static str;
800
801        fn measure_candidate_ns(
802            &mut self,
803            _program: &Program,
804            workgroup_size: [u32; 3],
805        ) -> Result<u64, Self::Error> {
806            self.measured.push(workgroup_size);
807            if self.fail_on == Some(workgroup_size[0]) {
808                return Err("timer failed");
809            }
810            Ok(match workgroup_size[0] {
811                64 => 12_000,
812                128 => 8_000,
813                256 => 10_000,
814                _ => 50_000,
815            })
816        }
817    }
818
819    fn empty_program() -> Program {
820        Program::wrapped(Vec::new(), [64, 1, 1], Vec::new())
821    }
822
823    #[test]
824    fn unset_autotuner_mode_defaults_to_natural_gradient_release_path() {
825        assert_eq!(Mode::production_default(), Mode::NaturalGradient);
826        assert_eq!(Mode::from_env_value(None), Mode::NaturalGradient);
827    }
828
829    #[test]
830    fn explicit_env_modes_preserve_escape_hatches() {
831        assert_eq!(Mode::from_env_value(Some("natural")), Mode::NaturalGradient);
832        assert_eq!(Mode::from_env_value(Some("ng")), Mode::NaturalGradient);
833        assert_eq!(Mode::from_env_value(Some("on")), Mode::On);
834        assert_eq!(Mode::from_env_value(Some("off")), Mode::OffUseDefault);
835        assert_eq!(Mode::from_env_value(Some("default")), Mode::OffUseDefault);
836    }
837
838    #[test]
839    fn identity_fisher_preserves_fastest_candidate_policy_gradient() {
840        let policy = NaturalGradientPolicy {
841            temperature_ns: 4_000,
842        };
843        let samples = measurements();
844        let step = policy
845            .suggest(&samples, &identity_fisher_q16(samples.len()))
846            .expect("Fix: identity Fisher natural-gradient update should be valid");
847
848        assert_eq!(step.best_measured_workgroup_size, [128, 1, 1]);
849        assert_eq!(step.selected_workgroup_size, [128, 1, 1]);
850        assert_eq!(step.best_measured_elapsed_ns, 8_000);
851    }
852
853    #[test]
854    fn anisotropic_fisher_can_redirect_next_probe_without_changing_measurement_winner() {
855        let policy = NaturalGradientPolicy {
856            temperature_ns: 4_000,
857        };
858        let samples = measurements();
859        let mut fisher = identity_fisher_q16(samples.len());
860        fisher[0] = Q16_ONE * 8;
861
862        let step = policy
863            .suggest(&samples, &fisher)
864            .expect("Fix: diagonal Fisher natural-gradient update should be valid");
865
866        assert_eq!(step.best_measured_workgroup_size, [128, 1, 1]);
867        assert_eq!(
868            step.selected_workgroup_size,
869            [64, 1, 1],
870            "Fix: Fisher geometry must be able to steer exploration away from the raw fastest sample."
871        );
872        assert!(
873            step.natural_gradient_q16[0] > step.natural_gradient_q16[1],
874            "Fix: preconditioned gradient should reflect the anisotropic Fisher block."
875        );
876    }
877
878    #[test]
879    fn softmax_weights_conserve_q16_probability_mass_across_hostile_latencies() {
880        let policy = NaturalGradientPolicy { temperature_ns: 1 };
881        for base in [0_u64, 1, 10, 1_000, u64::MAX - 2] {
882            let samples = vec![
883                TuningMeasurement {
884                    workgroup_size: [32, 1, 1],
885                    elapsed_ns: base,
886                },
887                TuningMeasurement {
888                    workgroup_size: [64, 1, 1],
889                    elapsed_ns: base.saturating_add(1),
890                },
891                TuningMeasurement {
892                    workgroup_size: [128, 1, 1],
893                    elapsed_ns: base.saturating_add(2),
894                },
895            ];
896            let step = policy
897                .suggest(&samples, &identity_fisher_q16(samples.len()))
898                .expect("Fix: hostile latency range should still produce a normalized policy");
899            let total: u32 = step.policy_weights_q16.iter().sum();
900            assert_eq!(
901                total, Q16_ONE,
902                "Fix: fixed-point policy weights must conserve probability mass for base={base}."
903            );
904        }
905    }
906
907    #[test]
908    fn rejects_empty_measurements_zero_temperature_and_bad_fisher_shape() {
909        let policy = NaturalGradientPolicy::default();
910        assert_eq!(
911            policy.suggest(&[], &[]),
912            Err(NaturalGradientTuningError::EmptyMeasurements)
913        );
914
915        let samples = measurements();
916        let zero_temp = NaturalGradientPolicy { temperature_ns: 0 };
917        assert_eq!(
918            zero_temp.suggest(&samples, &identity_fisher_q16(samples.len())),
919            Err(NaturalGradientTuningError::ZeroTemperature)
920        );
921        assert_eq!(
922            policy.suggest(&samples, &[Q16_ONE]),
923            Err(NaturalGradientTuningError::FisherMatrixShape {
924                measurements: samples.len(),
925                cells: 1,
926            })
927        );
928    }
929
930    #[test]
931    fn tuner_exposes_natural_gradient_step_surface() {
932        let tuner = Tuner::new("natural-gradient-test-adapter", Mode::OffUseDefault);
933        let samples = measurements();
934        let step = tuner
935            .natural_gradient_step(
936                &samples,
937                &identity_fisher_q16(samples.len()),
938                NaturalGradientPolicy::default(),
939            )
940            .expect("Fix: tuner natural-gradient policy surface should accept identity Fisher");
941
942        assert_eq!(step.selected_workgroup_size, [128, 1, 1]);
943    }
944
945    #[test]
946    fn measured_natural_gradient_sweep_uses_backend_timer_and_fisher_policy() {
947        let tuner = Tuner::new(
948            "measured-natural-gradient-test-adapter",
949            Mode::NaturalGradient,
950        );
951        let mut timer = StaticTimer::new();
952        let mut fisher = identity_fisher_q16(3);
953        fisher[0] = Q16_ONE * 8;
954
955        let step = tuner
956            .best_of_natural_gradient(
957                &empty_program(),
958                [[64, 1, 1], [128, 1, 1], [256, 1, 1]],
959                &mut timer,
960                &fisher,
961                NaturalGradientPolicy {
962                    temperature_ns: 4_000,
963                },
964            )
965            .expect("Fix: backend timer should succeed")
966            .expect("Fix: natural-gradient policy should accept measured candidates");
967
968        assert_eq!(
969            timer.measured,
970            vec![[64, 1, 1], [128, 1, 1], [256, 1, 1]],
971            "Fix: natural-gradient sweep must measure every supplied candidate."
972        );
973        assert_eq!(step.best_measured_workgroup_size, [128, 1, 1]);
974        assert_eq!(
975            step.selected_workgroup_size,
976            [64, 1, 1],
977            "Fix: measured natural-gradient sweep must use Fisher policy, not raw fastest-only selection."
978        );
979    }
980
981    #[test]
982    fn measured_natural_gradient_sweep_propagates_timer_failures() {
983        let tuner = Tuner::new(
984            "measured-natural-gradient-error-test-adapter",
985            Mode::NaturalGradient,
986        );
987        let mut timer = StaticTimer::failing(128);
988        let err = tuner
989            .best_of_natural_gradient(
990                &empty_program(),
991                [[64, 1, 1], [128, 1, 1], [256, 1, 1]],
992                &mut timer,
993                &identity_fisher_q16(3),
994                NaturalGradientPolicy::default(),
995            )
996            .expect_err("Fix: backend timer failures must propagate before policy update");
997
998        assert_eq!(err, "timer failed");
999        assert_eq!(
1000            timer.measured,
1001            vec![[64, 1, 1], [128, 1, 1]],
1002            "Fix: failed measurements must stop the sweep instead of producing a fake policy result."
1003        );
1004    }
1005}