scouter_evaluate/
genai.rs1use crate::error::EvaluationError;
2use crate::evaluate::types::{EvaluationConfig, GenAIEvalResults};
3use crate::utils::{
4 collect_and_align_results, post_process_aligned_results,
5 spawn_evaluation_tasks_with_embeddings, spawn_evaluation_tasks_without_embeddings,
6};
7use pyo3::prelude::*;
8use pyo3::types::{PyList, PySlice};
9use pyo3::IntoPyObjectExt;
10use scouter_state::app_state;
11use scouter_types::genai::{AssertionTask, GenAIEvalProfile, LLMJudgeTask, TraceAssertionTask};
12use scouter_types::GenAIEvalRecord;
13use scouter_types::PyHelperFuncs;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{debug, instrument};
18
19#[instrument(skip_all)]
24pub async fn evaluate_genai_dataset(
25 dataset: &GenAIEvalDataset,
26 config: &Arc<EvaluationConfig>,
27) -> Result<GenAIEvalResults, EvaluationError> {
28 debug!(
29 "Starting LLM evaluation for {} records",
30 dataset.records.len()
31 );
32
33 let join_set = match (
34 config.embedder.as_ref(),
35 config.embedding_targets.is_empty(),
36 ) {
37 (Some(embedder), false) => {
38 debug!("Using embedding-enabled evaluation path");
39 spawn_evaluation_tasks_with_embeddings(dataset, embedder.clone(), config).await
40 }
41 _ => {
42 debug!("Using standard evaluation path");
43 spawn_evaluation_tasks_without_embeddings(dataset, config).await
44 }
45 };
46
47 let mut results = collect_and_align_results(join_set, &dataset.records).await?;
48
49 if config.needs_post_processing() {
50 post_process_aligned_results(&mut results, config)?;
51 }
52
53 if config.compute_histograms {
54 results.finalize(config)?;
55 }
56
57 Ok(results)
58}
59
60#[pyclass]
61pub struct DatasetRecords {
62 records: Arc<Vec<GenAIEvalRecord>>,
63 index: usize,
64}
65
66#[pymethods]
67impl DatasetRecords {
68 pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
69 slf
70 }
71
72 pub fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<GenAIEvalRecord> {
73 if slf.index < slf.records.len() {
74 let record = slf.records[slf.index].clone();
75 slf.index += 1;
76 Some(record)
77 } else {
78 None
79 }
80 }
81
82 fn __getitem__<'py>(
83 &self,
84 py: Python<'py>,
85 index: &Bound<'py, PyAny>,
86 ) -> Result<Bound<'py, PyAny>, EvaluationError> {
87 if let Ok(i) = index.extract::<isize>() {
88 let len = self.records.len() as isize;
89 let actual_index = if i < 0 { len + i } else { i };
90
91 if actual_index < 0 || actual_index >= len {
92 return Err(EvaluationError::IndexOutOfBounds {
93 index: i,
94 length: self.records.len(),
95 });
96 }
97
98 Ok(self.records[actual_index as usize]
99 .clone()
100 .into_bound_py_any(py)?)
101 } else if let Ok(slice) = index.cast::<PySlice>() {
102 let indices = slice.indices(self.records.len() as isize)?;
103 let mut result = Vec::new();
104
105 let mut i = indices.start;
106 while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
107 result.push(self.records[i as usize].clone());
108 i += indices.step;
109 }
110
111 Ok(result.into_bound_py_any(py)?)
112 } else {
113 Err(EvaluationError::IndexOrSliceExpected)
114 }
115 }
116
117 fn __len__(&self) -> usize {
118 self.records.len()
119 }
120}
121
122#[pyclass]
123#[derive(Debug, Serialize, Deserialize)]
124pub struct GenAIEvalDataset {
125 pub records: Arc<Vec<GenAIEvalRecord>>,
126 pub profile: Arc<GenAIEvalProfile>,
127}
128
129#[pymethods]
130impl GenAIEvalDataset {
131 #[new]
132 #[pyo3(signature = (records, tasks))]
133 pub fn new(
134 records: Vec<GenAIEvalRecord>,
135 tasks: &Bound<'_, PyList>,
136 ) -> Result<Self, EvaluationError> {
137 let profile = GenAIEvalProfile::new_py(tasks, None, None)?;
138
139 Ok(Self {
140 records: Arc::new(records),
141 profile: Arc::new(profile),
142 })
143 }
144
145 #[getter]
146 pub fn records(&self) -> DatasetRecords {
147 DatasetRecords {
148 records: Arc::clone(&self.records),
149 index: 0,
150 }
151 }
152
153 fn __iter__(slf: PyRef<'_, Self>) -> DatasetRecords {
154 DatasetRecords {
155 records: Arc::clone(&slf.records),
156 index: 0,
157 }
158 }
159
160 fn __len__(&self) -> usize {
161 self.records.len()
162 }
163
164 #[getter]
165 pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
166 self.profile.llm_judge_tasks()
167 }
168
169 #[getter]
170 pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
171 self.profile.assertion_tasks()
172 }
173
174 #[getter]
175 pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
176 self.profile.trace_assertion_tasks()
177 }
178
179 pub fn print_execution_plan(&self) -> Result<(), EvaluationError> {
180 self.profile.print_execution_plan()?;
181 Ok(())
182 }
183
184 #[pyo3(signature = (config=None))]
185 fn evaluate(
186 &self,
187 config: Option<EvaluationConfig>,
188 ) -> Result<GenAIEvalResults, EvaluationError> {
189 let config = Arc::new(config.unwrap_or_default());
190 app_state()
191 .handle()
192 .block_on(async { evaluate_genai_dataset(self, &config).await })
193 }
194
195 pub fn __str__(&self) -> String {
196 PyHelperFuncs::__str__(self)
198 }
199
200 #[pyo3(signature = (context_map))]
222 pub fn with_updated_contexts_by_id(
223 &self,
224 py: Python<'_>,
225 context_map: HashMap<String, Bound<'_, PyAny>>,
226 ) -> Result<Self, EvaluationError> {
227 let updated_records: Vec<GenAIEvalRecord> = self
228 .records
229 .iter()
230 .map(|record| {
231 if let Some(new_context) = context_map.get(&record.record_id) {
232 let mut updated_record = record.clone();
233 updated_record.update_context(py, new_context)?;
234 Ok(updated_record)
235 } else {
236 Ok(record.clone())
237 }
238 })
239 .collect::<Result<Vec<GenAIEvalRecord>, EvaluationError>>()?;
240
241 Ok(Self {
242 records: Arc::new(updated_records),
243 profile: Arc::clone(&self.profile),
244 })
245 }
246}