Skip to main content

zeph_bench/loaders/
locomo.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::path::Path;
5
6use serde::Deserialize;
7
8use crate::{
9    error::BenchError,
10    scenario::{DatasetLoader, EvalResult, Evaluator, Scenario, token_f1},
11};
12
13const PASS_THRESHOLD: f64 = 0.5;
14
15#[derive(Debug, Deserialize)]
16struct LocomoSession {
17    session_id: String,
18    qa: Vec<LocomoQa>,
19}
20
21#[derive(Debug, Deserialize)]
22struct LocomoQa {
23    question: String,
24    answer: String,
25}
26
27/// Loads LOCOMO benchmark scenarios from a JSON file.
28///
29/// **Source**: [`lmlab/locomo`](https://huggingface.co/datasets/lmlab/locomo) on `HuggingFace`.
30///
31/// **Schema**: the file is a JSON array of session objects:
32/// ```json
33/// [
34///   {
35///     "session_id": "abc",
36///     "qa": [
37///       {"question": "...", "answer": "..."}
38///     ]
39///   }
40/// ]
41/// ```
42///
43/// Each QA pair within a session becomes one [`Scenario`] with id
44/// `"{session_id}_{qa_index}"` (zero-based). `metadata` is set to
45/// [`serde_json::Value::Null`] because LOCOMO QA pairs carry no extra fields.
46///
47/// # Examples
48///
49/// ```no_run
50/// use std::path::Path;
51/// use zeph_bench::loaders::LocomoLoader;
52/// use zeph_bench::scenario::DatasetLoader;
53///
54/// let scenarios = LocomoLoader.load(Path::new("/data/locomo.json")).unwrap();
55/// println!("loaded {} scenarios", scenarios.len());
56/// ```
57#[derive(Debug)]
58pub struct LocomoLoader;
59
60impl DatasetLoader for LocomoLoader {
61    fn name(&self) -> &'static str {
62        "locomo"
63    }
64
65    /// # Errors
66    ///
67    /// Returns [`BenchError::Io`] when the file cannot be read and
68    /// [`BenchError::InvalidFormat`] when JSON parsing fails.
69    fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError> {
70        let content = std::fs::read_to_string(path)?;
71        let sessions: Vec<LocomoSession> =
72            serde_json::from_str(&content).map_err(|e| BenchError::InvalidFormat(e.to_string()))?;
73
74        let mut scenarios = Vec::new();
75        for session in sessions {
76            for (idx, qa) in session.qa.iter().enumerate() {
77                scenarios.push(Scenario::single(
78                    format!("{}_{}", session.session_id, idx),
79                    qa.question.clone(),
80                    qa.answer.clone(),
81                    serde_json::Value::Null,
82                ));
83            }
84        }
85        Ok(scenarios)
86    }
87}
88
89/// Evaluates LOCOMO responses using token F1 with a pass threshold of 0.5.
90///
91/// A response passes when its token F1 score against the gold answer is ≥ 0.5.
92/// The raw score (in `0.0..=1.0`) is always written to the result regardless of
93/// the pass/fail decision.
94///
95/// # Examples
96///
97/// ```
98/// use zeph_bench::{Scenario, loaders::LocomoEvaluator};
99/// use zeph_bench::scenario::Evaluator;
100///
101/// let scenario = Scenario::single(
102///     "s1_0",
103///     "What is the capital of France?",
104///     "Paris",
105///     serde_json::Value::Null,
106/// );
107///
108/// let result = LocomoEvaluator.evaluate(&scenario, "Paris");
109/// assert!((result.score - 1.0).abs() < f64::EPSILON);
110/// assert!(result.passed);
111///
112/// let bad = LocomoEvaluator.evaluate(&scenario, "completely unrelated answer xyz");
113/// assert!(!bad.passed);
114/// ```
115#[derive(Debug)]
116pub struct LocomoEvaluator;
117
118impl Evaluator for LocomoEvaluator {
119    fn evaluate(&self, scenario: &Scenario, agent_response: &str) -> EvalResult {
120        // Normalize both sides before scoring: lowercase and keep only alphanumeric/whitespace.
121        // This ensures "4." and "Leonardo da Vinci." match their expected forms.
122        let normalized_response = normalize_for_f1(agent_response);
123        let normalized_expected = normalize_for_f1(&scenario.expected);
124        let score = token_f1(&normalized_response, &normalized_expected);
125        EvalResult {
126            scenario_id: scenario.id.clone(),
127            score,
128            passed: score >= PASS_THRESHOLD,
129            details: format!("token_f1={score:.4}"),
130        }
131    }
132}
133
134/// Normalize a string for token-F1 scoring: lowercase and strip non-alphanumeric characters.
135///
136/// This mirrors the normalization used in the original `SQuAD` evaluation script and
137/// ensures that punctuation differences (e.g., "Paris." vs "Paris") do not penalize
138/// otherwise correct answers.
139fn normalize_for_f1(s: &str) -> String {
140    s.chars()
141        .filter(|c| c.is_alphanumeric() || c.is_whitespace())
142        .collect::<String>()
143        .to_lowercase()
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    const FIXTURE: &str = r#"[
151        {
152            "session_id": "s1",
153            "qa": [
154                {"question": "What is Rust?", "answer": "A systems programming language"},
155                {"question": "Is it fast?", "answer": "Yes"}
156            ]
157        }
158    ]"#;
159
160    fn load_from_str(json: &str) -> Vec<Scenario> {
161        let dir = tempfile::tempdir().unwrap();
162        let path = dir.path().join("locomo.json");
163        std::fs::write(&path, json).unwrap();
164        LocomoLoader.load(&path).unwrap()
165    }
166
167    #[test]
168    fn load_parses_scenario_count() {
169        let scenarios = load_from_str(FIXTURE);
170        assert_eq!(scenarios.len(), 2);
171    }
172
173    #[test]
174    fn load_builds_correct_ids() {
175        let scenarios = load_from_str(FIXTURE);
176        assert_eq!(scenarios[0].id, "s1_0");
177        assert_eq!(scenarios[1].id, "s1_1");
178    }
179
180    #[test]
181    fn load_maps_prompt_and_expected() {
182        let scenarios = load_from_str(FIXTURE);
183        assert_eq!(scenarios[0].primary_prompt().unwrap(), "What is Rust?");
184        assert_eq!(scenarios[0].expected, "A systems programming language");
185    }
186
187    #[test]
188    fn evaluator_perfect_match_passes() {
189        let scenarios = load_from_str(FIXTURE);
190        let result = LocomoEvaluator.evaluate(&scenarios[0], "A systems programming language");
191        assert!((result.score - 1.0).abs() < f64::EPSILON);
192        assert!(result.passed);
193    }
194
195    #[test]
196    fn evaluator_no_match_fails() {
197        let scenarios = load_from_str(FIXTURE);
198        let result = LocomoEvaluator.evaluate(&scenarios[0], "completely different response xyz");
199        assert!(!result.passed);
200    }
201
202    #[test]
203    fn load_invalid_json_returns_error() {
204        let dir = tempfile::tempdir().unwrap();
205        let path = dir.path().join("bad.json");
206        std::fs::write(&path, "not json").unwrap();
207        assert!(LocomoLoader.load(&path).is_err());
208    }
209}