1use anyhow::{Context, Result};
8use console::{Color as ConsoleColor, Style as ConsoleStyle, style};
9use dialoguer::{Confirm, theme::ColorfulTheme};
10use indexmap::IndexMap;
11use is_terminal::IsTerminal;
12use serde::{Deserialize, Serialize};
13use std::collections::{BTreeMap, HashMap, HashSet};
14use std::fs;
15use std::path::{Path, PathBuf};
16
17use crate::ui::theme;
18use crate::utils::ansi::{AnsiRenderer, MessageStyle};
19
20use crate::config::constants::tools;
21use crate::config::core::tools::{ToolPolicy as ConfigToolPolicy, ToolsConfig};
22use crate::config::mcp::{McpAllowListConfig, McpAllowListRules};
23
24const AUTO_ALLOW_TOOLS: &[&str] = &["run_terminal_cmd", "bash"];
25const DEFAULT_CURL_MAX_RESPONSE_BYTES: usize = 64 * 1024;
26
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29#[serde(rename_all = "lowercase")]
30pub enum ToolPolicy {
31 Allow,
33 Prompt,
35 Deny,
37}
38
39impl Default for ToolPolicy {
40 fn default() -> Self {
41 ToolPolicy::Prompt
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolPolicyConfig {
48 pub version: u32,
50 pub available_tools: Vec<String>,
52 pub policies: IndexMap<String, ToolPolicy>,
54 #[serde(default)]
56 pub constraints: IndexMap<String, ToolConstraints>,
57 #[serde(default)]
59 pub mcp: McpPolicyStore,
60}
61
62impl Default for ToolPolicyConfig {
63 fn default() -> Self {
64 Self {
65 version: 1,
66 available_tools: Vec::new(),
67 policies: IndexMap::new(),
68 constraints: IndexMap::new(),
69 mcp: McpPolicyStore::default(),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct McpPolicyStore {
77 #[serde(default = "default_secure_mcp_allowlist")]
79 pub allowlist: McpAllowListConfig,
80 #[serde(default)]
82 pub providers: IndexMap<String, McpProviderPolicy>,
83}
84
85impl Default for McpPolicyStore {
86 fn default() -> Self {
87 Self {
88 allowlist: default_secure_mcp_allowlist(),
89 providers: IndexMap::new(),
90 }
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize, Default)]
96pub struct McpProviderPolicy {
97 #[serde(default)]
98 pub tools: IndexMap<String, ToolPolicy>,
99}
100
101fn default_secure_mcp_allowlist() -> McpAllowListConfig {
102 let mut allowlist = McpAllowListConfig::default();
103 allowlist.enforce = true;
104
105 allowlist.default.logging = Some(vec![
106 "mcp.provider_initialized".to_string(),
107 "mcp.provider_initialization_failed".to_string(),
108 "mcp.tool_filtered".to_string(),
109 "mcp.tool_execution".to_string(),
110 "mcp.tool_failed".to_string(),
111 "mcp.tool_denied".to_string(),
112 ]);
113
114 allowlist.default.configuration = Some(BTreeMap::from([
115 (
116 "client".to_string(),
117 vec![
118 "max_concurrent_connections".to_string(),
119 "request_timeout_seconds".to_string(),
120 "retry_attempts".to_string(),
121 ],
122 ),
123 (
124 "ui".to_string(),
125 vec![
126 "mode".to_string(),
127 "max_events".to_string(),
128 "show_provider_names".to_string(),
129 ],
130 ),
131 (
132 "server".to_string(),
133 vec![
134 "enabled".to_string(),
135 "bind_address".to_string(),
136 "port".to_string(),
137 "transport".to_string(),
138 "name".to_string(),
139 "version".to_string(),
140 ],
141 ),
142 ]));
143
144 let mut time_rules = McpAllowListRules::default();
145 time_rules.tools = Some(vec![
146 "get_*".to_string(),
147 "list_*".to_string(),
148 "convert_timezone".to_string(),
149 "describe_timezone".to_string(),
150 "time_*".to_string(),
151 ]);
152 time_rules.resources = Some(vec!["timezone:*".to_string(), "location:*".to_string()]);
153 time_rules.logging = Some(vec![
154 "mcp.tool_execution".to_string(),
155 "mcp.tool_failed".to_string(),
156 "mcp.tool_denied".to_string(),
157 "mcp.tool_filtered".to_string(),
158 "mcp.provider_initialized".to_string(),
159 ]);
160 time_rules.configuration = Some(BTreeMap::from([
161 (
162 "provider".to_string(),
163 vec!["max_concurrent_requests".to_string()],
164 ),
165 (
166 "time".to_string(),
167 vec!["local_timezone_override".to_string()],
168 ),
169 ]));
170 allowlist.providers.insert("time".to_string(), time_rules);
171
172 let mut context_rules = McpAllowListRules::default();
173 context_rules.tools = Some(vec![
174 "search_*".to_string(),
175 "fetch_*".to_string(),
176 "list_*".to_string(),
177 "context7_*".to_string(),
178 "get_*".to_string(),
179 ]);
180 context_rules.resources = Some(vec![
181 "docs::*".to_string(),
182 "snippets::*".to_string(),
183 "repositories::*".to_string(),
184 "context7::*".to_string(),
185 ]);
186 context_rules.prompts = Some(vec![
187 "context7::*".to_string(),
188 "support::*".to_string(),
189 "docs::*".to_string(),
190 ]);
191 context_rules.logging = Some(vec![
192 "mcp.tool_execution".to_string(),
193 "mcp.tool_failed".to_string(),
194 "mcp.tool_denied".to_string(),
195 "mcp.tool_filtered".to_string(),
196 "mcp.provider_initialized".to_string(),
197 ]);
198 context_rules.configuration = Some(BTreeMap::from([
199 (
200 "provider".to_string(),
201 vec!["max_concurrent_requests".to_string()],
202 ),
203 (
204 "context7".to_string(),
205 vec![
206 "workspace".to_string(),
207 "search_scope".to_string(),
208 "max_results".to_string(),
209 ],
210 ),
211 ]));
212 allowlist
213 .providers
214 .insert("context7".to_string(), context_rules);
215
216 let mut seq_rules = McpAllowListRules::default();
217 seq_rules.tools = Some(vec![
218 "plan".to_string(),
219 "critique".to_string(),
220 "reflect".to_string(),
221 "decompose".to_string(),
222 "sequential_*".to_string(),
223 ]);
224 seq_rules.prompts = Some(vec![
225 "sequential-thinking::*".to_string(),
226 "plan".to_string(),
227 "reflect".to_string(),
228 "critique".to_string(),
229 ]);
230 seq_rules.logging = Some(vec![
231 "mcp.tool_execution".to_string(),
232 "mcp.tool_failed".to_string(),
233 "mcp.tool_denied".to_string(),
234 "mcp.tool_filtered".to_string(),
235 "mcp.provider_initialized".to_string(),
236 ]);
237 seq_rules.configuration = Some(BTreeMap::from([
238 (
239 "provider".to_string(),
240 vec!["max_concurrent_requests".to_string()],
241 ),
242 (
243 "sequencing".to_string(),
244 vec!["max_depth".to_string(), "max_branches".to_string()],
245 ),
246 ]));
247 allowlist
248 .providers
249 .insert("sequential-thinking".to_string(), seq_rules);
250
251 allowlist
252}
253
254fn parse_mcp_policy_key(tool_name: &str) -> Option<(String, String)> {
255 let mut parts = tool_name.splitn(3, "::");
256 match (parts.next()?, parts.next(), parts.next()) {
257 ("mcp", Some(provider), Some(tool)) if !provider.is_empty() && !tool.is_empty() => {
258 Some((provider.to_string(), tool.to_string()))
259 }
260 _ => None,
261 }
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct AlternativeToolPolicyConfig {
267 pub version: u32,
269 pub default: AlternativeDefaultPolicy,
271 pub tools: IndexMap<String, AlternativeToolPolicy>,
273 #[serde(default)]
275 pub constraints: IndexMap<String, ToolConstraints>,
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct AlternativeDefaultPolicy {
281 pub allow: bool,
283 pub rate_limit_per_run: u32,
285 pub max_concurrent: u32,
287 pub fs_write: bool,
289 pub network: bool,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct AlternativeToolPolicy {
296 pub allow: bool,
298 #[serde(default)]
300 pub fs_write: bool,
301 #[serde(default)]
303 pub network: bool,
304 #[serde(default)]
306 pub args_policy: Option<AlternativeArgsPolicy>,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct AlternativeArgsPolicy {
312 pub deny_substrings: Vec<String>,
314}
315
316#[derive(Clone)]
318pub struct ToolPolicyManager {
319 config_path: PathBuf,
320 config: ToolPolicyConfig,
321}
322
323impl ToolPolicyManager {
324 pub fn new() -> Result<Self> {
326 let config_path = Self::get_config_path()?;
327 let config = Self::load_or_create_config(&config_path)?;
328
329 Ok(Self {
330 config_path,
331 config,
332 })
333 }
334
335 pub fn new_with_workspace(workspace_root: &PathBuf) -> Result<Self> {
337 let config_path = Self::get_workspace_config_path(workspace_root)?;
338 let config = Self::load_or_create_config(&config_path)?;
339
340 Ok(Self {
341 config_path,
342 config,
343 })
344 }
345
346 fn get_config_path() -> Result<PathBuf> {
348 let home_dir = dirs::home_dir().context("Could not determine home directory")?;
349
350 let vtcode_dir = home_dir.join(".vtcode");
351 if !vtcode_dir.exists() {
352 fs::create_dir_all(&vtcode_dir).context("Failed to create ~/.vtcode directory")?;
353 }
354
355 Ok(vtcode_dir.join("tool-policy.json"))
356 }
357
358 fn get_workspace_config_path(workspace_root: &PathBuf) -> Result<PathBuf> {
360 let workspace_vtcode_dir = workspace_root.join(".vtcode");
361
362 if !workspace_vtcode_dir.exists() {
363 fs::create_dir_all(&workspace_vtcode_dir).with_context(|| {
364 format!(
365 "Failed to create workspace policy directory at {}",
366 workspace_vtcode_dir.display()
367 )
368 })?;
369 }
370
371 Ok(workspace_vtcode_dir.join("tool-policy.json"))
372 }
373
374 fn load_or_create_config(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
376 if config_path.exists() {
377 let content =
378 fs::read_to_string(config_path).context("Failed to read tool policy config")?;
379
380 if let Ok(alt_config) = serde_json::from_str::<AlternativeToolPolicyConfig>(&content) {
382 return Ok(Self::convert_from_alternative(alt_config));
384 }
385
386 match serde_json::from_str(&content) {
388 Ok(mut config) => {
389 Self::apply_auto_allow_defaults(&mut config);
390 Self::ensure_network_constraints(&mut config);
391 Ok(config)
392 }
393 Err(parse_err) => {
394 eprintln!(
395 "Warning: Invalid tool policy config at {} ({}). Resetting to defaults.",
396 config_path.display(),
397 parse_err
398 );
399 Self::reset_to_default(config_path)
400 }
401 }
402 } else {
403 let mut config = ToolPolicyConfig::default();
405 Self::apply_auto_allow_defaults(&mut config);
406 Self::ensure_network_constraints(&mut config);
407 Ok(config)
408 }
409 }
410
411 fn apply_auto_allow_defaults(config: &mut ToolPolicyConfig) {
412 for tool in AUTO_ALLOW_TOOLS {
413 config
414 .policies
415 .entry((*tool).to_string())
416 .and_modify(|policy| *policy = ToolPolicy::Allow)
417 .or_insert(ToolPolicy::Allow);
418 if !config.available_tools.contains(&tool.to_string()) {
419 config.available_tools.push(tool.to_string());
420 }
421 }
422 Self::ensure_network_constraints(config);
423 }
424
425 fn ensure_network_constraints(config: &mut ToolPolicyConfig) {
426 let entry = config
427 .constraints
428 .entry(tools::CURL.to_string())
429 .or_insert_with(ToolConstraints::default);
430
431 if entry.max_response_bytes.is_none() {
432 entry.max_response_bytes = Some(DEFAULT_CURL_MAX_RESPONSE_BYTES);
433 }
434 if entry.allowed_url_schemes.is_none() {
435 entry.allowed_url_schemes = Some(vec!["https".to_string()]);
436 }
437 if entry.denied_url_hosts.is_none() {
438 entry.denied_url_hosts = Some(vec![
439 "localhost".to_string(),
440 "127.0.0.1".to_string(),
441 "0.0.0.0".to_string(),
442 "::1".to_string(),
443 ".localhost".to_string(),
444 ".local".to_string(),
445 ".internal".to_string(),
446 ".lan".to_string(),
447 ]);
448 }
449 }
450
451 fn reset_to_default(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
452 let backup_path = config_path.with_extension("json.bak");
453
454 if let Err(err) = fs::rename(config_path, &backup_path) {
455 eprintln!(
456 "Warning: Unable to back up invalid tool policy config ({}). {}",
457 config_path.display(),
458 err
459 );
460 } else {
461 eprintln!(
462 "Backed up invalid tool policy config to {}",
463 backup_path.display()
464 );
465 }
466
467 let default_config = ToolPolicyConfig::default();
468 Self::write_config(config_path.as_path(), &default_config)?;
469 Ok(default_config)
470 }
471
472 fn write_config(path: &Path, config: &ToolPolicyConfig) -> Result<()> {
473 if let Some(parent) = path.parent() {
474 if !parent.exists() {
475 fs::create_dir_all(parent).with_context(|| {
476 format!(
477 "Failed to create directory for tool policy config at {}",
478 parent.display()
479 )
480 })?;
481 }
482 }
483
484 let serialized = serde_json::to_string_pretty(config)
485 .context("Failed to serialize tool policy config")?;
486
487 fs::write(path, serialized)
488 .with_context(|| format!("Failed to write tool policy config: {}", path.display()))
489 }
490
491 fn convert_from_alternative(alt_config: AlternativeToolPolicyConfig) -> ToolPolicyConfig {
493 let mut policies = IndexMap::new();
494
495 for (tool_name, alt_policy) in alt_config.tools {
497 let policy = if alt_policy.allow {
498 ToolPolicy::Allow
499 } else {
500 ToolPolicy::Deny
501 };
502 policies.insert(tool_name, policy);
503 }
504
505 let mut config = ToolPolicyConfig {
506 version: alt_config.version,
507 available_tools: policies.keys().cloned().collect(),
508 policies,
509 constraints: alt_config.constraints,
510 mcp: McpPolicyStore::default(),
511 };
512 Self::apply_auto_allow_defaults(&mut config);
513 config
514 }
515
516 fn apply_config_policy(&mut self, tool_name: &str, policy: ConfigToolPolicy) {
517 let runtime_policy = match policy {
518 ConfigToolPolicy::Allow => ToolPolicy::Allow,
519 ConfigToolPolicy::Prompt => ToolPolicy::Prompt,
520 ConfigToolPolicy::Deny => ToolPolicy::Deny,
521 };
522
523 self.config
524 .policies
525 .insert(tool_name.to_string(), runtime_policy);
526 }
527
528 fn resolve_config_policy(tools_config: &ToolsConfig, tool_name: &str) -> ConfigToolPolicy {
529 if let Some(policy) = tools_config.policies.get(tool_name) {
530 return policy.clone();
531 }
532
533 match tool_name {
534 tools::LIST_FILES => tools_config
535 .policies
536 .get("list_dir")
537 .or_else(|| tools_config.policies.get("list_directory"))
538 .cloned(),
539 _ => None,
540 }
541 .unwrap_or_else(|| tools_config.default_policy.clone())
542 }
543
544 pub fn apply_tools_config(&mut self, tools_config: &ToolsConfig) -> Result<()> {
546 if self.config.available_tools.is_empty() {
547 return Ok(());
548 }
549
550 for tool in self.config.available_tools.clone() {
551 let config_policy = Self::resolve_config_policy(tools_config, &tool);
552 self.apply_config_policy(&tool, config_policy);
553 }
554
555 Self::apply_auto_allow_defaults(&mut self.config);
556 self.save_config()
557 }
558
559 pub fn update_available_tools(&mut self, tools: Vec<String>) -> Result<()> {
561 let current_tools: HashSet<_> = self.config.policies.keys().cloned().collect();
562 let new_tools: HashSet<_> = tools
563 .iter()
564 .filter(|name| !name.starts_with("mcp::"))
565 .cloned()
566 .collect();
567
568 for tool in tools
570 .iter()
571 .filter(|tool| !tool.starts_with("mcp::") && !current_tools.contains(*tool))
572 {
573 let default_policy = if AUTO_ALLOW_TOOLS.contains(&tool.as_str()) {
574 ToolPolicy::Allow
575 } else {
576 ToolPolicy::Prompt
577 };
578 self.config.policies.insert(tool.clone(), default_policy);
579 }
580
581 let tools_to_remove: Vec<_> = self
583 .config
584 .policies
585 .keys()
586 .filter(|tool| !new_tools.contains(*tool))
587 .cloned()
588 .collect();
589
590 for tool in tools_to_remove {
591 self.config.policies.shift_remove(&tool);
592 }
593
594 self.config.available_tools = tools;
596
597 Self::ensure_network_constraints(&mut self.config);
598
599 self.save_config()
600 }
601
602 pub fn update_mcp_tools(
604 &mut self,
605 provider_tools: &HashMap<String, Vec<String>>,
606 ) -> Result<()> {
607 let stored_providers: HashSet<String> = self.config.mcp.providers.keys().cloned().collect();
608
609 for (provider, tools) in provider_tools {
611 let entry = self
612 .config
613 .mcp
614 .providers
615 .entry(provider.clone())
616 .or_insert_with(McpProviderPolicy::default);
617
618 let existing_tools: HashSet<String> = entry.tools.keys().cloned().collect();
619 let advertised: HashSet<String> = tools.iter().cloned().collect();
620
621 for tool in tools {
623 entry
624 .tools
625 .entry(tool.clone())
626 .or_insert(ToolPolicy::Prompt);
627 }
628
629 for stale in existing_tools.difference(&advertised) {
631 entry.tools.shift_remove(stale.as_str());
632 }
633 }
634
635 let advertised_providers: HashSet<String> = provider_tools.keys().cloned().collect();
637 for provider in stored_providers
638 .difference(&advertised_providers)
639 .cloned()
640 .collect::<Vec<_>>()
641 {
642 self.config.mcp.providers.shift_remove(provider.as_str());
643 }
644
645 let stale_runtime_keys: Vec<_> = self
647 .config
648 .policies
649 .keys()
650 .filter(|name| name.starts_with("mcp::"))
651 .cloned()
652 .collect();
653 for key in stale_runtime_keys {
654 self.config.policies.shift_remove(&key);
655 }
656
657 let mut available: Vec<String> = self
659 .config
660 .available_tools
661 .iter()
662 .filter(|name| !name.starts_with("mcp::"))
663 .cloned()
664 .collect();
665
666 for (provider, policy) in &self.config.mcp.providers {
667 for tool in policy.tools.keys() {
668 available.push(format!("mcp::{}::{}", provider, tool));
669 }
670 }
671
672 available.sort();
673 available.dedup();
674 self.config.available_tools = available;
675
676 self.save_config()
677 }
678
679 pub fn get_mcp_tool_policy(&self, provider: &str, tool: &str) -> ToolPolicy {
681 self.config
682 .mcp
683 .providers
684 .get(provider)
685 .and_then(|policy| policy.tools.get(tool))
686 .cloned()
687 .unwrap_or(ToolPolicy::Prompt)
688 }
689
690 pub fn set_mcp_tool_policy(
692 &mut self,
693 provider: &str,
694 tool: &str,
695 policy: ToolPolicy,
696 ) -> Result<()> {
697 let entry = self
698 .config
699 .mcp
700 .providers
701 .entry(provider.to_string())
702 .or_insert_with(McpProviderPolicy::default);
703 entry.tools.insert(tool.to_string(), policy);
704 self.save_config()
705 }
706
707 pub fn mcp_allowlist(&self) -> &McpAllowListConfig {
709 &self.config.mcp.allowlist
710 }
711
712 pub fn set_mcp_allowlist(&mut self, allowlist: McpAllowListConfig) -> Result<()> {
714 self.config.mcp.allowlist = allowlist;
715 self.save_config()
716 }
717
718 pub fn get_policy(&self, tool_name: &str) -> ToolPolicy {
720 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
721 return self.get_mcp_tool_policy(&provider, &tool);
722 }
723
724 self.config
725 .policies
726 .get(tool_name)
727 .cloned()
728 .unwrap_or(ToolPolicy::Prompt)
729 }
730
731 pub fn get_constraints(&self, tool_name: &str) -> Option<&ToolConstraints> {
733 self.config.constraints.get(tool_name)
734 }
735
736 pub fn should_execute_tool(&mut self, tool_name: &str) -> Result<bool> {
738 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
739 return match self.get_mcp_tool_policy(&provider, &tool) {
740 ToolPolicy::Allow => Ok(true),
741 ToolPolicy::Deny => Ok(false),
742 ToolPolicy::Prompt => {
743 if ToolPolicyManager::is_auto_allow_tool(tool_name) {
744 self.set_mcp_tool_policy(&provider, &tool, ToolPolicy::Allow)?;
745 Ok(true)
746 } else {
747 self.prompt_user_for_tool(tool_name)
748 }
749 }
750 };
751 }
752
753 match self.get_policy(tool_name) {
754 ToolPolicy::Allow => Ok(true),
755 ToolPolicy::Deny => Ok(false),
756 ToolPolicy::Prompt => {
757 if AUTO_ALLOW_TOOLS.contains(&tool_name) {
758 self.set_policy(tool_name, ToolPolicy::Allow)?;
759 return Ok(true);
760 }
761 let should_execute = self.prompt_user_for_tool(tool_name)?;
762 Ok(should_execute)
763 }
764 }
765 }
766
767 pub fn is_auto_allow_tool(tool_name: &str) -> bool {
768 AUTO_ALLOW_TOOLS.contains(&tool_name)
769 }
770
771 fn prompt_user_for_tool(&mut self, tool_name: &str) -> Result<bool> {
773 let interactive = std::io::stdin().is_terminal() && std::io::stdout().is_terminal();
774 let mut renderer = AnsiRenderer::stdout();
775 let banner_style = theme::banner_style();
776
777 if !interactive {
778 let message = format!(
779 "Non-interactive environment detected. Auto-approving '{}' tool.",
780 tool_name
781 );
782 renderer.line_with_style(banner_style, &message)?;
783 return Ok(true);
784 }
785
786 let header = format!("Tool Permission Request: {}", tool_name);
787 renderer.line_with_style(banner_style, &header)?;
788 renderer.line_with_style(
789 banner_style,
790 &format!("The agent wants to use the '{}' tool.", tool_name),
791 )?;
792 renderer.line_with_style(banner_style, "")?;
793 renderer.line_with_style(
794 banner_style,
795 "This decision applies to the current request only.",
796 )?;
797 renderer.line_with_style(
798 banner_style,
799 "Update the policy file or use CLI flags to change the default.",
800 )?;
801 renderer.line_with_style(banner_style, "")?;
802
803 if AUTO_ALLOW_TOOLS.contains(&tool_name) {
804 renderer.line_with_style(
805 banner_style,
806 &format!(
807 "Auto-approving '{}' tool (default trusted tool).",
808 tool_name
809 ),
810 )?;
811 return Ok(true);
812 }
813
814 let rgb = theme::banner_color();
815 let to_ansi_256 = |value: u8| -> u8 {
816 if value < 48 {
817 0
818 } else if value < 114 {
819 1
820 } else {
821 ((value - 35) / 40).min(5)
822 }
823 };
824 let rgb_to_index = |r: u8, g: u8, b: u8| -> u8 {
825 let r_idx = to_ansi_256(r);
826 let g_idx = to_ansi_256(g);
827 let b_idx = to_ansi_256(b);
828 16 + 36 * r_idx + 6 * g_idx + b_idx
829 };
830 let color_index = rgb_to_index(rgb.0, rgb.1, rgb.2);
831 let dialog_color = ConsoleColor::Color256(color_index);
832 let tinted_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
833
834 let mut dialog_theme = ColorfulTheme::default();
835 dialog_theme.prompt_style = tinted_style;
836 dialog_theme.prompt_prefix = style("—".to_string()).for_stderr().fg(dialog_color);
837 dialog_theme.prompt_suffix = style("—".to_string()).for_stderr().fg(dialog_color);
838 dialog_theme.hint_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
839 dialog_theme.defaults_style = dialog_theme.hint_style.clone();
840 dialog_theme.success_prefix = style("✓".to_string()).for_stderr().fg(dialog_color);
841 dialog_theme.success_suffix = style("·".to_string()).for_stderr().fg(dialog_color);
842 dialog_theme.error_prefix = style("✗".to_string()).for_stderr().fg(dialog_color);
843 dialog_theme.error_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
844 dialog_theme.values_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
845
846 let prompt_text = format!("Allow the agent to use '{}'?", tool_name);
847
848 match Confirm::with_theme(&dialog_theme)
849 .with_prompt(prompt_text)
850 .default(false)
851 .interact()
852 {
853 Ok(confirmed) => {
854 let message = if confirmed {
855 format!("✓ Approved: '{}' tool will run now", tool_name)
856 } else {
857 format!("✗ Denied: '{}' tool will not run", tool_name)
858 };
859 let style = if confirmed {
860 MessageStyle::Tool
861 } else {
862 MessageStyle::Error
863 };
864 renderer.line(style, &message)?;
865 Ok(confirmed)
866 }
867 Err(e) => {
868 renderer.line(
869 MessageStyle::Error,
870 &format!("Failed to read confirmation: {}", e),
871 )?;
872 Ok(false)
873 }
874 }
875 }
876
877 pub fn set_policy(&mut self, tool_name: &str, policy: ToolPolicy) -> Result<()> {
879 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
880 return self.set_mcp_tool_policy(&provider, &tool, policy);
881 }
882
883 self.config.policies.insert(tool_name.to_string(), policy);
884 self.save_config()
885 }
886
887 pub fn reset_all_to_prompt(&mut self) -> Result<()> {
889 for policy in self.config.policies.values_mut() {
890 *policy = ToolPolicy::Prompt;
891 }
892 for provider in self.config.mcp.providers.values_mut() {
893 for policy in provider.tools.values_mut() {
894 *policy = ToolPolicy::Prompt;
895 }
896 }
897 self.save_config()
898 }
899
900 pub fn allow_all_tools(&mut self) -> Result<()> {
902 for policy in self.config.policies.values_mut() {
903 *policy = ToolPolicy::Allow;
904 }
905 for provider in self.config.mcp.providers.values_mut() {
906 for policy in provider.tools.values_mut() {
907 *policy = ToolPolicy::Allow;
908 }
909 }
910 self.save_config()
911 }
912
913 pub fn deny_all_tools(&mut self) -> Result<()> {
915 for policy in self.config.policies.values_mut() {
916 *policy = ToolPolicy::Deny;
917 }
918 for provider in self.config.mcp.providers.values_mut() {
919 for policy in provider.tools.values_mut() {
920 *policy = ToolPolicy::Deny;
921 }
922 }
923 self.save_config()
924 }
925
926 pub fn get_policy_summary(&self) -> IndexMap<String, ToolPolicy> {
928 let mut summary = self.config.policies.clone();
929 for (provider, policy) in &self.config.mcp.providers {
930 for (tool, status) in &policy.tools {
931 summary.insert(format!("mcp::{}::{}", provider, tool), status.clone());
932 }
933 }
934 summary
935 }
936
937 fn save_config(&self) -> Result<()> {
939 Self::write_config(&self.config_path, &self.config)
940 }
941
942 pub fn print_status(&self) {
944 println!("{}", style("Tool Policy Status").cyan().bold());
945 println!("Config file: {}", self.config_path.display());
946 println!();
947
948 let summary = self.get_policy_summary();
949
950 if summary.is_empty() {
951 println!("No tools configured yet.");
952 return;
953 }
954
955 let mut allow_count = 0;
956 let mut prompt_count = 0;
957 let mut deny_count = 0;
958
959 for (tool, policy) in &summary {
960 let (status, color_name) = match policy {
961 ToolPolicy::Allow => {
962 allow_count += 1;
963 ("ALLOW", "green")
964 }
965 ToolPolicy::Prompt => {
966 prompt_count += 1;
967 ("PROMPT", "yellow")
968 }
969 ToolPolicy::Deny => {
970 deny_count += 1;
971 ("DENY", "red")
972 }
973 };
974
975 let status_styled = match color_name {
976 "green" => style(status).green(),
977 "yellow" => style(status).yellow(),
978 "red" => style(status).red(),
979 _ => style(status),
980 };
981
982 println!(
983 " {} {}",
984 style(format!("{:15}", tool)).cyan(),
985 status_styled
986 );
987 }
988
989 println!();
990 println!(
991 "Summary: {} allowed, {} prompt, {} denied",
992 style(allow_count).green(),
993 style(prompt_count).yellow(),
994 style(deny_count).red()
995 );
996 }
997
998 pub fn config_path(&self) -> &Path {
1000 &self.config_path
1001 }
1002}
1003
1004#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1006pub struct ToolConstraints {
1007 #[serde(default)]
1009 pub allowed_modes: Option<Vec<String>>,
1010 #[serde(default)]
1012 pub max_results_per_call: Option<usize>,
1013 #[serde(default)]
1015 pub max_items_per_call: Option<usize>,
1016 #[serde(default)]
1018 pub default_response_format: Option<String>,
1019 #[serde(default)]
1021 pub max_bytes_per_read: Option<usize>,
1022 #[serde(default)]
1024 pub max_response_bytes: Option<usize>,
1025 #[serde(default)]
1027 pub allowed_url_schemes: Option<Vec<String>>,
1028 #[serde(default)]
1030 pub denied_url_hosts: Option<Vec<String>>,
1031}
1032
1033#[cfg(test)]
1034mod tests {
1035 use super::*;
1036 use crate::config::constants::tools;
1037 use tempfile::tempdir;
1038
1039 #[test]
1040 fn test_tool_policy_config_serialization() {
1041 let mut config = ToolPolicyConfig::default();
1042 config.available_tools = vec![tools::READ_FILE.to_string(), tools::WRITE_FILE.to_string()];
1043 config
1044 .policies
1045 .insert(tools::READ_FILE.to_string(), ToolPolicy::Allow);
1046 config
1047 .policies
1048 .insert(tools::WRITE_FILE.to_string(), ToolPolicy::Prompt);
1049
1050 let json = serde_json::to_string_pretty(&config).unwrap();
1051 let deserialized: ToolPolicyConfig = serde_json::from_str(&json).unwrap();
1052
1053 assert_eq!(config.available_tools, deserialized.available_tools);
1054 assert_eq!(config.policies, deserialized.policies);
1055 }
1056
1057 #[test]
1058 fn test_policy_updates() {
1059 let dir = tempdir().unwrap();
1060 let config_path = dir.path().join("tool-policy.json");
1061
1062 let mut config = ToolPolicyConfig::default();
1063 config.available_tools = vec!["tool1".to_string()];
1064 config
1065 .policies
1066 .insert("tool1".to_string(), ToolPolicy::Prompt);
1067
1068 let content = serde_json::to_string_pretty(&config).unwrap();
1070 fs::write(&config_path, content).unwrap();
1071
1072 let mut loaded_config = ToolPolicyManager::load_or_create_config(&config_path).unwrap();
1074
1075 let new_tools = vec!["tool1".to_string(), "tool2".to_string()];
1077 let current_tools: std::collections::HashSet<_> =
1078 loaded_config.available_tools.iter().collect();
1079
1080 for tool in &new_tools {
1081 if !current_tools.contains(tool) {
1082 loaded_config
1083 .policies
1084 .insert(tool.clone(), ToolPolicy::Prompt);
1085 }
1086 }
1087
1088 loaded_config.available_tools = new_tools;
1089
1090 assert_eq!(loaded_config.policies.len(), 2);
1091 assert_eq!(
1092 loaded_config.policies.get("tool2"),
1093 Some(&ToolPolicy::Prompt)
1094 );
1095 }
1096}