1use crate::error::TypeError;
2use crate::genai::utils::{extract_assertion_tasks_from_pylist, AssertionTasks};
3use crate::genai::{AgentAssertionTask, AssertionTask, LLMJudgeTask, TraceAssertionTask};
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::PyHelperFuncs;
6use potato_head::create_uuid7;
7use pyo3::prelude::*;
8use pyo3::types::{PyDict, PyList};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12
13fn default_id() -> String {
14 create_uuid7()
15}
16
17fn default_max_turns() -> usize {
18 10
19}
20
21fn default_tasks() -> AssertionTasks {
22 AssertionTasks {
23 assertion: vec![],
24 judge: vec![],
25 trace: vec![],
26 agent: vec![],
27 }
28}
29
30#[pyclass]
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
53pub struct EvalScenario {
54 #[pyo3(get, set)]
55 #[serde(default = "default_id")]
56 pub id: String,
57
58 #[pyo3(get, set)]
59 pub initial_query: String,
60
61 #[pyo3(get, set)]
62 #[serde(default)]
63 pub predefined_turns: Vec<String>,
64
65 #[pyo3(get, set)]
66 pub simulated_user_persona: Option<String>,
67
68 #[pyo3(get, set)]
69 pub termination_signal: Option<String>,
70
71 #[pyo3(get, set)]
72 #[serde(default = "default_max_turns")]
73 pub max_turns: usize,
74
75 #[pyo3(get, set)]
76 pub expected_outcome: Option<String>,
77
78 #[serde(default = "default_tasks")]
81 pub tasks: AssertionTasks,
82
83 pub metadata: Option<HashMap<String, Value>>,
84}
85
86#[pymethods]
87#[allow(clippy::too_many_arguments)]
88impl EvalScenario {
89 #[new]
90 #[pyo3(signature = (
91 initial_query,
92 tasks = None,
93 id = None,
94 expected_outcome = None,
95 predefined_turns = None,
96 simulated_user_persona = None,
97 termination_signal = None,
98 max_turns = 10,
99 metadata = None,
100 ))]
101 pub fn new(
102 initial_query: String,
103 tasks: Option<&Bound<'_, PyList>>,
104 id: Option<String>,
105 expected_outcome: Option<String>,
106 predefined_turns: Option<Vec<String>>,
107 simulated_user_persona: Option<String>,
108 termination_signal: Option<String>,
109 max_turns: usize,
110 metadata: Option<&Bound<'_, PyDict>>,
111 ) -> Result<Self, TypeError> {
112 let tasks = match tasks {
113 Some(list) => extract_assertion_tasks_from_pylist(list)?,
114 None => AssertionTasks {
115 assertion: vec![],
116 judge: vec![],
117 trace: vec![],
118 agent: vec![],
119 },
120 };
121
122 let metadata = match metadata {
123 Some(dict) => {
124 let mut map = HashMap::new();
125 for (k, v) in dict.iter() {
126 let key: String = k.extract()?;
127 let value: Value = pyobject_to_json(&v)?;
128 map.insert(key, value);
129 }
130 Some(map)
131 }
132 None => None,
133 };
134
135 Ok(Self {
136 id: id.unwrap_or_else(create_uuid7),
137 initial_query,
138 predefined_turns: predefined_turns.unwrap_or_default(),
139 simulated_user_persona,
140 termination_signal,
141 max_turns,
142 expected_outcome,
143 tasks,
144 metadata,
145 })
146 }
147
148 pub fn __str__(&self) -> String {
149 PyHelperFuncs::__str__(self)
150 }
151
152 pub fn model_dump_json(&self) -> String {
153 PyHelperFuncs::__json__(self)
154 }
155
156 pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, TypeError> {
157 let json_value = serde_json::to_value(self)?;
158 let dict = PyDict::new(py);
159 json_to_pyobject(py, &json_value, &dict)?;
160 Ok(dict.into())
161 }
162
163 #[getter]
164 pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
165 self.tasks.assertion.clone()
166 }
167
168 #[getter]
169 pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
170 self.tasks.judge.clone()
171 }
172
173 #[getter]
174 pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
175 self.tasks.trace.clone()
176 }
177
178 #[getter]
179 pub fn agent_assertion_tasks(&self) -> Vec<AgentAssertionTask> {
180 self.tasks.agent.clone()
181 }
182
183 #[getter]
184 pub fn has_tasks(&self) -> bool {
185 !self.tasks.assertion.is_empty()
186 || !self.tasks.judge.is_empty()
187 || !self.tasks.trace.is_empty()
188 || !self.tasks.agent.is_empty()
189 }
190
191 pub fn is_multi_turn(&self) -> bool {
193 !self.predefined_turns.is_empty()
194 }
195
196 pub fn is_reactive(&self) -> bool {
201 self.simulated_user_persona.is_some()
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 fn empty_tasks() -> AssertionTasks {
210 AssertionTasks {
211 assertion: vec![],
212 judge: vec![],
213 trace: vec![],
214 agent: vec![],
215 }
216 }
217
218 #[test]
219 fn default_values() {
220 let s = EvalScenario {
221 id: "test-id".to_string(),
222 initial_query: "What is 2+2?".to_string(),
223 predefined_turns: vec![],
224 simulated_user_persona: None,
225 termination_signal: None,
226 max_turns: 10,
227 expected_outcome: None,
228 tasks: empty_tasks(),
229 metadata: None,
230 };
231
232 assert_eq!(s.max_turns, 10);
233 assert!(!s.is_multi_turn());
234 assert!(!s.is_reactive());
235 assert!(!s.has_tasks());
236 }
237
238 #[test]
239 fn multi_turn_detection() {
240 let s = EvalScenario {
241 id: "t".to_string(),
242 initial_query: "Hello".to_string(),
243 predefined_turns: vec!["Follow-up".to_string()],
244 simulated_user_persona: None,
245 termination_signal: None,
246 max_turns: 10,
247 expected_outcome: None,
248 tasks: empty_tasks(),
249 metadata: None,
250 };
251
252 assert!(s.is_multi_turn());
253 assert!(!s.is_reactive());
254 }
255
256 #[test]
257 fn reactive_detection() {
258 let s = EvalScenario {
259 id: "t".to_string(),
260 initial_query: "Hello".to_string(),
261 predefined_turns: vec![],
262 simulated_user_persona: Some("Busy professional".to_string()),
263 termination_signal: Some("That's great".to_string()),
264 max_turns: 8,
265 expected_outcome: None,
266 tasks: empty_tasks(),
267 metadata: None,
268 };
269
270 assert!(!s.is_multi_turn());
271 assert!(s.is_reactive());
272 }
273
274 #[test]
275 fn serialization_roundtrip() {
276 let s = EvalScenario {
277 id: "roundtrip-id".to_string(),
278 initial_query: "Make pasta".to_string(),
279 predefined_turns: vec!["Make it spicier".to_string()],
280 simulated_user_persona: None,
281 termination_signal: None,
282 max_turns: 5,
283 expected_outcome: Some("A pasta recipe".to_string()),
284 tasks: empty_tasks(),
285 metadata: None,
286 };
287
288 let json = serde_json::to_string(&s).unwrap();
289 let restored: EvalScenario = serde_json::from_str(&json).unwrap();
290
291 assert_eq!(s.id, restored.id);
292 assert_eq!(s.initial_query, restored.initial_query);
293 assert_eq!(s.predefined_turns, restored.predefined_turns);
294 assert_eq!(s.max_turns, restored.max_turns);
295 assert_eq!(s.expected_outcome, restored.expected_outcome);
296 }
297
298 #[test]
299 fn default_id_from_serde() {
300 let json = r#"{"initial_query": "Hello"}"#;
302 let s: EvalScenario = serde_json::from_str(json).unwrap();
303 assert!(!s.id.is_empty());
304 assert_eq!(s.max_turns, 10);
305 assert!(s.predefined_turns.is_empty());
306 }
307}