Skip to main content

scouter_evaluate/evaluate/
types.rs

1use crate::error::EvaluationError;
2use crate::evaluate::compare::compare_results;
3use crate::utils::parse_embedder;
4use crate::utils::post_process_aligned_results;
5use ndarray::Array2;
6use owo_colors::OwoColorize;
7use potato_head::Embedder;
8use potato_head::PyHelperFuncs;
9use pyo3::prelude::*;
10use pyo3::types::IntoPyDict;
11use pyo3::types::PyDict;
12use scouter_profile::{Histogram, NumProfiler};
13use scouter_types::genai::EvalSet;
14use scouter_types::{EvalRecord, TaskResultTableEntry, WorkflowResultTableEntry};
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::{BTreeMap, HashMap};
18use std::sync::Arc;
19use tabled::Tabled;
20use tabled::{
21    settings::{object::Rows, Alignment, Color, Format, Style},
22    Table,
23};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[pyclass]
27pub struct MissingTask {
28    #[pyo3(get)]
29    pub task_id: String,
30
31    #[pyo3(get)]
32    pub present_in: String,
33
34    #[pyo3(get)]
35    pub record_id: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[pyclass]
40pub struct TaskComparison {
41    #[pyo3(get)]
42    pub task_id: String,
43
44    #[pyo3(get)]
45    pub record_id: String,
46
47    #[pyo3(get)]
48    pub baseline_passed: bool,
49
50    #[pyo3(get)]
51    pub comparison_passed: bool,
52
53    #[pyo3(get)]
54    pub status_changed: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58#[pyclass]
59pub struct WorkflowComparison {
60    #[pyo3(get)]
61    pub baseline_id: String,
62
63    #[pyo3(get)]
64    pub comparison_id: String,
65
66    #[pyo3(get)]
67    pub baseline_pass_rate: f64,
68
69    #[pyo3(get)]
70    pub comparison_pass_rate: f64,
71
72    #[pyo3(get)]
73    pub pass_rate_delta: f64,
74
75    #[pyo3(get)]
76    pub is_regression: bool,
77
78    #[pyo3(get)]
79    pub task_comparisons: Vec<TaskComparison>,
80}
81
82#[derive(Tabled)]
83struct WorkflowComparisonEntry {
84    #[tabled(rename = "Baseline ID")]
85    baseline_id: String,
86    #[tabled(rename = "Comparison ID")]
87    comparison_id: String,
88    #[tabled(rename = "Baseline Pass Rate")]
89    baseline_pass_rate: String,
90    #[tabled(rename = "Comparison Pass Rate")]
91    comparison_pass_rate: String,
92    #[tabled(rename = "Delta")]
93    delta: String,
94    #[tabled(rename = "Status")]
95    status: String,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99#[pyclass]
100pub struct TaskAggregateStats {
101    #[pyo3(get)]
102    pub task_id: String,
103
104    #[pyo3(get)]
105    pub workflows_evaluated: usize,
106
107    #[pyo3(get)]
108    pub baseline_pass_count: usize,
109
110    #[pyo3(get)]
111    pub comparison_pass_count: usize,
112
113    #[pyo3(get)]
114    pub status_changed_count: usize,
115
116    #[pyo3(get)]
117    pub baseline_pass_rate: f64,
118
119    #[pyo3(get)]
120    pub comparison_pass_rate: f64,
121}
122
123#[derive(Tabled)]
124struct TaskAggregateEntry {
125    #[tabled(rename = "Task ID")]
126    task_id: String,
127    #[tabled(rename = "Workflows")]
128    workflows: String,
129    #[tabled(rename = "Baseline Pass Rate")]
130    baseline_rate: String,
131    #[tabled(rename = "Comparison Pass Rate")]
132    comparison_rate: String,
133    #[tabled(rename = "Delta")]
134    delta: String,
135    #[tabled(rename = "Status Changes")]
136    changes: String,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140#[pyclass]
141pub struct ComparisonResults {
142    #[pyo3(get)]
143    pub workflow_comparisons: Vec<WorkflowComparison>,
144
145    #[pyo3(get)]
146    pub total_workflows: usize,
147
148    #[pyo3(get)]
149    pub improved_workflows: usize,
150
151    #[pyo3(get)]
152    pub regressed_workflows: usize,
153
154    #[pyo3(get)]
155    pub unchanged_workflows: usize,
156
157    #[pyo3(get)]
158    pub mean_pass_rate_delta: f64,
159
160    #[pyo3(get)]
161    pub task_status_changes: Vec<TaskComparison>,
162
163    #[pyo3(get)]
164    pub missing_tasks: Vec<MissingTask>,
165
166    #[pyo3(get)]
167    pub baseline_workflow_count: usize,
168
169    #[pyo3(get)]
170    pub comparison_workflow_count: usize,
171
172    #[pyo3(get)]
173    pub regressed: bool,
174}
175
176#[pymethods]
177impl ComparisonResults {
178    #[getter]
179    pub fn task_aggregate_stats(&self) -> Vec<TaskAggregateStats> {
180        let mut task_stats: HashMap<String, (usize, usize, usize, usize)> = HashMap::new();
181
182        for wc in &self.workflow_comparisons {
183            for tc in &wc.task_comparisons {
184                let entry = task_stats.entry(tc.task_id.clone()).or_insert((0, 0, 0, 0));
185                entry.0 += 1; // workflows_evaluated
186                if tc.baseline_passed {
187                    entry.1 += 1; // baseline_pass_count
188                }
189                if tc.comparison_passed {
190                    entry.2 += 1; // comparison_pass_count
191                }
192                if tc.status_changed {
193                    entry.3 += 1; // status_changed_count
194                }
195            }
196        }
197
198        task_stats
199            .into_iter()
200            .map(
201                |(task_id, (total, baseline_pass, comparison_pass, changed))| TaskAggregateStats {
202                    task_id,
203                    workflows_evaluated: total,
204                    baseline_pass_count: baseline_pass,
205                    comparison_pass_count: comparison_pass,
206                    status_changed_count: changed,
207                    baseline_pass_rate: if total > 0 {
208                        baseline_pass as f64 / total as f64
209                    } else {
210                        0.0
211                    },
212                    comparison_pass_rate: if total > 0 {
213                        comparison_pass as f64 / total as f64
214                    } else {
215                        0.0
216                    },
217                },
218            )
219            .collect()
220    }
221
222    #[getter]
223    pub fn has_missing_tasks(&self) -> bool {
224        !self.missing_tasks.is_empty()
225    }
226
227    pub fn print_missing_tasks(&self) {
228        if !self.missing_tasks.is_empty() {
229            println!("\n{}", "⚠ Missing Tasks".yellow().bold());
230
231            let baseline_only: Vec<_> = self
232                .missing_tasks
233                .iter()
234                .filter(|t| t.present_in == "baseline_only")
235                .collect();
236
237            let comparison_only: Vec<_> = self
238                .missing_tasks
239                .iter()
240                .filter(|t| t.present_in == "comparison_only")
241                .collect();
242
243            if !baseline_only.is_empty() {
244                println!("  Baseline only ({} tasks):", baseline_only.len());
245                for task in baseline_only {
246                    println!("    - {} - {}", task.task_id, task.record_id);
247                }
248            }
249
250            if !comparison_only.is_empty() {
251                println!("  Comparison only ({} tasks):", comparison_only.len());
252                for task in comparison_only {
253                    println!("    - {} - {}", task.task_id, task.record_id);
254                }
255            }
256        }
257    }
258
259    pub fn __str__(&self) -> String {
260        PyHelperFuncs::__str__(self)
261    }
262
263    pub fn as_table(&self) {
264        self.print_summary_table();
265
266        if !self.task_status_changes.is_empty() {
267            println!(
268                "\n{}",
269                "Task Status Changes (Workflow-Specific)"
270                    .truecolor(245, 77, 85)
271                    .bold()
272            );
273            self.print_status_changes_table();
274        }
275
276        self.print_task_aggregate_table();
277
278        if self.has_missing_tasks() {
279            self.print_missing_tasks();
280        }
281
282        self.print_summary_stats();
283    }
284
285    fn print_task_aggregate_table(&self) {
286        let stats = self.task_aggregate_stats();
287
288        if stats.is_empty() {
289            return;
290        }
291
292        println!(
293            "\n{}",
294            "Task Aggregate Stats (Cross-Workflow)"
295                .truecolor(245, 77, 85)
296                .bold()
297        );
298
299        let entries: Vec<_> = stats
300            .iter()
301            .map(|ts| {
302                let baseline_rate = format!("{:.1}%", ts.baseline_pass_rate * 100.0);
303                let comparison_rate = format!("{:.1}%", ts.comparison_pass_rate * 100.0);
304                let delta_val = (ts.comparison_pass_rate - ts.baseline_pass_rate) * 100.0;
305                let delta_str = format!("{:+.1}%", delta_val);
306
307                let colored_delta = if delta_val > 1.0 {
308                    delta_str.green().to_string()
309                } else if delta_val < -1.0 {
310                    delta_str.red().to_string()
311                } else {
312                    delta_str.yellow().to_string()
313                };
314
315                let change_pct = if ts.workflows_evaluated > 0 {
316                    (ts.status_changed_count as f64 / ts.workflows_evaluated as f64) * 100.0
317                } else {
318                    0.0
319                };
320
321                TaskAggregateEntry {
322                    task_id: ts.task_id.clone(),
323                    workflows: ts.workflows_evaluated.to_string(),
324                    baseline_rate,
325                    comparison_rate,
326                    delta: colored_delta,
327                    changes: format!(
328                        "{}/{} ({:.0}%)",
329                        ts.status_changed_count, ts.workflows_evaluated, change_pct
330                    ),
331                }
332            })
333            .collect();
334
335        let mut table = Table::new(entries);
336        table.with(Style::sharp());
337
338        table.modify(
339            Rows::new(0..1),
340            (
341                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
342                Alignment::center(),
343                Color::BOLD,
344            ),
345        );
346
347        println!("{}", table);
348    }
349
350    fn print_summary_table(&self) {
351        let entries: Vec<_> = self
352            .workflow_comparisons
353            .iter()
354            .map(|wc| {
355                let baseline_rate = format!("{:.1}%", wc.baseline_pass_rate * 100.0);
356                let comparison_rate = format!("{:.1}%", wc.comparison_pass_rate * 100.0);
357                let delta_val = wc.pass_rate_delta * 100.0;
358                let delta_str = format!("{:+.1}%", delta_val);
359
360                let colored_delta = if delta_val > 1.0 {
361                    delta_str.green().to_string()
362                } else if delta_val < -1.0 {
363                    delta_str.red().to_string()
364                } else {
365                    delta_str.yellow().to_string()
366                };
367
368                let status = if wc.is_regression {
369                    "Regressed".red().to_string()
370                } else if wc.pass_rate_delta > 0.01 {
371                    "Improved".green().to_string()
372                } else {
373                    "Unchanged".yellow().to_string()
374                };
375
376                WorkflowComparisonEntry {
377                    baseline_id: wc.baseline_id[..16.min(wc.baseline_id.len())]
378                        .to_string()
379                        .truecolor(249, 179, 93)
380                        .to_string(),
381                    comparison_id: wc.comparison_id[..16.min(wc.comparison_id.len())]
382                        .to_string()
383                        .truecolor(249, 179, 93)
384                        .to_string(),
385                    baseline_pass_rate: baseline_rate,
386                    comparison_pass_rate: comparison_rate,
387                    delta: colored_delta,
388                    status,
389                }
390            })
391            .collect();
392
393        let mut table = Table::new(entries);
394        table.with(Style::sharp());
395
396        table.modify(
397            Rows::new(0..1),
398            (
399                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
400                Alignment::center(),
401                Color::BOLD,
402            ),
403        );
404
405        println!("{}", table);
406    }
407
408    fn print_status_changes_table(&self) {
409        let entries: Vec<_> = self
410            .task_status_changes
411            .iter()
412            .map(|tc| {
413                let baseline_status = if tc.baseline_passed {
414                    "✓ Pass".green().to_string()
415                } else {
416                    "✗ Fail".red().to_string()
417                };
418
419                let comparison_status = if tc.comparison_passed {
420                    "✓ Pass".green().to_string()
421                } else {
422                    "✗ Fail".red().to_string()
423                };
424
425                let change = match (tc.baseline_passed, tc.comparison_passed) {
426                    (true, false) => "Pass → Fail".red().bold().to_string(),
427                    (false, true) => "Fail → Pass".green().bold().to_string(),
428                    _ => "No Change".yellow().to_string(),
429                };
430
431                TaskStatusChangeEntry {
432                    task_id: tc.task_id.clone(),
433                    baseline_status,
434                    comparison_status,
435                    change,
436                }
437            })
438            .collect();
439
440        let mut table = Table::new(entries);
441        table.with(Style::sharp());
442
443        table.modify(
444            Rows::new(0..1),
445            (
446                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
447                Alignment::center(),
448                Color::BOLD,
449            ),
450        );
451
452        println!("{}", table);
453    }
454
455    fn print_summary_stats(&self) {
456        println!("\n{}", "Summary".truecolor(245, 77, 85).bold());
457
458        let regression_indicator = if self.regressed {
459            "⚠️  REGRESSION DETECTED".red().bold().to_string()
460        } else if self.improved_workflows > 0 {
461            "✅ IMPROVEMENT DETECTED".green().bold().to_string()
462        } else {
463            "➡️  NO SIGNIFICANT CHANGE".yellow().bold().to_string()
464        };
465
466        println!("  Overall Status: {}", regression_indicator);
467        println!("  Total Workflows: {}", self.total_workflows);
468        println!(
469            "  Improved: {}",
470            self.improved_workflows.to_string().green()
471        );
472        println!(
473            "  Regressed: {}",
474            self.regressed_workflows.to_string().red()
475        );
476        println!(
477            "  Unchanged: {}",
478            self.unchanged_workflows.to_string().yellow()
479        );
480
481        let mean_delta_str = format!("{:+.2}%", self.mean_pass_rate_delta * 100.0);
482        let colored_mean = if self.mean_pass_rate_delta > 0.0 {
483            mean_delta_str.green().to_string()
484        } else if self.mean_pass_rate_delta < 0.0 {
485            mean_delta_str.red().to_string()
486        } else {
487            mean_delta_str.yellow().to_string()
488        };
489        println!("  Mean Pass Rate Delta: {}", colored_mean);
490    }
491}
492
493impl ComparisonResults {}
494
495#[derive(Tabled)]
496struct TaskStatusChangeEntry {
497    #[tabled(rename = "Task ID")]
498    task_id: String,
499    #[tabled(rename = "Baseline")]
500    baseline_status: String,
501    #[tabled(rename = "Comparison")]
502    comparison_status: String,
503    #[tabled(rename = "Change")]
504    change: String,
505}
506
507#[derive(Debug, Clone, Serialize, Deserialize)]
508#[pyclass]
509pub struct EvalResults {
510    /// Aligned results in original record order
511    pub aligned_results: Vec<AlignedEvalResult>,
512
513    #[pyo3(get)]
514    pub errored_tasks: Vec<String>,
515
516    pub cluster_data: Option<ClusterData>,
517
518    #[pyo3(get)]
519    pub histograms: Option<HashMap<String, Histogram>>,
520
521    #[serde(skip)]
522    pub array_dataset: Option<ArrayDataset>,
523
524    #[serde(skip)]
525    pub results_by_id: HashMap<String, usize>,
526}
527
528#[pymethods]
529impl EvalResults {
530    pub fn __getitem__(&self, key: &str) -> Result<AlignedEvalResult, EvaluationError> {
531        self.results_by_id
532            .get(key)
533            .and_then(|&idx| self.aligned_results.get(idx))
534            .cloned()
535            .ok_or_else(|| EvaluationError::MissingKeyError(key.to_string()))
536    }
537
538    #[getter]
539    pub fn successful_count(&self) -> usize {
540        self.aligned_results.iter().filter(|r| r.success).count()
541    }
542
543    #[getter]
544    pub fn failed_count(&self) -> usize {
545        self.aligned_results.iter().filter(|r| !r.success).count()
546    }
547
548    /// Export to dataframe format
549    #[pyo3(signature = (polars=false))]
550    pub fn to_dataframe<'py>(
551        &mut self,
552        py: Python<'py>,
553        polars: bool,
554    ) -> Result<Bound<'py, PyAny>, EvaluationError> {
555        let all_task_records: Vec<_> = self
556            .aligned_results
557            .iter()
558            .flat_map(|r| r.to_flat_task_records())
559            .collect();
560
561        if all_task_records.is_empty() {
562            return Err(EvaluationError::NoResultsFound);
563        }
564
565        let py_records = PyDict::new(py);
566
567        // Collect all unique column names
568        let mut all_columns = std::collections::BTreeSet::new();
569        for record in &all_task_records {
570            all_columns.extend(record.keys().cloned());
571        }
572
573        // Build columns in consistent order
574        for column_name in all_columns {
575            let column_data: Vec<_> = all_task_records
576                .iter()
577                .map(|record| record.get(&column_name).cloned().unwrap_or(Value::Null))
578                .collect();
579
580            let py_col = pythonize::pythonize(py, &column_data)?;
581            py_records.set_item(&column_name, py_col)?;
582        }
583
584        let module = if polars { "polars" } else { "pandas" };
585        let df_module = py.import(module)?;
586        let df_class = df_module.getattr("DataFrame")?;
587
588        if polars {
589            let schema = self.get_schema_mapping(py)?;
590            let schema_dict = &[("schema", schema)].into_py_dict(py)?;
591            schema_dict.set_item("strict", false)?;
592            Ok(df_class.call((py_records,), Some(schema_dict))?)
593        } else {
594            Ok(df_class.call_method1("from_dict", (py_records,))?)
595        }
596    }
597
598    pub fn __str__(&self) -> String {
599        PyHelperFuncs::__str__(self)
600    }
601
602    #[pyo3(signature = (show_tasks=false))]
603    /// Display results as a table in the console
604    /// # Arguments
605    /// * `show_tasks` - If true, display detailed task results; otherwise, show workflow summary
606    pub fn as_table(&mut self, show_tasks: bool) {
607        if show_tasks {
608            let tasks_table = self.build_tasks_table();
609            println!("\n{}", "Task Details".truecolor(245, 77, 85).bold());
610            println!("{}", tasks_table);
611        } else {
612            let workflow_table = self.build_workflow_table();
613            println!("\n{}", "Workflow Summary".truecolor(245, 77, 85).bold());
614            println!("{}", workflow_table);
615        }
616    }
617
618    pub fn model_dump_json(&self) -> String {
619        PyHelperFuncs::__json__(self)
620    }
621
622    #[staticmethod]
623    pub fn model_validate_json(json_string: String) -> Result<Self, EvaluationError> {
624        Ok(serde_json::from_str(&json_string)?)
625    }
626
627    /// Compare this evaluation against another baseline evaluation
628    /// Matches workflows by their task structure and compares task-by-task
629    #[pyo3(signature = (baseline, regression_threshold=0.05))]
630    pub fn compare_to(
631        &self,
632        baseline: &EvalResults,
633        regression_threshold: f64,
634    ) -> Result<ComparisonResults, EvaluationError> {
635        compare_results(baseline, self, regression_threshold)
636    }
637}
638
639impl EvalResults {
640    fn get_schema_mapping<'py>(
641        &self,
642        py: Python<'py>,
643    ) -> Result<Bound<'py, PyDict>, EvaluationError> {
644        let schema = PyDict::new(py);
645        let pl = py.import("polars")?;
646
647        schema.set_item("created_at", pl.getattr("Utf8")?)?;
648        schema.set_item("record_uid", pl.getattr("Utf8")?)?;
649        schema.set_item("success", pl.getattr("Boolean")?)?;
650        schema.set_item("workflow_error", pl.getattr("Utf8")?)?;
651
652        schema.set_item("workflow_total_tasks", pl.getattr("Int64")?)?;
653        schema.set_item("workflow_passed_tasks", pl.getattr("Int64")?)?;
654        schema.set_item("workflow_failed_tasks", pl.getattr("Int64")?)?;
655        schema.set_item("workflow_pass_rate", pl.getattr("Float64")?)?;
656        schema.set_item("workflow_duration_ms", pl.getattr("Int64")?)?;
657
658        schema.set_item("task_id", pl.getattr("Utf8")?)?;
659        schema.set_item("task_type", pl.getattr("Utf8")?)?;
660        schema.set_item("task_passed", pl.getattr("Boolean")?)?;
661        schema.set_item("task_value", pl.getattr("Float64")?)?;
662        schema.set_item("task_message", pl.getattr("Utf8")?)?;
663        schema.set_item("task_assertion", pl.getattr("Utf8")?)?;
664        schema.set_item("task_operator", pl.getattr("Utf8")?)?;
665        schema.set_item("task_expected", pl.getattr("Utf8")?)?;
666        schema.set_item("task_actual", pl.getattr("Utf8")?)?;
667
668        schema.set_item("context", pl.getattr("Utf8")?)?;
669        schema.set_item("embedding_means", pl.getattr("Utf8")?)?;
670        schema.set_item("similarity_scores", pl.getattr("Utf8")?)?;
671
672        Ok(schema)
673    }
674    /// Build workflow result table for console display
675    fn build_workflow_table(&self) -> Table {
676        let entries: Vec<WorkflowResultTableEntry> = self
677            .aligned_results
678            .iter()
679            .flat_map(|result| result.eval_set.build_workflow_entries())
680            .collect();
681
682        let mut table = Table::new(entries);
683        table.with(Style::sharp());
684
685        table.modify(
686            Rows::new(0..1),
687            (
688                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
689                Alignment::center(),
690                Color::BOLD,
691            ),
692        );
693        table
694    }
695
696    /// Build detailed task results table for console display
697    fn build_tasks_table(&mut self) -> Table {
698        self.aligned_results.sort_by(|a, b| {
699            let a_id = if a.record_id.is_empty() {
700                &a.record_uid
701            } else {
702                &a.record_id
703            };
704            let b_id = if b.record_id.is_empty() {
705                &b.record_uid
706            } else {
707                &b.record_id
708            };
709            a_id.cmp(b_id)
710        });
711
712        let entries: Vec<TaskResultTableEntry> = self
713            .aligned_results
714            .iter_mut()
715            .flat_map(|result| {
716                let resolved_id = if result.record_id.is_empty() {
717                    &result.record_uid
718                } else {
719                    &result.record_id
720                };
721
722                result.eval_set.build_task_entries(resolved_id)
723            })
724            .collect();
725
726        let mut table = Table::new(entries);
727        table.with(Style::sharp());
728
729        table.modify(
730            Rows::new(0..1),
731            (
732                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
733                Alignment::center(),
734                Color::BOLD,
735            ),
736        );
737
738        table
739    }
740
741    pub fn new() -> Self {
742        Self {
743            aligned_results: Vec::new(),
744            errored_tasks: Vec::new(),
745            array_dataset: None,
746            cluster_data: None,
747            histograms: None,
748            results_by_id: HashMap::new(),
749        }
750    }
751
752    /// Add a successful result
753    pub fn add_success(
754        &mut self,
755        record: &EvalRecord,
756        eval_set: EvalSet,
757        embeddings: BTreeMap<String, Vec<f32>>,
758    ) {
759        self.aligned_results.push(AlignedEvalResult::from_success(
760            record, eval_set, embeddings,
761        ));
762
763        if !record.record_id.is_empty() {
764            self.results_by_id
765                .insert(record.record_id.clone(), self.aligned_results.len() - 1);
766        }
767    }
768
769    /// Add a failed result - only reference the record
770    pub fn add_failure(&mut self, record: &EvalRecord, error: String) {
771        let uid = record.uid.clone();
772
773        self.aligned_results
774            .push(AlignedEvalResult::from_failure(record, error));
775
776        self.errored_tasks.push(uid);
777    }
778
779    /// Finalize the results by performing post-processing steps which includes:
780    /// - Post-processing embeddings (if any)
781    /// - Building the array dataset (if not already built)
782    /// - Performing clustering and dimensionality reduction (if enabled) for visualization
783    /// # Arguments
784    /// * `config` - The evaluation configuration that dictates post-processing behavior
785    /// # Returns
786    /// * `Result<(), EvaluationError>` - Returns Ok(()) if successful, otherwise returns
787    pub fn finalize(&mut self, config: &Arc<EvaluationConfig>) -> Result<(), EvaluationError> {
788        // Post-process embeddings if needed
789        if !config.embedding_targets.is_empty() {
790            post_process_aligned_results(self, config)?;
791        }
792
793        if config.compute_histograms {
794            self.build_array_dataset()?;
795
796            // Compute histograms for all numeric fields
797            if let Some(array_dataset) = &self.array_dataset {
798                let profiler = NumProfiler::new();
799                let histograms = profiler.compute_histogram(
800                    &array_dataset.data.view(),
801                    &array_dataset.feature_names,
802                    &10,
803                    false,
804                )?;
805                self.histograms = Some(histograms);
806            }
807        }
808
809        Ok(())
810    }
811
812    /// Build an NDArray dataset from the result tasks
813    fn build_array_dataset(&mut self) -> Result<(), EvaluationError> {
814        if self.array_dataset.is_none() {
815            self.array_dataset = Some(ArrayDataset::from_results(self)?);
816        }
817        Ok(())
818    }
819}
820
821impl Default for EvalResults {
822    fn default() -> Self {
823        Self::new()
824    }
825}
826
827pub fn array_to_dict<'py>(
828    py: Python<'py>,
829    array: &ArrayDataset,
830) -> Result<Bound<'py, PyDict>, EvaluationError> {
831    let pydict = PyDict::new(py);
832
833    // set task ids
834    pydict.set_item(
835        "task",
836        array.idx_map.values().cloned().collect::<Vec<String>>(),
837    )?;
838
839    // set feature columns
840    for (i, feature) in array.feature_names.iter().enumerate() {
841        let column_data: Vec<f64> = array.data.column(i).to_vec();
842        pydict.set_item(feature, column_data)?;
843    }
844
845    // add cluster column if available
846    if array.clusters.len() == array.data.nrows() {
847        pydict.set_item("cluster", array.clusters.clone())?;
848    }
849    Ok(pydict)
850}
851
852#[derive(Debug, Clone, Serialize, Deserialize)]
853#[pyclass]
854pub struct ClusterData {
855    #[pyo3(get)]
856    pub x: Vec<f64>,
857    #[pyo3(get)]
858    pub y: Vec<f64>,
859    #[pyo3(get)]
860    pub clusters: Vec<i32>,
861    pub idx_map: HashMap<usize, String>,
862}
863
864impl ClusterData {
865    pub fn new(
866        x: Vec<f64>,
867        y: Vec<f64>,
868        clusters: Vec<i32>,
869        idx_map: HashMap<usize, String>,
870    ) -> Self {
871        ClusterData {
872            x,
873            y,
874            clusters,
875            idx_map,
876        }
877    }
878}
879
880#[derive(Debug, Clone)]
881pub struct ArrayDataset {
882    pub data: Array2<f64>,
883    pub feature_names: Vec<String>,
884    pub idx_map: HashMap<usize, String>,
885    pub clusters: Vec<i32>,
886}
887
888impl Default for ArrayDataset {
889    fn default() -> Self {
890        Self::new()
891    }
892}
893
894impl ArrayDataset {
895    pub fn new() -> Self {
896        Self {
897            data: Array2::zeros((0, 0)),
898            feature_names: Vec::new(),
899            idx_map: HashMap::new(),
900            clusters: vec![],
901        }
902    }
903
904    /// Build feature names from aligned results
905    /// This extracts all unique task names, embedding targets, and similarity pairs
906    /// from the evaluation results to create column names for the array dataset
907    fn build_feature_names(results: &EvalResults) -> Result<Vec<String>, EvaluationError> {
908        // Get first successful result to determine schema
909        let first_result = results
910            .aligned_results
911            .iter()
912            .find(|r| r.success)
913            .ok_or(EvaluationError::NoResultsFound)?;
914
915        let mut names = Vec::new();
916
917        for task_record in &first_result.eval_set.records {
918            names.push(task_record.task_id.clone());
919        }
920
921        names.extend(first_result.mean_embeddings.keys().cloned());
922        names.extend(first_result.similarity_scores.keys().cloned());
923
924        Ok(names)
925    }
926
927    // Build array dataset from aligned evaluation results
928    /// Creates a 2D array where:
929    /// - Rows = evaluation records
930    /// - Columns = task scores, embedding means, similarity scores
931    pub fn from_results(results: &EvalResults) -> Result<Self, EvaluationError> {
932        if results.aligned_results.is_empty() {
933            return Ok(Self::new());
934        }
935
936        // Only include successful evaluations in the dataset
937        let successful_results: Vec<&AlignedEvalResult> = results
938            .aligned_results
939            .iter()
940            .filter(|r| r.success)
941            .collect();
942
943        if successful_results.is_empty() {
944            return Err(EvaluationError::NoResultsFound);
945        }
946
947        let feature_names = Self::build_feature_names(results)?;
948        let n_rows = successful_results.len();
949        let n_cols = feature_names.len();
950
951        let mut data = Vec::with_capacity(n_rows * n_cols);
952        let mut idx_map = HashMap::new();
953
954        // Build task score lookup for efficient access
955        // This maps task_id -> score value for quick column population
956        for (row_idx, aligned) in successful_results.iter().enumerate() {
957            idx_map.insert(row_idx, aligned.record_uid.clone());
958
959            // Build lookup map from task_id to value
960            let task_scores: HashMap<String, f64> = aligned
961                .eval_set
962                .records
963                .iter()
964                .map(|task| (task.task_id.clone(), task.value))
965                .collect();
966
967            // Collect all values in correct column order
968            let row: Vec<f64> = feature_names
969                .iter()
970                .map(|feature_name| {
971                    // Try task scores first
972                    if let Some(&score) = task_scores.get(feature_name) {
973                        return score;
974                    }
975
976                    // Try embedding means
977                    if let Some(&mean) = aligned.mean_embeddings.get(feature_name) {
978                        return mean;
979                    }
980
981                    // Try similarity scores
982                    if let Some(&sim) = aligned.similarity_scores.get(feature_name) {
983                        return sim;
984                    }
985
986                    // Default for missing values
987                    0.0
988                })
989                .collect();
990
991            data.extend(row);
992        }
993
994        let array = Array2::from_shape_vec((n_rows, n_cols), data)?;
995
996        Ok(Self {
997            data: array,
998            feature_names,
999            idx_map,
1000            clusters: vec![],
1001        })
1002    }
1003}
1004
1005#[derive(Debug, Clone, Serialize, Deserialize)]
1006#[pyclass]
1007pub struct AlignedEvalResult {
1008    #[pyo3(get)]
1009    pub record_id: String,
1010
1011    #[pyo3(get)]
1012    pub record_uid: String,
1013
1014    #[pyo3(get)]
1015    pub eval_set: EvalSet,
1016
1017    #[pyo3(get)]
1018    #[serde(skip)]
1019    pub embeddings: BTreeMap<String, Vec<f32>>,
1020
1021    #[pyo3(get)]
1022    pub mean_embeddings: BTreeMap<String, f64>,
1023
1024    #[pyo3(get)]
1025    pub similarity_scores: BTreeMap<String, f64>,
1026
1027    #[pyo3(get)]
1028    pub success: bool,
1029
1030    #[pyo3(get)]
1031    pub error_message: Option<String>,
1032
1033    #[serde(skip)]
1034    pub context_snapshot: Option<BTreeMap<String, serde_json::Value>>,
1035}
1036
1037#[pymethods]
1038impl AlignedEvalResult {
1039    pub fn __str__(&self) -> String {
1040        PyHelperFuncs::__str__(self)
1041    }
1042
1043    #[getter]
1044    pub fn task_count(&self) -> usize {
1045        self.eval_set.records.len()
1046    }
1047}
1048
1049impl AlignedEvalResult {
1050    /// Create from successful evaluation
1051    pub fn from_success(
1052        record: &EvalRecord,
1053        eval_set: EvalSet,
1054        embeddings: BTreeMap<String, Vec<f32>>,
1055    ) -> Self {
1056        Self {
1057            record_uid: record.uid.clone(),
1058            record_id: record.record_id.clone(),
1059            eval_set,
1060            embeddings,
1061            mean_embeddings: BTreeMap::new(),
1062            similarity_scores: BTreeMap::new(),
1063            success: true,
1064            error_message: None,
1065            context_snapshot: None,
1066        }
1067    }
1068
1069    /// Create from failed evaluation
1070    pub fn from_failure(record: &EvalRecord, error: String) -> Self {
1071        Self {
1072            record_uid: record.uid.clone(),
1073            eval_set: EvalSet::empty(),
1074            embeddings: BTreeMap::new(),
1075            mean_embeddings: BTreeMap::new(),
1076            similarity_scores: BTreeMap::new(),
1077            success: false,
1078            error_message: Some(error),
1079            context_snapshot: None,
1080            record_id: record.record_id.clone(),
1081        }
1082    }
1083
1084    /// Capture context snapshot for dataframe export
1085    pub fn capture_context(&mut self, record: &EvalRecord) {
1086        if let serde_json::Value::Object(context_map) = &record.context {
1087            self.context_snapshot = Some(
1088                context_map
1089                    .iter()
1090                    .map(|(k, v)| (k.clone(), v.clone()))
1091                    .collect(),
1092            );
1093        }
1094    }
1095
1096    /// Get flattened data for dataframe export
1097    pub fn to_flat_task_records(&self) -> Vec<BTreeMap<String, serde_json::Value>> {
1098        let mut records = Vec::new();
1099
1100        for task_result in &self.eval_set.records {
1101            let mut flat = BTreeMap::new();
1102
1103            // Workflow metadata (repeated for each task)
1104            flat.insert(
1105                "created_at".to_string(),
1106                self.eval_set.inner.created_at.to_rfc3339().into(),
1107            );
1108            flat.insert("record_uid".to_string(), self.record_uid.clone().into());
1109            flat.insert("success".to_string(), self.success.into());
1110
1111            // insert workflow error or "" if none
1112            flat.insert(
1113                "workflow_error".to_string(),
1114                match &self.error_message {
1115                    Some(err) => serde_json::Value::String(err.clone()),
1116                    None => serde_json::Value::String("".to_string()),
1117                },
1118            );
1119
1120            // Workflow-level metrics
1121            flat.insert(
1122                "workflow_total_tasks".to_string(),
1123                self.eval_set.inner.total_tasks.into(),
1124            );
1125            flat.insert(
1126                "workflow_passed_tasks".to_string(),
1127                self.eval_set.inner.passed_tasks.into(),
1128            );
1129            flat.insert(
1130                "workflow_failed_tasks".to_string(),
1131                self.eval_set.inner.failed_tasks.into(),
1132            );
1133            flat.insert(
1134                "workflow_pass_rate".to_string(),
1135                self.eval_set.inner.pass_rate.into(),
1136            );
1137            flat.insert(
1138                "workflow_duration_ms".to_string(),
1139                self.eval_set.inner.duration_ms.into(),
1140            );
1141
1142            // Task-specific data
1143            flat.insert("task_id".to_string(), task_result.task_id.clone().into());
1144            flat.insert(
1145                "task_type".to_string(),
1146                task_result.task_type.to_string().into(),
1147            );
1148            flat.insert("task_passed".to_string(), task_result.passed.into());
1149            flat.insert("task_value".to_string(), task_result.value.into());
1150            flat.insert(
1151                "task_message".to_string(),
1152                serde_json::Value::String(task_result.message.clone()),
1153            );
1154
1155            flat.insert(
1156                "task_assertion".to_string(),
1157                serde_json::to_value(task_result.assertion.clone())
1158                    .unwrap_or(serde_json::Value::Null),
1159            );
1160
1161            flat.insert(
1162                "task_operator".to_string(),
1163                task_result.operator.to_string().into(),
1164            );
1165            flat.insert("task_expected".to_string(), task_result.expected.clone());
1166            flat.insert("task_actual".to_string(), task_result.actual.clone());
1167
1168            // Context as single JSON column
1169            flat.insert(
1170                "context".to_string(),
1171                self.context_snapshot
1172                    .as_ref()
1173                    .map(|ctx| serde_json::to_value(ctx).unwrap_or(serde_json::Value::Null))
1174                    .unwrap_or(serde_json::Value::Null),
1175            );
1176
1177            // Embedding means as single JSON column
1178            flat.insert(
1179                "embedding_means".to_string(),
1180                serde_json::to_value(&self.mean_embeddings).unwrap_or(serde_json::Value::Null),
1181            );
1182
1183            // Similarity scores as single JSON column
1184            flat.insert(
1185                "similarity_scores".to_string(),
1186                serde_json::to_value(&self.similarity_scores).unwrap_or(serde_json::Value::Null),
1187            );
1188
1189            records.push(flat);
1190        }
1191
1192        records
1193    }
1194}
1195
1196#[derive(Debug, Clone, Default)]
1197#[pyclass]
1198pub struct EvaluationConfig {
1199    // optional embedder for embedding-based evaluations
1200    pub embedder: Option<Arc<Embedder>>,
1201
1202    // fields in the record to generate embeddings for
1203    pub embedding_targets: Vec<String>,
1204
1205    // this will compute similarities for all combinations of embeddings in the targets
1206    // e.g. if you have targets ["a", "b"], it will compute similarity between a-b
1207    pub compute_similarity: bool,
1208
1209    // whether to compute histograms for all scores, embeddings and similarities (if available)
1210    pub compute_histograms: bool,
1211}
1212
1213#[pymethods]
1214impl EvaluationConfig {
1215    #[new]
1216    #[pyo3(signature = (embedder=None, embedding_targets=None, compute_similarity=false, compute_histograms=false))]
1217    /// Creates a new EvaluationConfig instance.
1218    /// # Arguments
1219    /// * `embedder` - Optional reference to a PyEmbedder instance.
1220    /// * `embedding_targets` - Optional list of fields in the record to generate embeddings for.
1221    /// * `compute_similarity` - Whether to compute similarities between embeddings.
1222    /// * `compute_histograms` - Whether to compute histograms for all scores, embeddings and similarities (if available).
1223    /// # Returns
1224    /// A new EvaluationConfig instance.
1225    fn new(
1226        embedder: Option<&Bound<'_, PyAny>>,
1227        embedding_targets: Option<Vec<String>>,
1228        compute_similarity: bool,
1229        compute_histograms: bool,
1230    ) -> Result<Self, EvaluationError> {
1231        let embedder = parse_embedder(embedder)?;
1232        let embedding_targets = embedding_targets.unwrap_or_default();
1233
1234        Ok(Self {
1235            embedder,
1236            embedding_targets,
1237            compute_similarity,
1238            compute_histograms,
1239        })
1240    }
1241
1242    pub fn needs_post_processing(&self) -> bool {
1243        !self.embedding_targets.is_empty()
1244    }
1245}