Skip to main content

zeph_bench/loaders/tau2_bench/
loader.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Dataset loader for tau2-bench retail and airline domains.
5
6use std::fmt::Write as _;
7use std::io::BufReader;
8use std::path::{Path, PathBuf};
9
10use crate::{
11    error::BenchError,
12    scenario::{DatasetLoader, Scenario},
13};
14
15#[cfg(test)]
16use super::data::EvaluationCriteria;
17use super::data::{Domain, Task, UserInstructions};
18
19/// Resolved file paths for one tau2-bench domain.
20///
21/// All three files must reside in the same directory (the layout produced by
22/// `bench download --dataset tau2-bench`).
23pub struct DataPaths {
24    /// JSON array of task objects.
25    pub tasks_json: PathBuf,
26    /// JSON database seed file for the environment.
27    pub db_json: PathBuf,
28    /// JSON split definitions (`base`, `train`, `test`).
29    pub split_tasks_json: PathBuf,
30}
31
32impl DataPaths {
33    /// Resolve the three-file set for `domain` under `root`.
34    #[must_use]
35    pub fn resolve(root: &Path, domain: Domain) -> Self {
36        let dir = root.join(domain.as_str());
37        Self {
38            tasks_json: dir.join("tasks.json"),
39            db_json: dir.join("db.json"),
40            split_tasks_json: dir.join("split_tasks.json"),
41        }
42    }
43}
44
45/// Loads tau2-bench scenarios for a single domain.
46///
47/// The loader reads `tasks.json`, parses each [`Task`] into a [`Scenario`], and
48/// stores the serialised `EvaluationCriteria` JSON in `scenario.metadata` for the
49/// evaluator to retrieve per-scenario at runtime.
50///
51/// # Path convention
52///
53/// Pass the absolute path to `tasks.json` as the `path` argument to
54/// [`DatasetLoader::load`]. The `db.json` and `split_tasks.json` files are
55/// expected to reside in the same directory.
56///
57/// # Examples
58///
59/// ```no_run
60/// use std::path::Path;
61/// use zeph_bench::loaders::tau2_bench::loader::Tau2BenchLoader;
62/// use zeph_bench::scenario::DatasetLoader;
63///
64/// let loader = Tau2BenchLoader::retail();
65/// let scenarios = loader.load(Path::new("/data/tau2-bench/retail/tasks.json")).unwrap();
66/// println!("loaded {} scenarios", scenarios.len());
67/// ```
68pub struct Tau2BenchLoader {
69    /// Domain this loader targets.
70    pub domain: Domain,
71}
72
73impl Tau2BenchLoader {
74    /// Create a loader for the retail domain.
75    #[must_use]
76    pub fn retail() -> Self {
77        Self {
78            domain: Domain::Retail,
79        }
80    }
81
82    /// Create a loader for the airline domain.
83    #[must_use]
84    pub fn airline() -> Self {
85        Self {
86            domain: Domain::Airline,
87        }
88    }
89}
90
91impl DatasetLoader for Tau2BenchLoader {
92    fn name(&self) -> &'static str {
93        match self.domain {
94            Domain::Retail => "tau2-bench-retail",
95            Domain::Airline => "tau2-bench-airline",
96        }
97    }
98
99    /// Load all tasks from `path` (must be `tasks.json`).
100    ///
101    /// All tasks are loaded regardless of `reward_basis` — the evaluator scores
102    /// based on the `actions` field which is present in every task.
103    ///
104    /// # Errors
105    ///
106    /// Returns [`BenchError::InvalidFormat`] when the file cannot be opened or
107    /// when JSON parsing fails.
108    fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError> {
109        let file = std::fs::File::open(path)
110            .map_err(|e| BenchError::InvalidFormat(format!("open tasks.json: {e}")))?;
111        let tasks: Vec<Task> = serde_json::from_reader(BufReader::new(file))
112            .map_err(|e| BenchError::InvalidFormat(format!("parse tasks.json: {e}")))?;
113
114        let loader_name = self.name();
115        let mut scenarios = Vec::with_capacity(tasks.len());
116
117        for task in tasks {
118            let id = format!("{}_{}", loader_name, task.id);
119            let prompt = build_prompt(&task);
120            let metadata = serde_json::json!({
121                "domain": loader_name,
122                "tau2_task_id": task.id,
123                "evaluation_criteria": task.evaluation_criteria,
124                "user_scenario": task.user_scenario,
125            });
126            scenarios.push(Scenario::single(id, prompt, "", metadata));
127        }
128
129        Ok(scenarios)
130    }
131}
132
133/// Convert a [`Task`]'s user scenario into a single instruction string.
134///
135/// This is a deliberate MVP simplification — the full tau2-bench benchmark uses
136/// a multi-turn user simulator. Here we collapse it into one upfront prompt.
137///
138/// Multi-turn user simulation is deferred to #4233 (D4 of #3417). The current approach
139/// works for ACTION-only scoring because the agent sees all information at once,
140/// but will under-score tasks where the user simulator drives information
141/// exchange across turns.
142fn build_prompt(task: &Task) -> String {
143    match &task.user_scenario.instructions {
144        UserInstructions::Plain(s) => s.clone(),
145        UserInstructions::Structured(i) => {
146            let mut buf = String::new();
147            writeln!(buf, "You are speaking to a customer support agent.").ok();
148            writeln!(buf, "\nReason for call:\n{}", i.reason_for_call).ok();
149            if let Some(known) = &i.known_info {
150                writeln!(buf, "\nKnown information about you:\n{known}").ok();
151            }
152            writeln!(buf, "\nTask instructions:\n{}", i.task_instructions).ok();
153            buf
154        }
155    }
156}
157
158/// Return the `db.json` path that accompanies the given `tasks.json` path.
159///
160/// # Errors
161///
162/// Returns [`BenchError::InvalidFormat`] if `tasks_json` has no parent directory.
163pub fn db_json_path(tasks_json: &Path) -> Result<PathBuf, BenchError> {
164    tasks_json
165        .parent()
166        .map(|dir| dir.join("db.json"))
167        .ok_or_else(|| {
168            BenchError::InvalidFormat(
169                "tasks.json must have a parent directory containing db.json".into(),
170            )
171        })
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    const TASKS_FIXTURE: &str = r##"[
179      {
180        "id": "0",
181        "user_scenario": {
182          "instructions": {
183            "domain": "retail",
184            "reason_for_call": "Cancel my order",
185            "task_instructions": "Cancel order #W0001",
186            "known_info": "Order id: #W0001"
187          }
188        },
189        "evaluation_criteria": {
190          "actions": [
191            {
192              "action_id": "a1",
193              "requestor": "assistant",
194              "name": "cancel_pending_order",
195              "arguments": {"order_id": "#W0001", "reason": "no_longer_needed"},
196              "compare_args": ["order_id", "reason"]
197            }
198          ],
199          "reward_basis": ["ACTION"]
200        }
201      },
202      {
203        "id": "1",
204        "user_scenario": {
205          "instructions": "Simple plain prompt"
206        },
207        "evaluation_criteria": {
208          "actions": [],
209          "reward_basis": ["ACTION"]
210        }
211      },
212      {
213        "id": "2",
214        "user_scenario": {
215          "instructions": "DB-only task"
216        },
217        "evaluation_criteria": {
218          "actions": [],
219          "reward_basis": ["DB"]
220        }
221      }
222    ]"##;
223
224    fn load_from_str(json: &str, domain: Domain) -> Vec<Scenario> {
225        let dir = tempfile::tempdir().unwrap();
226        let path = dir.path().join("tasks.json");
227        std::fs::write(&path, json).unwrap();
228        let loader = Tau2BenchLoader { domain };
229        loader.load(&path).unwrap()
230    }
231
232    #[test]
233    fn load_all_tasks_regardless_of_reward_basis() {
234        // All 3 tasks are loaded — reward_basis filter was removed.
235        let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
236        assert_eq!(scenarios.len(), 3);
237    }
238
239    #[test]
240    fn load_builds_correct_ids() {
241        let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
242        assert_eq!(scenarios[0].id, "tau2-bench-retail_0");
243        assert_eq!(scenarios[1].id, "tau2-bench-retail_1");
244        assert_eq!(scenarios[2].id, "tau2-bench-retail_2");
245    }
246
247    #[test]
248    fn load_prompt_from_structured_instructions() {
249        let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
250        let prompt = scenarios[0].primary_prompt().unwrap();
251        assert!(prompt.contains("Cancel my order") || prompt.contains("Cancel order"));
252    }
253
254    #[test]
255    fn load_prompt_from_plain_instructions() {
256        let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
257        let prompt = scenarios[1].primary_prompt().unwrap();
258        assert_eq!(prompt, "Simple plain prompt");
259    }
260
261    #[test]
262    fn metadata_contains_evaluation_criteria() {
263        let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
264        let criteria_value = scenarios[0].metadata.get("evaluation_criteria").unwrap();
265        let criteria: EvaluationCriteria = serde_json::from_value(criteria_value.clone()).unwrap();
266        assert_eq!(criteria.actions.len(), 1);
267        assert_eq!(criteria.actions[0].name, "cancel_pending_order");
268    }
269
270    #[test]
271    fn metadata_roundtrip_preserves_arguments() {
272        let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
273        let criteria_value = scenarios[0].metadata.get("evaluation_criteria").unwrap();
274        let criteria: EvaluationCriteria = serde_json::from_value(criteria_value.clone()).unwrap();
275        let arg = criteria.actions[0].arguments.get("order_id").unwrap();
276        assert_eq!(arg.as_str(), Some("#W0001"));
277    }
278
279    #[test]
280    fn airline_loader_name() {
281        let loader = Tau2BenchLoader::airline();
282        assert_eq!(loader.name(), "tau2-bench-airline");
283    }
284}