scouter_evaluate/
types.rs

1use crate::error::EvaluationError;
2use crate::util::{parse_embedder, post_process};
3use ndarray::Array2;
4use potato_head::{create_uuid7, Embedder, PyHelperFuncs, Score};
5use pyo3::prelude::*;
6use pyo3::types::PyDict;
7use pyo3::IntoPyObjectExt;
8use scouter_profile::{Histogram, NumProfiler};
9use scouter_types::{is_pydantic_basemodel, json_to_pyobject_value, pyobject_to_json};
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::collections::{BTreeMap, HashMap};
13use std::sync::Arc;
14
15pub fn array_to_dict<'py>(
16    py: Python<'py>,
17    array: &ArrayDataset,
18) -> Result<Bound<'py, PyDict>, EvaluationError> {
19    let pydict = PyDict::new(py);
20
21    // set task ids
22    pydict.set_item(
23        "task",
24        array.idx_map.values().cloned().collect::<Vec<String>>(),
25    )?;
26
27    // set feature columns
28    for (i, feature) in array.feature_names.iter().enumerate() {
29        let column_data: Vec<f64> = array.data.column(i).to_vec();
30        pydict.set_item(feature, column_data)?;
31    }
32
33    // add cluster column if available
34    if array.clusters.len() == array.data.nrows() {
35        pydict.set_item("cluster", array.clusters.clone())?;
36    }
37    Ok(pydict)
38}
39
40/// Enhanced results collection that captures both successes and failures
41#[derive(Debug, Serialize, Deserialize)]
42#[pyclass]
43pub struct LLMEvalResults {
44    pub results: HashMap<String, LLMEvalTaskResult>,
45
46    #[pyo3(get)]
47    pub errored_tasks: Vec<String>,
48
49    pub cluster_data: Option<ClusterData>,
50
51    #[pyo3(get)]
52    pub histograms: Option<HashMap<String, Histogram>>,
53
54    #[serde(skip)]
55    pub array_dataset: Option<ArrayDataset>,
56}
57
58#[pymethods]
59impl LLMEvalResults {
60    /// Get tasks for a specific record ID
61    pub fn __getitem__(&self, key: &str) -> Result<LLMEvalTaskResult, EvaluationError> {
62        match self.results.get(key) {
63            Some(value) => Ok(value.clone()),
64            None => Err(EvaluationError::MissingKeyError(key.to_string())),
65        }
66    }
67
68    pub fn __str__(&self) -> String {
69        PyHelperFuncs::__str__(self)
70    }
71
72    pub fn model_dump_json(&self) -> String {
73        // serialize the struct to a string
74        PyHelperFuncs::__json__(self)
75    }
76
77    #[staticmethod]
78    pub fn model_validate_json(json_string: String) -> Result<LLMEvalResults, EvaluationError> {
79        // deserialize the string to a struct
80        Ok(serde_json::from_str(&json_string)?)
81    }
82
83    #[pyo3(signature = (polars=false))]
84    pub fn to_dataframe<'py>(
85        &mut self,
86        py: Python<'py>,
87        polars: bool,
88    ) -> Result<Bound<'py, PyAny>, EvaluationError> {
89        if self.array_dataset.is_none() {
90            self.build_array_dataset()?;
91        }
92
93        let dataset = self.array_dataset.as_ref().unwrap();
94        let records = array_to_dict(py, dataset)?;
95
96        let module = if polars { "polars" } else { "pandas" };
97
98        let df_module = py.import(module)?;
99        let df_class = df_module.getattr("DataFrame")?;
100
101        if polars {
102            Ok(df_class.call1((records,))?)
103        } else {
104            Ok(df_class.call_method1("from_dict", (records,))?)
105        }
106    }
107
108    #[getter]
109    pub fn cluster_data(&self) -> Option<ClusterData> {
110        self.cluster_data.clone()
111    }
112
113    #[getter]
114    pub fn successful_count(&self) -> usize {
115        self.results.len()
116    }
117
118    #[getter]
119    pub fn failed_count(&self) -> usize {
120        self.errored_tasks.len()
121    }
122}
123
124impl LLMEvalResults {
125    /// Finalize the results by performing post-processing steps which includes:
126    /// - Post-processing embeddings (if any)
127    /// - Building the array dataset (if not already built)
128    /// - Performing clustering and dimensionality reduction (if enabled) for visualization
129    /// # Arguments
130    /// * `config` - The evaluation configuration that dictates post-processing behavior
131    /// # Returns
132    /// * `Result<(), EvaluationError>` - Returns Ok(()) if successful, otherwise returns
133    pub fn finalize(&mut self, config: &Arc<EvaluationConfig>) -> Result<(), EvaluationError> {
134        // Post-process embeddings if needed
135        if !config.embedding_targets.is_empty() {
136            post_process(self, config);
137        }
138
139        if config.compute_histograms {
140            self.build_array_dataset()?;
141
142            // Compute histograms for all numeric fields
143            if let Some(array_dataset) = &self.array_dataset {
144                let profiler = NumProfiler::new();
145                let histograms = profiler.compute_histogram(
146                    &array_dataset.data.view(),
147                    &array_dataset.feature_names,
148                    &10,
149                    false,
150                )?;
151                self.histograms = Some(histograms);
152            }
153        }
154
155        Ok(())
156    }
157
158    /// Build an NDArray dataset from the result tasks
159    fn build_array_dataset(&mut self) -> Result<(), EvaluationError> {
160        if self.array_dataset.is_none() {
161            self.array_dataset = Some(ArrayDataset::from_results(self)?);
162        }
163        Ok(())
164    }
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
168#[pyclass]
169pub struct ClusterData {
170    #[pyo3(get)]
171    pub x: Vec<f64>,
172    #[pyo3(get)]
173    pub y: Vec<f64>,
174    #[pyo3(get)]
175    pub clusters: Vec<i32>,
176    pub idx_map: HashMap<usize, String>,
177}
178
179impl ClusterData {
180    pub fn new(
181        x: Vec<f64>,
182        y: Vec<f64>,
183        clusters: Vec<i32>,
184        idx_map: HashMap<usize, String>,
185    ) -> Self {
186        ClusterData {
187            x,
188            y,
189            clusters,
190            idx_map,
191        }
192    }
193}
194
195#[derive(Debug)]
196pub struct ArrayDataset {
197    pub data: Array2<f64>,
198    pub feature_names: Vec<String>,
199    pub idx_map: HashMap<usize, String>,
200    pub clusters: Vec<i32>,
201}
202
203impl Default for ArrayDataset {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209impl ArrayDataset {
210    pub fn new() -> Self {
211        Self {
212            data: Array2::zeros((0, 0)),
213            feature_names: Vec::new(),
214            idx_map: HashMap::new(),
215            clusters: vec![],
216        }
217    }
218
219    /// Build feature names from the results keys
220    /// This is used when constructing a dataframe from the results and when writing records
221    /// to the server
222    fn build_feature_names(results: &LLMEvalResults) -> Result<Vec<String>, EvaluationError> {
223        let first_task = results
224            .results
225            .values()
226            .next()
227            .ok_or(EvaluationError::NoResultsFound)?;
228
229        let mut names = Vec::new();
230
231        // BTreeMap iteration is already sorted
232        names.extend(first_task.metrics.keys().cloned());
233        names.extend(first_task.mean_embeddings.keys().cloned());
234        names.extend(first_task.similarity_scores.keys().cloned());
235
236        Ok(names)
237    }
238
239    fn from_results(results: &LLMEvalResults) -> Result<Self, EvaluationError> {
240        if results.results.is_empty() {
241            return Ok(Self::new());
242        }
243
244        let feature_names = Self::build_feature_names(results)?;
245        let n_rows = results.results.len();
246        let n_cols = feature_names.len();
247
248        let mut data = Vec::with_capacity(n_rows * n_cols);
249        let mut idx_map = HashMap::new();
250
251        // Build data matrix efficiently
252        for (i, task) in results.results.values().enumerate() {
253            idx_map.insert(i, task.id.clone());
254
255            // Collect all values in correct order (metrics, embeddings, similarities)
256            let row: Vec<f64> = feature_names
257                .iter()
258                .map(|name| {
259                    if let Some(score) = task.metrics.get(name) {
260                        score.score as f64
261                    } else if let Some(&mean) = task.mean_embeddings.get(name) {
262                        mean
263                    } else if let Some(&sim) = task.similarity_scores.get(name) {
264                        sim
265                    } else {
266                        0.0 // Default for missing values
267                    }
268                })
269                .collect();
270
271            data.extend(row);
272        }
273
274        let array = Array2::from_shape_vec((n_rows, n_cols), data)?;
275
276        Ok(Self {
277            data: array,
278            feature_names,
279            idx_map,
280            clusters: vec![],
281        })
282    }
283}
284
285impl LLMEvalResults {
286    pub fn new() -> Self {
287        Self {
288            results: HashMap::new(),
289            errored_tasks: Vec::new(),
290            array_dataset: None,
291            cluster_data: None,
292            histograms: None,
293        }
294    }
295}
296
297impl Default for LLMEvalResults {
298    fn default() -> Self {
299        Self::new()
300    }
301}
302
303/// Struct for collecting results from LLM evaluation tasks.
304#[derive(Debug, Clone, Serialize, Deserialize)]
305#[pyclass]
306pub struct LLMEvalTaskResult {
307    #[pyo3(get)]
308    pub id: String,
309
310    #[pyo3(get)]
311    pub metrics: BTreeMap<String, Score>,
312
313    #[pyo3(get)]
314    #[serde(skip)]
315    pub embedding: BTreeMap<String, Vec<f32>>,
316
317    #[pyo3(get)]
318    pub mean_embeddings: BTreeMap<String, f64>,
319
320    #[pyo3(get)]
321    pub similarity_scores: BTreeMap<String, f64>,
322}
323
324#[pymethods]
325impl LLMEvalTaskResult {
326    pub fn __str__(&self) -> String {
327        PyHelperFuncs::__str__(self)
328    }
329}
330
331impl LLMEvalTaskResult {
332    pub fn new(
333        id: String,
334        metrics: BTreeMap<String, Score>,
335        embedding: BTreeMap<String, Vec<f32>>,
336    ) -> Self {
337        Self {
338            id,
339            metrics,
340            embedding,
341            mean_embeddings: BTreeMap::new(),
342            similarity_scores: BTreeMap::new(),
343        }
344    }
345}
346
347#[pyclass]
348#[derive(Clone, Debug)]
349pub struct LLMEvalRecord {
350    pub id: String,
351    pub context: Value,
352}
353
354#[pymethods]
355impl LLMEvalRecord {
356    #[new]
357    #[pyo3(signature = (
358        context,
359        id=None
360    ))]
361
362    /// Creates a new LLMRecord instance.
363    /// The context is either a python dictionary or a pydantic basemodel.
364    pub fn new(
365        py: Python<'_>,
366        context: Bound<'_, PyAny>,
367        id: Option<String>,
368    ) -> Result<Self, EvaluationError> {
369        // check if context is a PyDict or PyObject(Pydantic model)
370        let context_val = if context.is_instance_of::<PyDict>() {
371            pyobject_to_json(&context)?
372        } else if is_pydantic_basemodel(py, &context)? {
373            // Dump pydantic model to dictionary
374            let model = context.call_method0("model_dump")?;
375
376            // Serialize the dictionary to JSON
377            pyobject_to_json(&model)?
378        } else {
379            Err(EvaluationError::MustBeDictOrBaseModel)?
380        };
381
382        let id = id.unwrap_or_else(create_uuid7);
383
384        Ok(LLMEvalRecord {
385            id,
386            context: context_val,
387        })
388    }
389
390    #[getter]
391    pub fn context<'py>(&self, py: Python<'py>) -> Result<Bound<'py, PyAny>, EvaluationError> {
392        Ok(json_to_pyobject_value(py, &self.context)?
393            .into_bound_py_any(py)?
394            .clone())
395    }
396}
397
398#[derive(Debug, Clone, Default)]
399#[pyclass]
400pub struct EvaluationConfig {
401    // optional embedder for embedding-based evaluations
402    pub embedder: Option<Arc<Embedder>>,
403
404    // fields in the record to generate embeddings for
405    pub embedding_targets: Vec<String>,
406
407    // this will compute similarities for all combinations of embeddings in the targets
408    // e.g. if you have targets ["a", "b"], it will compute similarity between a-b
409    pub compute_similarity: bool,
410
411    // whether to run clustering for all scores, embeddings and similarities (if available)
412    pub cluster: bool,
413
414    // whether to compute histograms for all scores, embeddings and similarities (if available)
415    pub compute_histograms: bool,
416}
417
418#[pymethods]
419impl EvaluationConfig {
420    #[new]
421    #[pyo3(signature = (embedder=None, embedding_targets=None, compute_similarity=false, cluster=false, compute_histograms=false))]
422    /// Creates a new EvaluationConfig instance.
423    /// # Arguments
424    /// * `embedder` - Optional reference to a PyEmbedder instance.
425    /// * `embedding_targets` - Optional list of fields in the record to generate embeddings for.
426    /// * `compute_similarity` - Whether to compute similarities between embeddings.
427    /// * `cluster` - Whether to run clustering for all scores, embeddings and similarities (if available).
428    /// * `compute_histograms` - Whether to compute histograms for all scores, embeddings and similarities (if available).
429    /// # Returns
430    /// A new EvaluationConfig instance.
431    fn new(
432        embedder: Option<&Bound<'_, PyAny>>,
433        embedding_targets: Option<Vec<String>>,
434        compute_similarity: bool,
435        cluster: bool,
436        compute_histograms: bool,
437    ) -> Result<Self, EvaluationError> {
438        let embedder = parse_embedder(embedder)?;
439        let embedding_targets = embedding_targets.unwrap_or_default();
440
441        Ok(Self {
442            embedder,
443            embedding_targets,
444            compute_similarity,
445            cluster,
446            compute_histograms,
447        })
448    }
449
450    pub fn needs_post_processing(&self) -> bool {
451        !self.embedding_targets.is_empty() || self.cluster
452    }
453}