Skip to main content

scouter_types/genai/
utils.rs

1use crate::error::TypeError;
2use crate::genai::{AssertionTask, LLMJudgeTask, TaskConfig, TasksFile, TraceAssertionTask};
3use pyo3::prelude::*;
4use pyo3::types::PyList;
5use serde::{Deserialize, Serialize};
6use std::collections::BTreeSet;
7
8#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
9pub struct AssertionTasks {
10    pub assertion: Vec<AssertionTask>,
11    pub judge: Vec<LLMJudgeTask>,
12    pub trace: Vec<TraceAssertionTask>,
13}
14
15impl AssertionTasks {
16    pub fn collect_non_judge_task_ids(&self) -> BTreeSet<String> {
17        self.assertion
18            .iter()
19            .map(|t| t.id.clone())
20            .chain(self.trace.iter().map(|t| t.id.clone()))
21            .collect()
22    }
23
24    pub fn collect_all_task_ids(&self) -> Result<BTreeSet<String>, TypeError> {
25        let mut task_ids = BTreeSet::new();
26
27        for task in &self.assertion {
28            task_ids.insert(task.id.clone());
29        }
30        for task in &self.judge {
31            task_ids.insert(task.id.clone());
32        }
33        for task in &self.trace {
34            task_ids.insert(task.id.clone());
35        }
36
37        let total_tasks = self.assertion.len() + self.judge.len() + self.trace.len();
38        if task_ids.len() != total_tasks {
39            return Err(TypeError::DuplicateTaskIds);
40        }
41
42        Ok(task_ids)
43    }
44
45    /// this is meant to consume the tasks from TasksFile and convert them into AssertionTasks, separating them into assertion, judge, and trace tasks. It also validates that there are no duplicate task IDs across all tasks.
46    pub fn from_tasks_file(tasks: TasksFile) -> Self {
47        let mut assertion = Vec::new();
48        let mut judge = Vec::new();
49        let mut trace = Vec::new();
50
51        for task in tasks.tasks {
52            match task {
53                TaskConfig::Assertion(t) => assertion.push(t),
54                TaskConfig::LLMJudge(t) => judge.push(*t),
55                TaskConfig::TraceAssertion(t) => trace.push(t),
56            }
57        }
58
59        AssertionTasks {
60            assertion,
61            judge,
62            trace,
63        }
64    }
65}
66
67/// Helper function to extract AssertionTask and LLMJudgeTask from a PyList
68pub fn extract_assertion_tasks_from_pylist(
69    list: &Bound<'_, PyList>,
70) -> Result<AssertionTasks, TypeError> {
71    let mut assertion_tasks = Vec::new();
72    let mut llm_judge_tasks = Vec::new();
73    let mut trace_tasks = Vec::new();
74
75    for item in list.iter() {
76        if item.is_instance_of::<AssertionTask>() {
77            let task = item.extract::<AssertionTask>()?;
78            assertion_tasks.push(task);
79        } else if item.is_instance_of::<LLMJudgeTask>() {
80            let task = item.extract::<LLMJudgeTask>()?;
81            llm_judge_tasks.push(task);
82        } else if item.is_instance_of::<TraceAssertionTask>() {
83            let task = item.extract::<TraceAssertionTask>()?;
84            trace_tasks.push(task);
85        } else {
86            return Err(TypeError::InvalidAssertionTaskType);
87        }
88    }
89    Ok(AssertionTasks {
90        assertion: assertion_tasks,
91        judge: llm_judge_tasks,
92        trace: trace_tasks,
93    })
94}