1use anyhow::{Context, Result};
8use console::{Color as ConsoleColor, Style as ConsoleStyle, style};
9use dialoguer::{Confirm, theme::ColorfulTheme};
10use indexmap::IndexMap;
11use is_terminal::IsTerminal;
12use serde::{Deserialize, Serialize};
13use std::collections::{BTreeMap, HashMap, HashSet};
14use std::fs;
15use std::path::{Path, PathBuf};
16
17use crate::ui::theme;
18use crate::utils::ansi::{AnsiRenderer, MessageStyle};
19
20use crate::config::constants::tools;
21use crate::config::core::tools::{ToolPolicy as ConfigToolPolicy, ToolsConfig};
22use crate::config::mcp::{McpAllowListConfig, McpAllowListRules};
23
24const AUTO_ALLOW_TOOLS: &[&str] = &[
25 tools::GREP_SEARCH,
26 tools::LIST_FILES,
27 tools::UPDATE_PLAN,
28 tools::RUN_TERMINAL_CMD,
29 tools::READ_FILE,
30 tools::EDIT_FILE,
31 tools::AST_GREP_SEARCH,
32 tools::SIMPLE_SEARCH,
33 tools::BASH,
34];
35const DEFAULT_CURL_MAX_RESPONSE_BYTES: usize = 64 * 1024;
36
37#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
39#[serde(rename_all = "lowercase")]
40pub enum ToolPolicy {
41 Allow,
43 Prompt,
45 Deny,
47}
48
49impl Default for ToolPolicy {
50 fn default() -> Self {
51 ToolPolicy::Prompt
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct ToolPolicyConfig {
58 pub version: u32,
60 pub available_tools: Vec<String>,
62 pub policies: IndexMap<String, ToolPolicy>,
64 #[serde(default)]
66 pub constraints: IndexMap<String, ToolConstraints>,
67 #[serde(default)]
69 pub mcp: McpPolicyStore,
70}
71
72impl Default for ToolPolicyConfig {
73 fn default() -> Self {
74 Self {
75 version: 1,
76 available_tools: Vec::new(),
77 policies: IndexMap::new(),
78 constraints: IndexMap::new(),
79 mcp: McpPolicyStore::default(),
80 }
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct McpPolicyStore {
87 #[serde(default = "default_secure_mcp_allowlist")]
89 pub allowlist: McpAllowListConfig,
90 #[serde(default)]
92 pub providers: IndexMap<String, McpProviderPolicy>,
93}
94
95impl Default for McpPolicyStore {
96 fn default() -> Self {
97 Self {
98 allowlist: default_secure_mcp_allowlist(),
99 providers: IndexMap::new(),
100 }
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize, Default)]
106pub struct McpProviderPolicy {
107 #[serde(default)]
108 pub tools: IndexMap<String, ToolPolicy>,
109}
110
111fn default_secure_mcp_allowlist() -> McpAllowListConfig {
112 let mut allowlist = McpAllowListConfig::default();
113 allowlist.enforce = true;
114
115 allowlist.default.logging = Some(vec![
116 "mcp.provider_initialized".to_string(),
117 "mcp.provider_initialization_failed".to_string(),
118 "mcp.tool_filtered".to_string(),
119 "mcp.tool_execution".to_string(),
120 "mcp.tool_failed".to_string(),
121 "mcp.tool_denied".to_string(),
122 ]);
123
124 allowlist.default.configuration = Some(BTreeMap::from([
125 (
126 "client".to_string(),
127 vec![
128 "max_concurrent_connections".to_string(),
129 "request_timeout_seconds".to_string(),
130 "retry_attempts".to_string(),
131 ],
132 ),
133 (
134 "ui".to_string(),
135 vec![
136 "mode".to_string(),
137 "max_events".to_string(),
138 "show_provider_names".to_string(),
139 ],
140 ),
141 (
142 "server".to_string(),
143 vec![
144 "enabled".to_string(),
145 "bind_address".to_string(),
146 "port".to_string(),
147 "transport".to_string(),
148 "name".to_string(),
149 "version".to_string(),
150 ],
151 ),
152 ]));
153
154 let mut time_rules = McpAllowListRules::default();
155 time_rules.tools = Some(vec![
156 "get_*".to_string(),
157 "list_*".to_string(),
158 "convert_timezone".to_string(),
159 "describe_timezone".to_string(),
160 "time_*".to_string(),
161 ]);
162 time_rules.resources = Some(vec!["timezone:*".to_string(), "location:*".to_string()]);
163 time_rules.logging = Some(vec![
164 "mcp.tool_execution".to_string(),
165 "mcp.tool_failed".to_string(),
166 "mcp.tool_denied".to_string(),
167 "mcp.tool_filtered".to_string(),
168 "mcp.provider_initialized".to_string(),
169 ]);
170 time_rules.configuration = Some(BTreeMap::from([
171 (
172 "provider".to_string(),
173 vec!["max_concurrent_requests".to_string()],
174 ),
175 (
176 "time".to_string(),
177 vec!["local_timezone_override".to_string()],
178 ),
179 ]));
180 allowlist.providers.insert("time".to_string(), time_rules);
181
182 let mut context_rules = McpAllowListRules::default();
183 context_rules.tools = Some(vec![
184 "search_*".to_string(),
185 "fetch_*".to_string(),
186 "list_*".to_string(),
187 "context7_*".to_string(),
188 "get_*".to_string(),
189 ]);
190 context_rules.resources = Some(vec![
191 "docs::*".to_string(),
192 "snippets::*".to_string(),
193 "repositories::*".to_string(),
194 "context7::*".to_string(),
195 ]);
196 context_rules.prompts = Some(vec![
197 "context7::*".to_string(),
198 "support::*".to_string(),
199 "docs::*".to_string(),
200 ]);
201 context_rules.logging = Some(vec![
202 "mcp.tool_execution".to_string(),
203 "mcp.tool_failed".to_string(),
204 "mcp.tool_denied".to_string(),
205 "mcp.tool_filtered".to_string(),
206 "mcp.provider_initialized".to_string(),
207 ]);
208 context_rules.configuration = Some(BTreeMap::from([
209 (
210 "provider".to_string(),
211 vec!["max_concurrent_requests".to_string()],
212 ),
213 (
214 "context7".to_string(),
215 vec![
216 "workspace".to_string(),
217 "search_scope".to_string(),
218 "max_results".to_string(),
219 ],
220 ),
221 ]));
222 allowlist
223 .providers
224 .insert("context7".to_string(), context_rules);
225
226 let mut seq_rules = McpAllowListRules::default();
227 seq_rules.tools = Some(vec![
228 "plan".to_string(),
229 "critique".to_string(),
230 "reflect".to_string(),
231 "decompose".to_string(),
232 "sequential_*".to_string(),
233 ]);
234 seq_rules.prompts = Some(vec![
235 "sequential-thinking::*".to_string(),
236 "plan".to_string(),
237 "reflect".to_string(),
238 "critique".to_string(),
239 ]);
240 seq_rules.logging = Some(vec![
241 "mcp.tool_execution".to_string(),
242 "mcp.tool_failed".to_string(),
243 "mcp.tool_denied".to_string(),
244 "mcp.tool_filtered".to_string(),
245 "mcp.provider_initialized".to_string(),
246 ]);
247 seq_rules.configuration = Some(BTreeMap::from([
248 (
249 "provider".to_string(),
250 vec!["max_concurrent_requests".to_string()],
251 ),
252 (
253 "sequencing".to_string(),
254 vec!["max_depth".to_string(), "max_branches".to_string()],
255 ),
256 ]));
257 allowlist
258 .providers
259 .insert("sequential-thinking".to_string(), seq_rules);
260
261 allowlist
262}
263
264fn parse_mcp_policy_key(tool_name: &str) -> Option<(String, String)> {
265 let mut parts = tool_name.splitn(3, "::");
266 match (parts.next()?, parts.next(), parts.next()) {
267 ("mcp", Some(provider), Some(tool)) if !provider.is_empty() && !tool.is_empty() => {
268 Some((provider.to_string(), tool.to_string()))
269 }
270 _ => None,
271 }
272}
273
274#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct AlternativeToolPolicyConfig {
277 pub version: u32,
279 pub default: AlternativeDefaultPolicy,
281 pub tools: IndexMap<String, AlternativeToolPolicy>,
283 #[serde(default)]
285 pub constraints: IndexMap<String, ToolConstraints>,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct AlternativeDefaultPolicy {
291 pub allow: bool,
293 pub rate_limit_per_run: u32,
295 pub max_concurrent: u32,
297 pub fs_write: bool,
299 pub network: bool,
301}
302
303#[derive(Debug, Clone, Serialize, Deserialize)]
305pub struct AlternativeToolPolicy {
306 pub allow: bool,
308 #[serde(default)]
310 pub fs_write: bool,
311 #[serde(default)]
313 pub network: bool,
314 #[serde(default)]
316 pub args_policy: Option<AlternativeArgsPolicy>,
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct AlternativeArgsPolicy {
322 pub deny_substrings: Vec<String>,
324}
325
326#[derive(Clone)]
328pub struct ToolPolicyManager {
329 config_path: PathBuf,
330 config: ToolPolicyConfig,
331}
332
333impl ToolPolicyManager {
334 pub fn new() -> Result<Self> {
336 let config_path = Self::get_config_path()?;
337 let config = Self::load_or_create_config(&config_path)?;
338
339 Ok(Self {
340 config_path,
341 config,
342 })
343 }
344
345 pub fn new_with_workspace(workspace_root: &PathBuf) -> Result<Self> {
347 let config_path = Self::get_workspace_config_path(workspace_root)?;
348 let config = Self::load_or_create_config(&config_path)?;
349
350 Ok(Self {
351 config_path,
352 config,
353 })
354 }
355
356 fn get_config_path() -> Result<PathBuf> {
358 let home_dir = dirs::home_dir().context("Could not determine home directory")?;
359
360 let vtcode_dir = home_dir.join(".vtcode");
361 if !vtcode_dir.exists() {
362 fs::create_dir_all(&vtcode_dir).context("Failed to create ~/.vtcode directory")?;
363 }
364
365 Ok(vtcode_dir.join("tool-policy.json"))
366 }
367
368 fn get_workspace_config_path(workspace_root: &PathBuf) -> Result<PathBuf> {
370 let workspace_vtcode_dir = workspace_root.join(".vtcode");
371
372 if !workspace_vtcode_dir.exists() {
373 fs::create_dir_all(&workspace_vtcode_dir).with_context(|| {
374 format!(
375 "Failed to create workspace policy directory at {}",
376 workspace_vtcode_dir.display()
377 )
378 })?;
379 }
380
381 Ok(workspace_vtcode_dir.join("tool-policy.json"))
382 }
383
384 fn load_or_create_config(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
386 if config_path.exists() {
387 let content =
388 fs::read_to_string(config_path).context("Failed to read tool policy config")?;
389
390 if let Ok(alt_config) = serde_json::from_str::<AlternativeToolPolicyConfig>(&content) {
392 return Ok(Self::convert_from_alternative(alt_config));
394 }
395
396 match serde_json::from_str(&content) {
398 Ok(mut config) => {
399 Self::apply_auto_allow_defaults(&mut config);
400 Self::ensure_network_constraints(&mut config);
401 Ok(config)
402 }
403 Err(parse_err) => {
404 eprintln!(
405 "Warning: Invalid tool policy config at {} ({}). Resetting to defaults.",
406 config_path.display(),
407 parse_err
408 );
409 Self::reset_to_default(config_path)
410 }
411 }
412 } else {
413 let mut config = ToolPolicyConfig::default();
415 Self::apply_auto_allow_defaults(&mut config);
416 Self::ensure_network_constraints(&mut config);
417 Ok(config)
418 }
419 }
420
421 fn apply_auto_allow_defaults(config: &mut ToolPolicyConfig) {
422 for tool in AUTO_ALLOW_TOOLS {
423 config
424 .policies
425 .entry((*tool).to_string())
426 .and_modify(|policy| *policy = ToolPolicy::Allow)
427 .or_insert(ToolPolicy::Allow);
428 if !config.available_tools.contains(&tool.to_string()) {
429 config.available_tools.push(tool.to_string());
430 }
431 }
432 Self::ensure_network_constraints(config);
433 }
434
435 fn ensure_network_constraints(config: &mut ToolPolicyConfig) {
436 let entry = config
437 .constraints
438 .entry(tools::CURL.to_string())
439 .or_insert_with(ToolConstraints::default);
440
441 if entry.max_response_bytes.is_none() {
442 entry.max_response_bytes = Some(DEFAULT_CURL_MAX_RESPONSE_BYTES);
443 }
444 if entry.allowed_url_schemes.is_none() {
445 entry.allowed_url_schemes = Some(vec!["https".to_string()]);
446 }
447 if entry.denied_url_hosts.is_none() {
448 entry.denied_url_hosts = Some(vec![
449 "localhost".to_string(),
450 "127.0.0.1".to_string(),
451 "0.0.0.0".to_string(),
452 "::1".to_string(),
453 ".localhost".to_string(),
454 ".local".to_string(),
455 ".internal".to_string(),
456 ".lan".to_string(),
457 ]);
458 }
459 }
460
461 fn reset_to_default(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
462 let backup_path = config_path.with_extension("json.bak");
463
464 if let Err(err) = fs::rename(config_path, &backup_path) {
465 eprintln!(
466 "Warning: Unable to back up invalid tool policy config ({}). {}",
467 config_path.display(),
468 err
469 );
470 } else {
471 eprintln!(
472 "Backed up invalid tool policy config to {}",
473 backup_path.display()
474 );
475 }
476
477 let default_config = ToolPolicyConfig::default();
478 Self::write_config(config_path.as_path(), &default_config)?;
479 Ok(default_config)
480 }
481
482 fn write_config(path: &Path, config: &ToolPolicyConfig) -> Result<()> {
483 if let Some(parent) = path.parent() {
484 if !parent.exists() {
485 fs::create_dir_all(parent).with_context(|| {
486 format!(
487 "Failed to create directory for tool policy config at {}",
488 parent.display()
489 )
490 })?;
491 }
492 }
493
494 let serialized = serde_json::to_string_pretty(config)
495 .context("Failed to serialize tool policy config")?;
496
497 fs::write(path, serialized)
498 .with_context(|| format!("Failed to write tool policy config: {}", path.display()))
499 }
500
501 fn convert_from_alternative(alt_config: AlternativeToolPolicyConfig) -> ToolPolicyConfig {
503 let mut policies = IndexMap::new();
504
505 for (tool_name, alt_policy) in alt_config.tools {
507 let policy = if alt_policy.allow {
508 ToolPolicy::Allow
509 } else {
510 ToolPolicy::Deny
511 };
512 policies.insert(tool_name, policy);
513 }
514
515 let mut config = ToolPolicyConfig {
516 version: alt_config.version,
517 available_tools: policies.keys().cloned().collect(),
518 policies,
519 constraints: alt_config.constraints,
520 mcp: McpPolicyStore::default(),
521 };
522 Self::apply_auto_allow_defaults(&mut config);
523 config
524 }
525
526 fn apply_config_policy(&mut self, tool_name: &str, policy: ConfigToolPolicy) {
527 let runtime_policy = match policy {
528 ConfigToolPolicy::Allow => ToolPolicy::Allow,
529 ConfigToolPolicy::Prompt => ToolPolicy::Prompt,
530 ConfigToolPolicy::Deny => ToolPolicy::Deny,
531 };
532
533 self.config
534 .policies
535 .insert(tool_name.to_string(), runtime_policy);
536 }
537
538 fn resolve_config_policy(tools_config: &ToolsConfig, tool_name: &str) -> ConfigToolPolicy {
539 if let Some(policy) = tools_config.policies.get(tool_name) {
540 return policy.clone();
541 }
542
543 match tool_name {
544 tools::LIST_FILES => tools_config
545 .policies
546 .get("list_dir")
547 .or_else(|| tools_config.policies.get("list_directory"))
548 .cloned(),
549 _ => None,
550 }
551 .unwrap_or_else(|| tools_config.default_policy.clone())
552 }
553
554 pub fn apply_tools_config(&mut self, tools_config: &ToolsConfig) -> Result<()> {
556 if self.config.available_tools.is_empty() {
557 return Ok(());
558 }
559
560 for tool in self.config.available_tools.clone() {
561 let config_policy = Self::resolve_config_policy(tools_config, &tool);
562 self.apply_config_policy(&tool, config_policy);
563 }
564
565 Self::apply_auto_allow_defaults(&mut self.config);
566 self.save_config()
567 }
568
569 pub fn update_available_tools(&mut self, tools: Vec<String>) -> Result<()> {
571 let current_tools: HashSet<_> = self.config.policies.keys().cloned().collect();
572 let new_tools: HashSet<_> = tools
573 .iter()
574 .filter(|name| !name.starts_with("mcp::"))
575 .cloned()
576 .collect();
577
578 let mut has_changes = false;
579
580 for tool in tools
582 .iter()
583 .filter(|tool| !tool.starts_with("mcp::") && !current_tools.contains(*tool))
584 {
585 let default_policy = if AUTO_ALLOW_TOOLS.contains(&tool.as_str()) {
586 ToolPolicy::Allow
587 } else {
588 ToolPolicy::Prompt
589 };
590 self.config.policies.insert(tool.clone(), default_policy);
591 has_changes = true;
592 }
593
594 let tools_to_remove: Vec<_> = self
596 .config
597 .policies
598 .keys()
599 .filter(|tool| !new_tools.contains(*tool))
600 .cloned()
601 .collect();
602
603 for tool in tools_to_remove {
604 self.config.policies.shift_remove(&tool);
605 has_changes = true;
606 }
607
608 if self.config.available_tools != tools {
610 self.config.available_tools = tools;
612 has_changes = true;
613 }
614
615 Self::ensure_network_constraints(&mut self.config);
616
617 if has_changes {
618 self.save_config()
619 } else {
620 Ok(())
621 }
622 }
623
624 pub fn update_mcp_tools(
626 &mut self,
627 provider_tools: &HashMap<String, Vec<String>>,
628 ) -> Result<()> {
629 let stored_providers: HashSet<String> = self.config.mcp.providers.keys().cloned().collect();
630 let mut has_changes = false;
631
632 for (provider, tools) in provider_tools {
634 let entry = self
635 .config
636 .mcp
637 .providers
638 .entry(provider.clone())
639 .or_insert_with(McpProviderPolicy::default);
640
641 let existing_tools: HashSet<String> = entry.tools.keys().cloned().collect();
642 let advertised: HashSet<String> = tools.iter().cloned().collect();
643
644 for tool in tools {
646 if !existing_tools.contains(tool) {
647 entry.tools.insert(tool.clone(), ToolPolicy::Prompt);
648 has_changes = true;
649 }
650 }
651
652 for stale in existing_tools.difference(&advertised) {
654 entry.tools.shift_remove(stale.as_str());
655 has_changes = true;
656 }
657 }
658
659 let advertised_providers: HashSet<String> = provider_tools.keys().cloned().collect();
661 for provider in stored_providers
662 .difference(&advertised_providers)
663 .cloned()
664 .collect::<Vec<_>>()
665 {
666 self.config.mcp.providers.shift_remove(provider.as_str());
667 has_changes = true;
668 }
669
670 let stale_runtime_keys: Vec<_> = self
672 .config
673 .policies
674 .keys()
675 .filter(|name| name.starts_with("mcp::"))
676 .cloned()
677 .collect();
678
679 for key in stale_runtime_keys {
680 self.config.policies.shift_remove(&key);
681 has_changes = true;
682 }
683
684 let mut available: Vec<String> = self
686 .config
687 .available_tools
688 .iter()
689 .filter(|name| !name.starts_with("mcp::"))
690 .cloned()
691 .collect();
692
693 for (provider, policy) in &self.config.mcp.providers {
694 for tool in policy.tools.keys() {
695 available.push(format!("mcp::{}::{}", provider, tool));
696 }
697 }
698
699 available.sort();
700 available.dedup();
701
702 if self.config.available_tools != available {
704 self.config.available_tools = available;
705 has_changes = true;
706 }
707
708 if has_changes {
709 self.save_config()
710 } else {
711 Ok(())
712 }
713 }
714
715 pub fn get_mcp_tool_policy(&self, provider: &str, tool: &str) -> ToolPolicy {
717 self.config
718 .mcp
719 .providers
720 .get(provider)
721 .and_then(|policy| policy.tools.get(tool))
722 .cloned()
723 .unwrap_or(ToolPolicy::Prompt)
724 }
725
726 pub fn set_mcp_tool_policy(
728 &mut self,
729 provider: &str,
730 tool: &str,
731 policy: ToolPolicy,
732 ) -> Result<()> {
733 let entry = self
734 .config
735 .mcp
736 .providers
737 .entry(provider.to_string())
738 .or_insert_with(McpProviderPolicy::default);
739 entry.tools.insert(tool.to_string(), policy);
740 self.save_config()
741 }
742
743 pub fn mcp_allowlist(&self) -> &McpAllowListConfig {
745 &self.config.mcp.allowlist
746 }
747
748 pub fn set_mcp_allowlist(&mut self, allowlist: McpAllowListConfig) -> Result<()> {
750 self.config.mcp.allowlist = allowlist;
751 self.save_config()
752 }
753
754 pub fn get_policy(&self, tool_name: &str) -> ToolPolicy {
756 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
757 return self.get_mcp_tool_policy(&provider, &tool);
758 }
759
760 self.config
761 .policies
762 .get(tool_name)
763 .cloned()
764 .unwrap_or(ToolPolicy::Prompt)
765 }
766
767 pub fn get_constraints(&self, tool_name: &str) -> Option<&ToolConstraints> {
769 self.config.constraints.get(tool_name)
770 }
771
772 pub fn should_execute_tool(&mut self, tool_name: &str) -> Result<bool> {
774 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
775 return match self.get_mcp_tool_policy(&provider, &tool) {
776 ToolPolicy::Allow => Ok(true),
777 ToolPolicy::Deny => Ok(false),
778 ToolPolicy::Prompt => {
779 if ToolPolicyManager::is_auto_allow_tool(tool_name) {
780 self.set_mcp_tool_policy(&provider, &tool, ToolPolicy::Allow)?;
781 Ok(true)
782 } else {
783 self.prompt_user_for_tool(tool_name)
784 }
785 }
786 };
787 }
788
789 match self.get_policy(tool_name) {
790 ToolPolicy::Allow => Ok(true),
791 ToolPolicy::Deny => Ok(false),
792 ToolPolicy::Prompt => {
793 if AUTO_ALLOW_TOOLS.contains(&tool_name) {
794 self.set_policy(tool_name, ToolPolicy::Allow)?;
795 return Ok(true);
796 }
797 let should_execute = self.prompt_user_for_tool(tool_name)?;
798 Ok(should_execute)
799 }
800 }
801 }
802
803 pub fn is_auto_allow_tool(tool_name: &str) -> bool {
804 AUTO_ALLOW_TOOLS.contains(&tool_name)
805 }
806
807 fn prompt_user_for_tool(&mut self, tool_name: &str) -> Result<bool> {
809 let interactive = std::io::stdin().is_terminal() && std::io::stdout().is_terminal();
810 let mut renderer = AnsiRenderer::stdout();
811 let banner_style = theme::banner_style();
812
813 if !interactive {
814 let message = format!(
815 "Non-interactive environment detected. Auto-approving '{}' tool.",
816 tool_name
817 );
818 renderer.line_with_style(banner_style, &message)?;
819 return Ok(true);
820 }
821
822 let header = format!("Tool Permission Request: {}", tool_name);
823 renderer.line_with_style(banner_style, &header)?;
824 renderer.line_with_style(
825 banner_style,
826 &format!("The agent wants to use the '{}' tool.", tool_name),
827 )?;
828 renderer.line_with_style(banner_style, "")?;
829 renderer.line_with_style(
830 banner_style,
831 "This decision applies to the current request only.",
832 )?;
833 renderer.line_with_style(
834 banner_style,
835 "Update the policy file or use CLI flags to change the default.",
836 )?;
837 renderer.line_with_style(banner_style, "")?;
838
839 if AUTO_ALLOW_TOOLS.contains(&tool_name) {
840 renderer.line_with_style(
841 banner_style,
842 &format!(
843 "Auto-approving '{}' tool (default trusted tool).",
844 tool_name
845 ),
846 )?;
847 return Ok(true);
848 }
849
850 let rgb = theme::banner_color();
851 let to_ansi_256 = |value: u8| -> u8 {
852 if value < 48 {
853 0
854 } else if value < 114 {
855 1
856 } else {
857 ((value - 35) / 40).min(5)
858 }
859 };
860 let rgb_to_index = |r: u8, g: u8, b: u8| -> u8 {
861 let r_idx = to_ansi_256(r);
862 let g_idx = to_ansi_256(g);
863 let b_idx = to_ansi_256(b);
864 16 + 36 * r_idx + 6 * g_idx + b_idx
865 };
866 let color_index = rgb_to_index(rgb.0, rgb.1, rgb.2);
867 let dialog_color = ConsoleColor::Color256(color_index);
868 let tinted_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
869
870 let mut dialog_theme = ColorfulTheme::default();
871 dialog_theme.prompt_style = tinted_style;
872 dialog_theme.prompt_prefix = style("—".to_string()).for_stderr().fg(dialog_color);
873 dialog_theme.prompt_suffix = style("—".to_string()).for_stderr().fg(dialog_color);
874 dialog_theme.hint_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
875 dialog_theme.defaults_style = dialog_theme.hint_style.clone();
876 dialog_theme.success_prefix = style("✓".to_string()).for_stderr().fg(dialog_color);
877 dialog_theme.success_suffix = style("·".to_string()).for_stderr().fg(dialog_color);
878 dialog_theme.error_prefix = style("✗".to_string()).for_stderr().fg(dialog_color);
879 dialog_theme.error_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
880 dialog_theme.values_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
881
882 let prompt_text = format!("Allow the agent to use '{}'?", tool_name);
883
884 match Confirm::with_theme(&dialog_theme)
885 .with_prompt(prompt_text)
886 .default(false)
887 .interact()
888 {
889 Ok(confirmed) => {
890 let message = if confirmed {
891 format!("✓ Approved: '{}' tool will run now", tool_name)
892 } else {
893 format!("✗ Denied: '{}' tool will not run", tool_name)
894 };
895 let style = if confirmed {
896 MessageStyle::Tool
897 } else {
898 MessageStyle::Error
899 };
900 renderer.line(style, &message)?;
901 Ok(confirmed)
902 }
903 Err(e) => {
904 renderer.line(
905 MessageStyle::Error,
906 &format!("Failed to read confirmation: {}", e),
907 )?;
908 Ok(false)
909 }
910 }
911 }
912
913 pub fn set_policy(&mut self, tool_name: &str, policy: ToolPolicy) -> Result<()> {
915 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
916 return self.set_mcp_tool_policy(&provider, &tool, policy);
917 }
918
919 self.config.policies.insert(tool_name.to_string(), policy);
920 self.save_config()
921 }
922
923 pub fn reset_all_to_prompt(&mut self) -> Result<()> {
925 for policy in self.config.policies.values_mut() {
926 *policy = ToolPolicy::Prompt;
927 }
928 for provider in self.config.mcp.providers.values_mut() {
929 for policy in provider.tools.values_mut() {
930 *policy = ToolPolicy::Prompt;
931 }
932 }
933 self.save_config()
934 }
935
936 pub fn allow_all_tools(&mut self) -> Result<()> {
938 for policy in self.config.policies.values_mut() {
939 *policy = ToolPolicy::Allow;
940 }
941 for provider in self.config.mcp.providers.values_mut() {
942 for policy in provider.tools.values_mut() {
943 *policy = ToolPolicy::Allow;
944 }
945 }
946 self.save_config()
947 }
948
949 pub fn deny_all_tools(&mut self) -> Result<()> {
951 for policy in self.config.policies.values_mut() {
952 *policy = ToolPolicy::Deny;
953 }
954 for provider in self.config.mcp.providers.values_mut() {
955 for policy in provider.tools.values_mut() {
956 *policy = ToolPolicy::Deny;
957 }
958 }
959 self.save_config()
960 }
961
962 pub fn get_policy_summary(&self) -> IndexMap<String, ToolPolicy> {
964 let mut summary = self.config.policies.clone();
965 for (provider, policy) in &self.config.mcp.providers {
966 for (tool, status) in &policy.tools {
967 summary.insert(format!("mcp::{}::{}", provider, tool), status.clone());
968 }
969 }
970 summary
971 }
972
973 fn save_config(&self) -> Result<()> {
975 Self::write_config(&self.config_path, &self.config)
976 }
977
978 pub fn print_status(&self) {
980 println!("{}", style("Tool Policy Status").cyan().bold());
981 println!("Config file: {}", self.config_path.display());
982 println!();
983
984 let summary = self.get_policy_summary();
985
986 if summary.is_empty() {
987 println!("No tools configured yet.");
988 return;
989 }
990
991 let mut allow_count = 0;
992 let mut prompt_count = 0;
993 let mut deny_count = 0;
994
995 for (tool, policy) in &summary {
996 let (status, color_name) = match policy {
997 ToolPolicy::Allow => {
998 allow_count += 1;
999 ("ALLOW", "green")
1000 }
1001 ToolPolicy::Prompt => {
1002 prompt_count += 1;
1003 ("PROMPT", "yellow")
1004 }
1005 ToolPolicy::Deny => {
1006 deny_count += 1;
1007 ("DENY", "red")
1008 }
1009 };
1010
1011 let status_styled = match color_name {
1012 "green" => style(status).green(),
1013 "yellow" => style(status).yellow(),
1014 "red" => style(status).red(),
1015 _ => style(status),
1016 };
1017
1018 println!(
1019 " {} {}",
1020 style(format!("{:15}", tool)).cyan(),
1021 status_styled
1022 );
1023 }
1024
1025 println!();
1026 println!(
1027 "Summary: {} allowed, {} prompt, {} denied",
1028 style(allow_count).green(),
1029 style(prompt_count).yellow(),
1030 style(deny_count).red()
1031 );
1032 }
1033
1034 pub fn config_path(&self) -> &Path {
1036 &self.config_path
1037 }
1038}
1039
1040#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1042pub struct ToolConstraints {
1043 #[serde(default)]
1045 pub allowed_modes: Option<Vec<String>>,
1046 #[serde(default)]
1048 pub max_results_per_call: Option<usize>,
1049 #[serde(default)]
1051 pub max_items_per_call: Option<usize>,
1052 #[serde(default)]
1054 pub default_response_format: Option<String>,
1055 #[serde(default)]
1057 pub max_bytes_per_read: Option<usize>,
1058 #[serde(default)]
1060 pub max_response_bytes: Option<usize>,
1061 #[serde(default)]
1063 pub allowed_url_schemes: Option<Vec<String>>,
1064 #[serde(default)]
1066 pub denied_url_hosts: Option<Vec<String>>,
1067}
1068
1069#[cfg(test)]
1070mod tests {
1071 use super::*;
1072 use crate::config::constants::tools;
1073 use tempfile::tempdir;
1074
1075 #[test]
1076 fn test_tool_policy_config_serialization() {
1077 let mut config = ToolPolicyConfig::default();
1078 config.available_tools = vec![tools::READ_FILE.to_string(), tools::WRITE_FILE.to_string()];
1079 config
1080 .policies
1081 .insert(tools::READ_FILE.to_string(), ToolPolicy::Allow);
1082 config
1083 .policies
1084 .insert(tools::WRITE_FILE.to_string(), ToolPolicy::Prompt);
1085
1086 let json = serde_json::to_string_pretty(&config).unwrap();
1087 let deserialized: ToolPolicyConfig = serde_json::from_str(&json).unwrap();
1088
1089 assert_eq!(config.available_tools, deserialized.available_tools);
1090 assert_eq!(config.policies, deserialized.policies);
1091 }
1092
1093 #[test]
1094 fn test_policy_updates() {
1095 let dir = tempdir().unwrap();
1096 let config_path = dir.path().join("tool-policy.json");
1097
1098 let mut config = ToolPolicyConfig::default();
1099 config.available_tools = vec!["tool1".to_string()];
1100 config
1101 .policies
1102 .insert("tool1".to_string(), ToolPolicy::Prompt);
1103
1104 let content = serde_json::to_string_pretty(&config).unwrap();
1106 fs::write(&config_path, content).unwrap();
1107
1108 let mut loaded_config = ToolPolicyManager::load_or_create_config(&config_path).unwrap();
1110
1111 let new_tools = vec!["tool1".to_string(), "tool2".to_string()];
1113 let current_tools: std::collections::HashSet<_> =
1114 loaded_config.available_tools.iter().collect();
1115
1116 for tool in &new_tools {
1117 if !current_tools.contains(tool) {
1118 loaded_config
1119 .policies
1120 .insert(tool.clone(), ToolPolicy::Prompt);
1121 }
1122 }
1123
1124 loaded_config.available_tools = new_tools;
1125
1126 assert_eq!(loaded_config.policies.len(), 2);
1127 assert_eq!(
1128 loaded_config.policies.get("tool2"),
1129 Some(&ToolPolicy::Prompt)
1130 );
1131 }
1132}