zeph_bench/loaders/tau2_bench/
data.rs1use serde::{Deserialize, Serialize};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum Domain {
14 Retail,
16 Airline,
18}
19
20impl Domain {
21 #[must_use]
23 pub fn as_str(self) -> &'static str {
24 match self {
25 Self::Retail => "retail",
26 Self::Airline => "airline",
27 }
28 }
29}
30
31#[derive(Debug, Clone, Deserialize)]
36pub struct Task {
37 pub id: String,
39 pub user_scenario: UserScenario,
41 pub evaluation_criteria: Option<EvaluationCriteria>,
43 #[serde(flatten)]
45 _rest: serde_json::Map<String, serde_json::Value>,
46}
47
48#[derive(Debug, Clone, Deserialize, Serialize)]
50pub struct UserScenario {
51 pub instructions: UserInstructions,
53 #[serde(default)]
55 pub persona: Option<String>,
56}
57
58#[derive(Debug, Clone, Deserialize, Serialize)]
63#[serde(untagged)]
64pub enum UserInstructions {
65 Structured(StructuredUserInstructions),
67 Plain(String),
69}
70
71#[derive(Debug, Clone, Deserialize, Serialize)]
73pub struct StructuredUserInstructions {
74 pub domain: String,
76 pub reason_for_call: String,
78 pub task_instructions: String,
80 #[serde(default)]
82 pub known_info: Option<String>,
83 #[serde(default)]
85 pub unknown_info: Option<String>,
86}
87
88#[derive(Debug, Clone, Deserialize, Serialize)]
90pub struct EvaluationCriteria {
91 #[serde(default)]
93 pub actions: Vec<Action>,
94 #[serde(default)]
96 pub reward_basis: Vec<String>,
97}
98
99#[derive(Debug, Clone, Deserialize, Serialize)]
106pub struct Action {
107 pub action_id: String,
109 #[serde(default = "default_requestor")]
111 pub requestor: String,
112 pub name: String,
114 #[serde(default)]
116 pub arguments: serde_json::Map<String, serde_json::Value>,
117 #[serde(default)]
119 pub compare_args: Option<Vec<String>>,
120 #[serde(default)]
122 pub info: Option<String>,
123}
124
125fn default_requestor() -> String {
126 "assistant".to_owned()
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132
133 const RETAIL_FIXTURE: &str = r##"[
134 {
135 "id": "0",
136 "user_scenario": {
137 "instructions": {
138 "domain": "retail",
139 "reason_for_call": "I need to cancel an order",
140 "task_instructions": "Cancel order #W1234567",
141 "known_info": "Order id: #W1234567"
142 },
143 "persona": "Impatient customer"
144 },
145 "evaluation_criteria": {
146 "actions": [
147 {
148 "action_id": "a1",
149 "requestor": "assistant",
150 "name": "cancel_pending_order",
151 "arguments": {"order_id": "#W1234567", "reason": "no_longer_needed"},
152 "compare_args": ["order_id", "reason"]
153 }
154 ],
155 "reward_basis": ["ACTION"]
156 }
157 },
158 {
159 "id": "1",
160 "user_scenario": {
161 "instructions": "Plain string instructions for a simple task"
162 },
163 "evaluation_criteria": {
164 "actions": [],
165 "reward_basis": ["ACTION"]
166 }
167 }
168 ]"##;
169
170 #[test]
171 fn parse_structured_instructions() {
172 let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
173 assert_eq!(tasks.len(), 2);
174 assert_eq!(tasks[0].id, "0");
175 match &tasks[0].user_scenario.instructions {
176 UserInstructions::Structured(s) => {
177 assert_eq!(s.domain, "retail");
178 assert_eq!(s.reason_for_call, "I need to cancel an order");
179 assert!(s.known_info.is_some());
180 }
181 UserInstructions::Plain(_) => panic!("expected structured"),
182 }
183 }
184
185 #[test]
186 fn parse_plain_instructions() {
187 let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
188 match &tasks[1].user_scenario.instructions {
189 UserInstructions::Plain(s) => assert!(s.contains("Plain string")),
190 UserInstructions::Structured(_) => panic!("expected plain"),
191 }
192 }
193
194 #[test]
195 fn parse_evaluation_criteria() {
196 let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
197 let criteria = tasks[0].evaluation_criteria.as_ref().unwrap();
198 assert_eq!(criteria.actions.len(), 1);
199 assert_eq!(criteria.actions[0].name, "cancel_pending_order");
200 assert_eq!(
201 criteria.actions[0].compare_args,
202 Some(vec!["order_id".to_owned(), "reason".to_owned()])
203 );
204 }
205
206 #[test]
207 fn metadata_roundtrip() {
208 let tasks: Vec<Task> = serde_json::from_str(RETAIL_FIXTURE).unwrap();
209 let criteria = tasks[0].evaluation_criteria.as_ref().unwrap();
210 let value = serde_json::to_value(criteria).unwrap();
211 let back: EvaluationCriteria = serde_json::from_value(value).unwrap();
212 assert_eq!(back.actions.len(), 1);
213 assert_eq!(back.actions[0].name, "cancel_pending_order");
214 }
215
216 #[test]
217 fn domain_as_str() {
218 assert_eq!(Domain::Retail.as_str(), "retail");
219 assert_eq!(Domain::Airline.as_str(), "airline");
220 }
221}