scouter_types/genai/
utils.rs1use crate::error::TypeError;
2use crate::genai::{AssertionTask, LLMJudgeTask, TaskConfig, TasksFile, TraceAssertionTask};
3use pyo3::prelude::*;
4use pyo3::types::PyList;
5use serde::{Deserialize, Serialize};
6use std::collections::BTreeSet;
7
8#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
9pub struct AssertionTasks {
10 pub assertion: Vec<AssertionTask>,
11 pub judge: Vec<LLMJudgeTask>,
12 pub trace: Vec<TraceAssertionTask>,
13}
14
15impl AssertionTasks {
16 pub fn collect_non_judge_task_ids(&self) -> BTreeSet<String> {
17 self.assertion
18 .iter()
19 .map(|t| t.id.clone())
20 .chain(self.trace.iter().map(|t| t.id.clone()))
21 .collect()
22 }
23
24 pub fn collect_all_task_ids(&self) -> Result<BTreeSet<String>, TypeError> {
25 let mut task_ids = BTreeSet::new();
26
27 for task in &self.assertion {
28 task_ids.insert(task.id.clone());
29 }
30 for task in &self.judge {
31 task_ids.insert(task.id.clone());
32 }
33 for task in &self.trace {
34 task_ids.insert(task.id.clone());
35 }
36
37 let total_tasks = self.assertion.len() + self.judge.len() + self.trace.len();
38 if task_ids.len() != total_tasks {
39 return Err(TypeError::DuplicateTaskIds);
40 }
41
42 Ok(task_ids)
43 }
44
45 pub fn from_tasks_file(tasks: TasksFile) -> Self {
47 let mut assertion = Vec::new();
48 let mut judge = Vec::new();
49 let mut trace = Vec::new();
50
51 for task in tasks.tasks {
52 match task {
53 TaskConfig::Assertion(t) => assertion.push(t),
54 TaskConfig::LLMJudge(t) => judge.push(*t),
55 TaskConfig::TraceAssertion(t) => trace.push(t),
56 }
57 }
58
59 AssertionTasks {
60 assertion,
61 judge,
62 trace,
63 }
64 }
65}
66
67pub fn extract_assertion_tasks_from_pylist(
69 list: &Bound<'_, PyList>,
70) -> Result<AssertionTasks, TypeError> {
71 let mut assertion_tasks = Vec::new();
72 let mut llm_judge_tasks = Vec::new();
73 let mut trace_tasks = Vec::new();
74
75 for item in list.iter() {
76 if item.is_instance_of::<AssertionTask>() {
77 let task = item.extract::<AssertionTask>()?;
78 assertion_tasks.push(task);
79 } else if item.is_instance_of::<LLMJudgeTask>() {
80 let task = item.extract::<LLMJudgeTask>()?;
81 llm_judge_tasks.push(task);
82 } else if item.is_instance_of::<TraceAssertionTask>() {
83 let task = item.extract::<TraceAssertionTask>()?;
84 trace_tasks.push(task);
85 } else {
86 return Err(TypeError::InvalidAssertionTaskType);
87 }
88 }
89 Ok(AssertionTasks {
90 assertion: assertion_tasks,
91 judge: llm_judge_tasks,
92 trace: trace_tasks,
93 })
94}