1use hashbrown::HashSet;
12use parking_lot::{Mutex, RwLock};
13use std::future::Future;
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17
18use serde::{Deserialize, Serialize};
19use serde_json::Value;
20use thiserror::Error;
21
22use crate::config::CommandsConfig;
23use crate::config::constants::tools;
24use crate::dotfile_protection::{
25 AccessContext, AccessType, DotfileGuardian, ProtectionDecision, get_global_guardian,
26};
27use crate::tools::apply_patch::{Patch, PatchOperation, decode_apply_patch_input};
28use crate::tools::command_policy::CommandPolicyEvaluator;
29use crate::tools::invocation::ToolInvocationId;
30use crate::tools::rate_limit_config::{
31 tool_calls_per_minute_from_env, tool_calls_per_second_from_env,
32};
33use crate::tools::registry::{
34 RiskLevel, ToolRiskContext, ToolRiskScorer, ToolSource, WorkspaceTrust,
35};
36use crate::tools::tool_intent::{
37 classify_tool_intent, unified_exec_action_in, unified_exec_action_is, unified_file_action_is,
38 unified_search_action_is,
39};
40use vtcode_config::core::DotfileProtectionConfig;
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
44pub enum SafetyTrustLevel {
45 Untrusted,
46 #[default]
47 Standard,
48 Elevated,
49 Full,
50}
51
52impl SafetyTrustLevel {
53 #[inline]
54 pub const fn can_bypass_approval(self) -> bool {
55 matches!(self, Self::Elevated | Self::Full)
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct SafetyContext {
62 pub session_id: String,
63 pub trust_level: SafetyTrustLevel,
64}
65
66impl SafetyContext {
67 #[must_use]
68 pub fn new(session_id: impl Into<String>) -> Self {
69 Self {
70 session_id: session_id.into(),
71 trust_level: SafetyTrustLevel::default(),
72 }
73 }
74}
75
76#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
78pub enum SafetyDecision {
79 Allow,
81 Deny(String),
83 NeedsApproval(String),
85}
86
87impl SafetyDecision {
88 #[inline]
90 pub fn is_allowed(&self) -> bool {
91 matches!(self, SafetyDecision::Allow)
92 }
93
94 #[inline]
96 pub fn is_denied(&self) -> bool {
97 matches!(self, SafetyDecision::Deny(_))
98 }
99
100 #[inline]
102 pub fn needs_approval(&self) -> bool {
103 matches!(self, SafetyDecision::NeedsApproval(_))
104 }
105
106 pub fn reason(&self) -> Option<&str> {
108 match self {
109 SafetyDecision::Allow => None,
110 SafetyDecision::Deny(reason) | SafetyDecision::NeedsApproval(reason) => Some(reason),
111 }
112 }
113}
114
115#[derive(Debug, Error, Clone)]
117pub enum SafetyError {
118 #[error("Rate limit exceeded: {current} calls in {window} (max: {max})")]
119 RateLimitExceeded {
120 current: usize,
121 max: usize,
122 window: &'static str,
123 },
124 #[error("Per-turn tool limit reached (max: {max})")]
125 TurnLimitReached { max: usize },
126 #[error("Session tool limit reached (max: {max})")]
127 SessionLimitReached { max: usize },
128 #[error("Plan mode violation: {0}")]
129 PlanModeViolation(String),
130 #[error("Command policy denied: {0}")]
131 CommandPolicyDenied(String),
132 #[error("Dotfile protection violation: {0}")]
133 DotfileProtectionViolation(String),
134}
135
136#[derive(Debug, Clone)]
138pub struct SafetyCheckResult {
139 pub decision: SafetyDecision,
141 pub retry_after: Option<Duration>,
143 pub violation: Option<SafetyError>,
145}
146
147#[derive(Debug, Clone)]
149pub struct SafetyGatewayConfig {
150 pub max_per_turn: usize,
152 pub max_per_session: usize,
154 pub rate_limit_per_second: usize,
156 pub rate_limit_per_minute: Option<usize>,
158 pub plan_mode_active: bool,
160 pub workspace_trust: WorkspaceTrust,
162 pub approval_risk_threshold: RiskLevel,
164 pub enforce_rate_limits: bool,
167}
168
169impl Default for SafetyGatewayConfig {
170 fn default() -> Self {
171 let rate_limit_per_second = tool_calls_per_second_from_env().unwrap_or(5);
172
173 Self {
174 max_per_turn: 50,
175 max_per_session: 100,
176 rate_limit_per_second,
177 rate_limit_per_minute: tool_calls_per_minute_from_env(),
178 plan_mode_active: false,
179 workspace_trust: WorkspaceTrust::Trusted,
180 approval_risk_threshold: RiskLevel::Medium,
181 enforce_rate_limits: true,
182 }
183 }
184}
185
186#[derive(Debug, Default)]
188struct RateLimiterState {
189 calls_per_second: std::collections::VecDeque<Instant>,
190 calls_per_minute: std::collections::VecDeque<Instant>,
191 current_turn_count: usize,
192 session_count: usize,
193}
194
195pub struct SafetyGateway {
201 config: RwLock<SafetyGatewayConfig>,
203 command_policy: Option<CommandPolicyEvaluator>,
205 rate_state: Mutex<RateLimiterState>,
207 preapproved: Mutex<HashSet<String>>,
209 dotfile_guardian: Option<Arc<DotfileGuardian>>,
211}
212
213#[derive(Debug, Clone, PartialEq, Eq, Hash)]
214struct FileAccessTarget {
215 path: PathBuf,
216 access_type: AccessType,
217}
218
219fn primary_path_arg(args: &Value) -> Option<&str> {
220 args.get("path")
221 .and_then(|value| value.as_str())
222 .or_else(|| args.get("file_path").and_then(|value| value.as_str()))
223 .or_else(|| args.get("filepath").and_then(|value| value.as_str()))
224 .or_else(|| args.get("target_path").and_then(|value| value.as_str()))
225}
226
227fn destination_path_arg(args: &Value) -> Option<&str> {
228 args.get("destination").and_then(|value| value.as_str())
229}
230
231fn push_file_access_target(
232 targets: &mut Vec<FileAccessTarget>,
233 path: &str,
234 access_type: AccessType,
235) {
236 let path_str = path.trim();
237 if path_str.is_empty() {
238 return;
239 }
240
241 let path = PathBuf::from(path_str);
242 if targets
245 .iter()
246 .any(|existing| existing.path == path && existing.access_type == access_type)
247 {
248 return;
249 }
250
251 targets.push(FileAccessTarget { path, access_type });
252}
253
254fn command_text_for_tool(tool_name: &str, args: &Value) -> Option<String> {
255 match tool_name {
256 tools::SHELL | tools::RUN_PTY_CMD => crate::tools::command_args::command_text(args)
257 .ok()
258 .flatten(),
259 tools::SEND_PTY_INPUT => {
260 crate::tools::command_args::interactive_input_text(args).map(str::to_owned)
261 }
262 tools::UNIFIED_EXEC if unified_exec_action_is(args, "run") => {
263 crate::tools::command_args::command_text(args)
264 .ok()
265 .flatten()
266 }
267 tools::UNIFIED_EXEC if unified_exec_action_in(args, &["write", "continue"]) => {
268 crate::tools::command_args::interactive_input_text(args).map(str::to_owned)
269 }
270 _ => None,
271 }
272}
273
274fn patch_file_access_targets(args: &Value) -> Vec<FileAccessTarget> {
275 let Ok(Some(patch_input)) = decode_apply_patch_input(args) else {
276 return Vec::new();
277 };
278 let Ok(patch) = Patch::parse(&patch_input.text) else {
279 return Vec::new();
280 };
281
282 let mut targets = Vec::new();
283 for operation in patch.operations() {
284 match operation {
285 PatchOperation::AddFile { path, .. } => {
286 push_file_access_target(&mut targets, path, AccessType::Write);
287 }
288 PatchOperation::DeleteFile { path } => {
289 push_file_access_target(&mut targets, path, AccessType::Delete);
290 }
291 PatchOperation::UpdateFile { path, new_path, .. } => {
292 push_file_access_target(&mut targets, path, AccessType::Modify);
293 if let Some(destination) =
294 new_path.as_deref().filter(|candidate| *candidate != path)
295 {
296 push_file_access_target(&mut targets, destination, AccessType::Write);
297 }
298 }
299 }
300 }
301
302 targets
303}
304
305fn file_access_targets(tool_name: &str, args: &Value) -> Vec<FileAccessTarget> {
306 let mut targets = Vec::new();
307
308 match tool_name {
309 tools::WRITE_FILE | tools::CREATE_FILE => {
310 if let Some(path) = primary_path_arg(args) {
311 push_file_access_target(&mut targets, path, AccessType::Write);
312 }
313 }
314 tools::EDIT_FILE | "search_replace" => {
315 if let Some(path) = primary_path_arg(args) {
316 push_file_access_target(&mut targets, path, AccessType::Modify);
317 }
318 }
319 tools::DELETE_FILE => {
320 if let Some(path) = primary_path_arg(args) {
321 push_file_access_target(&mut targets, path, AccessType::Delete);
322 }
323 }
324 tools::MOVE_FILE => {
325 if let Some(path) = primary_path_arg(args) {
326 push_file_access_target(&mut targets, path, AccessType::Modify);
327 }
328 if let Some(path) = destination_path_arg(args) {
329 push_file_access_target(&mut targets, path, AccessType::Write);
330 }
331 }
332 tools::COPY_FILE => {
333 if let Some(path) = destination_path_arg(args) {
334 push_file_access_target(&mut targets, path, AccessType::Write);
335 }
336 }
337 tools::APPLY_PATCH => {
338 targets.extend(patch_file_access_targets(args));
339 }
340 tools::UNIFIED_FILE if unified_file_action_is(args, "write") => {
341 if let Some(path) = primary_path_arg(args) {
342 push_file_access_target(&mut targets, path, AccessType::Write);
343 }
344 }
345 tools::UNIFIED_FILE if unified_file_action_is(args, "edit") => {
346 if let Some(path) = primary_path_arg(args) {
347 push_file_access_target(&mut targets, path, AccessType::Modify);
348 }
349 }
350 tools::UNIFIED_FILE if unified_file_action_is(args, "delete") => {
351 if let Some(path) = primary_path_arg(args) {
352 push_file_access_target(&mut targets, path, AccessType::Delete);
353 }
354 }
355 tools::UNIFIED_FILE if unified_file_action_is(args, "move") => {
356 if let Some(path) = primary_path_arg(args) {
357 push_file_access_target(&mut targets, path, AccessType::Modify);
358 }
359 if let Some(path) = destination_path_arg(args) {
360 push_file_access_target(&mut targets, path, AccessType::Write);
361 }
362 }
363 tools::UNIFIED_FILE if unified_file_action_is(args, "copy") => {
364 if let Some(path) = destination_path_arg(args) {
365 push_file_access_target(&mut targets, path, AccessType::Write);
366 }
367 }
368 tools::UNIFIED_FILE if unified_file_action_is(args, "patch") => {
369 targets.extend(patch_file_access_targets(args));
370 }
371 _ => {}
372 }
373
374 targets
375}
376
377fn proposed_changes_preview(args: &Value) -> String {
378 const PREVIEW_LIMIT: usize = 500;
379
380 let preview_text = |label: &str, text: &str| {
381 let preview_len = text.len().min(PREVIEW_LIMIT);
382 format!(
383 "{label} ({} bytes):\n{}{}",
384 text.len(),
385 &text[..preview_len],
386 if text.len() > preview_len { "..." } else { "" }
387 )
388 };
389
390 if let Some(content) = args.get("content").and_then(|value| value.as_str()) {
391 return preview_text("Content", content);
392 }
393
394 if let Some(old_str) = args.get("old_str").and_then(|value| value.as_str()) {
395 let new_str = args
396 .get("new_str")
397 .and_then(|value| value.as_str())
398 .unwrap_or("");
399 return format!("Replace:\n '{}'\nWith:\n '{}'", old_str, new_str);
400 }
401
402 if let Ok(Some(patch_input)) = decode_apply_patch_input(args) {
403 return preview_text("Patch", &patch_input.text);
404 }
405
406 "No details provided".to_string()
407}
408
409impl SafetyGateway {
410 pub fn new() -> Self {
412 Self::with_config(SafetyGatewayConfig::default())
413 }
414
415 pub fn with_config(config: SafetyGatewayConfig) -> Self {
417 Self {
418 config: RwLock::new(config),
419 command_policy: None,
420 rate_state: Mutex::new(RateLimiterState::default()),
421 preapproved: Mutex::new(HashSet::new()),
422 dotfile_guardian: None,
423 }
424 }
425
426 pub fn with_dotfile_guardian(mut self, guardian: Arc<DotfileGuardian>) -> Self {
428 self.dotfile_guardian = Some(guardian);
429 self
430 }
431
432 pub async fn with_dotfile_protection(
434 mut self,
435 config: DotfileProtectionConfig,
436 ) -> anyhow::Result<Self> {
437 let guardian = DotfileGuardian::new(config).await?;
438 self.dotfile_guardian = Some(Arc::new(guardian));
439 Ok(self)
440 }
441
442 pub fn with_command_policy(mut self, policy: CommandPolicyEvaluator) -> Self {
444 self.command_policy = Some(policy);
445 self
446 }
447
448 pub fn with_commands_config(mut self, config: &CommandsConfig) -> Self {
450 self.command_policy = Some(CommandPolicyEvaluator::from_config(config));
451 self
452 }
453
454 pub fn set_plan_mode(&self, active: bool) {
456 self.config.write().plan_mode_active = active;
457 }
458
459 pub fn set_workspace_trust(&self, trust: WorkspaceTrust) {
461 self.config.write().workspace_trust = trust;
462 }
463
464 pub fn set_limits(&self, max_per_turn: usize, max_per_session: usize) {
466 let mut config = self.config.write();
467 config.max_per_turn = max_per_turn;
468 config.max_per_session = max_per_session;
469 }
470
471 pub fn set_rate_limits(
473 &self,
474 rate_limit_per_second: usize,
475 rate_limit_per_minute: Option<usize>,
476 ) {
477 let mut config = self.config.write();
478 if rate_limit_per_second > 0 {
479 config.rate_limit_per_second = rate_limit_per_second;
480 }
481 config.rate_limit_per_minute = rate_limit_per_minute.filter(|v| *v > 0);
482 }
483
484 pub fn set_rate_limit_enforcement(&self, enabled: bool) {
486 self.config.write().enforce_rate_limits = enabled;
487 }
488
489 pub fn increase_session_limit(&self, increment: usize) {
491 let mut config = self.config.write();
492 let new_max = config.max_per_session.saturating_add(increment);
493 config.max_per_session = new_max;
494 tracing::info!("Session tool limit increased to {}", new_max);
495 }
496
497 pub fn max_per_session(&self) -> usize {
498 self.config.read().max_per_session
499 }
500
501 pub fn start_turn(&self) {
503 let mut state = self.rate_state.lock();
504 state.current_turn_count = 0;
505 state.calls_per_second.clear();
506 state.calls_per_minute.clear();
507 }
508
509 pub fn preapprove(&self, tool_name: &str) {
511 let mut preapproved = self.preapproved.lock();
512 preapproved.insert(tool_name.to_string());
513 }
514
515 pub fn is_preapproved(&self, tool_name: &str) -> bool {
517 let preapproved = self.preapproved.lock();
518 preapproved.contains(tool_name)
519 }
520
521 pub fn is_destructive(&self, tool_name: &str) -> bool {
523 classify_tool_intent(tool_name, &Value::Object(Default::default())).destructive
524 }
525
526 pub fn is_mutating(&self, tool_name: &str) -> bool {
528 classify_tool_intent(tool_name, &Value::Object(Default::default())).mutating
529 }
530
531 fn is_destructive_call(&self, tool_name: &str, args: &Value) -> bool {
532 classify_tool_intent(tool_name, args).destructive
533 }
534
535 fn is_mutating_call(&self, tool_name: &str, args: &Value) -> bool {
536 classify_tool_intent(tool_name, args).mutating
537 }
538
539 pub fn check_safety<'a>(
545 &'a self,
546 ctx: &'a SafetyContext,
547 tool_name: &'a str,
548 args: &'a Value,
549 ) -> impl Future<Output = SafetyDecision> + 'a {
550 self.check_safety_with_id(ctx, tool_name, args, None)
551 }
552
553 pub async fn check_safety_with_id(
555 &self,
556 ctx: &SafetyContext,
557 tool_name: &str,
558 args: &Value,
559 invocation_id: Option<ToolInvocationId>,
560 ) -> SafetyDecision {
561 let inv_id = invocation_id
562 .map(|id| id.short())
563 .unwrap_or_else(|| "unknown".to_string());
564
565 tracing::trace!(
566 invocation_id = %inv_id,
567 tool = %tool_name,
568 "SafetyGateway: checking safety"
569 );
570
571 if let Err(err) = self.check_rate_limits() {
572 tracing::warn!(
573 invocation_id = %inv_id,
574 error = %err,
575 "SafetyGateway: rate limit exceeded"
576 );
577 return SafetyDecision::Deny(err.to_string());
578 }
579
580 self.evaluate_non_rate_decision(ctx, tool_name, args, &inv_id)
581 .await
582 }
583
584 pub async fn check_and_record(
589 &self,
590 ctx: &SafetyContext,
591 tool_name: &str,
592 args: &Value,
593 ) -> SafetyCheckResult {
594 self.check_and_record_with_id(ctx, tool_name, args, None)
595 .await
596 }
597
598 pub async fn check_and_record_with_id(
600 &self,
601 ctx: &SafetyContext,
602 tool_name: &str,
603 args: &Value,
604 invocation_id: Option<ToolInvocationId>,
605 ) -> SafetyCheckResult {
606 let inv_id = invocation_id
607 .map(|id| id.short())
608 .unwrap_or_else(|| "unknown".to_string());
609 tracing::trace!(
610 invocation_id = %inv_id,
611 tool = %tool_name,
612 "SafetyGateway: checking and recording safety"
613 );
614
615 let decision = self
616 .evaluate_non_rate_decision(ctx, tool_name, args, &inv_id)
617 .await;
618
619 if decision.is_denied() {
620 return SafetyCheckResult {
621 decision,
622 retry_after: None,
623 violation: None,
624 };
625 }
626
627 let now = Instant::now();
628 let mut state = self.rate_state.lock();
629 match self.check_rate_limits_locked(&mut state, now) {
630 Ok(()) => {
631 self.record_execution_locked(&mut state, now);
632 SafetyCheckResult {
633 decision,
634 retry_after: None,
635 violation: None,
636 }
637 }
638 Err(err) => {
639 tracing::warn!(
640 invocation_id = %inv_id,
641 error = %err,
642 "SafetyGateway: rate limit exceeded during atomic reservation"
643 );
644 SafetyCheckResult {
645 decision: SafetyDecision::Deny(err.to_string()),
646 retry_after: self.retry_after_for_violation(&err, &state, now),
647 violation: Some(err),
648 }
649 }
650 }
651 }
652
653 async fn evaluate_non_rate_decision(
654 &self,
655 ctx: &SafetyContext,
656 tool_name: &str,
657 args: &Value,
658 inv_id: &str,
659 ) -> SafetyDecision {
660 if let Some(decision) = self
661 .check_dotfile_protection(tool_name, args, &ctx.session_id)
662 .await
663 {
664 tracing::info!(
665 invocation_id = %inv_id,
666 tool = %tool_name,
667 "SafetyGateway: dotfile protection triggered"
668 );
669 return decision;
670 }
671
672 if self.config.read().plan_mode_active && self.is_mutating_call(tool_name, args) {
673 let reason = format!(
674 "Tool '{}' is blocked in plan mode (read-only). Switch to edit mode to execute.",
675 tool_name
676 );
677 tracing::info!(
678 invocation_id = %inv_id,
679 tool = %tool_name,
680 "SafetyGateway: plan mode violation"
681 );
682 return SafetyDecision::Deny(reason);
683 }
684
685 if let Some(ref policy) = self.command_policy
686 && let Some(command) = command_text_for_tool(tool_name, args)
687 && !policy.allows_text(&command)
688 {
689 let reason = format!("Command '{}' blocked by policy", command);
690 tracing::info!(
691 invocation_id = %inv_id,
692 command = %command,
693 "SafetyGateway: command policy denied"
694 );
695 return SafetyDecision::Deny(reason);
696 }
697
698 if self.is_preapproved(tool_name) {
699 tracing::trace!(
700 invocation_id = %inv_id,
701 tool = %tool_name,
702 "SafetyGateway: tool preapproved"
703 );
704 return SafetyDecision::Allow;
705 }
706
707 if ctx.trust_level.can_bypass_approval() {
708 tracing::trace!(
709 invocation_id = %inv_id,
710 tool = %tool_name,
711 trust_level = ?ctx.trust_level,
712 "SafetyGateway: trust level allows bypass"
713 );
714 return SafetyDecision::Allow;
715 }
716
717 let risk_ctx = self.build_risk_context(tool_name, args);
718 let risk_level = ToolRiskScorer::calculate_risk(&risk_ctx);
719
720 if ToolRiskScorer::requires_justification(
721 risk_level,
722 self.config.read().approval_risk_threshold,
723 ) {
724 let justification = self.build_approval_justification(tool_name, &risk_level, args);
725 tracing::info!(
726 invocation_id = %inv_id,
727 tool = %tool_name,
728 risk = %risk_level,
729 "SafetyGateway: requires approval"
730 );
731 return SafetyDecision::NeedsApproval(justification);
732 }
733
734 if self.is_destructive_call(tool_name, args) {
735 let justification = format!(
736 "Tool '{}' is destructive and may modify files or execute commands.",
737 tool_name
738 );
739 tracing::info!(
740 invocation_id = %inv_id,
741 tool = %tool_name,
742 "SafetyGateway: destructive tool requires approval"
743 );
744 return SafetyDecision::NeedsApproval(justification);
745 }
746
747 SafetyDecision::Allow
748 }
749
750 pub fn record_execution(&self) {
752 let mut state = self.rate_state.lock();
753 self.record_execution_locked(&mut state, Instant::now());
754 }
755
756 fn check_rate_limits(&self) -> Result<(), SafetyError> {
758 let mut state = self.rate_state.lock();
759 self.check_rate_limits_locked(&mut state, Instant::now())
760 }
761
762 fn check_rate_limits_locked(
763 &self,
764 state: &mut RateLimiterState,
765 now: Instant,
766 ) -> Result<(), SafetyError> {
767 let config = self.config.read();
768 self.prune_rate_windows(state, now);
769
770 if config.enforce_rate_limits {
771 if state.calls_per_second.len() >= config.rate_limit_per_second {
772 return Err(SafetyError::RateLimitExceeded {
773 current: state.calls_per_second.len(),
774 max: config.rate_limit_per_second,
775 window: "1s",
776 });
777 }
778
779 if let Some(limit) = config.rate_limit_per_minute
780 && state.calls_per_minute.len() >= limit
781 {
782 return Err(SafetyError::RateLimitExceeded {
783 current: state.calls_per_minute.len(),
784 max: limit,
785 window: "60s",
786 });
787 }
788 }
789
790 if state.current_turn_count >= config.max_per_turn {
791 return Err(SafetyError::TurnLimitReached {
792 max: config.max_per_turn,
793 });
794 }
795
796 if state.session_count >= config.max_per_session {
797 return Err(SafetyError::SessionLimitReached {
798 max: config.max_per_session,
799 });
800 }
801
802 Ok(())
803 }
804
805 fn record_execution_locked(&self, state: &mut RateLimiterState, now: Instant) {
806 state.current_turn_count = state.current_turn_count.saturating_add(1);
807 state.session_count = state.session_count.saturating_add(1);
808 state.calls_per_second.push_back(now);
809 state.calls_per_minute.push_back(now);
810 }
811
812 fn prune_rate_windows(&self, state: &mut RateLimiterState, now: Instant) {
813 while let Some(front) = state.calls_per_second.front() {
814 if now.duration_since(*front) > Duration::from_secs(1) {
815 state.calls_per_second.pop_front();
816 } else {
817 break;
818 }
819 }
820 while let Some(front) = state.calls_per_minute.front() {
821 if now.duration_since(*front) > Duration::from_secs(60) {
822 state.calls_per_minute.pop_front();
823 } else {
824 break;
825 }
826 }
827 }
828
829 fn retry_after_for_violation(
830 &self,
831 violation: &SafetyError,
832 state: &RateLimiterState,
833 now: Instant,
834 ) -> Option<Duration> {
835 match violation {
836 SafetyError::RateLimitExceeded { window: "1s", .. } => state
837 .calls_per_second
838 .front()
839 .map(|first| Duration::from_secs(1).saturating_sub(now.duration_since(*first))),
840 SafetyError::RateLimitExceeded { window: "60s", .. } => state
841 .calls_per_minute
842 .front()
843 .map(|first| Duration::from_secs(60).saturating_sub(now.duration_since(*first))),
844 _ => None,
845 }
846 }
847
848 async fn check_dotfile_protection(
851 &self,
852 tool_name: &str,
853 args: &Value,
854 session_id: &str,
855 ) -> Option<SafetyDecision> {
856 let guardian = match self.dotfile_guardian.as_ref() {
858 Some(g) => g.clone(),
859 None => get_global_guardian()?,
860 };
861
862 let file_targets = file_access_targets(tool_name, args);
863 if file_targets.is_empty() {
864 return None;
865 }
866
867 let proposed_changes = proposed_changes_preview(args);
868
869 for target in file_targets {
870 if !guardian.is_protected(Path::new(&target.path)) {
871 continue;
872 }
873
874 let context =
875 AccessContext::new(&target.path, target.access_type, tool_name, session_id)
876 .with_proposed_changes(&proposed_changes);
877
878 match guardian.request_access(&context).await {
879 Ok(ProtectionDecision::Allowed) => continue,
880 Ok(ProtectionDecision::RequiresConfirmation(req)) => {
881 return Some(SafetyDecision::NeedsApproval(format!(
882 "DOTFILE PROTECTION\n\n\
883 File: {}\n\
884 Operation: {}\n\
885 Reason: {}\n\n\
886 Proposed changes:\n{}\n\n\
887 {}",
888 req.file_path,
889 req.access_type,
890 req.protection_reason,
891 req.proposed_changes,
892 req.warning
893 )));
894 }
895 Ok(ProtectionDecision::RequiresSecondaryAuth(req)) => {
896 return Some(SafetyDecision::NeedsApproval(format!(
897 "DOTFILE SECONDARY AUTHENTICATION REQUIRED\n\n\
898 File: {} (whitelisted)\n\
899 Operation: {}\n\
900 Reason: {}\n\n\
901 This file is on the whitelist but requires secondary authentication.\n\n\
902 Proposed changes:\n{}\n\n\
903 {}",
904 req.file_path,
905 req.access_type,
906 req.protection_reason,
907 req.proposed_changes,
908 req.warning
909 )));
910 }
911 Ok(ProtectionDecision::Blocked(violation)) => {
912 return Some(SafetyDecision::Deny(format!(
913 "DOTFILE MODIFICATION BLOCKED\n\n\
914 File: {}\n\
915 Reason: {}\n\n\
916 Suggestion: {}",
917 violation.file_path, violation.reason, violation.suggestion
918 )));
919 }
920 Ok(ProtectionDecision::Denied(violation)) => {
921 return Some(SafetyDecision::Deny(format!(
922 "DOTFILE ACCESS DENIED\n\n\
923 File: {}\n\
924 Reason: {}\n\n\
925 Suggestion: {}",
926 violation.file_path, violation.reason, violation.suggestion
927 )));
928 }
929 Err(e) => {
930 tracing::error!("Dotfile protection check failed: {}", e);
931 return Some(SafetyDecision::Deny(format!(
932 "Dotfile protection check failed: {}",
933 e
934 )));
935 }
936 }
937 }
938
939 None
940 }
941
942 pub fn dotfile_guardian(&self) -> Option<&Arc<DotfileGuardian>> {
944 self.dotfile_guardian.as_ref()
945 }
946
947 fn build_risk_context(&self, tool_name: &str, args: &Value) -> ToolRiskContext {
949 let source = if tool_name.starts_with("mcp_") {
950 ToolSource::Mcp
951 } else if tool_name.starts_with("acp_") {
952 ToolSource::Acp
953 } else {
954 ToolSource::Internal
955 };
956
957 let web_search =
958 tool_name == tools::UNIFIED_SEARCH && unified_search_action_is(args, "web");
959 let risk_tool_name = if web_search {
960 "unified_search:web"
961 } else {
962 tool_name
963 };
964
965 let mut ctx = ToolRiskContext::new(
966 risk_tool_name.to_string(),
967 source,
968 self.config.read().workspace_trust,
969 );
970
971 if self.is_mutating_call(tool_name, args) {
973 ctx = ctx.as_write();
974 }
975 if self.is_destructive_call(tool_name, args) {
976 ctx = ctx.as_destructive();
977 }
978
979 if tool_name == tools::WEB_SEARCH || tool_name == tools::FETCH_URL || web_search {
981 ctx = ctx.accesses_network();
982 }
983
984 if let Some(command) = command_text_for_tool(tool_name, args) {
986 ctx = ctx.with_args(command.split_whitespace().map(String::from).collect());
987 }
988
989 ctx
990 }
991
992 fn build_approval_justification(
994 &self,
995 tool_name: &str,
996 risk_level: &RiskLevel,
997 args: &Value,
998 ) -> String {
999 let mut parts = Vec::new();
1000
1001 parts.push(format!("Tool: {}", tool_name));
1002 parts.push(format!("Risk level: {}", risk_level));
1003
1004 if self.is_destructive_call(tool_name, args) {
1005 parts.push("This tool may modify or delete files.".to_string());
1006 }
1007
1008 if let Some(command) = command_text_for_tool(tool_name, args) {
1009 parts.push(format!("Command: {}", command));
1010 }
1011
1012 let file_targets = file_access_targets(tool_name, args);
1013 if let Some(target) = file_targets.first() {
1014 parts.push(format!("Path: {}", target.path.display()));
1015 if file_targets.len() > 1 {
1016 parts.push(format!("Additional targets: {}", file_targets.len() - 1));
1017 }
1018 }
1019
1020 parts.join("\n")
1021 }
1022
1023 pub fn get_stats(&self) -> SafetyStats {
1025 let state = self.rate_state.lock();
1026 let preapproved = self.preapproved.lock();
1027 let config = self.config.read();
1028
1029 SafetyStats {
1030 turn_count: state.current_turn_count,
1031 session_count: state.session_count,
1032 max_per_turn: config.max_per_turn,
1033 max_per_session: config.max_per_session,
1034 plan_mode_active: config.plan_mode_active,
1035 preapproved_count: preapproved.len(),
1036 }
1037 }
1038}
1039
1040impl Default for SafetyGateway {
1041 fn default() -> Self {
1042 Self::new()
1043 }
1044}
1045
1046#[derive(Debug, Clone)]
1048pub struct SafetyStats {
1049 pub turn_count: usize,
1050 pub session_count: usize,
1051 pub max_per_turn: usize,
1052 pub max_per_session: usize,
1053 pub plan_mode_active: bool,
1054 pub preapproved_count: usize,
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059 use super::*;
1060 use vtcode_config::core::DotfileProtectionConfig;
1061
1062 fn make_ctx() -> SafetyContext {
1063 SafetyContext::new("test-session")
1064 }
1065
1066 #[tokio::test]
1067 async fn test_allow_read_only_tools() {
1068 let gateway = SafetyGateway::new();
1069 let ctx = make_ctx();
1070
1071 let decision = gateway
1072 .check_safety(&ctx, "read_file", &serde_json::json!({"path": "/tmp/test"}))
1073 .await;
1074
1075 assert!(decision.is_allowed());
1076 }
1077
1078 #[tokio::test]
1079 async fn test_destructive_tool_needs_approval() {
1080 let gateway = SafetyGateway::new();
1081 let ctx = make_ctx();
1082
1083 let decision = gateway
1084 .check_safety(
1085 &ctx,
1086 "delete_file",
1087 &serde_json::json!({"path": "/tmp/test"}),
1088 )
1089 .await;
1090
1091 assert!(decision.needs_approval());
1092 }
1093
1094 #[tokio::test]
1095 async fn test_plan_mode_blocks_mutating() {
1096 let gateway = SafetyGateway::new();
1097 gateway.set_plan_mode(true);
1098 let ctx = make_ctx();
1099
1100 let decision = gateway
1101 .check_safety(
1102 &ctx,
1103 "write_file",
1104 &serde_json::json!({"path": "/tmp/test"}),
1105 )
1106 .await;
1107
1108 assert!(decision.is_denied());
1109 assert!(decision.reason().unwrap().contains("plan mode"));
1110 }
1111
1112 #[tokio::test]
1113 async fn test_preapproved_tools_allowed() {
1114 let gateway = SafetyGateway::new();
1115 gateway.preapprove("delete_file");
1116 let ctx = make_ctx();
1117
1118 let decision = gateway
1119 .check_safety(
1120 &ctx,
1121 "delete_file",
1122 &serde_json::json!({"path": "/tmp/test"}),
1123 )
1124 .await;
1125
1126 assert!(decision.is_allowed());
1127 }
1128
1129 #[tokio::test]
1130 async fn test_trust_level_bypass() {
1131 let gateway = SafetyGateway::new();
1132 let mut ctx = make_ctx();
1133 ctx.trust_level = SafetyTrustLevel::Full;
1134
1135 let decision = gateway
1136 .check_safety(
1137 &ctx,
1138 "delete_file",
1139 &serde_json::json!({"path": "/tmp/test"}),
1140 )
1141 .await;
1142
1143 assert!(decision.is_allowed());
1144 }
1145
1146 #[tokio::test]
1147 async fn test_rate_limiting() {
1148 let config = SafetyGatewayConfig {
1149 rate_limit_per_second: 2,
1150 ..Default::default()
1151 };
1152 let gateway = SafetyGateway::with_config(config);
1153 let ctx = make_ctx();
1154
1155 gateway.record_execution();
1157 gateway.record_execution();
1158
1159 let decision = gateway
1161 .check_safety(&ctx, "read_file", &serde_json::json!({}))
1162 .await;
1163
1164 assert!(decision.is_denied());
1165 assert!(decision.reason().unwrap().contains("Rate limit"));
1166 }
1167
1168 #[tokio::test]
1169 async fn test_atomic_check_and_record_rate_limited() {
1170 let config = SafetyGatewayConfig {
1171 rate_limit_per_second: 2,
1172 ..Default::default()
1173 };
1174 let gateway = SafetyGateway::with_config(config);
1175 let ctx = make_ctx();
1176
1177 let first = gateway
1178 .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1179 .await;
1180 assert!(first.decision.is_allowed());
1181
1182 let second = gateway
1183 .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1184 .await;
1185 assert!(second.decision.is_allowed());
1186
1187 let third = gateway
1188 .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1189 .await;
1190 assert!(third.decision.is_denied());
1191 assert!(third.retry_after.is_some());
1192 assert!(matches!(
1193 third.violation,
1194 Some(SafetyError::RateLimitExceeded { .. })
1195 ));
1196 }
1197
1198 #[tokio::test]
1199 async fn test_command_policy_enforcement() {
1200 let mut commands_config = CommandsConfig::default();
1201 commands_config.deny_list.push("rm".to_string());
1202
1203 let gateway = SafetyGateway::new().with_commands_config(&commands_config);
1204 let ctx = make_ctx();
1205
1206 let decision = gateway
1207 .check_safety(&ctx, "shell", &serde_json::json!({"command": "rm -rf /"}))
1208 .await;
1209
1210 assert!(decision.is_denied());
1211 }
1212
1213 #[tokio::test]
1214 async fn test_unified_exec_command_policy_enforcement_with_indexed_args() {
1215 let mut commands_config = CommandsConfig::default();
1216 commands_config.deny_list.push("rm".to_string());
1217
1218 let gateway = SafetyGateway::new().with_commands_config(&commands_config);
1219 let ctx = make_ctx();
1220
1221 let decision = gateway
1222 .check_safety(
1223 &ctx,
1224 tools::UNIFIED_EXEC,
1225 &serde_json::json!({
1226 "command.0": "rm",
1227 "command.1": "-rf",
1228 "command.2": "/"
1229 }),
1230 )
1231 .await;
1232
1233 assert!(decision.is_denied());
1234 }
1235
1236 #[tokio::test]
1237 async fn test_unified_exec_continue_command_policy_enforcement_with_input() {
1238 let mut commands_config = CommandsConfig::default();
1239 commands_config.deny_list.push("rm".to_string());
1240
1241 let gateway = SafetyGateway::new().with_commands_config(&commands_config);
1242 let ctx = make_ctx();
1243
1244 let decision = gateway
1245 .check_safety(
1246 &ctx,
1247 tools::UNIFIED_EXEC,
1248 &serde_json::json!({
1249 "action": "continue",
1250 "session_id": "run-123",
1251 "input": "rm -rf /\n"
1252 }),
1253 )
1254 .await;
1255
1256 assert!(decision.is_denied());
1257 }
1258
1259 #[tokio::test]
1260 async fn test_send_pty_input_command_policy_enforcement_with_input() {
1261 let mut commands_config = CommandsConfig::default();
1262 commands_config.deny_list.push("rm".to_string());
1263
1264 let gateway = SafetyGateway::new().with_commands_config(&commands_config);
1265 let ctx = make_ctx();
1266
1267 let decision = gateway
1268 .check_safety(
1269 &ctx,
1270 tools::SEND_PTY_INPUT,
1271 &serde_json::json!({
1272 "session_id": "run-123",
1273 "input": "rm -rf /\n"
1274 }),
1275 )
1276 .await;
1277
1278 assert!(decision.is_denied());
1279 }
1280
1281 #[tokio::test]
1282 async fn test_apply_patch_dotfile_protection_requires_approval() {
1283 let guardian = Arc::new(
1284 DotfileGuardian::new(DotfileProtectionConfig {
1285 audit_logging_enabled: false,
1286 create_backups: false,
1287 ..Default::default()
1288 })
1289 .await
1290 .expect("guardian should initialize"),
1291 );
1292 let gateway = SafetyGateway::new().with_dotfile_guardian(guardian);
1293 let ctx = make_ctx();
1294
1295 let decision = gateway
1296 .check_safety(
1297 &ctx,
1298 tools::APPLY_PATCH,
1299 &serde_json::json!({
1300 "input": "*** Begin Patch\n*** Update File: .gitignore\n@@\n-old\n+new\n*** End Patch\n"
1301 }),
1302 )
1303 .await;
1304
1305 assert!(decision.needs_approval());
1306 assert!(
1307 decision
1308 .reason()
1309 .is_some_and(|reason| reason.contains(".gitignore"))
1310 );
1311 }
1312
1313 #[test]
1314 fn test_patch_file_access_targets_preserve_patch_order() {
1315 let targets = patch_file_access_targets(&serde_json::json!({
1316 "input": "*** Begin Patch\n*** Update File: src/main.rs\n@@\n-old\n+new\n*** Update File: .gitignore\n@@\n-old\n+new\n*** End Patch\n"
1317 }));
1318
1319 assert_eq!(targets.len(), 2);
1320 assert_eq!(targets[0].path, PathBuf::from("src/main.rs"));
1321 assert_eq!(targets[0].access_type, AccessType::Modify);
1322 assert_eq!(targets[1].path, PathBuf::from(".gitignore"));
1323 assert_eq!(targets[1].access_type, AccessType::Modify);
1324 }
1325
1326 #[tokio::test]
1327 async fn test_unified_file_patch_dotfile_protection_requires_approval() {
1328 let guardian = Arc::new(
1329 DotfileGuardian::new(DotfileProtectionConfig {
1330 audit_logging_enabled: false,
1331 create_backups: false,
1332 ..Default::default()
1333 })
1334 .await
1335 .expect("guardian should initialize"),
1336 );
1337 let gateway = SafetyGateway::new().with_dotfile_guardian(guardian);
1338 let ctx = make_ctx();
1339
1340 let decision = gateway
1341 .check_safety(
1342 &ctx,
1343 tools::UNIFIED_FILE,
1344 &serde_json::json!({
1345 "action": "patch",
1346 "patch": "*** Begin Patch\n*** Update File: .gitignore\n@@\n-old\n+new\n*** End Patch\n"
1347 }),
1348 )
1349 .await;
1350
1351 assert!(decision.needs_approval());
1352 assert!(
1353 decision
1354 .reason()
1355 .is_some_and(|reason| reason.contains(".gitignore"))
1356 );
1357 }
1358
1359 #[tokio::test]
1360 async fn test_stats_tracking() {
1361 let gateway = SafetyGateway::new();
1362 gateway.preapprove("test_tool");
1363 gateway.record_execution();
1364 gateway.record_execution();
1365
1366 let stats = gateway.get_stats();
1367 assert_eq!(stats.turn_count, 2);
1368 assert_eq!(stats.session_count, 2);
1369 assert_eq!(stats.preapproved_count, 1);
1370 }
1371
1372 #[tokio::test]
1373 async fn test_start_turn_resets_counters() {
1374 let gateway = SafetyGateway::new();
1375
1376 gateway.record_execution();
1377 gateway.record_execution();
1378
1379 let stats_before = gateway.get_stats();
1380 assert_eq!(stats_before.turn_count, 2);
1381
1382 gateway.start_turn();
1383
1384 let stats_after = gateway.get_stats();
1385 assert_eq!(stats_after.turn_count, 0);
1386 assert_eq!(stats_after.session_count, 2); }
1388
1389 #[tokio::test]
1390 async fn test_increase_session_limit_updates_limit() {
1391 let gateway = SafetyGateway::new();
1392 gateway.set_limits(10, 1);
1393 let ctx = make_ctx();
1394
1395 let first = gateway
1396 .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1397 .await;
1398 assert!(first.decision.is_allowed());
1399
1400 let second = gateway
1401 .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1402 .await;
1403 assert!(second.decision.is_denied());
1404
1405 gateway.increase_session_limit(1);
1406
1407 let third = gateway
1408 .check_and_record(&ctx, "read_file", &serde_json::json!({}))
1409 .await;
1410 assert!(third.decision.is_allowed());
1411 }
1412}