swink_agent_eval/
testing.rs1#![forbid(unsafe_code)]
11
12use std::sync::Mutex;
13use std::time::Duration;
14
15use crate::judge::{JudgeClient, JudgeError, JudgeFuture, JudgeVerdict};
16
17pub struct MockJudge {
39 inner: Mutex<MockInner>,
40}
41
42struct MockInner {
43 outcomes: Vec<MockOutcome>,
45 tail: MockOutcome,
47 calls: usize,
49}
50
51enum MockOutcome {
52 Verdict(JudgeVerdict),
53 Error(JudgeError),
54}
55
56impl MockOutcome {
57 fn clone_boxed(&self) -> Self {
58 match self {
59 Self::Verdict(v) => Self::Verdict(v.clone()),
60 Self::Error(e) => Self::Error(clone_judge_error(e)),
61 }
62 }
63}
64
65fn clone_judge_error(err: &JudgeError) -> JudgeError {
66 match err {
67 JudgeError::Transport(s) => JudgeError::Transport(s.clone()),
68 JudgeError::Timeout => JudgeError::Timeout,
69 JudgeError::MalformedResponse(s) => JudgeError::MalformedResponse(s.clone()),
70 JudgeError::Other(s) => JudgeError::Other(s.clone()),
71 }
72}
73
74impl MockJudge {
75 #[must_use]
78 pub fn with_verdicts(verdicts: Vec<JudgeVerdict>) -> Self {
79 let outcomes = verdicts.into_iter().map(MockOutcome::Verdict).collect();
80 Self::new(
81 outcomes,
82 MockOutcome::Error(JudgeError::Other(
83 "MockJudge outcome queue exhausted".into(),
84 )),
85 )
86 }
87
88 #[must_use]
92 pub const fn always_err(err: JudgeError) -> Self {
93 Self::new(Vec::new(), MockOutcome::Error(err))
94 }
95
96 #[must_use]
98 pub fn always_pass() -> Self {
99 Self::new(
100 Vec::new(),
101 MockOutcome::Verdict(JudgeVerdict {
102 score: 1.0,
103 pass: true,
104 reason: Some("mock pass".into()),
105 label: None,
106 }),
107 )
108 }
109
110 #[must_use]
112 pub fn always_fail() -> Self {
113 Self::new(
114 Vec::new(),
115 MockOutcome::Verdict(JudgeVerdict {
116 score: 0.0,
117 pass: false,
118 reason: Some("mock fail".into()),
119 label: None,
120 }),
121 )
122 }
123
124 #[must_use]
128 pub fn mixed_sequence(sequence: Vec<Result<JudgeVerdict, JudgeError>>) -> Self {
129 let outcomes = sequence
130 .into_iter()
131 .map(|r| match r {
132 Ok(v) => MockOutcome::Verdict(v),
133 Err(e) => MockOutcome::Error(e),
134 })
135 .collect();
136 Self::new(
137 outcomes,
138 MockOutcome::Error(JudgeError::Other(
139 "MockJudge outcome queue exhausted".into(),
140 )),
141 )
142 }
143
144 const fn new(outcomes: Vec<MockOutcome>, tail: MockOutcome) -> Self {
145 Self {
146 inner: Mutex::new(MockInner {
147 outcomes,
148 tail,
149 calls: 0,
150 }),
151 }
152 }
153
154 #[must_use]
156 pub fn call_count(&self) -> usize {
157 self.inner.lock().map(|g| g.calls).unwrap_or_default()
158 }
159}
160
161impl JudgeClient for MockJudge {
162 fn judge<'a>(&'a self, _prompt: &'a str) -> JudgeFuture<'a> {
163 Box::pin(async move {
164 let outcome = {
165 let mut guard = self.inner.lock().expect("MockJudge mutex poisoned");
166 guard.calls += 1;
167 if guard.outcomes.is_empty() {
168 guard.tail.clone_boxed()
169 } else {
170 guard.outcomes.remove(0)
171 }
172 };
173 match outcome {
174 MockOutcome::Verdict(v) => Ok(v),
175 MockOutcome::Error(e) => Err(e),
176 }
177 })
178 }
179}
180
181pub struct SlowMockJudge {
199 sleep: Duration,
200}
201
202impl SlowMockJudge {
203 #[must_use]
206 pub const fn new(sleep: Duration) -> Self {
207 Self { sleep }
208 }
209}
210
211impl JudgeClient for SlowMockJudge {
212 fn judge<'a>(&'a self, _prompt: &'a str) -> JudgeFuture<'a> {
213 Box::pin(async move {
214 tokio::time::sleep(self.sleep).await;
215 Ok(JudgeVerdict {
216 score: 1.0,
217 pass: true,
218 reason: Some("slow pass".into()),
219 label: None,
220 })
221 })
222 }
223}
224
225pub struct PanickingMockJudge {
243 message: &'static str,
244}
245
246impl PanickingMockJudge {
247 #[must_use]
250 pub const fn new() -> Self {
251 Self {
252 message: "judge panic",
253 }
254 }
255
256 #[must_use]
258 pub const fn with_message(message: &'static str) -> Self {
259 Self { message }
260 }
261}
262
263impl Default for PanickingMockJudge {
264 fn default() -> Self {
265 Self::new()
266 }
267}
268
269impl JudgeClient for PanickingMockJudge {
270 fn judge<'a>(&'a self, _prompt: &'a str) -> JudgeFuture<'a> {
271 Box::pin(async move { panic!("{}", self.message) })
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 fn verdict(pass: bool) -> JudgeVerdict {
280 JudgeVerdict {
281 score: if pass { 1.0 } else { 0.0 },
282 pass,
283 reason: None,
284 label: None,
285 }
286 }
287
288 #[tokio::test]
289 async fn with_verdicts_replays_in_order() {
290 let judge = MockJudge::with_verdicts(vec![verdict(true), verdict(false)]);
291 let v1 = judge.judge("a").await.unwrap();
292 assert!(v1.pass);
293 let v2 = judge.judge("b").await.unwrap();
294 assert!(!v2.pass);
295 }
296
297 #[tokio::test]
298 async fn with_verdicts_tail_errors_when_exhausted() {
299 let judge = MockJudge::with_verdicts(vec![verdict(true)]);
300 let _ = judge.judge("a").await.unwrap();
301 let err = judge.judge("b").await.unwrap_err();
302 match err {
303 JudgeError::Other(msg) => assert!(msg.contains("exhausted")),
304 other => panic!("expected Other, got {other:?}"),
305 }
306 }
307
308 #[tokio::test]
309 async fn always_err_returns_configured_variant() {
310 let judge = MockJudge::always_err(JudgeError::Timeout);
311 for _ in 0..3 {
312 match judge.judge("x").await {
313 Err(JudgeError::Timeout) => {}
314 other => panic!("expected Timeout, got {other:?}"),
315 }
316 }
317 }
318
319 #[tokio::test]
320 async fn always_pass_fail_return_canned_verdicts() {
321 let pass = MockJudge::always_pass();
322 let p = pass.judge("x").await.unwrap();
323 assert!(p.pass);
324 let fail = MockJudge::always_fail();
325 let f = fail.judge("x").await.unwrap();
326 assert!(!f.pass);
327 }
328
329 #[tokio::test]
330 async fn mixed_sequence_preserves_order() {
331 let judge = MockJudge::mixed_sequence(vec![
332 Ok(verdict(true)),
333 Err(JudgeError::MalformedResponse("bad".into())),
334 Ok(verdict(false)),
335 ]);
336 assert!(judge.judge("a").await.unwrap().pass);
337 match judge.judge("b").await.unwrap_err() {
338 JudgeError::MalformedResponse(m) => assert_eq!(m, "bad"),
339 other => panic!("expected MalformedResponse, got {other:?}"),
340 }
341 assert!(!judge.judge("c").await.unwrap().pass);
342 }
343
344 #[tokio::test]
345 async fn call_count_tracks_invocations() {
346 let judge = MockJudge::always_pass();
347 assert_eq!(judge.call_count(), 0);
348 let _ = judge.judge("a").await;
349 let _ = judge.judge("b").await;
350 assert_eq!(judge.call_count(), 2);
351 }
352
353 #[tokio::test]
354 async fn dyn_dispatch_compiles() {
355 use std::sync::Arc;
356 let judge: Arc<dyn JudgeClient> = Arc::new(MockJudge::always_pass());
357 let _ = judge.judge("prompt").await.unwrap();
358 }
359}