zeph_bench/loaders/tau2_bench/
data.rs1use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13#[non_exhaustive]
14pub enum Domain {
15 Retail,
17 Airline,
19}
20
21impl Domain {
22 #[must_use]
24 pub fn as_str(self) -> &'static str {
25 match self {
26 Self::Retail => "retail",
27 Self::Airline => "airline",
28 }
29 }
30}
31
32#[derive(Debug, Clone, Deserialize)]
37pub struct Task {
38 pub id: String,
40 pub user_scenario: UserScenario,
42 pub evaluation_criteria: Option<EvaluationCriteria>,
44 #[serde(flatten)]
46 _rest: serde_json::Map<String, serde_json::Value>,
47}
48
49#[derive(Debug, Clone, Deserialize, Serialize)]
51pub struct UserScenario {
52 pub instructions: UserInstructions,
54 #[serde(default)]
56 pub persona: Option<String>,
57}
58
59#[derive(Debug, Clone, Deserialize, Serialize)]
64#[serde(untagged)]
65#[non_exhaustive]
66pub enum UserInstructions {
67 Structured(StructuredUserInstructions),
69 Plain(String),
71}
72
73#[derive(Debug, Clone, Deserialize, Serialize)]
75pub struct StructuredUserInstructions {
76 pub domain: String,
78 pub reason_for_call: String,
80 pub task_instructions: String,
82 #[serde(default)]
84 pub known_info: Option<String>,
85 #[serde(default)]
87 pub unknown_info: Option<String>,
88}
89
90#[derive(Debug, Clone, Deserialize, Serialize)]
92pub struct EvaluationCriteria {
93 #[serde(default)]
95 pub actions: Vec<Action>,
96 #[serde(default)]
98 pub reward_basis: Vec<String>,
99}
100
101#[derive(Debug, Clone, Deserialize, Serialize)]
108pub struct Action {
109 pub action_id: String,
111 #[serde(default = "default_requestor")]
113 pub requestor: String,
114 pub name: String,
116 #[serde(default)]
118 pub arguments: serde_json::Map<String, serde_json::Value>,
119 #[serde(default)]
121 pub compare_args: Option<Vec<String>>,
122 #[serde(default)]
124 pub info: Option<String>,
125}
126
127fn default_requestor() -> String {
128 "assistant".to_owned()
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 const RETAIL_FIXTURE: &str = r##"[
136 {
137 "id": "0",
138 "user_scenario": {
139 "instructions": {
140 "domain": "retail",
141 "reason_for_call": "I need to cancel an order",
142 "task_instructions": "Cancel order #W1234567",
143 "known_info": "Order id: #W1234567"
144 },
145 "persona": "Impatient customer"
146 },
147 "evaluation_criteria": {
148 "actions": [
149 {
150 "action_id": "a1",
151 "requestor": "assistant",
152 "name": "cancel_pending_order",
153 "arguments": {"order_id": "#W1234567", "reason": "no_longer_needed"},
154 "compare_args": ["order_id", "reason"]
155 }
156 ],
157 "reward_basis": ["ACTION"]
158 }
159 },
160 {
161 "id": "1",
162 "user_scenario": {
163 "instructions": "Plain string instructions for a simple task"
164 },
165 "evaluation_criteria": {
166 "actions": [],
167 "reward_basis": ["ACTION"]
168 }
169 }
170 ]"##;
171
172 #[test]
173 fn parse_structured_instructions() {
174 let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
175 assert_eq!(tasks.len(), 2);
176 assert_eq!(tasks[0].id, "0");
177 match &tasks[0].user_scenario.instructions {
178 UserInstructions::Structured(s) => {
179 assert_eq!(s.domain, "retail");
180 assert_eq!(s.reason_for_call, "I need to cancel an order");
181 assert!(s.known_info.is_some());
182 }
183 UserInstructions::Plain(_) => panic!("expected structured"),
184 }
185 }
186
187 #[test]
188 fn parse_plain_instructions() {
189 let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
190 match &tasks[1].user_scenario.instructions {
191 UserInstructions::Plain(s) => assert!(s.contains("Plain string")),
192 UserInstructions::Structured(_) => panic!("expected plain"),
193 }
194 }
195
196 #[test]
197 fn parse_evaluation_criteria() {
198 let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
199 let criteria = tasks[0].evaluation_criteria.as_ref().unwrap();
200 assert_eq!(criteria.actions.len(), 1);
201 assert_eq!(criteria.actions[0].name, "cancel_pending_order");
202 assert_eq!(
203 criteria.actions[0].compare_args,
204 Some(vec!["order_id".to_owned(), "reason".to_owned()])
205 );
206 }
207
208 #[test]
209 fn metadata_roundtrip() {
210 let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
211 let criteria = tasks[0].evaluation_criteria.as_ref().unwrap();
212 let value = serde_json::to_value(criteria).unwrap();
213 let back: EvaluationCriteria = serde_json::from_value(value).unwrap();
214 assert_eq!(back.actions.len(), 1);
215 assert_eq!(back.actions[0].name, "cancel_pending_order");
216 }
217
218 #[test]
219 fn domain_as_str() {
220 assert_eq!(Domain::Retail.as_str(), "retail");
221 assert_eq!(Domain::Airline.as_str(), "airline");
222 }
223}