Skip to main content

scouter_evaluate/
genai.rs

1use crate::error::EvaluationError;
2use crate::evaluate::types::{EvaluationConfig, GenAIEvalResults};
3use crate::utils::{
4    collect_and_align_results, post_process_aligned_results,
5    spawn_evaluation_tasks_with_embeddings, spawn_evaluation_tasks_without_embeddings,
6};
7use pyo3::prelude::*;
8use pyo3::types::{PyList, PySlice};
9use pyo3::IntoPyObjectExt;
10use scouter_state::app_state;
11use scouter_types::genai::{AssertionTask, GenAIEvalProfile, LLMJudgeTask, TraceAssertionTask};
12use scouter_types::GenAIEvalRecord;
13use scouter_types::PyHelperFuncs;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{debug, instrument};
18
19/// Main orchestration function that decides which execution path to take
20/// # Arguments
21/// * `dataset`: The dataset containing records to evaluate.
22/// * `embedding_targets`: Optional list of fields to embed.
23#[instrument(skip_all)]
24pub async fn evaluate_genai_dataset(
25    dataset: &GenAIEvalDataset,
26    config: &Arc<EvaluationConfig>,
27) -> Result<GenAIEvalResults, EvaluationError> {
28    debug!(
29        "Starting LLM evaluation for {} records",
30        dataset.records.len()
31    );
32
33    let join_set = match (
34        config.embedder.as_ref(),
35        config.embedding_targets.is_empty(),
36    ) {
37        (Some(embedder), false) => {
38            debug!("Using embedding-enabled evaluation path");
39            spawn_evaluation_tasks_with_embeddings(dataset, embedder.clone(), config).await
40        }
41        _ => {
42            debug!("Using standard evaluation path");
43            spawn_evaluation_tasks_without_embeddings(dataset, config).await
44        }
45    };
46
47    let mut results = collect_and_align_results(join_set, &dataset.records).await?;
48
49    if config.needs_post_processing() {
50        post_process_aligned_results(&mut results, config)?;
51    }
52
53    if config.compute_histograms {
54        results.finalize(config)?;
55    }
56
57    Ok(results)
58}
59
60#[pyclass]
61pub struct DatasetRecords {
62    records: Arc<Vec<GenAIEvalRecord>>,
63    index: usize,
64}
65
66#[pymethods]
67impl DatasetRecords {
68    pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
69        slf
70    }
71
72    pub fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<GenAIEvalRecord> {
73        if slf.index < slf.records.len() {
74            let record = slf.records[slf.index].clone();
75            slf.index += 1;
76            Some(record)
77        } else {
78            None
79        }
80    }
81
82    fn __getitem__<'py>(
83        &self,
84        py: Python<'py>,
85        index: &Bound<'py, PyAny>,
86    ) -> Result<Bound<'py, PyAny>, EvaluationError> {
87        if let Ok(i) = index.extract::<isize>() {
88            let len = self.records.len() as isize;
89            let actual_index = if i < 0 { len + i } else { i };
90
91            if actual_index < 0 || actual_index >= len {
92                return Err(EvaluationError::IndexOutOfBounds {
93                    index: i,
94                    length: self.records.len(),
95                });
96            }
97
98            Ok(self.records[actual_index as usize]
99                .clone()
100                .into_bound_py_any(py)?)
101        } else if let Ok(slice) = index.cast::<PySlice>() {
102            let indices = slice.indices(self.records.len() as isize)?;
103            let mut result = Vec::new();
104
105            let mut i = indices.start;
106            while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
107                result.push(self.records[i as usize].clone());
108                i += indices.step;
109            }
110
111            Ok(result.into_bound_py_any(py)?)
112        } else {
113            Err(EvaluationError::IndexOrSliceExpected)
114        }
115    }
116
117    fn __len__(&self) -> usize {
118        self.records.len()
119    }
120}
121
122#[pyclass]
123#[derive(Debug, Serialize, Deserialize)]
124pub struct GenAIEvalDataset {
125    pub records: Arc<Vec<GenAIEvalRecord>>,
126    pub profile: Arc<GenAIEvalProfile>,
127}
128
129#[pymethods]
130impl GenAIEvalDataset {
131    #[new]
132    #[pyo3(signature = (records, tasks))]
133    pub fn new(
134        records: Vec<GenAIEvalRecord>,
135        tasks: &Bound<'_, PyList>,
136    ) -> Result<Self, EvaluationError> {
137        let profile = GenAIEvalProfile::new_py(tasks, None, None)?;
138
139        Ok(Self {
140            records: Arc::new(records),
141            profile: Arc::new(profile),
142        })
143    }
144
145    #[getter]
146    pub fn records(&self) -> DatasetRecords {
147        DatasetRecords {
148            records: Arc::clone(&self.records),
149            index: 0,
150        }
151    }
152
153    fn __iter__(slf: PyRef<'_, Self>) -> DatasetRecords {
154        DatasetRecords {
155            records: Arc::clone(&slf.records),
156            index: 0,
157        }
158    }
159
160    fn __len__(&self) -> usize {
161        self.records.len()
162    }
163
164    #[getter]
165    pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
166        self.profile.llm_judge_tasks()
167    }
168
169    #[getter]
170    pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
171        self.profile.assertion_tasks()
172    }
173
174    #[getter]
175    pub fn trace_assertion_tasks(&self) -> Vec<TraceAssertionTask> {
176        self.profile.trace_assertion_tasks()
177    }
178
179    pub fn print_execution_plan(&self) -> Result<(), EvaluationError> {
180        self.profile.print_execution_plan()?;
181        Ok(())
182    }
183
184    #[pyo3(signature = (config=None))]
185    fn evaluate(
186        &self,
187        config: Option<EvaluationConfig>,
188    ) -> Result<GenAIEvalResults, EvaluationError> {
189        let config = Arc::new(config.unwrap_or_default());
190        app_state()
191            .handle()
192            .block_on(async { evaluate_genai_dataset(self, &config).await })
193    }
194
195    pub fn __str__(&self) -> String {
196        // serialize the struct to a string
197        PyHelperFuncs::__str__(self)
198    }
199
200    /// Update contexts by record ID mapping. This is the safest approach for
201    /// ensuring context updates align with the correct records.
202    ///
203    /// # Arguments
204    /// * `context_map` - Dictionary mapping record_id to new context object
205    ///
206    /// # Returns
207    /// A new `GenAIEvalDataset` with updated contexts for matched IDs
208    ///
209    /// # Example
210    /// ```python
211    /// baseline_dataset = GenAIEvalDataset(records=[...], tasks=[...])
212    ///
213    /// # Update specific records by ID
214    /// new_contexts = {
215    ///     "product_classification_0": updated_context_0,
216    ///     "product_classification_1": updated_context_1,
217    /// }
218    ///
219    /// comparison_dataset = baseline_dataset.with_updated_contexts_by_id(new_contexts)
220    /// ```
221    #[pyo3(signature = (context_map))]
222    pub fn with_updated_contexts_by_id(
223        &self,
224        py: Python<'_>,
225        context_map: HashMap<String, Bound<'_, PyAny>>,
226    ) -> Result<Self, EvaluationError> {
227        let updated_records: Vec<GenAIEvalRecord> = self
228            .records
229            .iter()
230            .map(|record| {
231                if let Some(new_context) = context_map.get(&record.record_id) {
232                    let mut updated_record = record.clone();
233                    updated_record.update_context(py, new_context)?;
234                    Ok(updated_record)
235                } else {
236                    Ok(record.clone())
237                }
238            })
239            .collect::<Result<Vec<GenAIEvalRecord>, EvaluationError>>()?;
240
241        Ok(Self {
242            records: Arc::new(updated_records),
243            profile: Arc::clone(&self.profile),
244        })
245    }
246}