scouter_evaluate/
genai.rs1use crate::error::EvaluationError;
2use crate::evaluate::types::{EvalResults, EvaluationConfig};
3use crate::utils::{
4 collect_and_align_results, post_process_aligned_results,
5 spawn_evaluation_tasks_with_embeddings, spawn_evaluation_tasks_without_embeddings,
6};
7use pyo3::prelude::*;
8use pyo3::types::{PyList, PySlice};
9use pyo3::IntoPyObjectExt;
10use scouter_state::app_state;
11use scouter_types::genai::{
12 AgentAssertionTask, AssertionTask, GenAIEvalProfile, LLMJudgeTask, TraceAssertionTask,
13};
14use scouter_types::trace::sql::TraceSpan;
15use scouter_types::EvalRecord;
16use scouter_types::PyHelperFuncs;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::Arc;
20use tracing::{debug, instrument};
21
22#[instrument(skip_all)]
27pub async fn evaluate_genai_dataset(
28 dataset: &EvalDataset,
29 config: &Arc<EvaluationConfig>,
30) -> Result<EvalResults, EvaluationError> {
31 debug!(
32 "Starting LLM evaluation for {} records",
33 dataset.records.len()
34 );
35
36 let join_set = match (
37 config.embedder.as_ref(),
38 config.embedding_targets.is_empty(),
39 ) {
40 (Some(embedder), false) => {
41 debug!("Using embedding-enabled evaluation path");
42 spawn_evaluation_tasks_with_embeddings(dataset, embedder.clone(), config).await
43 }
44 _ => {
45 debug!("Using standard evaluation path");
46 spawn_evaluation_tasks_without_embeddings(dataset, config).await
47 }
48 };
49
50 let mut results = collect_and_align_results(join_set, &dataset.records).await?;
51
52 if config.needs_post_processing() {
53 post_process_aligned_results(&mut results, config)?;
54 }
55
56 if config.compute_histograms {
57 results.finalize(config)?;
58 }
59
60 Ok(results)
61}
62
63#[pyclass]
64pub struct DatasetRecords {
65 records: Arc<Vec<EvalRecord>>,
66 index: usize,
67}
68
69#[pymethods]
70impl DatasetRecords {
71 pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
72 slf
73 }
74
75 pub fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<EvalRecord> {
76 if slf.index < slf.records.len() {
77 let record = slf.records[slf.index].clone();
78 slf.index += 1;
79 Some(record)
80 } else {
81 None
82 }
83 }
84
85 fn __getitem__<'py>(
86 &self,
87 py: Python<'py>,
88 index: &Bound<'py, PyAny>,
89 ) -> Result<Bound<'py, PyAny>, EvaluationError> {
90 if let Ok(i) = index.extract::<isize>() {
91 let len = self.records.len() as isize;
92 let actual_index = if i < 0 { len + i } else { i };
93
94 if actual_index < 0 || actual_index >= len {
95 return Err(EvaluationError::IndexOutOfBounds {
96 index: i,
97 length: self.records.len(),
98 });
99 }
100
101 Ok(self.records[actual_index as usize]
102 .clone()
103 .into_bound_py_any(py)?)
104 } else if let Ok(slice) = index.cast::<PySlice>() {
105 let indices = slice.indices(self.records.len() as isize)?;
106 let mut result = Vec::new();
107
108 let mut i = indices.start;
109 while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
110 result.push(self.records[i as usize].clone());
111 i += indices.step;
112 }
113
114 Ok(result.into_bound_py_any(py)?)
115 } else {
116 Err(EvaluationError::IndexOrSliceExpected)
117 }
118 }
119
120 fn __len__(&self) -> usize {
121 self.records.len()
122 }
123}
124
125#[pyclass]
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct EvalDataset {
128 pub records: Arc<Vec<EvalRecord>>,
129 pub profile: Arc<GenAIEvalProfile>,
130 #[serde(skip)]
131 pub spans: Arc<Vec<TraceSpan>>,
132}
133
134#[pymethods]
135impl EvalDataset {
136 #[new]
137 #[pyo3(signature = (records, tasks))]
138 pub fn new(
139 records: Vec<EvalRecord>,
140 tasks: &Bound<'_, PyList>,
141 ) -> Result<Self, EvaluationError> {
142 let profile = GenAIEvalProfile::new_py(tasks, None, None)?;
143
144 Ok(Self {
145 records: Arc::new(records),
146 profile: Arc::new(profile),
147 spans: Arc::new(vec![]),
148 })
149 }
150
151 #[getter]
152 pub fn records(&self) -> DatasetRecords {
153 DatasetRecords {
154 records: Arc::clone(&self.records),
155 index: 0,
156 }
157 }
158
159 fn __iter__(slf: PyRef<'_, Self>) -> DatasetRecords {
160 DatasetRecords {
161 records: Arc::clone(&slf.records),
162 index: 0,
163 }
164 }
165
166 fn __len__(&self) -> usize {
167 self.records.len()
168 }
169
170 #[getter]
171 pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
172 self.profile.llm_judge_tasks()
173 }
174
175 #[getter]
176 pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
177 self.profile.assertion_tasks()
178 }
179
180 #[getter]
181 pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
182 self.profile.trace_assertion_tasks()
183 }
184
185 #[getter]
186 pub fn agent_assertion_tasks(&self) -> Vec<AgentAssertionTask> {
187 self.profile.agent_assertion_tasks()
188 }
189
190 pub fn print_execution_plan(&self) -> Result<(), EvaluationError> {
191 self.profile.print_execution_plan()?;
192 Ok(())
193 }
194
195 #[pyo3(signature = (config=None))]
196 fn evaluate(&self, config: Option<EvaluationConfig>) -> Result<EvalResults, EvaluationError> {
197 let config = Arc::new(config.unwrap_or_default());
198 app_state()
199 .handle()
200 .block_on(async { evaluate_genai_dataset(self, &config).await })
201 }
202
203 pub fn __str__(&self) -> String {
204 PyHelperFuncs::__str__(self)
206 }
207
208 #[pyo3(signature = (context_map))]
230 pub fn with_updated_contexts_by_id(
231 &self,
232 py: Python<'_>,
233 context_map: HashMap<String, Bound<'_, PyAny>>,
234 ) -> Result<Self, EvaluationError> {
235 let updated_records: Vec<EvalRecord> = self
236 .records
237 .iter()
238 .map(|record| {
239 if let Some(new_context) = context_map.get(&record.record_id) {
240 let mut updated_record = record.clone();
241 updated_record.update_context(py, new_context)?;
242 Ok(updated_record)
243 } else {
244 Ok(record.clone())
245 }
246 })
247 .collect::<Result<Vec<EvalRecord>, EvaluationError>>()?;
248
249 Ok(Self {
250 records: Arc::new(updated_records),
251 profile: Arc::clone(&self.profile),
252 spans: Arc::clone(&self.spans),
253 })
254 }
255}