scouter_evaluate/
utils.rs

1use crate::error::EvaluationError;
2use crate::evaluate::evaluator::GenAIEvaluator;
3use crate::evaluate::types::{EvaluationConfig, GenAIEvalResults};
4use crate::genai::GenAIEvalDataset;
5use crate::tasks::evaluator::FieldEvaluator;
6use itertools::iproduct;
7use num_traits::FromPrimitive;
8use potato_head::{Embedder, EmbeddingInput, PyEmbedder};
9use pyo3::prelude::*;
10use rayon::prelude::*;
11use scouter_types::genai::GenAIEvalSet;
12use scouter_types::GenAIEvalRecord;
13use serde_json::Value;
14use simsimd::SpatialSimilarity;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use tokio::task::JoinSet;
18use tracing::{debug, error, warn};
19
20type EvalTaskResult = (
21    usize, // Index into records array
22    Result<(GenAIEvalSet, BTreeMap<String, Vec<f32>>), String>,
23);
24
25/// Spawn tasks without embedding support
26/// This function will spawn a task that runs the workflows and extracts the results
27/// If there is an error during workflow execution, it will log the error and return None for that record
28/// # Arguments
29/// * `workflow` - The workflow to execute for each record.
30/// * `records` - The list of GenAIEvalRecords to process.
31/// # Returns
32/// A JoinSet containing tuples of record ID and optional GenAIEvalTaskResult.
33pub async fn spawn_evaluation_tasks_without_embeddings(
34    dataset: &GenAIEvalDataset,
35    _config: &Arc<EvaluationConfig>,
36) -> JoinSet<EvalTaskResult> {
37    let mut join_set = JoinSet::new();
38
39    for (idx, _) in dataset.records.iter().enumerate() {
40        // cloning here so we can reference inside async move
41        let record_ref = dataset.records.clone();
42        let profile_ref = dataset.profile.clone();
43
44        join_set.spawn(async move {
45            // Access record by index - no cloning
46            let record = &record_ref[idx];
47
48            debug!(
49                "Starting evaluation for record {} and index {}",
50                record.uid, idx
51            );
52
53            let result = match GenAIEvaluator::process_event_record(record, profile_ref).await {
54                Ok(eval_set) => Ok((eval_set, BTreeMap::new())),
55                Err(e) => Err(format!("Evaluation failed: {}", e)),
56            };
57
58            (idx, result)
59        });
60    }
61
62    join_set
63}
64
65/// Spawn tasks to run evaluation workflows with embedding calculations
66/// # Arguments
67/// * `dataset` - The GenAIEvalDataset containing records to evaluate.
68/// * `embedder` - The Embedder instance to use for generating embeddings.
69/// * `config` - The EvaluationConfig containing evaluation settings.
70/// # Returns
71/// A JoinSet containing GenAIEvalTaskResults for each record.
72pub async fn spawn_evaluation_tasks_with_embeddings(
73    dataset: &GenAIEvalDataset,
74    embedder: Arc<Embedder>,
75    config: &Arc<EvaluationConfig>,
76) -> JoinSet<EvalTaskResult> {
77    let mut join_set = JoinSet::new();
78
79    for (idx, _) in dataset.records.iter().enumerate() {
80        let record_ref = dataset.records.clone();
81        let profile_ref = dataset.profile.clone();
82        let embedder_ref = embedder.clone();
83        let config_ref = config.clone();
84
85        join_set.spawn(async move {
86            let record = &record_ref[idx];
87
88            // Generate embeddings
89            let embeddings = generate_embeddings_for_record(
90                record,
91                &embedder_ref,
92                &config_ref.embedding_targets,
93            )
94            .await;
95
96            // Execute evaluation
97            let result = match GenAIEvaluator::process_event_record(record, profile_ref).await {
98                Ok(eval_set) => Ok((eval_set, embeddings)),
99                Err(e) => Err(format!("Evaluation failed: {}", e)),
100            };
101
102            (idx, result)
103        });
104    }
105
106    join_set
107}
108
109/// Helper for extracting embeddings for a single record. Used in the genai evaulation workflow.
110/// # Arguments
111/// * `record` - The GenAIEvalRecord to extract embeddings from.
112/// * `embedder` - The Embedder instance to use for generating embeddings.
113/// * `embedding_targets` - The list of keys in the record's context to generate embeddings for.
114/// # Returns
115pub async fn generate_embeddings_for_record(
116    record: &GenAIEvalRecord,
117    embedder: &Arc<Embedder>,
118    embedding_targets: &[String],
119) -> BTreeMap<String, Vec<f32>> {
120    let mut embeddings = BTreeMap::new();
121
122    for target in embedding_targets {
123        match FieldEvaluator::extract_field_value(&record.context, target) {
124            Ok(value) => {
125                let text = match value {
126                    Value::String(s) => Some(s.clone()),
127                    Value::Array(_) | Value::Object(_) => serde_json::to_string(value).ok(),
128                    _ => {
129                        warn!(
130                            "Field '{}' has unsupported type for embedding: {:?}",
131                            target, value
132                        );
133                        None
134                    }
135                };
136
137                if let Some(text) = text {
138                    match embedder.embed(EmbeddingInput::Texts(vec![text])).await {
139                        Ok(embedding_response) => match embedding_response.values() {
140                            Ok(values) => {
141                                embeddings.insert(target.clone(), values.to_vec());
142                            }
143                            Err(e) => {
144                                error!(
145                                    "Failed to extract embedding values for target '{}': {:?}",
146                                    target, e
147                                );
148                            }
149                        },
150                        Err(e) => {
151                            error!(
152                                "Failed to generate embedding for target '{}': {:?}",
153                                target, e
154                            );
155                        }
156                    }
157                }
158            }
159            Err(e) => {
160                warn!("Failed to extract field '{}' for embedding: {}", target, e);
161            }
162        }
163    }
164
165    embeddings
166}
167/// Collect and align results with original records
168pub async fn collect_and_align_results(
169    mut join_set: JoinSet<EvalTaskResult>,
170    records: &Arc<Vec<GenAIEvalRecord>>,
171) -> Result<GenAIEvalResults, EvaluationError> {
172    let mut results = GenAIEvalResults::new();
173
174    while let Some(join_result) = join_set.join_next().await {
175        match join_result {
176            Ok((idx, eval_result)) => {
177                let record = &records[idx];
178
179                match eval_result {
180                    Ok((eval_set, embeddings)) => {
181                        results.add_success(record, eval_set, embeddings);
182                    }
183                    Err(error_msg) => {
184                        results.add_failure(record, error_msg);
185                    }
186                }
187            }
188            Err(join_error) => {
189                error!("Task join error: {:?}", join_error);
190            }
191        }
192    }
193
194    Ok(results)
195}
196
197/// Post-process aligned results
198pub fn post_process_aligned_results(
199    results: &mut GenAIEvalResults,
200    config: &Arc<EvaluationConfig>,
201) -> Result<(), EvaluationError> {
202    results.aligned_results.par_iter_mut().for_each(|aligned| {
203        // Compute embedding means
204        for (target, values) in aligned.embeddings.iter() {
205            if let Some(mean) = compute_mean(values) {
206                aligned.mean_embeddings.insert(target.clone(), mean);
207            }
208        }
209
210        // Compute similarities
211        if config.compute_similarity {
212            compute_similarity(
213                &config.embedding_targets,
214                &aligned.embeddings,
215                &mut aligned.similarity_scores,
216            );
217        }
218    });
219
220    Ok(())
221}
222
223/// Helper function for extracting embedder and runtime from optional PyEmbedder
224/// # Arguments
225/// * `embedder` - Optional reference to a PyEmbedder instance.
226/// # Returns
227/// An optional Arc-wrapped Embedder instance if provided, otherwise None.
228pub fn parse_embedder(
229    embedder: Option<&Bound<'_, PyAny>>,
230) -> Result<Option<Arc<Embedder>>, EvaluationError> {
231    // Extract embedder and runtime if PyEmbedder is provided
232    let embedder_arc = if let Some(embedder_bound) = embedder {
233        if embedder_bound.is_instance_of::<PyEmbedder>() {
234            let py_embedder = embedder_bound.extract::<PyEmbedder>()?;
235            Some(py_embedder.embedder.clone())
236        } else {
237            // embedder provided but not a PyEmbedder instance
238            return Err(EvaluationError::InvalidEmbedderType);
239        }
240    } else {
241        None
242    };
243    Ok(embedder_arc)
244}
245
246/// Calculate the mean of for a slice of f32 values
247/// There's no need for a generic implementation here, as we only need f32 for embeddings
248pub fn compute_mean(vec: &[f32]) -> Option<f64> {
249    match vec.len() {
250        0 => None,
251        _ => {
252            let sum = vec.iter().sum::<f32>();
253            let length = f32::from_usize(vec.len())?;
254
255            let mean = sum / length;
256            Some(mean as f64)
257        }
258    }
259}
260
261pub fn compute_similarity(
262    targets: &Vec<String>,
263    embeddings: &BTreeMap<String, Vec<f32>>,
264    scores: &mut BTreeMap<String, f64>,
265) {
266    for (a, b) in iproduct!(targets, targets) {
267        // only want unique pairs
268        if a == b {
269            continue;
270        }
271        if let (Some(vec_a), Some(vec_b)) = (embeddings.get(a), embeddings.get(b)) {
272            if vec_a.len() != vec_b.len() {
273                warn!(
274                    "Embedding length mismatch for targets {} and {}: {} vs {}",
275                    a,
276                    b,
277                    vec_a.len(),
278                    vec_b.len()
279                );
280                continue;
281            }
282
283            let similarity = f32::cosine(vec_a, vec_b).unwrap_or(-1.0);
284            let key = format!("{}_{}_cosine", a, b);
285            scores.insert(key, similarity);
286        } else {
287            warn!("Missing embeddings for targets {} or {}", a, b);
288        }
289    }
290}