1use serde::{Deserialize, Serialize};
11use std::path::Path;
12
13#[derive(Debug, Clone)]
15pub struct BenchmarkTask {
16 pub name: &'static str,
18 pub prompt: &'static str,
20 pub max_steps: usize,
22 pub verify: fn(&BenchmarkResult) -> f64,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct BenchmarkResult {
29 pub name: String,
30 pub steps: usize,
31 pub completed: bool,
32 pub tool_errors: usize,
33 pub loop_warnings: usize,
34 pub output: String,
36 pub score: f64,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct BenchmarkReport {
43 pub timestamp: u64,
44 pub commit: String,
45 pub results: Vec<BenchmarkResult>,
46 pub avg_score: f64,
48 pub std_dev: f64,
50}
51
52const TASK_QA: BenchmarkTask = BenchmarkTask {
58 name: "qa_simple",
59 prompt: "What is the capital of France? Answer with the finish tool.",
60 max_steps: 3,
61 verify: verify_qa,
62};
63
64fn verify_qa(r: &BenchmarkResult) -> f64 {
65 if !r.completed {
66 return 0.0;
67 }
68 let has_paris = r.output.to_lowercase().contains("paris");
69 let efficiency = if r.steps <= 1 {
70 1.0
71 } else {
72 0.8 / r.steps as f64
73 };
74 if has_paris {
75 (0.7 + efficiency * 0.3).min(1.0)
76 } else {
77 0.1
78 }
79}
80
81const TASK_READ: BenchmarkTask = BenchmarkTask {
83 name: "read_file",
84 prompt: "Read the file Cargo.toml in the current directory and tell me the package name. Use finish tool with the name.",
85 max_steps: 5,
86 verify: verify_read,
87};
88
89fn verify_read(r: &BenchmarkResult) -> f64 {
90 if !r.completed {
91 return 0.0;
92 }
93 let output_lower = r.output.to_lowercase();
95 let found_name = output_lower.contains("rust-code")
96 || output_lower.contains("sgr-agent")
97 || output_lower.contains("package");
98 let efficiency = (3.0 / r.steps.max(1) as f64).min(1.0);
99 if found_name {
100 0.6 + efficiency * 0.4
101 } else if r.completed {
102 0.3
103 } else {
104 0.0
105 }
106}
107
108const TASK_SEARCH: BenchmarkTask = BenchmarkTask {
110 name: "code_search",
111 prompt: "Search for the function 'parse_spec' in the codebase and tell me which file it's in. Use finish tool.",
112 max_steps: 8,
113 verify: verify_search,
114};
115
116fn verify_search(r: &BenchmarkResult) -> f64 {
117 if !r.completed {
118 return 0.0;
119 }
120 let output_lower = r.output.to_lowercase();
121 let found = output_lower.contains("openapi")
122 || output_lower.contains("spec.rs")
123 || output_lower.contains("parse_spec");
124 let no_errors = r.tool_errors == 0;
125 let efficiency = (5.0 / r.steps.max(1) as f64).min(1.0);
126 let mut score = 0.0;
127 if found {
128 score += 0.5;
129 }
130 if no_errors {
131 score += 0.2;
132 }
133 score += efficiency * 0.3;
134 score.min(1.0)
135}
136
137const TASK_MULTI: BenchmarkTask = BenchmarkTask {
139 name: "multi_step",
140 prompt: "Read crates/sgr-agent/src/lib.rs, count how many pub mod declarations it has, and answer with the count using finish tool.",
141 max_steps: 10,
142 verify: verify_multi,
143};
144
145fn verify_multi(r: &BenchmarkResult) -> f64 {
146 if !r.completed {
147 return 0.0;
148 }
149 let has_number = r.output.chars().any(|c| c.is_ascii_digit());
151 let no_loops = r.loop_warnings == 0;
152 let efficiency = (4.0 / r.steps.max(1) as f64).min(1.0);
153 let mut score = 0.0;
154 if has_number {
155 score += 0.5;
156 }
157 if no_loops {
158 score += 0.2;
159 }
160 score += efficiency * 0.3;
161 score.min(1.0)
162}
163
164const TASK_GIT: BenchmarkTask = BenchmarkTask {
166 name: "git_status",
167 prompt: "Check git status of this repo. Tell me which branch we're on and if there are uncommitted changes. Use finish tool.",
168 max_steps: 5,
169 verify: verify_git,
170};
171
172fn verify_git(r: &BenchmarkResult) -> f64 {
173 if !r.completed {
174 return 0.0;
175 }
176 let output_lower = r.output.to_lowercase();
177 let has_branch = output_lower.contains("master")
178 || output_lower.contains("main")
179 || output_lower.contains("branch");
180 let has_status = output_lower.contains("clean")
181 || output_lower.contains("uncommitted")
182 || output_lower.contains("modified")
183 || output_lower.contains("changes");
184 let efficiency = (3.0 / r.steps.max(1) as f64).min(1.0);
185 let mut score = 0.0;
186 if has_branch {
187 score += 0.35;
188 }
189 if has_status {
190 score += 0.35;
191 }
192 score += efficiency * 0.3;
193 score.min(1.0)
194}
195
196pub fn all_tasks() -> Vec<BenchmarkTask> {
202 vec![TASK_QA, TASK_READ, TASK_SEARCH, TASK_MULTI, TASK_GIT]
203}
204
205pub fn compute_report(results: Vec<BenchmarkResult>, commit: &str) -> BenchmarkReport {
207 let n = results.len() as f64;
208 let avg = if n > 0.0 {
209 results.iter().map(|r| r.score).sum::<f64>() / n
210 } else {
211 0.0
212 };
213 let variance = if n > 1.0 {
214 results.iter().map(|r| (r.score - avg).powi(2)).sum::<f64>() / (n - 1.0)
215 } else {
216 0.0
217 };
218 let std_dev = variance.sqrt();
219 let ts = std::time::SystemTime::now()
220 .duration_since(std::time::UNIX_EPOCH)
221 .unwrap_or_default()
222 .as_secs();
223
224 BenchmarkReport {
225 timestamp: ts,
226 commit: commit.to_string(),
227 results,
228 avg_score: avg,
229 std_dev,
230 }
231}
232
233pub fn format_report(report: &BenchmarkReport) -> String {
235 let mut out = format!(
236 "## Benchmark Report\n\n\
237 Commit: {} | Score: {:.3} ± {:.3}\n\n\
238 | Task | Steps | Errors | Score | Status |\n\
239 |------|-------|--------|-------|--------|\n",
240 report.commit, report.avg_score, report.std_dev,
241 );
242 for r in &report.results {
243 let status = if r.score >= 0.8 {
244 "✓"
245 } else if r.score >= 0.5 {
246 "~"
247 } else {
248 "✗"
249 };
250 out.push_str(&format!(
251 "| {} | {} | {} | {:.2} | {} |\n",
252 r.name, r.steps, r.tool_errors, r.score, status,
253 ));
254 }
255 out.push_str(&format!(
256 "\n**Average: {:.3} ± {:.3}**\n",
257 report.avg_score, report.std_dev,
258 ));
259 out
260}
261
262pub fn log_benchmark(agent_home: &str, report: &BenchmarkReport) -> Result<(), String> {
264 let path = Path::new(agent_home).join("benchmark.jsonl");
265 let line = serde_json::to_string(report).map_err(|e| format!("serialize: {}", e))?;
266 use std::io::Write;
267 let mut f = std::fs::OpenOptions::new()
268 .create(true)
269 .append(true)
270 .open(&path)
271 .map_err(|e| format!("open: {}", e))?;
272 writeln!(f, "{}", line).map_err(|e| format!("write: {}", e))?;
273 Ok(())
274}
275
276pub fn load_benchmarks(agent_home: &str) -> Vec<BenchmarkReport> {
278 let path = Path::new(agent_home).join("benchmark.jsonl");
279 let content = match std::fs::read_to_string(&path) {
280 Ok(c) => c,
281 Err(_) => return Vec::new(),
282 };
283 content
284 .lines()
285 .filter(|l| !l.trim().is_empty())
286 .filter_map(|l| serde_json::from_str(l).ok())
287 .collect()
288}
289
290pub fn compare(before: &BenchmarkReport, after: &BenchmarkReport) -> &'static str {
292 if after.avg_score > before.avg_score + before.std_dev * 0.5 {
293 "keep" } else if after.avg_score < before.avg_score - before.std_dev * 0.5 {
295 "discard" } else {
297 "neutral" }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn all_tasks_has_five() {
307 assert_eq!(all_tasks().len(), 5);
308 }
309
310 #[test]
311 fn verify_qa_correct() {
312 let r = BenchmarkResult {
313 name: "qa".into(),
314 steps: 1,
315 completed: true,
316 tool_errors: 0,
317 loop_warnings: 0,
318 output: "The capital of France is Paris.".into(),
319 score: 0.0,
320 };
321 let s = verify_qa(&r);
322 assert!(
323 s > 0.9,
324 "correct answer in 1 step should score >0.9, got {}",
325 s
326 );
327 }
328
329 #[test]
330 fn verify_qa_wrong() {
331 let r = BenchmarkResult {
332 name: "qa".into(),
333 steps: 1,
334 completed: true,
335 tool_errors: 0,
336 loop_warnings: 0,
337 output: "I don't know".into(),
338 score: 0.0,
339 };
340 assert!(verify_qa(&r) < 0.5);
341 }
342
343 #[test]
344 fn verify_qa_not_completed() {
345 let r = BenchmarkResult {
346 name: "qa".into(),
347 steps: 3,
348 completed: false,
349 tool_errors: 1,
350 loop_warnings: 0,
351 output: "".into(),
352 score: 0.0,
353 };
354 assert_eq!(verify_qa(&r), 0.0);
355 }
356
357 #[test]
358 fn compute_report_avg_and_stddev() {
359 let results = vec![
360 BenchmarkResult {
361 name: "a".into(),
362 steps: 1,
363 completed: true,
364 tool_errors: 0,
365 loop_warnings: 0,
366 output: "".into(),
367 score: 0.8,
368 },
369 BenchmarkResult {
370 name: "b".into(),
371 steps: 2,
372 completed: true,
373 tool_errors: 0,
374 loop_warnings: 0,
375 output: "".into(),
376 score: 0.6,
377 },
378 ];
379 let report = compute_report(results, "abc123");
380 assert!((report.avg_score - 0.7).abs() < 0.001);
381 assert!(report.std_dev > 0.0);
382 }
383
384 #[test]
385 fn compare_improvement() {
386 let before = BenchmarkReport {
387 timestamp: 0,
388 commit: "a".into(),
389 results: vec![],
390 avg_score: 0.5,
391 std_dev: 0.1,
392 };
393 let after = BenchmarkReport {
394 timestamp: 1,
395 commit: "b".into(),
396 results: vec![],
397 avg_score: 0.7,
398 std_dev: 0.1,
399 };
400 assert_eq!(compare(&before, &after), "keep");
401 }
402
403 #[test]
404 fn compare_regression() {
405 let before = BenchmarkReport {
406 timestamp: 0,
407 commit: "a".into(),
408 results: vec![],
409 avg_score: 0.8,
410 std_dev: 0.05,
411 };
412 let after = BenchmarkReport {
413 timestamp: 1,
414 commit: "b".into(),
415 results: vec![],
416 avg_score: 0.6,
417 std_dev: 0.05,
418 };
419 assert_eq!(compare(&before, &after), "discard");
420 }
421
422 #[test]
423 fn compare_neutral() {
424 let before = BenchmarkReport {
425 timestamp: 0,
426 commit: "a".into(),
427 results: vec![],
428 avg_score: 0.7,
429 std_dev: 0.15,
430 };
431 let after = BenchmarkReport {
432 timestamp: 1,
433 commit: "b".into(),
434 results: vec![],
435 avg_score: 0.72,
436 std_dev: 0.15,
437 };
438 assert_eq!(compare(&before, &after), "neutral");
439 }
440
441 #[test]
442 fn format_report_markdown() {
443 let report = BenchmarkReport {
444 timestamp: 0,
445 commit: "abc123".into(),
446 results: vec![BenchmarkResult {
447 name: "test".into(),
448 steps: 2,
449 completed: true,
450 tool_errors: 0,
451 loop_warnings: 0,
452 output: "done".into(),
453 score: 0.9,
454 }],
455 avg_score: 0.9,
456 std_dev: 0.0,
457 };
458 let md = format_report(&report);
459 assert!(md.contains("abc123"));
460 assert!(md.contains("0.900"));
461 assert!(md.contains("test"));
462 }
463
464 #[test]
465 fn log_and_load_benchmarks() {
466 let dir = tempfile::tempdir().unwrap();
467 let home = dir.path().to_str().unwrap();
468 let report = BenchmarkReport {
469 timestamp: 12345,
470 commit: "test".into(),
471 results: vec![],
472 avg_score: 0.75,
473 std_dev: 0.1,
474 };
475 log_benchmark(home, &report).unwrap();
476 let history = load_benchmarks(home);
477 assert_eq!(history.len(), 1);
478 assert_eq!(history[0].avg_score, 0.75);
479 }
480}