Skip to main content

scouter_evaluate/evaluate/
types.rs

1use crate::error::EvaluationError;
2use crate::evaluate::compare::compare_results;
3use crate::utils::parse_embedder;
4use crate::utils::post_process_aligned_results;
5use ndarray::Array2;
6use owo_colors::OwoColorize;
7use potato_head::Embedder;
8use potato_head::PyHelperFuncs;
9use pyo3::prelude::*;
10use pyo3::types::IntoPyDict;
11use pyo3::types::PyDict;
12use scouter_profile::{Histogram, NumProfiler};
13use scouter_types::genai::GenAIEvalSet;
14use scouter_types::{GenAIEvalRecord, TaskResultTableEntry, WorkflowResultTableEntry};
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::{BTreeMap, HashMap};
18use std::sync::Arc;
19use tabled::Tabled;
20use tabled::{
21    settings::{object::Rows, Alignment, Color, Format, Style},
22    Table,
23};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26#[pyclass]
27pub struct MissingTask {
28    #[pyo3(get)]
29    pub task_id: String,
30
31    #[pyo3(get)]
32    pub present_in: String,
33
34    #[pyo3(get)]
35    pub record_id: String,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39#[pyclass]
40pub struct TaskComparison {
41    #[pyo3(get)]
42    pub task_id: String,
43
44    #[pyo3(get)]
45    pub record_id: String,
46
47    #[pyo3(get)]
48    pub baseline_passed: bool,
49
50    #[pyo3(get)]
51    pub comparison_passed: bool,
52
53    #[pyo3(get)]
54    pub status_changed: bool,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58#[pyclass]
59pub struct WorkflowComparison {
60    #[pyo3(get)]
61    pub baseline_id: String,
62
63    #[pyo3(get)]
64    pub comparison_id: String,
65
66    #[pyo3(get)]
67    pub baseline_pass_rate: f64,
68
69    #[pyo3(get)]
70    pub comparison_pass_rate: f64,
71
72    #[pyo3(get)]
73    pub pass_rate_delta: f64,
74
75    #[pyo3(get)]
76    pub is_regression: bool,
77
78    #[pyo3(get)]
79    pub task_comparisons: Vec<TaskComparison>,
80}
81
82#[derive(Tabled)]
83struct WorkflowComparisonEntry {
84    #[tabled(rename = "Baseline ID")]
85    baseline_id: String,
86    #[tabled(rename = "Comparison ID")]
87    comparison_id: String,
88    #[tabled(rename = "Baseline Pass Rate")]
89    baseline_pass_rate: String,
90    #[tabled(rename = "Comparison Pass Rate")]
91    comparison_pass_rate: String,
92    #[tabled(rename = "Delta")]
93    delta: String,
94    #[tabled(rename = "Status")]
95    status: String,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
99#[pyclass]
100pub struct TaskAggregateStats {
101    #[pyo3(get)]
102    pub task_id: String,
103
104    #[pyo3(get)]
105    pub workflows_evaluated: usize,
106
107    #[pyo3(get)]
108    pub baseline_pass_count: usize,
109
110    #[pyo3(get)]
111    pub comparison_pass_count: usize,
112
113    #[pyo3(get)]
114    pub status_changed_count: usize,
115
116    #[pyo3(get)]
117    pub baseline_pass_rate: f64,
118
119    #[pyo3(get)]
120    pub comparison_pass_rate: f64,
121}
122
123#[derive(Tabled)]
124struct TaskAggregateEntry {
125    #[tabled(rename = "Task ID")]
126    task_id: String,
127    #[tabled(rename = "Workflows")]
128    workflows: String,
129    #[tabled(rename = "Baseline Pass Rate")]
130    baseline_rate: String,
131    #[tabled(rename = "Comparison Pass Rate")]
132    comparison_rate: String,
133    #[tabled(rename = "Delta")]
134    delta: String,
135    #[tabled(rename = "Status Changes")]
136    changes: String,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize)]
140#[pyclass]
141pub struct ComparisonResults {
142    #[pyo3(get)]
143    pub workflow_comparisons: Vec<WorkflowComparison>,
144
145    #[pyo3(get)]
146    pub total_workflows: usize,
147
148    #[pyo3(get)]
149    pub improved_workflows: usize,
150
151    #[pyo3(get)]
152    pub regressed_workflows: usize,
153
154    #[pyo3(get)]
155    pub unchanged_workflows: usize,
156
157    #[pyo3(get)]
158    pub mean_pass_rate_delta: f64,
159
160    #[pyo3(get)]
161    pub task_status_changes: Vec<TaskComparison>,
162
163    #[pyo3(get)]
164    pub missing_tasks: Vec<MissingTask>,
165
166    #[pyo3(get)]
167    pub baseline_workflow_count: usize,
168
169    #[pyo3(get)]
170    pub comparison_workflow_count: usize,
171
172    #[pyo3(get)]
173    pub regressed: bool,
174}
175
176#[pymethods]
177impl ComparisonResults {
178    #[getter]
179    pub fn task_aggregate_stats(&self) -> Vec<TaskAggregateStats> {
180        let mut task_stats: HashMap<String, (usize, usize, usize, usize)> = HashMap::new();
181
182        for wc in &self.workflow_comparisons {
183            for tc in &wc.task_comparisons {
184                let entry = task_stats.entry(tc.task_id.clone()).or_insert((0, 0, 0, 0));
185                entry.0 += 1; // workflows_evaluated
186                if tc.baseline_passed {
187                    entry.1 += 1; // baseline_pass_count
188                }
189                if tc.comparison_passed {
190                    entry.2 += 1; // comparison_pass_count
191                }
192                if tc.status_changed {
193                    entry.3 += 1; // status_changed_count
194                }
195            }
196        }
197
198        task_stats
199            .into_iter()
200            .map(
201                |(task_id, (total, baseline_pass, comparison_pass, changed))| TaskAggregateStats {
202                    task_id,
203                    workflows_evaluated: total,
204                    baseline_pass_count: baseline_pass,
205                    comparison_pass_count: comparison_pass,
206                    status_changed_count: changed,
207                    baseline_pass_rate: baseline_pass as f64 / total as f64,
208                    comparison_pass_rate: comparison_pass as f64 / total as f64,
209                },
210            )
211            .collect()
212    }
213
214    #[getter]
215    pub fn has_missing_tasks(&self) -> bool {
216        !self.missing_tasks.is_empty()
217    }
218
219    pub fn print_missing_tasks(&self) {
220        if !self.missing_tasks.is_empty() {
221            println!("\n{}", "⚠ Missing Tasks".yellow().bold());
222
223            let baseline_only: Vec<_> = self
224                .missing_tasks
225                .iter()
226                .filter(|t| t.present_in == "baseline_only")
227                .collect();
228
229            let comparison_only: Vec<_> = self
230                .missing_tasks
231                .iter()
232                .filter(|t| t.present_in == "comparison_only")
233                .collect();
234
235            if !baseline_only.is_empty() {
236                println!("  Baseline only ({} tasks):", baseline_only.len());
237                for task in baseline_only {
238                    println!("    - {} - {}", task.task_id, task.record_id);
239                }
240            }
241
242            if !comparison_only.is_empty() {
243                println!("  Comparison only ({} tasks):", comparison_only.len());
244                for task in comparison_only {
245                    println!("    - {} - {}", task.task_id, task.record_id);
246                }
247            }
248        }
249    }
250
251    pub fn __str__(&self) -> String {
252        PyHelperFuncs::__str__(self)
253    }
254
255    pub fn as_table(&self) {
256        self.print_summary_table();
257
258        if !self.task_status_changes.is_empty() {
259            println!(
260                "\n{}",
261                "Task Status Changes (Workflow-Specific)"
262                    .truecolor(245, 77, 85)
263                    .bold()
264            );
265            self.print_status_changes_table();
266        }
267
268        self.print_task_aggregate_table();
269
270        if self.has_missing_tasks() {
271            self.print_missing_tasks();
272        }
273
274        self.print_summary_stats();
275    }
276
277    fn print_task_aggregate_table(&self) {
278        let stats = self.task_aggregate_stats();
279
280        if stats.is_empty() {
281            return;
282        }
283
284        println!(
285            "\n{}",
286            "Task Aggregate Stats (Cross-Workflow)"
287                .truecolor(245, 77, 85)
288                .bold()
289        );
290
291        let entries: Vec<_> = stats
292            .iter()
293            .map(|ts| {
294                let baseline_rate = format!("{:.1}%", ts.baseline_pass_rate * 100.0);
295                let comparison_rate = format!("{:.1}%", ts.comparison_pass_rate * 100.0);
296                let delta_val = (ts.comparison_pass_rate - ts.baseline_pass_rate) * 100.0;
297                let delta_str = format!("{:+.1}%", delta_val);
298
299                let colored_delta = if delta_val > 1.0 {
300                    delta_str.green().to_string()
301                } else if delta_val < -1.0 {
302                    delta_str.red().to_string()
303                } else {
304                    delta_str.yellow().to_string()
305                };
306
307                let change_pct = if ts.workflows_evaluated > 0 {
308                    (ts.status_changed_count as f64 / ts.workflows_evaluated as f64) * 100.0
309                } else {
310                    0.0
311                };
312
313                TaskAggregateEntry {
314                    task_id: ts.task_id.clone(),
315                    workflows: ts.workflows_evaluated.to_string(),
316                    baseline_rate,
317                    comparison_rate,
318                    delta: colored_delta,
319                    changes: format!(
320                        "{}/{} ({:.0}%)",
321                        ts.status_changed_count, ts.workflows_evaluated, change_pct
322                    ),
323                }
324            })
325            .collect();
326
327        let mut table = Table::new(entries);
328        table.with(Style::sharp());
329
330        table.modify(
331            Rows::new(0..1),
332            (
333                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
334                Alignment::center(),
335                Color::BOLD,
336            ),
337        );
338
339        println!("{}", table);
340    }
341
342    fn print_summary_table(&self) {
343        let entries: Vec<_> = self
344            .workflow_comparisons
345            .iter()
346            .map(|wc| {
347                let baseline_rate = format!("{:.1}%", wc.baseline_pass_rate * 100.0);
348                let comparison_rate = format!("{:.1}%", wc.comparison_pass_rate * 100.0);
349                let delta_val = wc.pass_rate_delta * 100.0;
350                let delta_str = format!("{:+.1}%", delta_val);
351
352                let colored_delta = if delta_val > 1.0 {
353                    delta_str.green().to_string()
354                } else if delta_val < -1.0 {
355                    delta_str.red().to_string()
356                } else {
357                    delta_str.yellow().to_string()
358                };
359
360                let status = if wc.is_regression {
361                    "Regressed".red().to_string()
362                } else if wc.pass_rate_delta > 0.01 {
363                    "Improved".green().to_string()
364                } else {
365                    "Unchanged".yellow().to_string()
366                };
367
368                WorkflowComparisonEntry {
369                    baseline_id: wc.baseline_id[..16.min(wc.baseline_id.len())]
370                        .to_string()
371                        .truecolor(249, 179, 93)
372                        .to_string(),
373                    comparison_id: wc.comparison_id[..16.min(wc.comparison_id.len())]
374                        .to_string()
375                        .truecolor(249, 179, 93)
376                        .to_string(),
377                    baseline_pass_rate: baseline_rate,
378                    comparison_pass_rate: comparison_rate,
379                    delta: colored_delta,
380                    status,
381                }
382            })
383            .collect();
384
385        let mut table = Table::new(entries);
386        table.with(Style::sharp());
387
388        table.modify(
389            Rows::new(0..1),
390            (
391                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
392                Alignment::center(),
393                Color::BOLD,
394            ),
395        );
396
397        println!("{}", table);
398    }
399
400    fn print_status_changes_table(&self) {
401        let entries: Vec<_> = self
402            .task_status_changes
403            .iter()
404            .map(|tc| {
405                let baseline_status = if tc.baseline_passed {
406                    "✓ Pass".green().to_string()
407                } else {
408                    "✗ Fail".red().to_string()
409                };
410
411                let comparison_status = if tc.comparison_passed {
412                    "✓ Pass".green().to_string()
413                } else {
414                    "✗ Fail".red().to_string()
415                };
416
417                let change = match (tc.baseline_passed, tc.comparison_passed) {
418                    (true, false) => "Pass → Fail".red().bold().to_string(),
419                    (false, true) => "Fail → Pass".green().bold().to_string(),
420                    _ => "No Change".yellow().to_string(),
421                };
422
423                TaskStatusChangeEntry {
424                    task_id: tc.task_id.clone(),
425                    baseline_status,
426                    comparison_status,
427                    change,
428                }
429            })
430            .collect();
431
432        let mut table = Table::new(entries);
433        table.with(Style::sharp());
434
435        table.modify(
436            Rows::new(0..1),
437            (
438                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
439                Alignment::center(),
440                Color::BOLD,
441            ),
442        );
443
444        println!("{}", table);
445    }
446
447    fn print_summary_stats(&self) {
448        println!("\n{}", "Summary".truecolor(245, 77, 85).bold());
449
450        let regression_indicator = if self.regressed {
451            "⚠️  REGRESSION DETECTED".red().bold().to_string()
452        } else if self.improved_workflows > 0 {
453            "✅ IMPROVEMENT DETECTED".green().bold().to_string()
454        } else {
455            "➡️  NO SIGNIFICANT CHANGE".yellow().bold().to_string()
456        };
457
458        println!("  Overall Status: {}", regression_indicator);
459        println!("  Total Workflows: {}", self.total_workflows);
460        println!(
461            "  Improved: {}",
462            self.improved_workflows.to_string().green()
463        );
464        println!(
465            "  Regressed: {}",
466            self.regressed_workflows.to_string().red()
467        );
468        println!(
469            "  Unchanged: {}",
470            self.unchanged_workflows.to_string().yellow()
471        );
472
473        let mean_delta_str = format!("{:+.2}%", self.mean_pass_rate_delta * 100.0);
474        let colored_mean = if self.mean_pass_rate_delta > 0.0 {
475            mean_delta_str.green().to_string()
476        } else if self.mean_pass_rate_delta < 0.0 {
477            mean_delta_str.red().to_string()
478        } else {
479            mean_delta_str.yellow().to_string()
480        };
481        println!("  Mean Pass Rate Delta: {}", colored_mean);
482    }
483}
484
485impl ComparisonResults {}
486
487#[derive(Tabled)]
488struct TaskStatusChangeEntry {
489    #[tabled(rename = "Task ID")]
490    task_id: String,
491    #[tabled(rename = "Baseline")]
492    baseline_status: String,
493    #[tabled(rename = "Comparison")]
494    comparison_status: String,
495    #[tabled(rename = "Change")]
496    change: String,
497}
498
499#[derive(Debug, Serialize, Deserialize)]
500#[pyclass]
501pub struct GenAIEvalResults {
502    /// Aligned results in original record order
503    pub aligned_results: Vec<AlignedEvalResult>,
504
505    #[pyo3(get)]
506    pub errored_tasks: Vec<String>,
507
508    pub cluster_data: Option<ClusterData>,
509
510    #[pyo3(get)]
511    pub histograms: Option<HashMap<String, Histogram>>,
512
513    #[serde(skip)]
514    pub array_dataset: Option<ArrayDataset>,
515
516    #[serde(skip)]
517    pub results_by_id: HashMap<String, usize>,
518}
519
520#[pymethods]
521impl GenAIEvalResults {
522    pub fn __getitem__(&self, key: &str) -> Result<AlignedEvalResult, EvaluationError> {
523        self.results_by_id
524            .get(key)
525            .and_then(|&idx| self.aligned_results.get(idx))
526            .cloned()
527            .ok_or_else(|| EvaluationError::MissingKeyError(key.to_string()))
528    }
529
530    #[getter]
531    pub fn successful_count(&self) -> usize {
532        self.aligned_results.iter().filter(|r| r.success).count()
533    }
534
535    #[getter]
536    pub fn failed_count(&self) -> usize {
537        self.aligned_results.iter().filter(|r| !r.success).count()
538    }
539
540    /// Export to dataframe format
541    #[pyo3(signature = (polars=false))]
542    pub fn to_dataframe<'py>(
543        &mut self,
544        py: Python<'py>,
545        polars: bool,
546    ) -> Result<Bound<'py, PyAny>, EvaluationError> {
547        let all_task_records: Vec<_> = self
548            .aligned_results
549            .iter()
550            .flat_map(|r| r.to_flat_task_records())
551            .collect();
552
553        if all_task_records.is_empty() {
554            return Err(EvaluationError::NoResultsFound);
555        }
556
557        let py_records = PyDict::new(py);
558
559        // Collect all unique column names
560        let mut all_columns = std::collections::BTreeSet::new();
561        for record in &all_task_records {
562            all_columns.extend(record.keys().cloned());
563        }
564
565        // Build columns in consistent order
566        for column_name in all_columns {
567            let column_data: Vec<_> = all_task_records
568                .iter()
569                .map(|record| record.get(&column_name).cloned().unwrap_or(Value::Null))
570                .collect();
571
572            let py_col = pythonize::pythonize(py, &column_data)?;
573            py_records.set_item(&column_name, py_col)?;
574        }
575
576        let module = if polars { "polars" } else { "pandas" };
577        let df_module = py.import(module)?;
578        let df_class = df_module.getattr("DataFrame")?;
579
580        if polars {
581            let schema = self.get_schema_mapping(py)?;
582            let schema_dict = &[("schema", schema)].into_py_dict(py)?;
583            schema_dict.set_item("strict", false)?;
584            Ok(df_class.call((py_records,), Some(schema_dict))?)
585        } else {
586            Ok(df_class.call_method1("from_dict", (py_records,))?)
587        }
588    }
589
590    pub fn __str__(&self) -> String {
591        PyHelperFuncs::__str__(self)
592    }
593
594    #[pyo3(signature = (show_tasks=false))]
595    /// Display results as a table in the console
596    /// # Arguments
597    /// * `show_tasks` - If true, display detailed task results; otherwise, show workflow summary
598    pub fn as_table(&mut self, show_tasks: bool) {
599        if show_tasks {
600            let tasks_table = self.build_tasks_table();
601            println!("\n{}", "Task Details".truecolor(245, 77, 85).bold());
602            println!("{}", tasks_table);
603        } else {
604            let workflow_table = self.build_workflow_table();
605            println!("\n{}", "Workflow Summary".truecolor(245, 77, 85).bold());
606            println!("{}", workflow_table);
607        }
608    }
609
610    pub fn model_dump_json(&self) -> String {
611        PyHelperFuncs::__json__(self)
612    }
613
614    #[staticmethod]
615    pub fn model_validate_json(json_string: String) -> Result<Self, EvaluationError> {
616        Ok(serde_json::from_str(&json_string)?)
617    }
618
619    /// Compare this evaluation against another baseline evaluation
620    /// Matches workflows by their task structure and compares task-by-task
621    #[pyo3(signature = (baseline, regression_threshold=0.05))]
622    pub fn compare_to(
623        &self,
624        baseline: &GenAIEvalResults,
625        regression_threshold: f64,
626    ) -> Result<ComparisonResults, EvaluationError> {
627        compare_results(baseline, self, regression_threshold)
628    }
629}
630
631impl GenAIEvalResults {
632    fn get_schema_mapping<'py>(
633        &self,
634        py: Python<'py>,
635    ) -> Result<Bound<'py, PyDict>, EvaluationError> {
636        let schema = PyDict::new(py);
637        let pl = py.import("polars")?;
638
639        schema.set_item("created_at", pl.getattr("Utf8")?)?;
640        schema.set_item("record_uid", pl.getattr("Utf8")?)?;
641        schema.set_item("success", pl.getattr("Boolean")?)?;
642        schema.set_item("workflow_error", pl.getattr("Utf8")?)?;
643
644        schema.set_item("workflow_total_tasks", pl.getattr("Int64")?)?;
645        schema.set_item("workflow_passed_tasks", pl.getattr("Int64")?)?;
646        schema.set_item("workflow_failed_tasks", pl.getattr("Int64")?)?;
647        schema.set_item("workflow_pass_rate", pl.getattr("Float64")?)?;
648        schema.set_item("workflow_duration_ms", pl.getattr("Int64")?)?;
649
650        schema.set_item("task_id", pl.getattr("Utf8")?)?;
651        schema.set_item("task_type", pl.getattr("Utf8")?)?;
652        schema.set_item("task_passed", pl.getattr("Boolean")?)?;
653        schema.set_item("task_value", pl.getattr("Float64")?)?;
654        schema.set_item("task_message", pl.getattr("Utf8")?)?;
655        schema.set_item("task_field_path", pl.getattr("Utf8")?)?;
656        schema.set_item("task_operator", pl.getattr("Utf8")?)?;
657        schema.set_item("task_expected", pl.getattr("Utf8")?)?;
658        schema.set_item("task_actual", pl.getattr("Utf8")?)?;
659
660        schema.set_item("context", pl.getattr("Utf8")?)?;
661        schema.set_item("embedding_means", pl.getattr("Utf8")?)?;
662        schema.set_item("similarity_scores", pl.getattr("Utf8")?)?;
663
664        Ok(schema)
665    }
666    /// Build workflow result table for console display
667    fn build_workflow_table(&self) -> Table {
668        let entries: Vec<WorkflowResultTableEntry> = self
669            .aligned_results
670            .iter()
671            .flat_map(|result| result.eval_set.build_workflow_entries())
672            .collect();
673
674        let mut table = Table::new(entries);
675        table.with(Style::sharp());
676
677        table.modify(
678            Rows::new(0..1),
679            (
680                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
681                Alignment::center(),
682                Color::BOLD,
683            ),
684        );
685        table
686    }
687
688    /// Build detailed task results table for console display
689    fn build_tasks_table(&mut self) -> Table {
690        self.aligned_results.sort_by(|a, b| {
691            let a_id = if a.record_id.is_empty() {
692                &a.record_uid
693            } else {
694                &a.record_id
695            };
696            let b_id = if b.record_id.is_empty() {
697                &b.record_uid
698            } else {
699                &b.record_id
700            };
701            a_id.cmp(b_id)
702        });
703
704        let entries: Vec<TaskResultTableEntry> = self
705            .aligned_results
706            .iter_mut()
707            .flat_map(|result| {
708                let resolved_id = if result.record_id.is_empty() {
709                    &result.record_uid
710                } else {
711                    &result.record_id
712                };
713
714                result.eval_set.build_task_entries(resolved_id)
715            })
716            .collect();
717
718        let mut table = Table::new(entries);
719        table.with(Style::sharp());
720
721        table.modify(
722            Rows::new(0..1),
723            (
724                Format::content(|s: &str| s.truecolor(245, 77, 85).bold().to_string()),
725                Alignment::center(),
726                Color::BOLD,
727            ),
728        );
729
730        table
731    }
732
733    pub fn new() -> Self {
734        Self {
735            aligned_results: Vec::new(),
736            errored_tasks: Vec::new(),
737            array_dataset: None,
738            cluster_data: None,
739            histograms: None,
740            results_by_id: HashMap::new(),
741        }
742    }
743
744    /// Add a successful result
745    pub fn add_success(
746        &mut self,
747        record: &GenAIEvalRecord,
748        eval_set: GenAIEvalSet,
749        embeddings: BTreeMap<String, Vec<f32>>,
750    ) {
751        self.aligned_results.push(AlignedEvalResult::from_success(
752            record, eval_set, embeddings,
753        ));
754
755        if !record.record_id.is_empty() {
756            self.results_by_id
757                .insert(record.record_id.clone(), self.aligned_results.len() - 1);
758        }
759    }
760
761    /// Add a failed result - only reference the record
762    pub fn add_failure(&mut self, record: &GenAIEvalRecord, error: String) {
763        let uid = record.uid.clone();
764
765        self.aligned_results
766            .push(AlignedEvalResult::from_failure(record, error));
767
768        self.errored_tasks.push(uid);
769    }
770
771    /// Finalize the results by performing post-processing steps which includes:
772    /// - Post-processing embeddings (if any)
773    /// - Building the array dataset (if not already built)
774    /// - Performing clustering and dimensionality reduction (if enabled) for visualization
775    /// # Arguments
776    /// * `config` - The evaluation configuration that dictates post-processing behavior
777    /// # Returns
778    /// * `Result<(), EvaluationError>` - Returns Ok(()) if successful, otherwise returns
779    pub fn finalize(&mut self, config: &Arc<EvaluationConfig>) -> Result<(), EvaluationError> {
780        // Post-process embeddings if needed
781        if !config.embedding_targets.is_empty() {
782            post_process_aligned_results(self, config)?;
783        }
784
785        if config.compute_histograms {
786            self.build_array_dataset()?;
787
788            // Compute histograms for all numeric fields
789            if let Some(array_dataset) = &self.array_dataset {
790                let profiler = NumProfiler::new();
791                let histograms = profiler.compute_histogram(
792                    &array_dataset.data.view(),
793                    &array_dataset.feature_names,
794                    &10,
795                    false,
796                )?;
797                self.histograms = Some(histograms);
798            }
799        }
800
801        Ok(())
802    }
803
804    /// Build an NDArray dataset from the result tasks
805    fn build_array_dataset(&mut self) -> Result<(), EvaluationError> {
806        if self.array_dataset.is_none() {
807            self.array_dataset = Some(ArrayDataset::from_results(self)?);
808        }
809        Ok(())
810    }
811}
812
813impl Default for GenAIEvalResults {
814    fn default() -> Self {
815        Self::new()
816    }
817}
818
819pub fn array_to_dict<'py>(
820    py: Python<'py>,
821    array: &ArrayDataset,
822) -> Result<Bound<'py, PyDict>, EvaluationError> {
823    let pydict = PyDict::new(py);
824
825    // set task ids
826    pydict.set_item(
827        "task",
828        array.idx_map.values().cloned().collect::<Vec<String>>(),
829    )?;
830
831    // set feature columns
832    for (i, feature) in array.feature_names.iter().enumerate() {
833        let column_data: Vec<f64> = array.data.column(i).to_vec();
834        pydict.set_item(feature, column_data)?;
835    }
836
837    // add cluster column if available
838    if array.clusters.len() == array.data.nrows() {
839        pydict.set_item("cluster", array.clusters.clone())?;
840    }
841    Ok(pydict)
842}
843
844#[derive(Debug, Clone, Serialize, Deserialize)]
845#[pyclass]
846pub struct ClusterData {
847    #[pyo3(get)]
848    pub x: Vec<f64>,
849    #[pyo3(get)]
850    pub y: Vec<f64>,
851    #[pyo3(get)]
852    pub clusters: Vec<i32>,
853    pub idx_map: HashMap<usize, String>,
854}
855
856impl ClusterData {
857    pub fn new(
858        x: Vec<f64>,
859        y: Vec<f64>,
860        clusters: Vec<i32>,
861        idx_map: HashMap<usize, String>,
862    ) -> Self {
863        ClusterData {
864            x,
865            y,
866            clusters,
867            idx_map,
868        }
869    }
870}
871
872#[derive(Debug)]
873pub struct ArrayDataset {
874    pub data: Array2<f64>,
875    pub feature_names: Vec<String>,
876    pub idx_map: HashMap<usize, String>,
877    pub clusters: Vec<i32>,
878}
879
880impl Default for ArrayDataset {
881    fn default() -> Self {
882        Self::new()
883    }
884}
885
886impl ArrayDataset {
887    pub fn new() -> Self {
888        Self {
889            data: Array2::zeros((0, 0)),
890            feature_names: Vec::new(),
891            idx_map: HashMap::new(),
892            clusters: vec![],
893        }
894    }
895
896    /// Build feature names from aligned results
897    /// This extracts all unique task names, embedding targets, and similarity pairs
898    /// from the evaluation results to create column names for the array dataset
899    fn build_feature_names(results: &GenAIEvalResults) -> Result<Vec<String>, EvaluationError> {
900        // Get first successful result to determine schema
901        let first_result = results
902            .aligned_results
903            .iter()
904            .find(|r| r.success)
905            .ok_or(EvaluationError::NoResultsFound)?;
906
907        let mut names = Vec::new();
908
909        for task_record in &first_result.eval_set.records {
910            names.push(task_record.task_id.clone());
911        }
912
913        names.extend(first_result.mean_embeddings.keys().cloned());
914        names.extend(first_result.similarity_scores.keys().cloned());
915
916        Ok(names)
917    }
918
919    // Build array dataset from aligned evaluation results
920    /// Creates a 2D array where:
921    /// - Rows = evaluation records
922    /// - Columns = task scores, embedding means, similarity scores
923    pub fn from_results(results: &GenAIEvalResults) -> Result<Self, EvaluationError> {
924        if results.aligned_results.is_empty() {
925            return Ok(Self::new());
926        }
927
928        // Only include successful evaluations in the dataset
929        let successful_results: Vec<&AlignedEvalResult> = results
930            .aligned_results
931            .iter()
932            .filter(|r| r.success)
933            .collect();
934
935        if successful_results.is_empty() {
936            return Err(EvaluationError::NoResultsFound);
937        }
938
939        let feature_names = Self::build_feature_names(results)?;
940        let n_rows = successful_results.len();
941        let n_cols = feature_names.len();
942
943        let mut data = Vec::with_capacity(n_rows * n_cols);
944        let mut idx_map = HashMap::new();
945
946        // Build task score lookup for efficient access
947        // This maps task_id -> score value for quick column population
948        for (row_idx, aligned) in successful_results.iter().enumerate() {
949            idx_map.insert(row_idx, aligned.record_uid.clone());
950
951            // Build lookup map from task_id to value
952            let task_scores: HashMap<String, f64> = aligned
953                .eval_set
954                .records
955                .iter()
956                .map(|task| (task.task_id.clone(), task.value))
957                .collect();
958
959            // Collect all values in correct column order
960            let row: Vec<f64> = feature_names
961                .iter()
962                .map(|feature_name| {
963                    // Try task scores first
964                    if let Some(&score) = task_scores.get(feature_name) {
965                        return score;
966                    }
967
968                    // Try embedding means
969                    if let Some(&mean) = aligned.mean_embeddings.get(feature_name) {
970                        return mean;
971                    }
972
973                    // Try similarity scores
974                    if let Some(&sim) = aligned.similarity_scores.get(feature_name) {
975                        return sim;
976                    }
977
978                    // Default for missing values
979                    0.0
980                })
981                .collect();
982
983            data.extend(row);
984        }
985
986        let array = Array2::from_shape_vec((n_rows, n_cols), data)?;
987
988        Ok(Self {
989            data: array,
990            feature_names,
991            idx_map,
992            clusters: vec![],
993        })
994    }
995}
996
997#[derive(Debug, Clone, Serialize, Deserialize)]
998#[pyclass]
999pub struct AlignedEvalResult {
1000    #[pyo3(get)]
1001    pub record_id: String,
1002
1003    #[pyo3(get)]
1004    pub record_uid: String,
1005
1006    #[pyo3(get)]
1007    pub eval_set: GenAIEvalSet,
1008
1009    #[pyo3(get)]
1010    #[serde(skip)]
1011    pub embeddings: BTreeMap<String, Vec<f32>>,
1012
1013    #[pyo3(get)]
1014    pub mean_embeddings: BTreeMap<String, f64>,
1015
1016    #[pyo3(get)]
1017    pub similarity_scores: BTreeMap<String, f64>,
1018
1019    #[pyo3(get)]
1020    pub success: bool,
1021
1022    #[pyo3(get)]
1023    pub error_message: Option<String>,
1024
1025    #[serde(skip)]
1026    pub context_snapshot: Option<BTreeMap<String, serde_json::Value>>,
1027}
1028
1029#[pymethods]
1030impl AlignedEvalResult {
1031    pub fn __str__(&self) -> String {
1032        PyHelperFuncs::__str__(self)
1033    }
1034
1035    #[getter]
1036    pub fn task_count(&self) -> usize {
1037        self.eval_set.records.len()
1038    }
1039}
1040
1041impl AlignedEvalResult {
1042    /// Create from successful evaluation
1043    pub fn from_success(
1044        record: &GenAIEvalRecord,
1045        eval_set: GenAIEvalSet,
1046        embeddings: BTreeMap<String, Vec<f32>>,
1047    ) -> Self {
1048        Self {
1049            record_uid: record.uid.clone(),
1050            record_id: record.record_id.clone(),
1051            eval_set,
1052            embeddings,
1053            mean_embeddings: BTreeMap::new(),
1054            similarity_scores: BTreeMap::new(),
1055            success: true,
1056            error_message: None,
1057            context_snapshot: None,
1058        }
1059    }
1060
1061    /// Create from failed evaluation
1062    pub fn from_failure(record: &GenAIEvalRecord, error: String) -> Self {
1063        Self {
1064            record_uid: record.uid.clone(),
1065            eval_set: GenAIEvalSet::empty(),
1066            embeddings: BTreeMap::new(),
1067            mean_embeddings: BTreeMap::new(),
1068            similarity_scores: BTreeMap::new(),
1069            success: false,
1070            error_message: Some(error),
1071            context_snapshot: None,
1072            record_id: record.record_id.clone(),
1073        }
1074    }
1075
1076    /// Capture context snapshot for dataframe export
1077    pub fn capture_context(&mut self, record: &GenAIEvalRecord) {
1078        if let serde_json::Value::Object(context_map) = &record.context {
1079            self.context_snapshot = Some(
1080                context_map
1081                    .iter()
1082                    .map(|(k, v)| (k.clone(), v.clone()))
1083                    .collect(),
1084            );
1085        }
1086    }
1087
1088    /// Get flattened data for dataframe export
1089    pub fn to_flat_task_records(&self) -> Vec<BTreeMap<String, serde_json::Value>> {
1090        let mut records = Vec::new();
1091
1092        for task_result in &self.eval_set.records {
1093            let mut flat = BTreeMap::new();
1094
1095            // Workflow metadata (repeated for each task)
1096            flat.insert(
1097                "created_at".to_string(),
1098                self.eval_set.inner.created_at.to_rfc3339().into(),
1099            );
1100            flat.insert("record_uid".to_string(), self.record_uid.clone().into());
1101            flat.insert("success".to_string(), self.success.into());
1102
1103            // insert workflow error or "" if none
1104            flat.insert(
1105                "workflow_error".to_string(),
1106                match &self.error_message {
1107                    Some(err) => serde_json::Value::String(err.clone()),
1108                    None => serde_json::Value::String("".to_string()),
1109                },
1110            );
1111
1112            // Workflow-level metrics
1113            flat.insert(
1114                "workflow_total_tasks".to_string(),
1115                self.eval_set.inner.total_tasks.into(),
1116            );
1117            flat.insert(
1118                "workflow_passed_tasks".to_string(),
1119                self.eval_set.inner.passed_tasks.into(),
1120            );
1121            flat.insert(
1122                "workflow_failed_tasks".to_string(),
1123                self.eval_set.inner.failed_tasks.into(),
1124            );
1125            flat.insert(
1126                "workflow_pass_rate".to_string(),
1127                self.eval_set.inner.pass_rate.into(),
1128            );
1129            flat.insert(
1130                "workflow_duration_ms".to_string(),
1131                self.eval_set.inner.duration_ms.into(),
1132            );
1133
1134            // Task-specific data
1135            flat.insert("task_id".to_string(), task_result.task_id.clone().into());
1136            flat.insert(
1137                "task_type".to_string(),
1138                task_result.task_type.to_string().into(),
1139            );
1140            flat.insert("task_passed".to_string(), task_result.passed.into());
1141            flat.insert("task_value".to_string(), task_result.value.into());
1142            flat.insert(
1143                "task_message".to_string(),
1144                serde_json::Value::String(task_result.message.clone()),
1145            );
1146
1147            flat.insert(
1148                "task_field_path".to_string(),
1149                match &task_result.field_path {
1150                    Some(path) => serde_json::Value::String(path.clone()),
1151                    None => serde_json::Value::Null,
1152                },
1153            );
1154
1155            flat.insert(
1156                "task_operator".to_string(),
1157                task_result.operator.to_string().into(),
1158            );
1159            flat.insert("task_expected".to_string(), task_result.expected.clone());
1160            flat.insert("task_actual".to_string(), task_result.actual.clone());
1161
1162            // Context as single JSON column
1163            flat.insert(
1164                "context".to_string(),
1165                self.context_snapshot
1166                    .as_ref()
1167                    .map(|ctx| serde_json::to_value(ctx).unwrap_or(serde_json::Value::Null))
1168                    .unwrap_or(serde_json::Value::Null),
1169            );
1170
1171            // Embedding means as single JSON column
1172            flat.insert(
1173                "embedding_means".to_string(),
1174                serde_json::to_value(&self.mean_embeddings).unwrap_or(serde_json::Value::Null),
1175            );
1176
1177            // Similarity scores as single JSON column
1178            flat.insert(
1179                "similarity_scores".to_string(),
1180                serde_json::to_value(&self.similarity_scores).unwrap_or(serde_json::Value::Null),
1181            );
1182
1183            records.push(flat);
1184        }
1185
1186        records
1187    }
1188}
1189
1190#[derive(Debug, Clone, Default)]
1191#[pyclass]
1192pub struct EvaluationConfig {
1193    // optional embedder for embedding-based evaluations
1194    pub embedder: Option<Arc<Embedder>>,
1195
1196    // fields in the record to generate embeddings for
1197    pub embedding_targets: Vec<String>,
1198
1199    // this will compute similarities for all combinations of embeddings in the targets
1200    // e.g. if you have targets ["a", "b"], it will compute similarity between a-b
1201    pub compute_similarity: bool,
1202
1203    // whether to compute histograms for all scores, embeddings and similarities (if available)
1204    pub compute_histograms: bool,
1205}
1206
1207#[pymethods]
1208impl EvaluationConfig {
1209    #[new]
1210    #[pyo3(signature = (embedder=None, embedding_targets=None, compute_similarity=false, compute_histograms=false))]
1211    /// Creates a new EvaluationConfig instance.
1212    /// # Arguments
1213    /// * `embedder` - Optional reference to a PyEmbedder instance.
1214    /// * `embedding_targets` - Optional list of fields in the record to generate embeddings for.
1215    /// * `compute_similarity` - Whether to compute similarities between embeddings.
1216    /// * `compute_histograms` - Whether to compute histograms for all scores, embeddings and similarities (if available).
1217    /// # Returns
1218    /// A new EvaluationConfig instance.
1219    fn new(
1220        embedder: Option<&Bound<'_, PyAny>>,
1221        embedding_targets: Option<Vec<String>>,
1222        compute_similarity: bool,
1223        compute_histograms: bool,
1224    ) -> Result<Self, EvaluationError> {
1225        let embedder = parse_embedder(embedder)?;
1226        let embedding_targets = embedding_targets.unwrap_or_default();
1227
1228        Ok(Self {
1229            embedder,
1230            embedding_targets,
1231            compute_similarity,
1232            compute_histograms,
1233        })
1234    }
1235
1236    pub fn needs_post_processing(&self) -> bool {
1237        !self.embedding_targets.is_empty()
1238    }
1239}