Skip to main content

vtcode_core/tools/
safety_gateway.rs

1//! Unified Safety Gateway
2//!
3//! Consolidates all safety checking mechanisms into a single gateway:
4//! - Rate limiting (from runloop's tool_call_safety)
5//! - Destructive tool detection
6//! - Command policy enforcement
7//! - Plan mode restrictions
8//!
9//! This provides consistent safety decisions across all tool execution paths.
10
11use hashbrown::HashSet;
12use parking_lot::{Mutex, RwLock};
13use std::future::Future;
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use thiserror::Error;
21
22use crate::config::CommandsConfig;
23use crate::config::constants::tools;
24use crate::dotfile_protection::{
25    AccessContext, AccessType, DotfileGuardian, ProtectionDecision, get_global_guardian,
26};
27use crate::tools::apply_patch::{Patch, PatchOperation, decode_apply_patch_input};
28use crate::tools::command_policy::CommandPolicyEvaluator;
29use crate::tools::invocation::ToolInvocationId;
30use crate::tools::rate_limit_config::{
31    tool_calls_per_minute_from_env, tool_calls_per_second_from_env,
32};
33use crate::tools::registry::{
34    RiskLevel, ToolRiskContext, ToolRiskScorer, ToolSource, WorkspaceTrust,
35};
36use crate::tools::tool_intent::{
37    classify_tool_intent, unified_exec_action_in, unified_exec_action_is, unified_file_action_is,
38    unified_search_action_is,
39};
40use vtcode_config::core::DotfileProtectionConfig;
41
42/// Trust level used by the safety gateway for approval bypass decisions.
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
44pub enum SafetyTrustLevel {
45    Untrusted,
46    #[default]
47    Standard,
48    Elevated,
49    Full,
50}
51
52impl SafetyTrustLevel {
53    #[inline]
54    pub const fn can_bypass_approval(self) -> bool {
55        matches!(self, Self::Elevated | Self::Full)
56    }
57}
58
59/// Minimal execution context required for safety decisions.
60#[derive(Debug, Clone)]
61pub struct SafetyContext {
62    pub session_id: String,
63    pub trust_level: SafetyTrustLevel,
64}
65
66impl SafetyContext {
67    #[must_use]
68    pub fn new(session_id: impl Into<String>) -> Self {
69        Self {
70            session_id: session_id.into(),
71            trust_level: SafetyTrustLevel::default(),
72        }
73    }
74}
75
76/// Safety decision for a tool invocation
77#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
78pub enum SafetyDecision {
79    /// Tool execution is allowed without user approval
80    Allow,
81    /// Tool execution is denied with a reason
82    Deny(String),
83    /// Tool execution requires user approval with justification
84    NeedsApproval(String),
85}
86
87impl SafetyDecision {
88    /// Whether execution can proceed (Allow only)
89    #[inline]
90    pub fn is_allowed(&self) -> bool {
91        matches!(self, SafetyDecision::Allow)
92    }
93
94    /// Whether execution is blocked (Deny)
95    #[inline]
96    pub fn is_denied(&self) -> bool {
97        matches!(self, SafetyDecision::Deny(_))
98    }
99
100    /// Whether user approval is needed
101    #[inline]
102    pub fn needs_approval(&self) -> bool {
103        matches!(self, SafetyDecision::NeedsApproval(_))
104    }
105
106    /// Get the reason/justification if present
107    pub fn reason(&self) -> Option<&str> {
108        match self {
109            SafetyDecision::Allow => None,
110            SafetyDecision::Deny(reason) | SafetyDecision::NeedsApproval(reason) => Some(reason),
111        }
112    }
113}
114
115/// Errors from safety checks
116#[derive(Debug, Error, Clone)]
117pub enum SafetyError {
118    #[error("Rate limit exceeded: {current} calls in {window} (max: {max})")]
119    RateLimitExceeded {
120        current: usize,
121        max: usize,
122        window: &'static str,
123    },
124    #[error("Per-turn tool limit reached (max: {max})")]
125    TurnLimitReached { max: usize },
126    #[error("Session tool limit reached (max: {max})")]
127    SessionLimitReached { max: usize },
128    #[error("Plan mode violation: {0}")]
129    PlanModeViolation(String),
130    #[error("Command policy denied: {0}")]
131    CommandPolicyDenied(String),
132    #[error("Dotfile protection violation: {0}")]
133    DotfileProtectionViolation(String),
134}
135
136/// Result of a safety check with optional retry hint metadata.
137#[derive(Debug, Clone)]
138pub struct SafetyCheckResult {
139    /// Final decision for this invocation.
140    pub decision: SafetyDecision,
141    /// Suggested delay before retrying if the decision is a rate-limit denial.
142    pub retry_after: Option<Duration>,
143    /// Structured error when denial is produced by a safety limit.
144    pub violation: Option<SafetyError>,
145}
146
147/// Configuration for the safety gateway
148#[derive(Debug, Clone)]
149pub struct SafetyGatewayConfig {
150    /// Maximum tool calls per turn
151    pub max_per_turn: usize,
152    /// Maximum tool calls per session
153    pub max_per_session: usize,
154    /// Rate limit: calls per second
155    pub rate_limit_per_second: usize,
156    /// Rate limit: calls per minute (optional burst protection)
157    pub rate_limit_per_minute: Option<usize>,
158    /// Whether plan mode is active (read-only)
159    pub plan_mode_active: bool,
160    /// Workspace trust level
161    pub workspace_trust: WorkspaceTrust,
162    /// Risk threshold for requiring approval
163    pub approval_risk_threshold: RiskLevel,
164    /// Enforce short-window rate limiting (per-second/per-minute).
165    /// Turn/session limits are always enforced.
166    pub enforce_rate_limits: bool,
167}
168
169impl Default for SafetyGatewayConfig {
170    fn default() -> Self {
171        let rate_limit_per_second = tool_calls_per_second_from_env().unwrap_or(5);
172
173        Self {
174            max_per_turn: 50,
175            max_per_session: 100,
176            rate_limit_per_second,
177            rate_limit_per_minute: tool_calls_per_minute_from_env(),
178            plan_mode_active: false,
179            workspace_trust: WorkspaceTrust::Trusted,
180            approval_risk_threshold: RiskLevel::Medium,
181            enforce_rate_limits: true,
182        }
183    }
184}
185
186/// Rate limiter state (shared across async contexts)
187#[derive(Debug, Default)]
188struct RateLimiterState {
189    calls_per_second: std::collections::VecDeque<Instant>,
190    calls_per_minute: std::collections::VecDeque<Instant>,
191    current_turn_count: usize,
192    session_count: usize,
193}
194
195/// Unified Safety Gateway
196///
197/// Consolidates rate limiting, destructive tool detection, command policy
198/// enforcement, plan mode restrictions, and dotfile protection into a single
199/// safety decision point.
200pub struct SafetyGateway {
201    /// Configuration
202    config: RwLock<SafetyGatewayConfig>,
203    /// Command policy evaluator (optional, for shell commands)
204    command_policy: Option<CommandPolicyEvaluator>,
205    /// Rate limiter state
206    rate_state: Mutex<RateLimiterState>,
207    /// Preapproved tools for this session
208    preapproved: Mutex<HashSet<String>>,
209    /// Dotfile guardian for protected file access
210    dotfile_guardian: Option<Arc<DotfileGuardian>>,
211}
212
213#[derive(Debug, Clone, PartialEq, Eq, Hash)]
214struct FileAccessTarget {
215    path: PathBuf,
216    access_type: AccessType,
217}
218
219fn primary_path_arg(args: &Value) -> Option<&str> {
220    args.get("path")
221        .and_then(|value| value.as_str())
222        .or_else(|| args.get("file_path").and_then(|value| value.as_str()))
223        .or_else(|| args.get("filepath").and_then(|value| value.as_str()))
224        .or_else(|| args.get("target_path").and_then(|value| value.as_str()))
225}
226
227fn destination_path_arg(args: &Value) -> Option<&str> {
228    args.get("destination").and_then(|value| value.as_str())
229}
230
231fn push_file_access_target(
232    targets: &mut Vec<FileAccessTarget>,
233    path: &str,
234    access_type: AccessType,
235) {
236    let path_str = path.trim();
237    if path_str.is_empty() {
238        return;
239    }
240
241    let path = PathBuf::from(path_str);
242    // For small number of targets, linear search is faster than HashSet.
243    // In large patches, we'll use a local HashSet in patch_file_access_targets.
244    if targets
245        .iter()
246        .any(|existing| existing.path == path && existing.access_type == access_type)
247    {
248        return;
249    }
250
251    targets.push(FileAccessTarget { path, access_type });
252}
253
254fn command_text_for_tool(tool_name: &str, args: &Value) -> Option<String> {
255    match tool_name {
256        tools::SHELL | tools::RUN_PTY_CMD => crate::tools::command_args::command_text(args)
257            .ok()
258            .flatten(),
259        tools::SEND_PTY_INPUT => {
260            crate::tools::command_args::interactive_input_text(args).map(str::to_owned)
261        }
262        tools::UNIFIED_EXEC if unified_exec_action_is(args, "run") => {
263            crate::tools::command_args::command_text(args)
264                .ok()
265                .flatten()
266        }
267        tools::UNIFIED_EXEC if unified_exec_action_in(args, &["write", "continue"]) => {
268            crate::tools::command_args::interactive_input_text(args).map(str::to_owned)
269        }
270        _ => None,
271    }
272}
273
274fn patch_file_access_targets(args: &Value) -> Vec<FileAccessTarget> {
275    let Ok(Some(patch_input)) = decode_apply_patch_input(args) else {
276        return Vec::new();
277    };
278    let Ok(patch) = Patch::parse(&patch_input.text) else {
279        return Vec::new();
280    };
281
282    let mut targets = Vec::new();
283    for operation in patch.operations() {
284        match operation {
285            PatchOperation::AddFile { path, .. } => {
286                push_file_access_target(&mut targets, path, AccessType::Write);
287            }
288            PatchOperation::DeleteFile { path } => {
289                push_file_access_target(&mut targets, path, AccessType::Delete);
290            }
291            PatchOperation::UpdateFile { path, new_path, .. } => {
292                push_file_access_target(&mut targets, path, AccessType::Modify);
293                if let Some(destination) =
294                    new_path.as_deref().filter(|candidate| *candidate != path)
295                {
296                    push_file_access_target(&mut targets, destination, AccessType::Write);
297                }
298            }
299        }
300    }
301
302    targets
303}
304
305fn file_access_targets(tool_name: &str, args: &Value) -> Vec<FileAccessTarget> {
306    let mut targets = Vec::new();
307
308    match tool_name {
309        tools::WRITE_FILE | tools::CREATE_FILE => {
310            if let Some(path) = primary_path_arg(args) {
311                push_file_access_target(&mut targets, path, AccessType::Write);
312            }
313        }
314        tools::EDIT_FILE | "search_replace" => {
315            if let Some(path) = primary_path_arg(args) {
316                push_file_access_target(&mut targets, path, AccessType::Modify);
317            }
318        }
319        tools::DELETE_FILE => {
320            if let Some(path) = primary_path_arg(args) {
321                push_file_access_target(&mut targets, path, AccessType::Delete);
322            }
323        }
324        tools::MOVE_FILE => {
325            if let Some(path) = primary_path_arg(args) {
326                push_file_access_target(&mut targets, path, AccessType::Modify);
327            }
328            if let Some(path) = destination_path_arg(args) {
329                push_file_access_target(&mut targets, path, AccessType::Write);
330            }
331        }
332        tools::COPY_FILE => {
333            if let Some(path) = destination_path_arg(args) {
334                push_file_access_target(&mut targets, path, AccessType::Write);
335            }
336        }
337        tools::APPLY_PATCH => {
338            targets.extend(patch_file_access_targets(args));
339        }
340        tools::UNIFIED_FILE if unified_file_action_is(args, "write") => {
341            if let Some(path) = primary_path_arg(args) {
342                push_file_access_target(&mut targets, path, AccessType::Write);
343            }
344        }
345        tools::UNIFIED_FILE if unified_file_action_is(args, "edit") => {
346            if let Some(path) = primary_path_arg(args) {
347                push_file_access_target(&mut targets, path, AccessType::Modify);
348            }
349        }
350        tools::UNIFIED_FILE if unified_file_action_is(args, "delete") => {
351            if let Some(path) = primary_path_arg(args) {
352                push_file_access_target(&mut targets, path, AccessType::Delete);
353            }
354        }
355        tools::UNIFIED_FILE if unified_file_action_is(args, "move") => {
356            if let Some(path) = primary_path_arg(args) {
357                push_file_access_target(&mut targets, path, AccessType::Modify);
358            }
359            if let Some(path) = destination_path_arg(args) {
360                push_file_access_target(&mut targets, path, AccessType::Write);
361            }
362        }
363        tools::UNIFIED_FILE if unified_file_action_is(args, "copy") => {
364            if let Some(path) = destination_path_arg(args) {
365                push_file_access_target(&mut targets, path, AccessType::Write);
366            }
367        }
368        tools::UNIFIED_FILE if unified_file_action_is(args, "patch") => {
369            targets.extend(patch_file_access_targets(args));
370        }
371        _ => {}
372    }
373
374    targets
375}
376
377fn proposed_changes_preview(args: &Value) -> String {
378    const PREVIEW_LIMIT: usize = 500;
379
380    let preview_text = |label: &str, text: &str| {
381        let preview_len = text.len().min(PREVIEW_LIMIT);
382        format!(
383            "{label} ({} bytes):\n{}{}",
384            text.len(),
385            &text[..preview_len],
386            if text.len() > preview_len { "..." } else { "" }
387        )
388    };
389
390    if let Some(content) = args.get("content").and_then(|value| value.as_str()) {
391        return preview_text("Content", content);
392    }
393
394    if let Some(old_str) = args.get("old_str").and_then(|value| value.as_str()) {
395        let new_str = args
396            .get("new_str")
397            .and_then(|value| value.as_str())
398            .unwrap_or("");
399        return format!("Replace:\n  '{}'\nWith:\n  '{}'", old_str, new_str);
400    }
401
402    if let Ok(Some(patch_input)) = decode_apply_patch_input(args) {
403        return preview_text("Patch", &patch_input.text);
404    }
405
406    "No details provided".to_string()
407}
408
409impl SafetyGateway {
410    /// Create a new safety gateway with default configuration
411    pub fn new() -> Self {
412        Self::with_config(SafetyGatewayConfig::default())
413    }
414
415    /// Create a new safety gateway with custom configuration
416    pub fn with_config(config: SafetyGatewayConfig) -> Self {
417        Self {
418            config: RwLock::new(config),
419            command_policy: None,
420            rate_state: Mutex::new(RateLimiterState::default()),
421            preapproved: Mutex::new(HashSet::new()),
422            dotfile_guardian: None,
423        }
424    }
425
426    /// Set the dotfile guardian for protected file access
427    pub fn with_dotfile_guardian(mut self, guardian: Arc<DotfileGuardian>) -> Self {
428        self.dotfile_guardian = Some(guardian);
429        self
430    }
431
432    /// Create and set a dotfile guardian from configuration
433    pub async fn with_dotfile_protection(
434        mut self,
435        config: DotfileProtectionConfig,
436    ) -> anyhow::Result<Self> {
437        let guardian = DotfileGuardian::new(config).await?;
438        self.dotfile_guardian = Some(Arc::new(guardian));
439        Ok(self)
440    }
441
442    /// Set the command policy evaluator for shell command checks
443    pub fn with_command_policy(mut self, policy: CommandPolicyEvaluator) -> Self {
444        self.command_policy = Some(policy);
445        self
446    }
447
448    /// Create from commands config
449    pub fn with_commands_config(mut self, config: &CommandsConfig) -> Self {
450        self.command_policy = Some(CommandPolicyEvaluator::from_config(config));
451        self
452    }
453
454    /// Enable or disable plan mode
455    pub fn set_plan_mode(&self, active: bool) {
456        self.config.write().plan_mode_active = active;
457    }
458
459    /// Set workspace trust level
460    pub fn set_workspace_trust(&self, trust: WorkspaceTrust) {
461        self.config.write().workspace_trust = trust;
462    }
463
464    /// Update rate limits
465    pub fn set_limits(&self, max_per_turn: usize, max_per_session: usize) {
466        let mut config = self.config.write();
467        config.max_per_turn = max_per_turn;
468        config.max_per_session = max_per_session;
469    }
470
471    /// Update rate-limiter thresholds.
472    pub fn set_rate_limits(
473        &self,
474        rate_limit_per_second: usize,
475        rate_limit_per_minute: Option<usize>,
476    ) {
477        let mut config = self.config.write();
478        if rate_limit_per_second > 0 {
479            config.rate_limit_per_second = rate_limit_per_second;
480        }
481        config.rate_limit_per_minute = rate_limit_per_minute.filter(|v| *v > 0);
482    }
483
484    /// Enable or disable rate-limit enforcement while preserving counters.
485    pub fn set_rate_limit_enforcement(&self, enabled: bool) {
486        self.config.write().enforce_rate_limits = enabled;
487    }
488
489    /// Increase session limit dynamically
490    pub fn increase_session_limit(&self, increment: usize) {
491        let mut config = self.config.write();
492        let new_max = config.max_per_session.saturating_add(increment);
493        config.max_per_session = new_max;
494        tracing::info!("Session tool limit increased to {}", new_max);
495    }
496
497    pub fn max_per_session(&self) -> usize {
498        self.config.read().max_per_session
499    }
500
501    /// Reset turn counters (call at start of new turn)
502    pub fn start_turn(&self) {
503        let mut state = self.rate_state.lock();
504        state.current_turn_count = 0;
505        state.calls_per_second.clear();
506        state.calls_per_minute.clear();
507    }
508
509    /// Preapprove a tool for this session
510    pub fn preapprove(&self, tool_name: &str) {
511        let mut preapproved = self.preapproved.lock();
512        preapproved.insert(tool_name.to_string());
513    }
514
515    /// Check if a tool is preapproved
516    pub fn is_preapproved(&self, tool_name: &str) -> bool {
517        let preapproved = self.preapproved.lock();
518        preapproved.contains(tool_name)
519    }
520
521    /// Check if a tool is destructive
522    pub fn is_destructive(&self, tool_name: &str) -> bool {
523        classify_tool_intent(tool_name, &Value::Object(Default::default())).destructive
524    }
525
526    /// Check if a tool is mutating
527    pub fn is_mutating(&self, tool_name: &str) -> bool {
528        classify_tool_intent(tool_name, &Value::Object(Default::default())).mutating
529    }
530
531    fn is_destructive_call(&self, tool_name: &str, args: &Value) -> bool {
532        classify_tool_intent(tool_name, args).destructive
533    }
534
535    fn is_mutating_call(&self, tool_name: &str, args: &Value) -> bool {
536        classify_tool_intent(tool_name, args).mutating
537    }
538
539    /// Main entry point: check safety for a tool invocation.
540    ///
541    /// Returns a [`SafetyDecision`] indicating whether execution can proceed.
542    /// Inline-delegating wrapper that returns the inner future directly to
543    /// avoid an extra coroutine state machine (audit section 16).
544    pub fn check_safety<'a>(
545        &'a self,
546        ctx: &'a SafetyContext,
547        tool_name: &'a str,
548        args: &'a Value,
549    ) -> impl Future<Output = SafetyDecision> + 'a {
550        self.check_safety_with_id(ctx, tool_name, args, None)
551    }
552
553    /// Check safety with explicit invocation ID for correlation
554    pub async fn check_safety_with_id(
555        &self,
556        ctx: &SafetyContext,
557        tool_name: &str,
558        args: &Value,
559        invocation_id: Option<ToolInvocationId>,
560    ) -> SafetyDecision {
561        let inv_id = invocation_id
562            .map(|id| id.short())
563            .unwrap_or_else(|| "unknown".to_string());
564
565        tracing::trace!(
566            invocation_id = %inv_id,
567            tool = %tool_name,
568            "SafetyGateway: checking safety"
569        );
570
571        if let Err(err) = self.check_rate_limits() {
572            tracing::warn!(
573                invocation_id = %inv_id,
574                error = %err,
575                "SafetyGateway: rate limit exceeded"
576            );
577            return SafetyDecision::Deny(err.to_string());
578        }
579
580        self.evaluate_non_rate_decision(ctx, tool_name, args, &inv_id)
581            .await
582    }
583
584    /// Check safety and atomically reserve a rate-limit slot on success.
585    ///
586    /// This avoids split check/record races by validating rate limits and recording
587    /// execution under a single lock acquisition.
588    pub async fn check_and_record(
589        &self,
590        ctx: &SafetyContext,
591        tool_name: &str,
592        args: &Value,
593    ) -> SafetyCheckResult {
594        self.check_and_record_with_id(ctx, tool_name, args, None)
595            .await
596    }
597
598    /// Check safety with correlation ID and atomically reserve a rate-limit slot.
599    pub async fn check_and_record_with_id(
600        &self,
601        ctx: &SafetyContext,
602        tool_name: &str,
603        args: &Value,
604        invocation_id: Option<ToolInvocationId>,
605    ) -> SafetyCheckResult {
606        let inv_id = invocation_id
607            .map(|id| id.short())
608            .unwrap_or_else(|| "unknown".to_string());
609        tracing::trace!(
610            invocation_id = %inv_id,
611            tool = %tool_name,
612            "SafetyGateway: checking and recording safety"
613        );
614
615        let decision = self
616            .evaluate_non_rate_decision(ctx, tool_name, args, &inv_id)
617            .await;
618
619        if decision.is_denied() {
620            return SafetyCheckResult {
621                decision,
622                retry_after: None,
623                violation: None,
624            };
625        }
626
627        let now = Instant::now();
628        let mut state = self.rate_state.lock();
629        match self.check_rate_limits_locked(&mut state, now) {
630            Ok(()) => {
631                self.record_execution_locked(&mut state, now);
632                SafetyCheckResult {
633                    decision,
634                    retry_after: None,
635                    violation: None,
636                }
637            }
638            Err(err) => {
639                tracing::warn!(
640                    invocation_id = %inv_id,
641                    error = %err,
642                    "SafetyGateway: rate limit exceeded during atomic reservation"
643                );
644                SafetyCheckResult {
645                    decision: SafetyDecision::Deny(err.to_string()),
646                    retry_after: self.retry_after_for_violation(&err, &state, now),
647                    violation: Some(err),
648                }
649            }
650        }
651    }
652
653    async fn evaluate_non_rate_decision(
654        &self,
655        ctx: &SafetyContext,
656        tool_name: &str,
657        args: &Value,
658        inv_id: &str,
659    ) -> SafetyDecision {
660        if let Some(decision) = self
661            .check_dotfile_protection(tool_name, args, &ctx.session_id)
662            .await
663        {
664            tracing::info!(
665                invocation_id = %inv_id,
666                tool = %tool_name,
667                "SafetyGateway: dotfile protection triggered"
668            );
669            return decision;
670        }
671
672        if self.config.read().plan_mode_active && self.is_mutating_call(tool_name, args) {
673            let reason = format!(
674                "Tool '{}' is blocked in plan mode (read-only). Switch to edit mode to execute.",
675                tool_name
676            );
677            tracing::info!(
678                invocation_id = %inv_id,
679                tool = %tool_name,
680                "SafetyGateway: plan mode violation"
681            );
682            return SafetyDecision::Deny(reason);
683        }
684
685        if let Some(ref policy) = self.command_policy
686            && let Some(command) = command_text_for_tool(tool_name, args)
687            && !policy.allows_text(&command)
688        {
689            let reason = format!("Command '{}' blocked by policy", command);
690            tracing::info!(
691                invocation_id = %inv_id,
692                command = %command,
693                "SafetyGateway: command policy denied"
694            );
695            return SafetyDecision::Deny(reason);
696        }
697
698        if self.is_preapproved(tool_name) {
699            tracing::trace!(
700                invocation_id = %inv_id,
701                tool = %tool_name,
702                "SafetyGateway: tool preapproved"
703            );
704            return SafetyDecision::Allow;
705        }
706
707        if ctx.trust_level.can_bypass_approval() {
708            tracing::trace!(
709                invocation_id = %inv_id,
710                tool = %tool_name,
711                trust_level = ?ctx.trust_level,
712                "SafetyGateway: trust level allows bypass"
713            );
714            return SafetyDecision::Allow;
715        }
716
717        let risk_ctx = self.build_risk_context(tool_name, args);
718        let risk_level = ToolRiskScorer::calculate_risk(&risk_ctx);
719
720        if ToolRiskScorer::requires_justification(
721            risk_level,
722            self.config.read().approval_risk_threshold,
723        ) {
724            let justification = self.build_approval_justification(tool_name, &risk_level, args);
725            tracing::info!(
726                invocation_id = %inv_id,
727                tool = %tool_name,
728                risk = %risk_level,
729                "SafetyGateway: requires approval"
730            );
731            return SafetyDecision::NeedsApproval(justification);
732        }
733
734        if self.is_destructive_call(tool_name, args) {
735            let justification = format!(
736                "Tool '{}' is destructive and may modify files or execute commands.",
737                tool_name
738            );
739            tracing::info!(
740                invocation_id = %inv_id,
741                tool = %tool_name,
742                "SafetyGateway: destructive tool requires approval"
743            );
744            return SafetyDecision::NeedsApproval(justification);
745        }
746
747        SafetyDecision::Allow
748    }
749
750    /// Record that a tool call was executed (for rate limiting)
751    pub fn record_execution(&self) {
752        let mut state = self.rate_state.lock();
753        self.record_execution_locked(&mut state, Instant::now());
754    }
755
756    /// Check rate limits without recording
757    fn check_rate_limits(&self) -> Result<(), SafetyError> {
758        let mut state = self.rate_state.lock();
759        self.check_rate_limits_locked(&mut state, Instant::now())
760    }
761
762    fn check_rate_limits_locked(
763        &self,
764        state: &mut RateLimiterState,
765        now: Instant,
766    ) -> Result<(), SafetyError> {
767        let config = self.config.read();
768        self.prune_rate_windows(state, now);
769
770        if config.enforce_rate_limits {
771            if state.calls_per_second.len() >= config.rate_limit_per_second {
772                return Err(SafetyError::RateLimitExceeded {
773                    current: state.calls_per_second.len(),
774                    max: config.rate_limit_per_second,
775                    window: "1s",
776                });
777            }
778
779            if let Some(limit) = config.rate_limit_per_minute
780                && state.calls_per_minute.len() >= limit
781            {
782                return Err(SafetyError::RateLimitExceeded {
783                    current: state.calls_per_minute.len(),
784                    max: limit,
785                    window: "60s",
786                });
787            }
788        }
789
790        if state.current_turn_count >= config.max_per_turn {
791            return Err(SafetyError::TurnLimitReached {
792                max: config.max_per_turn,
793            });
794        }
795
796        if state.session_count >= config.max_per_session {
797            return Err(SafetyError::SessionLimitReached {
798                max: config.max_per_session,
799            });
800        }
801
802        Ok(())
803    }
804
805    fn record_execution_locked(&self, state: &mut RateLimiterState, now: Instant) {
806        state.current_turn_count = state.current_turn_count.saturating_add(1);
807        state.session_count = state.session_count.saturating_add(1);
808        state.calls_per_second.push_back(now);
809        state.calls_per_minute.push_back(now);
810    }
811
812    fn prune_rate_windows(&self, state: &mut RateLimiterState, now: Instant) {
813        while let Some(front) = state.calls_per_second.front() {
814            if now.duration_since(*front) > Duration::from_secs(1) {
815                state.calls_per_second.pop_front();
816            } else {
817                break;
818            }
819        }
820        while let Some(front) = state.calls_per_minute.front() {
821            if now.duration_since(*front) > Duration::from_secs(60) {
822                state.calls_per_minute.pop_front();
823            } else {
824                break;
825            }
826        }
827    }
828
829    fn retry_after_for_violation(
830        &self,
831        violation: &SafetyError,
832        state: &RateLimiterState,
833        now: Instant,
834    ) -> Option<Duration> {
835        match violation {
836            SafetyError::RateLimitExceeded { window: "1s", .. } => state
837                .calls_per_second
838                .front()
839                .map(|first| Duration::from_secs(1).saturating_sub(now.duration_since(*first))),
840            SafetyError::RateLimitExceeded { window: "60s", .. } => state
841                .calls_per_minute
842                .front()
843                .map(|first| Duration::from_secs(60).saturating_sub(now.duration_since(*first))),
844            _ => None,
845        }
846    }
847
848    /// Check dotfile protection for file operations.
849    /// Returns Some(SafetyDecision) if dotfile protection applies, None otherwise.
850    async fn check_dotfile_protection(
851        &self,
852        tool_name: &str,
853        args: &Value,
854        session_id: &str,
855    ) -> Option<SafetyDecision> {
856        // Use local guardian if set, otherwise try global guardian
857        let guardian = match self.dotfile_guardian.as_ref() {
858            Some(g) => g.clone(),
859            None => get_global_guardian()?,
860        };
861
862        let file_targets = file_access_targets(tool_name, args);
863        if file_targets.is_empty() {
864            return None;
865        }
866
867        let proposed_changes = proposed_changes_preview(args);
868
869        for target in file_targets {
870            if !guardian.is_protected(Path::new(&target.path)) {
871                continue;
872            }
873
874            let context =
875                AccessContext::new(&target.path, target.access_type, tool_name, session_id)
876                    .with_proposed_changes(&proposed_changes);
877
878            match guardian.request_access(&context).await {
879                Ok(ProtectionDecision::Allowed) => continue,
880                Ok(ProtectionDecision::RequiresConfirmation(req)) => {
881                    return Some(SafetyDecision::NeedsApproval(format!(
882                        "DOTFILE PROTECTION\n\n\
883                        File: {}\n\
884                        Operation: {}\n\
885                        Reason: {}\n\n\
886                        Proposed changes:\n{}\n\n\
887                        {}",
888                        req.file_path,
889                        req.access_type,
890                        req.protection_reason,
891                        req.proposed_changes,
892                        req.warning
893                    )));
894                }
895                Ok(ProtectionDecision::RequiresSecondaryAuth(req)) => {
896                    return Some(SafetyDecision::NeedsApproval(format!(
897                        "DOTFILE SECONDARY AUTHENTICATION REQUIRED\n\n\
898                        File: {} (whitelisted)\n\
899                        Operation: {}\n\
900                        Reason: {}\n\n\
901                        This file is on the whitelist but requires secondary authentication.\n\n\
902                        Proposed changes:\n{}\n\n\
903                        {}",
904                        req.file_path,
905                        req.access_type,
906                        req.protection_reason,
907                        req.proposed_changes,
908                        req.warning
909                    )));
910                }
911                Ok(ProtectionDecision::Blocked(violation)) => {
912                    return Some(SafetyDecision::Deny(format!(
913                        "DOTFILE MODIFICATION BLOCKED\n\n\
914                            File: {}\n\
915                            Reason: {}\n\n\
916                            Suggestion: {}",
917                        violation.file_path, violation.reason, violation.suggestion
918                    )));
919                }
920                Ok(ProtectionDecision::Denied(violation)) => {
921                    return Some(SafetyDecision::Deny(format!(
922                        "DOTFILE ACCESS DENIED\n\n\
923                            File: {}\n\
924                            Reason: {}\n\n\
925                            Suggestion: {}",
926                        violation.file_path, violation.reason, violation.suggestion
927                    )));
928                }
929                Err(e) => {
930                    tracing::error!("Dotfile protection check failed: {}", e);
931                    return Some(SafetyDecision::Deny(format!(
932                        "Dotfile protection check failed: {}",
933                        e
934                    )));
935                }
936            }
937        }
938
939        None
940    }
941
942    /// Get the dotfile guardian (if configured)
943    pub fn dotfile_guardian(&self) -> Option<&Arc<DotfileGuardian>> {
944        self.dotfile_guardian.as_ref()
945    }
946
947    /// Build risk context from tool name and arguments
948    fn build_risk_context(&self, tool_name: &str, args: &Value) -> ToolRiskContext {
949        let source = if tool_name.starts_with("mcp_") {
950            ToolSource::Mcp
951        } else if tool_name.starts_with("acp_") {
952            ToolSource::Acp
953        } else {
954            ToolSource::Internal
955        };
956
957        let web_search =
958            tool_name == tools::UNIFIED_SEARCH && unified_search_action_is(args, "web");
959        let risk_tool_name = if web_search {
960            "unified_search:web"
961        } else {
962            tool_name
963        };
964
965        let mut ctx = ToolRiskContext::new(
966            risk_tool_name.to_string(),
967            source,
968            self.config.read().workspace_trust,
969        );
970
971        // Set flags based on tool type
972        if self.is_mutating_call(tool_name, args) {
973            ctx = ctx.as_write();
974        }
975        if self.is_destructive_call(tool_name, args) {
976            ctx = ctx.as_destructive();
977        }
978
979        // Check for network access
980        if tool_name == tools::WEB_SEARCH || tool_name == tools::FETCH_URL || web_search {
981            ctx = ctx.accesses_network();
982        }
983
984        // Extract command args for shell tools
985        if let Some(command) = command_text_for_tool(tool_name, args) {
986            ctx = ctx.with_args(command.split_whitespace().map(String::from).collect());
987        }
988
989        ctx
990    }
991
992    /// Build justification message for approval prompt
993    fn build_approval_justification(
994        &self,
995        tool_name: &str,
996        risk_level: &RiskLevel,
997        args: &Value,
998    ) -> String {
999        let mut parts = Vec::new();
1000
1001        parts.push(format!("Tool: {}", tool_name));
1002        parts.push(format!("Risk level: {}", risk_level));
1003
1004        if self.is_destructive_call(tool_name, args) {
1005            parts.push("This tool may modify or delete files.".to_string());
1006        }
1007
1008        if let Some(command) = command_text_for_tool(tool_name, args) {
1009            parts.push(format!("Command: {}", command));
1010        }
1011
1012        let file_targets = file_access_targets(tool_name, args);
1013        if let Some(target) = file_targets.first() {
1014            parts.push(format!("Path: {}", target.path.display()));
1015            if file_targets.len() > 1 {
1016                parts.push(format!("Additional targets: {}", file_targets.len() - 1));
1017            }
1018        }
1019
1020        parts.join("\n")
1021    }
1022
1023    /// Get current session statistics
1024    pub fn get_stats(&self) -> SafetyStats {
1025        let state = self.rate_state.lock();
1026        let preapproved = self.preapproved.lock();
1027        let config = self.config.read();
1028
1029        SafetyStats {
1030            turn_count: state.current_turn_count,
1031            session_count: state.session_count,
1032            max_per_turn: config.max_per_turn,
1033            max_per_session: config.max_per_session,
1034            plan_mode_active: config.plan_mode_active,
1035            preapproved_count: preapproved.len(),
1036        }
1037    }
1038}
1039
1040impl Default for SafetyGateway {
1041    fn default() -> Self {
1042        Self::new()
1043    }
1044}
1045
1046/// Statistics from the safety gateway
1047#[derive(Debug, Clone)]
1048pub struct SafetyStats {
1049    pub turn_count: usize,
1050    pub session_count: usize,
1051    pub max_per_turn: usize,
1052    pub max_per_session: usize,
1053    pub plan_mode_active: bool,
1054    pub preapproved_count: usize,
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059    use super::*;
1060    use vtcode_config::core::DotfileProtectionConfig;
1061
1062    fn make_ctx() -> SafetyContext {
1063        SafetyContext::new("test-session")
1064    }
1065
1066    #[tokio::test]
1067    async fn test_allow_read_only_tools() {
1068        let gateway = SafetyGateway::new();
1069        let ctx = make_ctx();
1070
1071        let decision = gateway
1072            .check_safety(&ctx, "read_file", &serde_json::json!({"path": "/tmp/test"}))
1073            .await;
1074
1075        assert!(decision.is_allowed());
1076    }
1077
1078    #[tokio::test]
1079    async fn test_destructive_tool_needs_approval() {
1080        let gateway = SafetyGateway::new();
1081        let ctx = make_ctx();
1082
1083        let decision = gateway
1084            .check_safety(
1085                &ctx,
1086                "delete_file",
1087                &serde_json::json!({"path": "/tmp/test"}),
1088            )
1089            .await;
1090
1091        assert!(decision.needs_approval());
1092    }
1093
1094    #[tokio::test]
1095    async fn test_plan_mode_blocks_mutating() {
1096        let gateway = SafetyGateway::new();
1097        gateway.set_plan_mode(true);
1098        let ctx = make_ctx();
1099
1100        let decision = gateway
1101            .check_safety(
1102                &ctx,
1103                "write_file",
1104                &serde_json::json!({"path": "/tmp/test"}),
1105            )
1106            .await;
1107
1108        assert!(decision.is_denied());
1109        assert!(decision.reason().unwrap().contains("plan mode"));
1110    }
1111
1112    #[tokio::test]
1113    async fn test_preapproved_tools_allowed() {
1114        let gateway = SafetyGateway::new();
1115        gateway.preapprove("delete_file");
1116        let ctx = make_ctx();
1117
1118        let decision = gateway
1119            .check_safety(
1120                &ctx,
1121                "delete_file",
1122                &serde_json::json!({"path": "/tmp/test"}),
1123            )
1124            .await;
1125
1126        assert!(decision.is_allowed());
1127    }
1128
1129    #[tokio::test]
1130    async fn test_trust_level_bypass() {
1131        let gateway = SafetyGateway::new();
1132        let mut ctx = make_ctx();
1133        ctx.trust_level = SafetyTrustLevel::Full;
1134
1135        let decision = gateway
1136            .check_safety(
1137                &ctx,
1138                "delete_file",
1139                &serde_json::json!({"path": "/tmp/test"}),
1140            )
1141            .await;
1142
1143        assert!(decision.is_allowed());
1144    }
1145
1146    #[tokio::test]
1147    async fn test_rate_limiting() {
1148        let config = SafetyGatewayConfig {
1149            rate_limit_per_second: 2,
1150            ..Default::default()
1151        };
1152        let gateway = SafetyGateway::with_config(config);
1153        let ctx = make_ctx();
1154
1155        // First two calls should succeed
1156        gateway.record_execution();
1157        gateway.record_execution();
1158
1159        // Third call should be denied
1160        let decision = gateway
1161            .check_safety(&ctx, "read_file", &serde_json::json!({}))
1162            .await;
1163
1164        assert!(decision.is_denied());
1165        assert!(decision.reason().unwrap().contains("Rate limit"));
1166    }
1167
1168    #[tokio::test]
1169    async fn test_atomic_check_and_record_rate_limited() {
1170        let config = SafetyGatewayConfig {
1171            rate_limit_per_second: 2,
1172            ..Default::default()
1173        };
1174        let gateway = SafetyGateway::with_config(config);
1175        let ctx = make_ctx();
1176
1177        let first = gateway
1178            .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1179            .await;
1180        assert!(first.decision.is_allowed());
1181
1182        let second = gateway
1183            .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1184            .await;
1185        assert!(second.decision.is_allowed());
1186
1187        let third = gateway
1188            .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1189            .await;
1190        assert!(third.decision.is_denied());
1191        assert!(third.retry_after.is_some());
1192        assert!(matches!(
1193            third.violation,
1194            Some(SafetyError::RateLimitExceeded { .. })
1195        ));
1196    }
1197
1198    #[tokio::test]
1199    async fn test_command_policy_enforcement() {
1200        let mut commands_config = CommandsConfig::default();
1201        commands_config.deny_list.push("rm".to_string());
1202
1203        let gateway = SafetyGateway::new().with_commands_config(&commands_config);
1204        let ctx = make_ctx();
1205
1206        let decision = gateway
1207            .check_safety(&ctx, "shell", &serde_json::json!({"command": "rm -rf /"}))
1208            .await;
1209
1210        assert!(decision.is_denied());
1211    }
1212
1213    #[tokio::test]
1214    async fn test_unified_exec_command_policy_enforcement_with_indexed_args() {
1215        let mut commands_config = CommandsConfig::default();
1216        commands_config.deny_list.push("rm".to_string());
1217
1218        let gateway = SafetyGateway::new().with_commands_config(&commands_config);
1219        let ctx = make_ctx();
1220
1221        let decision = gateway
1222            .check_safety(
1223                &ctx,
1224                tools::UNIFIED_EXEC,
1225                &serde_json::json!({
1226                    "command.0": "rm",
1227                    "command.1": "-rf",
1228                    "command.2": "/"
1229                }),
1230            )
1231            .await;
1232
1233        assert!(decision.is_denied());
1234    }
1235
1236    #[tokio::test]
1237    async fn test_unified_exec_continue_command_policy_enforcement_with_input() {
1238        let mut commands_config = CommandsConfig::default();
1239        commands_config.deny_list.push("rm".to_string());
1240
1241        let gateway = SafetyGateway::new().with_commands_config(&commands_config);
1242        let ctx = make_ctx();
1243
1244        let decision = gateway
1245            .check_safety(
1246                &ctx,
1247                tools::UNIFIED_EXEC,
1248                &serde_json::json!({
1249                    "action": "continue",
1250                    "session_id": "run-123",
1251                    "input": "rm -rf /\n"
1252                }),
1253            )
1254            .await;
1255
1256        assert!(decision.is_denied());
1257    }
1258
1259    #[tokio::test]
1260    async fn test_send_pty_input_command_policy_enforcement_with_input() {
1261        let mut commands_config = CommandsConfig::default();
1262        commands_config.deny_list.push("rm".to_string());
1263
1264        let gateway = SafetyGateway::new().with_commands_config(&commands_config);
1265        let ctx = make_ctx();
1266
1267        let decision = gateway
1268            .check_safety(
1269                &ctx,
1270                tools::SEND_PTY_INPUT,
1271                &serde_json::json!({
1272                    "session_id": "run-123",
1273                    "input": "rm -rf /\n"
1274                }),
1275            )
1276            .await;
1277
1278        assert!(decision.is_denied());
1279    }
1280
1281    #[tokio::test]
1282    async fn test_apply_patch_dotfile_protection_requires_approval() {
1283        let guardian = Arc::new(
1284            DotfileGuardian::new(DotfileProtectionConfig {
1285                audit_logging_enabled: false,
1286                create_backups: false,
1287                ..Default::default()
1288            })
1289            .await
1290            .expect("guardian should initialize"),
1291        );
1292        let gateway = SafetyGateway::new().with_dotfile_guardian(guardian);
1293        let ctx = make_ctx();
1294
1295        let decision = gateway
1296            .check_safety(
1297                &ctx,
1298                tools::APPLY_PATCH,
1299                &serde_json::json!({
1300                    "input": "*** Begin Patch\n*** Update File: .gitignore\n@@\n-old\n+new\n*** End Patch\n"
1301                }),
1302            )
1303            .await;
1304
1305        assert!(decision.needs_approval());
1306        assert!(
1307            decision
1308                .reason()
1309                .is_some_and(|reason| reason.contains(".gitignore"))
1310        );
1311    }
1312
1313    #[test]
1314    fn test_patch_file_access_targets_preserve_patch_order() {
1315        let targets = patch_file_access_targets(&serde_json::json!({
1316            "input": "*** Begin Patch\n*** Update File: src/main.rs\n@@\n-old\n+new\n*** Update File: .gitignore\n@@\n-old\n+new\n*** End Patch\n"
1317        }));
1318
1319        assert_eq!(targets.len(), 2);
1320        assert_eq!(targets[0].path, PathBuf::from("src/main.rs"));
1321        assert_eq!(targets[0].access_type, AccessType::Modify);
1322        assert_eq!(targets[1].path, PathBuf::from(".gitignore"));
1323        assert_eq!(targets[1].access_type, AccessType::Modify);
1324    }
1325
1326    #[tokio::test]
1327    async fn test_unified_file_patch_dotfile_protection_requires_approval() {
1328        let guardian = Arc::new(
1329            DotfileGuardian::new(DotfileProtectionConfig {
1330                audit_logging_enabled: false,
1331                create_backups: false,
1332                ..Default::default()
1333            })
1334            .await
1335            .expect("guardian should initialize"),
1336        );
1337        let gateway = SafetyGateway::new().with_dotfile_guardian(guardian);
1338        let ctx = make_ctx();
1339
1340        let decision = gateway
1341            .check_safety(
1342                &ctx,
1343                tools::UNIFIED_FILE,
1344                &serde_json::json!({
1345                    "action": "patch",
1346                    "patch": "*** Begin Patch\n*** Update File: .gitignore\n@@\n-old\n+new\n*** End Patch\n"
1347                }),
1348            )
1349            .await;
1350
1351        assert!(decision.needs_approval());
1352        assert!(
1353            decision
1354                .reason()
1355                .is_some_and(|reason| reason.contains(".gitignore"))
1356        );
1357    }
1358
1359    #[tokio::test]
1360    async fn test_stats_tracking() {
1361        let gateway = SafetyGateway::new();
1362        gateway.preapprove("test_tool");
1363        gateway.record_execution();
1364        gateway.record_execution();
1365
1366        let stats = gateway.get_stats();
1367        assert_eq!(stats.turn_count, 2);
1368        assert_eq!(stats.session_count, 2);
1369        assert_eq!(stats.preapproved_count, 1);
1370    }
1371
1372    #[tokio::test]
1373    async fn test_start_turn_resets_counters() {
1374        let gateway = SafetyGateway::new();
1375
1376        gateway.record_execution();
1377        gateway.record_execution();
1378
1379        let stats_before = gateway.get_stats();
1380        assert_eq!(stats_before.turn_count, 2);
1381
1382        gateway.start_turn();
1383
1384        let stats_after = gateway.get_stats();
1385        assert_eq!(stats_after.turn_count, 0);
1386        assert_eq!(stats_after.session_count, 2); // Session count preserved
1387    }
1388
1389    #[tokio::test]
1390    async fn test_increase_session_limit_updates_limit() {
1391        let gateway = SafetyGateway::new();
1392        gateway.set_limits(10, 1);
1393        let ctx = make_ctx();
1394
1395        let first = gateway
1396            .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1397            .await;
1398        assert!(first.decision.is_allowed());
1399
1400        let second = gateway
1401            .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1402            .await;
1403        assert!(second.decision.is_denied());
1404
1405        gateway.increase_session_limit(1);
1406
1407        let third = gateway
1408            .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1409            .await;
1410        assert!(third.decision.is_allowed());
1411    }
1412}