Skip to main content

zeph_bench/
scenario.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::path::Path;
5
6use crate::error::BenchError;
7
8/// Role of a turn in a multi-turn scenario conversation.
9///
10/// # Examples
11///
12/// ```
13/// use zeph_bench::scenario::Role;
14///
15/// assert!(matches!(Role::User, Role::User));
16/// assert!(matches!(Role::Assistant, Role::Assistant));
17/// ```
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum Role {
20    /// A message from the human user.
21    User,
22    /// A message from the AI assistant.
23    Assistant,
24}
25
26/// One turn in a multi-turn scenario conversation.
27///
28/// # Examples
29///
30/// ```
31/// use zeph_bench::scenario::{Role, Turn};
32///
33/// let turn = Turn { role: Role::User, content: "What is the capital of France?".into() };
34/// assert!(matches!(turn.role, Role::User));
35/// ```
36#[derive(Debug, Clone)]
37pub struct Turn {
38    /// Who authored this turn.
39    pub role: Role,
40    /// Text content of the turn.
41    pub content: String,
42}
43
44/// A single benchmark scenario loaded from a dataset file.
45///
46/// Each scenario represents one question/task that will be presented to the agent.
47/// The `id` field is used to correlate agent responses with ground-truth answers and
48/// to skip already-completed scenarios during a `--resume` run.
49///
50/// Construct via [`Scenario::single`] for single-turn scenarios (all built-in loaders),
51/// or push [`Turn`]s directly into [`Scenario::turns`] for multi-turn scenarios.
52///
53/// # Examples
54///
55/// ```
56/// use zeph_bench::Scenario;
57///
58/// let scenario = Scenario::single(
59///     "gaia_t42",
60///     "What is the boiling point of water in Celsius?",
61///     "100",
62///     serde_json::json!({"level": 1}),
63/// );
64/// assert_eq!(scenario.id, "gaia_t42");
65/// assert_eq!(scenario.primary_prompt().unwrap(), "What is the boiling point of water in Celsius?");
66/// ```
67#[derive(Debug, Clone)]
68pub struct Scenario {
69    /// Unique identifier within the dataset (e.g. `"frames_0"`, `"s1_2"`).
70    pub id: String,
71    /// Ordered turns in this scenario. Non-empty by contract of [`Scenario::single`].
72    ///
73    /// Direct construction is allowed for multi-turn scenarios; callers must ensure
74    /// at least one [`Role::User`] turn is present before calling [`Scenario::primary_prompt`].
75    pub turns: Vec<Turn>,
76    /// The gold-standard answer used for scoring.
77    pub expected: String,
78    /// Dataset-specific extras such as difficulty level or `reasoning_types`.
79    ///
80    /// Set to [`serde_json::Value::Null`] when the dataset has no extra metadata.
81    pub metadata: serde_json::Value,
82}
83
84impl Scenario {
85    /// Convenience constructor for single-turn scenarios.
86    ///
87    /// Wraps `prompt` in a one-element [`Vec<Turn>`] with [`Role::User`]. All built-in
88    /// dataset loaders use this constructor.
89    ///
90    /// # Examples
91    ///
92    /// ```
93    /// use zeph_bench::Scenario;
94    ///
95    /// let s = Scenario::single("id1", "What year?", "2026", serde_json::Value::Null);
96    /// assert_eq!(s.primary_prompt().unwrap(), "What year?");
97    /// ```
98    #[must_use]
99    pub fn single(
100        id: impl Into<String>,
101        prompt: impl Into<String>,
102        expected: impl Into<String>,
103        metadata: serde_json::Value,
104    ) -> Self {
105        Self {
106            id: id.into(),
107            turns: vec![Turn {
108                role: Role::User,
109                content: prompt.into(),
110            }],
111            expected: expected.into(),
112            metadata,
113        }
114    }
115
116    /// Returns the content of the first [`Role::User`] turn.
117    ///
118    /// # Errors
119    ///
120    /// Returns [`BenchError::InvalidFormat`] when `turns` is empty or contains no
121    /// [`Role::User`] entry. Loaders must construct via [`Scenario::single`] or push
122    /// at least one user turn.
123    ///
124    /// # Examples
125    ///
126    /// ```
127    /// use zeph_bench::Scenario;
128    ///
129    /// let s = Scenario::single("id1", "hello", "world", serde_json::Value::Null);
130    /// assert_eq!(s.primary_prompt().unwrap(), "hello");
131    /// ```
132    pub fn primary_prompt(&self) -> Result<&str, BenchError> {
133        self.turns
134            .iter()
135            .find(|t| matches!(t.role, Role::User))
136            .map(|t| t.content.as_str())
137            .ok_or_else(|| {
138                BenchError::InvalidFormat(format!("scenario '{}' has no user turn", self.id))
139            })
140    }
141}
142
143/// Result of evaluating one agent response against the expected answer.
144///
145/// Produced by [`Evaluator::evaluate`]. The `score` is always in `0.0..=1.0`:
146/// - `1.0` — perfect match (exact or token-level depending on the evaluator).
147/// - `0.0` — no match.
148/// - Intermediate values — partial token overlap (LOCOMO token-F1 evaluator).
149///
150/// # Examples
151///
152/// ```
153/// use zeph_bench::EvalResult;
154///
155/// let result = EvalResult {
156///     scenario_id: "s1".into(),
157///     score: 0.75,
158///     passed: true,
159///     details: "token_f1=0.7500".into(),
160/// };
161/// assert!(result.passed);
162/// ```
163#[derive(Debug, Clone)]
164pub struct EvalResult {
165    /// ID of the scenario that produced this result.
166    pub scenario_id: String,
167    /// Numeric score in `0.0..=1.0`.
168    pub score: f64,
169    /// `true` when `score >= threshold` (threshold is evaluator-specific).
170    pub passed: bool,
171    /// Human-readable details such as `"token_f1=0.7500"` or `"exact_match=true"`.
172    pub details: String,
173}
174
175/// Loads scenarios from a dataset file on disk.
176///
177/// Implement this trait to add support for a new dataset format. The harness
178/// calls [`DatasetLoader::load`] once per run to materialise the full scenario
179/// list before iterating.
180///
181/// Built-in implementations:
182/// - [`crate::loaders::LocomoLoader`] — JSON array of sessions
183/// - [`crate::loaders::FramesLoader`] — JSONL, one record per line
184/// - [`crate::loaders::GaiaLoader`] — JSONL with optional level filter
185pub trait DatasetLoader {
186    /// Short identifier matching the dataset name in [`crate::DatasetRegistry`].
187    fn name(&self) -> &'static str;
188
189    /// Load all matching scenarios from `path`.
190    ///
191    /// # Errors
192    ///
193    /// Returns [`BenchError::Io`] when the file cannot be opened or read, and
194    /// [`BenchError::InvalidFormat`] when the file content cannot be parsed.
195    fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError>;
196}
197
198/// Scores one agent response against a [`Scenario`].
199///
200/// Each dataset loader ships a paired evaluator:
201/// - [`crate::loaders::LocomoEvaluator`] — token F1 with threshold 0.5
202/// - [`crate::loaders::FramesEvaluator`] — exact match (case-insensitive, punctuation stripped)
203/// - [`crate::loaders::GaiaEvaluator`] — GAIA-normalized exact match (articles stripped)
204pub trait Evaluator {
205    /// Compute and return an [`EvalResult`] for the given `agent_response`.
206    fn evaluate(&self, scenario: &Scenario, agent_response: &str) -> EvalResult;
207}
208
209/// Token F1 score: overlap of whitespace-split tokens between prediction and reference.
210///
211/// Splits both strings on whitespace, computes precision and recall over the
212/// token-type intersection, then returns the harmonic mean (F1).
213/// Returns `0.0` when either string is empty.
214///
215/// This metric is tolerant of minor wording differences and is used by the
216/// LOCOMO evaluator.
217///
218/// # Examples
219///
220/// ```
221/// use zeph_bench::token_f1;
222///
223/// // Perfect match.
224/// assert!((token_f1("hello world", "hello world") - 1.0).abs() < f64::EPSILON);
225///
226/// // No overlap.
227/// assert!(token_f1("foo bar", "baz qux") < f64::EPSILON);
228///
229/// // Partial overlap gives a value between 0 and 1.
230/// let f1 = token_f1("the cat sat", "the cat ran");
231/// assert!(f1 > 0.0 && f1 < 1.0);
232///
233/// // Empty strings return 0.
234/// assert!(token_f1("", "hello") < f64::EPSILON);
235/// ```
236#[must_use]
237pub fn token_f1(prediction: &str, reference: &str) -> f64 {
238    let pred_tokens: std::collections::HashSet<&str> = prediction.split_whitespace().collect();
239    let ref_tokens: std::collections::HashSet<&str> = reference.split_whitespace().collect();
240
241    if pred_tokens.is_empty() || ref_tokens.is_empty() {
242        return 0.0;
243    }
244
245    #[allow(clippy::cast_precision_loss)]
246    let common = pred_tokens.intersection(&ref_tokens).count() as f64;
247    #[allow(clippy::cast_precision_loss)]
248    let precision = common / pred_tokens.len() as f64;
249    #[allow(clippy::cast_precision_loss)]
250    let recall = common / ref_tokens.len() as f64;
251
252    if precision + recall == 0.0 {
253        return 0.0;
254    }
255
256    2.0 * precision * recall / (precision + recall)
257}
258
259/// Exact match after lowercasing and stripping punctuation/whitespace.
260///
261/// Both strings are normalized by:
262/// 1. Keeping only alphanumeric characters and whitespace.
263/// 2. Converting to lowercase.
264/// 3. Collapsing runs of whitespace to a single space.
265///
266/// Used by the FRAMES evaluator.
267///
268/// # Examples
269///
270/// ```
271/// use zeph_bench::exact_match;
272///
273/// assert!(exact_match("Hello, World!", "hello world"));
274/// assert!(exact_match("answer: YES.", "answer yes"));
275/// assert!(!exact_match("foo", "bar"));
276/// ```
277#[must_use]
278pub fn exact_match(prediction: &str, reference: &str) -> bool {
279    normalize_basic(prediction) == normalize_basic(reference)
280}
281
282/// GAIA-normalized exact match: lowercase, strip articles, strip punctuation, collapse
283/// whitespace, then compare.
284///
285/// Normalization steps (in order):
286/// 1. Keep only alphanumeric characters and whitespace.
287/// 2. Convert to lowercase.
288/// 3. Remove the articles `a`, `an`, and `the`.
289/// 4. Collapse whitespace and compare.
290///
291/// This matches the official GAIA leaderboard scoring script.
292///
293/// # Examples
294///
295/// ```
296/// use zeph_bench::gaia_normalized_exact_match;
297///
298/// // Articles are stripped from both sides.
299/// assert!(gaia_normalized_exact_match("The Tokyo", "Tokyo"));
300/// assert!(gaia_normalized_exact_match("a cat sat on an apple", "cat sat on apple"));
301///
302/// // Different answers do not match.
303/// assert!(!gaia_normalized_exact_match("1944", "1945"));
304/// ```
305#[must_use]
306pub fn gaia_normalized_exact_match(prediction: &str, reference: &str) -> bool {
307    normalize_gaia(prediction) == normalize_gaia(reference)
308}
309
310fn normalize_basic(s: &str) -> String {
311    s.chars()
312        .filter(|c| c.is_alphanumeric() || c.is_whitespace())
313        .collect::<String>()
314        .to_lowercase()
315        .split_whitespace()
316        .collect::<Vec<_>>()
317        .join(" ")
318}
319
320fn normalize_gaia(s: &str) -> String {
321    const ARTICLES: &[&str] = &["a", "an", "the"];
322
323    // Map Unicode subscript/superscript digits to their ASCII equivalents before
324    // stripping — this ensures "H₂O" and "H2O" normalize identically.
325    let ascii_mapped: String = s.chars().map(ascii_fold_digit).collect();
326
327    let stripped = ascii_mapped
328        .chars()
329        .filter(|c| c.is_alphanumeric() || c.is_whitespace())
330        .collect::<String>()
331        .to_lowercase();
332
333    stripped
334        .split_whitespace()
335        .filter(|tok| !ARTICLES.contains(tok))
336        .collect::<Vec<_>>()
337        .join(" ")
338}
339
340/// Map Unicode subscript and superscript digit characters to their ASCII equivalents.
341///
342/// Returns the character unchanged if it is not a subscript/superscript digit.
343fn ascii_fold_digit(c: char) -> char {
344    match c {
345        '\u{2080}' | '\u{2070}' => '0',
346        '\u{2081}' | '\u{00B9}' => '1',
347        '\u{2082}' | '\u{00B2}' => '2',
348        '\u{2083}' | '\u{00B3}' => '3',
349        '\u{2084}' | '\u{2074}' => '4',
350        '\u{2085}' | '\u{2075}' => '5',
351        '\u{2086}' | '\u{2076}' => '6',
352        '\u{2087}' | '\u{2077}' => '7',
353        '\u{2088}' | '\u{2078}' => '8',
354        '\u{2089}' | '\u{2079}' => '9',
355        other => other,
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[test]
364    fn token_f1_identical() {
365        assert!((token_f1("hello world", "hello world") - 1.0).abs() < f64::EPSILON);
366    }
367
368    #[test]
369    fn token_f1_no_overlap() {
370        assert!(token_f1("foo bar", "baz qux") < f64::EPSILON);
371    }
372
373    #[test]
374    fn token_f1_partial_overlap() {
375        let f1 = token_f1("hello world foo", "hello world bar");
376        assert!(f1 > 0.0 && f1 < 1.0);
377    }
378
379    #[test]
380    fn token_f1_empty_prediction() {
381        assert!(token_f1("", "hello") < f64::EPSILON);
382    }
383
384    #[test]
385    fn token_f1_empty_reference() {
386        assert!(token_f1("hello", "") < f64::EPSILON);
387    }
388
389    #[test]
390    fn exact_match_identical() {
391        assert!(exact_match("Hello, World!", "hello world"));
392    }
393
394    #[test]
395    fn exact_match_differs() {
396        assert!(!exact_match("foo", "bar"));
397    }
398
399    #[test]
400    fn exact_match_strips_punctuation() {
401        assert!(exact_match("answer: yes.", "answer yes"));
402    }
403
404    #[test]
405    fn gaia_normalized_strips_articles() {
406        assert!(gaia_normalized_exact_match(
407            "The quick brown fox",
408            "quick brown fox"
409        ));
410    }
411
412    #[test]
413    fn gaia_normalized_strips_a_an() {
414        assert!(gaia_normalized_exact_match(
415            "a cat sat on an apple",
416            "cat sat on apple"
417        ));
418    }
419
420    #[test]
421    fn gaia_normalized_differs() {
422        assert!(!gaia_normalized_exact_match("cat", "dog"));
423    }
424
425    #[test]
426    fn gaia_normalized_subscript_digits_match_ascii() {
427        // Model may respond with Unicode subscript "H₂O" — must match ASCII "H2O".
428        assert!(gaia_normalized_exact_match("H\u{2082}O", "H2O"));
429    }
430
431    #[test]
432    fn single_constructs_one_user_turn() {
433        let s = Scenario::single("id1", "hello", "world", serde_json::Value::Null);
434        assert_eq!(s.turns.len(), 1);
435        assert!(matches!(s.turns[0].role, Role::User));
436        assert_eq!(s.turns[0].content, "hello");
437        assert_eq!(s.expected, "world");
438    }
439
440    #[test]
441    fn primary_prompt_returns_first_user_turn_content() {
442        let s = Scenario::single("id1", "What year?", "2026", serde_json::Value::Null);
443        assert_eq!(s.primary_prompt().unwrap(), "What year?");
444    }
445
446    #[test]
447    fn primary_prompt_skips_leading_assistant_turns() {
448        let s = Scenario {
449            id: "id2".into(),
450            turns: vec![
451                Turn {
452                    role: Role::Assistant,
453                    content: "I am ready.".into(),
454                },
455                Turn {
456                    role: Role::User,
457                    content: "What is Rust?".into(),
458                },
459            ],
460            expected: "A systems language".into(),
461            metadata: serde_json::Value::Null,
462        };
463        assert_eq!(s.primary_prompt().unwrap(), "What is Rust?");
464    }
465
466    #[test]
467    fn primary_prompt_errors_when_no_user_turn() {
468        let s = Scenario {
469            id: "id3".into(),
470            turns: vec![Turn {
471                role: Role::Assistant,
472                content: "assistant only".into(),
473            }],
474            expected: String::new(),
475            metadata: serde_json::Value::Null,
476        };
477        assert!(s.primary_prompt().is_err());
478
479        let empty = Scenario {
480            id: "id4".into(),
481            turns: vec![],
482            expected: String::new(),
483            metadata: serde_json::Value::Null,
484        };
485        assert!(empty.primary_prompt().is_err());
486    }
487}