1use anyhow::{Context, Result};
8use dialoguer::{
9 Confirm,
10 console::{Color as ConsoleColor, Style as ConsoleStyle, style},
11 theme::ColorfulTheme,
12};
13use indexmap::IndexMap;
14use is_terminal::IsTerminal;
15use serde::{Deserialize, Serialize};
16use std::collections::{BTreeMap, HashMap, HashSet};
17use std::fs;
18use std::path::{Path, PathBuf};
19
20use crate::ui::theme;
21use crate::utils::ansi::{AnsiRenderer, MessageStyle};
22
23use crate::config::constants::tools;
24use crate::config::core::tools::{ToolPolicy as ConfigToolPolicy, ToolsConfig};
25use crate::config::mcp::{McpAllowListConfig, McpAllowListRules};
26
27const AUTO_ALLOW_TOOLS: &[&str] = &[
28 tools::GREP_SEARCH,
29 tools::LIST_FILES,
30 tools::UPDATE_PLAN,
31 tools::RUN_TERMINAL_CMD,
32 tools::READ_FILE,
33 tools::EDIT_FILE,
34 tools::AST_GREP_SEARCH,
35 tools::SIMPLE_SEARCH,
36 tools::BASH,
37];
38const DEFAULT_CURL_MAX_RESPONSE_BYTES: usize = 64 * 1024;
39
40#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum ToolPolicy {
44 Allow,
46 Prompt,
48 Deny,
50}
51
52impl Default for ToolPolicy {
53 fn default() -> Self {
54 ToolPolicy::Prompt
55 }
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ToolPolicyConfig {
61 pub version: u32,
63 pub available_tools: Vec<String>,
65 pub policies: IndexMap<String, ToolPolicy>,
67 #[serde(default)]
69 pub constraints: IndexMap<String, ToolConstraints>,
70 #[serde(default)]
72 pub mcp: McpPolicyStore,
73}
74
75impl Default for ToolPolicyConfig {
76 fn default() -> Self {
77 Self {
78 version: 1,
79 available_tools: Vec::new(),
80 policies: IndexMap::new(),
81 constraints: IndexMap::new(),
82 mcp: McpPolicyStore::default(),
83 }
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct McpPolicyStore {
90 #[serde(default = "default_secure_mcp_allowlist")]
92 pub allowlist: McpAllowListConfig,
93 #[serde(default)]
95 pub providers: IndexMap<String, McpProviderPolicy>,
96}
97
98impl Default for McpPolicyStore {
99 fn default() -> Self {
100 Self {
101 allowlist: default_secure_mcp_allowlist(),
102 providers: IndexMap::new(),
103 }
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize, Default)]
109pub struct McpProviderPolicy {
110 #[serde(default)]
111 pub tools: IndexMap<String, ToolPolicy>,
112}
113
114fn default_secure_mcp_allowlist() -> McpAllowListConfig {
115 let mut allowlist = McpAllowListConfig::default();
116 allowlist.enforce = true;
117
118 allowlist.default.logging = Some(vec![
119 "mcp.provider_initialized".to_string(),
120 "mcp.provider_initialization_failed".to_string(),
121 "mcp.tool_filtered".to_string(),
122 "mcp.tool_execution".to_string(),
123 "mcp.tool_failed".to_string(),
124 "mcp.tool_denied".to_string(),
125 ]);
126
127 allowlist.default.configuration = Some(BTreeMap::from([
128 (
129 "client".to_string(),
130 vec![
131 "max_concurrent_connections".to_string(),
132 "request_timeout_seconds".to_string(),
133 "retry_attempts".to_string(),
134 ],
135 ),
136 (
137 "ui".to_string(),
138 vec![
139 "mode".to_string(),
140 "max_events".to_string(),
141 "show_provider_names".to_string(),
142 ],
143 ),
144 (
145 "server".to_string(),
146 vec![
147 "enabled".to_string(),
148 "bind_address".to_string(),
149 "port".to_string(),
150 "transport".to_string(),
151 "name".to_string(),
152 "version".to_string(),
153 ],
154 ),
155 ]));
156
157 let mut time_rules = McpAllowListRules::default();
158 time_rules.tools = Some(vec![
159 "get_*".to_string(),
160 "list_*".to_string(),
161 "convert_timezone".to_string(),
162 "describe_timezone".to_string(),
163 "time_*".to_string(),
164 ]);
165 time_rules.resources = Some(vec!["timezone:*".to_string(), "location:*".to_string()]);
166 time_rules.logging = Some(vec![
167 "mcp.tool_execution".to_string(),
168 "mcp.tool_failed".to_string(),
169 "mcp.tool_denied".to_string(),
170 "mcp.tool_filtered".to_string(),
171 "mcp.provider_initialized".to_string(),
172 ]);
173 time_rules.configuration = Some(BTreeMap::from([
174 (
175 "provider".to_string(),
176 vec!["max_concurrent_requests".to_string()],
177 ),
178 (
179 "time".to_string(),
180 vec!["local_timezone_override".to_string()],
181 ),
182 ]));
183 allowlist.providers.insert("time".to_string(), time_rules);
184
185 let mut context_rules = McpAllowListRules::default();
186 context_rules.tools = Some(vec![
187 "search_*".to_string(),
188 "fetch_*".to_string(),
189 "list_*".to_string(),
190 "context7_*".to_string(),
191 "get_*".to_string(),
192 ]);
193 context_rules.resources = Some(vec![
194 "docs::*".to_string(),
195 "snippets::*".to_string(),
196 "repositories::*".to_string(),
197 "context7::*".to_string(),
198 ]);
199 context_rules.prompts = Some(vec![
200 "context7::*".to_string(),
201 "support::*".to_string(),
202 "docs::*".to_string(),
203 ]);
204 context_rules.logging = Some(vec![
205 "mcp.tool_execution".to_string(),
206 "mcp.tool_failed".to_string(),
207 "mcp.tool_denied".to_string(),
208 "mcp.tool_filtered".to_string(),
209 "mcp.provider_initialized".to_string(),
210 ]);
211 context_rules.configuration = Some(BTreeMap::from([
212 (
213 "provider".to_string(),
214 vec!["max_concurrent_requests".to_string()],
215 ),
216 (
217 "context7".to_string(),
218 vec![
219 "workspace".to_string(),
220 "search_scope".to_string(),
221 "max_results".to_string(),
222 ],
223 ),
224 ]));
225 allowlist
226 .providers
227 .insert("context7".to_string(), context_rules);
228
229 let mut seq_rules = McpAllowListRules::default();
230 seq_rules.tools = Some(vec![
231 "plan".to_string(),
232 "critique".to_string(),
233 "reflect".to_string(),
234 "decompose".to_string(),
235 "sequential_*".to_string(),
236 ]);
237 seq_rules.prompts = Some(vec![
238 "sequential-thinking::*".to_string(),
239 "plan".to_string(),
240 "reflect".to_string(),
241 "critique".to_string(),
242 ]);
243 seq_rules.logging = Some(vec![
244 "mcp.tool_execution".to_string(),
245 "mcp.tool_failed".to_string(),
246 "mcp.tool_denied".to_string(),
247 "mcp.tool_filtered".to_string(),
248 "mcp.provider_initialized".to_string(),
249 ]);
250 seq_rules.configuration = Some(BTreeMap::from([
251 (
252 "provider".to_string(),
253 vec!["max_concurrent_requests".to_string()],
254 ),
255 (
256 "sequencing".to_string(),
257 vec!["max_depth".to_string(), "max_branches".to_string()],
258 ),
259 ]));
260 allowlist
261 .providers
262 .insert("sequential-thinking".to_string(), seq_rules);
263
264 allowlist
265}
266
267fn parse_mcp_policy_key(tool_name: &str) -> Option<(String, String)> {
268 let mut parts = tool_name.splitn(3, "::");
269 match (parts.next()?, parts.next(), parts.next()) {
270 ("mcp", Some(provider), Some(tool)) if !provider.is_empty() && !tool.is_empty() => {
271 Some((provider.to_string(), tool.to_string()))
272 }
273 _ => None,
274 }
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct AlternativeToolPolicyConfig {
280 pub version: u32,
282 pub default: AlternativeDefaultPolicy,
284 pub tools: IndexMap<String, AlternativeToolPolicy>,
286 #[serde(default)]
288 pub constraints: IndexMap<String, ToolConstraints>,
289}
290
291#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct AlternativeDefaultPolicy {
294 pub allow: bool,
296 pub rate_limit_per_run: u32,
298 pub max_concurrent: u32,
300 pub fs_write: bool,
302 pub network: bool,
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct AlternativeToolPolicy {
309 pub allow: bool,
311 #[serde(default)]
313 pub fs_write: bool,
314 #[serde(default)]
316 pub network: bool,
317 #[serde(default)]
319 pub args_policy: Option<AlternativeArgsPolicy>,
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct AlternativeArgsPolicy {
325 pub deny_substrings: Vec<String>,
327}
328
329#[derive(Clone)]
331pub struct ToolPolicyManager {
332 config_path: PathBuf,
333 config: ToolPolicyConfig,
334}
335
336impl ToolPolicyManager {
337 pub fn new() -> Result<Self> {
339 let config_path = Self::get_config_path()?;
340 let config = Self::load_or_create_config(&config_path)?;
341
342 Ok(Self {
343 config_path,
344 config,
345 })
346 }
347
348 pub fn new_with_workspace(workspace_root: &PathBuf) -> Result<Self> {
350 let config_path = Self::get_workspace_config_path(workspace_root)?;
351 let config = Self::load_or_create_config(&config_path)?;
352
353 Ok(Self {
354 config_path,
355 config,
356 })
357 }
358
359 fn get_config_path() -> Result<PathBuf> {
361 let home_dir = dirs::home_dir().context("Could not determine home directory")?;
362
363 let vtcode_dir = home_dir.join(".vtcode");
364 if !vtcode_dir.exists() {
365 fs::create_dir_all(&vtcode_dir).context("Failed to create ~/.vtcode directory")?;
366 }
367
368 Ok(vtcode_dir.join("tool-policy.json"))
369 }
370
371 fn get_workspace_config_path(workspace_root: &PathBuf) -> Result<PathBuf> {
373 let workspace_vtcode_dir = workspace_root.join(".vtcode");
374
375 if !workspace_vtcode_dir.exists() {
376 fs::create_dir_all(&workspace_vtcode_dir).with_context(|| {
377 format!(
378 "Failed to create workspace policy directory at {}",
379 workspace_vtcode_dir.display()
380 )
381 })?;
382 }
383
384 Ok(workspace_vtcode_dir.join("tool-policy.json"))
385 }
386
387 fn load_or_create_config(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
389 if config_path.exists() {
390 let content =
391 fs::read_to_string(config_path).context("Failed to read tool policy config")?;
392
393 if let Ok(alt_config) = serde_json::from_str::<AlternativeToolPolicyConfig>(&content) {
395 return Ok(Self::convert_from_alternative(alt_config));
397 }
398
399 match serde_json::from_str(&content) {
401 Ok(mut config) => {
402 Self::apply_auto_allow_defaults(&mut config);
403 Self::ensure_network_constraints(&mut config);
404 Ok(config)
405 }
406 Err(parse_err) => {
407 eprintln!(
408 "Warning: Invalid tool policy config at {} ({}). Resetting to defaults.",
409 config_path.display(),
410 parse_err
411 );
412 Self::reset_to_default(config_path)
413 }
414 }
415 } else {
416 let mut config = ToolPolicyConfig::default();
418 Self::apply_auto_allow_defaults(&mut config);
419 Self::ensure_network_constraints(&mut config);
420 Ok(config)
421 }
422 }
423
424 fn apply_auto_allow_defaults(config: &mut ToolPolicyConfig) {
425 for tool in AUTO_ALLOW_TOOLS {
426 config
427 .policies
428 .entry((*tool).to_string())
429 .and_modify(|policy| *policy = ToolPolicy::Allow)
430 .or_insert(ToolPolicy::Allow);
431 if !config.available_tools.contains(&tool.to_string()) {
432 config.available_tools.push(tool.to_string());
433 }
434 }
435 Self::ensure_network_constraints(config);
436 }
437
438 fn ensure_network_constraints(config: &mut ToolPolicyConfig) {
439 let entry = config
440 .constraints
441 .entry(tools::CURL.to_string())
442 .or_insert_with(ToolConstraints::default);
443
444 if entry.max_response_bytes.is_none() {
445 entry.max_response_bytes = Some(DEFAULT_CURL_MAX_RESPONSE_BYTES);
446 }
447 if entry.allowed_url_schemes.is_none() {
448 entry.allowed_url_schemes = Some(vec!["https".to_string()]);
449 }
450 if entry.denied_url_hosts.is_none() {
451 entry.denied_url_hosts = Some(vec![
452 "localhost".to_string(),
453 "127.0.0.1".to_string(),
454 "0.0.0.0".to_string(),
455 "::1".to_string(),
456 ".localhost".to_string(),
457 ".local".to_string(),
458 ".internal".to_string(),
459 ".lan".to_string(),
460 ]);
461 }
462 }
463
464 fn reset_to_default(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
465 let backup_path = config_path.with_extension("json.bak");
466
467 if let Err(err) = fs::rename(config_path, &backup_path) {
468 eprintln!(
469 "Warning: Unable to back up invalid tool policy config ({}). {}",
470 config_path.display(),
471 err
472 );
473 } else {
474 eprintln!(
475 "Backed up invalid tool policy config to {}",
476 backup_path.display()
477 );
478 }
479
480 let default_config = ToolPolicyConfig::default();
481 Self::write_config(config_path.as_path(), &default_config)?;
482 Ok(default_config)
483 }
484
485 fn write_config(path: &Path, config: &ToolPolicyConfig) -> Result<()> {
486 if let Some(parent) = path.parent() {
487 if !parent.exists() {
488 fs::create_dir_all(parent).with_context(|| {
489 format!(
490 "Failed to create directory for tool policy config at {}",
491 parent.display()
492 )
493 })?;
494 }
495 }
496
497 let serialized = serde_json::to_string_pretty(config)
498 .context("Failed to serialize tool policy config")?;
499
500 fs::write(path, serialized)
501 .with_context(|| format!("Failed to write tool policy config: {}", path.display()))
502 }
503
504 fn convert_from_alternative(alt_config: AlternativeToolPolicyConfig) -> ToolPolicyConfig {
506 let mut policies = IndexMap::new();
507
508 for (tool_name, alt_policy) in alt_config.tools {
510 let policy = if alt_policy.allow {
511 ToolPolicy::Allow
512 } else {
513 ToolPolicy::Deny
514 };
515 policies.insert(tool_name, policy);
516 }
517
518 let mut config = ToolPolicyConfig {
519 version: alt_config.version,
520 available_tools: policies.keys().cloned().collect(),
521 policies,
522 constraints: alt_config.constraints,
523 mcp: McpPolicyStore::default(),
524 };
525 Self::apply_auto_allow_defaults(&mut config);
526 config
527 }
528
529 fn apply_config_policy(&mut self, tool_name: &str, policy: ConfigToolPolicy) {
530 let runtime_policy = match policy {
531 ConfigToolPolicy::Allow => ToolPolicy::Allow,
532 ConfigToolPolicy::Prompt => ToolPolicy::Prompt,
533 ConfigToolPolicy::Deny => ToolPolicy::Deny,
534 };
535
536 self.config
537 .policies
538 .insert(tool_name.to_string(), runtime_policy);
539 }
540
541 fn resolve_config_policy(tools_config: &ToolsConfig, tool_name: &str) -> ConfigToolPolicy {
542 if let Some(policy) = tools_config.policies.get(tool_name) {
543 return policy.clone();
544 }
545
546 match tool_name {
547 tools::LIST_FILES => tools_config
548 .policies
549 .get("list_dir")
550 .or_else(|| tools_config.policies.get("list_directory"))
551 .cloned(),
552 _ => None,
553 }
554 .unwrap_or_else(|| tools_config.default_policy.clone())
555 }
556
557 pub fn apply_tools_config(&mut self, tools_config: &ToolsConfig) -> Result<()> {
559 if self.config.available_tools.is_empty() {
560 return Ok(());
561 }
562
563 for tool in self.config.available_tools.clone() {
564 let config_policy = Self::resolve_config_policy(tools_config, &tool);
565 self.apply_config_policy(&tool, config_policy);
566 }
567
568 Self::apply_auto_allow_defaults(&mut self.config);
569 self.save_config()
570 }
571
572 pub fn update_available_tools(&mut self, tools: Vec<String>) -> Result<()> {
574 let current_tools: HashSet<_> = self.config.policies.keys().cloned().collect();
575 let new_tools: HashSet<_> = tools
576 .iter()
577 .filter(|name| !name.starts_with("mcp::"))
578 .cloned()
579 .collect();
580
581 let mut has_changes = false;
582
583 for tool in tools
585 .iter()
586 .filter(|tool| !tool.starts_with("mcp::") && !current_tools.contains(*tool))
587 {
588 let default_policy = if AUTO_ALLOW_TOOLS.contains(&tool.as_str()) {
589 ToolPolicy::Allow
590 } else {
591 ToolPolicy::Prompt
592 };
593 self.config.policies.insert(tool.clone(), default_policy);
594 has_changes = true;
595 }
596
597 let tools_to_remove: Vec<_> = self
599 .config
600 .policies
601 .keys()
602 .filter(|tool| !new_tools.contains(*tool))
603 .cloned()
604 .collect();
605
606 for tool in tools_to_remove {
607 self.config.policies.shift_remove(&tool);
608 has_changes = true;
609 }
610
611 if self.config.available_tools != tools {
613 self.config.available_tools = tools;
615 has_changes = true;
616 }
617
618 Self::ensure_network_constraints(&mut self.config);
619
620 if has_changes {
621 self.save_config()
622 } else {
623 Ok(())
624 }
625 }
626
627 pub fn update_mcp_tools(
629 &mut self,
630 provider_tools: &HashMap<String, Vec<String>>,
631 ) -> Result<()> {
632 let stored_providers: HashSet<String> = self.config.mcp.providers.keys().cloned().collect();
633 let mut has_changes = false;
634
635 for (provider, tools) in provider_tools {
637 let entry = self
638 .config
639 .mcp
640 .providers
641 .entry(provider.clone())
642 .or_insert_with(McpProviderPolicy::default);
643
644 let existing_tools: HashSet<String> = entry.tools.keys().cloned().collect();
645 let advertised: HashSet<String> = tools.iter().cloned().collect();
646
647 for tool in tools {
649 if !existing_tools.contains(tool) {
650 entry.tools.insert(tool.clone(), ToolPolicy::Prompt);
651 has_changes = true;
652 }
653 }
654
655 for stale in existing_tools.difference(&advertised) {
657 entry.tools.shift_remove(stale.as_str());
658 has_changes = true;
659 }
660 }
661
662 let advertised_providers: HashSet<String> = provider_tools.keys().cloned().collect();
664 for provider in stored_providers
665 .difference(&advertised_providers)
666 .cloned()
667 .collect::<Vec<_>>()
668 {
669 self.config.mcp.providers.shift_remove(provider.as_str());
670 has_changes = true;
671 }
672
673 let stale_runtime_keys: Vec<_> = self
675 .config
676 .policies
677 .keys()
678 .filter(|name| name.starts_with("mcp::"))
679 .cloned()
680 .collect();
681
682 for key in stale_runtime_keys {
683 self.config.policies.shift_remove(&key);
684 has_changes = true;
685 }
686
687 let mut available: Vec<String> = self
689 .config
690 .available_tools
691 .iter()
692 .filter(|name| !name.starts_with("mcp::"))
693 .cloned()
694 .collect();
695
696 for (provider, policy) in &self.config.mcp.providers {
697 for tool in policy.tools.keys() {
698 available.push(format!("mcp::{}::{}", provider, tool));
699 }
700 }
701
702 available.sort();
703 available.dedup();
704
705 if self.config.available_tools != available {
707 self.config.available_tools = available;
708 has_changes = true;
709 }
710
711 if has_changes {
712 self.save_config()
713 } else {
714 Ok(())
715 }
716 }
717
718 pub fn get_mcp_tool_policy(&self, provider: &str, tool: &str) -> ToolPolicy {
720 self.config
721 .mcp
722 .providers
723 .get(provider)
724 .and_then(|policy| policy.tools.get(tool))
725 .cloned()
726 .unwrap_or(ToolPolicy::Prompt)
727 }
728
729 pub fn set_mcp_tool_policy(
731 &mut self,
732 provider: &str,
733 tool: &str,
734 policy: ToolPolicy,
735 ) -> Result<()> {
736 let entry = self
737 .config
738 .mcp
739 .providers
740 .entry(provider.to_string())
741 .or_insert_with(McpProviderPolicy::default);
742 entry.tools.insert(tool.to_string(), policy);
743 self.save_config()
744 }
745
746 pub fn mcp_allowlist(&self) -> &McpAllowListConfig {
748 &self.config.mcp.allowlist
749 }
750
751 pub fn set_mcp_allowlist(&mut self, allowlist: McpAllowListConfig) -> Result<()> {
753 self.config.mcp.allowlist = allowlist;
754 self.save_config()
755 }
756
757 pub fn get_policy(&self, tool_name: &str) -> ToolPolicy {
759 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
760 return self.get_mcp_tool_policy(&provider, &tool);
761 }
762
763 self.config
764 .policies
765 .get(tool_name)
766 .cloned()
767 .unwrap_or(ToolPolicy::Prompt)
768 }
769
770 pub fn get_constraints(&self, tool_name: &str) -> Option<&ToolConstraints> {
772 self.config.constraints.get(tool_name)
773 }
774
775 pub fn should_execute_tool(&mut self, tool_name: &str) -> Result<bool> {
777 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
778 return match self.get_mcp_tool_policy(&provider, &tool) {
779 ToolPolicy::Allow => Ok(true),
780 ToolPolicy::Deny => Ok(false),
781 ToolPolicy::Prompt => {
782 if ToolPolicyManager::is_auto_allow_tool(tool_name) {
783 self.set_mcp_tool_policy(&provider, &tool, ToolPolicy::Allow)?;
784 Ok(true)
785 } else {
786 self.prompt_user_for_tool(tool_name)
787 }
788 }
789 };
790 }
791
792 match self.get_policy(tool_name) {
793 ToolPolicy::Allow => Ok(true),
794 ToolPolicy::Deny => Ok(false),
795 ToolPolicy::Prompt => {
796 if AUTO_ALLOW_TOOLS.contains(&tool_name) {
797 self.set_policy(tool_name, ToolPolicy::Allow)?;
798 return Ok(true);
799 }
800 let should_execute = self.prompt_user_for_tool(tool_name)?;
801 Ok(should_execute)
802 }
803 }
804 }
805
806 pub fn is_auto_allow_tool(tool_name: &str) -> bool {
807 AUTO_ALLOW_TOOLS.contains(&tool_name)
808 }
809
810 fn prompt_user_for_tool(&mut self, tool_name: &str) -> Result<bool> {
812 let interactive = std::io::stdin().is_terminal() && std::io::stdout().is_terminal();
813 let mut renderer = AnsiRenderer::stdout();
814 let banner_style = theme::banner_style();
815
816 if !interactive {
817 let message = format!(
818 "Non-interactive environment detected. Auto-approving '{}' tool.",
819 tool_name
820 );
821 renderer.line_with_style(banner_style, &message)?;
822 self.set_policy(tool_name, ToolPolicy::Allow)?;
823 return Ok(true);
824 }
825
826 let header = format!("Tool Permission Request: {}", tool_name);
827 renderer.line_with_style(banner_style, &header)?;
828 renderer.line_with_style(
829 banner_style,
830 &format!("The agent wants to use the '{}' tool.", tool_name),
831 )?;
832 renderer.line_with_style(banner_style, "")?;
833 renderer.line_with_style(
834 banner_style,
835 "This decision applies to the current request only.",
836 )?;
837 renderer.line_with_style(
838 banner_style,
839 "Update the policy file or use CLI flags to change the default.",
840 )?;
841 renderer.line_with_style(banner_style, "")?;
842
843 if AUTO_ALLOW_TOOLS.contains(&tool_name) {
844 renderer.line_with_style(
845 banner_style,
846 &format!(
847 "Auto-approving '{}' tool (default trusted tool).",
848 tool_name
849 ),
850 )?;
851 return Ok(true);
852 }
853
854 let rgb = theme::banner_color();
855 let to_ansi_256 = |value: u8| -> u8 {
856 if value < 48 {
857 0
858 } else if value < 114 {
859 1
860 } else {
861 ((value - 35) / 40).min(5)
862 }
863 };
864 let rgb_to_index = |r: u8, g: u8, b: u8| -> u8 {
865 let r_idx = to_ansi_256(r);
866 let g_idx = to_ansi_256(g);
867 let b_idx = to_ansi_256(b);
868 16 + 36 * r_idx + 6 * g_idx + b_idx
869 };
870 let color_index = rgb_to_index(rgb.0, rgb.1, rgb.2);
871 let dialog_color = ConsoleColor::Color256(color_index);
872 let tinted_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
873
874 let mut dialog_theme = ColorfulTheme::default();
875 dialog_theme.prompt_style = tinted_style;
876 dialog_theme.prompt_prefix = style("—".to_string()).for_stderr().fg(dialog_color);
877 dialog_theme.prompt_suffix = style("—".to_string()).for_stderr().fg(dialog_color);
878 dialog_theme.hint_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
879 dialog_theme.defaults_style = dialog_theme.hint_style.clone();
880 dialog_theme.success_prefix = style("✓".to_string()).for_stderr().fg(dialog_color);
881 dialog_theme.success_suffix = style("·".to_string()).for_stderr().fg(dialog_color);
882 dialog_theme.error_prefix = style("✗".to_string()).for_stderr().fg(dialog_color);
883 dialog_theme.error_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
884 dialog_theme.values_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
885
886 let prompt_text = format!("Allow the agent to use '{}'?", tool_name);
887
888 match Confirm::with_theme(&dialog_theme)
889 .with_prompt(prompt_text)
890 .default(false)
891 .interact()
892 {
893 Ok(confirmed) => {
894 let message = if confirmed {
895 format!("✓ Approved: '{}' tool will run now", tool_name)
896 } else {
897 format!("✗ Denied: '{}' tool will not run", tool_name)
898 };
899 let style = if confirmed {
900 MessageStyle::Tool
901 } else {
902 MessageStyle::Error
903 };
904 renderer.line(style, &message)?;
905 Ok(confirmed)
906 }
907 Err(e) => {
908 renderer.line(
909 MessageStyle::Error,
910 &format!("Failed to read confirmation: {}", e),
911 )?;
912 Ok(false)
913 }
914 }
915 }
916
917 pub fn set_policy(&mut self, tool_name: &str, policy: ToolPolicy) -> Result<()> {
919 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
920 return self.set_mcp_tool_policy(&provider, &tool, policy);
921 }
922
923 self.config.policies.insert(tool_name.to_string(), policy);
924 self.save_config()
925 }
926
927 pub fn reset_all_to_prompt(&mut self) -> Result<()> {
929 for policy in self.config.policies.values_mut() {
930 *policy = ToolPolicy::Prompt;
931 }
932 for provider in self.config.mcp.providers.values_mut() {
933 for policy in provider.tools.values_mut() {
934 *policy = ToolPolicy::Prompt;
935 }
936 }
937 self.save_config()
938 }
939
940 pub fn allow_all_tools(&mut self) -> Result<()> {
942 for policy in self.config.policies.values_mut() {
943 *policy = ToolPolicy::Allow;
944 }
945 for provider in self.config.mcp.providers.values_mut() {
946 for policy in provider.tools.values_mut() {
947 *policy = ToolPolicy::Allow;
948 }
949 }
950 self.save_config()
951 }
952
953 pub fn deny_all_tools(&mut self) -> Result<()> {
955 for policy in self.config.policies.values_mut() {
956 *policy = ToolPolicy::Deny;
957 }
958 for provider in self.config.mcp.providers.values_mut() {
959 for policy in provider.tools.values_mut() {
960 *policy = ToolPolicy::Deny;
961 }
962 }
963 self.save_config()
964 }
965
966 pub fn get_policy_summary(&self) -> IndexMap<String, ToolPolicy> {
968 let mut summary = self.config.policies.clone();
969 for (provider, policy) in &self.config.mcp.providers {
970 for (tool, status) in &policy.tools {
971 summary.insert(format!("mcp::{}::{}", provider, tool), status.clone());
972 }
973 }
974 summary
975 }
976
977 fn save_config(&self) -> Result<()> {
979 Self::write_config(&self.config_path, &self.config)
980 }
981
982 pub fn print_status(&self) {
984 println!("{}", style("Tool Policy Status").cyan().bold());
985 println!("Config file: {}", self.config_path.display());
986 println!();
987
988 let summary = self.get_policy_summary();
989
990 if summary.is_empty() {
991 println!("No tools configured yet.");
992 return;
993 }
994
995 let mut allow_count = 0;
996 let mut prompt_count = 0;
997 let mut deny_count = 0;
998
999 for (tool, policy) in &summary {
1000 let (status, color_name) = match policy {
1001 ToolPolicy::Allow => {
1002 allow_count += 1;
1003 ("ALLOW", "green")
1004 }
1005 ToolPolicy::Prompt => {
1006 prompt_count += 1;
1007 ("PROMPT", "yellow")
1008 }
1009 ToolPolicy::Deny => {
1010 deny_count += 1;
1011 ("DENY", "red")
1012 }
1013 };
1014
1015 let status_styled = match color_name {
1016 "green" => style(status).green(),
1017 "yellow" => style(status).yellow(),
1018 "red" => style(status).red(),
1019 _ => style(status),
1020 };
1021
1022 println!(
1023 " {} {}",
1024 style(format!("{:15}", tool)).cyan(),
1025 status_styled
1026 );
1027 }
1028
1029 println!();
1030 println!(
1031 "Summary: {} allowed, {} prompt, {} denied",
1032 style(allow_count).green(),
1033 style(prompt_count).yellow(),
1034 style(deny_count).red()
1035 );
1036 }
1037
1038 pub fn config_path(&self) -> &Path {
1040 &self.config_path
1041 }
1042}
1043
1044#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1046pub struct ToolConstraints {
1047 #[serde(default)]
1049 pub allowed_modes: Option<Vec<String>>,
1050 #[serde(default)]
1052 pub max_results_per_call: Option<usize>,
1053 #[serde(default)]
1055 pub max_items_per_call: Option<usize>,
1056 #[serde(default)]
1058 pub default_response_format: Option<String>,
1059 #[serde(default)]
1061 pub max_bytes_per_read: Option<usize>,
1062 #[serde(default)]
1064 pub max_response_bytes: Option<usize>,
1065 #[serde(default)]
1067 pub allowed_url_schemes: Option<Vec<String>>,
1068 #[serde(default)]
1070 pub denied_url_hosts: Option<Vec<String>>,
1071}
1072
1073#[cfg(test)]
1074mod tests {
1075 use super::*;
1076 use crate::config::constants::tools;
1077 use tempfile::tempdir;
1078
1079 #[test]
1080 fn test_tool_policy_config_serialization() {
1081 let mut config = ToolPolicyConfig::default();
1082 config.available_tools = vec![tools::READ_FILE.to_string(), tools::WRITE_FILE.to_string()];
1083 config
1084 .policies
1085 .insert(tools::READ_FILE.to_string(), ToolPolicy::Allow);
1086 config
1087 .policies
1088 .insert(tools::WRITE_FILE.to_string(), ToolPolicy::Prompt);
1089
1090 let json = serde_json::to_string_pretty(&config).unwrap();
1091 let deserialized: ToolPolicyConfig = serde_json::from_str(&json).unwrap();
1092
1093 assert_eq!(config.available_tools, deserialized.available_tools);
1094 assert_eq!(config.policies, deserialized.policies);
1095 }
1096
1097 #[test]
1098 fn test_policy_updates() {
1099 let dir = tempdir().unwrap();
1100 let config_path = dir.path().join("tool-policy.json");
1101
1102 let mut config = ToolPolicyConfig::default();
1103 config.available_tools = vec!["tool1".to_string()];
1104 config
1105 .policies
1106 .insert("tool1".to_string(), ToolPolicy::Prompt);
1107
1108 let content = serde_json::to_string_pretty(&config).unwrap();
1110 fs::write(&config_path, content).unwrap();
1111
1112 let mut loaded_config = ToolPolicyManager::load_or_create_config(&config_path).unwrap();
1114
1115 let new_tools = vec!["tool1".to_string(), "tool2".to_string()];
1117 let current_tools: std::collections::HashSet<_> =
1118 loaded_config.available_tools.iter().collect();
1119
1120 for tool in &new_tools {
1121 if !current_tools.contains(tool) {
1122 loaded_config
1123 .policies
1124 .insert(tool.clone(), ToolPolicy::Prompt);
1125 }
1126 }
1127
1128 loaded_config.available_tools = new_tools;
1129
1130 assert_eq!(loaded_config.policies.len(), 2);
1131 assert_eq!(
1132 loaded_config.policies.get("tool2"),
1133 Some(&ToolPolicy::Prompt)
1134 );
1135 }
1136}