Skip to main content

split_brain_harness/
generative_forge.rs

1/// Phase 3 of the Ephemeral Tool Forge — source generation and static analysis.
2///
3/// Takes a `CapabilityRequest`, calls the inference engine to produce a Rust
4/// function with inline tests, then runs static analysis and verifies tests
5/// are present.  Does NOT execute the code — that is Phase 4 (`wasm_forge`).
6///
7/// **Status**: production-quality, full test coverage.
8/// **Requires**: a configured inference backend (SBH_BACKEND / SBH_API_KEY).
9/// **CLI entry point**: `sbh forge "<capability>" "<input>"`
10///
11/// Pipeline:
12///   input validation → policy check → code generation →
13///   static analysis → test presence check → memory update
14use std::time::Instant;
15
16use crate::backends::InferenceEngine;
17use crate::capability::{Budget, CapabilityMemoryRecord, CapabilityRequest, ToolMetrics};
18use crate::code_gen::{CodeGenerator, GeneratedTool};
19use crate::input_validation;
20use crate::policy::{self, PolicyState};
21use crate::tool_memory::CapabilityMemory;
22use crate::types::Soul;
23use serde::{Deserialize, Serialize};
24
25// ---------------------------------------------------------------------------
26// Report type — richer than Phase 2 ToolRunReport
27// ---------------------------------------------------------------------------
28
29/// Full result of one generative forge pass.
30#[derive(Debug, Serialize, Deserialize, Clone)]
31pub struct GenerativeReport {
32    /// True if the request passed all pre-generation checks.
33    pub accepted: bool,
34    /// Non-empty when the request was rejected before generation.
35    pub rejection_reasons: Vec<String>,
36    /// True if static analysis passed AND tests are present.
37    /// False when generation succeeds but output is unsafe or missing tests.
38    pub verification_passed: bool,
39    /// Phase 3 never executes against real data.
40    pub executed: bool,
41    /// The generated tool (present when generation succeeded).
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub generated_tool: Option<GeneratedTool>,
44    /// Set when the LLM call or code block extraction failed.
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub generation_error: Option<String>,
47    pub metrics: ToolMetrics,
48    /// True: source is not persisted after this call.
49    pub destroyed: bool,
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub memory_update: Option<CapabilityMemoryRecord>,
52}
53
54// ---------------------------------------------------------------------------
55// Supervisor
56// ---------------------------------------------------------------------------
57
58/// Phase 3 supervisor. One instance per session.
59pub struct GenerativeForge<'e> {
60    budget: Budget,
61    state: PolicyState,
62    pub memory: CapabilityMemory,
63    engine: &'e dyn InferenceEngine,
64    soul: Soul,
65    session_log: Vec<GenerativeReport>,
66}
67
68impl<'e> GenerativeForge<'e> {
69    pub fn new(engine: &'e dyn InferenceEngine, soul: Soul) -> Self {
70        Self {
71            budget: Budget::default(),
72            state: PolicyState::default(),
73            memory: CapabilityMemory::new(),
74            engine,
75            soul,
76            session_log: vec![],
77        }
78    }
79
80    pub fn with_budget(budget: Budget, engine: &'e dyn InferenceEngine, soul: Soul) -> Self {
81        Self {
82            budget,
83            state: PolicyState::default(),
84            memory: CapabilityMemory::new(),
85            engine,
86            soul,
87            session_log: vec![],
88        }
89    }
90
91    /// Every call is recorded in the session log regardless of outcome.
92    pub fn audit(&self) -> &[GenerativeReport] {
93        &self.session_log
94    }
95
96    /// Process one CapabilityRequest.
97    pub async fn handle(&mut self, req: &CapabilityRequest, input: &str) -> GenerativeReport {
98        let report = self.handle_inner(req, input).await;
99        self.session_log.push(report.clone());
100        report
101    }
102
103    async fn handle_inner(&mut self, req: &CapabilityRequest, input: &str) -> GenerativeReport {
104        // Input validation
105        if let Err(e) = input_validation::validate_forge_input(input) {
106            return rejected(vec![format!("input validation: {e}")]);
107        }
108        if let Err(e) = input_validation::validate_capability_fields(req) {
109            return rejected(vec![format!("capability field validation: {e}")]);
110        }
111
112        // Budget check
113        if let Some(reason) = self.state.budget_exceeded(&self.budget) {
114            return rejected(vec![reason]);
115        }
116
117        // Policy checks
118        let violations = policy::check_request(req);
119        if !violations.is_empty() {
120            return rejected(violations.into_iter().map(|v| v.detail).collect());
121        }
122
123        // Code generation
124        let generator = CodeGenerator::new(self.engine, &self.soul);
125        let start = Instant::now();
126        let gen_result = generator.generate(req).await;
127        let generation_ms = start.elapsed().as_millis() as u64;
128
129        let generated: GeneratedTool = match gen_result {
130            Ok(tool) => tool,
131            Err(e) => {
132                let metrics = ToolMetrics {
133                    runtime_ms: generation_ms,
134                    input_bytes: input.len(),
135                    output_bytes: 0,
136                    success: false,
137                };
138                self.state.record_run(&metrics);
139                return GenerativeReport {
140                    accepted: true,
141                    rejection_reasons: vec![],
142                    verification_passed: false,
143                    executed: false,
144                    generated_tool: None,
145                    generation_error: Some(format!("code generation failed: {e}")),
146                    metrics,
147                    destroyed: true,
148                    memory_update: None,
149                };
150            }
151        };
152
153        // Verification: static analysis must pass AND at least 2 tests present
154        let verification_passed = generated.static_analysis.passed && generated.tests_included;
155
156        let metrics = ToolMetrics {
157            runtime_ms: generation_ms,
158            input_bytes: input.len(),
159            output_bytes: generated.source.len(),
160            success: verification_passed,
161        };
162        self.state.record_run(&metrics);
163
164        // Memory update on success
165        let memory_update = if verification_passed {
166            let signature = CapabilityMemory::derive_signature(req);
167            let record = CapabilityMemoryRecord {
168                problem_signature: signature,
169                solution_pattern: format!("generated:{}", req.capability),
170                input_shape: shape_token(&req.input_contract),
171                output_shape: shape_token(&req.output_contract),
172                constraints: req.constraints.clone(),
173            };
174            self.memory.upsert(record.clone(), &metrics);
175            Some(record)
176        } else {
177            None
178        };
179
180        GenerativeReport {
181            accepted: true,
182            rejection_reasons: vec![],
183            verification_passed,
184            executed: false,
185            generated_tool: Some(generated),
186            generation_error: None,
187            metrics,
188            destroyed: true,
189            memory_update,
190        }
191    }
192
193    pub fn tools_invoked(&self) -> usize {
194        self.state.tools_invoked
195    }
196}
197
198// ---------------------------------------------------------------------------
199// Helper
200// ---------------------------------------------------------------------------
201
202fn rejected(reasons: Vec<String>) -> GenerativeReport {
203    GenerativeReport {
204        accepted: false,
205        rejection_reasons: reasons,
206        verification_passed: false,
207        executed: false,
208        generated_tool: None,
209        generation_error: None,
210        metrics: ToolMetrics::default(),
211        destroyed: false,
212        memory_update: None,
213    }
214}
215
216fn shape_token(contract: &str) -> String {
217    contract
218        .split_whitespace()
219        .take(3)
220        .map(|w| {
221            w.to_lowercase()
222                .trim_matches(|c: char| !c.is_alphanumeric())
223                .to_string()
224        })
225        .filter(|s| !s.is_empty())
226        .collect::<Vec<_>>()
227        .join("_")
228}
229
230// ---------------------------------------------------------------------------
231// Tests
232// ---------------------------------------------------------------------------
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::capability::CapabilityConstraints;
238    use crate::soul;
239    use async_trait::async_trait;
240
241    // --- Mock engine helpers ---
242
243    struct MockEngine {
244        response: String,
245    }
246
247    #[async_trait]
248    impl crate::backends::InferenceEngine for MockEngine {
249        async fn generate(&self, _sys: &str, _prompt: &str) -> Result<String, String> {
250            Ok(self.response.clone())
251        }
252    }
253
254    struct ErrorEngine;
255
256    #[async_trait]
257    impl crate::backends::InferenceEngine for ErrorEngine {
258        async fn generate(&self, _sys: &str, _prompt: &str) -> Result<String, String> {
259            Err("backend unavailable".into())
260        }
261    }
262
263    fn clean_req(cap: &str) -> CapabilityRequest {
264        CapabilityRequest {
265            kind: "capability_request".into(),
266            capability: cap.into(),
267            input_contract: "utf8 text lines".into(),
268            output_contract: "json object".into(),
269            constraints: CapabilityConstraints::default(),
270            reason: "text reasoning insufficient for this task".into(),
271        }
272    }
273
274    const CLEAN_RUST_RESPONSE: &str = r#"```rust
275pub fn run(input: &str) -> Result<String, String> {
276    let count = input.split_whitespace().count();
277    Ok(format!("{\"word_count\":{}}", count))
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    #[test]
284    fn two_words() {
285        assert!(run("hello world").unwrap().contains("2"));
286    }
287    #[test]
288    fn empty_input() {
289        assert!(run("").unwrap().contains("0"));
290    }
291}
292```"#;
293
294    const UNSAFE_RUST_RESPONSE: &str = r#"```rust
295pub fn run(input: &str) -> Result<String, String> {
296    unsafe { let _ = 0; }
297    Ok(format!("{\"word_count\":{}}", input.len()))
298}
299#[test]
300fn t1() {}
301#[test]
302fn t2() {}
303```"#;
304
305    const NO_TESTS_RESPONSE: &str = r#"```rust
306pub fn run(input: &str) -> Result<String, String> {
307    Ok(input.to_string())
308}
309```"#;
310
311    const NO_CODE_BLOCK_RESPONSE: &str = "Here is my analysis but no code block.";
312
313    // --- Acceptance paths ---
314
315    #[tokio::test]
316    async fn accepts_clean_request_and_clean_code() {
317        let engine = MockEngine {
318            response: CLEAN_RUST_RESPONSE.into(),
319        };
320        let soul = soul::load(None).unwrap();
321        let mut forge = GenerativeForge::new(&engine, soul);
322        let req = clean_req("word_count");
323        let report = forge.handle(&req, "some input").await;
324
325        assert!(report.accepted);
326        assert!(
327            report.verification_passed,
328            "clean code + 2 tests should pass"
329        );
330        assert!(!report.executed, "Phase 3 never executes");
331        assert!(report.generated_tool.is_some());
332        assert!(report.generation_error.is_none());
333    }
334
335    // --- Rejection paths (before generation) ---
336
337    #[tokio::test]
338    async fn rejects_network_access_request() {
339        let engine = MockEngine {
340            response: CLEAN_RUST_RESPONSE.into(),
341        };
342        let soul = soul::load(None).unwrap();
343        let mut forge = GenerativeForge::new(&engine, soul);
344        let mut req = clean_req("fetch_url");
345        req.constraints.no_network = false;
346        let report = forge.handle(&req, "input").await;
347        assert!(!report.accepted);
348        assert!(report
349            .rejection_reasons
350            .iter()
351            .any(|r| r.contains("no_network")));
352    }
353
354    #[tokio::test]
355    async fn rejects_when_budget_exhausted() {
356        let budget = Budget {
357            max_tools_per_session: 1,
358            ..Budget::default()
359        };
360        let engine = MockEngine {
361            response: CLEAN_RUST_RESPONSE.into(),
362        };
363        let soul = soul::load(None).unwrap();
364        let mut forge = GenerativeForge::with_budget(budget, &engine, soul);
365        forge.handle(&clean_req("word_count"), "a").await;
366        let report = forge.handle(&clean_req("word_count"), "b").await;
367        assert!(!report.accepted);
368        assert!(report.rejection_reasons[0].contains("session tool limit"));
369    }
370
371    #[tokio::test]
372    async fn rejects_oversized_input() {
373        let engine = MockEngine {
374            response: CLEAN_RUST_RESPONSE.into(),
375        };
376        let soul = soul::load(None).unwrap();
377        let mut forge = GenerativeForge::new(&engine, soul);
378        let big = "x".repeat(crate::input_validation::MAX_FORGE_INPUT_BYTES + 1);
379        let report = forge.handle(&clean_req("word_count"), &big).await;
380        assert!(!report.accepted);
381        assert!(report.rejection_reasons[0].contains("input validation"));
382    }
383
384    // --- Verification failure paths (after generation) ---
385
386    #[tokio::test]
387    async fn fails_verification_on_static_violation() {
388        let engine = MockEngine {
389            response: UNSAFE_RUST_RESPONSE.into(),
390        };
391        let soul = soul::load(None).unwrap();
392        let mut forge = GenerativeForge::new(&engine, soul);
393        let report = forge.handle(&clean_req("word_count"), "input").await;
394        assert!(report.accepted, "request was accepted");
395        assert!(
396            !report.verification_passed,
397            "unsafe code must fail verification"
398        );
399        let tool = report.generated_tool.unwrap();
400        assert!(!tool.static_analysis.passed);
401        assert!(tool
402            .static_analysis
403            .violations
404            .iter()
405            .any(|v| v.kind == "unsafe_code"));
406    }
407
408    #[tokio::test]
409    async fn fails_verification_on_missing_tests() {
410        let engine = MockEngine {
411            response: NO_TESTS_RESPONSE.into(),
412        };
413        let soul = soul::load(None).unwrap();
414        let mut forge = GenerativeForge::new(&engine, soul);
415        let report = forge.handle(&clean_req("word_count"), "input").await;
416        assert!(report.accepted);
417        assert!(
418            !report.verification_passed,
419            "code without tests must fail verification"
420        );
421        let tool = report.generated_tool.unwrap();
422        assert!(!tool.tests_included);
423    }
424
425    #[tokio::test]
426    async fn generation_error_when_no_code_block() {
427        let engine = MockEngine {
428            response: NO_CODE_BLOCK_RESPONSE.into(),
429        };
430        let soul = soul::load(None).unwrap();
431        let mut forge = GenerativeForge::new(&engine, soul);
432        let report = forge.handle(&clean_req("word_count"), "input").await;
433        assert!(report.accepted, "request was valid");
434        assert!(!report.verification_passed);
435        assert!(report.generated_tool.is_none());
436        assert!(report.generation_error.is_some());
437    }
438
439    #[tokio::test]
440    async fn generation_error_when_backend_fails() {
441        let soul = soul::load(None).unwrap();
442        let mut forge = GenerativeForge::new(&ErrorEngine, soul);
443        let report = forge.handle(&clean_req("word_count"), "input").await;
444        assert!(report.accepted);
445        assert!(!report.verification_passed);
446        assert!(report
447            .generation_error
448            .as_deref()
449            .unwrap_or("")
450            .contains("backend unavailable"));
451    }
452
453    // --- Session log ---
454
455    #[tokio::test]
456    async fn session_log_records_every_call() {
457        let engine = MockEngine {
458            response: CLEAN_RUST_RESPONSE.into(),
459        };
460        let soul = soul::load(None).unwrap();
461        let mut forge = GenerativeForge::new(&engine, soul);
462        let mut req_bad = clean_req("x");
463        req_bad.constraints.no_network = false;
464        forge.handle(&req_bad, "a").await;
465        forge.handle(&clean_req("word_count"), "b").await;
466        assert_eq!(forge.audit().len(), 2);
467        assert!(!forge.audit()[0].accepted);
468        assert!(forge.audit()[1].accepted);
469    }
470
471    // --- Memory update ---
472
473    #[tokio::test]
474    async fn memory_updated_on_success() {
475        let engine = MockEngine {
476            response: CLEAN_RUST_RESPONSE.into(),
477        };
478        let soul = soul::load(None).unwrap();
479        let mut forge = GenerativeForge::new(&engine, soul);
480        let report = forge.handle(&clean_req("word_count"), "input").await;
481        assert!(report.memory_update.is_some());
482        assert_eq!(forge.memory.len(), 1);
483    }
484
485    #[tokio::test]
486    async fn memory_not_updated_on_verification_failure() {
487        let engine = MockEngine {
488            response: UNSAFE_RUST_RESPONSE.into(),
489        };
490        let soul = soul::load(None).unwrap();
491        let mut forge = GenerativeForge::new(&engine, soul);
492        forge.handle(&clean_req("word_count"), "input").await;
493        assert_eq!(
494            forge.memory.len(),
495            0,
496            "memory must not be updated when verification fails"
497        );
498    }
499}