1use crate::error::EvaluationError;
2use crate::evaluate::compare::compare_results;
3use crate::utils::parse_embedder;
4use crate::utils::post_process_aligned_results;
5use ndarray::Array2;
6use owo_colors::OwoColorize;
7use potato_head::Embedder;
8use potato_head::PyHelperFuncs;
9use pyo3::prelude::*;
10use pyo3::types::IntoPyDict;
11use pyo3::types::PyDict;
12use scouter_profile::{Histogram, NumProfiler};
13use scouter_types::genai::GenAIEvalSet;
14use scouter_types::{GenAIEvalRecord, TaskResultTableEntry, WorkflowResultTableEntry};
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::{BTreeMap, HashMap};
18use std::sync::Arc;
19use tabled::Tabled;
20use tabled::{
21 settings::{object::Rows, Alignment, Color, Format, Style},
22 Table,
23};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[pyclass]
27pub struct MissingTask {
28 #[pyo3(get)]
29 pub task_id: String,
30
31 #[pyo3(get)]
32 pub present_in: String,
33
34 #[pyo3(get)]
35 pub record_id: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[pyclass]
40pub struct TaskComparison {
41 #[pyo3(get)]
42 pub task_id: String,
43
44 #[pyo3(get)]
45 pub record_id: String,
46
47 #[pyo3(get)]
48 pub baseline_passed: bool,
49
50 #[pyo3(get)]
51 pub comparison_passed: bool,
52
53 #[pyo3(get)]
54 pub status_changed: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58#[pyclass]
59pub struct WorkflowComparison {
60 #[pyo3(get)]
61 pub baseline_id: String,
62
63 #[pyo3(get)]
64 pub comparison_id: String,
65
66 #[pyo3(get)]
67 pub baseline_pass_rate: f64,
68
69 #[pyo3(get)]
70 pub comparison_pass_rate: f64,
71
72 #[pyo3(get)]
73 pub pass_rate_delta: f64,
74
75 #[pyo3(get)]
76 pub is_regression: bool,
77
78 #[pyo3(get)]
79 pub task_comparisons: Vec<TaskComparison>,
80}
81
82#[derive(Tabled)]
83struct WorkflowComparisonEntry {
84 #[tabled(rename = "Baseline ID")]
85 baseline_id: String,
86 #[tabled(rename = "Comparison ID")]
87 comparison_id: String,
88 #[tabled(rename = "Baseline Pass Rate")]
89 baseline_pass_rate: String,
90 #[tabled(rename = "Comparison Pass Rate")]
91 comparison_pass_rate: String,
92 #[tabled(rename = "Delta")]
93 delta: String,
94 #[tabled(rename = "Status")]
95 status: String,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99#[pyclass]
100pub struct TaskAggregateStats {
101 #[pyo3(get)]
102 pub task_id: String,
103
104 #[pyo3(get)]
105 pub workflows_evaluated: usize,
106
107 #[pyo3(get)]
108 pub baseline_pass_count: usize,
109
110 #[pyo3(get)]
111 pub comparison_pass_count: usize,
112
113 #[pyo3(get)]
114 pub status_changed_count: usize,
115
116 #[pyo3(get)]
117 pub baseline_pass_rate: f64,
118
119 #[pyo3(get)]
120 pub comparison_pass_rate: f64,
121}
122
123#[derive(Tabled)]
124struct TaskAggregateEntry {
125 #[tabled(rename = "Task ID")]
126 task_id: String,
127 #[tabled(rename = "Workflows")]
128 workflows: String,
129 #[tabled(rename = "Baseline Pass Rate")]
130 baseline_rate: String,
131 #[tabled(rename = "Comparison Pass Rate")]
132 comparison_rate: String,
133 #[tabled(rename = "Delta")]
134 delta: String,
135 #[tabled(rename = "Status Changes")]
136 changes: String,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140#[pyclass]
141pub struct ComparisonResults {
142 #[pyo3(get)]
143 pub workflow_comparisons: Vec<WorkflowComparison>,
144
145 #[pyo3(get)]
146 pub total_workflows: usize,
147
148 #[pyo3(get)]
149 pub improved_workflows: usize,
150
151 #[pyo3(get)]
152 pub regressed_workflows: usize,
153
154 #[pyo3(get)]
155 pub unchanged_workflows: usize,
156
157 #[pyo3(get)]
158 pub mean_pass_rate_delta: f64,
159
160 #[pyo3(get)]
161 pub task_status_changes: Vec<TaskComparison>,
162
163 #[pyo3(get)]
164 pub missing_tasks: Vec<MissingTask>,
165
166 #[pyo3(get)]
167 pub baseline_workflow_count: usize,
168
169 #[pyo3(get)]
170 pub comparison_workflow_count: usize,
171
172 #[pyo3(get)]
173 pub regressed: bool,
174}
175
176#[pymethods]
177impl ComparisonResults {
178 #[getter]
179 pub fn task_aggregate_stats(&self) -> Vec<TaskAggregateStats> {
180 let mut task_stats: HashMap<String, (usize, usize, usize, usize)> = HashMap::new();
181
182 for wc in &self.workflow_comparisons {
183 for tc in &wc.task_comparisons {
184 let entry = task_stats.entry(tc.task_id.clone()).or_insert((0, 0, 0, 0));
185 entry.0 += 1; if tc.baseline_passed {
187 entry.1 += 1; }
189 if tc.comparison_passed {
190 entry.2 += 1; }
192 if tc.status_changed {
193 entry.3 += 1; }
195 }
196 }
197
198 task_stats
199 .into_iter()
200 .map(
201 |(task_id, (total, baseline_pass, comparison_pass, changed))| TaskAggregateStats {
202 task_id,
203 workflows_evaluated: total,
204 baseline_pass_count: baseline_pass,
205 comparison_pass_count: comparison_pass,
206 status_changed_count: changed,
207 baseline_pass_rate: baseline_pass as f64 / total as f64,
208 comparison_pass_rate: comparison_pass as f64 / total as f64,
209 },
210 )
211 .collect()
212 }
213
214 #[getter]
215 pub fn has_missing_tasks(&self) -> bool {
216 !self.missing_tasks.is_empty()
217 }
218
219 pub fn print_missing_tasks(&self) {
220 if !self.missing_tasks.is_empty() {
221 println!("\n{}", "⚠ Missing Tasks".yellow().bold());
222
223 let baseline_only: Vec<_> = self
224 .missing_tasks
225 .iter()
226 .filter(|t| t.present_in == "baseline_only")
227 .collect();
228
229 let comparison_only: Vec<_> = self
230 .missing_tasks
231 .iter()
232 .filter(|t| t.present_in == "comparison_only")
233 .collect();
234
235 if !baseline_only.is_empty() {
236 println!(" Baseline only ({} tasks):", baseline_only.len());
237 for task in baseline_only {
238 println!(" - {} - {}", task.task_id, task.record_id);
239 }
240 }
241
242 if !comparison_only.is_empty() {
243 println!(" Comparison only ({} tasks):", comparison_only.len());
244 for task in comparison_only {
245 println!(" - {} - {}", task.task_id, task.record_id);
246 }
247 }
248 }
249 }
250
251 pub fn __str__(&self) -> String {
252 PyHelperFuncs::__str__(self)
253 }
254
255 pub fn as_table(&self) {
256 self.print_summary_table();
257
258 if !self.task_status_changes.is_empty() {
259 println!(
260 "\n{}",
261 "Task Status Changes (Workflow-Specific)"
262 .truecolor(245, 77, 85)
263 .bold()
264 );
265 self.print_status_changes_table();
266 }
267
268 self.print_task_aggregate_table();
269
270 if self.has_missing_tasks() {
271 self.print_missing_tasks();
272 }
273
274 self.print_summary_stats();
275 }
276
277 fn print_task_aggregate_table(&self) {
278 let stats = self.task_aggregate_stats();
279
280 if stats.is_empty() {
281 return;
282 }
283
284 println!(
285 "\n{}",
286 "Task Aggregate Stats (Cross-Workflow)"
287 .truecolor(245, 77, 85)
288 .bold()
289 );
290
291 let entries: Vec<_> = stats
292 .iter()
293 .map(|ts| {
294 let baseline_rate = format!("{:.1}%", ts.baseline_pass_rate * 100.0);
295 let comparison_rate = format!("{:.1}%", ts.comparison_pass_rate * 100.0);
296 let delta_val = (ts.comparison_pass_rate - ts.baseline_pass_rate) * 100.0;
297 let delta_str = format!("{:+.1}%", delta_val);
298
299 let colored_delta = if delta_val > 1.0 {
300 delta_str.green().to_string()
301 } else if delta_val < -1.0 {
302 delta_str.red().to_string()
303 } else {
304 delta_str.yellow().to_string()
305 };
306
307 let change_pct = if ts.workflows_evaluated > 0 {
308 (ts.status_changed_count as f64 / ts.workflows_evaluated as f64) * 100.0
309 } else {
310 0.0
311 };
312
313 TaskAggregateEntry {
314 task_id: ts.task_id.clone(),
315 workflows: ts.workflows_evaluated.to_string(),
316 baseline_rate,
317 comparison_rate,
318 delta: colored_delta,
319 changes: format!(
320 "{}/{} ({:.0}%)",
321 ts.status_changed_count, ts.workflows_evaluated, change_pct
322 ),
323 }
324 })
325 .collect();
326
327 let mut table = Table::new(entries);
328 table.with(Style::sharp());
329
330 table.modify(
331 Rows::new(0..1),
332 (
333 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
334 Alignment::center(),
335 Color::BOLD,
336 ),
337 );
338
339 println!("{}", table);
340 }
341
342 fn print_summary_table(&self) {
343 let entries: Vec<_> = self
344 .workflow_comparisons
345 .iter()
346 .map(|wc| {
347 let baseline_rate = format!("{:.1}%", wc.baseline_pass_rate * 100.0);
348 let comparison_rate = format!("{:.1}%", wc.comparison_pass_rate * 100.0);
349 let delta_val = wc.pass_rate_delta * 100.0;
350 let delta_str = format!("{:+.1}%", delta_val);
351
352 let colored_delta = if delta_val > 1.0 {
353 delta_str.green().to_string()
354 } else if delta_val < -1.0 {
355 delta_str.red().to_string()
356 } else {
357 delta_str.yellow().to_string()
358 };
359
360 let status = if wc.is_regression {
361 "Regressed".red().to_string()
362 } else if wc.pass_rate_delta > 0.01 {
363 "Improved".green().to_string()
364 } else {
365 "Unchanged".yellow().to_string()
366 };
367
368 WorkflowComparisonEntry {
369 baseline_id: wc.baseline_id[..16.min(wc.baseline_id.len())]
370 .to_string()
371 .truecolor(249, 179, 93)
372 .to_string(),
373 comparison_id: wc.comparison_id[..16.min(wc.comparison_id.len())]
374 .to_string()
375 .truecolor(249, 179, 93)
376 .to_string(),
377 baseline_pass_rate: baseline_rate,
378 comparison_pass_rate: comparison_rate,
379 delta: colored_delta,
380 status,
381 }
382 })
383 .collect();
384
385 let mut table = Table::new(entries);
386 table.with(Style::sharp());
387
388 table.modify(
389 Rows::new(0..1),
390 (
391 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
392 Alignment::center(),
393 Color::BOLD,
394 ),
395 );
396
397 println!("{}", table);
398 }
399
400 fn print_status_changes_table(&self) {
401 let entries: Vec<_> = self
402 .task_status_changes
403 .iter()
404 .map(|tc| {
405 let baseline_status = if tc.baseline_passed {
406 "✓ Pass".green().to_string()
407 } else {
408 "✗ Fail".red().to_string()
409 };
410
411 let comparison_status = if tc.comparison_passed {
412 "✓ Pass".green().to_string()
413 } else {
414 "✗ Fail".red().to_string()
415 };
416
417 let change = match (tc.baseline_passed, tc.comparison_passed) {
418 (true, false) => "Pass → Fail".red().bold().to_string(),
419 (false, true) => "Fail → Pass".green().bold().to_string(),
420 _ => "No Change".yellow().to_string(),
421 };
422
423 TaskStatusChangeEntry {
424 task_id: tc.task_id.clone(),
425 baseline_status,
426 comparison_status,
427 change,
428 }
429 })
430 .collect();
431
432 let mut table = Table::new(entries);
433 table.with(Style::sharp());
434
435 table.modify(
436 Rows::new(0..1),
437 (
438 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
439 Alignment::center(),
440 Color::BOLD,
441 ),
442 );
443
444 println!("{}", table);
445 }
446
447 fn print_summary_stats(&self) {
448 println!("\n{}", "Summary".truecolor(245, 77, 85).bold());
449
450 let regression_indicator = if self.regressed {
451 "⚠️ REGRESSION DETECTED".red().bold().to_string()
452 } else if self.improved_workflows > 0 {
453 "✅ IMPROVEMENT DETECTED".green().bold().to_string()
454 } else {
455 "➡️ NO SIGNIFICANT CHANGE".yellow().bold().to_string()
456 };
457
458 println!(" Overall Status: {}", regression_indicator);
459 println!(" Total Workflows: {}", self.total_workflows);
460 println!(
461 " Improved: {}",
462 self.improved_workflows.to_string().green()
463 );
464 println!(
465 " Regressed: {}",
466 self.regressed_workflows.to_string().red()
467 );
468 println!(
469 " Unchanged: {}",
470 self.unchanged_workflows.to_string().yellow()
471 );
472
473 let mean_delta_str = format!("{:+.2}%", self.mean_pass_rate_delta * 100.0);
474 let colored_mean = if self.mean_pass_rate_delta > 0.0 {
475 mean_delta_str.green().to_string()
476 } else if self.mean_pass_rate_delta < 0.0 {
477 mean_delta_str.red().to_string()
478 } else {
479 mean_delta_str.yellow().to_string()
480 };
481 println!(" Mean Pass Rate Delta: {}", colored_mean);
482 }
483}
484
485impl ComparisonResults {}
486
487#[derive(Tabled)]
488struct TaskStatusChangeEntry {
489 #[tabled(rename = "Task ID")]
490 task_id: String,
491 #[tabled(rename = "Baseline")]
492 baseline_status: String,
493 #[tabled(rename = "Comparison")]
494 comparison_status: String,
495 #[tabled(rename = "Change")]
496 change: String,
497}
498
499#[derive(Debug, Serialize, Deserialize)]
500#[pyclass]
501pub struct GenAIEvalResults {
502 pub aligned_results: Vec<AlignedEvalResult>,
504
505 #[pyo3(get)]
506 pub errored_tasks: Vec<String>,
507
508 pub cluster_data: Option<ClusterData>,
509
510 #[pyo3(get)]
511 pub histograms: Option<HashMap<String, Histogram>>,
512
513 #[serde(skip)]
514 pub array_dataset: Option<ArrayDataset>,
515
516 #[serde(skip)]
517 pub results_by_id: HashMap<String, usize>,
518}
519
520#[pymethods]
521impl GenAIEvalResults {
522 pub fn __getitem__(&self, key: &str) -> Result<AlignedEvalResult, EvaluationError> {
523 self.results_by_id
524 .get(key)
525 .and_then(|&idx| self.aligned_results.get(idx))
526 .cloned()
527 .ok_or_else(|| EvaluationError::MissingKeyError(key.to_string()))
528 }
529
530 #[getter]
531 pub fn successful_count(&self) -> usize {
532 self.aligned_results.iter().filter(|r| r.success).count()
533 }
534
535 #[getter]
536 pub fn failed_count(&self) -> usize {
537 self.aligned_results.iter().filter(|r| !r.success).count()
538 }
539
540 #[pyo3(signature = (polars=false))]
542 pub fn to_dataframe<'py>(
543 &mut self,
544 py: Python<'py>,
545 polars: bool,
546 ) -> Result<Bound<'py, PyAny>, EvaluationError> {
547 let all_task_records: Vec<_> = self
548 .aligned_results
549 .iter()
550 .flat_map(|r| r.to_flat_task_records())
551 .collect();
552
553 if all_task_records.is_empty() {
554 return Err(EvaluationError::NoResultsFound);
555 }
556
557 let py_records = PyDict::new(py);
558
559 let mut all_columns = std::collections::BTreeSet::new();
561 for record in &all_task_records {
562 all_columns.extend(record.keys().cloned());
563 }
564
565 for column_name in all_columns {
567 let column_data: Vec<_> = all_task_records
568 .iter()
569 .map(|record| record.get(&column_name).cloned().unwrap_or(Value::Null))
570 .collect();
571
572 let py_col = pythonize::pythonize(py, &column_data)?;
573 py_records.set_item(&column_name, py_col)?;
574 }
575
576 let module = if polars { "polars" } else { "pandas" };
577 let df_module = py.import(module)?;
578 let df_class = df_module.getattr("DataFrame")?;
579
580 if polars {
581 let schema = self.get_schema_mapping(py)?;
582 let schema_dict = &[("schema", schema)].into_py_dict(py)?;
583 schema_dict.set_item("strict", false)?;
584 Ok(df_class.call((py_records,), Some(schema_dict))?)
585 } else {
586 Ok(df_class.call_method1("from_dict", (py_records,))?)
587 }
588 }
589
590 pub fn __str__(&self) -> String {
591 PyHelperFuncs::__str__(self)
592 }
593
594 #[pyo3(signature = (show_tasks=false))]
595 pub fn as_table(&mut self, show_tasks: bool) {
599 if show_tasks {
600 let tasks_table = self.build_tasks_table();
601 println!("\n{}", "Task Details".truecolor(245, 77, 85).bold());
602 println!("{}", tasks_table);
603 } else {
604 let workflow_table = self.build_workflow_table();
605 println!("\n{}", "Workflow Summary".truecolor(245, 77, 85).bold());
606 println!("{}", workflow_table);
607 }
608 }
609
610 pub fn model_dump_json(&self) -> String {
611 PyHelperFuncs::__json__(self)
612 }
613
614 #[staticmethod]
615 pub fn model_validate_json(json_string: String) -> Result<Self, EvaluationError> {
616 Ok(serde_json::from_str(&json_string)?)
617 }
618
619 #[pyo3(signature = (baseline, regression_threshold=0.05))]
622 pub fn compare_to(
623 &self,
624 baseline: &GenAIEvalResults,
625 regression_threshold: f64,
626 ) -> Result<ComparisonResults, EvaluationError> {
627 compare_results(baseline, self, regression_threshold)
628 }
629}
630
631impl GenAIEvalResults {
632 fn get_schema_mapping<'py>(
633 &self,
634 py: Python<'py>,
635 ) -> Result<Bound<'py, PyDict>, EvaluationError> {
636 let schema = PyDict::new(py);
637 let pl = py.import("polars")?;
638
639 schema.set_item("created_at", pl.getattr("Utf8")?)?;
640 schema.set_item("record_uid", pl.getattr("Utf8")?)?;
641 schema.set_item("success", pl.getattr("Boolean")?)?;
642 schema.set_item("workflow_error", pl.getattr("Utf8")?)?;
643
644 schema.set_item("workflow_total_tasks", pl.getattr("Int64")?)?;
645 schema.set_item("workflow_passed_tasks", pl.getattr("Int64")?)?;
646 schema.set_item("workflow_failed_tasks", pl.getattr("Int64")?)?;
647 schema.set_item("workflow_pass_rate", pl.getattr("Float64")?)?;
648 schema.set_item("workflow_duration_ms", pl.getattr("Int64")?)?;
649
650 schema.set_item("task_id", pl.getattr("Utf8")?)?;
651 schema.set_item("task_type", pl.getattr("Utf8")?)?;
652 schema.set_item("task_passed", pl.getattr("Boolean")?)?;
653 schema.set_item("task_value", pl.getattr("Float64")?)?;
654 schema.set_item("task_message", pl.getattr("Utf8")?)?;
655 schema.set_item("task_field_path", pl.getattr("Utf8")?)?;
656 schema.set_item("task_operator", pl.getattr("Utf8")?)?;
657 schema.set_item("task_expected", pl.getattr("Utf8")?)?;
658 schema.set_item("task_actual", pl.getattr("Utf8")?)?;
659
660 schema.set_item("context", pl.getattr("Utf8")?)?;
661 schema.set_item("embedding_means", pl.getattr("Utf8")?)?;
662 schema.set_item("similarity_scores", pl.getattr("Utf8")?)?;
663
664 Ok(schema)
665 }
666 fn build_workflow_table(&self) -> Table {
668 let entries: Vec<WorkflowResultTableEntry> = self
669 .aligned_results
670 .iter()
671 .flat_map(|result| result.eval_set.build_workflow_entries())
672 .collect();
673
674 let mut table = Table::new(entries);
675 table.with(Style::sharp());
676
677 table.modify(
678 Rows::new(0..1),
679 (
680 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
681 Alignment::center(),
682 Color::BOLD,
683 ),
684 );
685 table
686 }
687
688 fn build_tasks_table(&mut self) -> Table {
690 self.aligned_results.sort_by(|a, b| {
691 let a_id = if a.record_id.is_empty() {
692 &a.record_uid
693 } else {
694 &a.record_id
695 };
696 let b_id = if b.record_id.is_empty() {
697 &b.record_uid
698 } else {
699 &b.record_id
700 };
701 a_id.cmp(b_id)
702 });
703
704 let entries: Vec<TaskResultTableEntry> = self
705 .aligned_results
706 .iter_mut()
707 .flat_map(|result| {
708 let resolved_id = if result.record_id.is_empty() {
709 &result.record_uid
710 } else {
711 &result.record_id
712 };
713
714 result.eval_set.build_task_entries(resolved_id)
715 })
716 .collect();
717
718 let mut table = Table::new(entries);
719 table.with(Style::sharp());
720
721 table.modify(
722 Rows::new(0..1),
723 (
724 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
725 Alignment::center(),
726 Color::BOLD,
727 ),
728 );
729
730 table
731 }
732
733 pub fn new() -> Self {
734 Self {
735 aligned_results: Vec::new(),
736 errored_tasks: Vec::new(),
737 array_dataset: None,
738 cluster_data: None,
739 histograms: None,
740 results_by_id: HashMap::new(),
741 }
742 }
743
744 pub fn add_success(
746 &mut self,
747 record: &GenAIEvalRecord,
748 eval_set: GenAIEvalSet,
749 embeddings: BTreeMap<String, Vec<f32>>,
750 ) {
751 self.aligned_results.push(AlignedEvalResult::from_success(
752 record, eval_set, embeddings,
753 ));
754
755 if !record.record_id.is_empty() {
756 self.results_by_id
757 .insert(record.record_id.clone(), self.aligned_results.len() - 1);
758 }
759 }
760
761 pub fn add_failure(&mut self, record: &GenAIEvalRecord, error: String) {
763 let uid = record.uid.clone();
764
765 self.aligned_results
766 .push(AlignedEvalResult::from_failure(record, error));
767
768 self.errored_tasks.push(uid);
769 }
770
771 pub fn finalize(&mut self, config: &Arc<EvaluationConfig>) -> Result<(), EvaluationError> {
780 if !config.embedding_targets.is_empty() {
782 post_process_aligned_results(self, config)?;
783 }
784
785 if config.compute_histograms {
786 self.build_array_dataset()?;
787
788 if let Some(array_dataset) = &self.array_dataset {
790 let profiler = NumProfiler::new();
791 let histograms = profiler.compute_histogram(
792 &array_dataset.data.view(),
793 &array_dataset.feature_names,
794 &10,
795 false,
796 )?;
797 self.histograms = Some(histograms);
798 }
799 }
800
801 Ok(())
802 }
803
804 fn build_array_dataset(&mut self) -> Result<(), EvaluationError> {
806 if self.array_dataset.is_none() {
807 self.array_dataset = Some(ArrayDataset::from_results(self)?);
808 }
809 Ok(())
810 }
811}
812
813impl Default for GenAIEvalResults {
814 fn default() -> Self {
815 Self::new()
816 }
817}
818
819pub fn array_to_dict<'py>(
820 py: Python<'py>,
821 array: &ArrayDataset,
822) -> Result<Bound<'py, PyDict>, EvaluationError> {
823 let pydict = PyDict::new(py);
824
825 pydict.set_item(
827 "task",
828 array.idx_map.values().cloned().collect::<Vec<String>>(),
829 )?;
830
831 for (i, feature) in array.feature_names.iter().enumerate() {
833 let column_data: Vec<f64> = array.data.column(i).to_vec();
834 pydict.set_item(feature, column_data)?;
835 }
836
837 if array.clusters.len() == array.data.nrows() {
839 pydict.set_item("cluster", array.clusters.clone())?;
840 }
841 Ok(pydict)
842}
843
844#[derive(Debug, Clone, Serialize, Deserialize)]
845#[pyclass]
846pub struct ClusterData {
847 #[pyo3(get)]
848 pub x: Vec<f64>,
849 #[pyo3(get)]
850 pub y: Vec<f64>,
851 #[pyo3(get)]
852 pub clusters: Vec<i32>,
853 pub idx_map: HashMap<usize, String>,
854}
855
856impl ClusterData {
857 pub fn new(
858 x: Vec<f64>,
859 y: Vec<f64>,
860 clusters: Vec<i32>,
861 idx_map: HashMap<usize, String>,
862 ) -> Self {
863 ClusterData {
864 x,
865 y,
866 clusters,
867 idx_map,
868 }
869 }
870}
871
872#[derive(Debug)]
873pub struct ArrayDataset {
874 pub data: Array2<f64>,
875 pub feature_names: Vec<String>,
876 pub idx_map: HashMap<usize, String>,
877 pub clusters: Vec<i32>,
878}
879
880impl Default for ArrayDataset {
881 fn default() -> Self {
882 Self::new()
883 }
884}
885
886impl ArrayDataset {
887 pub fn new() -> Self {
888 Self {
889 data: Array2::zeros((0, 0)),
890 feature_names: Vec::new(),
891 idx_map: HashMap::new(),
892 clusters: vec![],
893 }
894 }
895
896 fn build_feature_names(results: &GenAIEvalResults) -> Result<Vec<String>, EvaluationError> {
900 let first_result = results
902 .aligned_results
903 .iter()
904 .find(|r| r.success)
905 .ok_or(EvaluationError::NoResultsFound)?;
906
907 let mut names = Vec::new();
908
909 for task_record in &first_result.eval_set.records {
910 names.push(task_record.task_id.clone());
911 }
912
913 names.extend(first_result.mean_embeddings.keys().cloned());
914 names.extend(first_result.similarity_scores.keys().cloned());
915
916 Ok(names)
917 }
918
919 pub fn from_results(results: &GenAIEvalResults) -> Result<Self, EvaluationError> {
924 if results.aligned_results.is_empty() {
925 return Ok(Self::new());
926 }
927
928 let successful_results: Vec<&AlignedEvalResult> = results
930 .aligned_results
931 .iter()
932 .filter(|r| r.success)
933 .collect();
934
935 if successful_results.is_empty() {
936 return Err(EvaluationError::NoResultsFound);
937 }
938
939 let feature_names = Self::build_feature_names(results)?;
940 let n_rows = successful_results.len();
941 let n_cols = feature_names.len();
942
943 let mut data = Vec::with_capacity(n_rows * n_cols);
944 let mut idx_map = HashMap::new();
945
946 for (row_idx, aligned) in successful_results.iter().enumerate() {
949 idx_map.insert(row_idx, aligned.record_uid.clone());
950
951 let task_scores: HashMap<String, f64> = aligned
953 .eval_set
954 .records
955 .iter()
956 .map(|task| (task.task_id.clone(), task.value))
957 .collect();
958
959 let row: Vec<f64> = feature_names
961 .iter()
962 .map(|feature_name| {
963 if let Some(&score) = task_scores.get(feature_name) {
965 return score;
966 }
967
968 if let Some(&mean) = aligned.mean_embeddings.get(feature_name) {
970 return mean;
971 }
972
973 if let Some(&sim) = aligned.similarity_scores.get(feature_name) {
975 return sim;
976 }
977
978 0.0
980 })
981 .collect();
982
983 data.extend(row);
984 }
985
986 let array = Array2::from_shape_vec((n_rows, n_cols), data)?;
987
988 Ok(Self {
989 data: array,
990 feature_names,
991 idx_map,
992 clusters: vec![],
993 })
994 }
995}
996
997#[derive(Debug, Clone, Serialize, Deserialize)]
998#[pyclass]
999pub struct AlignedEvalResult {
1000 #[pyo3(get)]
1001 pub record_id: String,
1002
1003 #[pyo3(get)]
1004 pub record_uid: String,
1005
1006 #[pyo3(get)]
1007 pub eval_set: GenAIEvalSet,
1008
1009 #[pyo3(get)]
1010 #[serde(skip)]
1011 pub embeddings: BTreeMap<String, Vec<f32>>,
1012
1013 #[pyo3(get)]
1014 pub mean_embeddings: BTreeMap<String, f64>,
1015
1016 #[pyo3(get)]
1017 pub similarity_scores: BTreeMap<String, f64>,
1018
1019 #[pyo3(get)]
1020 pub success: bool,
1021
1022 #[pyo3(get)]
1023 pub error_message: Option<String>,
1024
1025 #[serde(skip)]
1026 pub context_snapshot: Option<BTreeMap<String, serde_json::Value>>,
1027}
1028
1029#[pymethods]
1030impl AlignedEvalResult {
1031 pub fn __str__(&self) -> String {
1032 PyHelperFuncs::__str__(self)
1033 }
1034
1035 #[getter]
1036 pub fn task_count(&self) -> usize {
1037 self.eval_set.records.len()
1038 }
1039}
1040
1041impl AlignedEvalResult {
1042 pub fn from_success(
1044 record: &GenAIEvalRecord,
1045 eval_set: GenAIEvalSet,
1046 embeddings: BTreeMap<String, Vec<f32>>,
1047 ) -> Self {
1048 Self {
1049 record_uid: record.uid.clone(),
1050 record_id: record.record_id.clone(),
1051 eval_set,
1052 embeddings,
1053 mean_embeddings: BTreeMap::new(),
1054 similarity_scores: BTreeMap::new(),
1055 success: true,
1056 error_message: None,
1057 context_snapshot: None,
1058 }
1059 }
1060
1061 pub fn from_failure(record: &GenAIEvalRecord, error: String) -> Self {
1063 Self {
1064 record_uid: record.uid.clone(),
1065 eval_set: GenAIEvalSet::empty(),
1066 embeddings: BTreeMap::new(),
1067 mean_embeddings: BTreeMap::new(),
1068 similarity_scores: BTreeMap::new(),
1069 success: false,
1070 error_message: Some(error),
1071 context_snapshot: None,
1072 record_id: record.record_id.clone(),
1073 }
1074 }
1075
1076 pub fn capture_context(&mut self, record: &GenAIEvalRecord) {
1078 if let serde_json::Value::Object(context_map) = &record.context {
1079 self.context_snapshot = Some(
1080 context_map
1081 .iter()
1082 .map(|(k, v)| (k.clone(), v.clone()))
1083 .collect(),
1084 );
1085 }
1086 }
1087
1088 pub fn to_flat_task_records(&self) -> Vec<BTreeMap<String, serde_json::Value>> {
1090 let mut records = Vec::new();
1091
1092 for task_result in &self.eval_set.records {
1093 let mut flat = BTreeMap::new();
1094
1095 flat.insert(
1097 "created_at".to_string(),
1098 self.eval_set.inner.created_at.to_rfc3339().into(),
1099 );
1100 flat.insert("record_uid".to_string(), self.record_uid.clone().into());
1101 flat.insert("success".to_string(), self.success.into());
1102
1103 flat.insert(
1105 "workflow_error".to_string(),
1106 match &self.error_message {
1107 Some(err) => serde_json::Value::String(err.clone()),
1108 None => serde_json::Value::String("".to_string()),
1109 },
1110 );
1111
1112 flat.insert(
1114 "workflow_total_tasks".to_string(),
1115 self.eval_set.inner.total_tasks.into(),
1116 );
1117 flat.insert(
1118 "workflow_passed_tasks".to_string(),
1119 self.eval_set.inner.passed_tasks.into(),
1120 );
1121 flat.insert(
1122 "workflow_failed_tasks".to_string(),
1123 self.eval_set.inner.failed_tasks.into(),
1124 );
1125 flat.insert(
1126 "workflow_pass_rate".to_string(),
1127 self.eval_set.inner.pass_rate.into(),
1128 );
1129 flat.insert(
1130 "workflow_duration_ms".to_string(),
1131 self.eval_set.inner.duration_ms.into(),
1132 );
1133
1134 flat.insert("task_id".to_string(), task_result.task_id.clone().into());
1136 flat.insert(
1137 "task_type".to_string(),
1138 task_result.task_type.to_string().into(),
1139 );
1140 flat.insert("task_passed".to_string(), task_result.passed.into());
1141 flat.insert("task_value".to_string(), task_result.value.into());
1142 flat.insert(
1143 "task_message".to_string(),
1144 serde_json::Value::String(task_result.message.clone()),
1145 );
1146
1147 flat.insert(
1148 "task_field_path".to_string(),
1149 match &task_result.field_path {
1150 Some(path) => serde_json::Value::String(path.clone()),
1151 None => serde_json::Value::Null,
1152 },
1153 );
1154
1155 flat.insert(
1156 "task_operator".to_string(),
1157 task_result.operator.to_string().into(),
1158 );
1159 flat.insert("task_expected".to_string(), task_result.expected.clone());
1160 flat.insert("task_actual".to_string(), task_result.actual.clone());
1161
1162 flat.insert(
1164 "context".to_string(),
1165 self.context_snapshot
1166 .as_ref()
1167 .map(|ctx| serde_json::to_value(ctx).unwrap_or(serde_json::Value::Null))
1168 .unwrap_or(serde_json::Value::Null),
1169 );
1170
1171 flat.insert(
1173 "embedding_means".to_string(),
1174 serde_json::to_value(&self.mean_embeddings).unwrap_or(serde_json::Value::Null),
1175 );
1176
1177 flat.insert(
1179 "similarity_scores".to_string(),
1180 serde_json::to_value(&self.similarity_scores).unwrap_or(serde_json::Value::Null),
1181 );
1182
1183 records.push(flat);
1184 }
1185
1186 records
1187 }
1188}
1189
1190#[derive(Debug, Clone, Default)]
1191#[pyclass]
1192pub struct EvaluationConfig {
1193 pub embedder: Option<Arc<Embedder>>,
1195
1196 pub embedding_targets: Vec<String>,
1198
1199 pub compute_similarity: bool,
1202
1203 pub compute_histograms: bool,
1205}
1206
1207#[pymethods]
1208impl EvaluationConfig {
1209 #[new]
1210 #[pyo3(signature = (embedder=None, embedding_targets=None, compute_similarity=false, compute_histograms=false))]
1211 fn new(
1220 embedder: Option<&Bound<'_, PyAny>>,
1221 embedding_targets: Option<Vec<String>>,
1222 compute_similarity: bool,
1223 compute_histograms: bool,
1224 ) -> Result<Self, EvaluationError> {
1225 let embedder = parse_embedder(embedder)?;
1226 let embedding_targets = embedding_targets.unwrap_or_default();
1227
1228 Ok(Self {
1229 embedder,
1230 embedding_targets,
1231 compute_similarity,
1232 compute_histograms,
1233 })
1234 }
1235
1236 pub fn needs_post_processing(&self) -> bool {
1237 !self.embedding_targets.is_empty()
1238 }
1239}