Skip to main content

split_brain_harness/
tool_forge.rs

1/// Phase 2 mock tool forge — the supervisor.
2///
3/// The supervisor:
4///   1. Checks the budget (session limits).
5///   2. Runs static policy checks (no network, read-only, resource ceilings).
6///   3. Looks up the capability name in the mock registry.
7///   4. Executes the mock (deterministic, no generated code).
8///   5. Updates capability memory.
9///   6. Returns a ToolRunReport.
10///
11/// The model never runs code. The model emits a CapabilityRequest; the
12/// supervisor decides, runs, and destroys.
13use std::collections::HashMap;
14use std::time::Instant;
15
16use crate::capability::{
17    Budget, CapabilityMemoryRecord, CapabilityRequest, ToolMetrics, ToolRunReport,
18};
19use crate::input_validation;
20use crate::policy::{self, PolicyState};
21use crate::tool_memory::CapabilityMemory;
22
23/// Signature for a mock tool implementation.
24///
25/// Arguments:
26/// - `input` — the original user input text passed to the harness
27/// - `req`   — the parsed CapabilityRequest from the model
28///
29/// Returns a serialized JSON string on success, or an error description.
30pub type MockToolFn = fn(input: &str, req: &CapabilityRequest) -> Result<String, String>;
31
32/// Registry of hand-written mock tool implementations keyed by capability name.
33pub struct MockToolRegistry {
34    tools: HashMap<&'static str, MockToolFn>,
35}
36
37impl MockToolRegistry {
38    fn new() -> Self {
39        let mut tools: HashMap<&'static str, MockToolFn> = HashMap::new();
40        tools.insert("stream_parse_logs", mock_stream_parse_logs);
41        tools.insert("word_count", mock_word_count);
42        tools.insert("json_extract", mock_json_extract);
43        Self { tools }
44    }
45
46    pub fn get(&self, name: &str) -> Option<MockToolFn> {
47        self.tools.get(name).copied()
48    }
49
50    pub fn known_capabilities(&self) -> Vec<&'static str> {
51        let mut v: Vec<&'static str> = self.tools.keys().copied().collect();
52        v.sort_unstable();
53        v
54    }
55}
56
57/// The supervisor. Owns session state (budget accounting + capability memory).
58/// One Forge per session; each call to `handle` uses up budget.
59pub struct Forge {
60    budget: Budget,
61    state: PolicyState,
62    registry: MockToolRegistry,
63    pub memory: CapabilityMemory,
64    /// Immutable record of every decision made this session.
65    session_log: Vec<ToolRunReport>,
66}
67
68impl Forge {
69    pub fn new() -> Self {
70        Self {
71            budget: Budget::default(),
72            state: PolicyState::default(),
73            registry: MockToolRegistry::new(),
74            memory: CapabilityMemory::new(),
75            session_log: vec![],
76        }
77    }
78
79    pub fn with_budget(budget: Budget) -> Self {
80        Self {
81            budget,
82            state: PolicyState::default(),
83            registry: MockToolRegistry::new(),
84            memory: CapabilityMemory::new(),
85            session_log: vec![],
86        }
87    }
88
89    /// Immutable record of every ToolRunReport produced this session,
90    /// in order. Includes accepted and rejected requests.
91    pub fn audit(&self) -> &[ToolRunReport] {
92        &self.session_log
93    }
94
95    /// Process one CapabilityRequest from the model.
96    ///
97    /// Decision order:
98    ///   input validation → budget check → policy check → registry lookup
99    ///   → execute → memory update → audit log
100    ///
101    /// Every call is recorded in `self.session_log` regardless of outcome.
102    pub fn handle(&mut self, req: &CapabilityRequest, input: &str) -> ToolRunReport {
103        let report = self.handle_inner(req, input);
104        self.session_log.push(report.clone());
105        report
106    }
107
108    fn handle_inner(&mut self, req: &CapabilityRequest, input: &str) -> ToolRunReport {
109        // Validate forge input — reject malformed strings before any processing
110        if let Err(e) = input_validation::validate_forge_input(input) {
111            return rejected(vec![format!("input validation: {e}")]);
112        }
113
114        // Validate capability request fields — length limits on model-supplied strings
115        if let Err(e) = input_validation::validate_capability_fields(req) {
116            return rejected(vec![format!("capability field validation: {e}")]);
117        }
118
119        // Budget check — fail fast if session limits are already exhausted
120        if let Some(reason) = self.state.budget_exceeded(&self.budget) {
121            return rejected(vec![reason]);
122        }
123
124        // Static policy checks
125        let violations = policy::check_request(req);
126        if !violations.is_empty() {
127            return rejected(violations.into_iter().map(|v| v.detail).collect());
128        }
129
130        // Registry lookup — explicit allowlist of known capabilities
131        let mock_fn = match self.registry.get(&req.capability) {
132            Some(f) => f,
133            None => {
134                return rejected(vec![format!(
135                    "capability '{}' is not registered; known: {}",
136                    req.capability,
137                    self.registry.known_capabilities().join(", ")
138                )]);
139            }
140        };
141
142        // Execute
143        let start = Instant::now();
144        let exec_result = mock_fn(input, req);
145        let runtime_ms = start.elapsed().as_millis() as u64;
146
147        let (output_str, success) = match exec_result {
148            Ok(out) => (Some(out), true),
149            Err(e) => (Some(format!("{{\"error\":\"{e}\"}}")), false),
150        };
151
152        let metrics = ToolMetrics {
153            runtime_ms,
154            input_bytes: input.len(),
155            output_bytes: output_str.as_deref().map(|s| s.len()).unwrap_or(0),
156            success,
157        };
158
159        // Record against budget
160        self.state.record_run(&metrics);
161
162        // Update capability memory on success
163        let memory_update = if success {
164            let signature = CapabilityMemory::derive_signature(req);
165            let record = CapabilityMemoryRecord {
166                problem_signature: signature,
167                solution_pattern: format!("mock:{}", req.capability),
168                input_shape: shape_token(&req.input_contract),
169                output_shape: shape_token(&req.output_contract),
170                constraints: req.constraints.clone(),
171            };
172            self.memory.upsert(record.clone(), &metrics);
173            Some(record)
174        } else {
175            None
176        };
177
178        ToolRunReport {
179            accepted: true,
180            rejection_reasons: vec![],
181            verification_passed: true, // Phase 2: mocks are pre-verified by definition
182            executed: true,
183            output: output_str,
184            metrics,
185            destroyed: true, // lifecycle complete; no binary existed to destroy
186            memory_update,
187        }
188    }
189
190    pub fn tools_invoked(&self) -> usize {
191        self.state.tools_invoked
192    }
193}
194
195impl Default for Forge {
196    fn default() -> Self {
197        Self::new()
198    }
199}
200
201// ---------------------------------------------------------------------------
202// Helper
203// ---------------------------------------------------------------------------
204
205fn rejected(reasons: Vec<String>) -> ToolRunReport {
206    ToolRunReport {
207        accepted: false,
208        rejection_reasons: reasons,
209        verification_passed: false,
210        executed: false,
211        output: None,
212        metrics: ToolMetrics::default(),
213        destroyed: false,
214        memory_update: None,
215    }
216}
217
218fn shape_token(contract: &str) -> String {
219    contract
220        .split_whitespace()
221        .take(3)
222        .map(|w| {
223            w.to_lowercase()
224                .trim_matches(|c: char| !c.is_alphanumeric())
225                .to_string()
226        })
227        .filter(|s| !s.is_empty())
228        .collect::<Vec<_>>()
229        .join("_")
230}
231
232// ---------------------------------------------------------------------------
233// Mock tool implementations
234// ---------------------------------------------------------------------------
235
236/// Counts input lines and lines matching common HTTP status code patterns.
237/// Returns JSON: {"total_lines": N, "status_counts": {"200": M, ...}}
238fn mock_stream_parse_logs(input: &str, _req: &CapabilityRequest) -> Result<String, String> {
239    let mut status_counts: HashMap<String, usize> = HashMap::new();
240    let mut total = 0usize;
241
242    for line in input.lines() {
243        total += 1;
244        // Very simple: look for a 3-digit HTTP status code after a space
245        let mut found = false;
246        for token in line.split_whitespace() {
247            if token.len() == 3 && token.chars().all(|c| c.is_ascii_digit()) {
248                let first = token.chars().next().unwrap();
249                if ('1'..='5').contains(&first) {
250                    *status_counts.entry(token.to_string()).or_insert(0) += 1;
251                    found = true;
252                    break;
253                }
254            }
255        }
256        if !found && !line.trim().is_empty() {
257            *status_counts.entry("unknown".to_string()).or_insert(0) += 1;
258        }
259    }
260
261    let counts_json: Vec<String> = status_counts
262        .iter()
263        .map(|(k, v)| format!("\"{k}\":{v}"))
264        .collect();
265
266    Ok(format!(
267        "{{\"total_lines\":{total},\"status_counts\":{{{}}},\"note\":\"mock:stream_parse_logs\"}}",
268        counts_json.join(",")
269    ))
270}
271
272/// Counts words, lines, and characters in the input text.
273/// Returns JSON: {"word_count": N, "line_count": N, "char_count": N}
274fn mock_word_count(input: &str, _req: &CapabilityRequest) -> Result<String, String> {
275    let word_count = input.split_whitespace().count();
276    let line_count = input.lines().count();
277    let char_count = input.chars().count();
278    Ok(format!(
279        "{{\"word_count\":{word_count},\"line_count\":{line_count},\"char_count\":{char_count},\"note\":\"mock:word_count\"}}"
280    ))
281}
282
283/// Parses input as JSON and returns the top-level field names.
284/// Returns JSON: {"fields": ["a", "b", ...]}
285fn mock_json_extract(input: &str, _req: &CapabilityRequest) -> Result<String, String> {
286    let parsed: serde_json::Value =
287        serde_json::from_str(input).map_err(|e| format!("json parse error: {e}"))?;
288
289    let fields: Vec<String> = match &parsed {
290        serde_json::Value::Object(map) => map.keys().map(|k| format!("\"{k}\"")).collect(),
291        _ => return Err("input must be a JSON object".into()),
292    };
293
294    Ok(format!(
295        "{{\"fields\":[{}],\"note\":\"mock:json_extract\"}}",
296        fields.join(",")
297    ))
298}
299
300// ---------------------------------------------------------------------------
301// Tests
302// ---------------------------------------------------------------------------
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::capability::CapabilityConstraints;
308
309    fn clean_req(capability: &str) -> CapabilityRequest {
310        CapabilityRequest {
311            kind: "capability_request".into(),
312            capability: capability.into(),
313            input_contract: "utf8 text".into(),
314            output_contract: "json object".into(),
315            constraints: CapabilityConstraints::default(),
316            reason: "text reasoning cannot efficiently process this".into(),
317        }
318    }
319
320    #[test]
321    fn accepts_registered_capability_and_executes() {
322        let mut forge = Forge::new();
323        let req = clean_req("word_count");
324        let report = forge.handle(&req, "hello world\nsecond line");
325        assert!(report.accepted, "should be accepted");
326        assert!(report.executed, "should have executed");
327        assert!(report.destroyed, "lifecycle must be marked complete");
328        assert!(report.rejection_reasons.is_empty());
329        let out = report.output.unwrap();
330        assert!(out.contains("word_count"));
331    }
332
333    #[test]
334    fn rejects_unknown_capability() {
335        let mut forge = Forge::new();
336        let req = clean_req("nonexistent_tool");
337        let report = forge.handle(&req, "input");
338        assert!(!report.accepted);
339        assert!(!report.executed);
340        assert!(report.rejection_reasons[0].contains("not registered"));
341    }
342
343    #[test]
344    fn rejects_network_access_request() {
345        let mut forge = Forge::new();
346        let mut req = clean_req("word_count");
347        req.constraints.no_network = false;
348        let report = forge.handle(&req, "input");
349        assert!(!report.accepted);
350        assert!(report
351            .rejection_reasons
352            .iter()
353            .any(|r| r.contains("no_network")));
354    }
355
356    #[test]
357    fn rejects_when_budget_exhausted() {
358        let budget = Budget {
359            max_tools_per_session: 1,
360            ..Budget::default()
361        };
362        let mut forge = Forge::with_budget(budget);
363        // First run consumes the budget
364        forge.handle(&clean_req("word_count"), "hello");
365        // Second run should be rejected
366        let report = forge.handle(&clean_req("word_count"), "hello");
367        assert!(!report.accepted);
368        assert!(report.rejection_reasons[0].contains("session tool limit"));
369    }
370
371    #[test]
372    fn updates_memory_on_success() {
373        let mut forge = Forge::new();
374        let req = clean_req("word_count");
375        let report = forge.handle(&req, "hello world");
376        assert!(report.memory_update.is_some());
377        assert!(!forge.memory.is_empty());
378    }
379
380    #[test]
381    fn mock_stream_parse_logs_counts_status_codes() {
382        let input = "127.0.0.1 - - [01/Jan/2025] \"GET / HTTP/1.1\" 200 1234\n\
383                     127.0.0.1 - - [01/Jan/2025] \"GET /missing HTTP/1.1\" 404 0\n\
384                     127.0.0.1 - - [01/Jan/2025] \"POST /api HTTP/1.1\" 200 500";
385        let mut forge = Forge::new();
386        let req = clean_req("stream_parse_logs");
387        let report = forge.handle(&req, input);
388        assert!(report.accepted);
389        let out = report.output.unwrap();
390        assert!(out.contains("\"200\""));
391        assert!(out.contains("\"404\""));
392        assert!(out.contains("total_lines"));
393    }
394
395    #[test]
396    fn mock_word_count_correct_counts() {
397        let input = "hello world\nthird word here";
398        let mut forge = Forge::new();
399        let report = forge.handle(&clean_req("word_count"), input);
400        assert!(report.accepted);
401        let out = report.output.unwrap();
402        // 5 words
403        assert!(out.contains("\"word_count\":5"), "got: {out}");
404        // 2 lines
405        assert!(out.contains("\"line_count\":2"), "got: {out}");
406    }
407
408    #[test]
409    fn mock_json_extract_returns_field_names() {
410        let input = r#"{"alpha": 1, "beta": "two"}"#;
411        let mut forge = Forge::new();
412        let report = forge.handle(&clean_req("json_extract"), input);
413        assert!(report.accepted);
414        let out = report.output.unwrap();
415        assert!(out.contains("alpha") && out.contains("beta"), "got: {out}");
416    }
417
418    #[test]
419    fn mock_json_extract_error_on_non_object() {
420        let input = "[1, 2, 3]";
421        let mut forge = Forge::new();
422        let report = forge.handle(&clean_req("json_extract"), input);
423        assert!(report.accepted, "accepted — mock ran to completion");
424        assert!(!report.metrics.success, "but execution failed");
425    }
426
427    #[test]
428    fn budget_tracks_multiple_runs() {
429        let mut forge = Forge::new();
430        forge.handle(&clean_req("word_count"), "a");
431        forge.handle(&clean_req("word_count"), "b");
432        assert_eq!(forge.tools_invoked(), 2);
433    }
434
435    #[test]
436    fn memory_accumulates_across_runs() {
437        let mut forge = Forge::new();
438        forge.handle(&clean_req("word_count"), "first input");
439        forge.handle(&clean_req("word_count"), "second input");
440        // Same signature — should be one entry with 2 runs
441        assert_eq!(forge.memory.len(), 1);
442        let sig = CapabilityMemory::derive_signature(&clean_req("word_count"));
443        let entry = forge.memory.lookup(&sig).unwrap();
444        assert_eq!(entry.metrics.runs, 2);
445    }
446
447    // --- Input validation at the forge boundary ---
448
449    #[test]
450    fn forge_rejects_oversized_input() {
451        let mut forge = Forge::new();
452        let big = "x".repeat(crate::input_validation::MAX_FORGE_INPUT_BYTES + 1);
453        let report = forge.handle(&clean_req("word_count"), &big);
454        assert!(!report.accepted);
455        assert!(report.rejection_reasons[0].contains("input validation"));
456    }
457
458    #[test]
459    fn forge_rejects_null_byte_in_input() {
460        let mut forge = Forge::new();
461        let report = forge.handle(&clean_req("word_count"), "good\x00bad");
462        assert!(!report.accepted);
463        assert!(report.rejection_reasons[0].contains("input validation"));
464    }
465
466    #[test]
467    fn forge_rejects_oversized_capability_name() {
468        let mut forge = Forge::new();
469        let mut req = clean_req("word_count");
470        req.capability = "x".repeat(crate::input_validation::MAX_CAPABILITY_NAME_BYTES + 1);
471        let report = forge.handle(&req, "hello");
472        assert!(!report.accepted);
473        assert!(report.rejection_reasons[0].contains("capability field validation"));
474    }
475
476    // --- Session audit log ---
477
478    #[test]
479    fn session_log_records_all_calls() {
480        let mut forge = Forge::new();
481        forge.handle(&clean_req("word_count"), "a");
482        forge.handle(&clean_req("nonexistent"), "b");
483        assert_eq!(
484            forge.audit().len(),
485            2,
486            "both calls must appear in audit log"
487        );
488    }
489
490    #[test]
491    fn session_log_records_rejections() {
492        let mut forge = Forge::new();
493        let mut req = clean_req("word_count");
494        req.constraints.no_network = false;
495        forge.handle(&req, "input");
496        let log = forge.audit();
497        assert_eq!(log.len(), 1);
498        assert!(!log[0].accepted, "rejected call must be in audit log");
499    }
500
501    // --- Idempotent repeated calls ---
502
503    #[test]
504    fn repeated_calls_same_input_produce_same_output() {
505        let mut forge = Forge::new();
506        let r1 = forge.handle(&clean_req("word_count"), "hello world");
507        let r2 = forge.handle(&clean_req("word_count"), "hello world");
508        // Mock is deterministic — outputs should be identical
509        assert_eq!(r1.output, r2.output, "mock tools must be deterministic");
510    }
511
512    // --- Failure recovery after exception ---
513
514    #[test]
515    fn failure_recovery_bad_then_good_input() {
516        let mut forge = Forge::new();
517        // First call: json_extract with bad JSON (fails)
518        let r1 = forge.handle(&clean_req("json_extract"), "[1, 2, 3]");
519        assert!(r1.accepted, "accepted — mock ran to completion");
520        assert!(!r1.metrics.success, "but execution failed (not an object)");
521
522        // Second call: different tool succeeds — forge is not corrupted
523        let r2 = forge.handle(&clean_req("word_count"), "hello world");
524        assert!(r2.accepted);
525        assert!(
526            r2.metrics.success,
527            "word_count should succeed after json_extract failed"
528        );
529    }
530
531    // --- Shared state isolation between Forge instances ---
532
533    #[test]
534    fn two_forge_instances_do_not_share_memory() {
535        let mut forge_a = Forge::new();
536        let mut forge_b = Forge::new();
537
538        forge_a.handle(&clean_req("word_count"), "a");
539        assert_eq!(
540            forge_a.memory.len(),
541            1,
542            "forge_a should have 1 memory entry"
543        );
544        assert_eq!(
545            forge_b.memory.len(),
546            0,
547            "forge_b memory must be independent"
548        );
549    }
550
551    #[test]
552    fn two_forge_instances_do_not_share_budget() {
553        let budget = Budget {
554            max_tools_per_session: 1,
555            ..Budget::default()
556        };
557        let mut forge_a = Forge::with_budget(budget.clone());
558        let mut forge_b = Forge::with_budget(budget);
559
560        // Exhaust forge_a's budget
561        forge_a.handle(&clean_req("word_count"), "a");
562        let rejected_a = forge_a.handle(&clean_req("word_count"), "b");
563        assert!(!rejected_a.accepted, "forge_a should be exhausted");
564
565        // forge_b is unaffected
566        let ok_b = forge_b.handle(&clean_req("word_count"), "c");
567        assert!(
568            ok_b.accepted,
569            "forge_b budget must be independent of forge_a"
570        );
571    }
572}