1use std::sync::Arc;
24
25use async_trait::async_trait;
26
27use crate::{EvalReport, EvalSuite, TestResult};
28
29#[async_trait]
33pub trait AsyncMetric: Send + Sync + 'static {
34 fn name(&self) -> &'static str;
36
37 async fn score(&self, input: &str, actual_output: &str, expected_keywords: &[&str]) -> f64;
41}
42
43#[async_trait]
47pub trait EvalAgent: Send + Sync {
48 async fn respond(&self, input: &str) -> traitclaw_core::Result<String>;
50}
51
52pub struct EvalRunner {
54 metrics: Vec<Arc<dyn AsyncMetric>>,
55 threshold: f64,
56}
57
58impl EvalRunner {
59 #[must_use]
61 pub fn new() -> Self {
62 Self {
63 metrics: Vec::new(),
64 threshold: 0.7,
65 }
66 }
67
68 #[must_use]
70 pub fn metric(mut self, metric: Box<dyn AsyncMetric>) -> Self {
71 self.metrics.push(Arc::from(metric));
72 self
73 }
74
75 #[must_use]
79 pub fn threshold(mut self, threshold: f64) -> Self {
80 self.threshold = threshold;
81 self
82 }
83
84 pub async fn run(
93 &self,
94 agent: &dyn EvalAgent,
95 suite: &EvalSuite,
96 ) -> traitclaw_core::Result<EvalReport> {
97 let mut results = Vec::new();
98 let mut total_score = 0.0;
99 let mut passed_count = 0;
100 let mut score_count = 0;
101
102 for case in suite.cases() {
103 let actual_output = agent.respond(&case.input).await?;
104
105 let keywords: Vec<&str> = case.expected_keywords.iter().map(String::as_str).collect();
106
107 let mut scores = std::collections::HashMap::new();
108 for metric in &self.metrics {
109 let s = metric.score(&case.input, &actual_output, &keywords).await;
110 scores.insert(metric.name().to_string(), s);
111 total_score += s;
112 score_count += 1;
113 }
114
115 if self.metrics.is_empty() {
117 let kw_score = score_keywords(&actual_output, &keywords);
118 scores.insert("keyword_match".to_string(), kw_score);
119 total_score += kw_score;
120 score_count += 1;
121 }
122
123 let all_pass = scores.values().all(|&s| s >= self.threshold);
124 if all_pass {
125 passed_count += 1;
126 }
127
128 results.push(TestResult {
129 case_id: case.id.clone(),
130 actual_output,
131 scores,
132 passed: all_pass,
133 });
134 }
135
136 let average_score = if score_count > 0 {
137 total_score / score_count as f64
138 } else {
139 0.0
140 };
141
142 Ok(EvalReport {
143 suite_name: suite.name().to_string(),
144 results,
145 average_score,
146 passed: passed_count,
147 total: suite.cases().len(),
148 })
149 }
150}
151
152impl Default for EvalRunner {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158fn score_keywords(output: &str, keywords: &[&str]) -> f64 {
159 if keywords.is_empty() {
160 return 1.0;
161 }
162 let lower = output.to_lowercase();
163 let matched = keywords.iter().filter(|&&kw| lower.contains(kw)).count();
164 matched as f64 / keywords.len() as f64
165}
166
167pub struct SyncMetricAdapter<M: crate::Metric>(pub M);
171
172#[async_trait]
173impl<M: crate::Metric> AsyncMetric for SyncMetricAdapter<M> {
174 fn name(&self) -> &'static str {
175 self.0.name()
176 }
177
178 async fn score(&self, input: &str, actual_output: &str, expected_keywords: &[&str]) -> f64 {
179 self.0.score(input, actual_output, expected_keywords)
180 }
181}
182
183#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::{EvalSuite, TestCase};
191
192 struct EchoAgent;
193
194 #[async_trait]
195 impl EvalAgent for EchoAgent {
196 async fn respond(&self, input: &str) -> traitclaw_core::Result<String> {
197 Ok(format!("echo: {input}"))
198 }
199 }
200
201 struct FixedMetric(f64, &'static str);
202
203 #[async_trait]
204 impl AsyncMetric for FixedMetric {
205 fn name(&self) -> &'static str {
206 self.1
207 }
208 async fn score(&self, _: &str, _: &str, _: &[&str]) -> f64 {
209 self.0
210 }
211 }
212
213 struct KeywordAsyncMetric;
214
215 #[async_trait]
216 impl AsyncMetric for KeywordAsyncMetric {
217 fn name(&self) -> &'static str {
218 "keyword"
219 }
220 async fn score(&self, _: &str, output: &str, kw: &[&str]) -> f64 {
221 if kw.is_empty() {
222 return 1.0;
223 }
224 let low = output.to_lowercase();
225 let m = kw.iter().filter(|&&k| low.contains(k)).count();
226 m as f64 / kw.len() as f64
227 }
228 }
229
230 #[tokio::test]
231 async fn test_eval_runner_three_cases() {
232 let suite = EvalSuite::new("suite")
234 .add_case(TestCase::new("c1", "hello").expect_contains("echo"))
235 .add_case(TestCase::new("c2", "world").expect_contains("echo"))
236 .add_case(TestCase::new("c3", "foo").expect_contains("echo"));
237
238 let runner = EvalRunner::new()
239 .metric(Box::new(KeywordAsyncMetric))
240 .threshold(0.8);
241
242 let report = runner.run(&EchoAgent, &suite).await.unwrap();
243
244 assert_eq!(report.results.len(), 3);
245 assert_eq!(report.total, 3);
246 assert_eq!(report.passed, 3);
248 }
249
250 #[tokio::test]
251 async fn test_eval_runner_threshold_fail() {
252 let suite =
254 EvalSuite::new("s").add_case(TestCase::new("c1", "hello").expect_contains("xyzabc")); let runner = EvalRunner::new()
257 .metric(Box::new(KeywordAsyncMetric))
258 .threshold(0.8);
259
260 let report = runner.run(&EchoAgent, &suite).await.unwrap();
261 assert_eq!(report.passed, 0, "case with 0.0 keyword score should fail");
262 }
263
264 #[tokio::test]
265 async fn test_eval_runner_average_score() {
266 let suite = EvalSuite::new("s")
268 .add_case(TestCase::new("c1", "hello"))
269 .add_case(TestCase::new("c2", "world"));
270
271 let runner = EvalRunner::new()
272 .metric(Box::new(FixedMetric(0.8, "m")))
273 .threshold(0.7);
274
275 let report = runner.run(&EchoAgent, &suite).await.unwrap();
276 assert!((report.average_score - 0.8).abs() < 1e-6);
277 assert_eq!(report.passed, 2);
278 }
279
280 #[tokio::test]
281 async fn test_sync_metric_adapter() {
282 let adapter = SyncMetricAdapter(crate::KeywordMetric);
283 let score = adapter.score("in", "hello world", &["hello"]).await;
284 assert!((score - 1.0).abs() < 1e-6);
285 }
286
287 #[tokio::test]
288 async fn test_empty_suite_gives_zero_results() {
289 let suite = EvalSuite::new("empty");
290 let runner = EvalRunner::new().metric(Box::new(KeywordAsyncMetric));
291 let report = runner.run(&EchoAgent, &suite).await.unwrap();
292 assert_eq!(report.results.len(), 0);
293 assert_eq!(report.total, 0);
294 }
295}