scouter_evaluate/evaluate/
compare.rs1use crate::error::EvaluationError;
2use crate::evaluate::types::{
3 ComparisonResults, GenAIEvalResults, MissingTask, TaskComparison, WorkflowComparison,
4};
5use std::collections::HashMap;
6
7pub fn compare_results(
39 baseline: &GenAIEvalResults,
40 comparison: &GenAIEvalResults,
41 regression_threshold: f64,
42) -> Result<ComparisonResults, EvaluationError> {
43 let baseline_map: HashMap<_, _> = baseline
44 .aligned_results
45 .iter()
46 .filter(|r| r.success)
47 .map(|r| {
48 if r.record_id.is_empty() {
49 (r.record_uid.as_str(), r)
50 } else {
51 (r.record_id.as_str(), r)
52 }
53 })
54 .collect();
55
56 let comparison_map: HashMap<_, _> = comparison
57 .aligned_results
58 .iter()
59 .filter(|r| r.success)
60 .map(|r| {
61 if r.record_id.is_empty() {
62 (r.record_uid.as_str(), r)
63 } else {
64 (r.record_id.as_str(), r)
65 }
66 })
67 .collect();
68
69 let mut workflow_comparisons = Vec::new();
70 let mut task_status_changes = Vec::new();
71 let mut missing_tasks = Vec::new();
72
73 for (record_id, baseline_result) in &baseline_map {
74 if let Some(comparison_result) = comparison_map.get(record_id) {
75 let baseline_task_map: HashMap<_, _> = baseline_result
76 .eval_set
77 .records
78 .iter()
79 .map(|t| (t.task_id.as_str(), t))
80 .collect();
81
82 let comparison_task_map: HashMap<_, _> = comparison_result
83 .eval_set
84 .records
85 .iter()
86 .map(|t| (t.task_id.as_str(), t))
87 .collect();
88
89 let mut workflow_task_comparisons = Vec::new();
90 let mut matched_baseline_passed = 0;
91 let mut matched_comparison_passed = 0;
92 let mut total_matched = 0;
93
94 for (task_id, baseline_task) in &baseline_task_map {
95 if let Some(comparison_task) = comparison_task_map.get(task_id) {
96 let status_changed = baseline_task.passed != comparison_task.passed;
97
98 if baseline_task.passed {
99 matched_baseline_passed += 1;
100 }
101 if comparison_task.passed {
102 matched_comparison_passed += 1;
103 }
104 total_matched += 1;
105
106 let task_comp = TaskComparison {
107 task_id: task_id.to_string(),
108 baseline_passed: baseline_task.passed,
109 comparison_passed: comparison_task.passed,
110 status_changed,
111 record_id: (*record_id).to_string(),
112 };
113
114 workflow_task_comparisons.push(task_comp.clone());
115
116 if status_changed {
117 task_status_changes.push(task_comp.clone());
118 }
119 } else {
120 missing_tasks.push(MissingTask {
121 task_id: task_id.to_string(),
122 present_in: "baseline_only".to_string(),
123 record_id: (*record_id).to_string(),
124 });
125 }
126 }
127
128 for task_id in comparison_task_map.keys() {
129 if !baseline_task_map.contains_key(task_id) {
130 missing_tasks.push(MissingTask {
131 task_id: task_id.to_string(),
132 present_in: "comparison_only".to_string(),
133 record_id: (*record_id).to_string(),
134 });
135 }
136 }
137
138 let baseline_pass_rate = if total_matched > 0 {
139 matched_baseline_passed as f64 / total_matched as f64
140 } else {
141 0.0
142 };
143
144 let comparison_pass_rate = if total_matched > 0 {
145 matched_comparison_passed as f64 / total_matched as f64
146 } else {
147 0.0
148 };
149
150 let pass_rate_delta = comparison_pass_rate - baseline_pass_rate;
151 let is_regression = pass_rate_delta < -regression_threshold;
152
153 workflow_comparisons.push(WorkflowComparison {
154 baseline_id: (*record_id).to_string(),
155 comparison_id: (*record_id).to_string(),
156 baseline_pass_rate,
157 comparison_pass_rate,
158 pass_rate_delta,
159 is_regression,
160 task_comparisons: workflow_task_comparisons,
161 });
162 }
163 }
164
165 let (improved, regressed, unchanged) =
166 workflow_comparisons
167 .iter()
168 .fold((0, 0, 0), |(i, r, u), wc| {
169 if wc.is_regression {
170 (i, r + 1, u)
171 } else if wc.pass_rate_delta > 0.01 {
172 (i + 1, r, u)
173 } else {
174 (i, r, u + 1)
175 }
176 });
177
178 let mean_delta = if !workflow_comparisons.is_empty() {
179 workflow_comparisons
180 .iter()
181 .map(|wc| wc.pass_rate_delta)
182 .sum::<f64>()
183 / workflow_comparisons.len() as f64
184 } else {
185 0.0
186 };
187
188 let has_regressed = regressed > 0;
189
190 Ok(ComparisonResults {
191 workflow_comparisons,
192 total_workflows: baseline_map.len().min(comparison_map.len()),
193 improved_workflows: improved,
194 regressed_workflows: regressed,
195 unchanged_workflows: unchanged,
196 mean_pass_rate_delta: mean_delta,
197 task_status_changes,
198 missing_tasks,
199 baseline_workflow_count: baseline.aligned_results.len(),
200 comparison_workflow_count: comparison.aligned_results.len(),
201 regressed: has_regressed,
202 })
203}