scouter_evaluate/
genai.rs1use crate::error::EvaluationError;
2use crate::evaluate::types::{EvaluationConfig, GenAIEvalResults};
3use crate::utils::{
4 collect_and_align_results, post_process_aligned_results,
5 spawn_evaluation_tasks_with_embeddings, spawn_evaluation_tasks_without_embeddings,
6};
7use pyo3::prelude::*;
8use pyo3::types::{PyList, PySlice};
9use pyo3::IntoPyObjectExt;
10use scouter_state::app_state;
11use scouter_types::genai::{AssertionTask, GenAIEvalConfig, GenAIEvalProfile, LLMJudgeTask};
12use scouter_types::GenAIEvalRecord;
13use scouter_types::PyHelperFuncs;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{debug, instrument};
18
19#[instrument(skip_all)]
24pub async fn evaluate_genai_dataset(
25 dataset: &GenAIEvalDataset,
26 config: &Arc<EvaluationConfig>,
27) -> Result<GenAIEvalResults, EvaluationError> {
28 debug!(
29 "Starting LLM evaluation for {} records",
30 dataset.records.len()
31 );
32
33 let join_set = match (
34 config.embedder.as_ref(),
35 config.embedding_targets.is_empty(),
36 ) {
37 (Some(embedder), false) => {
38 debug!("Using embedding-enabled evaluation path");
39 spawn_evaluation_tasks_with_embeddings(dataset, embedder.clone(), config).await
40 }
41 _ => {
42 debug!("Using standard evaluation path");
43 spawn_evaluation_tasks_without_embeddings(dataset, config).await
44 }
45 };
46
47 let mut results = collect_and_align_results(join_set, &dataset.records).await?;
48
49 if config.needs_post_processing() {
50 post_process_aligned_results(&mut results, config)?;
51 }
52
53 if config.compute_histograms {
54 results.finalize(config)?;
55 }
56
57 Ok(results)
58}
59
60#[pyclass]
61pub struct DatasetRecords {
62 records: Arc<Vec<GenAIEvalRecord>>,
63 index: usize,
64}
65
66#[pymethods]
67impl DatasetRecords {
68 pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
69 slf
70 }
71
72 pub fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<GenAIEvalRecord> {
73 if slf.index < slf.records.len() {
74 let record = slf.records[slf.index].clone();
75 slf.index += 1;
76 Some(record)
77 } else {
78 None
79 }
80 }
81
82 fn __getitem__<'py>(
83 &self,
84 py: Python<'py>,
85 index: &Bound<'py, PyAny>,
86 ) -> Result<Bound<'py, PyAny>, EvaluationError> {
87 if let Ok(i) = index.extract::<isize>() {
88 let len = self.records.len() as isize;
89 let actual_index = if i < 0 { len + i } else { i };
90
91 if actual_index < 0 || actual_index >= len {
92 return Err(EvaluationError::IndexOutOfBounds {
93 index: i,
94 length: self.records.len(),
95 });
96 }
97
98 Ok(self.records[actual_index as usize]
99 .clone()
100 .into_bound_py_any(py)?)
101 } else if let Ok(slice) = index.cast::<PySlice>() {
102 let indices = slice.indices(self.records.len() as isize)?;
103 let mut result = Vec::new();
104
105 let mut i = indices.start;
106 while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
107 result.push(self.records[i as usize].clone());
108 i += indices.step;
109 }
110
111 Ok(result.into_bound_py_any(py)?)
112 } else {
113 Err(EvaluationError::IndexOrSliceExpected)
114 }
115 }
116
117 fn __len__(&self) -> usize {
118 self.records.len()
119 }
120}
121
122#[pyclass]
123#[derive(Debug, Serialize, Deserialize)]
124pub struct GenAIEvalDataset {
125 pub records: Arc<Vec<GenAIEvalRecord>>,
126 pub profile: Arc<GenAIEvalProfile>,
127}
128
129#[pymethods]
130impl GenAIEvalDataset {
131 #[new]
132 #[pyo3(signature = (records, tasks))]
133 pub fn new(
134 records: Vec<GenAIEvalRecord>,
135 tasks: &Bound<'_, PyList>,
136 ) -> Result<Self, EvaluationError> {
137 let profile = GenAIEvalProfile::new_py(GenAIEvalConfig::default(), tasks)?;
138
139 Ok(Self {
140 records: Arc::new(records),
141 profile: Arc::new(profile),
142 })
143 }
144
145 #[getter]
146 pub fn records(&self) -> DatasetRecords {
147 DatasetRecords {
148 records: Arc::clone(&self.records),
149 index: 0,
150 }
151 }
152
153 fn __iter__(slf: PyRef<'_, Self>) -> DatasetRecords {
154 DatasetRecords {
155 records: Arc::clone(&slf.records),
156 index: 0,
157 }
158 }
159
160 fn __len__(&self) -> usize {
161 self.records.len()
162 }
163
164 #[getter]
165 pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
166 self.profile.llm_judge_tasks.clone()
167 }
168
169 #[getter]
170 pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
171 self.profile.assertion_tasks.clone()
172 }
173
174 pub fn print_execution_plan(&self) -> Result<(), EvaluationError> {
175 self.profile.print_execution_plan()?;
176 Ok(())
177 }
178
179 #[pyo3(signature = (config=None))]
180 fn evaluate(
181 &self,
182 config: Option<EvaluationConfig>,
183 ) -> Result<GenAIEvalResults, EvaluationError> {
184 let config = Arc::new(config.unwrap_or_default());
185 app_state()
186 .handle()
187 .block_on(async { evaluate_genai_dataset(self, &config).await })
188 }
189
190 pub fn __str__(&self) -> String {
191 PyHelperFuncs::__str__(self)
193 }
194
195 #[pyo3(signature = (context_map))]
217 pub fn with_updated_contexts_by_id(
218 &self,
219 py: Python<'_>,
220 context_map: HashMap<String, Bound<'_, PyAny>>,
221 ) -> Result<Self, EvaluationError> {
222 let updated_records: Vec<GenAIEvalRecord> = self
223 .records
224 .iter()
225 .map(|record| {
226 if let Some(new_context) = context_map.get(&record.record_id) {
227 let mut updated_record = record.clone();
228 updated_record.update_context(py, new_context)?;
229 Ok(updated_record)
230 } else {
231 Ok(record.clone())
232 }
233 })
234 .collect::<Result<Vec<GenAIEvalRecord>, EvaluationError>>()?;
235
236 Ok(Self {
237 records: Arc::new(updated_records),
238 profile: Arc::clone(&self.profile),
239 })
240 }
241}