vyre_driver/
shape_prediction.rs1pub type ShapeFingerprint = [u32; 8];
25
26#[derive(Debug, Clone)]
29pub struct ShapeHistory {
30 entries: [ShapeFingerprint; MAX_HISTORY],
31 start: usize,
32 len: usize,
33}
34
35pub const MAX_HISTORY: usize = 16;
40
41impl Default for ShapeHistory {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl ShapeHistory {
48 #[must_use]
50 pub fn new() -> Self {
51 Self {
52 entries: [[0u32; 8]; MAX_HISTORY],
53 start: 0,
54 len: 0,
55 }
56 }
57
58 pub fn record(&mut self, fingerprint: ShapeFingerprint) {
61 if self.len < MAX_HISTORY {
62 let idx = (self.start + self.len) % MAX_HISTORY;
63 self.entries[idx] = fingerprint;
64 self.len += 1;
65 } else {
66 self.entries[self.start] = fingerprint;
67 self.start = (self.start + 1) % MAX_HISTORY;
68 }
69 }
70
71 #[must_use]
73 pub fn latest(&self) -> Option<&ShapeFingerprint> {
74 if self.len == 0 {
75 return None;
76 }
77 Some(&self.entries[(self.start + self.len - 1) % MAX_HISTORY])
78 }
79
80 #[must_use]
82 pub fn len(&self) -> usize {
83 self.len
84 }
85
86 #[must_use]
88 pub fn is_empty(&self) -> bool {
89 self.len == 0
90 }
91
92 #[must_use]
97 pub fn contains(&self, fingerprint: &ShapeFingerprint) -> bool {
98 (0..self.len).any(|idx| self.get(idx) == *fingerprint)
99 }
100
101 fn get(&self, logical_idx: usize) -> ShapeFingerprint {
102 debug_assert!(logical_idx < self.len);
103 self.entries[(self.start + logical_idx) % MAX_HISTORY]
104 }
105
106 #[must_use]
117 pub fn predict_next(&self) -> Option<ShapeFingerprint> {
118 let n = self.len;
119 if n == 0 {
120 return None;
121 }
122 if n >= 2 && self.get(n - 1) == self.get(n - 2) {
124 return Some(self.get(n - 1));
125 }
126 for cycle in 2..n {
130 let mut matches = true;
131 for i in cycle..n {
132 if self.get(i) != self.get(i - cycle) {
133 matches = false;
134 break;
135 }
136 }
137 if matches {
138 return Some(self.get(n - cycle));
139 }
140 }
141 None
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 fn fp(seed: u32) -> ShapeFingerprint {
150 let mut a = [0u32; 8];
151 for (i, slot) in a.iter_mut().enumerate() {
152 *slot = seed.wrapping_mul(31).wrapping_add(i as u32);
153 }
154 a
155 }
156
157 #[test]
158 fn empty_history_predicts_nothing() {
159 let h = ShapeHistory::new();
160 assert!(h.predict_next().is_none());
161 }
162
163 #[test]
164 fn single_entry_history_cannot_predict() {
165 let mut h = ShapeHistory::new();
166 h.record(fp(1));
167 assert!(h.predict_next().is_none());
168 }
169
170 #[test]
171 fn repeated_fingerprint_predicts_repeat() {
172 let mut h = ShapeHistory::new();
173 h.record(fp(1));
174 h.record(fp(1));
175 assert_eq!(h.predict_next(), Some(fp(1)));
176 }
177
178 #[test]
179 fn two_step_cycle_is_predicted() {
180 let mut h = ShapeHistory::new();
181 h.record(fp(1));
182 h.record(fp(2));
183 h.record(fp(1));
184 h.record(fp(2));
185 assert_eq!(h.predict_next(), Some(fp(1)));
186 }
187
188 #[test]
189 fn three_step_cycle_is_predicted() {
190 let mut h = ShapeHistory::new();
191 h.record(fp(1));
192 h.record(fp(2));
193 h.record(fp(3));
194 h.record(fp(1));
195 h.record(fp(2));
196 h.record(fp(3));
197 assert_eq!(h.predict_next(), Some(fp(1)));
198 }
199
200 #[test]
201 fn partial_three_step_cycle_is_predicted_before_second_cycle_completes() {
202 let mut h = ShapeHistory::new();
203 h.record(fp(1));
204 h.record(fp(2));
205 h.record(fp(3));
206 h.record(fp(1));
207 h.record(fp(2));
208 assert_eq!(h.predict_next(), Some(fp(3)));
209 }
210
211 #[test]
212 fn partial_long_cycle_prefetches_next_phase() {
213 let mut h = ShapeHistory::new();
214 for seed in [10, 20, 30, 40, 10, 20, 30] {
215 h.record(fp(seed));
216 }
217 assert_eq!(h.predict_next(), Some(fp(40)));
218 }
219
220 #[test]
221 fn no_pattern_means_no_prediction() {
222 let mut h = ShapeHistory::new();
223 h.record(fp(1));
224 h.record(fp(2));
225 h.record(fp(3));
226 h.record(fp(4));
227 assert!(h.predict_next().is_none());
228 }
229
230 #[test]
231 fn history_caps_at_max_entries() {
232 let mut h = ShapeHistory::new();
233 for i in 0..(MAX_HISTORY + 5) {
234 h.record(fp(i as u32));
235 }
236 assert_eq!(h.len(), MAX_HISTORY);
237 assert_eq!(h.latest(), Some(&fp((MAX_HISTORY + 4) as u32)));
239 assert!(!h.contains(&fp(0)));
240 assert!(h.contains(&fp(5)));
241 assert!(h.contains(&fp((MAX_HISTORY + 4) as u32)));
242 }
243}