Skip to main content

vyre_driver/
shape_prediction.rs

1//! N8 substrate: predicted-next-shape fingerprint API.
2//!
3//! Async dispatch path already exists (D3 / D7); the wait window
4//! between submission and completion is dead CPU time. This module
5//! owns the *prediction*: given recent dispatch fingerprints, what
6//! is the most likely next dispatch? The runtime can then prefetch
7//! the predicted pipeline cache key during the wait.
8//!
9//! Three prediction strategies, in order of preference:
10//!
11//! 1. **Repeat**  -  same fingerprint as the immediate predecessor
12//!    (covers tight loops dispatching the same kernel).
13//! 2. **Cycle of length N**  -  fingerprint = the one N steps ago, even when
14//!    only a partial next cycle has been observed (covers attention's Q, K, V,
15//!    scale, softmax, attend cycle before the second full cycle completes).
16//! 3. **None**  -  history too sparse to predict; runtime skips the
17//!    prefetch this iteration.
18//!
19//! Pure analysis; allocation-free after construction.
20
21/// Fingerprint type the predictor operates over. Same shape as
22/// [`crate::launch::program_vsa_fingerprint_words`] returns; the
23/// callsite passes an opaque 8-word fingerprint.
24pub type ShapeFingerprint = [u32; 8];
25
26/// Bounded ring buffer of recent dispatch fingerprints. The
27/// predictor looks back at most [`MAX_HISTORY`] entries.
28#[derive(Debug, Clone)]
29pub struct ShapeHistory {
30    entries: [ShapeFingerprint; MAX_HISTORY],
31    start: usize,
32    len: usize,
33}
34
35/// Maximum number of historical fingerprints retained for prediction.
36/// 16 is enough to catch attention-style 6-step cycles with one
37/// repeat, and small enough to scan in O(N²) at predict time
38/// without a measurable cost.
39pub const MAX_HISTORY: usize = 16;
40
41impl Default for ShapeHistory {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl ShapeHistory {
48    /// Empty history  -  no prediction is possible.
49    #[must_use]
50    pub fn new() -> Self {
51        Self {
52            entries: [[0u32; 8]; MAX_HISTORY],
53            start: 0,
54            len: 0,
55        }
56    }
57
58    /// Record a dispatch fingerprint. The predictor uses the most
59    /// recent [`MAX_HISTORY`] entries to predict the next.
60    pub fn record(&mut self, fingerprint: ShapeFingerprint) {
61        if self.len < MAX_HISTORY {
62            let idx = (self.start + self.len) % MAX_HISTORY;
63            self.entries[idx] = fingerprint;
64            self.len += 1;
65        } else {
66            self.entries[self.start] = fingerprint;
67            self.start = (self.start + 1) % MAX_HISTORY;
68        }
69    }
70
71    /// Most recent fingerprint, or `None` if history is empty.
72    #[must_use]
73    pub fn latest(&self) -> Option<&ShapeFingerprint> {
74        if self.len == 0 {
75            return None;
76        }
77        Some(&self.entries[(self.start + self.len - 1) % MAX_HISTORY])
78    }
79
80    /// Number of entries currently retained.
81    #[must_use]
82    pub fn len(&self) -> usize {
83        self.len
84    }
85
86    /// True when no entries have been recorded yet.
87    #[must_use]
88    pub fn is_empty(&self) -> bool {
89        self.len == 0
90    }
91
92    /// True when the retained history window contains `fingerprint`.
93    ///
94    /// This lets backend-side prediction caches evict cloned predicted
95    /// programs that can no longer be predicted by the bounded history.
96    #[must_use]
97    pub fn contains(&self, fingerprint: &ShapeFingerprint) -> bool {
98        (0..self.len).any(|idx| self.get(idx) == *fingerprint)
99    }
100
101    fn get(&self, logical_idx: usize) -> ShapeFingerprint {
102        debug_assert!(logical_idx < self.len);
103        self.entries[(self.start + logical_idx) % MAX_HISTORY]
104    }
105
106    /// Predict the next dispatch fingerprint. Returns `None` when
107    /// the history is too sparse or no pattern matches.
108    ///
109    /// Strategy:
110    /// 1. If the last two entries are equal, predict another repeat.
111    /// 2. Otherwise, look for the smallest cycle length `N` such that every
112    ///    retained entry with an entry `N` positions earlier matches it.
113    ///    This predicts partial cycles as soon as one lag agrees, e.g.
114    ///    `A,B,C,A,B -> C`, instead of waiting for `A,B,C,A,B,C`.
115    /// 3. No prediction.
116    #[must_use]
117    pub fn predict_next(&self) -> Option<ShapeFingerprint> {
118        let n = self.len;
119        if n == 0 {
120            return None;
121        }
122        // Strategy 1: repeat.
123        if n >= 2 && self.get(n - 1) == self.get(n - 2) {
124            return Some(self.get(n - 1));
125        }
126        // Strategy 2: cycle of length 2..n. Partial-cycle detection matters
127        // for prefetch: after A,B,C,A,B the next useful fingerprint is C, and
128        // waiting for A,B,C,A,B,C loses one dispatch worth of overlap.
129        for cycle in 2..n {
130            let mut matches = true;
131            for i in cycle..n {
132                if self.get(i) != self.get(i - cycle) {
133                    matches = false;
134                    break;
135                }
136            }
137            if matches {
138                return Some(self.get(n - cycle));
139            }
140        }
141        None
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    fn fp(seed: u32) -> ShapeFingerprint {
150        let mut a = [0u32; 8];
151        for (i, slot) in a.iter_mut().enumerate() {
152            *slot = seed.wrapping_mul(31).wrapping_add(i as u32);
153        }
154        a
155    }
156
157    #[test]
158    fn empty_history_predicts_nothing() {
159        let h = ShapeHistory::new();
160        assert!(h.predict_next().is_none());
161    }
162
163    #[test]
164    fn single_entry_history_cannot_predict() {
165        let mut h = ShapeHistory::new();
166        h.record(fp(1));
167        assert!(h.predict_next().is_none());
168    }
169
170    #[test]
171    fn repeated_fingerprint_predicts_repeat() {
172        let mut h = ShapeHistory::new();
173        h.record(fp(1));
174        h.record(fp(1));
175        assert_eq!(h.predict_next(), Some(fp(1)));
176    }
177
178    #[test]
179    fn two_step_cycle_is_predicted() {
180        let mut h = ShapeHistory::new();
181        h.record(fp(1));
182        h.record(fp(2));
183        h.record(fp(1));
184        h.record(fp(2));
185        assert_eq!(h.predict_next(), Some(fp(1)));
186    }
187
188    #[test]
189    fn three_step_cycle_is_predicted() {
190        let mut h = ShapeHistory::new();
191        h.record(fp(1));
192        h.record(fp(2));
193        h.record(fp(3));
194        h.record(fp(1));
195        h.record(fp(2));
196        h.record(fp(3));
197        assert_eq!(h.predict_next(), Some(fp(1)));
198    }
199
200    #[test]
201    fn partial_three_step_cycle_is_predicted_before_second_cycle_completes() {
202        let mut h = ShapeHistory::new();
203        h.record(fp(1));
204        h.record(fp(2));
205        h.record(fp(3));
206        h.record(fp(1));
207        h.record(fp(2));
208        assert_eq!(h.predict_next(), Some(fp(3)));
209    }
210
211    #[test]
212    fn partial_long_cycle_prefetches_next_phase() {
213        let mut h = ShapeHistory::new();
214        for seed in [10, 20, 30, 40, 10, 20, 30] {
215            h.record(fp(seed));
216        }
217        assert_eq!(h.predict_next(), Some(fp(40)));
218    }
219
220    #[test]
221    fn no_pattern_means_no_prediction() {
222        let mut h = ShapeHistory::new();
223        h.record(fp(1));
224        h.record(fp(2));
225        h.record(fp(3));
226        h.record(fp(4));
227        assert!(h.predict_next().is_none());
228    }
229
230    #[test]
231    fn history_caps_at_max_entries() {
232        let mut h = ShapeHistory::new();
233        for i in 0..(MAX_HISTORY + 5) {
234            h.record(fp(i as u32));
235        }
236        assert_eq!(h.len(), MAX_HISTORY);
237        // Earliest entry is fp(5), latest is fp(MAX_HISTORY+4).
238        assert_eq!(h.latest(), Some(&fp((MAX_HISTORY + 4) as u32)));
239        assert!(!h.contains(&fp(0)));
240        assert!(h.contains(&fp(5)));
241        assert!(h.contains(&fp((MAX_HISTORY + 4) as u32)));
242    }
243}