Skip to main content

zeph_bench/loaders/tau2_bench/
data.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Data model matching the tau2-bench JSON schema.
5//!
6//! All types derive `Deserialize` and map directly onto the upstream
7//! `data/tau2/domains/<domain>/tasks.json` format.
8
9use serde::{Deserialize, Serialize};
10
11/// tau2-bench domain selector for routing loader and env construction.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13#[non_exhaustive]
14pub enum Domain {
15    /// Customer-service retail domain.
16    Retail,
17    /// Flight-reservation airline domain.
18    Airline,
19}
20
21impl Domain {
22    /// Short lowercase identifier used in file paths and scenario IDs.
23    #[must_use]
24    pub fn as_str(self) -> &'static str {
25        match self {
26            Self::Retail => "retail",
27            Self::Airline => "airline",
28        }
29    }
30}
31
32/// One task from tau2-bench `tasks.json`.
33///
34/// Unused upstream fields (`description`, `ticket`, `initial_state`, `annotations`)
35/// are collected by `_rest` and silently discarded so we don't fail on schema evolution.
36#[derive(Debug, Clone, Deserialize)]
37pub struct Task {
38    /// Unique task identifier within the domain (e.g. `"0"`, `"retail_1"`).
39    pub id: String,
40    /// User-facing scenario: instructions + optional persona.
41    pub user_scenario: UserScenario,
42    /// Expected tool calls and reward criteria.
43    pub evaluation_criteria: Option<EvaluationCriteria>,
44    /// Forward-compat catch-all for fields not modelled here.
45    #[serde(flatten)]
46    _rest: serde_json::Map<String, serde_json::Value>,
47}
48
49/// Wraps the user instructions for a scenario.
50#[derive(Debug, Clone, Deserialize, Serialize)]
51pub struct UserScenario {
52    /// Instructions given to the user simulator (or a plain string prompt).
53    pub instructions: UserInstructions,
54    /// Optional persona text.
55    #[serde(default)]
56    pub persona: Option<String>,
57}
58
59/// Instructions are either a structured object or a plain string.
60///
61/// The upstream schema uses structured objects for most tasks, but older
62/// or synthetic tasks may provide a plain string.
63#[derive(Debug, Clone, Deserialize, Serialize)]
64#[serde(untagged)]
65#[non_exhaustive]
66pub enum UserInstructions {
67    /// Structured object with individual fields.
68    Structured(StructuredUserInstructions),
69    /// Legacy or synthetic plain-string prompt.
70    Plain(String),
71}
72
73/// Structured form of the user instructions.
74#[derive(Debug, Clone, Deserialize, Serialize)]
75pub struct StructuredUserInstructions {
76    /// Domain identifier string (e.g. `"retail"`).
77    pub domain: String,
78    /// Why the user is calling customer support.
79    pub reason_for_call: String,
80    /// Step-by-step instructions for the user simulator.
81    pub task_instructions: String,
82    /// Information the user already knows (injected into prompt).
83    #[serde(default)]
84    pub known_info: Option<String>,
85    /// Information the user deliberately hides (not injected into prompt).
86    #[serde(default)]
87    pub unknown_info: Option<String>,
88}
89
90/// Expected actions and reward policy for a task.
91#[derive(Debug, Clone, Deserialize, Serialize)]
92pub struct EvaluationCriteria {
93    /// Gold tool calls the agent must make (in any order).
94    #[serde(default)]
95    pub actions: Vec<Action>,
96    /// Reward components required for full credit (e.g. `["ACTION"]`, `["DB", "ACTION"]`).
97    #[serde(default)]
98    pub reward_basis: Vec<String>,
99}
100
101/// One expected tool call from the upstream `Action` data model.
102///
103/// Scoring uses `Action.compare_with_tool_call` semantics:
104/// - If `compare_args` is `None`, all argument keys are compared.
105/// - If `compare_args` is `Some([])`, only the tool name is checked.
106/// - If `compare_args` is `Some(keys)`, only those keys are compared.
107#[derive(Debug, Clone, Deserialize, Serialize)]
108pub struct Action {
109    /// Unique identifier for this action within the task.
110    pub action_id: String,
111    /// Who performs this action — `"assistant"` or `"user"`.
112    #[serde(default = "default_requestor")]
113    pub requestor: String,
114    /// Tool name (must match an available tool exactly).
115    pub name: String,
116    /// Expected arguments to the tool call.
117    #[serde(default)]
118    pub arguments: serde_json::Map<String, serde_json::Value>,
119    /// Argument keys to compare, or `None` for all keys.
120    #[serde(default)]
121    pub compare_args: Option<Vec<String>>,
122    /// Optional human-readable description of the action.
123    #[serde(default)]
124    pub info: Option<String>,
125}
126
127fn default_requestor() -> String {
128    "assistant".to_owned()
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134
135    const RETAIL_FIXTURE: &str = r##"[
136      {
137        "id": "0",
138        "user_scenario": {
139          "instructions": {
140            "domain": "retail",
141            "reason_for_call": "I need to cancel an order",
142            "task_instructions": "Cancel order #W1234567",
143            "known_info": "Order id: #W1234567"
144          },
145          "persona": "Impatient customer"
146        },
147        "evaluation_criteria": {
148          "actions": [
149            {
150              "action_id": "a1",
151              "requestor": "assistant",
152              "name": "cancel_pending_order",
153              "arguments": {"order_id": "#W1234567", "reason": "no_longer_needed"},
154              "compare_args": ["order_id", "reason"]
155            }
156          ],
157          "reward_basis": ["ACTION"]
158        }
159      },
160      {
161        "id": "1",
162        "user_scenario": {
163          "instructions": "Plain string instructions for a simple task"
164        },
165        "evaluation_criteria": {
166          "actions": [],
167          "reward_basis": ["ACTION"]
168        }
169      }
170    ]"##;
171
172    #[test]
173    fn parse_structured_instructions() {
174        let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
175        assert_eq!(tasks.len(), 2);
176        assert_eq!(tasks[0].id, "0");
177        match &tasks[0].user_scenario.instructions {
178            UserInstructions::Structured(s) => {
179                assert_eq!(s.domain, "retail");
180                assert_eq!(s.reason_for_call, "I need to cancel an order");
181                assert!(s.known_info.is_some());
182            }
183            UserInstructions::Plain(_) => panic!("expected structured"),
184        }
185    }
186
187    #[test]
188    fn parse_plain_instructions() {
189        let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
190        match &tasks[1].user_scenario.instructions {
191            UserInstructions::Plain(s) => assert!(s.contains("Plain string")),
192            UserInstructions::Structured(_) => panic!("expected plain"),
193        }
194    }
195
196    #[test]
197    fn parse_evaluation_criteria() {
198        let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
199        let criteria = tasks[0].evaluation_criteria.as_ref().unwrap();
200        assert_eq!(criteria.actions.len(), 1);
201        assert_eq!(criteria.actions[0].name, "cancel_pending_order");
202        assert_eq!(
203            criteria.actions[0].compare_args,
204            Some(vec!["order_id".to_owned(), "reason".to_owned()])
205        );
206    }
207
208    #[test]
209    fn metadata_roundtrip() {
210        let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
211        let criteria = tasks[0].evaluation_criteria.as_ref().unwrap();
212        let value = serde_json::to_value(criteria).unwrap();
213        let back: EvaluationCriteria = serde_json::from_value(value).unwrap();
214        assert_eq!(back.actions.len(), 1);
215        assert_eq!(back.actions[0].name, "cancel_pending_order");
216    }
217
218    #[test]
219    fn domain_as_str() {
220        assert_eq!(Domain::Retail.as_str(), "retail");
221        assert_eq!(Domain::Airline.as_str(), "airline");
222    }
223}