zeph_bench/loaders/tau2_bench/
loader.rs1use std::fmt::Write as _;
7use std::io::BufReader;
8use std::path::{Path, PathBuf};
9
10use crate::{
11 error::BenchError,
12 scenario::{DatasetLoader, Scenario},
13};
14
15#[cfg(test)]
16use super::data::EvaluationCriteria;
17use super::data::{Domain, Task, UserInstructions};
18
19pub struct DataPaths {
24 pub tasks_json: PathBuf,
26 pub db_json: PathBuf,
28 pub split_tasks_json: PathBuf,
30}
31
32impl DataPaths {
33 #[must_use]
35 pub fn resolve(root: &Path, domain: Domain) -> Self {
36 let dir = root.join(domain.as_str());
37 Self {
38 tasks_json: dir.join("tasks.json"),
39 db_json: dir.join("db.json"),
40 split_tasks_json: dir.join("split_tasks.json"),
41 }
42 }
43}
44
45pub struct Tau2BenchLoader {
69 pub domain: Domain,
71}
72
73impl Tau2BenchLoader {
74 #[must_use]
76 pub fn retail() -> Self {
77 Self {
78 domain: Domain::Retail,
79 }
80 }
81
82 #[must_use]
84 pub fn airline() -> Self {
85 Self {
86 domain: Domain::Airline,
87 }
88 }
89}
90
91impl DatasetLoader for Tau2BenchLoader {
92 fn name(&self) -> &'static str {
93 match self.domain {
94 Domain::Retail => "tau2-bench-retail",
95 Domain::Airline => "tau2-bench-airline",
96 }
97 }
98
99 fn load(&self, path: &Path) -> Result<Vec<Scenario>, BenchError> {
109 let file = std::fs::File::open(path)
110 .map_err(|e| BenchError::InvalidFormat(format!("open tasks.json: {e}")))?;
111 let tasks: Vec<Task> = serde_json::from_reader(BufReader::new(file))
112 .map_err(|e| BenchError::InvalidFormat(format!("parse tasks.json: {e}")))?;
113
114 let loader_name = self.name();
115 let mut scenarios = Vec::with_capacity(tasks.len());
116
117 for task in tasks {
118 let id = format!("{}_{}", loader_name, task.id);
119 let prompt = build_prompt(&task);
120 let metadata = serde_json::json!({
121 "domain": loader_name,
122 "tau2_task_id": task.id,
123 "evaluation_criteria": task.evaluation_criteria,
124 "user_scenario": task.user_scenario,
125 });
126 scenarios.push(Scenario::single(id, prompt, "", metadata));
127 }
128
129 Ok(scenarios)
130 }
131}
132
133fn build_prompt(task: &Task) -> String {
143 match &task.user_scenario.instructions {
144 UserInstructions::Plain(s) => s.clone(),
145 UserInstructions::Structured(i) => {
146 let mut buf = String::new();
147 writeln!(buf, "You are speaking to a customer support agent.").ok();
148 writeln!(buf, "\nReason for call:\n{}", i.reason_for_call).ok();
149 if let Some(known) = &i.known_info {
150 writeln!(buf, "\nKnown information about you:\n{known}").ok();
151 }
152 writeln!(buf, "\nTask instructions:\n{}", i.task_instructions).ok();
153 buf
154 }
155 }
156}
157
158pub fn db_json_path(tasks_json: &Path) -> Result<PathBuf, BenchError> {
164 tasks_json
165 .parent()
166 .map(|dir| dir.join("db.json"))
167 .ok_or_else(|| {
168 BenchError::InvalidFormat(
169 "tasks.json must have a parent directory containing db.json".into(),
170 )
171 })
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 const TASKS_FIXTURE: &str = r##"[
179 {
180 "id": "0",
181 "user_scenario": {
182 "instructions": {
183 "domain": "retail",
184 "reason_for_call": "Cancel my order",
185 "task_instructions": "Cancel order #W0001",
186 "known_info": "Order id: #W0001"
187 }
188 },
189 "evaluation_criteria": {
190 "actions": [
191 {
192 "action_id": "a1",
193 "requestor": "assistant",
194 "name": "cancel_pending_order",
195 "arguments": {"order_id": "#W0001", "reason": "no_longer_needed"},
196 "compare_args": ["order_id", "reason"]
197 }
198 ],
199 "reward_basis": ["ACTION"]
200 }
201 },
202 {
203 "id": "1",
204 "user_scenario": {
205 "instructions": "Simple plain prompt"
206 },
207 "evaluation_criteria": {
208 "actions": [],
209 "reward_basis": ["ACTION"]
210 }
211 },
212 {
213 "id": "2",
214 "user_scenario": {
215 "instructions": "DB-only task"
216 },
217 "evaluation_criteria": {
218 "actions": [],
219 "reward_basis": ["DB"]
220 }
221 }
222 ]"##;
223
224 fn load_from_str(json: &str, domain: Domain) -> Vec<Scenario> {
225 let dir = tempfile::tempdir().unwrap();
226 let path = dir.path().join("tasks.json");
227 std::fs::write(&path, json).unwrap();
228 let loader = Tau2BenchLoader { domain };
229 loader.load(&path).unwrap()
230 }
231
232 #[test]
233 fn load_all_tasks_regardless_of_reward_basis() {
234 let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
236 assert_eq!(scenarios.len(), 3);
237 }
238
239 #[test]
240 fn load_builds_correct_ids() {
241 let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
242 assert_eq!(scenarios[0].id, "tau2-bench-retail_0");
243 assert_eq!(scenarios[1].id, "tau2-bench-retail_1");
244 assert_eq!(scenarios[2].id, "tau2-bench-retail_2");
245 }
246
247 #[test]
248 fn load_prompt_from_structured_instructions() {
249 let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
250 let prompt = scenarios[0].primary_prompt().unwrap();
251 assert!(prompt.contains("Cancel my order") || prompt.contains("Cancel order"));
252 }
253
254 #[test]
255 fn load_prompt_from_plain_instructions() {
256 let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
257 let prompt = scenarios[1].primary_prompt().unwrap();
258 assert_eq!(prompt, "Simple plain prompt");
259 }
260
261 #[test]
262 fn metadata_contains_evaluation_criteria() {
263 let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
264 let criteria_value = scenarios[0].metadata.get("evaluation_criteria").unwrap();
265 let criteria: EvaluationCriteria = serde_json::from_value(criteria_value.clone()).unwrap();
266 assert_eq!(criteria.actions.len(), 1);
267 assert_eq!(criteria.actions[0].name, "cancel_pending_order");
268 }
269
270 #[test]
271 fn metadata_roundtrip_preserves_arguments() {
272 let scenarios = load_from_str(TASKS_FIXTURE, Domain::Retail);
273 let criteria_value = scenarios[0].metadata.get("evaluation_criteria").unwrap();
274 let criteria: EvaluationCriteria = serde_json::from_value(criteria_value.clone()).unwrap();
275 let arg = criteria.actions[0].arguments.get("order_id").unwrap();
276 assert_eq!(arg.as_str(), Some("#W0001"));
277 }
278
279 #[test]
280 fn airline_loader_name() {
281 let loader = Tau2BenchLoader::airline();
282 assert_eq!(loader.name(), "tau2-bench-airline");
283 }
284}