scouter_types/genai/
utils.rs1use crate::error::TypeError;
2use crate::genai::{
3 AgentAssertionTask, AssertionTask, LLMJudgeTask, TaskConfig, TasksFile, TraceAssertionTask,
4};
5use pyo3::prelude::*;
6use pyo3::types::PyList;
7use serde::{Deserialize, Serialize};
8use std::collections::BTreeSet;
9
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
11pub struct AssertionTasks {
12 pub assertion: Vec<AssertionTask>,
13 pub judge: Vec<LLMJudgeTask>,
14 pub trace: Vec<TraceAssertionTask>,
15 pub agent: Vec<AgentAssertionTask>,
16}
17
18impl AssertionTasks {
19 pub fn collect_non_judge_task_ids(&self) -> BTreeSet<String> {
20 self.assertion
21 .iter()
22 .map(|t| t.id.clone())
23 .chain(self.trace.iter().map(|t| t.id.clone()))
24 .chain(self.agent.iter().map(|t| t.id.clone()))
25 .collect()
26 }
27
28 pub fn collect_all_task_ids(&self) -> Result<BTreeSet<String>, TypeError> {
29 let mut task_ids = BTreeSet::new();
30
31 for task in &self.assertion {
32 task_ids.insert(task.id.clone());
33 }
34 for task in &self.judge {
35 task_ids.insert(task.id.clone());
36 }
37 for task in &self.trace {
38 task_ids.insert(task.id.clone());
39 }
40 for task in &self.agent {
41 task_ids.insert(task.id.clone());
42 }
43
44 let total_tasks =
45 self.assertion.len() + self.judge.len() + self.trace.len() + self.agent.len();
46 if task_ids.len() != total_tasks {
47 return Err(TypeError::DuplicateTaskIds);
48 }
49
50 Ok(task_ids)
51 }
52
53 pub fn from_tasks_file(tasks: TasksFile) -> Self {
54 let mut assertion = Vec::new();
55 let mut judge = Vec::new();
56 let mut trace = Vec::new();
57 let mut agent = Vec::new();
58
59 for task in tasks.tasks {
60 match task {
61 TaskConfig::Assertion(t) => assertion.push(t),
62 TaskConfig::LLMJudge(t) => judge.push(*t),
63 TaskConfig::TraceAssertion(t) => trace.push(t),
64 TaskConfig::AgentAssertion(t) => agent.push(t),
65 }
66 }
67
68 AssertionTasks {
69 assertion,
70 judge,
71 trace,
72 agent,
73 }
74 }
75}
76
77pub fn extract_assertion_tasks_from_pylist(
79 list: &Bound<'_, PyList>,
80) -> Result<AssertionTasks, TypeError> {
81 let mut assertion_tasks = Vec::new();
82 let mut llm_judge_tasks = Vec::new();
83 let mut trace_tasks = Vec::new();
84 let mut agent_tasks = Vec::new();
85
86 for item in list.iter() {
87 if item.is_instance_of::<AssertionTask>() {
88 let task = item.extract::<AssertionTask>()?;
89 assertion_tasks.push(task);
90 } else if item.is_instance_of::<LLMJudgeTask>() {
91 let task = item.extract::<LLMJudgeTask>()?;
92 llm_judge_tasks.push(task);
93 } else if item.is_instance_of::<TraceAssertionTask>() {
94 let task = item.extract::<TraceAssertionTask>()?;
95 trace_tasks.push(task);
96 } else if item.is_instance_of::<AgentAssertionTask>() {
97 let task = item.extract::<AgentAssertionTask>()?;
98 agent_tasks.push(task);
99 } else {
100 return Err(TypeError::InvalidAssertionTaskType);
101 }
102 }
103 Ok(AssertionTasks {
104 assertion: assertion_tasks,
105 judge: llm_judge_tasks,
106 trace: trace_tasks,
107 agent: agent_tasks,
108 })
109}