Skip to main content

zeph_experiments/
benchmark.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Benchmark dataset types and TOML loading.
5
6use std::path::Path;
7
8use serde::Deserialize;
9
10use super::error::EvalError;
11
12/// Maximum allowed benchmark file size (10 MiB).
13const MAX_BENCHMARK_SIZE: u64 = 10 * 1024 * 1024;
14
15/// A set of benchmark cases loaded from a TOML file.
16///
17/// Each [`BenchmarkCase`] defines one prompt/response pair to evaluate. The set is
18/// validated to be non-empty before evaluation begins.
19///
20/// # TOML Format
21///
22/// ```toml
23/// [[cases]]
24/// prompt = "What is the capital of France?"
25/// reference = "Paris"
26/// tags = ["geography", "factual"]
27///
28/// [[cases]]
29/// prompt = "Explain async/await in Rust."
30/// context = "You are a Rust expert."
31/// ```
32///
33/// # Examples
34///
35/// ```rust,no_run
36/// use std::path::Path;
37/// use zeph_experiments::BenchmarkSet;
38///
39/// let set = BenchmarkSet::from_file(Path::new("bench/qa.toml"))
40///     .expect("benchmark file must exist and be valid TOML");
41/// set.validate().expect("benchmark must have at least one case");
42/// println!("{} cases loaded", set.cases.len());
43/// ```
44#[derive(Debug, Clone, Deserialize)]
45pub struct BenchmarkSet {
46    /// The benchmark cases to evaluate. Must be non-empty (enforced by [`Self::validate`]).
47    pub cases: Vec<BenchmarkCase>,
48}
49
50/// A single benchmark case.
51///
52/// The `prompt` is sent to the subject model. If `context` is present it is injected
53/// as a system message before the user turn. If `reference` is present the judge model
54/// uses it to calibrate its score.
55///
56/// # Examples
57///
58/// ```rust
59/// use zeph_experiments::BenchmarkCase;
60///
61/// let case = BenchmarkCase {
62///     prompt: "Name the first element.".into(),
63///     context: Some("You are a chemistry expert.".into()),
64///     reference: Some("Hydrogen".into()),
65///     tags: Some(vec!["chemistry".into()]),
66/// };
67/// assert!(case.reference.as_deref() == Some("Hydrogen"));
68/// ```
69#[derive(Debug, Clone, Deserialize)]
70pub struct BenchmarkCase {
71    /// The prompt sent to the subject model.
72    pub prompt: String,
73    /// Optional system context injected before the user turn.
74    #[serde(default)]
75    pub context: Option<String>,
76    /// Optional reference answer for the judge to calibrate scoring.
77    #[serde(default)]
78    pub reference: Option<String>,
79    /// Optional tags for filtering or grouping results.
80    #[serde(default)]
81    pub tags: Option<Vec<String>>,
82}
83
84impl BenchmarkSet {
85    /// Load a benchmark set from a TOML file.
86    ///
87    /// Performs size guard (10 MiB limit) and canonicalisation before reading.
88    /// Symlinks that escape the file's parent directory are rejected.
89    ///
90    /// # Errors
91    ///
92    /// Returns [`EvalError::BenchmarkLoad`] if the file cannot be read,
93    /// [`EvalError::BenchmarkParse`] if the TOML is invalid,
94    /// [`EvalError::BenchmarkTooLarge`] if the file exceeds the size limit, or
95    /// [`EvalError::PathTraversal`] if canonicalization reveals a symlink escape.
96    pub fn from_file(path: &Path) -> Result<Self, EvalError> {
97        // Canonicalize to resolve symlinks before opening — eliminates TOCTOU race.
98        let canonical = std::fs::canonicalize(path)
99            .map_err(|e| EvalError::BenchmarkLoad(path.display().to_string(), e))?;
100
101        // Verify the canonical path stays within the parent directory.
102        // This prevents symlinks from escaping into arbitrary filesystem locations.
103        if let Some(parent) = path.parent()
104            && let Ok(canonical_parent) = std::fs::canonicalize(parent)
105            && !canonical.starts_with(&canonical_parent)
106        {
107            return Err(EvalError::PathTraversal(canonical.display().to_string()));
108        }
109
110        // Guard against unbounded memory use from oversized files.
111        let metadata = std::fs::metadata(&canonical)
112            .map_err(|e| EvalError::BenchmarkLoad(canonical.display().to_string(), e))?;
113        if metadata.len() > MAX_BENCHMARK_SIZE {
114            return Err(EvalError::BenchmarkTooLarge {
115                path: canonical.display().to_string(),
116                size: metadata.len(),
117                limit: MAX_BENCHMARK_SIZE,
118            });
119        }
120
121        let content = std::fs::read_to_string(&canonical)
122            .map_err(|e| EvalError::BenchmarkLoad(canonical.display().to_string(), e))?;
123        toml::from_str(&content)
124            .map_err(|e| EvalError::BenchmarkParse(canonical.display().to_string(), e.to_string()))
125    }
126
127    /// Validate that the benchmark set is non-empty.
128    ///
129    /// # Errors
130    ///
131    /// Returns [`EvalError::EmptyBenchmarkSet`] if `cases` is empty.
132    pub fn validate(&self) -> Result<(), EvalError> {
133        if self.cases.is_empty() {
134            return Err(EvalError::EmptyBenchmarkSet);
135        }
136        Ok(())
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    #![allow(clippy::redundant_closure_for_method_calls)]
143
144    use super::*;
145
146    fn parse(toml: &str) -> BenchmarkSet {
147        toml::from_str(toml).expect("valid TOML")
148    }
149
150    #[test]
151    fn benchmark_from_toml_happy_path() {
152        let toml = r#"
153[[cases]]
154prompt = "What is 2+2?"
155"#;
156        let set = parse(toml);
157        assert_eq!(set.cases.len(), 1);
158        assert_eq!(set.cases[0].prompt, "What is 2+2?");
159        assert!(set.cases[0].context.is_none());
160        assert!(set.cases[0].reference.is_none());
161        assert!(set.cases[0].tags.is_none());
162    }
163
164    #[test]
165    fn benchmark_from_toml_with_all_fields() {
166        let toml = r#"
167[[cases]]
168prompt = "Explain Rust ownership."
169context = "You are a Rust expert."
170reference = "Ownership is Rust's memory management model."
171tags = ["rust", "concepts"]
172"#;
173        let set = parse(toml);
174        assert_eq!(set.cases.len(), 1);
175        let case = &set.cases[0];
176        assert_eq!(case.context.as_deref(), Some("You are a Rust expert."));
177        assert!(case.reference.is_some());
178        assert_eq!(case.tags.as_ref().map(std::vec::Vec::len), Some(2));
179    }
180
181    #[test]
182    fn benchmark_empty_cases_rejected() {
183        let set = BenchmarkSet { cases: vec![] };
184        assert!(matches!(set.validate(), Err(EvalError::EmptyBenchmarkSet)));
185    }
186
187    #[test]
188    fn benchmark_from_file_missing_file() {
189        let result = BenchmarkSet::from_file(Path::new("/nonexistent/path/benchmark.toml"));
190        assert!(matches!(result, Err(EvalError::BenchmarkLoad(_, _))));
191    }
192
193    #[test]
194    fn benchmark_from_toml_invalid_syntax() {
195        let bad = "[[cases\nprompt = 'unclosed'";
196        let result: Result<BenchmarkSet, _> = toml::from_str(bad);
197        assert!(result.is_err());
198    }
199
200    #[test]
201    fn benchmark_from_file_invalid_toml() {
202        use std::io::Write;
203        let mut f = tempfile::NamedTempFile::new().unwrap();
204        writeln!(f, "not valid toml ][[]").unwrap();
205        let result = BenchmarkSet::from_file(f.path());
206        assert!(matches!(result, Err(EvalError::BenchmarkParse(_, _))));
207    }
208
209    #[test]
210    fn benchmark_from_file_too_large() {
211        // Write a file larger than MAX_BENCHMARK_SIZE by writing in chunks.
212        // We override the limit via a helper that accepts a custom limit instead of
213        // creating a truly 10 MiB file. Test the error variant directly via a stub.
214        // Since we cannot override the constant, we verify the error type is correct
215        // by constructing it directly.
216        let err = EvalError::BenchmarkTooLarge {
217            path: "/tmp/bench.toml".into(),
218            size: MAX_BENCHMARK_SIZE + 1,
219            limit: MAX_BENCHMARK_SIZE,
220        };
221        assert!(err.to_string().contains("exceeds size limit"));
222    }
223
224    #[test]
225    fn benchmark_from_file_size_guard_allows_normal_file() {
226        use std::io::Write;
227        let mut f = tempfile::NamedTempFile::new().unwrap();
228        writeln!(f, "[[cases]]\nprompt = \"hello\"").unwrap();
229        // Normal-sized file must load without size error.
230        let result = BenchmarkSet::from_file(f.path());
231        assert!(result.is_ok());
232    }
233
234    #[test]
235    fn benchmark_validate_passes_for_nonempty() {
236        let set = BenchmarkSet {
237            cases: vec![BenchmarkCase {
238                prompt: "hello".into(),
239                context: None,
240                reference: None,
241                tags: None,
242            }],
243        };
244        assert!(set.validate().is_ok());
245    }
246}