1use serde::{Deserialize, Serialize};
6
7const FORBIDDEN: &[(&str, &str)] = &[
14 ("process_spawn", "process::Command"),
16 ("process_spawn", "Command::new("),
17 ("filesystem_write", "fs::write("),
19 ("filesystem_write", "File::create("),
20 ("filesystem_write", "OpenOptions"),
21 ("filesystem_write", ".write_all("),
22 ("network_access", "std::net::"),
24 ("network_access", "TcpStream"),
25 ("network_access", "UdpSocket"),
26 ("network_access", "reqwest"),
27 ("network_access", "ureq::"),
28 ("network_access", "hyper::"),
29 ("network_access", "tokio::net"),
30 ("unsafe_code", "unsafe {"),
32 ("unsafe_code", "unsafe fn "),
33 ("unsafe_code", "unsafe impl "),
34 ("env_access", "std::env::"),
36 ("env_access", "env::var("),
37 ("env_access", "env::args("),
38 ("external_crate", "serde_json"),
40 ("external_crate", "serde::"),
41 ("external_crate", "tokio::"),
42 ("external_crate", "anyhow::"),
43 ("external_crate", "thiserror::"),
44 ("external_crate", "regex::"),
45 ("external_crate", "chrono::"),
46 ("external_crate", "rand::"),
47 ("external_crate", "uuid::"),
48 ("external_crate", "base64::"),
49];
50
51#[derive(Debug, Serialize, Deserialize, Clone)]
56pub struct StaticViolation {
57 pub kind: String,
58 pub pattern: String,
59 pub line: usize,
60}
61
62#[derive(Debug, Serialize, Deserialize, Clone)]
63pub struct StaticAnalysisReport {
64 pub passed: bool,
65 pub violations: Vec<StaticViolation>,
66}
67
68pub fn check(source: &str) -> StaticAnalysisReport {
75 let mut violations: Vec<StaticViolation> = vec![];
76
77 for (line_num, line) in source.lines().enumerate() {
78 let trimmed = line.trim();
80 if trimmed.starts_with("//") {
81 continue;
82 }
83 for (kind, pattern) in FORBIDDEN {
84 if line.contains(pattern) {
85 violations.push(StaticViolation {
86 kind: (*kind).to_string(),
87 pattern: (*pattern).to_string(),
88 line: line_num + 1,
89 });
90 }
91 }
92 }
93
94 StaticAnalysisReport {
95 passed: violations.is_empty(),
96 violations,
97 }
98}
99
100pub fn has_tests(source: &str) -> bool {
102 source.contains("#[test]")
103}
104
105pub fn test_count(source: &str) -> usize {
107 source.matches("#[test]").count()
108}
109
110#[cfg(test)]
115mod tests {
116 use super::*;
117
118 const CLEAN_SOURCE: &str = r#"
119pub fn run(input: &str) -> Result<String, String> {
120 let words = input.split_whitespace().count();
121 Ok(format!("{{\"word_count\":{}}}", words))
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 #[test]
128 fn counts_words() {
129 let r = run("hello world").unwrap();
130 assert!(r.contains("2"));
131 }
132 #[test]
133 fn empty_input() {
134 let r = run("").unwrap();
135 assert!(r.contains("0"));
136 }
137}
138"#;
139
140 #[test]
141 fn clean_source_passes() {
142 let report = check(CLEAN_SOURCE);
143 assert!(
144 report.passed,
145 "clean source should pass: {:?}",
146 report.violations
147 );
148 }
149
150 #[test]
151 fn process_command_detected() {
152 let source = r#"fn run(i: &str) -> Result<String, String> {
153 let _ = std::process::Command::new("ls").output();
154 Ok("ok".into())
155}"#;
156 let report = check(source);
157 assert!(!report.passed);
158 assert!(report.violations.iter().any(|v| v.kind == "process_spawn"));
159 }
160
161 #[test]
162 fn command_new_shorthand_detected() {
163 let source = r#"fn run(i: &str) -> Result<String, String> {
164 Command::new("ls");
165 Ok("ok".into())
166}"#;
167 let report = check(source);
168 assert!(report
169 .violations
170 .iter()
171 .any(|v| v.pattern == "Command::new("));
172 }
173
174 #[test]
175 fn fs_write_detected() {
176 let source = r#"fn run(i: &str) -> Result<String, String> {
177 std::fs::write("out.txt", i).unwrap();
178 Ok("ok".into())
179}"#;
180 let report = check(source);
181 assert!(!report.passed);
182 assert!(report
183 .violations
184 .iter()
185 .any(|v| v.kind == "filesystem_write"));
186 }
187
188 #[test]
189 fn file_create_detected() {
190 let source = r#"fn run(i: &str) -> Result<String, String> {
191 let f = File::create("out.txt").unwrap();
192 Ok("ok".into())
193}"#;
194 let report = check(source);
195 assert!(report
196 .violations
197 .iter()
198 .any(|v| v.pattern == "File::create("));
199 }
200
201 #[test]
202 fn tcpstream_detected() {
203 let source = r#"fn run(i: &str) -> Result<String, String> {
204 let s = TcpStream::connect("127.0.0.1:80").unwrap();
205 Ok("ok".into())
206}"#;
207 let report = check(source);
208 assert!(report.violations.iter().any(|v| v.kind == "network_access"));
209 }
210
211 #[test]
212 fn reqwest_detected() {
213 let source = r#"fn run(i: &str) -> Result<String, String> {
214 reqwest::get("https://example.com");
215 Ok("ok".into())
216}"#;
217 let report = check(source);
218 assert!(report.violations.iter().any(|v| v.pattern == "reqwest"));
219 }
220
221 #[test]
222 fn unsafe_block_detected() {
223 let source = r#"fn run(i: &str) -> Result<String, String> {
224 unsafe { let _ = 0; }
225 Ok("ok".into())
226}"#;
227 let report = check(source);
228 assert!(report.violations.iter().any(|v| v.kind == "unsafe_code"));
229 }
230
231 #[test]
232 fn unsafe_fn_detected() {
233 let source = "unsafe fn run(i: &str) -> Result<String, String> { Ok(\"ok\".into()) }";
234 let report = check(source);
235 assert!(report.violations.iter().any(|v| v.pattern == "unsafe fn "));
236 }
237
238 #[test]
239 fn env_var_detected() {
240 let source = r#"fn run(i: &str) -> Result<String, String> {
241 let k = std::env::var("SECRET").unwrap();
242 Ok(k)
243}"#;
244 let report = check(source);
245 assert!(report.violations.iter().any(|v| v.kind == "env_access"));
246 }
247
248 #[test]
249 fn comment_lines_are_skipped() {
250 let source = r#"fn run(i: &str) -> Result<String, String> {
252 // do NOT use std::process::Command here
253 Ok("ok".into())
254}"#;
255 let report = check(source);
256 assert!(report.passed, "comments must not trigger violations");
257 }
258
259 #[test]
260 fn violation_line_number_is_accurate() {
261 let source =
262 "fn run(i: &str) -> Result<String, String> {\n unsafe { }\n Ok(\"ok\".into())\n}";
263 let report = check(source);
264 let unsafe_v = report
265 .violations
266 .iter()
267 .find(|v| v.kind == "unsafe_code")
268 .unwrap();
269 assert_eq!(unsafe_v.line, 2, "violation line should be 2");
270 }
271
272 #[test]
273 fn has_tests_true_when_test_attribute_present() {
274 assert!(has_tests(CLEAN_SOURCE));
275 }
276
277 #[test]
278 fn has_tests_false_when_no_test_attribute() {
279 let source = "fn run(i: &str) -> Result<String, String> { Ok(\"ok\".into()) }";
280 assert!(!has_tests(source));
281 }
282
283 #[test]
284 fn test_count_correct() {
285 assert_eq!(test_count(CLEAN_SOURCE), 2);
286 }
287
288 #[test]
289 fn multiple_violations_all_reported() {
290 let source = r#"fn run(i: &str) -> Result<String, String> {
291 unsafe { }
292 let _ = TcpStream::connect("x").unwrap();
293 Ok("ok".into())
294}"#;
295 let report = check(source);
296 assert!(!report.passed);
297 assert!(report.violations.len() >= 2);
298 }
299}