Skip to main content

scouter_types/genai/
scenario.rs

1use crate::error::TypeError;
2use crate::genai::utils::{extract_assertion_tasks_from_pylist, AssertionTasks};
3use crate::genai::{AgentAssertionTask, AssertionTask, LLMJudgeTask, TraceAssertionTask};
4use crate::util::{json_to_pyobject, pyobject_to_json};
5use crate::PyHelperFuncs;
6use potato_head::create_uuid7;
7use pyo3::prelude::*;
8use pyo3::types::{PyDict, PyList};
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12
13fn default_id() -> String {
14    create_uuid7()
15}
16
17fn default_max_turns() -> usize {
18    10
19}
20
21fn default_tasks() -> AssertionTasks {
22    AssertionTasks {
23        assertion: vec![],
24        judge: vec![],
25        trace: vec![],
26        agent: vec![],
27    }
28}
29
30/// A single test case in an offline agent evaluation run.
31///
32/// Scenarios drive `EvalScenarios` (the orchestrator). At minimum, supply an
33/// `initial_query`. Everything else is optional.
34///
35/// ## Task attachment
36/// Use `tasks` to attach scenario-level evaluation tasks (any of the four task
37/// types: `AssertionTask`, `LLMJudgeTask`, `AgentAssertionTask`,
38/// `TraceAssertionTask`). These are evaluated against the agent's **final
39/// response** for this specific scenario. Sub-agent evaluation happens
40/// holistically across all scenarios via the profiles already registered on
41/// `ScouterQueue`.
42///
43/// ## Multi-turn scenarios
44/// Populate `predefined_turns` with follow-up queries (executed in order after
45/// `initial_query`). Leave empty for single-turn evaluation.
46///
47/// ## ReAct / simulated-user scenarios
48/// `simulated_user_persona` and `termination_signal` are placeholder fields
49/// for future ReAct support. Setting them has no effect in the current
50/// implementation.
51#[pyclass]
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
53pub struct EvalScenario {
54    #[pyo3(get, set)]
55    #[serde(default = "default_id")]
56    pub id: String,
57
58    #[pyo3(get, set)]
59    pub initial_query: String,
60
61    #[pyo3(get, set)]
62    #[serde(default)]
63    pub predefined_turns: Vec<String>,
64
65    #[pyo3(get, set)]
66    pub simulated_user_persona: Option<String>,
67
68    #[pyo3(get, set)]
69    pub termination_signal: Option<String>,
70
71    #[pyo3(get, set)]
72    #[serde(default = "default_max_turns")]
73    pub max_turns: usize,
74
75    #[pyo3(get, set)]
76    pub expected_outcome: Option<String>,
77
78    // Stored as structured buckets — same internal representation as GenAIEvalProfile.
79    // Not exposed directly as a Python getter; use the typed getters below.
80    #[serde(default = "default_tasks")]
81    pub tasks: AssertionTasks,
82
83    pub metadata: Option<HashMap<String, Value>>,
84}
85
86#[pymethods]
87#[allow(clippy::too_many_arguments)]
88impl EvalScenario {
89    #[new]
90    #[pyo3(signature = (
91        initial_query,
92        tasks = None,
93        id = None,
94        expected_outcome = None,
95        predefined_turns = None,
96        simulated_user_persona = None,
97        termination_signal = None,
98        max_turns = 10,
99        metadata = None,
100    ))]
101    pub fn new(
102        initial_query: String,
103        tasks: Option<&Bound<'_, PyList>>,
104        id: Option<String>,
105        expected_outcome: Option<String>,
106        predefined_turns: Option<Vec<String>>,
107        simulated_user_persona: Option<String>,
108        termination_signal: Option<String>,
109        max_turns: usize,
110        metadata: Option<&Bound<'_, PyDict>>,
111    ) -> Result<Self, TypeError> {
112        let tasks = match tasks {
113            Some(list) => extract_assertion_tasks_from_pylist(list)?,
114            None => AssertionTasks {
115                assertion: vec![],
116                judge: vec![],
117                trace: vec![],
118                agent: vec![],
119            },
120        };
121
122        let metadata = match metadata {
123            Some(dict) => {
124                let mut map = HashMap::new();
125                for (k, v) in dict.iter() {
126                    let key: String = k.extract()?;
127                    let value: Value = pyobject_to_json(&v)?;
128                    map.insert(key, value);
129                }
130                Some(map)
131            }
132            None => None,
133        };
134
135        Ok(Self {
136            id: id.unwrap_or_else(create_uuid7),
137            initial_query,
138            predefined_turns: predefined_turns.unwrap_or_default(),
139            simulated_user_persona,
140            termination_signal,
141            max_turns,
142            expected_outcome,
143            tasks,
144            metadata,
145        })
146    }
147
148    pub fn __str__(&self) -> String {
149        PyHelperFuncs::__str__(self)
150    }
151
152    pub fn model_dump_json(&self) -> String {
153        PyHelperFuncs::__json__(self)
154    }
155
156    pub fn model_dump(&self, py: Python) -> Result<Py<PyDict>, TypeError> {
157        let json_value = serde_json::to_value(self)?;
158        let dict = PyDict::new(py);
159        json_to_pyobject(py, &json_value, &dict)?;
160        Ok(dict.into())
161    }
162
163    #[getter]
164    pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
165        self.tasks.assertion.clone()
166    }
167
168    #[getter]
169    pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
170        self.tasks.judge.clone()
171    }
172
173    #[getter]
174    pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
175        self.tasks.trace.clone()
176    }
177
178    #[getter]
179    pub fn agent_assertion_tasks(&self) -> Vec<AgentAssertionTask> {
180        self.tasks.agent.clone()
181    }
182
183    #[getter]
184    pub fn has_tasks(&self) -> bool {
185        !self.tasks.assertion.is_empty()
186            || !self.tasks.judge.is_empty()
187            || !self.tasks.trace.is_empty()
188            || !self.tasks.agent.is_empty()
189    }
190
191    /// Returns `true` when `predefined_turns` is non-empty (scripted multi-turn scenario).
192    pub fn is_multi_turn(&self) -> bool {
193        !self.predefined_turns.is_empty()
194    }
195
196    /// Returns `true` when a `simulated_user_persona` is set (ReAct/reactive scenario).
197    ///
198    /// Note: ReAct execution is deferred — this field is a placeholder until
199    /// potato_head supports simulated user LLMs.
200    pub fn is_reactive(&self) -> bool {
201        self.simulated_user_persona.is_some()
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    fn empty_tasks() -> AssertionTasks {
210        AssertionTasks {
211            assertion: vec![],
212            judge: vec![],
213            trace: vec![],
214            agent: vec![],
215        }
216    }
217
218    #[test]
219    fn default_values() {
220        let s = EvalScenario {
221            id: "test-id".to_string(),
222            initial_query: "What is 2+2?".to_string(),
223            predefined_turns: vec![],
224            simulated_user_persona: None,
225            termination_signal: None,
226            max_turns: 10,
227            expected_outcome: None,
228            tasks: empty_tasks(),
229            metadata: None,
230        };
231
232        assert_eq!(s.max_turns, 10);
233        assert!(!s.is_multi_turn());
234        assert!(!s.is_reactive());
235        assert!(!s.has_tasks());
236    }
237
238    #[test]
239    fn multi_turn_detection() {
240        let s = EvalScenario {
241            id: "t".to_string(),
242            initial_query: "Hello".to_string(),
243            predefined_turns: vec!["Follow-up".to_string()],
244            simulated_user_persona: None,
245            termination_signal: None,
246            max_turns: 10,
247            expected_outcome: None,
248            tasks: empty_tasks(),
249            metadata: None,
250        };
251
252        assert!(s.is_multi_turn());
253        assert!(!s.is_reactive());
254    }
255
256    #[test]
257    fn reactive_detection() {
258        let s = EvalScenario {
259            id: "t".to_string(),
260            initial_query: "Hello".to_string(),
261            predefined_turns: vec![],
262            simulated_user_persona: Some("Busy professional".to_string()),
263            termination_signal: Some("That's great".to_string()),
264            max_turns: 8,
265            expected_outcome: None,
266            tasks: empty_tasks(),
267            metadata: None,
268        };
269
270        assert!(!s.is_multi_turn());
271        assert!(s.is_reactive());
272    }
273
274    #[test]
275    fn serialization_roundtrip() {
276        let s = EvalScenario {
277            id: "roundtrip-id".to_string(),
278            initial_query: "Make pasta".to_string(),
279            predefined_turns: vec!["Make it spicier".to_string()],
280            simulated_user_persona: None,
281            termination_signal: None,
282            max_turns: 5,
283            expected_outcome: Some("A pasta recipe".to_string()),
284            tasks: empty_tasks(),
285            metadata: None,
286        };
287
288        let json = serde_json::to_string(&s).unwrap();
289        let restored: EvalScenario = serde_json::from_str(&json).unwrap();
290
291        assert_eq!(s.id, restored.id);
292        assert_eq!(s.initial_query, restored.initial_query);
293        assert_eq!(s.predefined_turns, restored.predefined_turns);
294        assert_eq!(s.max_turns, restored.max_turns);
295        assert_eq!(s.expected_outcome, restored.expected_outcome);
296    }
297
298    #[test]
299    fn default_id_from_serde() {
300        // Deserializing from JSON without an `id` field should generate one.
301        let json = r#"{"initial_query": "Hello"}"#;
302        let s: EvalScenario = serde_json::from_str(json).unwrap();
303        assert!(!s.id.is_empty());
304        assert_eq!(s.max_turns, 10);
305        assert!(s.predefined_turns.is_empty());
306    }
307}