split_brain_harness/
code_gen.rs1use crate::backends::InferenceEngine;
7use crate::capability::CapabilityRequest;
8use crate::static_analysis::{self, StaticAnalysisReport};
9use crate::types::Soul;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Serialize, Deserialize, Clone)]
21pub struct GeneratedTool {
22 pub source: String,
24 pub function_name: String,
26 pub tests_included: bool,
28 pub test_count: usize,
30 pub static_analysis: StaticAnalysisReport,
32}
33
34pub struct CodeGenerator<'e> {
40 engine: &'e dyn InferenceEngine,
41 soul: &'e Soul,
42}
43
44impl<'e> CodeGenerator<'e> {
45 pub fn new(engine: &'e dyn InferenceEngine, soul: &'e Soul) -> Self {
46 Self { engine, soul }
47 }
48
49 pub async fn generate(&self, req: &CapabilityRequest) -> Result<GeneratedTool, String> {
55 let prompt = build_prompt(req);
56
57 let raw = self
58 .engine
59 .generate(&self.soul.code_gen_system_prompt, &prompt)
60 .await?;
61
62 let source = extract_code_block(&raw).ok_or_else(|| {
63 format!(
64 "model did not return a Rust code block \
65 (expected ```rust ... ```) — raw response length: {} chars",
66 raw.len()
67 )
68 })?;
69
70 let function_name = extract_function_name(&source).unwrap_or_else(|| "unknown".to_string());
71 let test_count = static_analysis::test_count(&source);
72 let tests_included = test_count >= 2;
73 let sa = static_analysis::check(&source);
74
75 Ok(GeneratedTool {
76 source,
77 function_name,
78 tests_included,
79 test_count,
80 static_analysis: sa,
81 })
82 }
83}
84
85pub fn build_prompt(req: &CapabilityRequest) -> String {
91 format!(
92 "<capability_request>\n\
93 capability: {cap}\n\
94 input_contract: {inp}\n\
95 output_contract: {out}\n\
96 reason: {reason}\n\
97 constraints:\n\
98 no_network: {no_net}\n\
99 read_only_input: {ro}\n\
100 max_runtime_ms: {rt}\n\
101 max_memory_mb: {mem}\n\
102 </capability_request>\n\n\
103 Respond with ONLY a ```rust ... ``` code block. No prose before or after it.",
104 cap = req.capability,
105 inp = req.input_contract,
106 out = req.output_contract,
107 reason = req.reason,
108 no_net = req.constraints.no_network,
109 ro = req.constraints.read_only_input,
110 rt = req.constraints.max_runtime_ms,
111 mem = req.constraints.max_memory_mb,
112 )
113}
114
115pub fn extract_code_block(response: &str) -> Option<String> {
123 let normalised = response.replace("\r\n", "\n");
125
126 let (marker, skip): (&str, usize) = if normalised.contains("```rust\n") {
128 ("```rust\n", "```rust\n".len())
129 } else if normalised.contains("```Rust\n") {
130 ("```Rust\n", "```Rust\n".len())
131 } else {
132 return None;
133 };
134
135 let start = normalised.find(marker)? + skip;
136 let rest = &normalised[start..];
137 let end = rest.find("```")?;
138 let code = rest[..end].trim().to_string();
139 if code.is_empty() {
140 return None;
141 }
142 Some(code)
143}
144
145pub fn extract_function_name(source: &str) -> Option<String> {
147 for line in source.lines() {
148 let t = line.trim();
149 let rest = t.strip_prefix("pub fn ").or_else(|| t.strip_prefix("fn "));
150 if let Some(rest) = rest {
151 let name: String = rest
152 .chars()
153 .take_while(|c| c.is_alphanumeric() || *c == '_')
154 .collect();
155 if !name.is_empty() {
156 return Some(name);
157 }
158 }
159 }
160 None
161}
162
163#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::capability::CapabilityConstraints;
171
172 fn clean_req() -> CapabilityRequest {
173 CapabilityRequest {
174 kind: "capability_request".into(),
175 capability: "word_count".into(),
176 input_contract: "utf8 text".into(),
177 output_contract: "json object with word_count".into(),
178 constraints: CapabilityConstraints::default(),
179 reason: "text reasoning insufficient".into(),
180 }
181 }
182
183 const CLEAN_RESPONSE: &str = r#"Here is the Rust implementation:
184
185```rust
186pub fn run(input: &str) -> Result<String, String> {
187 let count = input.split_whitespace().count();
188 Ok(format!("{\"word_count\":{}}", count))
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 #[test]
195 fn two_words() {
196 assert!(run("hello world").unwrap().contains("2"));
197 }
198 #[test]
199 fn empty() {
200 assert!(run("").unwrap().contains("0"));
201 }
202}
203```
204
205That should fulfil the contract.
206"#;
207
208 #[test]
211 fn extracts_rust_code_block() {
212 let source = extract_code_block(CLEAN_RESPONSE).unwrap();
213 assert!(
214 source.contains("pub fn run"),
215 "extracted code must contain the function"
216 );
217 assert!(
218 !source.contains("```"),
219 "backticks must not appear in extracted code"
220 );
221 }
222
223 #[test]
224 fn returns_none_when_no_code_block() {
225 let response = "Here is some analysis but no code.";
226 assert!(extract_code_block(response).is_none());
227 }
228
229 #[test]
230 fn returns_none_for_empty_code_block() {
231 let response = "```rust\n```";
232 assert!(extract_code_block(response).is_none());
233 }
234
235 #[test]
236 fn extract_code_block_ignores_leading_prose() {
237 let r = "Some explanation.\n\n```rust\nfn run(i: &str) -> Result<String, String> { Ok(i.into()) }\n```\n";
238 let code = extract_code_block(r).unwrap();
239 assert!(code.starts_with("fn run"));
240 }
241
242 #[test]
245 fn extracts_pub_fn_name() {
246 let src = "pub fn run(input: &str) -> Result<String, String> {\n Ok(\"ok\".into())\n}";
247 assert_eq!(extract_function_name(src).unwrap(), "run");
248 }
249
250 #[test]
251 fn extracts_private_fn_name() {
252 let src = "fn process(input: &str) -> Result<String, String> {\n Ok(\"ok\".into())\n}";
253 assert_eq!(extract_function_name(src).unwrap(), "process");
254 }
255
256 #[test]
257 fn returns_none_for_no_function() {
258 let src = "// just a comment\nconst X: u32 = 0;";
259 assert!(extract_function_name(src).is_none());
260 }
261
262 #[test]
265 fn build_prompt_includes_all_fields() {
266 let req = clean_req();
267 let prompt = build_prompt(&req);
268 assert!(prompt.contains("word_count"));
269 assert!(prompt.contains("utf8 text"));
270 assert!(prompt.contains("no_network: true"));
271 assert!(prompt.contains("read_only_input: true"));
272 }
273
274 #[test]
277 fn generated_tool_from_clean_response() {
278 let source = extract_code_block(CLEAN_RESPONSE).unwrap();
279 let tc = static_analysis::test_count(&source);
280 let sa = static_analysis::check(&source);
281 let tool = GeneratedTool {
282 function_name: extract_function_name(&source).unwrap_or_default(),
283 tests_included: tc >= 2,
284 test_count: tc,
285 static_analysis: sa,
286 source,
287 };
288 assert_eq!(tool.function_name, "run");
289 assert!(tool.tests_included, "response includes 2 tests");
290 assert!(tool.static_analysis.passed, "clean code passes analysis");
291 }
292
293 #[test]
294 fn generated_tool_flags_unsafe_source() {
295 let bad_response = "```rust\npub fn run(i: &str) -> Result<String, String> {\n unsafe { }\n Ok(\"ok\".into())\n}\n#[test]\nfn t1() {}\n#[test]\nfn t2() {}\n```";
296 let source = extract_code_block(bad_response).unwrap();
297 let sa = static_analysis::check(&source);
298 assert!(!sa.passed, "unsafe code should fail static analysis");
299 }
300}