Skip to main content

trident/neural/data/
replay.rs

1//! Replay buffer for online learning (Stage 3).
2//!
3//! Stores build results with prioritized experience replay.
4//! Persistence via rkyv zero-copy archives.
5
6use rkyv::{Archive, Deserialize, Serialize};
7use std::path::{Path, PathBuf};
8
9/// Result of a neural compilation attempt.
10#[derive(Archive, Serialize, Deserialize, Clone, Debug)]
11#[rkyv(derive(Debug))]
12pub struct BuildResult {
13    /// Poseidon2 CID of the TIR input.
14    pub tir_hash: [u8; 32],
15    /// Generated TASM instructions.
16    pub generated_tasm: Vec<String>,
17    /// Whether the output passed stack verification.
18    pub valid: bool,
19    /// Clock cycles if valid (None if invalid).
20    pub clock_cycles: Option<u64>,
21    /// Compiler baseline cycles.
22    pub compiler_cycles: u64,
23    /// Whether fallback to compiler was used.
24    pub fallback_used: bool,
25    /// Unix timestamp.
26    pub timestamp: u64,
27    /// Model checkpoint version.
28    pub model_version: u32,
29}
30
31/// Serializable wrapper for the replay buffer entries.
32#[derive(Archive, Serialize, Deserialize, Clone, Debug)]
33#[rkyv(derive(Debug))]
34struct ReplayArchive {
35    entries: Vec<BuildResult>,
36}
37
38/// Default path for replay buffer persistence.
39fn default_replay_path() -> PathBuf {
40    PathBuf::from("model/general/v2/replay.rkyv")
41}
42
43/// Priority-based replay buffer.
44pub struct ReplayBuffer {
45    entries: Vec<(f64, BuildResult)>,
46    capacity: usize,
47}
48
49impl ReplayBuffer {
50    /// Create a new replay buffer with the given capacity.
51    pub fn new(capacity: usize) -> Self {
52        Self {
53            entries: Vec::new(),
54            capacity,
55        }
56    }
57
58    /// Add a build result with computed priority.
59    pub fn push(&mut self, result: BuildResult) {
60        let priority = Self::compute_priority(&result);
61        self.entries.push((priority, result));
62
63        // If over capacity, remove lowest-priority entry
64        if self.entries.len() > self.capacity {
65            self.entries
66                .sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
67            self.entries.truncate(self.capacity);
68        }
69    }
70
71    /// Number of entries.
72    pub fn len(&self) -> usize {
73        self.entries.len()
74    }
75
76    /// Whether the buffer is empty.
77    pub fn is_empty(&self) -> bool {
78        self.entries.is_empty()
79    }
80
81    /// Count of valid (non-fallback) results.
82    pub fn valid_count(&self) -> usize {
83        self.entries
84            .iter()
85            .filter(|(_, r)| r.valid && !r.fallback_used)
86            .count()
87    }
88
89    /// Sample a batch of entries (highest priority first).
90    pub fn sample(&self, batch_size: usize) -> Vec<&BuildResult> {
91        self.entries
92            .iter()
93            .take(batch_size)
94            .map(|(_, r)| r)
95            .collect()
96    }
97
98    /// Compute priority for a build result.
99    fn compute_priority(result: &BuildResult) -> f64 {
100        if !result.valid {
101            return 0.001; // Low priority for invalid results
102        }
103        if result.fallback_used {
104            return 0.01;
105        }
106        // Priority = reward = improvement ratio
107        let improvement = result
108            .compiler_cycles
109            .saturating_sub(result.clock_cycles.unwrap_or(result.compiler_cycles));
110        if result.compiler_cycles == 0 {
111            return 1.0;
112        }
113        1.0 + (improvement as f64 / result.compiler_cycles as f64)
114    }
115
116    /// Save replay buffer to disk as rkyv archive.
117    pub fn save(&self, path: Option<&Path>) -> Result<(), String> {
118        let path = path.map(PathBuf::from).unwrap_or_else(default_replay_path);
119        if let Some(parent) = path.parent() {
120            std::fs::create_dir_all(parent)
121                .map_err(|e| format!("mkdir {}: {}", parent.display(), e))?;
122        }
123
124        let archive = ReplayArchive {
125            entries: self.entries.iter().map(|(_, r)| r.clone()).collect(),
126        };
127
128        let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(&archive)
129            .map_err(|e| format!("rkyv serialize: {}", e))?;
130        std::fs::write(&path, &bytes).map_err(|e| format!("write {}: {}", path.display(), e))?;
131
132        Ok(())
133    }
134
135    /// Load replay buffer from disk. Returns empty buffer if file doesn't exist.
136    pub fn load(capacity: usize, path: Option<&Path>) -> Result<Self, String> {
137        let path = path.map(PathBuf::from).unwrap_or_else(default_replay_path);
138        if !path.exists() {
139            return Ok(Self::new(capacity));
140        }
141
142        let bytes = std::fs::read(&path).map_err(|e| format!("read {}: {}", path.display(), e))?;
143        let archive = rkyv::from_bytes::<ReplayArchive, rkyv::rancor::Error>(&bytes)
144            .map_err(|e| format!("rkyv deserialize: {}", e))?;
145
146        let mut buf = Self::new(capacity);
147        for result in archive.entries {
148            buf.push(result);
149        }
150        Ok(buf)
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    fn make_result(valid: bool, cycles: Option<u64>, compiler: u64) -> BuildResult {
159        BuildResult {
160            tir_hash: [0u8; 32],
161            generated_tasm: vec!["push 1".into()],
162            valid,
163            clock_cycles: cycles,
164            compiler_cycles: compiler,
165            fallback_used: false,
166            timestamp: 0,
167            model_version: 1,
168        }
169    }
170
171    #[test]
172    fn replay_buffer_capacity() {
173        let mut buf = ReplayBuffer::new(3);
174        for i in 0..5 {
175            buf.push(make_result(true, Some(10 - i), 10));
176        }
177        assert_eq!(buf.len(), 3);
178    }
179
180    #[test]
181    fn replay_buffer_valid_count() {
182        let mut buf = ReplayBuffer::new(10);
183        buf.push(make_result(true, Some(5), 10));
184        buf.push(make_result(false, None, 10));
185        buf.push(make_result(true, Some(8), 10));
186        assert_eq!(buf.valid_count(), 2);
187    }
188
189    #[test]
190    fn replay_buffer_save_load_roundtrip() {
191        let dir = std::env::temp_dir().join("trident_test_replay");
192        let path = dir.join("test_replay.rkyv");
193        let _ = std::fs::remove_file(&path);
194
195        let mut buf = ReplayBuffer::new(10);
196        buf.push(make_result(true, Some(5), 10));
197        buf.push(make_result(false, None, 10));
198        buf.push(make_result(true, Some(8), 10));
199        buf.save(Some(&path)).unwrap();
200
201        let loaded = ReplayBuffer::load(10, Some(&path)).unwrap();
202        assert_eq!(loaded.len(), 3);
203        assert_eq!(loaded.valid_count(), 2);
204
205        let _ = std::fs::remove_file(&path);
206        let _ = std::fs::remove_dir(&dir);
207    }
208
209    #[test]
210    fn replay_buffer_load_missing_file() {
211        let path = std::env::temp_dir().join("trident_nonexistent_replay.rkyv");
212        let loaded = ReplayBuffer::load(10, Some(&path)).unwrap();
213        assert_eq!(loaded.len(), 0);
214    }
215
216    #[test]
217    fn replay_buffer_priority_ordering() {
218        let mut buf = ReplayBuffer::new(10);
219        buf.push(make_result(true, Some(10), 10)); // no improvement
220        buf.push(make_result(true, Some(5), 10)); // 50% improvement
221        buf.push(make_result(false, None, 10)); // invalid
222        let samples = buf.sample(3);
223        // Highest priority first (after sort on push)
224        assert!(samples[0].valid);
225    }
226}