1use super::{
7 approval::{AskForApproval, ExecApprovalRequirement, ExecPolicyAmendment},
8 policy::{Decision, Policy, PolicyEvaluation, RuleMatch},
9};
10use crate::command_safety::command_might_be_dangerous;
11use crate::sandboxing::SandboxPolicy;
12use anyhow::{Context, Result};
13use std::{
14 collections::HashSet,
15 path::{Path, PathBuf},
16 sync::Arc,
17};
18use tokio::sync::RwLock;
19
20const PROMPT_CONFLICT_REASON: &str =
21 "approval required by policy, but AskForApproval is set to Never";
22const REJECT_SANDBOX_APPROVAL_REASON: &str =
23 "approval required by policy, but AskForApproval::Reject.sandbox_approval is set";
24const REJECT_RULES_APPROVAL_REASON: &str =
25 "approval required by policy rule, but AskForApproval::Reject.rules is set";
26
27fn prompt_is_rejected_by_policy(
28 approval_policy: AskForApproval,
29 prompt_is_rule: bool,
30) -> Option<&'static str> {
31 if prompt_is_rule {
32 if !approval_policy.rejects_rule_prompt() {
33 return None;
34 }
35
36 return Some(if matches!(approval_policy, AskForApproval::Never) {
37 PROMPT_CONFLICT_REASON
38 } else {
39 REJECT_RULES_APPROVAL_REASON
40 });
41 }
42
43 if !approval_policy.rejects_sandbox_prompt() {
44 return None;
45 }
46
47 Some(if matches!(approval_policy, AskForApproval::Never) {
48 PROMPT_CONFLICT_REASON
49 } else {
50 REJECT_SANDBOX_APPROVAL_REASON
51 })
52}
53
54#[derive(Debug, Clone)]
56pub struct ExecPolicyConfig {
57 pub default_sandbox_policy: SandboxPolicy,
59
60 pub default_approval: AskForApproval,
62
63 pub use_heuristics: bool,
65
66 pub max_auto_approve_length: usize,
68}
69
70impl Default for ExecPolicyConfig {
71 fn default() -> Self {
72 Self {
73 default_sandbox_policy: SandboxPolicy::read_only(),
74 default_approval: AskForApproval::UnlessTrusted,
75 use_heuristics: true,
76 max_auto_approve_length: 256,
77 }
78 }
79}
80
81pub struct ExecPolicyManager {
83 policy: RwLock<Policy>,
85
86 trusted_patterns: RwLock<Vec<ExecPolicyAmendment>>,
88
89 sandbox_policy: RwLock<SandboxPolicy>,
91
92 config: ExecPolicyConfig,
94
95 #[expect(dead_code)]
97 workspace_root: PathBuf,
98
99 session_approved: RwLock<HashSet<String>>,
101}
102
103impl ExecPolicyManager {
104 pub fn new(workspace_root: PathBuf, config: ExecPolicyConfig) -> Self {
106 Self {
107 policy: RwLock::new(Policy::empty()),
108 trusted_patterns: RwLock::new(Vec::new()),
109 sandbox_policy: RwLock::new(config.default_sandbox_policy.clone()),
110 config,
111 workspace_root,
112 session_approved: RwLock::new(HashSet::new()),
113 }
114 }
115
116 pub fn with_defaults(workspace_root: PathBuf) -> Self {
118 Self::new(workspace_root, ExecPolicyConfig::default())
119 }
120
121 pub async fn load_policy(&self, path: &Path) -> Result<()> {
123 let parser = super::parser::PolicyParser::new();
124 let loaded_policy = parser
125 .load_file(path)
126 .await
127 .context("Failed to load policy file")?;
128
129 let mut policy = self.policy.write().await;
130 *policy = loaded_policy;
131 Ok(())
132 }
133
134 pub async fn add_prefix_rule(&self, pattern: &[String], decision: Decision) -> Result<()> {
136 let mut policy = self.policy.write().await;
137 policy.add_prefix_rule(pattern, decision)
138 }
139
140 pub async fn add_trusted_pattern(&self, amendment: ExecPolicyAmendment) {
142 let mut patterns = self.trusted_patterns.write().await;
143 patterns.push(amendment);
144 }
145
146 pub async fn set_sandbox_policy(&self, policy: SandboxPolicy) {
148 let mut sandbox = self.sandbox_policy.write().await;
149 *sandbox = policy;
150 }
151
152 pub async fn sandbox_policy(&self) -> SandboxPolicy {
154 self.sandbox_policy.read().await.clone()
155 }
156
157 pub async fn check_approval(&self, command: &[String]) -> ExecApprovalRequirement {
159 let command_key = command.join(" ");
161 {
162 let approved = self.session_approved.read().await;
163 if approved.contains(&command_key) {
164 return ExecApprovalRequirement::skip();
165 }
166 }
167
168 {
170 let patterns = self.trusted_patterns.read().await;
171 for pattern in patterns.iter() {
172 if pattern.matches(command) {
173 return ExecApprovalRequirement::skip();
174 }
175 }
176 }
177
178 let policy = self.policy.read().await;
180 let rule_match = policy.check(command);
181
182 let decision = match &rule_match {
184 RuleMatch::PrefixRuleMatch { decision, .. } => *decision,
185 RuleMatch::HeuristicsRuleMatch { .. } => self.heuristics_decision(command),
186 };
187
188 match decision {
189 Decision::Allow => ExecApprovalRequirement::skip(),
190 Decision::Prompt => {
191 let prompt_is_rule = matches!(
192 rule_match,
193 RuleMatch::PrefixRuleMatch {
194 decision: Decision::Prompt,
195 ..
196 }
197 );
198
199 match prompt_is_rejected_by_policy(self.config.default_approval, prompt_is_rule) {
200 Some(reason) => ExecApprovalRequirement::forbidden(reason),
201 None => ExecApprovalRequirement::needs_approval(
202 self.format_approval_reason(command, &rule_match),
203 ),
204 }
205 }
206 Decision::Forbidden => ExecApprovalRequirement::forbidden(
207 self.format_forbidden_reason(command, &rule_match),
208 ),
209 }
210 }
211
212 pub async fn check_approval_batch(&self, commands: &[Vec<String>]) -> ExecApprovalRequirement {
214 let mut needs_approval_flag = false;
215 let mut reasons = Vec::new();
216
217 for command in commands {
218 let approval = self.check_approval(command).await;
219 if approval.is_forbidden() {
220 return approval;
221 }
222 if approval.requires_approval() {
223 needs_approval_flag = true;
224 if let ExecApprovalRequirement::NeedsApproval {
225 reason: Some(r), ..
226 } = &approval
227 {
228 reasons.push(r.clone());
229 }
230 }
231 }
232
233 if needs_approval_flag {
234 ExecApprovalRequirement::needs_approval(reasons.join("; "))
235 } else {
236 ExecApprovalRequirement::skip()
237 }
238 }
239
240 pub async fn approve_command(&self, command: &[String]) {
242 let command_key = command.join(" ");
243 let mut approved = self.session_approved.write().await;
244 approved.insert(command_key);
245 }
246
247 pub async fn clear_session_approvals(&self) {
249 let mut approved = self.session_approved.write().await;
250 approved.clear();
251 }
252
253 pub async fn evaluate(&self, command: &[String]) -> PolicyEvaluation {
255 let policy = self.policy.read().await;
256 let commands = [command.to_vec()];
257 policy.check_multiple(commands.iter(), &|cmd| self.heuristics_decision(cmd))
258 }
259
260 fn heuristics_decision(&self, command: &[String]) -> Decision {
264 if !self.config.use_heuristics {
265 return Decision::Prompt;
266 }
267
268 if command.is_empty() {
269 return Decision::Prompt;
270 }
271
272 let cmd = &command[0];
273
274 let safe_commands = [
276 "ls", "cat", "head", "tail", "grep", "find", "echo", "pwd", "which", "type", "less",
277 "more", "wc", "sort", "uniq", "diff", "env", "printenv", "hostname", "uname", "date",
278 "whoami", "id", "file", "stat", "tree", "df", "du", "uptime",
279 ];
280
281 if safe_commands.contains(&cmd.as_str()) {
282 return Decision::Allow;
283 }
284
285 if command_might_be_dangerous(command) {
287 if command.iter().any(|arg| arg == "--dry-run" || arg == "-n") {
289 return Decision::Prompt;
290 }
291 return Decision::Forbidden;
292 }
293
294 Decision::Prompt
296 }
297
298 fn format_approval_reason(&self, command: &[String], rule_match: &RuleMatch) -> String {
300 match rule_match {
301 RuleMatch::PrefixRuleMatch { rule, .. } => {
302 format!(
303 "Command '{}' matched rule '{}' requiring confirmation",
304 command.join(" "),
305 rule.pattern.join(" ")
306 )
307 }
308 RuleMatch::HeuristicsRuleMatch { .. } => {
309 format!(
310 "Command '{}' requires confirmation (no explicit policy rule)",
311 command.join(" ")
312 )
313 }
314 }
315 }
316
317 fn format_forbidden_reason(&self, command: &[String], rule_match: &RuleMatch) -> String {
319 match rule_match {
320 RuleMatch::PrefixRuleMatch { rule, .. } => {
321 format!(
322 "Command '{}' is forbidden by rule '{}'",
323 command.join(" "),
324 rule.pattern.join(" ")
325 )
326 }
327 RuleMatch::HeuristicsRuleMatch { .. } => {
328 format!(
329 "Command '{}' is forbidden by safety heuristics",
330 command.join(" ")
331 )
332 }
333 }
334 }
335}
336
337pub type SharedExecPolicyManager = Arc<ExecPolicyManager>;
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use tempfile::tempdir;
344
345 #[tokio::test]
346 async fn test_policy_manager_basic() {
347 let dir = tempdir().unwrap();
348 let manager = ExecPolicyManager::with_defaults(dir.path().to_path_buf());
349
350 manager
352 .add_prefix_rule(&["cargo".to_string(), "build".to_string()], Decision::Allow)
353 .await
354 .unwrap();
355
356 let result = manager
358 .check_approval(&["cargo".to_string(), "build".to_string()])
359 .await;
360 assert!(result.can_proceed());
361
362 let result = manager
364 .check_approval(&["unknown".to_string(), "command".to_string()])
365 .await;
366 assert!(result.requires_approval());
367 }
368
369 #[tokio::test]
370 async fn test_prompt_conflict_with_never_policy_forbids() {
371 let dir = tempdir().unwrap();
372 let manager = ExecPolicyManager::new(
373 dir.path().to_path_buf(),
374 ExecPolicyConfig {
375 default_approval: AskForApproval::Never,
376 ..ExecPolicyConfig::default()
377 },
378 );
379
380 let result = manager
381 .check_approval(&["unknown".to_string(), "command".to_string()])
382 .await;
383 assert_eq!(
384 result,
385 ExecApprovalRequirement::forbidden(PROMPT_CONFLICT_REASON)
386 );
387 }
388
389 #[tokio::test]
390 async fn test_reject_rules_policy_forbids_rule_prompt() {
391 let dir = tempdir().unwrap();
392 let manager = ExecPolicyManager::new(
393 dir.path().to_path_buf(),
394 ExecPolicyConfig {
395 default_approval: AskForApproval::Reject(crate::exec_policy::RejectConfig {
396 sandbox_approval: false,
397 rules: true,
398 request_permissions: false,
399 mcp_elicitations: false,
400 }),
401 ..ExecPolicyConfig::default()
402 },
403 );
404 manager
405 .add_prefix_rule(&["git".to_string()], Decision::Prompt)
406 .await
407 .expect("add prompt rule");
408
409 let result = manager.check_approval(&["git".to_string()]).await;
410 assert_eq!(
411 result,
412 ExecApprovalRequirement::forbidden(REJECT_RULES_APPROVAL_REASON)
413 );
414 }
415
416 #[tokio::test]
417 async fn test_reject_sandbox_policy_forbids_non_rule_prompt() {
418 let dir = tempdir().unwrap();
419 let manager = ExecPolicyManager::new(
420 dir.path().to_path_buf(),
421 ExecPolicyConfig {
422 default_approval: AskForApproval::Reject(crate::exec_policy::RejectConfig {
423 sandbox_approval: true,
424 rules: false,
425 request_permissions: false,
426 mcp_elicitations: false,
427 }),
428 ..ExecPolicyConfig::default()
429 },
430 );
431
432 let result = manager
433 .check_approval(&["unknown".to_string(), "command".to_string()])
434 .await;
435 assert_eq!(
436 result,
437 ExecApprovalRequirement::forbidden(REJECT_SANDBOX_APPROVAL_REASON)
438 );
439 }
440
441 #[tokio::test]
442 async fn test_trusted_patterns() {
443 let dir = tempdir().unwrap();
444 let manager = ExecPolicyManager::with_defaults(dir.path().to_path_buf());
445
446 let amendment = ExecPolicyAmendment::from_prefix("cargo");
448 manager.add_trusted_pattern(amendment).await;
449
450 let result = manager
452 .check_approval(&["cargo".to_string(), "test".to_string()])
453 .await;
454 assert!(result.can_proceed());
455 }
456
457 #[tokio::test]
458 async fn test_session_approval() {
459 let dir = tempdir().unwrap();
460 let manager = ExecPolicyManager::with_defaults(dir.path().to_path_buf());
461
462 let cmd = vec!["git".to_string(), "status".to_string()];
463
464 let result = manager.check_approval(&cmd).await;
466 assert!(result.requires_approval());
467
468 manager.approve_command(&cmd).await;
470
471 let result = manager.check_approval(&cmd).await;
473 assert!(result.can_proceed());
474
475 manager.clear_session_approvals().await;
477
478 let result = manager.check_approval(&cmd).await;
480 assert!(result.requires_approval());
481 }
482
483 #[tokio::test]
484 async fn test_heuristics() {
485 let dir = tempdir().unwrap();
486 let manager = ExecPolicyManager::with_defaults(dir.path().to_path_buf());
487
488 let result = manager.check_approval(&["ls".to_string()]).await;
490 assert!(result.can_proceed());
491
492 let result = manager
494 .check_approval(&["rm".to_string(), "-rf".to_string()])
495 .await;
496 assert!(result.is_forbidden());
497 }
498}