split_brain_harness/
tool_memory.rs1use std::path::Path;
10
11use serde::{Deserialize, Serialize};
12
13use crate::capability::{CapabilityMemoryRecord, CapabilityRequest, ToolMetrics};
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct PatternMetrics {
18 pub runs: u64,
19 pub successes: u64,
20 pub total_runtime_ms: u64,
21 pub total_input_bytes: usize,
22 pub total_output_bytes: usize,
23 pub consecutive_failures: u64,
25}
26
27impl PatternMetrics {
28 pub fn record(&mut self, metrics: &ToolMetrics) {
29 self.runs += 1;
30 if metrics.success {
31 self.successes += 1;
32 self.consecutive_failures = 0;
33 } else {
34 self.consecutive_failures += 1;
35 }
36 self.total_runtime_ms += metrics.runtime_ms;
37 self.total_input_bytes += metrics.input_bytes;
38 self.total_output_bytes += metrics.output_bytes;
39 }
40
41 pub fn success_rate(&self) -> f64 {
42 if self.runs == 0 {
43 return 0.0;
44 }
45 self.successes as f64 / self.runs as f64
46 }
47
48 pub fn avg_runtime_ms(&self) -> f64 {
49 if self.runs == 0 {
50 return 0.0;
51 }
52 self.total_runtime_ms as f64 / self.runs as f64
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct MemoryEntry {
59 pub record: CapabilityMemoryRecord,
60 pub metrics: PatternMetrics,
61}
62
63#[derive(Debug, Default)]
66pub struct CapabilityMemory {
67 entries: Vec<MemoryEntry>,
68}
69
70impl CapabilityMemory {
71 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn lookup(&self, signature: &str) -> Option<&MemoryEntry> {
77 self.entries
78 .iter()
79 .find(|e| e.record.problem_signature == signature)
80 }
81
82 pub fn upsert(&mut self, record: CapabilityMemoryRecord, metrics: &ToolMetrics) {
84 if let Some(entry) = self
85 .entries
86 .iter_mut()
87 .find(|e| e.record.problem_signature == record.problem_signature)
88 {
89 entry.record = record;
90 entry.metrics.record(metrics);
91 } else {
92 let mut pm = PatternMetrics::default();
93 pm.record(metrics);
94 self.entries.push(MemoryEntry {
95 record,
96 metrics: pm,
97 });
98 }
99 }
100
101 pub fn save(&self, path: &Path) -> Result<(), String> {
103 let json = serde_json::to_string_pretty(&self.entries)
104 .map_err(|e| format!("serialize error: {e}"))?;
105 std::fs::write(path, json)
106 .map_err(|e| format!("failed to write memory to {}: {e}", path.display()))?;
107 Ok(())
108 }
109
110 pub fn load(path: &Path) -> Result<Self, String> {
113 if !path.exists() {
114 return Ok(Self::new());
115 }
116 let json = std::fs::read_to_string(path)
117 .map_err(|e| format!("failed to read memory from {}: {e}", path.display()))?;
118 let entries: Vec<MemoryEntry> = serde_json::from_str(&json)
119 .map_err(|e| format!("failed to parse memory file {}: {e}", path.display()))?;
120 Ok(Self { entries })
121 }
122
123 pub fn len(&self) -> usize {
125 self.entries.len()
126 }
127
128 pub fn is_empty(&self) -> bool {
129 self.entries.is_empty()
130 }
131
132 pub fn derive_signature(req: &CapabilityRequest) -> String {
136 let cap = req.capability.to_lowercase().replace(' ', "_");
137 let inp = shape_token(&req.input_contract);
138 let out = shape_token(&req.output_contract);
139 format!("{cap}:{inp}:{out}")
140 }
141}
142
143fn shape_token(contract: &str) -> String {
145 contract
146 .split_whitespace()
147 .take(3)
148 .map(|w| {
149 w.to_lowercase()
150 .trim_matches(|c: char| !c.is_alphanumeric())
151 .to_string()
152 })
153 .filter(|s| !s.is_empty())
154 .collect::<Vec<_>>()
155 .join("_")
156}
157
158#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::capability::CapabilityConstraints;
166
167 fn make_record(sig: &str) -> CapabilityMemoryRecord {
168 CapabilityMemoryRecord {
169 problem_signature: sig.into(),
170 solution_pattern: "mock".into(),
171 input_shape: "utf8_lines".into(),
172 output_shape: "json_counts".into(),
173 constraints: CapabilityConstraints::default(),
174 }
175 }
176
177 fn ok_metrics() -> ToolMetrics {
178 ToolMetrics {
179 runtime_ms: 10,
180 input_bytes: 100,
181 output_bytes: 50,
182 success: true,
183 }
184 }
185
186 #[test]
187 fn lookup_returns_none_when_empty() {
188 let mem = CapabilityMemory::new();
189 assert!(mem.lookup("anything").is_none());
190 }
191
192 #[test]
193 fn upsert_then_lookup() {
194 let mut mem = CapabilityMemory::new();
195 mem.upsert(make_record("test:sig"), &ok_metrics());
196 assert!(mem.lookup("test:sig").is_some());
197 }
198
199 #[test]
200 fn upsert_accumulates_metrics() {
201 let mut mem = CapabilityMemory::new();
202 mem.upsert(make_record("sig"), &ok_metrics());
203 mem.upsert(make_record("sig"), &ok_metrics());
204 let entry = mem.lookup("sig").unwrap();
205 assert_eq!(entry.metrics.runs, 2);
206 assert_eq!(entry.metrics.successes, 2);
207 assert_eq!(entry.metrics.total_runtime_ms, 20);
208 }
209
210 #[test]
211 fn upsert_different_sigs_stored_separately() {
212 let mut mem = CapabilityMemory::new();
213 mem.upsert(make_record("a"), &ok_metrics());
214 mem.upsert(make_record("b"), &ok_metrics());
215 assert_eq!(mem.len(), 2);
216 }
217
218 #[test]
219 fn success_rate_correct() {
220 let mut pm = PatternMetrics::default();
221 pm.record(&ToolMetrics {
222 success: true,
223 ..Default::default()
224 });
225 pm.record(&ToolMetrics {
226 success: false,
227 ..Default::default()
228 });
229 assert!((pm.success_rate() - 0.5).abs() < f64::EPSILON);
230 }
231
232 #[test]
233 fn derive_signature_is_stable() {
234 let req = CapabilityRequest {
235 kind: "capability_request".into(),
236 capability: "stream_parse_logs".into(),
237 input_contract: "UTF-8 log lines from stdin".into(),
238 output_contract: "JSON array of matching events".into(),
239 constraints: CapabilityConstraints::default(),
240 reason: "test".into(),
241 };
242 let s1 = CapabilityMemory::derive_signature(&req);
243 let s2 = CapabilityMemory::derive_signature(&req);
244 assert_eq!(s1, s2);
245 assert!(s1.starts_with("stream_parse_logs:"));
246 }
247
248 #[test]
249 fn derive_signature_different_contracts_differ() {
250 let req_a = CapabilityRequest {
251 kind: "capability_request".into(),
252 capability: "parse".into(),
253 input_contract: "utf8 text".into(),
254 output_contract: "json counts".into(),
255 constraints: CapabilityConstraints::default(),
256 reason: "r".into(),
257 };
258 let req_b = CapabilityRequest {
259 input_contract: "binary blob".into(),
260 ..req_a.clone()
261 };
262 assert_ne!(
263 CapabilityMemory::derive_signature(&req_a),
264 CapabilityMemory::derive_signature(&req_b)
265 );
266 }
267
268 #[test]
271 fn save_and_load_round_trip() {
272 let dir = tempfile::tempdir().unwrap();
273 let path = dir.path().join("memory.json");
274
275 let mut mem = CapabilityMemory::new();
276 mem.upsert(make_record("word_count:utf8:json"), &ok_metrics());
277 mem.upsert(
278 make_record("word_count:utf8:json"),
279 &ToolMetrics {
280 success: false,
281 ..Default::default()
282 },
283 );
284 mem.save(&path).unwrap();
285
286 let loaded = CapabilityMemory::load(&path).unwrap();
287 assert_eq!(loaded.len(), 1);
288 let entry = loaded.lookup("word_count:utf8:json").unwrap();
289 assert_eq!(entry.metrics.runs, 2);
290 assert_eq!(entry.metrics.successes, 1);
291 assert_eq!(entry.metrics.consecutive_failures, 1);
292 }
293
294 #[test]
295 fn load_nonexistent_path_returns_empty() {
296 let path = std::path::Path::new("/tmp/sbh-memory-does-not-exist-xyz.json");
297 let mem = CapabilityMemory::load(path).unwrap();
298 assert!(mem.is_empty());
299 }
300
301 #[test]
302 fn load_corrupt_file_returns_err() {
303 let dir = tempfile::tempdir().unwrap();
304 let path = dir.path().join("bad.json");
305 std::fs::write(&path, b"not valid json [[{{").unwrap();
306 assert!(CapabilityMemory::load(&path).is_err());
307 }
308
309 #[test]
310 fn save_preserves_consecutive_failures() {
311 let dir = tempfile::tempdir().unwrap();
312 let path = dir.path().join("mem.json");
313
314 let mut mem = CapabilityMemory::new();
315 let fail = ToolMetrics {
316 success: false,
317 ..Default::default()
318 };
319 mem.upsert(make_record("sig"), &fail);
320 mem.upsert(make_record("sig"), &fail);
321 mem.save(&path).unwrap();
322
323 let loaded = CapabilityMemory::load(&path).unwrap();
324 assert_eq!(
325 loaded.lookup("sig").unwrap().metrics.consecutive_failures,
326 2
327 );
328 }
329}