zeph_bench/loaders/
gaia.rs1use std::{
5 io::{BufRead as _, BufReader},
6 path::Path,
7};
8
9use serde::Deserialize;
10
11use crate::{
12 error::BenchError,
13 scenario::{DatasetLoader, EvalResult, Evaluator, Scenario, gaia_normalized_exact_match},
14};
15
16#[derive(Debug, Deserialize)]
17struct GaiaRecord {
18 task_id: String,
19 #[serde(rename = "Question")]
20 question: String,
21 #[serde(rename = "Level")]
22 level: u8,
23 #[serde(rename = "Final answer")]
24 final_answer: String,
25 #[serde(rename = "Annotator Metadata")]
26 annotator_metadata: Option<serde_json::Value>,
27}
28
29#[derive(Debug)]
69pub struct GaiaLoader {
70 pub level: Option<u8>,
72}
73
74impl GaiaLoader {
75 #[must_use]
86 pub fn all_levels() -> Self {
87 Self { level: None }
88 }
89
90 #[must_use]
103 pub fn with_level(level: u8) -> Self {
104 Self { level: Some(level) }
105 }
106}
107
108impl DatasetLoader for GaiaLoader {
109 fn name(&self) -> &'static str {
110 "gaia"
111 }
112
113 fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError> {
118 let file = std::fs::File::open(path)?;
119 let reader = BufReader::new(file);
120
121 let mut scenarios = Vec::new();
122 for (line_number, line) in reader.lines().enumerate() {
123 let line = line?;
124 let trimmed = line.trim();
125 if trimmed.is_empty() {
126 continue;
127 }
128 let record: GaiaRecord = serde_json::from_str(trimmed)
129 .map_err(|e| BenchError::InvalidFormat(format!("line {line_number}: {e}")))?;
130
131 if let Some(filter_level) = self.level
132 && record.level != filter_level
133 {
134 continue;
135 }
136
137 let metadata = serde_json::json!({
138 "level": record.level,
139 "annotator_metadata": record.annotator_metadata,
140 });
141
142 scenarios.push(Scenario::single(
143 record.task_id,
144 record.question,
145 record.final_answer,
146 metadata,
147 ));
148 }
149 Ok(scenarios)
150 }
151}
152
153#[derive(Debug)]
177pub struct GaiaEvaluator;
178
179impl Evaluator for GaiaEvaluator {
180 fn evaluate(&self, scenario: &Scenario, agent_response: &str) -> EvalResult {
181 let passed = gaia_normalized_exact_match(agent_response, &scenario.expected);
182 EvalResult {
183 scenario_id: scenario.id.clone(),
184 score: if passed { 1.0 } else { 0.0 },
185 passed,
186 details: format!(
187 "gaia_normalized_exact_match={}",
188 if passed { "true" } else { "false" }
189 ),
190 }
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 const FIXTURE: &str = r#"{"task_id": "t1", "Question": "What year did WWII end?", "Level": 1, "Final answer": "1945", "Annotator Metadata": {"difficulty": "easy"}}
199{"task_id": "t2", "Question": "Who wrote Hamlet?", "Level": 2, "Final answer": "Shakespeare", "Annotator Metadata": null}
200{"task_id": "t3", "Question": "Capital of Japan?", "Level": 1, "Final answer": "Tokyo", "Annotator Metadata": null}
201"#;
202
203 fn load_from_str(jsonl: &str, level: Option<u8>) -> Vec<Scenario> {
204 let dir = tempfile::tempdir().unwrap();
205 let path = dir.path().join("gaia.jsonl");
206 std::fs::write(&path, jsonl).unwrap();
207 GaiaLoader { level }.load(&path).unwrap()
208 }
209
210 #[test]
211 fn load_all_levels_parses_scenario_count() {
212 let scenarios = load_from_str(FIXTURE, None);
213 assert_eq!(scenarios.len(), 3);
214 }
215
216 #[test]
217 fn load_filters_by_level() {
218 let scenarios = load_from_str(FIXTURE, Some(1));
219 assert_eq!(scenarios.len(), 2);
220 for s in &scenarios {
221 assert_eq!(s.metadata["level"], 1);
222 }
223 }
224
225 #[test]
226 fn load_maps_task_id_to_scenario_id() {
227 let scenarios = load_from_str(FIXTURE, None);
228 assert_eq!(scenarios[0].id, "t1");
229 assert_eq!(scenarios[1].id, "t2");
230 }
231
232 #[test]
233 fn load_maps_prompt_and_expected() {
234 let scenarios = load_from_str(FIXTURE, None);
235 assert_eq!(
236 scenarios[0].primary_prompt().unwrap(),
237 "What year did WWII end?"
238 );
239 assert_eq!(scenarios[0].expected, "1945");
240 }
241
242 #[test]
243 fn load_stores_level_in_metadata() {
244 let scenarios = load_from_str(FIXTURE, None);
245 assert_eq!(scenarios[1].metadata["level"], 2);
246 }
247
248 #[test]
249 fn evaluator_normalized_match_passes() {
250 let scenarios = load_from_str(FIXTURE, None);
251 let result = GaiaEvaluator.evaluate(&scenarios[0], "1945");
253 assert!(result.passed);
254 }
255
256 #[test]
257 fn evaluator_wrong_answer_fails() {
258 let scenarios = load_from_str(FIXTURE, None);
259 let result = GaiaEvaluator.evaluate(&scenarios[0], "1944");
260 assert!(!result.passed);
261 assert!(result.score < f64::EPSILON);
262 }
263
264 #[test]
265 fn evaluator_strips_article_the() {
266 let scenarios = load_from_str(FIXTURE, None);
267 let result = GaiaEvaluator.evaluate(&scenarios[2], "The Tokyo");
269 assert!(result.passed);
270 }
271
272 #[test]
273 fn load_invalid_jsonl_returns_error() {
274 let dir = tempfile::tempdir().unwrap();
275 let path = dir.path().join("bad.jsonl");
276 std::fs::write(&path, "not json\n").unwrap();
277 assert!(GaiaLoader::all_levels().load(&path).is_err());
278 }
279
280 #[test]
281 fn all_levels_constructor() {
282 let loader = GaiaLoader::all_levels();
283 assert!(loader.level.is_none());
284 }
285
286 #[test]
287 fn with_level_constructor() {
288 let loader = GaiaLoader::with_level(2);
289 assert_eq!(loader.level, Some(2));
290 }
291}