Skip to main content

scouter_evaluate/
genai.rs

1use crate::error::EvaluationError;
2use crate::evaluate::types::{EvalResults, EvaluationConfig};
3use crate::utils::{
4    collect_and_align_results, post_process_aligned_results,
5    spawn_evaluation_tasks_with_embeddings, spawn_evaluation_tasks_without_embeddings,
6};
7use pyo3::prelude::*;
8use pyo3::types::{PyList, PySlice};
9use pyo3::IntoPyObjectExt;
10use scouter_state::app_state;
11use scouter_types::genai::{
12    AgentAssertionTask, AssertionTask, GenAIEvalProfile, LLMJudgeTask, TraceAssertionTask,
13};
14use scouter_types::trace::sql::TraceSpan;
15use scouter_types::EvalRecord;
16use scouter_types::PyHelperFuncs;
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19use std::sync::Arc;
20use tracing::{debug, instrument};
21
22/// Main orchestration function that decides which execution path to take
23/// # Arguments
24/// * `dataset`: The dataset containing records to evaluate.
25/// * `embedding_targets`: Optional list of fields to embed.
26#[instrument(skip_all)]
27pub async fn evaluate_genai_dataset(
28    dataset: &EvalDataset,
29    config: &Arc<EvaluationConfig>,
30) -> Result<EvalResults, EvaluationError> {
31    debug!(
32        "Starting LLM evaluation for {} records",
33        dataset.records.len()
34    );
35
36    let join_set = match (
37        config.embedder.as_ref(),
38        config.embedding_targets.is_empty(),
39    ) {
40        (Some(embedder), false) => {
41            debug!("Using embedding-enabled evaluation path");
42            spawn_evaluation_tasks_with_embeddings(dataset, embedder.clone(), config).await
43        }
44        _ => {
45            debug!("Using standard evaluation path");
46            spawn_evaluation_tasks_without_embeddings(dataset, config).await
47        }
48    };
49
50    let mut results = collect_and_align_results(join_set, &dataset.records).await?;
51
52    if config.needs_post_processing() {
53        post_process_aligned_results(&mut results, config)?;
54    }
55
56    if config.compute_histograms {
57        results.finalize(config)?;
58    }
59
60    Ok(results)
61}
62
63#[pyclass]
64pub struct DatasetRecords {
65    records: Arc<Vec<EvalRecord>>,
66    index: usize,
67}
68
69#[pymethods]
70impl DatasetRecords {
71    pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
72        slf
73    }
74
75    pub fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<EvalRecord> {
76        if slf.index < slf.records.len() {
77            let record = slf.records[slf.index].clone();
78            slf.index += 1;
79            Some(record)
80        } else {
81            None
82        }
83    }
84
85    fn __getitem__<'py>(
86        &self,
87        py: Python<'py>,
88        index: &Bound<'py, PyAny>,
89    ) -> Result<Bound<'py, PyAny>, EvaluationError> {
90        if let Ok(i) = index.extract::<isize>() {
91            let len = self.records.len() as isize;
92            let actual_index = if i < 0 { len + i } else { i };
93
94            if actual_index < 0 || actual_index >= len {
95                return Err(EvaluationError::IndexOutOfBounds {
96                    index: i,
97                    length: self.records.len(),
98                });
99            }
100
101            Ok(self.records[actual_index as usize]
102                .clone()
103                .into_bound_py_any(py)?)
104        } else if let Ok(slice) = index.cast::<PySlice>() {
105            let indices = slice.indices(self.records.len() as isize)?;
106            let mut result = Vec::new();
107
108            let mut i = indices.start;
109            while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
110                result.push(self.records[i as usize].clone());
111                i += indices.step;
112            }
113
114            Ok(result.into_bound_py_any(py)?)
115        } else {
116            Err(EvaluationError::IndexOrSliceExpected)
117        }
118    }
119
120    fn __len__(&self) -> usize {
121        self.records.len()
122    }
123}
124
125#[pyclass]
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct EvalDataset {
128    pub records: Arc<Vec<EvalRecord>>,
129    pub profile: Arc<GenAIEvalProfile>,
130    #[serde(skip)]
131    pub spans: Arc<Vec<TraceSpan>>,
132}
133
134#[pymethods]
135impl EvalDataset {
136    #[new]
137    #[pyo3(signature = (records, tasks))]
138    pub fn new(
139        records: Vec<EvalRecord>,
140        tasks: &Bound<'_, PyList>,
141    ) -> Result<Self, EvaluationError> {
142        let profile = GenAIEvalProfile::new_py(tasks, None, None)?;
143
144        Ok(Self {
145            records: Arc::new(records),
146            profile: Arc::new(profile),
147            spans: Arc::new(vec![]),
148        })
149    }
150
151    #[getter]
152    pub fn records(&self) -> DatasetRecords {
153        DatasetRecords {
154            records: Arc::clone(&self.records),
155            index: 0,
156        }
157    }
158
159    fn __iter__(slf: PyRef<'_, Self>) -> DatasetRecords {
160        DatasetRecords {
161            records: Arc::clone(&slf.records),
162            index: 0,
163        }
164    }
165
166    fn __len__(&self) -> usize {
167        self.records.len()
168    }
169
170    #[getter]
171    pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
172        self.profile.llm_judge_tasks()
173    }
174
175    #[getter]
176    pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
177        self.profile.assertion_tasks()
178    }
179
180    #[getter]
181    pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
182        self.profile.trace_assertion_tasks()
183    }
184
185    #[getter]
186    pub fn agent_assertion_tasks(&self) -> Vec<AgentAssertionTask> {
187        self.profile.agent_assertion_tasks()
188    }
189
190    pub fn print_execution_plan(&self) -> Result<(), EvaluationError> {
191        self.profile.print_execution_plan()?;
192        Ok(())
193    }
194
195    #[pyo3(signature = (config=None))]
196    fn evaluate(&self, config: Option<EvaluationConfig>) -> Result<EvalResults, EvaluationError> {
197        let config = Arc::new(config.unwrap_or_default());
198        app_state()
199            .handle()
200            .block_on(async { evaluate_genai_dataset(self, &config).await })
201    }
202
203    pub fn __str__(&self) -> String {
204        // serialize the struct to a string
205        PyHelperFuncs::__str__(self)
206    }
207
208    /// Update contexts by record ID mapping. This is the safest approach for
209    /// ensuring context updates align with the correct records.
210    ///
211    /// # Arguments
212    /// * `context_map` - Dictionary mapping record_id to new context object
213    ///
214    /// # Returns
215    /// A new `EvalDataset` with updated contexts for matched IDs
216    ///
217    /// # Example
218    /// ```python
219    /// baseline_dataset = EvalDataset(records=[...], tasks=[...])
220    ///
221    /// # Update specific records by ID
222    /// new_contexts = {
223    ///     "product_classification_0": updated_context_0,
224    ///     "product_classification_1": updated_context_1,
225    /// }
226    ///
227    /// comparison_dataset = baseline_dataset.with_updated_contexts_by_id(new_contexts)
228    /// ```
229    #[pyo3(signature = (context_map))]
230    pub fn with_updated_contexts_by_id(
231        &self,
232        py: Python<'_>,
233        context_map: HashMap<String, Bound<'_, PyAny>>,
234    ) -> Result<Self, EvaluationError> {
235        let updated_records: Vec<EvalRecord> = self
236            .records
237            .iter()
238            .map(|record| {
239                if let Some(new_context) = context_map.get(&record.record_id) {
240                    let mut updated_record = record.clone();
241                    updated_record.update_context(py, new_context)?;
242                    Ok(updated_record)
243                } else {
244                    Ok(record.clone())
245                }
246            })
247            .collect::<Result<Vec<EvalRecord>, EvaluationError>>()?;
248
249        Ok(Self {
250            records: Arc::new(updated_records),
251            profile: Arc::clone(&self.profile),
252            spans: Arc::clone(&self.spans),
253        })
254    }
255}