1use crate::error::EvaluationError;
2use crate::evaluate::compare::compare_results;
3use crate::utils::parse_embedder;
4use crate::utils::post_process_aligned_results;
5use ndarray::Array2;
6use owo_colors::OwoColorize;
7use potato_head::Embedder;
8use potato_head::PyHelperFuncs;
9use pyo3::prelude::*;
10use pyo3::types::IntoPyDict;
11use pyo3::types::PyDict;
12use scouter_profile::{Histogram, NumProfiler};
13use scouter_types::genai::EvalSet;
14use scouter_types::{EvalRecord, TaskResultTableEntry, WorkflowResultTableEntry};
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::{BTreeMap, HashMap};
18use std::sync::Arc;
19use tabled::Tabled;
20use tabled::{
21 settings::{object::Rows, Alignment, Color, Format, Style},
22 Table,
23};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[pyclass]
27pub struct MissingTask {
28 #[pyo3(get)]
29 pub task_id: String,
30
31 #[pyo3(get)]
32 pub present_in: String,
33
34 #[pyo3(get)]
35 pub record_id: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[pyclass]
40pub struct TaskComparison {
41 #[pyo3(get)]
42 pub task_id: String,
43
44 #[pyo3(get)]
45 pub record_id: String,
46
47 #[pyo3(get)]
48 pub baseline_passed: bool,
49
50 #[pyo3(get)]
51 pub comparison_passed: bool,
52
53 #[pyo3(get)]
54 pub status_changed: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58#[pyclass]
59pub struct WorkflowComparison {
60 #[pyo3(get)]
61 pub baseline_id: String,
62
63 #[pyo3(get)]
64 pub comparison_id: String,
65
66 #[pyo3(get)]
67 pub baseline_pass_rate: f64,
68
69 #[pyo3(get)]
70 pub comparison_pass_rate: f64,
71
72 #[pyo3(get)]
73 pub pass_rate_delta: f64,
74
75 #[pyo3(get)]
76 pub is_regression: bool,
77
78 #[pyo3(get)]
79 pub task_comparisons: Vec<TaskComparison>,
80}
81
82#[derive(Tabled)]
83struct WorkflowComparisonEntry {
84 #[tabled(rename = "Baseline ID")]
85 baseline_id: String,
86 #[tabled(rename = "Comparison ID")]
87 comparison_id: String,
88 #[tabled(rename = "Baseline Pass Rate")]
89 baseline_pass_rate: String,
90 #[tabled(rename = "Comparison Pass Rate")]
91 comparison_pass_rate: String,
92 #[tabled(rename = "Delta")]
93 delta: String,
94 #[tabled(rename = "Status")]
95 status: String,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99#[pyclass]
100pub struct TaskAggregateStats {
101 #[pyo3(get)]
102 pub task_id: String,
103
104 #[pyo3(get)]
105 pub workflows_evaluated: usize,
106
107 #[pyo3(get)]
108 pub baseline_pass_count: usize,
109
110 #[pyo3(get)]
111 pub comparison_pass_count: usize,
112
113 #[pyo3(get)]
114 pub status_changed_count: usize,
115
116 #[pyo3(get)]
117 pub baseline_pass_rate: f64,
118
119 #[pyo3(get)]
120 pub comparison_pass_rate: f64,
121}
122
123#[derive(Tabled)]
124struct TaskAggregateEntry {
125 #[tabled(rename = "Task ID")]
126 task_id: String,
127 #[tabled(rename = "Workflows")]
128 workflows: String,
129 #[tabled(rename = "Baseline Pass Rate")]
130 baseline_rate: String,
131 #[tabled(rename = "Comparison Pass Rate")]
132 comparison_rate: String,
133 #[tabled(rename = "Delta")]
134 delta: String,
135 #[tabled(rename = "Status Changes")]
136 changes: String,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140#[pyclass]
141pub struct ComparisonResults {
142 #[pyo3(get)]
143 pub workflow_comparisons: Vec<WorkflowComparison>,
144
145 #[pyo3(get)]
146 pub total_workflows: usize,
147
148 #[pyo3(get)]
149 pub improved_workflows: usize,
150
151 #[pyo3(get)]
152 pub regressed_workflows: usize,
153
154 #[pyo3(get)]
155 pub unchanged_workflows: usize,
156
157 #[pyo3(get)]
158 pub mean_pass_rate_delta: f64,
159
160 #[pyo3(get)]
161 pub task_status_changes: Vec<TaskComparison>,
162
163 #[pyo3(get)]
164 pub missing_tasks: Vec<MissingTask>,
165
166 #[pyo3(get)]
167 pub baseline_workflow_count: usize,
168
169 #[pyo3(get)]
170 pub comparison_workflow_count: usize,
171
172 #[pyo3(get)]
173 pub regressed: bool,
174}
175
176#[pymethods]
177impl ComparisonResults {
178 #[getter]
179 pub fn task_aggregate_stats(&self) -> Vec<TaskAggregateStats> {
180 let mut task_stats: HashMap<String, (usize, usize, usize, usize)> = HashMap::new();
181
182 for wc in &self.workflow_comparisons {
183 for tc in &wc.task_comparisons {
184 let entry = task_stats.entry(tc.task_id.clone()).or_insert((0, 0, 0, 0));
185 entry.0 += 1; if tc.baseline_passed {
187 entry.1 += 1; }
189 if tc.comparison_passed {
190 entry.2 += 1; }
192 if tc.status_changed {
193 entry.3 += 1; }
195 }
196 }
197
198 task_stats
199 .into_iter()
200 .map(
201 |(task_id, (total, baseline_pass, comparison_pass, changed))| TaskAggregateStats {
202 task_id,
203 workflows_evaluated: total,
204 baseline_pass_count: baseline_pass,
205 comparison_pass_count: comparison_pass,
206 status_changed_count: changed,
207 baseline_pass_rate: if total > 0 {
208 baseline_pass as f64 / total as f64
209 } else {
210 0.0
211 },
212 comparison_pass_rate: if total > 0 {
213 comparison_pass as f64 / total as f64
214 } else {
215 0.0
216 },
217 },
218 )
219 .collect()
220 }
221
222 #[getter]
223 pub fn has_missing_tasks(&self) -> bool {
224 !self.missing_tasks.is_empty()
225 }
226
227 pub fn print_missing_tasks(&self) {
228 if !self.missing_tasks.is_empty() {
229 println!("\n{}", "⚠ Missing Tasks".yellow().bold());
230
231 let baseline_only: Vec<_> = self
232 .missing_tasks
233 .iter()
234 .filter(|t| t.present_in == "baseline_only")
235 .collect();
236
237 let comparison_only: Vec<_> = self
238 .missing_tasks
239 .iter()
240 .filter(|t| t.present_in == "comparison_only")
241 .collect();
242
243 if !baseline_only.is_empty() {
244 println!(" Baseline only ({} tasks):", baseline_only.len());
245 for task in baseline_only {
246 println!(" - {} - {}", task.task_id, task.record_id);
247 }
248 }
249
250 if !comparison_only.is_empty() {
251 println!(" Comparison only ({} tasks):", comparison_only.len());
252 for task in comparison_only {
253 println!(" - {} - {}", task.task_id, task.record_id);
254 }
255 }
256 }
257 }
258
259 pub fn __str__(&self) -> String {
260 PyHelperFuncs::__str__(self)
261 }
262
263 pub fn as_table(&self) {
264 self.print_summary_table();
265
266 if !self.task_status_changes.is_empty() {
267 println!(
268 "\n{}",
269 "Task Status Changes (Workflow-Specific)"
270 .truecolor(245, 77, 85)
271 .bold()
272 );
273 self.print_status_changes_table();
274 }
275
276 self.print_task_aggregate_table();
277
278 if self.has_missing_tasks() {
279 self.print_missing_tasks();
280 }
281
282 self.print_summary_stats();
283 }
284
285 fn print_task_aggregate_table(&self) {
286 let stats = self.task_aggregate_stats();
287
288 if stats.is_empty() {
289 return;
290 }
291
292 println!(
293 "\n{}",
294 "Task Aggregate Stats (Cross-Workflow)"
295 .truecolor(245, 77, 85)
296 .bold()
297 );
298
299 let entries: Vec<_> = stats
300 .iter()
301 .map(|ts| {
302 let baseline_rate = format!("{:.1}%", ts.baseline_pass_rate * 100.0);
303 let comparison_rate = format!("{:.1}%", ts.comparison_pass_rate * 100.0);
304 let delta_val = (ts.comparison_pass_rate - ts.baseline_pass_rate) * 100.0;
305 let delta_str = format!("{:+.1}%", delta_val);
306
307 let colored_delta = if delta_val > 1.0 {
308 delta_str.green().to_string()
309 } else if delta_val < -1.0 {
310 delta_str.red().to_string()
311 } else {
312 delta_str.yellow().to_string()
313 };
314
315 let change_pct = if ts.workflows_evaluated > 0 {
316 (ts.status_changed_count as f64 / ts.workflows_evaluated as f64) * 100.0
317 } else {
318 0.0
319 };
320
321 TaskAggregateEntry {
322 task_id: ts.task_id.clone(),
323 workflows: ts.workflows_evaluated.to_string(),
324 baseline_rate,
325 comparison_rate,
326 delta: colored_delta,
327 changes: format!(
328 "{}/{} ({:.0}%)",
329 ts.status_changed_count, ts.workflows_evaluated, change_pct
330 ),
331 }
332 })
333 .collect();
334
335 let mut table = Table::new(entries);
336 table.with(Style::sharp());
337
338 table.modify(
339 Rows::new(0..1),
340 (
341 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
342 Alignment::center(),
343 Color::BOLD,
344 ),
345 );
346
347 println!("{}", table);
348 }
349
350 fn print_summary_table(&self) {
351 let entries: Vec<_> = self
352 .workflow_comparisons
353 .iter()
354 .map(|wc| {
355 let baseline_rate = format!("{:.1}%", wc.baseline_pass_rate * 100.0);
356 let comparison_rate = format!("{:.1}%", wc.comparison_pass_rate * 100.0);
357 let delta_val = wc.pass_rate_delta * 100.0;
358 let delta_str = format!("{:+.1}%", delta_val);
359
360 let colored_delta = if delta_val > 1.0 {
361 delta_str.green().to_string()
362 } else if delta_val < -1.0 {
363 delta_str.red().to_string()
364 } else {
365 delta_str.yellow().to_string()
366 };
367
368 let status = if wc.is_regression {
369 "Regressed".red().to_string()
370 } else if wc.pass_rate_delta > 0.01 {
371 "Improved".green().to_string()
372 } else {
373 "Unchanged".yellow().to_string()
374 };
375
376 WorkflowComparisonEntry {
377 baseline_id: wc.baseline_id[..16.min(wc.baseline_id.len())]
378 .to_string()
379 .truecolor(249, 179, 93)
380 .to_string(),
381 comparison_id: wc.comparison_id[..16.min(wc.comparison_id.len())]
382 .to_string()
383 .truecolor(249, 179, 93)
384 .to_string(),
385 baseline_pass_rate: baseline_rate,
386 comparison_pass_rate: comparison_rate,
387 delta: colored_delta,
388 status,
389 }
390 })
391 .collect();
392
393 let mut table = Table::new(entries);
394 table.with(Style::sharp());
395
396 table.modify(
397 Rows::new(0..1),
398 (
399 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
400 Alignment::center(),
401 Color::BOLD,
402 ),
403 );
404
405 println!("{}", table);
406 }
407
408 fn print_status_changes_table(&self) {
409 let entries: Vec<_> = self
410 .task_status_changes
411 .iter()
412 .map(|tc| {
413 let baseline_status = if tc.baseline_passed {
414 "✓ Pass".green().to_string()
415 } else {
416 "✗ Fail".red().to_string()
417 };
418
419 let comparison_status = if tc.comparison_passed {
420 "✓ Pass".green().to_string()
421 } else {
422 "✗ Fail".red().to_string()
423 };
424
425 let change = match (tc.baseline_passed, tc.comparison_passed) {
426 (true, false) => "Pass → Fail".red().bold().to_string(),
427 (false, true) => "Fail → Pass".green().bold().to_string(),
428 _ => "No Change".yellow().to_string(),
429 };
430
431 TaskStatusChangeEntry {
432 task_id: tc.task_id.clone(),
433 baseline_status,
434 comparison_status,
435 change,
436 }
437 })
438 .collect();
439
440 let mut table = Table::new(entries);
441 table.with(Style::sharp());
442
443 table.modify(
444 Rows::new(0..1),
445 (
446 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
447 Alignment::center(),
448 Color::BOLD,
449 ),
450 );
451
452 println!("{}", table);
453 }
454
455 fn print_summary_stats(&self) {
456 println!("\n{}", "Summary".truecolor(245, 77, 85).bold());
457
458 let regression_indicator = if self.regressed {
459 "⚠️ REGRESSION DETECTED".red().bold().to_string()
460 } else if self.improved_workflows > 0 {
461 "✅ IMPROVEMENT DETECTED".green().bold().to_string()
462 } else {
463 "➡️ NO SIGNIFICANT CHANGE".yellow().bold().to_string()
464 };
465
466 println!(" Overall Status: {}", regression_indicator);
467 println!(" Total Workflows: {}", self.total_workflows);
468 println!(
469 " Improved: {}",
470 self.improved_workflows.to_string().green()
471 );
472 println!(
473 " Regressed: {}",
474 self.regressed_workflows.to_string().red()
475 );
476 println!(
477 " Unchanged: {}",
478 self.unchanged_workflows.to_string().yellow()
479 );
480
481 let mean_delta_str = format!("{:+.2}%", self.mean_pass_rate_delta * 100.0);
482 let colored_mean = if self.mean_pass_rate_delta > 0.0 {
483 mean_delta_str.green().to_string()
484 } else if self.mean_pass_rate_delta < 0.0 {
485 mean_delta_str.red().to_string()
486 } else {
487 mean_delta_str.yellow().to_string()
488 };
489 println!(" Mean Pass Rate Delta: {}", colored_mean);
490 }
491}
492
493impl ComparisonResults {}
494
495#[derive(Tabled)]
496struct TaskStatusChangeEntry {
497 #[tabled(rename = "Task ID")]
498 task_id: String,
499 #[tabled(rename = "Baseline")]
500 baseline_status: String,
501 #[tabled(rename = "Comparison")]
502 comparison_status: String,
503 #[tabled(rename = "Change")]
504 change: String,
505}
506
507#[derive(Debug, Clone, Serialize, Deserialize)]
508#[pyclass]
509pub struct EvalResults {
510 pub aligned_results: Vec<AlignedEvalResult>,
512
513 #[pyo3(get)]
514 pub errored_tasks: Vec<String>,
515
516 pub cluster_data: Option<ClusterData>,
517
518 #[pyo3(get)]
519 pub histograms: Option<HashMap<String, Histogram>>,
520
521 #[serde(skip)]
522 pub array_dataset: Option<ArrayDataset>,
523
524 #[serde(skip)]
525 pub results_by_id: HashMap<String, usize>,
526}
527
528#[pymethods]
529impl EvalResults {
530 pub fn __getitem__(&self, key: &str) -> Result<AlignedEvalResult, EvaluationError> {
531 self.results_by_id
532 .get(key)
533 .and_then(|&idx| self.aligned_results.get(idx))
534 .cloned()
535 .ok_or_else(|| EvaluationError::MissingKeyError(key.to_string()))
536 }
537
538 #[getter]
539 pub fn successful_count(&self) -> usize {
540 self.aligned_results.iter().filter(|r| r.success).count()
541 }
542
543 #[getter]
544 pub fn failed_count(&self) -> usize {
545 self.aligned_results.iter().filter(|r| !r.success).count()
546 }
547
548 #[pyo3(signature = (polars=false))]
550 pub fn to_dataframe<'py>(
551 &mut self,
552 py: Python<'py>,
553 polars: bool,
554 ) -> Result<Bound<'py, PyAny>, EvaluationError> {
555 let all_task_records: Vec<_> = self
556 .aligned_results
557 .iter()
558 .flat_map(|r| r.to_flat_task_records())
559 .collect();
560
561 if all_task_records.is_empty() {
562 return Err(EvaluationError::NoResultsFound);
563 }
564
565 let py_records = PyDict::new(py);
566
567 let mut all_columns = std::collections::BTreeSet::new();
569 for record in &all_task_records {
570 all_columns.extend(record.keys().cloned());
571 }
572
573 for column_name in all_columns {
575 let column_data: Vec<_> = all_task_records
576 .iter()
577 .map(|record| record.get(&column_name).cloned().unwrap_or(Value::Null))
578 .collect();
579
580 let py_col = pythonize::pythonize(py, &column_data)?;
581 py_records.set_item(&column_name, py_col)?;
582 }
583
584 let module = if polars { "polars" } else { "pandas" };
585 let df_module = py.import(module)?;
586 let df_class = df_module.getattr("DataFrame")?;
587
588 if polars {
589 let schema = self.get_schema_mapping(py)?;
590 let schema_dict = &[("schema", schema)].into_py_dict(py)?;
591 schema_dict.set_item("strict", false)?;
592 Ok(df_class.call((py_records,), Some(schema_dict))?)
593 } else {
594 Ok(df_class.call_method1("from_dict", (py_records,))?)
595 }
596 }
597
598 pub fn __str__(&self) -> String {
599 PyHelperFuncs::__str__(self)
600 }
601
602 #[pyo3(signature = (show_tasks=false))]
603 pub fn as_table(&mut self, show_tasks: bool) {
607 if show_tasks {
608 let tasks_table = self.build_tasks_table();
609 println!("\n{}", "Task Details".truecolor(245, 77, 85).bold());
610 println!("{}", tasks_table);
611 } else {
612 let workflow_table = self.build_workflow_table();
613 println!("\n{}", "Workflow Summary".truecolor(245, 77, 85).bold());
614 println!("{}", workflow_table);
615 }
616 }
617
618 pub fn model_dump_json(&self) -> String {
619 PyHelperFuncs::__json__(self)
620 }
621
622 #[staticmethod]
623 pub fn model_validate_json(json_string: String) -> Result<Self, EvaluationError> {
624 Ok(serde_json::from_str(&json_string)?)
625 }
626
627 #[pyo3(signature = (baseline, regression_threshold=0.05))]
630 pub fn compare_to(
631 &self,
632 baseline: &EvalResults,
633 regression_threshold: f64,
634 ) -> Result<ComparisonResults, EvaluationError> {
635 compare_results(baseline, self, regression_threshold)
636 }
637}
638
639impl EvalResults {
640 fn get_schema_mapping<'py>(
641 &self,
642 py: Python<'py>,
643 ) -> Result<Bound<'py, PyDict>, EvaluationError> {
644 let schema = PyDict::new(py);
645 let pl = py.import("polars")?;
646
647 schema.set_item("created_at", pl.getattr("Utf8")?)?;
648 schema.set_item("record_uid", pl.getattr("Utf8")?)?;
649 schema.set_item("success", pl.getattr("Boolean")?)?;
650 schema.set_item("workflow_error", pl.getattr("Utf8")?)?;
651
652 schema.set_item("workflow_total_tasks", pl.getattr("Int64")?)?;
653 schema.set_item("workflow_passed_tasks", pl.getattr("Int64")?)?;
654 schema.set_item("workflow_failed_tasks", pl.getattr("Int64")?)?;
655 schema.set_item("workflow_pass_rate", pl.getattr("Float64")?)?;
656 schema.set_item("workflow_duration_ms", pl.getattr("Int64")?)?;
657
658 schema.set_item("task_id", pl.getattr("Utf8")?)?;
659 schema.set_item("task_type", pl.getattr("Utf8")?)?;
660 schema.set_item("task_passed", pl.getattr("Boolean")?)?;
661 schema.set_item("task_value", pl.getattr("Float64")?)?;
662 schema.set_item("task_message", pl.getattr("Utf8")?)?;
663 schema.set_item("task_assertion", pl.getattr("Utf8")?)?;
664 schema.set_item("task_operator", pl.getattr("Utf8")?)?;
665 schema.set_item("task_expected", pl.getattr("Utf8")?)?;
666 schema.set_item("task_actual", pl.getattr("Utf8")?)?;
667
668 schema.set_item("context", pl.getattr("Utf8")?)?;
669 schema.set_item("embedding_means", pl.getattr("Utf8")?)?;
670 schema.set_item("similarity_scores", pl.getattr("Utf8")?)?;
671
672 Ok(schema)
673 }
674 fn build_workflow_table(&self) -> Table {
676 let entries: Vec<WorkflowResultTableEntry> = self
677 .aligned_results
678 .iter()
679 .flat_map(|result| result.eval_set.build_workflow_entries())
680 .collect();
681
682 let mut table = Table::new(entries);
683 table.with(Style::sharp());
684
685 table.modify(
686 Rows::new(0..1),
687 (
688 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
689 Alignment::center(),
690 Color::BOLD,
691 ),
692 );
693 table
694 }
695
696 fn build_tasks_table(&mut self) -> Table {
698 self.aligned_results.sort_by(|a, b| {
699 let a_id = if a.record_id.is_empty() {
700 &a.record_uid
701 } else {
702 &a.record_id
703 };
704 let b_id = if b.record_id.is_empty() {
705 &b.record_uid
706 } else {
707 &b.record_id
708 };
709 a_id.cmp(b_id)
710 });
711
712 let entries: Vec<TaskResultTableEntry> = self
713 .aligned_results
714 .iter_mut()
715 .flat_map(|result| {
716 let resolved_id = if result.record_id.is_empty() {
717 &result.record_uid
718 } else {
719 &result.record_id
720 };
721
722 result.eval_set.build_task_entries(resolved_id)
723 })
724 .collect();
725
726 let mut table = Table::new(entries);
727 table.with(Style::sharp());
728
729 table.modify(
730 Rows::new(0..1),
731 (
732 Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
733 Alignment::center(),
734 Color::BOLD,
735 ),
736 );
737
738 table
739 }
740
741 pub fn new() -> Self {
742 Self {
743 aligned_results: Vec::new(),
744 errored_tasks: Vec::new(),
745 array_dataset: None,
746 cluster_data: None,
747 histograms: None,
748 results_by_id: HashMap::new(),
749 }
750 }
751
752 pub fn add_success(
754 &mut self,
755 record: &EvalRecord,
756 eval_set: EvalSet,
757 embeddings: BTreeMap<String, Vec<f32>>,
758 ) {
759 self.aligned_results.push(AlignedEvalResult::from_success(
760 record, eval_set, embeddings,
761 ));
762
763 if !record.record_id.is_empty() {
764 self.results_by_id
765 .insert(record.record_id.clone(), self.aligned_results.len() - 1);
766 }
767 }
768
769 pub fn add_failure(&mut self, record: &EvalRecord, error: String) {
771 let uid = record.uid.clone();
772
773 self.aligned_results
774 .push(AlignedEvalResult::from_failure(record, error));
775
776 self.errored_tasks.push(uid);
777 }
778
779 pub fn finalize(&mut self, config: &Arc<EvaluationConfig>) -> Result<(), EvaluationError> {
788 if !config.embedding_targets.is_empty() {
790 post_process_aligned_results(self, config)?;
791 }
792
793 if config.compute_histograms {
794 self.build_array_dataset()?;
795
796 if let Some(array_dataset) = &self.array_dataset {
798 let profiler = NumProfiler::new();
799 let histograms = profiler.compute_histogram(
800 &array_dataset.data.view(),
801 &array_dataset.feature_names,
802 &10,
803 false,
804 )?;
805 self.histograms = Some(histograms);
806 }
807 }
808
809 Ok(())
810 }
811
812 fn build_array_dataset(&mut self) -> Result<(), EvaluationError> {
814 if self.array_dataset.is_none() {
815 self.array_dataset = Some(ArrayDataset::from_results(self)?);
816 }
817 Ok(())
818 }
819}
820
821impl Default for EvalResults {
822 fn default() -> Self {
823 Self::new()
824 }
825}
826
827pub fn array_to_dict<'py>(
828 py: Python<'py>,
829 array: &ArrayDataset,
830) -> Result<Bound<'py, PyDict>, EvaluationError> {
831 let pydict = PyDict::new(py);
832
833 pydict.set_item(
835 "task",
836 array.idx_map.values().cloned().collect::<Vec<String>>(),
837 )?;
838
839 for (i, feature) in array.feature_names.iter().enumerate() {
841 let column_data: Vec<f64> = array.data.column(i).to_vec();
842 pydict.set_item(feature, column_data)?;
843 }
844
845 if array.clusters.len() == array.data.nrows() {
847 pydict.set_item("cluster", array.clusters.clone())?;
848 }
849 Ok(pydict)
850}
851
852#[derive(Debug, Clone, Serialize, Deserialize)]
853#[pyclass]
854pub struct ClusterData {
855 #[pyo3(get)]
856 pub x: Vec<f64>,
857 #[pyo3(get)]
858 pub y: Vec<f64>,
859 #[pyo3(get)]
860 pub clusters: Vec<i32>,
861 pub idx_map: HashMap<usize, String>,
862}
863
864impl ClusterData {
865 pub fn new(
866 x: Vec<f64>,
867 y: Vec<f64>,
868 clusters: Vec<i32>,
869 idx_map: HashMap<usize, String>,
870 ) -> Self {
871 ClusterData {
872 x,
873 y,
874 clusters,
875 idx_map,
876 }
877 }
878}
879
880#[derive(Debug, Clone)]
881pub struct ArrayDataset {
882 pub data: Array2<f64>,
883 pub feature_names: Vec<String>,
884 pub idx_map: HashMap<usize, String>,
885 pub clusters: Vec<i32>,
886}
887
888impl Default for ArrayDataset {
889 fn default() -> Self {
890 Self::new()
891 }
892}
893
894impl ArrayDataset {
895 pub fn new() -> Self {
896 Self {
897 data: Array2::zeros((0, 0)),
898 feature_names: Vec::new(),
899 idx_map: HashMap::new(),
900 clusters: vec![],
901 }
902 }
903
904 fn build_feature_names(results: &EvalResults) -> Result<Vec<String>, EvaluationError> {
908 let first_result = results
910 .aligned_results
911 .iter()
912 .find(|r| r.success)
913 .ok_or(EvaluationError::NoResultsFound)?;
914
915 let mut names = Vec::new();
916
917 for task_record in &first_result.eval_set.records {
918 names.push(task_record.task_id.clone());
919 }
920
921 names.extend(first_result.mean_embeddings.keys().cloned());
922 names.extend(first_result.similarity_scores.keys().cloned());
923
924 Ok(names)
925 }
926
927 pub fn from_results(results: &EvalResults) -> Result<Self, EvaluationError> {
932 if results.aligned_results.is_empty() {
933 return Ok(Self::new());
934 }
935
936 let successful_results: Vec<&AlignedEvalResult> = results
938 .aligned_results
939 .iter()
940 .filter(|r| r.success)
941 .collect();
942
943 if successful_results.is_empty() {
944 return Err(EvaluationError::NoResultsFound);
945 }
946
947 let feature_names = Self::build_feature_names(results)?;
948 let n_rows = successful_results.len();
949 let n_cols = feature_names.len();
950
951 let mut data = Vec::with_capacity(n_rows * n_cols);
952 let mut idx_map = HashMap::new();
953
954 for (row_idx, aligned) in successful_results.iter().enumerate() {
957 idx_map.insert(row_idx, aligned.record_uid.clone());
958
959 let task_scores: HashMap<String, f64> = aligned
961 .eval_set
962 .records
963 .iter()
964 .map(|task| (task.task_id.clone(), task.value))
965 .collect();
966
967 let row: Vec<f64> = feature_names
969 .iter()
970 .map(|feature_name| {
971 if let Some(&score) = task_scores.get(feature_name) {
973 return score;
974 }
975
976 if let Some(&mean) = aligned.mean_embeddings.get(feature_name) {
978 return mean;
979 }
980
981 if let Some(&sim) = aligned.similarity_scores.get(feature_name) {
983 return sim;
984 }
985
986 0.0
988 })
989 .collect();
990
991 data.extend(row);
992 }
993
994 let array = Array2::from_shape_vec((n_rows, n_cols), data)?;
995
996 Ok(Self {
997 data: array,
998 feature_names,
999 idx_map,
1000 clusters: vec![],
1001 })
1002 }
1003}
1004
1005#[derive(Debug, Clone, Serialize, Deserialize)]
1006#[pyclass]
1007pub struct AlignedEvalResult {
1008 #[pyo3(get)]
1009 pub record_id: String,
1010
1011 #[pyo3(get)]
1012 pub record_uid: String,
1013
1014 #[pyo3(get)]
1015 pub eval_set: EvalSet,
1016
1017 #[pyo3(get)]
1018 #[serde(skip)]
1019 pub embeddings: BTreeMap<String, Vec<f32>>,
1020
1021 #[pyo3(get)]
1022 pub mean_embeddings: BTreeMap<String, f64>,
1023
1024 #[pyo3(get)]
1025 pub similarity_scores: BTreeMap<String, f64>,
1026
1027 #[pyo3(get)]
1028 pub success: bool,
1029
1030 #[pyo3(get)]
1031 pub error_message: Option<String>,
1032
1033 #[serde(skip)]
1034 pub context_snapshot: Option<BTreeMap<String, serde_json::Value>>,
1035}
1036
1037#[pymethods]
1038impl AlignedEvalResult {
1039 pub fn __str__(&self) -> String {
1040 PyHelperFuncs::__str__(self)
1041 }
1042
1043 #[getter]
1044 pub fn task_count(&self) -> usize {
1045 self.eval_set.records.len()
1046 }
1047}
1048
1049impl AlignedEvalResult {
1050 pub fn from_success(
1052 record: &EvalRecord,
1053 eval_set: EvalSet,
1054 embeddings: BTreeMap<String, Vec<f32>>,
1055 ) -> Self {
1056 Self {
1057 record_uid: record.uid.clone(),
1058 record_id: record.record_id.clone(),
1059 eval_set,
1060 embeddings,
1061 mean_embeddings: BTreeMap::new(),
1062 similarity_scores: BTreeMap::new(),
1063 success: true,
1064 error_message: None,
1065 context_snapshot: None,
1066 }
1067 }
1068
1069 pub fn from_failure(record: &EvalRecord, error: String) -> Self {
1071 Self {
1072 record_uid: record.uid.clone(),
1073 eval_set: EvalSet::empty(),
1074 embeddings: BTreeMap::new(),
1075 mean_embeddings: BTreeMap::new(),
1076 similarity_scores: BTreeMap::new(),
1077 success: false,
1078 error_message: Some(error),
1079 context_snapshot: None,
1080 record_id: record.record_id.clone(),
1081 }
1082 }
1083
1084 pub fn capture_context(&mut self, record: &EvalRecord) {
1086 if let serde_json::Value::Object(context_map) = &record.context {
1087 self.context_snapshot = Some(
1088 context_map
1089 .iter()
1090 .map(|(k, v)| (k.clone(), v.clone()))
1091 .collect(),
1092 );
1093 }
1094 }
1095
1096 pub fn to_flat_task_records(&self) -> Vec<BTreeMap<String, serde_json::Value>> {
1098 let mut records = Vec::new();
1099
1100 for task_result in &self.eval_set.records {
1101 let mut flat = BTreeMap::new();
1102
1103 flat.insert(
1105 "created_at".to_string(),
1106 self.eval_set.inner.created_at.to_rfc3339().into(),
1107 );
1108 flat.insert("record_uid".to_string(), self.record_uid.clone().into());
1109 flat.insert("success".to_string(), self.success.into());
1110
1111 flat.insert(
1113 "workflow_error".to_string(),
1114 match &self.error_message {
1115 Some(err) => serde_json::Value::String(err.clone()),
1116 None => serde_json::Value::String("".to_string()),
1117 },
1118 );
1119
1120 flat.insert(
1122 "workflow_total_tasks".to_string(),
1123 self.eval_set.inner.total_tasks.into(),
1124 );
1125 flat.insert(
1126 "workflow_passed_tasks".to_string(),
1127 self.eval_set.inner.passed_tasks.into(),
1128 );
1129 flat.insert(
1130 "workflow_failed_tasks".to_string(),
1131 self.eval_set.inner.failed_tasks.into(),
1132 );
1133 flat.insert(
1134 "workflow_pass_rate".to_string(),
1135 self.eval_set.inner.pass_rate.into(),
1136 );
1137 flat.insert(
1138 "workflow_duration_ms".to_string(),
1139 self.eval_set.inner.duration_ms.into(),
1140 );
1141
1142 flat.insert("task_id".to_string(), task_result.task_id.clone().into());
1144 flat.insert(
1145 "task_type".to_string(),
1146 task_result.task_type.to_string().into(),
1147 );
1148 flat.insert("task_passed".to_string(), task_result.passed.into());
1149 flat.insert("task_value".to_string(), task_result.value.into());
1150 flat.insert(
1151 "task_message".to_string(),
1152 serde_json::Value::String(task_result.message.clone()),
1153 );
1154
1155 flat.insert(
1156 "task_assertion".to_string(),
1157 serde_json::to_value(task_result.assertion.clone())
1158 .unwrap_or(serde_json::Value::Null),
1159 );
1160
1161 flat.insert(
1162 "task_operator".to_string(),
1163 task_result.operator.to_string().into(),
1164 );
1165 flat.insert("task_expected".to_string(), task_result.expected.clone());
1166 flat.insert("task_actual".to_string(), task_result.actual.clone());
1167
1168 flat.insert(
1170 "context".to_string(),
1171 self.context_snapshot
1172 .as_ref()
1173 .map(|ctx| serde_json::to_value(ctx).unwrap_or(serde_json::Value::Null))
1174 .unwrap_or(serde_json::Value::Null),
1175 );
1176
1177 flat.insert(
1179 "embedding_means".to_string(),
1180 serde_json::to_value(&self.mean_embeddings).unwrap_or(serde_json::Value::Null),
1181 );
1182
1183 flat.insert(
1185 "similarity_scores".to_string(),
1186 serde_json::to_value(&self.similarity_scores).unwrap_or(serde_json::Value::Null),
1187 );
1188
1189 records.push(flat);
1190 }
1191
1192 records
1193 }
1194}
1195
1196#[derive(Debug, Clone, Default)]
1197#[pyclass]
1198pub struct EvaluationConfig {
1199 pub embedder: Option<Arc<Embedder>>,
1201
1202 pub embedding_targets: Vec<String>,
1204
1205 pub compute_similarity: bool,
1208
1209 pub compute_histograms: bool,
1211}
1212
1213#[pymethods]
1214impl EvaluationConfig {
1215 #[new]
1216 #[pyo3(signature = (embedder=None, embedding_targets=None, compute_similarity=false, compute_histograms=false))]
1217 fn new(
1226 embedder: Option<&Bound<'_, PyAny>>,
1227 embedding_targets: Option<Vec<String>>,
1228 compute_similarity: bool,
1229 compute_histograms: bool,
1230 ) -> Result<Self, EvaluationError> {
1231 let embedder = parse_embedder(embedder)?;
1232 let embedding_targets = embedding_targets.unwrap_or_default();
1233
1234 Ok(Self {
1235 embedder,
1236 embedding_targets,
1237 compute_similarity,
1238 compute_histograms,
1239 })
1240 }
1241
1242 pub fn needs_post_processing(&self) -> bool {
1243 !self.embedding_targets.is_empty()
1244 }
1245}