Skip to main content

split_brain_harness/
static_analysis.rs

1/// Static analysis of generated Rust source code.
2///
3/// Scans source text for forbidden patterns without compilation.
4/// Fast, deterministic, and free of subprocess calls.
5use serde::{Deserialize, Serialize};
6
7// ---------------------------------------------------------------------------
8// Forbidden patterns
9// ---------------------------------------------------------------------------
10
11/// (kind, pattern) pairs — kind is a category label, pattern is the literal
12/// string to scan for. Checked line-by-line.
13const FORBIDDEN: &[(&str, &str)] = &[
14    // Process spawning
15    ("process_spawn", "process::Command"),
16    ("process_spawn", "Command::new("),
17    // Filesystem writes
18    ("filesystem_write", "fs::write("),
19    ("filesystem_write", "File::create("),
20    ("filesystem_write", "OpenOptions"),
21    ("filesystem_write", ".write_all("),
22    // Network access
23    ("network_access", "std::net::"),
24    ("network_access", "TcpStream"),
25    ("network_access", "UdpSocket"),
26    ("network_access", "reqwest"),
27    ("network_access", "ureq::"),
28    ("network_access", "hyper::"),
29    ("network_access", "tokio::net"),
30    // Unsafe code
31    ("unsafe_code", "unsafe {"),
32    ("unsafe_code", "unsafe fn "),
33    ("unsafe_code", "unsafe impl "),
34    // Environment access
35    ("env_access", "std::env::"),
36    ("env_access", "env::var("),
37    ("env_access", "env::args("),
38    // External crate usage — stdlib only is permitted
39    ("external_crate", "serde_json"),
40    ("external_crate", "serde::"),
41    ("external_crate", "tokio::"),
42    ("external_crate", "anyhow::"),
43    ("external_crate", "thiserror::"),
44    ("external_crate", "regex::"),
45    ("external_crate", "chrono::"),
46    ("external_crate", "rand::"),
47    ("external_crate", "uuid::"),
48    ("external_crate", "base64::"),
49];
50
51// ---------------------------------------------------------------------------
52// Report types
53// ---------------------------------------------------------------------------
54
55#[derive(Debug, Serialize, Deserialize, Clone)]
56pub struct StaticViolation {
57    pub kind: String,
58    pub pattern: String,
59    pub line: usize,
60}
61
62#[derive(Debug, Serialize, Deserialize, Clone)]
63pub struct StaticAnalysisReport {
64    pub passed: bool,
65    pub violations: Vec<StaticViolation>,
66}
67
68// ---------------------------------------------------------------------------
69// Public API
70// ---------------------------------------------------------------------------
71
72/// Scan `source` for forbidden patterns. Returns a report of all violations.
73/// An empty violations list means the source passed all checks.
74pub fn check(source: &str) -> StaticAnalysisReport {
75    let mut violations: Vec<StaticViolation> = vec![];
76
77    for (line_num, line) in source.lines().enumerate() {
78        // Skip comment lines — // and /// are analysis noise
79        let trimmed = line.trim();
80        if trimmed.starts_with("//") {
81            continue;
82        }
83        for (kind, pattern) in FORBIDDEN {
84            if line.contains(pattern) {
85                violations.push(StaticViolation {
86                    kind: (*kind).to_string(),
87                    pattern: (*pattern).to_string(),
88                    line: line_num + 1,
89                });
90            }
91        }
92    }
93
94    StaticAnalysisReport {
95        passed: violations.is_empty(),
96        violations,
97    }
98}
99
100/// True if the source contains at least one `#[test]` function.
101pub fn has_tests(source: &str) -> bool {
102    source.contains("#[test]")
103}
104
105/// Count the number of `#[test]` occurrences in the source.
106pub fn test_count(source: &str) -> usize {
107    source.matches("#[test]").count()
108}
109
110// ---------------------------------------------------------------------------
111// Tests
112// ---------------------------------------------------------------------------
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    const CLEAN_SOURCE: &str = r#"
119pub fn run(input: &str) -> Result<String, String> {
120    let words = input.split_whitespace().count();
121    Ok(format!("{{\"word_count\":{}}}", words))
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    #[test]
128    fn counts_words() {
129        let r = run("hello world").unwrap();
130        assert!(r.contains("2"));
131    }
132    #[test]
133    fn empty_input() {
134        let r = run("").unwrap();
135        assert!(r.contains("0"));
136    }
137}
138"#;
139
140    #[test]
141    fn clean_source_passes() {
142        let report = check(CLEAN_SOURCE);
143        assert!(
144            report.passed,
145            "clean source should pass: {:?}",
146            report.violations
147        );
148    }
149
150    #[test]
151    fn process_command_detected() {
152        let source = r#"fn run(i: &str) -> Result<String, String> {
153    let _ = std::process::Command::new("ls").output();
154    Ok("ok".into())
155}"#;
156        let report = check(source);
157        assert!(!report.passed);
158        assert!(report.violations.iter().any(|v| v.kind == "process_spawn"));
159    }
160
161    #[test]
162    fn command_new_shorthand_detected() {
163        let source = r#"fn run(i: &str) -> Result<String, String> {
164    Command::new("ls");
165    Ok("ok".into())
166}"#;
167        let report = check(source);
168        assert!(report
169            .violations
170            .iter()
171            .any(|v| v.pattern == "Command::new("));
172    }
173
174    #[test]
175    fn fs_write_detected() {
176        let source = r#"fn run(i: &str) -> Result<String, String> {
177    std::fs::write("out.txt", i).unwrap();
178    Ok("ok".into())
179}"#;
180        let report = check(source);
181        assert!(!report.passed);
182        assert!(report
183            .violations
184            .iter()
185            .any(|v| v.kind == "filesystem_write"));
186    }
187
188    #[test]
189    fn file_create_detected() {
190        let source = r#"fn run(i: &str) -> Result<String, String> {
191    let f = File::create("out.txt").unwrap();
192    Ok("ok".into())
193}"#;
194        let report = check(source);
195        assert!(report
196            .violations
197            .iter()
198            .any(|v| v.pattern == "File::create("));
199    }
200
201    #[test]
202    fn tcpstream_detected() {
203        let source = r#"fn run(i: &str) -> Result<String, String> {
204    let s = TcpStream::connect("127.0.0.1:80").unwrap();
205    Ok("ok".into())
206}"#;
207        let report = check(source);
208        assert!(report.violations.iter().any(|v| v.kind == "network_access"));
209    }
210
211    #[test]
212    fn reqwest_detected() {
213        let source = r#"fn run(i: &str) -> Result<String, String> {
214    reqwest::get("https://example.com");
215    Ok("ok".into())
216}"#;
217        let report = check(source);
218        assert!(report.violations.iter().any(|v| v.pattern == "reqwest"));
219    }
220
221    #[test]
222    fn unsafe_block_detected() {
223        let source = r#"fn run(i: &str) -> Result<String, String> {
224    unsafe { let _ = 0; }
225    Ok("ok".into())
226}"#;
227        let report = check(source);
228        assert!(report.violations.iter().any(|v| v.kind == "unsafe_code"));
229    }
230
231    #[test]
232    fn unsafe_fn_detected() {
233        let source = "unsafe fn run(i: &str) -> Result<String, String> { Ok(\"ok\".into()) }";
234        let report = check(source);
235        assert!(report.violations.iter().any(|v| v.pattern == "unsafe fn "));
236    }
237
238    #[test]
239    fn env_var_detected() {
240        let source = r#"fn run(i: &str) -> Result<String, String> {
241    let k = std::env::var("SECRET").unwrap();
242    Ok(k)
243}"#;
244        let report = check(source);
245        assert!(report.violations.iter().any(|v| v.kind == "env_access"));
246    }
247
248    #[test]
249    fn comment_lines_are_skipped() {
250        // A comment mentioning a forbidden pattern should NOT fire
251        let source = r#"fn run(i: &str) -> Result<String, String> {
252    // do NOT use std::process::Command here
253    Ok("ok".into())
254}"#;
255        let report = check(source);
256        assert!(report.passed, "comments must not trigger violations");
257    }
258
259    #[test]
260    fn violation_line_number_is_accurate() {
261        let source =
262            "fn run(i: &str) -> Result<String, String> {\n    unsafe { }\n    Ok(\"ok\".into())\n}";
263        let report = check(source);
264        let unsafe_v = report
265            .violations
266            .iter()
267            .find(|v| v.kind == "unsafe_code")
268            .unwrap();
269        assert_eq!(unsafe_v.line, 2, "violation line should be 2");
270    }
271
272    #[test]
273    fn has_tests_true_when_test_attribute_present() {
274        assert!(has_tests(CLEAN_SOURCE));
275    }
276
277    #[test]
278    fn has_tests_false_when_no_test_attribute() {
279        let source = "fn run(i: &str) -> Result<String, String> { Ok(\"ok\".into()) }";
280        assert!(!has_tests(source));
281    }
282
283    #[test]
284    fn test_count_correct() {
285        assert_eq!(test_count(CLEAN_SOURCE), 2);
286    }
287
288    #[test]
289    fn multiple_violations_all_reported() {
290        let source = r#"fn run(i: &str) -> Result<String, String> {
291    unsafe { }
292    let _ = TcpStream::connect("x").unwrap();
293    Ok("ok".into())
294}"#;
295        let report = check(source);
296        assert!(!report.passed);
297        assert!(report.violations.len() >= 2);
298    }
299}