1use std::panic::{AssertUnwindSafe, catch_unwind};
7
8use regex::Regex;
9use swink_agent::prefix_chars;
10
11use crate::evaluator::Evaluator;
12use crate::score::Score;
13use crate::types::{EvalCase, EvalMetricResult, Invocation, ResponseCriteria};
14
15pub struct ResponseMatcher;
19
20impl Evaluator for ResponseMatcher {
21 fn name(&self) -> &'static str {
22 "response"
23 }
24
25 fn evaluate(&self, case: &EvalCase, invocation: &Invocation) -> Option<EvalMetricResult> {
26 let criteria = case.expected_response.as_ref()?;
27 let actual = invocation.final_response.as_deref().unwrap_or("");
28
29 let (score, details) = match criteria {
30 ResponseCriteria::Exact { expected } => {
31 if actual == expected {
32 (Score::pass(), "exact match".to_string())
33 } else {
34 (
35 Score::fail(),
36 format!("expected exact match, got: {}", truncate(actual, 100)),
37 )
38 }
39 }
40 ResponseCriteria::Contains { substring } => {
41 if actual.contains(substring.as_str()) {
42 (Score::pass(), format!("contains \"{substring}\""))
43 } else {
44 (
45 Score::fail(),
46 format!(
47 "expected to contain \"{substring}\", got: {}",
48 truncate(actual, 100)
49 ),
50 )
51 }
52 }
53 ResponseCriteria::Regex { pattern } => match Regex::new(pattern) {
54 Ok(re) => {
55 if re.is_match(actual) {
56 (Score::pass(), format!("matches pattern /{pattern}/"))
57 } else {
58 (
59 Score::fail(),
60 format!("does not match /{pattern}/, got: {}", truncate(actual, 100)),
61 )
62 }
63 }
64 Err(e) => (Score::fail(), format!("invalid regex: {e}")),
65 },
66 ResponseCriteria::Custom(f) => match catch_unwind(AssertUnwindSafe(|| f(actual))) {
67 Ok(score) => {
68 let details = format!("custom score: {:.2}", score.value);
69 (score, details)
70 }
71 Err(payload) => {
72 let msg = payload
73 .downcast_ref::<&str>()
74 .copied()
75 .or_else(|| payload.downcast_ref::<String>().map(String::as_str))
76 .unwrap_or("unknown panic");
77 (Score::fail(), format!("custom matcher panicked: {msg}"))
78 }
79 },
80 };
81
82 Some(EvalMetricResult {
83 evaluator_name: "response".to_string(),
84 score,
85 details: Some(details),
86 })
87 }
88}
89
90fn truncate(s: &str, max_len: usize) -> String {
92 if s.chars().count() <= max_len {
93 s.to_string()
94 } else {
95 format!("{}...", prefix_chars(s, max_len))
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102
103 use std::sync::Arc;
104 use std::time::Duration;
105
106 use swink_agent::{AssistantMessage, ContentBlock, Cost, ModelSpec, StopReason, Usage};
107
108 use crate::types::{EvalCase, Invocation, TurnRecord};
109
110 fn minimal_case_with_response(criteria: ResponseCriteria) -> EvalCase {
111 EvalCase {
112 id: "test".to_string(),
113 name: "Test".to_string(),
114 description: None,
115 system_prompt: "test".to_string(),
116 user_messages: vec!["test".to_string()],
117 expected_trajectory: None,
118 expected_response: Some(criteria),
119 expected_assertion: None,
120 expected_interactions: None,
121 few_shot_examples: vec![],
122 budget: None,
123 evaluators: vec![],
124 metadata: serde_json::Value::Null,
125 attachments: vec![],
126 session_id: None,
127 expected_environment_state: None,
128 expected_tool_intent: None,
129 semantic_tool_selection: false,
130 state_capture: None,
131 }
132 }
133
134 fn invocation_with_response(text: &str) -> Invocation {
135 Invocation {
136 turns: vec![TurnRecord {
137 turn_index: 0,
138 assistant_message: AssistantMessage {
139 content: vec![ContentBlock::Text {
140 text: text.to_string(),
141 }],
142 provider: "test".to_string(),
143 model_id: "test-model".to_string(),
144 usage: Usage::default(),
145 cost: Cost::default(),
146 stop_reason: StopReason::Stop,
147 error_message: None,
148 error_kind: None,
149 timestamp: 0,
150 cache_hint: None,
151 },
152 tool_calls: vec![],
153 tool_results: vec![],
154 duration: Duration::from_millis(10),
155 }],
156 total_usage: Usage::default(),
157 total_cost: Cost::default(),
158 total_duration: Duration::from_millis(10),
159 final_response: Some(text.to_string()),
160 stop_reason: StopReason::Stop,
161 model: ModelSpec::new("test", "test-model"),
162 }
163 }
164
165 #[test]
166 fn truncate_short_string() {
167 assert_eq!(truncate("hello", 10), "hello");
168 }
169
170 #[test]
171 fn truncate_long_string() {
172 let long = "a".repeat(200);
173 let result = truncate(&long, 100);
174 assert_eq!(result.len(), 103); assert!(result.ends_with("..."));
176 }
177
178 #[test]
179 fn truncate_multibyte_string_is_utf8_safe() {
180 let text = format!("{}🙂tail", "a".repeat(99));
181 let result = truncate(&text, 100);
182 assert_eq!(result, format!("{}🙂...", "a".repeat(99)));
183 }
184
185 #[test]
186 fn custom_fn_panic_caught_as_failure() {
187 let criteria = ResponseCriteria::Custom(Arc::new(|_: &str| -> Score {
188 panic!("deliberate test panic");
189 }));
190 let case = minimal_case_with_response(criteria);
191 let invocation = invocation_with_response("anything");
192
193 let result = ResponseMatcher.evaluate(&case, &invocation).unwrap();
194 assert!((result.score.value - 0.0).abs() < f64::EPSILON);
195 let details = result.details.unwrap();
196 assert!(
197 details.contains("panicked"),
198 "expected panic mention, got: {details}"
199 );
200 assert!(
201 details.contains("deliberate test panic"),
202 "expected panic message, got: {details}"
203 );
204 }
205}