Skip to main content

vtcode_core/command_safety/
unified.rs

1//! Unified Command Evaluator - Phase 5
2//!
3//! Merges CommandPolicyEvaluator with command_safety module to provide
4//! comprehensive command validation combining:
5//! - Policy-based rules (allow/deny prefixes, regexes, globs)
6//! - Safety rules (subcommand validation, dangerous patterns)
7//! - Shell parsing (decompose complex scripts)
8//! - Audit logging & caching
9
10use crate::command_safety::{
11    AuditEntry, CommandDatabase, SafeCommandRegistry, SafetyAuditLogger, SafetyDecision,
12    SafetyDecisionCache, command_might_be_dangerous, parse_bash_lc_commands,
13};
14use anyhow::Result;
15use std::path::PathBuf;
16use std::sync::Arc;
17
18/// Detailed reason for evaluation result
19#[derive(Clone, Debug, PartialEq)]
20pub enum EvaluationReason {
21    /// Command allowed by policy rule
22    PolicyAllow(String),
23    /// Command denied by policy rule
24    PolicyDeny(String),
25    /// Command passed safety checks
26    SafetyAllow,
27    /// Command failed safety checks
28    SafetyDeny(String),
29    /// Hardcoded dangerous command detected
30    DangerousCommand(String),
31    /// Retrieved from cache
32    CacheHit(bool, String),
33}
34
35impl std::fmt::Display for EvaluationReason {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        match self {
38            Self::PolicyAllow(msg) => write!(f, "Policy Allow: {}", msg),
39            Self::PolicyDeny(msg) => write!(f, "Policy Deny: {}", msg),
40            Self::SafetyAllow => write!(f, "Safety Allow"),
41            Self::SafetyDeny(msg) => write!(f, "Safety Deny: {}", msg),
42            Self::DangerousCommand(msg) => write!(f, "Dangerous: {}", msg),
43            Self::CacheHit(allowed, msg) => {
44                write!(
45                    f,
46                    "Cache {} {}",
47                    if *allowed { "Allow" } else { "Deny" },
48                    msg
49                )
50            }
51        }
52    }
53}
54
55/// Complete evaluation result
56#[derive(Clone, Debug)]
57pub struct EvaluationResult {
58    /// Whether the command is allowed
59    pub allowed: bool,
60    /// Primary reason for decision
61    pub primary_reason: EvaluationReason,
62    /// Secondary reasons (e.g., policy rule that matched)
63    pub secondary_reasons: Vec<String>,
64    /// Resolved command path (if available)
65    pub resolved_path: Option<PathBuf>,
66}
67
68/// Unified command evaluator combining policies and safety rules
69#[derive(Clone)]
70pub struct UnifiedCommandEvaluator {
71    // Safety components
72    registry: SafeCommandRegistry,
73    database: CommandDatabase,
74    cache: SafetyDecisionCache,
75    audit_logger: SafetyAuditLogger,
76}
77
78impl UnifiedCommandEvaluator {
79    async fn log_audit_entry(
80        &self,
81        command: &[String],
82        allowed: bool,
83        reason: impl Into<String>,
84        decision_type: &str,
85    ) {
86        self.audit_logger
87            .log(AuditEntry::new(
88                command.to_vec(),
89                allowed,
90                reason.into(),
91                decision_type.to_string(),
92            ))
93            .await;
94    }
95
96    /// Create a new unified evaluator with default components
97    pub fn new() -> Self {
98        Self {
99            registry: SafeCommandRegistry::new(),
100            database: CommandDatabase,
101            cache: SafetyDecisionCache::new(1000),
102            audit_logger: SafetyAuditLogger::new(true),
103        }
104    }
105
106    /// Evaluate a command with full context (async)
107    ///
108    /// # Evaluation Pipeline
109    /// 1. Check cache
110    /// 2. Apply dangerous command detection
111    /// 3. Apply safety registry rules (with subcommand validation)
112    /// 4. Handle shell parsing if needed (bash -lc)
113    /// 5. Log audit entry (async)
114    /// 6. Cache result
115    pub async fn evaluate(&self, command: &[String]) -> Result<EvaluationResult> {
116        if command.is_empty() {
117            return Ok(EvaluationResult {
118                allowed: false,
119                primary_reason: EvaluationReason::SafetyDeny("empty command".into()),
120                secondary_reasons: vec![],
121                resolved_path: None,
122            });
123        }
124
125        let command_text = command.join(" ");
126
127        // 1. Check cache first
128        if let Some(cached_decision) = self.cache.get(&command_text).await {
129            let reason =
130                EvaluationReason::CacheHit(cached_decision.is_safe, cached_decision.reason.clone());
131            // Note: Audit logging skipped for cached decisions (logged on original evaluation)
132            return Ok(EvaluationResult {
133                allowed: cached_decision.is_safe,
134                primary_reason: reason,
135                secondary_reasons: vec![],
136                resolved_path: None,
137            });
138        }
139
140        // 2. Check dangerous commands first (fail-fast)
141        if command_might_be_dangerous(command) {
142            let result = EvaluationResult {
143                allowed: false,
144                primary_reason: EvaluationReason::DangerousCommand(
145                    "matches dangerous patterns".into(),
146                ),
147                secondary_reasons: vec![],
148                resolved_path: None,
149            };
150            self.log_audit_entry(command, false, "matches dangerous patterns", "Dangerous")
151                .await;
152            self.cache
153                .put(
154                    command_text.clone(),
155                    false,
156                    "dangerous command pattern".into(),
157                )
158                .await;
159            return Ok(result);
160        }
161
162        // 3. Apply safety registry rules
163        let registry_decision = self.registry.is_safe(command);
164        match registry_decision {
165            SafetyDecision::Deny(reason) => {
166                let result = EvaluationResult {
167                    allowed: false,
168                    primary_reason: EvaluationReason::SafetyDeny(reason.clone()),
169                    secondary_reasons: vec!["registry rule".into()],
170                    resolved_path: None,
171                };
172                self.log_audit_entry(command, false, reason.clone(), "Deny")
173                    .await;
174                self.cache
175                    .put(command_text.clone(), false, reason.clone())
176                    .await;
177                return Ok(result);
178            }
179            SafetyDecision::Allow => {
180                // Passed registry, continue to database checks
181            }
182            SafetyDecision::Unknown => {
183                // Continue to database checks
184            }
185        }
186
187        // 4. Apply command database rules
188        // Note: Database rules are optional. Currently the registry covers the main use cases.
189        // In a production system, this would merge database rules with registry rules.
190        // For now, we skip explicit database check as the registry is comprehensive.
191
192        // 5. Handle shell parsing for bash -lc and similar patterns
193        // Note: For simplicity, we evaluate each sub-command non-recursively
194        // by applying the same checks. In production, this could be refactored to support recursion.
195        if let Some(scripts) = parse_bash_lc_commands(command) {
196            for script in scripts {
197                // Apply the same checks to each script without recursive call
198                if command_might_be_dangerous(&script) {
199                    let result = EvaluationResult {
200                        allowed: false,
201                        primary_reason: EvaluationReason::DangerousCommand(format!(
202                            "dangerous in sub-script: {}",
203                            script.join(" ")
204                        )),
205                        secondary_reasons: vec![],
206                        resolved_path: None,
207                    };
208                    self.cache
209                        .put(
210                            command_text.clone(),
211                            false,
212                            result.primary_reason.to_string(),
213                        )
214                        .await;
215                    return Ok(result);
216                }
217
218                // Check safety registry for sub-command
219                if let SafetyDecision::Deny(reason) = self.registry.is_safe(&script) {
220                    let result = EvaluationResult {
221                        allowed: false,
222                        primary_reason: EvaluationReason::SafetyDeny(format!(
223                            "sub-command denied: {}",
224                            reason
225                        )),
226                        secondary_reasons: vec![],
227                        resolved_path: None,
228                    };
229                    self.cache
230                        .put(
231                            command_text.clone(),
232                            false,
233                            result.primary_reason.to_string(),
234                        )
235                        .await;
236                    return Ok(result);
237                }
238            }
239        }
240
241        // 6. All checks passed
242        let result = EvaluationResult {
243            allowed: true,
244            primary_reason: EvaluationReason::SafetyAllow,
245            secondary_reasons: vec!["passed all safety checks".into()],
246            resolved_path: None,
247        };
248        self.log_audit_entry(command, true, "passed all safety checks", "Allow")
249            .await;
250        self.cache
251            .put(command_text, true, "passed all safety checks".into())
252            .await;
253        Ok(result)
254    }
255
256    /// Evaluate with explicit policy check (requires external CommandPolicyEvaluator)
257    ///
258    /// This is a placeholder for integration with CommandPolicyEvaluator.
259    /// In a real implementation, this would:
260    /// 1. Check policy rules first (deny precedence)
261    /// 2. Then apply safety rules
262    /// 3. Merge results
263    pub async fn evaluate_with_policy(
264        &self,
265        command: &[String],
266        policy_allowed: bool,
267        policy_reason: &str,
268    ) -> Result<EvaluationResult> {
269        // If policy explicitly denies, stop here
270        if !policy_allowed {
271            return Ok(EvaluationResult {
272                allowed: false,
273                primary_reason: EvaluationReason::PolicyDeny(policy_reason.into()),
274                secondary_reasons: vec![],
275                resolved_path: None,
276            });
277        }
278
279        // Policy allows, continue with safety checks
280        self.evaluate(command).await
281    }
282
283    /// Get reference to the cache for metrics/debugging
284    pub fn cache(&self) -> &SafetyDecisionCache {
285        &self.cache
286    }
287
288    /// Get reference to the audit logger
289    pub fn audit_logger(&self) -> &SafetyAuditLogger {
290        &self.audit_logger
291    }
292
293    /// Get reference to the registry
294    pub fn registry(&self) -> &SafeCommandRegistry {
295        &self.registry
296    }
297
298    /// Get reference to the database
299    pub fn database(&self) -> &CommandDatabase {
300        &self.database
301    }
302}
303
304impl Default for UnifiedCommandEvaluator {
305    fn default() -> Self {
306        Self::new()
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[tokio::test]
315    async fn empty_command_denied() {
316        let evaluator = UnifiedCommandEvaluator::new();
317        let result = evaluator.evaluate(&[]).await.unwrap();
318        assert!(!result.allowed);
319    }
320
321    #[tokio::test]
322    async fn dangerous_command_denied() {
323        let evaluator = UnifiedCommandEvaluator::new();
324        let result = evaluator
325            .evaluate(&["rm".to_string(), "-rf".to_string(), "/".to_string()])
326            .await
327            .unwrap();
328        assert!(!result.allowed);
329        matches!(result.primary_reason, EvaluationReason::DangerousCommand(_));
330    }
331
332    #[tokio::test]
333    async fn safe_command_allowed() {
334        let evaluator = UnifiedCommandEvaluator::new();
335        // git is in the default safe registry
336        let result = evaluator
337            .evaluate(&["git".to_string(), "status".to_string()])
338            .await
339            .unwrap();
340        assert!(result.allowed);
341    }
342
343    #[tokio::test]
344    async fn cache_hit_on_repeated_command() {
345        let evaluator = UnifiedCommandEvaluator::new();
346        let cmd = vec!["git".to_string(), "status".to_string()];
347
348        // First evaluation
349        let result1 = evaluator.evaluate(&cmd).await.unwrap();
350        assert!(result1.allowed);
351
352        // Second evaluation (should be cached)
353        let result2 = evaluator.evaluate(&cmd).await.unwrap();
354        assert!(result2.allowed);
355        matches!(result2.primary_reason, EvaluationReason::CacheHit(true, _));
356        assert_eq!(evaluator.audit_logger().count().await, 1);
357    }
358
359    #[tokio::test]
360    async fn dangerous_command_is_audited() {
361        let evaluator = UnifiedCommandEvaluator::new();
362        evaluator
363            .evaluate(&["rm".to_string(), "-rf".to_string(), "/".to_string()])
364            .await
365            .unwrap();
366
367        let entries = evaluator.audit_logger().entries().await;
368        assert_eq!(entries.len(), 1);
369        assert!(!entries[0].allowed);
370        assert_eq!(entries[0].decision_type, "Dangerous");
371    }
372
373    #[tokio::test]
374    async fn safe_command_is_audited() {
375        let evaluator = UnifiedCommandEvaluator::new();
376        evaluator
377            .evaluate(&["git".to_string(), "status".to_string()])
378            .await
379            .unwrap();
380
381        let entries = evaluator.audit_logger().entries().await;
382        assert_eq!(entries.len(), 1);
383        assert!(entries[0].allowed);
384        assert_eq!(entries[0].decision_type, "Allow");
385    }
386
387    #[tokio::test]
388    async fn bash_lc_decomposition() {
389        let evaluator = UnifiedCommandEvaluator::new();
390        // bash -lc with mixed safe/unsafe commands
391        let cmd = vec![
392            "bash".to_string(),
393            "-lc".to_string(),
394            "git status && rm -rf /".to_string(),
395        ];
396        let result = evaluator.evaluate(&cmd).await.unwrap();
397        assert!(!result.allowed);
398        // Should detect the rm -rf in the sub-command
399    }
400
401    #[test]
402    fn evaluation_reason_display() {
403        let reason = EvaluationReason::PolicyAllow("test".into());
404        assert_eq!(reason.to_string(), "Policy Allow: test");
405
406        let reason = EvaluationReason::SafetyDeny("forbidden".into());
407        assert_eq!(reason.to_string(), "Safety Deny: forbidden");
408    }
409
410    #[tokio::test]
411    async fn policy_deny_stops_evaluation() {
412        let evaluator = UnifiedCommandEvaluator::new();
413        let result = evaluator
414            .evaluate_with_policy(
415                &["git".to_string(), "status".to_string()],
416                false,
417                "policy blocked",
418            )
419            .await
420            .unwrap();
421        assert!(!result.allowed);
422        matches!(result.primary_reason, EvaluationReason::PolicyDeny(_));
423    }
424
425    #[tokio::test]
426    async fn policy_allow_continues_to_safety_checks() {
427        let evaluator = UnifiedCommandEvaluator::new();
428        let result = evaluator
429            .evaluate_with_policy(
430                &["git".to_string(), "status".to_string()],
431                true,
432                "policy allowed",
433            )
434            .await
435            .unwrap();
436        // Policy allows, git status passes safety checks
437        assert!(result.allowed);
438    }
439
440    #[tokio::test]
441    async fn safety_deny_overrides_policy_allow() {
442        let evaluator = UnifiedCommandEvaluator::new();
443        let result = evaluator
444            .evaluate_with_policy(
445                &["rm".to_string(), "-rf".to_string(), "/".to_string()],
446                true,
447                "policy allowed",
448            )
449            .await
450            .unwrap();
451        // Policy allows but safety rules deny
452        assert!(!result.allowed);
453        matches!(result.primary_reason, EvaluationReason::DangerousCommand(_));
454    }
455
456    #[tokio::test]
457    async fn evaluation_result_contains_reasons() {
458        let evaluator = UnifiedCommandEvaluator::new();
459        let result = evaluator
460            .evaluate(&["git".to_string(), "status".to_string()])
461            .await
462            .unwrap();
463        assert!(result.allowed);
464        assert!(!result.secondary_reasons.is_empty());
465    }
466
467    #[tokio::test]
468    async fn forbidden_git_subcommand_denied() {
469        let evaluator = UnifiedCommandEvaluator::new();
470        // git push is not in the allowed subcommands for git
471        let result = evaluator
472            .evaluate(&["git".to_string(), "push".to_string()])
473            .await
474            .unwrap();
475        assert!(!result.allowed);
476    }
477}
478
479/// Policy-aware evaluator adapter for backward compatibility with CommandPolicyEvaluator
480///
481/// This adapter wraps UnifiedCommandEvaluator with policy rule evaluation,
482/// allowing gradual migration from CommandPolicyEvaluator to UnifiedCommandEvaluator.
483#[derive(Clone)]
484pub struct PolicyAwareEvaluator {
485    unified: Arc<UnifiedCommandEvaluator>,
486    /// Policy allow decision (if Some, policy layer is active)
487    allow_policy_decision: Option<bool>,
488    policy_reason: Option<String>,
489}
490
491impl PolicyAwareEvaluator {
492    /// Create a new policy-aware evaluator with default components
493    pub fn new() -> Self {
494        Self {
495            unified: Arc::new(UnifiedCommandEvaluator::new()),
496            allow_policy_decision: None,
497            policy_reason: None,
498        }
499    }
500
501    /// Create with explicit policy decision
502    pub fn with_policy(allow_policy_decision: bool, policy_reason: impl Into<String>) -> Self {
503        Self {
504            unified: Arc::new(UnifiedCommandEvaluator::new()),
505            allow_policy_decision: Some(allow_policy_decision),
506            policy_reason: Some(policy_reason.into()),
507        }
508    }
509
510    /// Evaluate command with optional policy layer
511    pub async fn evaluate(&self, command: &[String]) -> Result<EvaluationResult> {
512        // Apply policy layer if configured
513        if let (Some(policy_allowed), Some(reason)) =
514            (&self.allow_policy_decision, &self.policy_reason)
515        {
516            self.unified
517                .evaluate_with_policy(command, *policy_allowed, reason)
518                .await
519        } else {
520            // No policy configured, use pure safety evaluation
521            self.unified.evaluate(command).await
522        }
523    }
524
525    /// Set policy decision (allows updating policy after creation)
526    pub fn set_policy(&mut self, allowed: bool, reason: impl Into<String>) {
527        self.allow_policy_decision = Some(allowed);
528        self.policy_reason = Some(reason.into());
529    }
530
531    /// Clear policy decision (revert to pure safety evaluation)
532    pub fn clear_policy(&mut self) {
533        self.allow_policy_decision = None;
534        self.policy_reason = None;
535    }
536
537    /// Get reference to the underlying evaluator for advanced access
538    pub fn unified(&self) -> Arc<UnifiedCommandEvaluator> {
539        Arc::clone(&self.unified)
540    }
541}
542
543impl Default for PolicyAwareEvaluator {
544    fn default() -> Self {
545        Self::new()
546    }
547}
548
549#[cfg(test)]
550mod adapter_tests {
551    use super::*;
552
553    #[tokio::test]
554    async fn policy_aware_without_policy_uses_safety() {
555        let evaluator = PolicyAwareEvaluator::new();
556        let result = evaluator
557            .evaluate(&["git".to_string(), "status".to_string()])
558            .await
559            .unwrap();
560        assert!(result.allowed);
561    }
562
563    #[tokio::test]
564    async fn policy_aware_with_deny_policy_blocks_safe_command() {
565        let evaluator = PolicyAwareEvaluator::with_policy(false, "policy blocked");
566        let result = evaluator
567            .evaluate(&["git".to_string(), "status".to_string()])
568            .await
569            .unwrap();
570        assert!(!result.allowed);
571        matches!(result.primary_reason, EvaluationReason::PolicyDeny(_));
572    }
573
574    #[tokio::test]
575    async fn policy_aware_with_allow_policy_still_blocks_dangerous() {
576        let evaluator = PolicyAwareEvaluator::with_policy(true, "policy allowed");
577        let result = evaluator
578            .evaluate(&["rm".to_string(), "-rf".to_string(), "/".to_string()])
579            .await
580            .unwrap();
581        // Policy allows, but safety rules should deny
582        assert!(!result.allowed);
583    }
584
585    #[tokio::test]
586    async fn policy_aware_mutable_set_policy() {
587        let mut evaluator = PolicyAwareEvaluator::new();
588        // Initially no policy
589        let result1 = evaluator
590            .evaluate(&["git".to_string(), "status".to_string()])
591            .await
592            .unwrap();
593        assert!(result1.allowed);
594
595        // Set deny policy
596        evaluator.set_policy(false, "policy blocked");
597        let result2 = evaluator
598            .evaluate(&["git".to_string(), "status".to_string()])
599            .await
600            .unwrap();
601        assert!(!result2.allowed);
602
603        // Clear policy
604        evaluator.clear_policy();
605        let result3 = evaluator
606            .evaluate(&["git".to_string(), "status".to_string()])
607            .await
608            .unwrap();
609        assert!(result3.allowed);
610    }
611}