trident/neural/data/
replay.rs1use rkyv::{Archive, Deserialize, Serialize};
7use std::path::{Path, PathBuf};
8
9#[derive(Archive, Serialize, Deserialize, Clone, Debug)]
11#[rkyv(derive(Debug))]
12pub struct BuildResult {
13 pub tir_hash: [u8; 32],
15 pub generated_tasm: Vec<String>,
17 pub valid: bool,
19 pub clock_cycles: Option<u64>,
21 pub compiler_cycles: u64,
23 pub fallback_used: bool,
25 pub timestamp: u64,
27 pub model_version: u32,
29}
30
31#[derive(Archive, Serialize, Deserialize, Clone, Debug)]
33#[rkyv(derive(Debug))]
34struct ReplayArchive {
35 entries: Vec<BuildResult>,
36}
37
38fn default_replay_path() -> PathBuf {
40 PathBuf::from("model/general/v2/replay.rkyv")
41}
42
43pub struct ReplayBuffer {
45 entries: Vec<(f64, BuildResult)>,
46 capacity: usize,
47}
48
49impl ReplayBuffer {
50 pub fn new(capacity: usize) -> Self {
52 Self {
53 entries: Vec::new(),
54 capacity,
55 }
56 }
57
58 pub fn push(&mut self, result: BuildResult) {
60 let priority = Self::compute_priority(&result);
61 self.entries.push((priority, result));
62
63 if self.entries.len() > self.capacity {
65 self.entries
66 .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
67 self.entries.truncate(self.capacity);
68 }
69 }
70
71 pub fn len(&self) -> usize {
73 self.entries.len()
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.entries.is_empty()
79 }
80
81 pub fn valid_count(&self) -> usize {
83 self.entries
84 .iter()
85 .filter(|(_, r)| r.valid && !r.fallback_used)
86 .count()
87 }
88
89 pub fn sample(&self, batch_size: usize) -> Vec<&BuildResult> {
91 self.entries
92 .iter()
93 .take(batch_size)
94 .map(|(_, r)| r)
95 .collect()
96 }
97
98 fn compute_priority(result: &BuildResult) -> f64 {
100 if !result.valid {
101 return 0.001; }
103 if result.fallback_used {
104 return 0.01;
105 }
106 let improvement = result
108 .compiler_cycles
109 .saturating_sub(result.clock_cycles.unwrap_or(result.compiler_cycles));
110 if result.compiler_cycles == 0 {
111 return 1.0;
112 }
113 1.0 + (improvement as f64 / result.compiler_cycles as f64)
114 }
115
116 pub fn save(&self, path: Option<&Path>) -> Result<(), String> {
118 let path = path.map(PathBuf::from).unwrap_or_else(default_replay_path);
119 if let Some(parent) = path.parent() {
120 std::fs::create_dir_all(parent)
121 .map_err(|e| format!("mkdir {}: {}", parent.display(), e))?;
122 }
123
124 let archive = ReplayArchive {
125 entries: self.entries.iter().map(|(_, r)| r.clone()).collect(),
126 };
127
128 let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&archive)
129 .map_err(|e| format!("rkyv serialize: {}", e))?;
130 std::fs::write(&path, &bytes).map_err(|e| format!("write {}: {}", path.display(), e))?;
131
132 Ok(())
133 }
134
135 pub fn load(capacity: usize, path: Option<&Path>) -> Result<Self, String> {
137 let path = path.map(PathBuf::from).unwrap_or_else(default_replay_path);
138 if !path.exists() {
139 return Ok(Self::new(capacity));
140 }
141
142 let bytes = std::fs::read(&path).map_err(|e| format!("read {}: {}", path.display(), e))?;
143 let archive = rkyv::from_bytes::<ReplayArchive, rkyv::rancor::Error>(&bytes)
144 .map_err(|e| format!("rkyv deserialize: {}", e))?;
145
146 let mut buf = Self::new(capacity);
147 for result in archive.entries {
148 buf.push(result);
149 }
150 Ok(buf)
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 fn make_result(valid: bool, cycles: Option<u64>, compiler: u64) -> BuildResult {
159 BuildResult {
160 tir_hash: [0u8; 32],
161 generated_tasm: vec!["push 1".into()],
162 valid,
163 clock_cycles: cycles,
164 compiler_cycles: compiler,
165 fallback_used: false,
166 timestamp: 0,
167 model_version: 1,
168 }
169 }
170
171 #[test]
172 fn replay_buffer_capacity() {
173 let mut buf = ReplayBuffer::new(3);
174 for i in 0..5 {
175 buf.push(make_result(true, Some(10 - i), 10));
176 }
177 assert_eq!(buf.len(), 3);
178 }
179
180 #[test]
181 fn replay_buffer_valid_count() {
182 let mut buf = ReplayBuffer::new(10);
183 buf.push(make_result(true, Some(5), 10));
184 buf.push(make_result(false, None, 10));
185 buf.push(make_result(true, Some(8), 10));
186 assert_eq!(buf.valid_count(), 2);
187 }
188
189 #[test]
190 fn replay_buffer_save_load_roundtrip() {
191 let dir = std::env::temp_dir().join("trident_test_replay");
192 let path = dir.join("test_replay.rkyv");
193 let _ = std::fs::remove_file(&path);
194
195 let mut buf = ReplayBuffer::new(10);
196 buf.push(make_result(true, Some(5), 10));
197 buf.push(make_result(false, None, 10));
198 buf.push(make_result(true, Some(8), 10));
199 buf.save(Some(&path)).unwrap();
200
201 let loaded = ReplayBuffer::load(10, Some(&path)).unwrap();
202 assert_eq!(loaded.len(), 3);
203 assert_eq!(loaded.valid_count(), 2);
204
205 let _ = std::fs::remove_file(&path);
206 let _ = std::fs::remove_dir(&dir);
207 }
208
209 #[test]
210 fn replay_buffer_load_missing_file() {
211 let path = std::env::temp_dir().join("trident_nonexistent_replay.rkyv");
212 let loaded = ReplayBuffer::load(10, Some(&path)).unwrap();
213 assert_eq!(loaded.len(), 0);
214 }
215
216 #[test]
217 fn replay_buffer_priority_ordering() {
218 let mut buf = ReplayBuffer::new(10);
219 buf.push(make_result(true, Some(10), 10)); buf.push(make_result(true, Some(5), 10)); buf.push(make_result(false, None, 10)); let samples = buf.sample(3);
223 assert!(samples[0].valid);
225 }
226}