1use crate::error::EvaluationError;
2use crate::evaluate::evaluator::GenAIEvaluator;
3use crate::evaluate::types::{EvaluationConfig, GenAIEvalResults};
4use crate::genai::GenAIEvalDataset;
5use crate::tasks::evaluator::FieldEvaluator;
6use itertools::iproduct;
7use num_traits::FromPrimitive;
8use potato_head::{Embedder, EmbeddingInput, PyEmbedder};
9use pyo3::prelude::*;
10use rayon::prelude::*;
11use scouter_types::genai::GenAIEvalSet;
12use scouter_types::GenAIEvalRecord;
13use serde_json::Value;
14use simsimd::SpatialSimilarity;
15use std::collections::BTreeMap;
16use std::sync::Arc;
17use tokio::task::JoinSet;
18use tracing::{debug, error, warn};
19
20type EvalTaskResult = (
21 usize, Result<(GenAIEvalSet, BTreeMap<String, Vec<f32>>), String>,
23);
24
25pub async fn spawn_evaluation_tasks_without_embeddings(
34 dataset: &GenAIEvalDataset,
35 _config: &Arc<EvaluationConfig>,
36) -> JoinSet<EvalTaskResult> {
37 let mut join_set = JoinSet::new();
38
39 for (idx, _) in dataset.records.iter().enumerate() {
40 let record_ref = dataset.records.clone();
42 let profile_ref = dataset.profile.clone();
43
44 join_set.spawn(async move {
45 let record = &record_ref[idx];
47
48 debug!(
49 "Starting evaluation for record {} and index {}",
50 record.uid, idx
51 );
52
53 let result = match GenAIEvaluator::process_event_record(record, profile_ref).await {
54 Ok(eval_set) => Ok((eval_set, BTreeMap::new())),
55 Err(e) => Err(format!("Evaluation failed: {}", e)),
56 };
57
58 (idx, result)
59 });
60 }
61
62 join_set
63}
64
65pub async fn spawn_evaluation_tasks_with_embeddings(
73 dataset: &GenAIEvalDataset,
74 embedder: Arc<Embedder>,
75 config: &Arc<EvaluationConfig>,
76) -> JoinSet<EvalTaskResult> {
77 let mut join_set = JoinSet::new();
78
79 for (idx, _) in dataset.records.iter().enumerate() {
80 let record_ref = dataset.records.clone();
81 let profile_ref = dataset.profile.clone();
82 let embedder_ref = embedder.clone();
83 let config_ref = config.clone();
84
85 join_set.spawn(async move {
86 let record = &record_ref[idx];
87
88 let embeddings = generate_embeddings_for_record(
90 record,
91 &embedder_ref,
92 &config_ref.embedding_targets,
93 )
94 .await;
95
96 let result = match GenAIEvaluator::process_event_record(record, profile_ref).await {
98 Ok(eval_set) => Ok((eval_set, embeddings)),
99 Err(e) => Err(format!("Evaluation failed: {}", e)),
100 };
101
102 (idx, result)
103 });
104 }
105
106 join_set
107}
108
109pub async fn generate_embeddings_for_record(
116 record: &GenAIEvalRecord,
117 embedder: &Arc<Embedder>,
118 embedding_targets: &[String],
119) -> BTreeMap<String, Vec<f32>> {
120 let mut embeddings = BTreeMap::new();
121
122 for target in embedding_targets {
123 match FieldEvaluator::extract_field_value(&record.context, target) {
124 Ok(value) => {
125 let text = match value {
126 Value::String(s) => Some(s.clone()),
127 Value::Array(_) | Value::Object(_) => serde_json::to_string(value).ok(),
128 _ => {
129 warn!(
130 "Field '{}' has unsupported type for embedding: {:?}",
131 target, value
132 );
133 None
134 }
135 };
136
137 if let Some(text) = text {
138 match embedder.embed(EmbeddingInput::Texts(vec![text])).await {
139 Ok(embedding_response) => match embedding_response.values() {
140 Ok(values) => {
141 embeddings.insert(target.clone(), values.to_vec());
142 }
143 Err(e) => {
144 error!(
145 "Failed to extract embedding values for target '{}': {:?}",
146 target, e
147 );
148 }
149 },
150 Err(e) => {
151 error!(
152 "Failed to generate embedding for target '{}': {:?}",
153 target, e
154 );
155 }
156 }
157 }
158 }
159 Err(e) => {
160 warn!("Failed to extract field '{}' for embedding: {}", target, e);
161 }
162 }
163 }
164
165 embeddings
166}
167pub async fn collect_and_align_results(
169 mut join_set: JoinSet<EvalTaskResult>,
170 records: &Arc<Vec<GenAIEvalRecord>>,
171) -> Result<GenAIEvalResults, EvaluationError> {
172 let mut results = GenAIEvalResults::new();
173
174 while let Some(join_result) = join_set.join_next().await {
175 match join_result {
176 Ok((idx, eval_result)) => {
177 let record = &records[idx];
178
179 match eval_result {
180 Ok((eval_set, embeddings)) => {
181 results.add_success(record, eval_set, embeddings);
182 }
183 Err(error_msg) => {
184 results.add_failure(record, error_msg);
185 }
186 }
187 }
188 Err(join_error) => {
189 error!("Task join error: {:?}", join_error);
190 }
191 }
192 }
193
194 Ok(results)
195}
196
197pub fn post_process_aligned_results(
199 results: &mut GenAIEvalResults,
200 config: &Arc<EvaluationConfig>,
201) -> Result<(), EvaluationError> {
202 results.aligned_results.par_iter_mut().for_each(|aligned| {
203 for (target, values) in aligned.embeddings.iter() {
205 if let Some(mean) = compute_mean(values) {
206 aligned.mean_embeddings.insert(target.clone(), mean);
207 }
208 }
209
210 if config.compute_similarity {
212 compute_similarity(
213 &config.embedding_targets,
214 &aligned.embeddings,
215 &mut aligned.similarity_scores,
216 );
217 }
218 });
219
220 Ok(())
221}
222
223pub fn parse_embedder(
229 embedder: Option<&Bound<'_, PyAny>>,
230) -> Result<Option<Arc<Embedder>>, EvaluationError> {
231 let embedder_arc = if let Some(embedder_bound) = embedder {
233 if embedder_bound.is_instance_of::<PyEmbedder>() {
234 let py_embedder = embedder_bound.extract::<PyEmbedder>()?;
235 Some(py_embedder.embedder.clone())
236 } else {
237 return Err(EvaluationError::InvalidEmbedderType);
239 }
240 } else {
241 None
242 };
243 Ok(embedder_arc)
244}
245
246pub fn compute_mean(vec: &[f32]) -> Option<f64> {
249 match vec.len() {
250 0 => None,
251 _ => {
252 let sum = vec.iter().sum::<f32>();
253 let length = f32::from_usize(vec.len())?;
254
255 let mean = sum / length;
256 Some(mean as f64)
257 }
258 }
259}
260
261pub fn compute_similarity(
262 targets: &Vec<String>,
263 embeddings: &BTreeMap<String, Vec<f32>>,
264 scores: &mut BTreeMap<String, f64>,
265) {
266 for (a, b) in iproduct!(targets, targets) {
267 if a == b {
269 continue;
270 }
271 if let (Some(vec_a), Some(vec_b)) = (embeddings.get(a), embeddings.get(b)) {
272 if vec_a.len() != vec_b.len() {
273 warn!(
274 "Embedding length mismatch for targets {} and {}: {} vs {}",
275 a,
276 b,
277 vec_a.len(),
278 vec_b.len()
279 );
280 continue;
281 }
282
283 let similarity = f32::cosine(vec_a, vec_b).unwrap_or(-1.0);
284 let key = format!("{}_{}_cosine", a, b);
285 scores.insert(key, similarity);
286 } else {
287 warn!("Missing embeddings for targets {} or {}", a, b);
288 }
289 }
290}