1use std::time::Instant;
8
9use anyhow::Result;
10use regex::Regex;
11
12use crate::config::RuleSettings;
13
14use super::judge::JudgmentProvider;
15use super::types::{JudgmentDecision, JudgmentRequest, JudgmentResult};
16
17struct ParsedContext {
19 operation: Option<String>,
21 target: Option<String>,
23}
24
25pub struct RuleEngine {
27 settings: RuleSettings,
28 allow_patterns: Vec<Regex>,
30}
31
32impl RuleEngine {
33 pub fn new(settings: RuleSettings) -> Self {
35 let allow_patterns = settings
36 .allow_patterns
37 .iter()
38 .filter_map(|p| match Regex::new(p) {
39 Ok(r) => Some(r),
40 Err(e) => {
41 tracing::warn!(pattern = %p, "Invalid allow_pattern regex: {}", e);
42 None
43 }
44 })
45 .collect();
46
47 Self {
48 settings,
49 allow_patterns,
50 }
51 }
52
53 fn parse_context(screen_context: &str) -> ParsedContext {
55 let last_lines: Vec<&str> = screen_context.lines().rev().take(15).collect();
62 let search_text: String = last_lines.into_iter().rev().collect::<Vec<_>>().join("\n");
63
64 let access_re = Regex::new(r"(?i)Allow\s+(\w+)\s+access\s+to\s+(.+)").expect("valid regex");
66 if let Some(caps) = access_re.captures(&search_text) {
67 return ParsedContext {
68 operation: Some(caps[1].to_string()),
69 target: Some(caps[2].trim().to_string()),
70 };
71 }
72
73 let colon_re = Regex::new(r"(?i)Allow\s+([\w\s]+?):\s+(.+)").expect("valid regex");
75 if let Some(caps) = colon_re.captures(&search_text) {
76 return ParsedContext {
77 operation: Some(caps[1].trim().to_string()),
78 target: Some(caps[2].trim().to_string()),
79 };
80 }
81
82 ParsedContext {
83 operation: None,
84 target: None,
85 }
86 }
87
88 fn check_allow(
90 &self,
91 screen_context: &str,
92 operation: Option<&str>,
93 target: Option<&str>,
94 ) -> Option<String> {
95 for (i, pattern) in self.allow_patterns.iter().enumerate() {
97 if pattern.is_match(screen_context) {
98 return Some(format!(
99 "allow_pattern[{}]: {}",
100 i, self.settings.allow_patterns[i]
101 ));
102 }
103 }
104
105 let op = operation.unwrap_or("").to_lowercase();
106 let tgt = target.unwrap_or("").to_lowercase();
107
108 if self.settings.allow_read {
110 if op == "read" {
111 return Some("allow_read: Read access".to_string());
112 }
113 let read_commands = [
114 "cat ", "head ", "tail ", "less ", "ls ", "find ", "grep ", "wc ",
115 ];
116 if op == "bash" {
117 for cmd in &read_commands {
118 if tgt.starts_with(cmd) || tgt.contains(&format!(" | {}", cmd)) {
119 return Some(format!("allow_read: {}", cmd.trim()));
120 }
121 }
122 }
123 }
124
125 if self.settings.allow_tests && op == "bash" {
127 let test_commands = [
128 "cargo test",
129 "npm test",
130 "npm run test",
131 "npx jest",
132 "npx vitest",
133 "pytest",
134 "python -m pytest",
135 "go test",
136 "dotnet test",
137 "mvn test",
138 "gradle test",
139 ];
140 for cmd in &test_commands {
141 if tgt.starts_with(cmd) || tgt.contains(&format!("&& {}", cmd)) {
142 return Some(format!("allow_tests: {}", cmd));
143 }
144 }
145 }
146
147 if self.settings.allow_fetch {
149 if op == "webfetch" || op == "websearch" {
150 return Some(format!("allow_fetch: {}", op));
151 }
152 if op == "bash"
154 && tgt.starts_with("curl ")
155 && !tgt.contains("-x post")
156 && !tgt.contains("--data")
157 && !tgt.contains(" -d ")
158 {
159 return Some("allow_fetch: curl GET".to_string());
160 }
161 }
162
163 if self.settings.allow_git_readonly && op == "bash" {
165 let git_readonly = [
166 "git status",
167 "git log",
168 "git diff",
169 "git branch",
170 "git show",
171 "git blame",
172 "git stash list",
173 "git remote -v",
174 "git tag",
175 "git rev-parse",
176 "git ls-files",
177 "git ls-tree",
178 ];
179 for cmd in &git_readonly {
180 if tgt.starts_with(cmd) {
181 return Some(format!("allow_git_readonly: {}", cmd));
182 }
183 }
184 }
185
186 if self.settings.allow_format_lint && op == "bash" {
188 let fmt_commands = [
189 "cargo fmt",
190 "cargo clippy",
191 "prettier",
192 "eslint",
193 "rustfmt",
194 "black ",
195 "isort ",
196 "gofmt",
197 "go fmt",
198 "biome ",
199 "deno fmt",
200 "deno lint",
201 ];
202 for cmd in &fmt_commands {
203 if tgt.starts_with(cmd) || tgt.contains(&format!("npx {}", cmd)) {
204 return Some(format!("allow_format_lint: {}", cmd.trim()));
205 }
206 }
207 }
208
209 None
210 }
211}
212
213impl JudgmentProvider for RuleEngine {
214 async fn judge(&self, request: &JudgmentRequest) -> Result<JudgmentResult> {
220 let start = Instant::now();
221 let parsed = Self::parse_context(&request.screen_context);
222
223 if let Some(rule) = self.check_allow(
225 &request.screen_context,
226 parsed.operation.as_deref(),
227 parsed.target.as_deref(),
228 ) {
229 return Ok(JudgmentResult {
230 decision: JudgmentDecision::Approve,
231 reasoning: format!("Allowed by rule: {}", rule),
232 model: format!("rules:{}", rule.split(':').next().unwrap_or("allow")),
233 elapsed_ms: start.elapsed().as_millis() as u64,
234 usage: None,
235 });
236 }
237
238 Ok(JudgmentResult {
240 decision: JudgmentDecision::Uncertain,
241 reasoning: "No matching allow rule".to_string(),
242 model: "rules:abstain".to_string(),
243 elapsed_ms: start.elapsed().as_millis() as u64,
244 usage: None,
245 })
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 fn default_engine() -> RuleEngine {
255 RuleEngine::new(RuleSettings::default())
256 }
257
258 fn request_with_context(screen_context: &str) -> JudgmentRequest {
260 JudgmentRequest {
261 target: "test:0.1".to_string(),
262 approval_type: "shell_command".to_string(),
263 details: String::new(),
264 screen_context: screen_context.to_string(),
265 cwd: "/tmp/project".to_string(),
266 agent_type: "claude_code".to_string(),
267 }
268 }
269
270 #[tokio::test]
271 async fn test_allow_read_access() {
272 let engine = default_engine();
273 let req = request_with_context("Allow Read access to /home/user/project/src/main.rs");
274 let result = engine.judge(&req).await.unwrap();
275 assert_eq!(result.decision, JudgmentDecision::Approve);
276 assert!(result.model.starts_with("rules:"));
277 }
278
279 #[tokio::test]
280 async fn test_allow_bash_cat() {
281 let engine = default_engine();
282 let req = request_with_context("Allow Bash: cat /etc/hosts");
283 let result = engine.judge(&req).await.unwrap();
284 assert_eq!(result.decision, JudgmentDecision::Approve);
285 }
286
287 #[tokio::test]
288 async fn test_allow_cargo_test() {
289 let engine = default_engine();
290 let req = request_with_context("Allow Bash: cargo test --lib");
291 let result = engine.judge(&req).await.unwrap();
292 assert_eq!(result.decision, JudgmentDecision::Approve);
293 assert!(result.reasoning.contains("allow_tests"));
294 }
295
296 #[tokio::test]
297 async fn test_allow_git_status() {
298 let engine = default_engine();
299 let req = request_with_context("Allow Bash: git status");
300 let result = engine.judge(&req).await.unwrap();
301 assert_eq!(result.decision, JudgmentDecision::Approve);
302 assert!(result.reasoning.contains("allow_git_readonly"));
303 }
304
305 #[tokio::test]
306 async fn test_allow_cargo_fmt() {
307 let engine = default_engine();
308 let req = request_with_context("Allow Bash: cargo fmt");
309 let result = engine.judge(&req).await.unwrap();
310 assert_eq!(result.decision, JudgmentDecision::Approve);
311 assert!(result.reasoning.contains("allow_format_lint"));
312 }
313
314 #[tokio::test]
315 async fn test_allow_webfetch() {
316 let engine = default_engine();
317 let req = request_with_context("Allow WebFetch: https://docs.rs/ratatui/latest");
318 let result = engine.judge(&req).await.unwrap();
319 assert_eq!(result.decision, JudgmentDecision::Approve);
320 assert!(result.reasoning.contains("allow_fetch"));
321 }
322
323 #[tokio::test]
324 async fn test_abstain_unknown_command() {
325 let engine = default_engine();
326 let req = request_with_context("Allow Bash: some-unknown-command --flag");
327 let result = engine.judge(&req).await.unwrap();
328 assert_eq!(result.decision, JudgmentDecision::Uncertain);
329 assert!(result.model.contains("abstain"));
330 }
331
332 #[tokio::test]
333 async fn test_abstain_edit_operation() {
334 let engine = default_engine();
336 let req = request_with_context("Allow Edit access to /home/user/project/src/main.rs");
337 let result = engine.judge(&req).await.unwrap();
338 assert_eq!(result.decision, JudgmentDecision::Uncertain);
339 }
340
341 #[tokio::test]
342 async fn test_disabled_allow_read() {
343 let settings = RuleSettings {
344 allow_read: false,
345 ..Default::default()
346 };
347 let engine = RuleEngine::new(settings);
348 let req = request_with_context("Allow Read access to /home/user/file.txt");
349 let result = engine.judge(&req).await.unwrap();
350 assert_eq!(result.decision, JudgmentDecision::Uncertain);
352 }
353
354 #[tokio::test]
355 async fn test_custom_allow_pattern() {
356 let settings = RuleSettings {
357 allow_patterns: vec![r"my-safe-tool".to_string()],
358 ..Default::default()
359 };
360 let engine = RuleEngine::new(settings);
361 let req = request_with_context("Allow Bash: my-safe-tool run --safe");
362 let result = engine.judge(&req).await.unwrap();
363 assert_eq!(result.decision, JudgmentDecision::Approve);
364 assert!(result.reasoning.contains("allow_pattern"));
365 }
366
367 #[tokio::test]
368 async fn test_model_field_format() {
369 let engine = default_engine();
370 let req = request_with_context("Allow Read access to /tmp/file.txt");
371 let result = engine.judge(&req).await.unwrap();
372 assert!(result.model.starts_with("rules:"));
373 }
374
375 #[tokio::test]
376 async fn test_curl_get_allowed() {
377 let engine = default_engine();
378 let req = request_with_context("Allow Bash: curl https://api.example.com/data");
379 let result = engine.judge(&req).await.unwrap();
380 assert_eq!(result.decision, JudgmentDecision::Approve);
381 }
382
383 #[tokio::test]
384 async fn test_curl_post_abstain() {
385 let engine = default_engine();
386 let req =
387 request_with_context("Allow Bash: curl -X POST https://api.example.com/data -d '{}'");
388 let result = engine.judge(&req).await.unwrap();
389 assert_eq!(result.decision, JudgmentDecision::Uncertain);
391 }
392}