Skip to main content

split_brain_harness/
code_gen.rs

1/// Code generation layer for Phase 3 of the Ephemeral Tool Forge.
2///
3/// Sends a CapabilityRequest to the inference engine and parses the response
4/// into a GeneratedTool with static analysis results. No compilation or
5/// execution happens here.
6use crate::backends::InferenceEngine;
7use crate::capability::CapabilityRequest;
8use crate::static_analysis::{self, StaticAnalysisReport};
9use crate::types::Soul;
10use serde::{Deserialize, Serialize};
11
12// ---------------------------------------------------------------------------
13// Output type
14// ---------------------------------------------------------------------------
15
16/// The result of one code generation pass.
17///
18/// Contains the raw source, extracted metadata, and the static analysis
19/// report. Does not contain compiled artefacts — those live in Phase 4.
20#[derive(Debug, Serialize, Deserialize, Clone)]
21pub struct GeneratedTool {
22    /// Rust source code as returned by the model (extracted from the code block).
23    pub source: String,
24    /// The first public/private `fn` name found in the source.
25    pub function_name: String,
26    /// True if the source contains at least two `#[test]` functions.
27    pub tests_included: bool,
28    /// Number of `#[test]` annotations found.
29    pub test_count: usize,
30    /// Static analysis results.
31    pub static_analysis: StaticAnalysisReport,
32}
33
34// ---------------------------------------------------------------------------
35// Generator
36// ---------------------------------------------------------------------------
37
38/// Uses an inference engine to generate Rust source for a CapabilityRequest.
39pub struct CodeGenerator<'e> {
40    engine: &'e dyn InferenceEngine,
41    soul: &'e Soul,
42}
43
44impl<'e> CodeGenerator<'e> {
45    pub fn new(engine: &'e dyn InferenceEngine, soul: &'e Soul) -> Self {
46        Self { engine, soul }
47    }
48
49    /// Generate Rust source code for `req`.
50    ///
51    /// Returns `Err` if the model call fails or no code block is present in
52    /// the response. Static analysis violations do NOT cause an Err — they
53    /// are reported inside `GeneratedTool`.
54    pub async fn generate(&self, req: &CapabilityRequest) -> Result<GeneratedTool, String> {
55        let prompt = build_prompt(req);
56
57        let raw = self
58            .engine
59            .generate(&self.soul.code_gen_system_prompt, &prompt)
60            .await?;
61
62        let source = extract_code_block(&raw).ok_or_else(|| {
63            format!(
64                "model did not return a Rust code block \
65                 (expected ```rust ... ```) — raw response length: {} chars",
66                raw.len()
67            )
68        })?;
69
70        let function_name = extract_function_name(&source).unwrap_or_else(|| "unknown".to_string());
71        let test_count = static_analysis::test_count(&source);
72        let tests_included = test_count >= 2;
73        let sa = static_analysis::check(&source);
74
75        Ok(GeneratedTool {
76            source,
77            function_name,
78            tests_included,
79            test_count,
80            static_analysis: sa,
81        })
82    }
83}
84
85// ---------------------------------------------------------------------------
86// Prompt builder
87// ---------------------------------------------------------------------------
88
89/// Build the generation prompt from a CapabilityRequest.
90pub fn build_prompt(req: &CapabilityRequest) -> String {
91    format!(
92        "<capability_request>\n\
93         capability: {cap}\n\
94         input_contract: {inp}\n\
95         output_contract: {out}\n\
96         reason: {reason}\n\
97         constraints:\n\
98           no_network: {no_net}\n\
99           read_only_input: {ro}\n\
100           max_runtime_ms: {rt}\n\
101           max_memory_mb: {mem}\n\
102         </capability_request>\n\n\
103         Respond with ONLY a ```rust ... ``` code block. No prose before or after it.",
104        cap = req.capability,
105        inp = req.input_contract,
106        out = req.output_contract,
107        reason = req.reason,
108        no_net = req.constraints.no_network,
109        ro = req.constraints.read_only_input,
110        rt = req.constraints.max_runtime_ms,
111        mem = req.constraints.max_memory_mb,
112    )
113}
114
115// ---------------------------------------------------------------------------
116// Internal helpers
117// ---------------------------------------------------------------------------
118
119/// Extract the content of the first ` ```rust ... ``` ` block in `response`.
120/// Returns `None` if no such block exists or the block is empty after trimming.
121/// Accepts both `rust` and `Rust` language tags and both LF and CRLF endings.
122pub fn extract_code_block(response: &str) -> Option<String> {
123    // Normalise to LF so we only need to match one variant
124    let normalised = response.replace("\r\n", "\n");
125
126    // Accept ```rust or ```Rust
127    let (marker, skip): (&str, usize) = if normalised.contains("```rust\n") {
128        ("```rust\n", "```rust\n".len())
129    } else if normalised.contains("```Rust\n") {
130        ("```Rust\n", "```Rust\n".len())
131    } else {
132        return None;
133    };
134
135    let start = normalised.find(marker)? + skip;
136    let rest = &normalised[start..];
137    let end = rest.find("```")?;
138    let code = rest[..end].trim().to_string();
139    if code.is_empty() {
140        return None;
141    }
142    Some(code)
143}
144
145/// Find the first `pub fn` or `fn` name in the source.
146pub fn extract_function_name(source: &str) -> Option<String> {
147    for line in source.lines() {
148        let t = line.trim();
149        let rest = t.strip_prefix("pub fn ").or_else(|| t.strip_prefix("fn "));
150        if let Some(rest) = rest {
151            let name: String = rest
152                .chars()
153                .take_while(|c| c.is_alphanumeric() || *c == '_')
154                .collect();
155            if !name.is_empty() {
156                return Some(name);
157            }
158        }
159    }
160    None
161}
162
163// ---------------------------------------------------------------------------
164// Tests
165// ---------------------------------------------------------------------------
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use crate::capability::CapabilityConstraints;
171
172    fn clean_req() -> CapabilityRequest {
173        CapabilityRequest {
174            kind: "capability_request".into(),
175            capability: "word_count".into(),
176            input_contract: "utf8 text".into(),
177            output_contract: "json object with word_count".into(),
178            constraints: CapabilityConstraints::default(),
179            reason: "text reasoning insufficient".into(),
180        }
181    }
182
183    const CLEAN_RESPONSE: &str = r#"Here is the Rust implementation:
184
185```rust
186pub fn run(input: &str) -> Result<String, String> {
187    let count = input.split_whitespace().count();
188    Ok(format!("{\"word_count\":{}}", count))
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    #[test]
195    fn two_words() {
196        assert!(run("hello world").unwrap().contains("2"));
197    }
198    #[test]
199    fn empty() {
200        assert!(run("").unwrap().contains("0"));
201    }
202}
203```
204
205That should fulfil the contract.
206"#;
207
208    // --- extract_code_block ---
209
210    #[test]
211    fn extracts_rust_code_block() {
212        let source = extract_code_block(CLEAN_RESPONSE).unwrap();
213        assert!(
214            source.contains("pub fn run"),
215            "extracted code must contain the function"
216        );
217        assert!(
218            !source.contains("```"),
219            "backticks must not appear in extracted code"
220        );
221    }
222
223    #[test]
224    fn returns_none_when_no_code_block() {
225        let response = "Here is some analysis but no code.";
226        assert!(extract_code_block(response).is_none());
227    }
228
229    #[test]
230    fn returns_none_for_empty_code_block() {
231        let response = "```rust\n```";
232        assert!(extract_code_block(response).is_none());
233    }
234
235    #[test]
236    fn extract_code_block_ignores_leading_prose() {
237        let r = "Some explanation.\n\n```rust\nfn run(i: &str) -> Result<String, String> { Ok(i.into()) }\n```\n";
238        let code = extract_code_block(r).unwrap();
239        assert!(code.starts_with("fn run"));
240    }
241
242    // --- extract_function_name ---
243
244    #[test]
245    fn extracts_pub_fn_name() {
246        let src = "pub fn run(input: &str) -> Result<String, String> {\n    Ok(\"ok\".into())\n}";
247        assert_eq!(extract_function_name(src).unwrap(), "run");
248    }
249
250    #[test]
251    fn extracts_private_fn_name() {
252        let src = "fn process(input: &str) -> Result<String, String> {\n    Ok(\"ok\".into())\n}";
253        assert_eq!(extract_function_name(src).unwrap(), "process");
254    }
255
256    #[test]
257    fn returns_none_for_no_function() {
258        let src = "// just a comment\nconst X: u32 = 0;";
259        assert!(extract_function_name(src).is_none());
260    }
261
262    // --- build_prompt ---
263
264    #[test]
265    fn build_prompt_includes_all_fields() {
266        let req = clean_req();
267        let prompt = build_prompt(&req);
268        assert!(prompt.contains("word_count"));
269        assert!(prompt.contains("utf8 text"));
270        assert!(prompt.contains("no_network: true"));
271        assert!(prompt.contains("read_only_input: true"));
272    }
273
274    // --- GeneratedTool construction ---
275
276    #[test]
277    fn generated_tool_from_clean_response() {
278        let source = extract_code_block(CLEAN_RESPONSE).unwrap();
279        let tc = static_analysis::test_count(&source);
280        let sa = static_analysis::check(&source);
281        let tool = GeneratedTool {
282            function_name: extract_function_name(&source).unwrap_or_default(),
283            tests_included: tc >= 2,
284            test_count: tc,
285            static_analysis: sa,
286            source,
287        };
288        assert_eq!(tool.function_name, "run");
289        assert!(tool.tests_included, "response includes 2 tests");
290        assert!(tool.static_analysis.passed, "clean code passes analysis");
291    }
292
293    #[test]
294    fn generated_tool_flags_unsafe_source() {
295        let bad_response = "```rust\npub fn run(i: &str) -> Result<String, String> {\n    unsafe { }\n    Ok(\"ok\".into())\n}\n#[test]\nfn t1() {}\n#[test]\nfn t2() {}\n```";
296        let source = extract_code_block(bad_response).unwrap();
297        let sa = static_analysis::check(&source);
298        assert!(!sa.passed, "unsafe code should fail static analysis");
299    }
300}