swink_agent_eval/evaluators/code/
extractor.rs1use regex::Regex;
9use std::sync::Arc;
10
11use crate::judge::JudgeClient;
12
13#[derive(Clone)]
15pub enum CodeExtractorStrategy {
16 MarkdownFence {
19 language: Option<String>,
21 },
22 Regex { pattern: Regex },
24 Llm {
30 prompt: String,
32 judge: Arc<dyn JudgeClient>,
34 },
35}
36
37impl std::fmt::Debug for CodeExtractorStrategy {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 Self::MarkdownFence { language } => f
41 .debug_struct("MarkdownFence")
42 .field("language", language)
43 .finish(),
44 Self::Regex { pattern } => f
45 .debug_struct("Regex")
46 .field("pattern", &pattern.as_str())
47 .finish(),
48 Self::Llm { prompt, .. } => f
49 .debug_struct("Llm")
50 .field("prompt_len", &prompt.len())
51 .finish(),
52 }
53 }
54}
55
56pub struct CodeExtractor {
63 strategy: CodeExtractorStrategy,
64}
65
66impl CodeExtractor {
67 #[must_use]
69 pub const fn new(strategy: CodeExtractorStrategy) -> Self {
70 Self { strategy }
71 }
72
73 #[must_use]
75 pub const fn markdown_fence() -> Self {
76 Self::new(CodeExtractorStrategy::MarkdownFence { language: None })
77 }
78
79 pub async fn extract(&self, response: &str) -> Option<String> {
81 match &self.strategy {
82 CodeExtractorStrategy::MarkdownFence { language } => {
83 extract_markdown_fence(response, language.as_deref())
84 }
85 CodeExtractorStrategy::Regex { pattern } => {
86 pattern.captures(response).and_then(|caps| {
87 caps.get(1)
88 .or_else(|| caps.get(0))
89 .map(|m| m.as_str().to_string())
90 })
91 }
92 CodeExtractorStrategy::Llm { prompt, judge } => {
93 let rendered = format!("{prompt}\n\n---\n{response}");
94 match judge.judge(&rendered).await {
95 Ok(verdict) => verdict.reason,
96 Err(_) => None,
97 }
98 }
99 }
100 }
101}
102
103fn extract_markdown_fence(response: &str, required_language: Option<&str>) -> Option<String> {
104 let mut lines = response.lines();
105 while let Some(line) = lines.next() {
106 let trimmed = line.trim_start();
107 let Some(rest) = trimmed.strip_prefix("```") else {
108 continue;
109 };
110 let tag = rest.trim();
111 if let Some(expected) = required_language
112 && !tag.eq_ignore_ascii_case(expected)
113 {
114 continue;
115 }
116 let mut body = String::new();
117 for inner in lines.by_ref() {
118 let inner_trimmed = inner.trim_start();
119 if inner_trimmed.starts_with("```") {
120 return Some(body.trim_end_matches('\n').to_string());
121 }
122 body.push_str(inner);
123 body.push('\n');
124 }
125 return Some(body.trim_end_matches('\n').to_string());
127 }
128 None
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn markdown_fence_extracts_first_block() {
137 let response =
138 "Here is the code:\n\n```rust\nfn add(a: i32, b: i32) -> i32 { a + b }\n```\n";
139 let out = extract_markdown_fence(response, Some("rust"));
140 assert_eq!(
141 out.as_deref(),
142 Some("fn add(a: i32, b: i32) -> i32 { a + b }")
143 );
144 }
145
146 #[test]
147 fn markdown_fence_skips_non_matching_language() {
148 let response = "```python\nprint('hi')\n```\n\n```rust\nfn a() {}\n```\n";
149 let out = extract_markdown_fence(response, Some("rust"));
150 assert_eq!(out.as_deref(), Some("fn a() {}"));
151 }
152
153 #[test]
154 fn markdown_fence_ignores_language_when_none() {
155 let response = "```\nanything\n```";
156 let out = extract_markdown_fence(response, None);
157 assert_eq!(out.as_deref(), Some("anything"));
158 }
159}