scouter_evaluate/
genai.rs

1use crate::error::EvaluationError;
2use crate::evaluate::types::{EvaluationConfig, GenAIEvalResults};
3use crate::utils::{
4    collect_and_align_results, post_process_aligned_results,
5    spawn_evaluation_tasks_with_embeddings, spawn_evaluation_tasks_without_embeddings,
6};
7use pyo3::prelude::*;
8use pyo3::types::{PyList, PySlice};
9use pyo3::IntoPyObjectExt;
10use scouter_state::app_state;
11use scouter_types::genai::{AssertionTask, GenAIEvalConfig, GenAIEvalProfile, LLMJudgeTask};
12use scouter_types::GenAIEvalRecord;
13use scouter_types::PyHelperFuncs;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{debug, instrument};
18
19/// Main orchestration function that decides which execution path to take
20/// # Arguments
21/// * `dataset`: The dataset containing records to evaluate.
22/// * `embedding_targets`: Optional list of fields to embed.
23#[instrument(skip_all)]
24pub async fn evaluate_genai_dataset(
25    dataset: &GenAIEvalDataset,
26    config: &Arc<EvaluationConfig>,
27) -> Result<GenAIEvalResults, EvaluationError> {
28    debug!(
29        "Starting LLM evaluation for {} records",
30        dataset.records.len()
31    );
32
33    let join_set = match (
34        config.embedder.as_ref(),
35        config.embedding_targets.is_empty(),
36    ) {
37        (Some(embedder), false) => {
38            debug!("Using embedding-enabled evaluation path");
39            spawn_evaluation_tasks_with_embeddings(dataset, embedder.clone(), config).await
40        }
41        _ => {
42            debug!("Using standard evaluation path");
43            spawn_evaluation_tasks_without_embeddings(dataset, config).await
44        }
45    };
46
47    let mut results = collect_and_align_results(join_set, &dataset.records).await?;
48
49    if config.needs_post_processing() {
50        post_process_aligned_results(&mut results, config)?;
51    }
52
53    if config.compute_histograms {
54        results.finalize(config)?;
55    }
56
57    Ok(results)
58}
59
60#[pyclass]
61pub struct DatasetRecords {
62    records: Arc<Vec<GenAIEvalRecord>>,
63    index: usize,
64}
65
66#[pymethods]
67impl DatasetRecords {
68    pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
69        slf
70    }
71
72    pub fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<GenAIEvalRecord> {
73        if slf.index < slf.records.len() {
74            let record = slf.records[slf.index].clone();
75            slf.index += 1;
76            Some(record)
77        } else {
78            None
79        }
80    }
81
82    fn __getitem__<'py>(
83        &self,
84        py: Python<'py>,
85        index: &Bound<'py, PyAny>,
86    ) -> Result<Bound<'py, PyAny>, EvaluationError> {
87        if let Ok(i) = index.extract::<isize>() {
88            let len = self.records.len() as isize;
89            let actual_index = if i < 0 { len + i } else { i };
90
91            if actual_index < 0 || actual_index >= len {
92                return Err(EvaluationError::IndexOutOfBounds {
93                    index: i,
94                    length: self.records.len(),
95                });
96            }
97
98            Ok(self.records[actual_index as usize]
99                .clone()
100                .into_bound_py_any(py)?)
101        } else if let Ok(slice) = index.cast::<PySlice>() {
102            let indices = slice.indices(self.records.len() as isize)?;
103            let mut result = Vec::new();
104
105            let mut i = indices.start;
106            while (indices.step > 0 && i < indices.stop) || (indices.step < 0 && i > indices.stop) {
107                result.push(self.records[i as usize].clone());
108                i += indices.step;
109            }
110
111            Ok(result.into_bound_py_any(py)?)
112        } else {
113            Err(EvaluationError::IndexOrSliceExpected)
114        }
115    }
116
117    fn __len__(&self) -> usize {
118        self.records.len()
119    }
120}
121
122#[pyclass]
123#[derive(Debug, Serialize, Deserialize)]
124pub struct GenAIEvalDataset {
125    pub records: Arc<Vec<GenAIEvalRecord>>,
126    pub profile: Arc<GenAIEvalProfile>,
127}
128
129#[pymethods]
130impl GenAIEvalDataset {
131    #[new]
132    #[pyo3(signature = (records, tasks))]
133    pub fn new(
134        records: Vec<GenAIEvalRecord>,
135        tasks: &Bound<'_, PyList>,
136    ) -> Result<Self, EvaluationError> {
137        let profile = GenAIEvalProfile::new_py(GenAIEvalConfig::default(), tasks)?;
138
139        Ok(Self {
140            records: Arc::new(records),
141            profile: Arc::new(profile),
142        })
143    }
144
145    #[getter]
146    pub fn records(&self) -> DatasetRecords {
147        DatasetRecords {
148            records: Arc::clone(&self.records),
149            index: 0,
150        }
151    }
152
153    fn __iter__(slf: PyRef<'_, Self>) -> DatasetRecords {
154        DatasetRecords {
155            records: Arc::clone(&slf.records),
156            index: 0,
157        }
158    }
159
160    fn __len__(&self) -> usize {
161        self.records.len()
162    }
163
164    #[getter]
165    pub fn llm_judge_tasks(&self) -> Vec<LLMJudgeTask> {
166        self.profile.llm_judge_tasks.clone()
167    }
168
169    #[getter]
170    pub fn assertion_tasks(&self) -> Vec<AssertionTask> {
171        self.profile.assertion_tasks.clone()
172    }
173
174    pub fn print_execution_plan(&self) -> Result<(), EvaluationError> {
175        self.profile.print_execution_plan()?;
176        Ok(())
177    }
178
179    #[pyo3(signature = (config=None))]
180    fn evaluate(
181        &self,
182        config: Option<EvaluationConfig>,
183    ) -> Result<GenAIEvalResults, EvaluationError> {
184        let config = Arc::new(config.unwrap_or_default());
185        app_state()
186            .handle()
187            .block_on(async { evaluate_genai_dataset(self, &config).await })
188    }
189
190    pub fn __str__(&self) -> String {
191        // serialize the struct to a string
192        PyHelperFuncs::__str__(self)
193    }
194
195    /// Update contexts by record ID mapping. This is the safest approach for
196    /// ensuring context updates align with the correct records.
197    ///
198    /// # Arguments
199    /// * `context_map` - Dictionary mapping record_id to new context object
200    ///
201    /// # Returns
202    /// A new `GenAIEvalDataset` with updated contexts for matched IDs
203    ///
204    /// # Example
205    /// ```python
206    /// baseline_dataset = GenAIEvalDataset(records=[...], tasks=[...])
207    ///
208    /// # Update specific records by ID
209    /// new_contexts = {
210    ///     "product_classification_0": updated_context_0,
211    ///     "product_classification_1": updated_context_1,
212    /// }
213    ///
214    /// comparison_dataset = baseline_dataset.with_updated_contexts_by_id(new_contexts)
215    /// ```
216    #[pyo3(signature = (context_map))]
217    pub fn with_updated_contexts_by_id(
218        &self,
219        py: Python<'_>,
220        context_map: HashMap<String, Bound<'_, PyAny>>,
221    ) -> Result<Self, EvaluationError> {
222        let updated_records: Vec<GenAIEvalRecord> = self
223            .records
224            .iter()
225            .map(|record| {
226                if let Some(new_context) = context_map.get(&record.record_id) {
227                    let mut updated_record = record.clone();
228                    updated_record.update_context(py, new_context)?;
229                    Ok(updated_record)
230                } else {
231                    Ok(record.clone())
232                }
233            })
234            .collect::<Result<Vec<GenAIEvalRecord>, EvaluationError>>()?;
235
236        Ok(Self {
237            records: Arc::new(updated_records),
238            profile: Arc::clone(&self.profile),
239        })
240    }
241}