Skip to main content

tmai_core/auto_approve/
rules.rs

1/// Rule-based auto-approve engine.
2///
3/// Evaluates screen context against allow rules to make instant
4/// (sub-millisecond) approval decisions without AI.
5/// If no allow rule matches, the decision is Uncertain — which becomes
6/// ManualRequired in Rules mode, or escalates to AI in Hybrid mode.
7use std::time::Instant;
8
9use anyhow::Result;
10use regex::Regex;
11
12use crate::config::RuleSettings;
13
14use super::judge::JudgmentProvider;
15use super::types::{JudgmentDecision, JudgmentRequest, JudgmentResult};
16
17/// Parsed operation from Claude Code's approval prompt
18struct ParsedContext {
19    /// Operation type: "Read", "Edit", "Bash", "WebFetch", "WebSearch", "MCP tool", etc.
20    operation: Option<String>,
21    /// Target: file path, command string, URL, etc.
22    target: Option<String>,
23}
24
25/// Rule engine that evaluates allow rules against screen context
26pub struct RuleEngine {
27    settings: RuleSettings,
28    /// Compiled user-defined allow patterns
29    allow_patterns: Vec<Regex>,
30}
31
32impl RuleEngine {
33    /// Create a new RuleEngine with the given settings
34    pub fn new(settings: RuleSettings) -> Self {
35        let allow_patterns = settings
36            .allow_patterns
37            .iter()
38            .filter_map(|p| match Regex::new(p) {
39                Ok(r) => Some(r),
40                Err(e) => {
41                    tracing::warn!(pattern = %p, "Invalid allow_pattern regex: {}", e);
42                    None
43                }
44            })
45            .collect();
46
47        Self {
48            settings,
49            allow_patterns,
50        }
51    }
52
53    /// Parse the screen context to extract operation and target
54    fn parse_context(screen_context: &str) -> ParsedContext {
55        // Claude Code approval prompts follow patterns like:
56        //   "Allow Read access to /path/to/file"
57        //   "Allow Edit access to /path/to/file"
58        //   "Allow Bash: git status"
59        //   "Allow WebFetch: https://..."
60        //   "Allow MCP tool: tool_name"
61        let last_lines: Vec<&str> = screen_context.lines().rev().take(15).collect();
62        let search_text: String = last_lines.into_iter().rev().collect::<Vec<_>>().join("\n");
63
64        // Try "Allow <Operation> access to <target>" pattern
65        let access_re = Regex::new(r"(?i)Allow\s+(\w+)\s+access\s+to\s+(.+)").expect("valid regex");
66        if let Some(caps) = access_re.captures(&search_text) {
67            return ParsedContext {
68                operation: Some(caps[1].to_string()),
69                target: Some(caps[2].trim().to_string()),
70            };
71        }
72
73        // Try "Allow <Operation>: <target>" pattern
74        let colon_re = Regex::new(r"(?i)Allow\s+([\w\s]+?):\s+(.+)").expect("valid regex");
75        if let Some(caps) = colon_re.captures(&search_text) {
76            return ParsedContext {
77                operation: Some(caps[1].trim().to_string()),
78                target: Some(caps[2].trim().to_string()),
79            };
80        }
81
82        ParsedContext {
83            operation: None,
84            target: None,
85        }
86    }
87
88    /// Check allow rules; returns the matching rule name if allowed
89    fn check_allow(
90        &self,
91        screen_context: &str,
92        operation: Option<&str>,
93        target: Option<&str>,
94    ) -> Option<String> {
95        // User-defined allow patterns (highest priority)
96        for (i, pattern) in self.allow_patterns.iter().enumerate() {
97            if pattern.is_match(screen_context) {
98                return Some(format!(
99                    "allow_pattern[{}]: {}",
100                    i, self.settings.allow_patterns[i]
101                ));
102            }
103        }
104
105        let op = operation.unwrap_or("").to_lowercase();
106        let tgt = target.unwrap_or("").to_lowercase();
107
108        // Read operations
109        if self.settings.allow_read {
110            if op == "read" {
111                return Some("allow_read: Read access".to_string());
112            }
113            let read_commands = [
114                "cat ", "head ", "tail ", "less ", "ls ", "find ", "grep ", "wc ",
115            ];
116            if op == "bash" {
117                for cmd in &read_commands {
118                    if tgt.starts_with(cmd) || tgt.contains(&format!(" | {}", cmd)) {
119                        return Some(format!("allow_read: {}", cmd.trim()));
120                    }
121                }
122            }
123        }
124
125        // Test execution
126        if self.settings.allow_tests && op == "bash" {
127            let test_commands = [
128                "cargo test",
129                "npm test",
130                "npm run test",
131                "npx jest",
132                "npx vitest",
133                "pytest",
134                "python -m pytest",
135                "go test",
136                "dotnet test",
137                "mvn test",
138                "gradle test",
139            ];
140            for cmd in &test_commands {
141                if tgt.starts_with(cmd) || tgt.contains(&format!("&& {}", cmd)) {
142                    return Some(format!("allow_tests: {}", cmd));
143                }
144            }
145        }
146
147        // Fetch/search operations
148        if self.settings.allow_fetch {
149            if op == "webfetch" || op == "websearch" {
150                return Some(format!("allow_fetch: {}", op));
151            }
152            // curl GET (no -X POST, no --data, no -d)
153            if op == "bash"
154                && tgt.starts_with("curl ")
155                && !tgt.contains("-x post")
156                && !tgt.contains("--data")
157                && !tgt.contains(" -d ")
158            {
159                return Some("allow_fetch: curl GET".to_string());
160            }
161        }
162
163        // Git read-only commands
164        if self.settings.allow_git_readonly && op == "bash" {
165            let git_readonly = [
166                "git status",
167                "git log",
168                "git diff",
169                "git branch",
170                "git show",
171                "git blame",
172                "git stash list",
173                "git remote -v",
174                "git tag",
175                "git rev-parse",
176                "git ls-files",
177                "git ls-tree",
178            ];
179            for cmd in &git_readonly {
180                if tgt.starts_with(cmd) {
181                    return Some(format!("allow_git_readonly: {}", cmd));
182                }
183            }
184        }
185
186        // Format/lint commands
187        if self.settings.allow_format_lint && op == "bash" {
188            let fmt_commands = [
189                "cargo fmt",
190                "cargo clippy",
191                "prettier",
192                "eslint",
193                "rustfmt",
194                "black ",
195                "isort ",
196                "gofmt",
197                "go fmt",
198                "biome ",
199                "deno fmt",
200                "deno lint",
201            ];
202            for cmd in &fmt_commands {
203                if tgt.starts_with(cmd) || tgt.contains(&format!("npx {}", cmd)) {
204                    return Some(format!("allow_format_lint: {}", cmd.trim()));
205                }
206            }
207        }
208
209        None
210    }
211}
212
213impl JudgmentProvider for RuleEngine {
214    /// Evaluate allow rules against the request (instant, sub-millisecond).
215    ///
216    /// Returns Approve if an allow rule matches, Uncertain otherwise.
217    /// There are no deny rules — unmatched requests fall through to manual
218    /// approval (Rules mode) or AI escalation (Hybrid mode).
219    async fn judge(&self, request: &JudgmentRequest) -> Result<JudgmentResult> {
220        let start = Instant::now();
221        let parsed = Self::parse_context(&request.screen_context);
222
223        // Check allow rules
224        if let Some(rule) = self.check_allow(
225            &request.screen_context,
226            parsed.operation.as_deref(),
227            parsed.target.as_deref(),
228        ) {
229            return Ok(JudgmentResult {
230                decision: JudgmentDecision::Approve,
231                reasoning: format!("Allowed by rule: {}", rule),
232                model: format!("rules:{}", rule.split(':').next().unwrap_or("allow")),
233                elapsed_ms: start.elapsed().as_millis() as u64,
234                usage: None,
235            });
236        }
237
238        // No matching rule — abstain (hand off to manual or AI)
239        Ok(JudgmentResult {
240            decision: JudgmentDecision::Uncertain,
241            reasoning: "No matching allow rule".to_string(),
242            model: "rules:abstain".to_string(),
243            elapsed_ms: start.elapsed().as_millis() as u64,
244            usage: None,
245        })
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    /// Helper to create a RuleEngine with default settings
254    fn default_engine() -> RuleEngine {
255        RuleEngine::new(RuleSettings::default())
256    }
257
258    /// Helper to create a JudgmentRequest with given screen context
259    fn request_with_context(screen_context: &str) -> JudgmentRequest {
260        JudgmentRequest {
261            target: "test:0.1".to_string(),
262            approval_type: "shell_command".to_string(),
263            details: String::new(),
264            screen_context: screen_context.to_string(),
265            cwd: "/tmp/project".to_string(),
266            agent_type: "claude_code".to_string(),
267        }
268    }
269
270    #[tokio::test]
271    async fn test_allow_read_access() {
272        let engine = default_engine();
273        let req = request_with_context("Allow Read access to /home/user/project/src/main.rs");
274        let result = engine.judge(&req).await.unwrap();
275        assert_eq!(result.decision, JudgmentDecision::Approve);
276        assert!(result.model.starts_with("rules:"));
277    }
278
279    #[tokio::test]
280    async fn test_allow_bash_cat() {
281        let engine = default_engine();
282        let req = request_with_context("Allow Bash: cat /etc/hosts");
283        let result = engine.judge(&req).await.unwrap();
284        assert_eq!(result.decision, JudgmentDecision::Approve);
285    }
286
287    #[tokio::test]
288    async fn test_allow_cargo_test() {
289        let engine = default_engine();
290        let req = request_with_context("Allow Bash: cargo test --lib");
291        let result = engine.judge(&req).await.unwrap();
292        assert_eq!(result.decision, JudgmentDecision::Approve);
293        assert!(result.reasoning.contains("allow_tests"));
294    }
295
296    #[tokio::test]
297    async fn test_allow_git_status() {
298        let engine = default_engine();
299        let req = request_with_context("Allow Bash: git status");
300        let result = engine.judge(&req).await.unwrap();
301        assert_eq!(result.decision, JudgmentDecision::Approve);
302        assert!(result.reasoning.contains("allow_git_readonly"));
303    }
304
305    #[tokio::test]
306    async fn test_allow_cargo_fmt() {
307        let engine = default_engine();
308        let req = request_with_context("Allow Bash: cargo fmt");
309        let result = engine.judge(&req).await.unwrap();
310        assert_eq!(result.decision, JudgmentDecision::Approve);
311        assert!(result.reasoning.contains("allow_format_lint"));
312    }
313
314    #[tokio::test]
315    async fn test_allow_webfetch() {
316        let engine = default_engine();
317        let req = request_with_context("Allow WebFetch: https://docs.rs/ratatui/latest");
318        let result = engine.judge(&req).await.unwrap();
319        assert_eq!(result.decision, JudgmentDecision::Approve);
320        assert!(result.reasoning.contains("allow_fetch"));
321    }
322
323    #[tokio::test]
324    async fn test_abstain_unknown_command() {
325        let engine = default_engine();
326        let req = request_with_context("Allow Bash: some-unknown-command --flag");
327        let result = engine.judge(&req).await.unwrap();
328        assert_eq!(result.decision, JudgmentDecision::Uncertain);
329        assert!(result.model.contains("abstain"));
330    }
331
332    #[tokio::test]
333    async fn test_abstain_edit_operation() {
334        // Edit operations should not be auto-approved by default rules
335        let engine = default_engine();
336        let req = request_with_context("Allow Edit access to /home/user/project/src/main.rs");
337        let result = engine.judge(&req).await.unwrap();
338        assert_eq!(result.decision, JudgmentDecision::Uncertain);
339    }
340
341    #[tokio::test]
342    async fn test_disabled_allow_read() {
343        let settings = RuleSettings {
344            allow_read: false,
345            ..Default::default()
346        };
347        let engine = RuleEngine::new(settings);
348        let req = request_with_context("Allow Read access to /home/user/file.txt");
349        let result = engine.judge(&req).await.unwrap();
350        // With allow_read disabled, Read should abstain (not be allowed)
351        assert_eq!(result.decision, JudgmentDecision::Uncertain);
352    }
353
354    #[tokio::test]
355    async fn test_custom_allow_pattern() {
356        let settings = RuleSettings {
357            allow_patterns: vec![r"my-safe-tool".to_string()],
358            ..Default::default()
359        };
360        let engine = RuleEngine::new(settings);
361        let req = request_with_context("Allow Bash: my-safe-tool run --safe");
362        let result = engine.judge(&req).await.unwrap();
363        assert_eq!(result.decision, JudgmentDecision::Approve);
364        assert!(result.reasoning.contains("allow_pattern"));
365    }
366
367    #[tokio::test]
368    async fn test_model_field_format() {
369        let engine = default_engine();
370        let req = request_with_context("Allow Read access to /tmp/file.txt");
371        let result = engine.judge(&req).await.unwrap();
372        assert!(result.model.starts_with("rules:"));
373    }
374
375    #[tokio::test]
376    async fn test_curl_get_allowed() {
377        let engine = default_engine();
378        let req = request_with_context("Allow Bash: curl https://api.example.com/data");
379        let result = engine.judge(&req).await.unwrap();
380        assert_eq!(result.decision, JudgmentDecision::Approve);
381    }
382
383    #[tokio::test]
384    async fn test_curl_post_abstain() {
385        let engine = default_engine();
386        let req =
387            request_with_context("Allow Bash: curl -X POST https://api.example.com/data -d '{}'");
388        let result = engine.judge(&req).await.unwrap();
389        // POST curl should not be auto-approved by rules → abstain
390        assert_eq!(result.decision, JudgmentDecision::Uncertain);
391    }
392}