1use std::path::PathBuf;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicUsize, Ordering};
11
12use serde::{Deserialize, Serialize};
13use tokio::sync::Semaphore;
14use tokio_util::sync::CancellationToken;
15use tracing::{debug, info, warn};
16
17use swink_agent::{
18 Agent, AssistantMessage, ContentBlock, Cost, ModelSpec, SessionState, StopReason, Usage,
19 UserMessage,
20};
21
22use crate::cache::{CacheKey, EvaluationDataStore, FingerprintContext};
23use crate::error::EvalError;
24use crate::evaluator::EvaluatorRegistry;
25use crate::score::{Score, Verdict};
26#[cfg(feature = "telemetry")]
27use crate::telemetry::{CaseSpan, EvalsTelemetry, RunSetSpan, RunSetSpanRef};
28use crate::trajectory::TrajectoryCollector;
29use crate::types::{
30 EvalCase, EvalCaseResult, EvalMetricResult, EvalSet, EvalSetResult, EvalSummary, Invocation,
31 TurnRecord,
32};
33
34struct FactoryCancellationGuard(CancellationToken);
35
36impl Drop for FactoryCancellationGuard {
37 fn drop(&mut self) {
38 self.0.cancel();
39 }
40}
41
42pub trait AgentFactory: Send + Sync {
44 fn create_agent(&self, case: &EvalCase) -> Result<(Agent, CancellationToken), EvalError>;
46
47 fn with_initial_session(&self, _state: &SessionState) {}
50
51 fn tool_set_hash(&self, _case: &EvalCase) -> Option<String> {
54 None
55 }
56
57 fn agent_model(&self, _case: &EvalCase) -> Option<String> {
59 None
60 }
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct RunnerMetricSample {
68 pub evaluator_name: String,
70 pub scores: Vec<f64>,
72 pub mean: f64,
74 pub std_dev: f64,
76}
77
78impl RunnerMetricSample {
79 fn from_samples(evaluator_name: String, scores: Vec<f64>) -> Self {
80 #[allow(clippy::cast_precision_loss)]
81 let n = scores.len() as f64;
82 let mean = if scores.is_empty() {
83 0.0
84 } else {
85 scores.iter().sum::<f64>() / n
86 };
87 let std_dev = if scores.len() <= 1 {
88 0.0
89 } else {
90 (scores.iter().map(|s| (s - mean).powi(2)).sum::<f64>() / n).sqrt()
91 };
92 Self {
93 evaluator_name,
94 scores,
95 mean,
96 std_dev,
97 }
98 }
99}
100
101pub struct EvalRunner {
104 registry: EvaluatorRegistry,
105 parallelism: usize,
106 num_runs: u32,
107 cache: Option<Arc<dyn EvaluationDataStore>>,
108 cancel: Option<CancellationToken>,
109 initial_session_file: Option<PathBuf>,
110 agent_invocations: Arc<AtomicUsize>,
111 #[cfg(feature = "telemetry")]
112 telemetry: Option<Arc<EvalsTelemetry>>,
113}
114
115impl EvalRunner {
116 #[must_use]
118 pub fn new(registry: EvaluatorRegistry) -> Self {
119 Self {
120 registry,
121 parallelism: 1,
122 num_runs: 1,
123 cache: None,
124 cancel: None,
125 initial_session_file: None,
126 agent_invocations: Arc::new(AtomicUsize::new(0)),
127 #[cfg(feature = "telemetry")]
128 telemetry: None,
129 }
130 }
131
132 #[must_use]
134 pub fn with_defaults() -> Self {
135 Self::new(EvaluatorRegistry::with_defaults())
136 }
137
138 #[must_use]
144 pub fn with_parallelism(mut self, n: usize) -> Self {
145 assert!(n > 0, "EvalRunner::with_parallelism: n must be > 0");
146 self.parallelism = n;
147 self
148 }
149
150 #[must_use]
156 pub fn with_num_runs(mut self, n: u32) -> Self {
157 assert!(n > 0, "EvalRunner::with_num_runs: n must be > 0");
158 self.num_runs = n;
159 self
160 }
161
162 #[must_use]
165 pub fn with_cache(mut self, store: Arc<dyn EvaluationDataStore>) -> Self {
166 self.cache = Some(store);
167 self
168 }
169
170 #[must_use]
172 pub fn with_cancellation(mut self, token: CancellationToken) -> Self {
173 self.cancel = Some(token);
174 self
175 }
176
177 #[must_use]
181 pub fn with_initial_session_file(mut self, path: PathBuf) -> Self {
182 self.initial_session_file = Some(path);
183 self
184 }
185
186 #[cfg(feature = "telemetry")]
190 #[must_use]
191 pub fn with_telemetry(mut self, telemetry: Arc<EvalsTelemetry>) -> Self {
192 self.telemetry = Some(telemetry);
193 self
194 }
195
196 #[must_use]
198 pub fn agent_invocation_count(&self) -> usize {
199 self.agent_invocations.load(Ordering::SeqCst)
200 }
201
202 pub fn reset_agent_invocation_count(&self) {
204 self.agent_invocations.store(0, Ordering::SeqCst);
205 }
206
207 pub async fn run_case(
209 &self,
210 case: &EvalCase,
211 factory: &dyn AgentFactory,
212 ) -> Result<EvalCaseResult, EvalError> {
213 info!(case_id = %case.id, case_name = %case.name, "running eval case");
214 let initial_session = self.load_initial_session()?;
215 if let Some(state) = &initial_session {
216 factory.with_initial_session(state);
217 }
218 let invocation =
219 invoke_agent_impl(case, factory, self.cancel.as_ref(), &self.agent_invocations).await?;
220 let metric_results = self.registry.evaluate(case, &invocation);
221 Ok(scored_case_result(case, invocation, metric_results))
222 }
223
224 #[allow(clippy::too_many_lines)]
226 pub async fn run_set(
227 &self,
228 eval_set: &EvalSet,
229 factory: &dyn AgentFactory,
230 ) -> Result<EvalSetResult, EvalError> {
231 info!(
232 set_id = %eval_set.id, cases = eval_set.cases.len(),
233 parallelism = self.parallelism, num_runs = self.num_runs,
234 cache = self.cache.is_some(), "running eval set"
235 );
236
237 #[cfg(feature = "telemetry")]
239 let run_set_span: Option<RunSetSpan> = self
240 .telemetry
241 .as_ref()
242 .map(|t| t.start_run_set_span(eval_set));
243
244 let initial_session = self.load_initial_session()?;
245 let initial_session_json = initial_session
246 .as_ref()
247 .map(serde_json::to_value)
248 .transpose()
249 .map_err(EvalError::from)?;
250 if let Some(state) = &initial_session {
251 factory.with_initial_session(state);
252 }
253
254 let semaphore = Arc::new(Semaphore::new(self.parallelism));
255 let eval_set_id = eval_set.id.clone();
256
257 let mut futures_vec = Vec::with_capacity(eval_set.cases.len());
260 for (index, case) in eval_set.cases.iter().enumerate() {
261 let sem = Arc::clone(&semaphore);
262 let cache = self.cache.clone();
263 let registry = &self.registry;
264 let num_runs = self.num_runs;
265 let cancel = self.cancel.clone();
266 let initial_session_value = initial_session_json.clone();
267 let agent_invocations = Arc::clone(&self.agent_invocations);
268 let eval_set_id = eval_set_id.clone();
269 #[cfg(feature = "telemetry")]
270 let telemetry = self.telemetry.clone();
271 #[cfg(feature = "telemetry")]
272 let run_set_context = run_set_span.as_ref().map(|s| RunSetSpanRef {
273 context: s.context().clone(),
274 set_id: eval_set_id.clone(),
275 });
276
277 futures_vec.push(async move {
278 if let Some(tok) = &cancel
279 && tok.is_cancelled()
280 {
281 return (index, cancelled_case_result(case));
282 }
283 let permit = match sem.acquire_owned().await {
284 Ok(p) => p,
285 Err(_) => return (index, cancelled_case_result(case)),
286 };
287 if let Some(tok) = &cancel
288 && tok.is_cancelled()
289 {
290 drop(permit);
291 return (index, cancelled_case_result(case));
292 }
293
294 #[cfg(feature = "telemetry")]
295 let case_span: Option<CaseSpan> = match (&telemetry, &run_set_context) {
296 (Some(t), Some(parent)) => Some(t.start_case_span_raw(parent, case)),
297 _ => None,
298 };
299 #[cfg(feature = "telemetry")]
300 let case_start = std::time::Instant::now();
301
302 let result = execute_case(
303 case,
304 factory,
305 &eval_set_id,
306 cache.as_deref(),
307 registry,
308 num_runs,
309 cancel.as_ref(),
310 initial_session_value.as_ref(),
311 &agent_invocations,
312 #[cfg(feature = "telemetry")]
313 telemetry.as_deref(),
314 #[cfg(feature = "telemetry")]
315 case_span.as_ref(),
316 )
317 .await
318 .unwrap_or_else(|e| error_case_result(case, &e));
319
320 #[cfg(feature = "telemetry")]
321 if let Some(span) = case_span {
322 span.end(&result, case_start.elapsed());
323 }
324
325 drop(permit);
326 (index, result)
327 });
328 }
329
330 let results: Vec<(usize, EvalCaseResult)> = futures::future::join_all(futures_vec).await;
331 let mut ordered: Vec<Option<EvalCaseResult>> =
332 (0..eval_set.cases.len()).map(|_| None).collect();
333 for (index, result) in results {
334 ordered[index] = Some(result);
335 }
336 let case_results: Vec<EvalCaseResult> = ordered
337 .into_iter()
338 .map(|slot| slot.expect("every case produces a result"))
339 .collect();
340
341 let mut total_cost = Cost::default();
342 let mut total_usage = Usage::default();
343 let mut total_duration = std::time::Duration::ZERO;
344 let mut passed = 0usize;
345 let mut failed = 0usize;
346 for result in &case_results {
347 total_cost += result.invocation.total_cost.clone();
348 total_usage += result.invocation.total_usage.clone();
349 total_duration += result.invocation.total_duration;
350 if result.verdict.is_pass() {
351 passed += 1;
352 } else {
353 failed += 1;
354 }
355 }
356 let summary = EvalSummary {
357 total_cases: eval_set.cases.len(),
358 passed,
359 failed,
360 total_cost,
361 total_usage,
362 total_duration,
363 };
364 info!(
365 set_id = %eval_set.id, passed = summary.passed,
366 failed = summary.failed, total = summary.total_cases,
367 "eval set complete"
368 );
369
370 #[cfg(feature = "telemetry")]
371 if let Some(span) = run_set_span {
372 span.end(summary.passed, summary.failed);
373 }
374
375 Ok(EvalSetResult {
376 eval_set_id: eval_set.id.clone(),
377 case_results,
378 summary,
379 timestamp: swink_agent::now_timestamp(),
380 })
381 }
382
383 fn load_initial_session(&self) -> Result<Option<SessionState>, EvalError> {
384 let Some(path) = &self.initial_session_file else {
385 return Ok(None);
386 };
387 let bytes = std::fs::read(path).map_err(|err| {
388 EvalError::invalid_case(format!(
389 "initial_session_file `{}` unreadable: {err}",
390 path.display()
391 ))
392 })?;
393 let state: SessionState = serde_json::from_slice(&bytes).map_err(|err| {
394 EvalError::invalid_case(format!(
395 "initial_session_file `{}` is not valid SessionState JSON: {err}",
396 path.display()
397 ))
398 })?;
399 Ok(Some(state))
400 }
401}
402
403#[allow(clippy::too_many_arguments)]
404async fn execute_case(
405 case: &EvalCase,
406 factory: &dyn AgentFactory,
407 eval_set_id: &str,
408 cache: Option<&(dyn EvaluationDataStore + 'static)>,
409 registry: &EvaluatorRegistry,
410 num_runs: u32,
411 cancel: Option<&CancellationToken>,
412 initial_session_json: Option<&serde_json::Value>,
413 agent_invocations: &AtomicUsize,
414 #[cfg(feature = "telemetry")] telemetry: Option<&EvalsTelemetry>,
415 #[cfg(feature = "telemetry")] case_span: Option<&CaseSpan>,
416) -> Result<EvalCaseResult, EvalError> {
417 info!(case_id = %case.id, case_name = %case.name, "running eval case");
418
419 let fingerprint = case.content_fingerprint();
420 let fp_ctx = FingerprintContext {
421 initial_session: initial_session_json.cloned(),
422 tool_set_hash: factory.tool_set_hash(case),
423 agent_model: factory.agent_model(case),
424 };
425 let cache_key = CacheKey::from_fingerprint(&fingerprint, &fp_ctx);
426
427 let cached = match cache {
428 Some(store) => match store.get(eval_set_id, &case.id, &cache_key) {
429 Ok(v) => v,
430 Err(err) => {
431 warn!(case_id = %case.id, error = %err, "cache read failed");
432 None
433 }
434 },
435 None => None,
436 };
437
438 let invocation = if let Some(inv) = cached {
439 debug!(case_id = %case.id, "cache hit");
440 inv
441 } else {
442 let inv = invoke_agent_impl(case, factory, cancel, agent_invocations).await?;
443 if let Some(store) = cache
444 && let Err(err) = store.put(eval_set_id, &case.id, &cache_key, &inv)
445 {
446 warn!(case_id = %case.id, error = %err, "cache write failed");
447 }
448 inv
449 };
450
451 let metric_results = dispatch_evaluators(
452 registry,
453 case,
454 &invocation,
455 num_runs,
456 cancel,
457 #[cfg(feature = "telemetry")]
458 telemetry,
459 #[cfg(feature = "telemetry")]
460 case_span,
461 );
462 Ok(scored_case_result(case, invocation, metric_results))
463}
464
465fn scored_case_result(
466 case: &EvalCase,
467 invocation: Invocation,
468 mut metric_results: Vec<EvalMetricResult>,
469) -> EvalCaseResult {
470 if metric_results.is_empty() {
471 metric_results.push(no_applicable_evaluators_metric());
472 }
473 let verdict = if metric_results.iter().all(|r| r.score.verdict().is_pass()) {
474 Verdict::Pass
475 } else {
476 Verdict::Fail
477 };
478 EvalCaseResult {
479 case_id: case.id.clone(),
480 invocation,
481 metric_results,
482 verdict,
483 }
484}
485
486fn no_applicable_evaluators_metric() -> EvalMetricResult {
487 EvalMetricResult {
488 evaluator_name: "no_applicable_evaluators".to_string(),
489 score: Score::fail(),
490 details: Some(
491 "no evaluator produced a metric; configure an applicable evaluator or expected criteria"
492 .to_string(),
493 ),
494 }
495}
496
497async fn invoke_agent_impl(
498 case: &EvalCase,
499 factory: &dyn AgentFactory,
500 cancel: Option<&CancellationToken>,
501 agent_invocations: &AtomicUsize,
502) -> Result<Invocation, EvalError> {
503 agent_invocations.fetch_add(1, Ordering::SeqCst);
504 let (mut agent, factory_cancel) = factory.create_agent(case)?;
505 let _factory_cancel = FactoryCancellationGuard(factory_cancel);
506 let messages: Vec<_> = case
507 .user_messages
508 .iter()
509 .map(|text| {
510 swink_agent::AgentMessage::Llm(swink_agent::LlmMessage::User(UserMessage {
511 content: vec![ContentBlock::Text { text: text.clone() }],
512 timestamp: swink_agent::now_timestamp(),
513 cache_hint: None,
514 }))
515 })
516 .collect();
517 let stream = agent.prompt_stream(messages)?;
518 let invocation = if let Some(tok) = cancel {
519 tokio::select! {
520 biased;
521 () = tok.cancelled() => {
522 return Ok(cancellation_placeholder_invocation());
525 }
526 inv = TrajectoryCollector::collect_from_stream(stream) => inv,
527 }
528 } else {
529 TrajectoryCollector::collect_from_stream(stream).await
530 };
531 Ok(invocation)
532}
533
534#[allow(clippy::too_many_arguments)]
535fn dispatch_evaluators(
536 registry: &EvaluatorRegistry,
537 case: &EvalCase,
538 invocation: &Invocation,
539 num_runs: u32,
540 cancel: Option<&CancellationToken>,
541 #[cfg(feature = "telemetry")] telemetry: Option<&EvalsTelemetry>,
542 #[cfg(feature = "telemetry")] case_span: Option<&CaseSpan>,
543) -> Vec<EvalMetricResult> {
544 debug_assert!(num_runs > 0);
545 if num_runs == 1 {
546 return run_registry_once(
547 registry,
548 case,
549 invocation,
550 #[cfg(feature = "telemetry")]
551 telemetry,
552 #[cfg(feature = "telemetry")]
553 case_span,
554 );
555 }
556
557 let mut per_evaluator: std::collections::BTreeMap<String, Vec<EvalMetricResult>> =
558 std::collections::BTreeMap::new();
559 let mut cancelled = false;
560 for run_idx in 0..num_runs {
561 if let Some(tok) = cancel
562 && tok.is_cancelled()
563 {
564 cancelled = true;
565 break;
566 }
567 #[cfg(feature = "telemetry")]
571 let iteration_telemetry = if run_idx == 0 { telemetry } else { None };
572 #[cfg(feature = "telemetry")]
573 let iteration_case_span = if run_idx == 0 { case_span } else { None };
574 let iteration = run_registry_once(
575 registry,
576 case,
577 invocation,
578 #[cfg(feature = "telemetry")]
579 iteration_telemetry,
580 #[cfg(feature = "telemetry")]
581 iteration_case_span,
582 );
583 for metric in iteration {
584 per_evaluator
585 .entry(metric.evaluator_name.clone())
586 .or_default()
587 .push(metric);
588 }
589 debug!(case_id = %case.id, run = run_idx + 1, "num_runs sample recorded");
590 }
591
592 let mut aggregated: Vec<EvalMetricResult> = per_evaluator
593 .into_iter()
594 .map(|(name, samples)| {
595 let scores: Vec<f64> = samples.iter().map(|m| m.score.value).collect();
596 let threshold = samples.first().map_or(0.5, |m| m.score.threshold);
597 let sample = RunnerMetricSample::from_samples(name.clone(), scores);
598 let mut detail_lines = vec![format!(
599 "num_runs={} mean={:.4} std_dev={:.4}",
600 sample.scores.len(),
601 sample.mean,
602 sample.std_dev
603 )];
604 let prior: Vec<_> = samples.iter().filter_map(|m| m.details.clone()).collect();
605 if !prior.is_empty() {
606 detail_lines.push(prior.join(" | "));
607 }
608 EvalMetricResult {
609 evaluator_name: name,
610 score: Score::new(sample.mean, threshold),
611 details: Some(detail_lines.join(" :: ")),
612 }
613 })
614 .collect();
615
616 if cancelled {
617 aggregated.push(cancelled_metric_result(
618 "runner cancellation observed during multi-run evaluator dispatch",
619 ));
620 }
621
622 aggregated
623}
624
625fn run_registry_once(
629 registry: &EvaluatorRegistry,
630 case: &EvalCase,
631 invocation: &Invocation,
632 #[cfg(feature = "telemetry")] telemetry: Option<&EvalsTelemetry>,
633 #[cfg(feature = "telemetry")] case_span: Option<&CaseSpan>,
634) -> Vec<EvalMetricResult> {
635 #[cfg(feature = "telemetry")]
636 if let (Some(t), Some(parent)) = (telemetry, case_span) {
637 return registry.evaluate_instrumented(case, invocation, |name, run| {
638 let span = t.start_evaluator_span(parent, name);
639 let outcome = run();
640 match outcome.as_ref() {
641 Some(metric) => span.end(metric),
642 None => span.end_inapplicable(name),
643 }
644 outcome
645 });
646 }
647 registry.evaluate(case, invocation)
648}
649
650fn cancelled_case_result(case: &EvalCase) -> EvalCaseResult {
651 EvalCaseResult {
652 case_id: case.id.clone(),
653 invocation: error_invocation(None),
654 metric_results: vec![cancelled_metric_result(
655 "runner cancellation observed before case completion",
656 )],
657 verdict: Verdict::Fail,
658 }
659}
660
661fn cancelled_metric_result(details: &str) -> EvalMetricResult {
662 EvalMetricResult {
663 evaluator_name: "cancelled".to_string(),
664 score: Score::fail(),
665 details: Some(details.to_string()),
666 }
667}
668
669fn error_case_result(case: &EvalCase, err: &EvalError) -> EvalCaseResult {
670 warn!(case_id = %case.id, error = %err, "eval case failed with error");
671 EvalCaseResult {
672 case_id: case.id.clone(),
673 invocation: error_invocation(Some(err.to_string())),
674 metric_results: vec![EvalMetricResult {
675 evaluator_name: "error".to_string(),
676 score: Score::fail(),
677 details: Some(err.to_string()),
678 }],
679 verdict: Verdict::Fail,
680 }
681}
682
683fn cancellation_placeholder_invocation() -> Invocation {
684 error_invocation(None)
685}
686
687fn error_invocation(error_message: Option<String>) -> Invocation {
688 let turns = error_message
689 .map(|msg| {
690 vec![TurnRecord {
691 turn_index: 0,
692 assistant_message: AssistantMessage {
693 content: vec![],
694 provider: String::new(),
695 model_id: String::new(),
696 usage: Usage::default(),
697 cost: Cost::default(),
698 stop_reason: StopReason::Error,
699 error_message: Some(msg),
700 error_kind: None,
701 timestamp: swink_agent::now_timestamp(),
702 cache_hint: None,
703 },
704 tool_calls: vec![],
705 tool_results: vec![],
706 duration: std::time::Duration::ZERO,
707 }]
708 })
709 .unwrap_or_default();
710 Invocation {
711 turns,
712 total_usage: Usage::default(),
713 total_cost: Cost::default(),
714 total_duration: std::time::Duration::ZERO,
715 final_response: None,
716 stop_reason: StopReason::Error,
717 model: ModelSpec::new("unknown", "unknown"),
718 }
719}