Skip to main content

scouter_types/genai/
utils.rs

1use crate::error::TypeError;
2use crate::genai::{
3    AgentAssertionTask, AssertionTask, LLMJudgeTask, TaskConfig, TasksFile, TraceAssertionTask,
4};
5use pyo3::prelude::*;
6use pyo3::types::PyList;
7use serde::{Deserialize, Serialize};
8use std::collections::BTreeSet;
9
10#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
11pub struct AssertionTasks {
12    pub assertion: Vec<AssertionTask>,
13    pub judge: Vec<LLMJudgeTask>,
14    pub trace: Vec<TraceAssertionTask>,
15    pub agent: Vec<AgentAssertionTask>,
16}
17
18impl AssertionTasks {
19    pub fn collect_non_judge_task_ids(&self) -> BTreeSet<String> {
20        self.assertion
21            .iter()
22            .map(|t| t.id.clone())
23            .chain(self.trace.iter().map(|t| t.id.clone()))
24            .chain(self.agent.iter().map(|t| t.id.clone()))
25            .collect()
26    }
27
28    pub fn collect_all_task_ids(&self) -> Result<BTreeSet<String>, TypeError> {
29        let mut task_ids = BTreeSet::new();
30
31        for task in &self.assertion {
32            task_ids.insert(task.id.clone());
33        }
34        for task in &self.judge {
35            task_ids.insert(task.id.clone());
36        }
37        for task in &self.trace {
38            task_ids.insert(task.id.clone());
39        }
40        for task in &self.agent {
41            task_ids.insert(task.id.clone());
42        }
43
44        let total_tasks =
45            self.assertion.len() + self.judge.len() + self.trace.len() + self.agent.len();
46        if task_ids.len() != total_tasks {
47            return Err(TypeError::DuplicateTaskIds);
48        }
49
50        Ok(task_ids)
51    }
52
53    pub fn from_tasks_file(tasks: TasksFile) -> Self {
54        let mut assertion = Vec::new();
55        let mut judge = Vec::new();
56        let mut trace = Vec::new();
57        let mut agent = Vec::new();
58
59        for task in tasks.tasks {
60            match task {
61                TaskConfig::Assertion(t) => assertion.push(t),
62                TaskConfig::LLMJudge(t) => judge.push(*t),
63                TaskConfig::TraceAssertion(t) => trace.push(t),
64                TaskConfig::AgentAssertion(t) => agent.push(t),
65            }
66        }
67
68        AssertionTasks {
69            assertion,
70            judge,
71            trace,
72            agent,
73        }
74    }
75}
76
77/// Helper function to extract AssertionTask and LLMJudgeTask from a PyList
78pub fn extract_assertion_tasks_from_pylist(
79    list: &Bound<'_, PyList>,
80) -> Result<AssertionTasks, TypeError> {
81    let mut assertion_tasks = Vec::new();
82    let mut llm_judge_tasks = Vec::new();
83    let mut trace_tasks = Vec::new();
84    let mut agent_tasks = Vec::new();
85
86    for item in list.iter() {
87        if item.is_instance_of::<AssertionTask>() {
88            let task = item.extract::<AssertionTask>()?;
89            assertion_tasks.push(task);
90        } else if item.is_instance_of::<LLMJudgeTask>() {
91            let task = item.extract::<LLMJudgeTask>()?;
92            llm_judge_tasks.push(task);
93        } else if item.is_instance_of::<TraceAssertionTask>() {
94            let task = item.extract::<TraceAssertionTask>()?;
95            trace_tasks.push(task);
96        } else if item.is_instance_of::<AgentAssertionTask>() {
97            let task = item.extract::<AgentAssertionTask>()?;
98            agent_tasks.push(task);
99        } else {
100            return Err(TypeError::InvalidAssertionTaskType);
101        }
102    }
103    Ok(AssertionTasks {
104        assertion: assertion_tasks,
105        judge: llm_judge_tasks,
106        trace: trace_tasks,
107        agent: agent_tasks,
108    })
109}