scouter_evaluate/evaluate/
compare.rs

1use crate::error::EvaluationError;
2use crate::evaluate::types::{
3    ComparisonResults, GenAIEvalResults, MissingTask, TaskComparison, WorkflowComparison,
4};
5use std::collections::HashMap;
6
7/// Compares two GenAIEvalResults datasets and produces a ComparisonResults summary.
8///
9/// Every workflow is compared against the baseline. The comparison identifies:
10/// - Tasks that passed/failed in both datasets
11/// - Tasks that changed status between baseline and comparison
12/// - Tasks missing in either dataset
13/// - Overall pass rate deltas and regression detection
14///
15/// # Arguments
16///
17/// * `baseline` - The baseline evaluation results to compare against
18/// * `comparison` - The evaluation results being compared
19/// * `regression_threshold` - Pass rate delta threshold to flag as regression
20///
21/// # Returns
22///
23/// A `ComparisonResults` struct containing workflow comparisons, task-level changes,
24/// and aggregate statistics.
25///
26/// # Errors
27///
28/// Returns `EvaluationError` if comparison processing fails.
29///
30/// # Algorithm
31///
32/// 1. Map baseline and comparison results by `record_uid`, filtering for successful runs
33/// 2. For each record present in both datasets:
34///    - Build task maps keyed by `task_id`
35///    - Compare task pass/fail status for all matched tasks
36///    - Track tasks only in baseline or comparison
37/// 3. Aggregate workflow-level statistics (pass rates, deltas, regressions)
38pub fn compare_results(
39    baseline: &GenAIEvalResults,
40    comparison: &GenAIEvalResults,
41    regression_threshold: f64,
42) -> Result<ComparisonResults, EvaluationError> {
43    let baseline_map: HashMap<_, _> = baseline
44        .aligned_results
45        .iter()
46        .filter(|r| r.success)
47        .map(|r| {
48            if r.record_id.is_empty() {
49                (r.record_uid.as_str(), r)
50            } else {
51                (r.record_id.as_str(), r)
52            }
53        })
54        .collect();
55
56    let comparison_map: HashMap<_, _> = comparison
57        .aligned_results
58        .iter()
59        .filter(|r| r.success)
60        .map(|r| {
61            if r.record_id.is_empty() {
62                (r.record_uid.as_str(), r)
63            } else {
64                (r.record_id.as_str(), r)
65            }
66        })
67        .collect();
68
69    let mut workflow_comparisons = Vec::new();
70    let mut task_status_changes = Vec::new();
71    let mut missing_tasks = Vec::new();
72
73    for (record_id, baseline_result) in &baseline_map {
74        if let Some(comparison_result) = comparison_map.get(record_id) {
75            let baseline_task_map: HashMap<_, _> = baseline_result
76                .eval_set
77                .records
78                .iter()
79                .map(|t| (t.task_id.as_str(), t))
80                .collect();
81
82            let comparison_task_map: HashMap<_, _> = comparison_result
83                .eval_set
84                .records
85                .iter()
86                .map(|t| (t.task_id.as_str(), t))
87                .collect();
88
89            let mut workflow_task_comparisons = Vec::new();
90            let mut matched_baseline_passed = 0;
91            let mut matched_comparison_passed = 0;
92            let mut total_matched = 0;
93
94            for (task_id, baseline_task) in &baseline_task_map {
95                if let Some(comparison_task) = comparison_task_map.get(task_id) {
96                    let status_changed = baseline_task.passed != comparison_task.passed;
97
98                    if baseline_task.passed {
99                        matched_baseline_passed += 1;
100                    }
101                    if comparison_task.passed {
102                        matched_comparison_passed += 1;
103                    }
104                    total_matched += 1;
105
106                    let task_comp = TaskComparison {
107                        task_id: task_id.to_string(),
108                        baseline_passed: baseline_task.passed,
109                        comparison_passed: comparison_task.passed,
110                        status_changed,
111                        record_id: (*record_id).to_string(),
112                    };
113
114                    workflow_task_comparisons.push(task_comp.clone());
115
116                    if status_changed {
117                        task_status_changes.push(task_comp.clone());
118                    }
119                } else {
120                    missing_tasks.push(MissingTask {
121                        task_id: task_id.to_string(),
122                        present_in: "baseline_only".to_string(),
123                        record_id: (*record_id).to_string(),
124                    });
125                }
126            }
127
128            for task_id in comparison_task_map.keys() {
129                if !baseline_task_map.contains_key(task_id) {
130                    missing_tasks.push(MissingTask {
131                        task_id: task_id.to_string(),
132                        present_in: "comparison_only".to_string(),
133                        record_id: (*record_id).to_string(),
134                    });
135                }
136            }
137
138            let baseline_pass_rate = if total_matched > 0 {
139                matched_baseline_passed as f64 / total_matched as f64
140            } else {
141                0.0
142            };
143
144            let comparison_pass_rate = if total_matched > 0 {
145                matched_comparison_passed as f64 / total_matched as f64
146            } else {
147                0.0
148            };
149
150            let pass_rate_delta = comparison_pass_rate - baseline_pass_rate;
151            let is_regression = pass_rate_delta < -regression_threshold;
152
153            workflow_comparisons.push(WorkflowComparison {
154                baseline_id: (*record_id).to_string(),
155                comparison_id: (*record_id).to_string(),
156                baseline_pass_rate,
157                comparison_pass_rate,
158                pass_rate_delta,
159                is_regression,
160                task_comparisons: workflow_task_comparisons,
161            });
162        }
163    }
164
165    let (improved, regressed, unchanged) =
166        workflow_comparisons
167            .iter()
168            .fold((0, 0, 0), |(i, r, u), wc| {
169                if wc.is_regression {
170                    (i, r + 1, u)
171                } else if wc.pass_rate_delta > 0.01 {
172                    (i + 1, r, u)
173                } else {
174                    (i, r, u + 1)
175                }
176            });
177
178    let mean_delta = if !workflow_comparisons.is_empty() {
179        workflow_comparisons
180            .iter()
181            .map(|wc| wc.pass_rate_delta)
182            .sum::<f64>()
183            / workflow_comparisons.len() as f64
184    } else {
185        0.0
186    };
187
188    let has_regressed = regressed > 0;
189
190    Ok(ComparisonResults {
191        workflow_comparisons,
192        total_workflows: baseline_map.len().min(comparison_map.len()),
193        improved_workflows: improved,
194        regressed_workflows: regressed,
195        unchanged_workflows: unchanged,
196        mean_pass_rate_delta: mean_delta,
197        task_status_changes,
198        missing_tasks,
199        baseline_workflow_count: baseline.aligned_results.len(),
200        comparison_workflow_count: comparison.aligned_results.len(),
201        regressed: has_regressed,
202    })
203}