Skip to main content

zeph_bench/loaders/tau2_bench/
data.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Data model matching the tau2-bench JSON schema.
5//!
6//! All types derive `Deserialize` and map directly onto the upstream
7//! `data/tau2/domains/<domain>/tasks.json` format.
8
9use serde::{Deserialize, Serialize};
10
11/// tau2-bench domain selector for routing loader and env construction.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum Domain {
14    /// Customer-service retail domain.
15    Retail,
16    /// Flight-reservation airline domain.
17    Airline,
18}
19
20impl Domain {
21    /// Short lowercase identifier used in file paths and scenario IDs.
22    #[must_use]
23    pub fn as_str(self) -> &'static str {
24        match self {
25            Self::Retail => "retail",
26            Self::Airline => "airline",
27        }
28    }
29}
30
31/// One task from tau2-bench `tasks.json`.
32///
33/// Unused upstream fields (`description`, `ticket`, `initial_state`, `annotations`)
34/// are collected by `_rest` and silently discarded so we don't fail on schema evolution.
35#[derive(Debug, Clone, Deserialize)]
36pub struct Task {
37    /// Unique task identifier within the domain (e.g. `"0"`, `"retail_1"`).
38    pub id: String,
39    /// User-facing scenario: instructions + optional persona.
40    pub user_scenario: UserScenario,
41    /// Expected tool calls and reward criteria.
42    pub evaluation_criteria: Option<EvaluationCriteria>,
43    /// Forward-compat catch-all for fields not modelled here.
44    #[serde(flatten)]
45    _rest: serde_json::Map<String, serde_json::Value>,
46}
47
48/// Wraps the user instructions for a scenario.
49#[derive(Debug, Clone, Deserialize, Serialize)]
50pub struct UserScenario {
51    /// Instructions given to the user simulator (or a plain string prompt).
52    pub instructions: UserInstructions,
53    /// Optional persona text.
54    #[serde(default)]
55    pub persona: Option<String>,
56}
57
58/// Instructions are either a structured object or a plain string.
59///
60/// The upstream schema uses structured objects for most tasks, but older
61/// or synthetic tasks may provide a plain string.
62#[derive(Debug, Clone, Deserialize, Serialize)]
63#[serde(untagged)]
64pub enum UserInstructions {
65    /// Structured object with individual fields.
66    Structured(StructuredUserInstructions),
67    /// Legacy or synthetic plain-string prompt.
68    Plain(String),
69}
70
71/// Structured form of the user instructions.
72#[derive(Debug, Clone, Deserialize, Serialize)]
73pub struct StructuredUserInstructions {
74    /// Domain identifier string (e.g. `"retail"`).
75    pub domain: String,
76    /// Why the user is calling customer support.
77    pub reason_for_call: String,
78    /// Step-by-step instructions for the user simulator.
79    pub task_instructions: String,
80    /// Information the user already knows (injected into prompt).
81    #[serde(default)]
82    pub known_info: Option<String>,
83    /// Information the user deliberately hides (not injected into prompt).
84    #[serde(default)]
85    pub unknown_info: Option<String>,
86}
87
88/// Expected actions and reward policy for a task.
89#[derive(Debug, Clone, Deserialize, Serialize)]
90pub struct EvaluationCriteria {
91    /// Gold tool calls the agent must make (in any order).
92    #[serde(default)]
93    pub actions: Vec<Action>,
94    /// Reward components required for full credit (e.g. `["ACTION"]`, `["DB", "ACTION"]`).
95    #[serde(default)]
96    pub reward_basis: Vec<String>,
97}
98
99/// One expected tool call from the upstream `Action` data model.
100///
101/// Scoring uses `Action.compare_with_tool_call` semantics:
102/// - If `compare_args` is `None`, all argument keys are compared.
103/// - If `compare_args` is `Some([])`, only the tool name is checked.
104/// - If `compare_args` is `Some(keys)`, only those keys are compared.
105#[derive(Debug, Clone, Deserialize, Serialize)]
106pub struct Action {
107    /// Unique identifier for this action within the task.
108    pub action_id: String,
109    /// Who performs this action — `"assistant"` or `"user"`.
110    #[serde(default = "default_requestor")]
111    pub requestor: String,
112    /// Tool name (must match an available tool exactly).
113    pub name: String,
114    /// Expected arguments to the tool call.
115    #[serde(default)]
116    pub arguments: serde_json::Map<String, serde_json::Value>,
117    /// Argument keys to compare, or `None` for all keys.
118    #[serde(default)]
119    pub compare_args: Option<Vec<String>>,
120    /// Optional human-readable description of the action.
121    #[serde(default)]
122    pub info: Option<String>,
123}
124
125fn default_requestor() -> String {
126    "assistant".to_owned()
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132
133    const RETAIL_FIXTURE: &str = r##"[
134      {
135        "id": "0",
136        "user_scenario": {
137          "instructions": {
138            "domain": "retail",
139            "reason_for_call": "I need to cancel an order",
140            "task_instructions": "Cancel order #W1234567",
141            "known_info": "Order id: #W1234567"
142          },
143          "persona": "Impatient customer"
144        },
145        "evaluation_criteria": {
146          "actions": [
147            {
148              "action_id": "a1",
149              "requestor": "assistant",
150              "name": "cancel_pending_order",
151              "arguments": {"order_id": "#W1234567", "reason": "no_longer_needed"},
152              "compare_args": ["order_id", "reason"]
153            }
154          ],
155          "reward_basis": ["ACTION"]
156        }
157      },
158      {
159        "id": "1",
160        "user_scenario": {
161          "instructions": "Plain string instructions for a simple task"
162        },
163        "evaluation_criteria": {
164          "actions": [],
165          "reward_basis": ["ACTION"]
166        }
167      }
168    ]"##;
169
170    #[test]
171    fn parse_structured_instructions() {
172        let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
173        assert_eq!(tasks.len(), 2);
174        assert_eq!(tasks[0].id, "0");
175        match &tasks[0].user_scenario.instructions {
176            UserInstructions::Structured(s) => {
177                assert_eq!(s.domain, "retail");
178                assert_eq!(s.reason_for_call, "I need to cancel an order");
179                assert!(s.known_info.is_some());
180            }
181            UserInstructions::Plain(_) => panic!("expected structured"),
182        }
183    }
184
185    #[test]
186    fn parse_plain_instructions() {
187        let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
188        match &tasks[1].user_scenario.instructions {
189            UserInstructions::Plain(s) => assert!(s.contains("Plain string")),
190            UserInstructions::Structured(_) => panic!("expected plain"),
191        }
192    }
193
194    #[test]
195    fn parse_evaluation_criteria() {
196        let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
197        let criteria = tasks[0].evaluation_criteria.as_ref().unwrap();
198        assert_eq!(criteria.actions.len(), 1);
199        assert_eq!(criteria.actions[0].name, "cancel_pending_order");
200        assert_eq!(
201            criteria.actions[0].compare_args,
202            Some(vec!["order_id".to_owned(), "reason".to_owned()])
203        );
204    }
205
206    #[test]
207    fn metadata_roundtrip() {
208        let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
209        let criteria = tasks[0].evaluation_criteria.as_ref().unwrap();
210        let value = serde_json::to_value(criteria).unwrap();
211        let back: EvaluationCriteria = serde_json::from_value(value).unwrap();
212        assert_eq!(back.actions.len(), 1);
213        assert_eq!(back.actions[0].name, "cancel_pending_order");
214    }
215
216    #[test]
217    fn domain_as_str() {
218        assert_eq!(Domain::Retail.as_str(), "retail");
219        assert_eq!(Domain::Airline.as_str(), "airline");
220    }
221}