1use anyhow::{Context, Result};
8use console::{Color as ConsoleColor, Style as ConsoleStyle, style};
9use dialoguer::{Confirm, theme::ColorfulTheme};
10use indexmap::IndexMap;
11use is_terminal::IsTerminal;
12use serde::{Deserialize, Serialize};
13use std::collections::{BTreeMap, HashMap, HashSet};
14use std::fs;
15use std::path::{Path, PathBuf};
16
17use crate::ui::theme;
18use crate::utils::ansi::{AnsiRenderer, MessageStyle};
19
20use crate::config::constants::tools;
21use crate::config::core::tools::{ToolPolicy as ConfigToolPolicy, ToolsConfig};
22use crate::config::mcp::{McpAllowListConfig, McpAllowListRules};
23
24const AUTO_ALLOW_TOOLS: &[&str] = &["run_terminal_cmd", "bash"];
25const DEFAULT_CURL_MAX_RESPONSE_BYTES: usize = 64 * 1024;
26
27#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29#[serde(rename_all = "lowercase")]
30pub enum ToolPolicy {
31 Allow,
33 Prompt,
35 Deny,
37}
38
39impl Default for ToolPolicy {
40 fn default() -> Self {
41 ToolPolicy::Prompt
42 }
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolPolicyConfig {
48 pub version: u32,
50 pub available_tools: Vec<String>,
52 pub policies: IndexMap<String, ToolPolicy>,
54 #[serde(default)]
56 pub constraints: IndexMap<String, ToolConstraints>,
57 #[serde(default)]
59 pub mcp: McpPolicyStore,
60}
61
62impl Default for ToolPolicyConfig {
63 fn default() -> Self {
64 Self {
65 version: 1,
66 available_tools: Vec::new(),
67 policies: IndexMap::new(),
68 constraints: IndexMap::new(),
69 mcp: McpPolicyStore::default(),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct McpPolicyStore {
77 #[serde(default = "default_secure_mcp_allowlist")]
79 pub allowlist: McpAllowListConfig,
80 #[serde(default)]
82 pub providers: IndexMap<String, McpProviderPolicy>,
83}
84
85impl Default for McpPolicyStore {
86 fn default() -> Self {
87 Self {
88 allowlist: default_secure_mcp_allowlist(),
89 providers: IndexMap::new(),
90 }
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize, Default)]
96pub struct McpProviderPolicy {
97 #[serde(default)]
98 pub tools: IndexMap<String, ToolPolicy>,
99}
100
101fn default_secure_mcp_allowlist() -> McpAllowListConfig {
102 let mut allowlist = McpAllowListConfig::default();
103 allowlist.enforce = true;
104
105 allowlist.default.logging = Some(vec![
106 "mcp.provider_initialized".to_string(),
107 "mcp.provider_initialization_failed".to_string(),
108 "mcp.tool_filtered".to_string(),
109 "mcp.tool_execution".to_string(),
110 "mcp.tool_failed".to_string(),
111 "mcp.tool_denied".to_string(),
112 ]);
113
114 allowlist.default.configuration = Some(BTreeMap::from([
115 (
116 "client".to_string(),
117 vec![
118 "max_concurrent_connections".to_string(),
119 "request_timeout_seconds".to_string(),
120 "retry_attempts".to_string(),
121 ],
122 ),
123 (
124 "ui".to_string(),
125 vec![
126 "mode".to_string(),
127 "max_events".to_string(),
128 "show_provider_names".to_string(),
129 ],
130 ),
131 (
132 "server".to_string(),
133 vec![
134 "enabled".to_string(),
135 "bind_address".to_string(),
136 "port".to_string(),
137 "transport".to_string(),
138 "name".to_string(),
139 "version".to_string(),
140 ],
141 ),
142 ]));
143
144 let mut time_rules = McpAllowListRules::default();
145 time_rules.tools = Some(vec![
146 "get_*".to_string(),
147 "list_*".to_string(),
148 "convert_timezone".to_string(),
149 "describe_timezone".to_string(),
150 "time_*".to_string(),
151 ]);
152 time_rules.resources = Some(vec!["timezone:*".to_string(), "location:*".to_string()]);
153 time_rules.logging = Some(vec![
154 "mcp.tool_execution".to_string(),
155 "mcp.tool_failed".to_string(),
156 "mcp.tool_denied".to_string(),
157 "mcp.tool_filtered".to_string(),
158 "mcp.provider_initialized".to_string(),
159 ]);
160 time_rules.configuration = Some(BTreeMap::from([
161 (
162 "provider".to_string(),
163 vec!["max_concurrent_requests".to_string()],
164 ),
165 (
166 "time".to_string(),
167 vec!["local_timezone_override".to_string()],
168 ),
169 ]));
170 allowlist.providers.insert("time".to_string(), time_rules);
171
172 let mut context_rules = McpAllowListRules::default();
173 context_rules.tools = Some(vec![
174 "search_*".to_string(),
175 "fetch_*".to_string(),
176 "list_*".to_string(),
177 "context7_*".to_string(),
178 "get_*".to_string(),
179 ]);
180 context_rules.resources = Some(vec![
181 "docs::*".to_string(),
182 "snippets::*".to_string(),
183 "repositories::*".to_string(),
184 "context7::*".to_string(),
185 ]);
186 context_rules.prompts = Some(vec![
187 "context7::*".to_string(),
188 "support::*".to_string(),
189 "docs::*".to_string(),
190 ]);
191 context_rules.logging = Some(vec![
192 "mcp.tool_execution".to_string(),
193 "mcp.tool_failed".to_string(),
194 "mcp.tool_denied".to_string(),
195 "mcp.tool_filtered".to_string(),
196 "mcp.provider_initialized".to_string(),
197 ]);
198 context_rules.configuration = Some(BTreeMap::from([
199 (
200 "provider".to_string(),
201 vec!["max_concurrent_requests".to_string()],
202 ),
203 (
204 "context7".to_string(),
205 vec![
206 "workspace".to_string(),
207 "search_scope".to_string(),
208 "max_results".to_string(),
209 ],
210 ),
211 ]));
212 allowlist
213 .providers
214 .insert("context7".to_string(), context_rules);
215
216 let mut seq_rules = McpAllowListRules::default();
217 seq_rules.tools = Some(vec![
218 "plan".to_string(),
219 "critique".to_string(),
220 "reflect".to_string(),
221 "decompose".to_string(),
222 "sequential_*".to_string(),
223 ]);
224 seq_rules.prompts = Some(vec![
225 "sequential-thinking::*".to_string(),
226 "plan".to_string(),
227 "reflect".to_string(),
228 "critique".to_string(),
229 ]);
230 seq_rules.logging = Some(vec![
231 "mcp.tool_execution".to_string(),
232 "mcp.tool_failed".to_string(),
233 "mcp.tool_denied".to_string(),
234 "mcp.tool_filtered".to_string(),
235 "mcp.provider_initialized".to_string(),
236 ]);
237 seq_rules.configuration = Some(BTreeMap::from([
238 (
239 "provider".to_string(),
240 vec!["max_concurrent_requests".to_string()],
241 ),
242 (
243 "sequencing".to_string(),
244 vec!["max_depth".to_string(), "max_branches".to_string()],
245 ),
246 ]));
247 allowlist
248 .providers
249 .insert("sequential-thinking".to_string(), seq_rules);
250
251 allowlist
252}
253
254fn parse_mcp_policy_key(tool_name: &str) -> Option<(String, String)> {
255 let mut parts = tool_name.splitn(3, "::");
256 match (parts.next()?, parts.next(), parts.next()) {
257 ("mcp", Some(provider), Some(tool)) if !provider.is_empty() && !tool.is_empty() => {
258 Some((provider.to_string(), tool.to_string()))
259 }
260 _ => None,
261 }
262}
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct AlternativeToolPolicyConfig {
267 pub version: u32,
269 pub default: AlternativeDefaultPolicy,
271 pub tools: IndexMap<String, AlternativeToolPolicy>,
273 #[serde(default)]
275 pub constraints: IndexMap<String, ToolConstraints>,
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct AlternativeDefaultPolicy {
281 pub allow: bool,
283 pub rate_limit_per_run: u32,
285 pub max_concurrent: u32,
287 pub fs_write: bool,
289 pub network: bool,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct AlternativeToolPolicy {
296 pub allow: bool,
298 #[serde(default)]
300 pub fs_write: bool,
301 #[serde(default)]
303 pub network: bool,
304 #[serde(default)]
306 pub args_policy: Option<AlternativeArgsPolicy>,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct AlternativeArgsPolicy {
312 pub deny_substrings: Vec<String>,
314}
315
316#[derive(Clone)]
318pub struct ToolPolicyManager {
319 config_path: PathBuf,
320 config: ToolPolicyConfig,
321}
322
323impl ToolPolicyManager {
324 pub fn new() -> Result<Self> {
326 let config_path = Self::get_config_path()?;
327 let config = Self::load_or_create_config(&config_path)?;
328
329 Ok(Self {
330 config_path,
331 config,
332 })
333 }
334
335 pub fn new_with_workspace(workspace_root: &PathBuf) -> Result<Self> {
337 let config_path = Self::get_workspace_config_path(workspace_root)?;
338 let config = Self::load_or_create_config(&config_path)?;
339
340 Ok(Self {
341 config_path,
342 config,
343 })
344 }
345
346 fn get_config_path() -> Result<PathBuf> {
348 let home_dir = dirs::home_dir().context("Could not determine home directory")?;
349
350 let vtcode_dir = home_dir.join(".vtcode");
351 if !vtcode_dir.exists() {
352 fs::create_dir_all(&vtcode_dir).context("Failed to create ~/.vtcode directory")?;
353 }
354
355 Ok(vtcode_dir.join("tool-policy.json"))
356 }
357
358 fn get_workspace_config_path(workspace_root: &PathBuf) -> Result<PathBuf> {
360 let workspace_vtcode_dir = workspace_root.join(".vtcode");
361
362 if !workspace_vtcode_dir.exists() {
363 fs::create_dir_all(&workspace_vtcode_dir).with_context(|| {
364 format!(
365 "Failed to create workspace policy directory at {}",
366 workspace_vtcode_dir.display()
367 )
368 })?;
369 }
370
371 Ok(workspace_vtcode_dir.join("tool-policy.json"))
372 }
373
374 fn load_or_create_config(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
376 if config_path.exists() {
377 let content =
378 fs::read_to_string(config_path).context("Failed to read tool policy config")?;
379
380 if let Ok(alt_config) = serde_json::from_str::<AlternativeToolPolicyConfig>(&content) {
382 return Ok(Self::convert_from_alternative(alt_config));
384 }
385
386 match serde_json::from_str(&content) {
388 Ok(mut config) => {
389 Self::apply_auto_allow_defaults(&mut config);
390 Self::ensure_network_constraints(&mut config);
391 Ok(config)
392 }
393 Err(parse_err) => {
394 eprintln!(
395 "Warning: Invalid tool policy config at {} ({}). Resetting to defaults.",
396 config_path.display(),
397 parse_err
398 );
399 Self::reset_to_default(config_path)
400 }
401 }
402 } else {
403 let mut config = ToolPolicyConfig::default();
405 Self::apply_auto_allow_defaults(&mut config);
406 Self::ensure_network_constraints(&mut config);
407 Ok(config)
408 }
409 }
410
411 fn apply_auto_allow_defaults(config: &mut ToolPolicyConfig) {
412 for tool in AUTO_ALLOW_TOOLS {
413 config
414 .policies
415 .entry((*tool).to_string())
416 .and_modify(|policy| *policy = ToolPolicy::Allow)
417 .or_insert(ToolPolicy::Allow);
418 if !config.available_tools.contains(&tool.to_string()) {
419 config.available_tools.push(tool.to_string());
420 }
421 }
422 Self::ensure_network_constraints(config);
423 }
424
425 fn ensure_network_constraints(config: &mut ToolPolicyConfig) {
426 let entry = config
427 .constraints
428 .entry(tools::CURL.to_string())
429 .or_insert_with(ToolConstraints::default);
430
431 if entry.max_response_bytes.is_none() {
432 entry.max_response_bytes = Some(DEFAULT_CURL_MAX_RESPONSE_BYTES);
433 }
434 if entry.allowed_url_schemes.is_none() {
435 entry.allowed_url_schemes = Some(vec!["https".to_string()]);
436 }
437 if entry.denied_url_hosts.is_none() {
438 entry.denied_url_hosts = Some(vec![
439 "localhost".to_string(),
440 "127.0.0.1".to_string(),
441 "0.0.0.0".to_string(),
442 "::1".to_string(),
443 ".localhost".to_string(),
444 ".local".to_string(),
445 ".internal".to_string(),
446 ".lan".to_string(),
447 ]);
448 }
449 }
450
451 fn reset_to_default(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
452 let backup_path = config_path.with_extension("json.bak");
453
454 if let Err(err) = fs::rename(config_path, &backup_path) {
455 eprintln!(
456 "Warning: Unable to back up invalid tool policy config ({}). {}",
457 config_path.display(),
458 err
459 );
460 } else {
461 eprintln!(
462 "Backed up invalid tool policy config to {}",
463 backup_path.display()
464 );
465 }
466
467 let default_config = ToolPolicyConfig::default();
468 Self::write_config(config_path.as_path(), &default_config)?;
469 Ok(default_config)
470 }
471
472 fn write_config(path: &Path, config: &ToolPolicyConfig) -> Result<()> {
473 if let Some(parent) = path.parent() {
474 if !parent.exists() {
475 fs::create_dir_all(parent).with_context(|| {
476 format!(
477 "Failed to create directory for tool policy config at {}",
478 parent.display()
479 )
480 })?;
481 }
482 }
483
484 let serialized = serde_json::to_string_pretty(config)
485 .context("Failed to serialize tool policy config")?;
486
487 fs::write(path, serialized)
488 .with_context(|| format!("Failed to write tool policy config: {}", path.display()))
489 }
490
491 fn convert_from_alternative(alt_config: AlternativeToolPolicyConfig) -> ToolPolicyConfig {
493 let mut policies = IndexMap::new();
494
495 for (tool_name, alt_policy) in alt_config.tools {
497 let policy = if alt_policy.allow {
498 ToolPolicy::Allow
499 } else {
500 ToolPolicy::Deny
501 };
502 policies.insert(tool_name, policy);
503 }
504
505 let mut config = ToolPolicyConfig {
506 version: alt_config.version,
507 available_tools: policies.keys().cloned().collect(),
508 policies,
509 constraints: alt_config.constraints,
510 mcp: McpPolicyStore::default(),
511 };
512 Self::apply_auto_allow_defaults(&mut config);
513 config
514 }
515
516 fn apply_config_policy(&mut self, tool_name: &str, policy: ConfigToolPolicy) {
517 let runtime_policy = match policy {
518 ConfigToolPolicy::Allow => ToolPolicy::Allow,
519 ConfigToolPolicy::Prompt => ToolPolicy::Prompt,
520 ConfigToolPolicy::Deny => ToolPolicy::Deny,
521 };
522
523 self.config
524 .policies
525 .insert(tool_name.to_string(), runtime_policy);
526 }
527
528 fn resolve_config_policy(tools_config: &ToolsConfig, tool_name: &str) -> ConfigToolPolicy {
529 if let Some(policy) = tools_config.policies.get(tool_name) {
530 return policy.clone();
531 }
532
533 match tool_name {
534 tools::LIST_FILES => tools_config
535 .policies
536 .get("list_dir")
537 .or_else(|| tools_config.policies.get("list_directory"))
538 .cloned(),
539 _ => None,
540 }
541 .unwrap_or_else(|| tools_config.default_policy.clone())
542 }
543
544 pub fn apply_tools_config(&mut self, tools_config: &ToolsConfig) -> Result<()> {
546 if self.config.available_tools.is_empty() {
547 return Ok(());
548 }
549
550 for tool in self.config.available_tools.clone() {
551 let config_policy = Self::resolve_config_policy(tools_config, &tool);
552 self.apply_config_policy(&tool, config_policy);
553 }
554
555 Self::apply_auto_allow_defaults(&mut self.config);
556 self.save_config()
557 }
558
559 pub fn update_available_tools(&mut self, tools: Vec<String>) -> Result<()> {
561 let current_tools: HashSet<_> = self.config.policies.keys().cloned().collect();
562 let new_tools: HashSet<_> = tools
563 .iter()
564 .filter(|name| !name.starts_with("mcp::"))
565 .cloned()
566 .collect();
567
568 let mut has_changes = false;
569
570 for tool in tools
572 .iter()
573 .filter(|tool| !tool.starts_with("mcp::") && !current_tools.contains(*tool))
574 {
575 let default_policy = if AUTO_ALLOW_TOOLS.contains(&tool.as_str()) {
576 ToolPolicy::Allow
577 } else {
578 ToolPolicy::Prompt
579 };
580 self.config.policies.insert(tool.clone(), default_policy);
581 has_changes = true;
582 }
583
584 let tools_to_remove: Vec<_> = self
586 .config
587 .policies
588 .keys()
589 .filter(|tool| !new_tools.contains(*tool))
590 .cloned()
591 .collect();
592
593 for tool in tools_to_remove {
594 self.config.policies.shift_remove(&tool);
595 has_changes = true;
596 }
597
598 if self.config.available_tools != tools {
600 self.config.available_tools = tools;
602 has_changes = true;
603 }
604
605 Self::ensure_network_constraints(&mut self.config);
606
607 if has_changes {
608 self.save_config()
609 } else {
610 Ok(())
611 }
612 }
613
614 pub fn update_mcp_tools(
616 &mut self,
617 provider_tools: &HashMap<String, Vec<String>>,
618 ) -> Result<()> {
619 let stored_providers: HashSet<String> = self.config.mcp.providers.keys().cloned().collect();
620 let mut has_changes = false;
621
622 for (provider, tools) in provider_tools {
624 let entry = self
625 .config
626 .mcp
627 .providers
628 .entry(provider.clone())
629 .or_insert_with(McpProviderPolicy::default);
630
631 let existing_tools: HashSet<String> = entry.tools.keys().cloned().collect();
632 let advertised: HashSet<String> = tools.iter().cloned().collect();
633
634 for tool in tools {
636 if !existing_tools.contains(tool) {
637 entry.tools.insert(tool.clone(), ToolPolicy::Prompt);
638 has_changes = true;
639 }
640 }
641
642 for stale in existing_tools.difference(&advertised) {
644 entry.tools.shift_remove(stale.as_str());
645 has_changes = true;
646 }
647 }
648
649 let advertised_providers: HashSet<String> = provider_tools.keys().cloned().collect();
651 for provider in stored_providers
652 .difference(&advertised_providers)
653 .cloned()
654 .collect::<Vec<_>>()
655 {
656 self.config.mcp.providers.shift_remove(provider.as_str());
657 has_changes = true;
658 }
659
660 let stale_runtime_keys: Vec<_> = self
662 .config
663 .policies
664 .keys()
665 .filter(|name| name.starts_with("mcp::"))
666 .cloned()
667 .collect();
668
669 for key in stale_runtime_keys {
670 self.config.policies.shift_remove(&key);
671 has_changes = true;
672 }
673
674 let mut available: Vec<String> = self
676 .config
677 .available_tools
678 .iter()
679 .filter(|name| !name.starts_with("mcp::"))
680 .cloned()
681 .collect();
682
683 for (provider, policy) in &self.config.mcp.providers {
684 for tool in policy.tools.keys() {
685 available.push(format!("mcp::{}::{}", provider, tool));
686 }
687 }
688
689 available.sort();
690 available.dedup();
691
692 if self.config.available_tools != available {
694 self.config.available_tools = available;
695 has_changes = true;
696 }
697
698 if has_changes {
699 self.save_config()
700 } else {
701 Ok(())
702 }
703 }
704
705 pub fn get_mcp_tool_policy(&self, provider: &str, tool: &str) -> ToolPolicy {
707 self.config
708 .mcp
709 .providers
710 .get(provider)
711 .and_then(|policy| policy.tools.get(tool))
712 .cloned()
713 .unwrap_or(ToolPolicy::Prompt)
714 }
715
716 pub fn set_mcp_tool_policy(
718 &mut self,
719 provider: &str,
720 tool: &str,
721 policy: ToolPolicy,
722 ) -> Result<()> {
723 let entry = self
724 .config
725 .mcp
726 .providers
727 .entry(provider.to_string())
728 .or_insert_with(McpProviderPolicy::default);
729 entry.tools.insert(tool.to_string(), policy);
730 self.save_config()
731 }
732
733 pub fn mcp_allowlist(&self) -> &McpAllowListConfig {
735 &self.config.mcp.allowlist
736 }
737
738 pub fn set_mcp_allowlist(&mut self, allowlist: McpAllowListConfig) -> Result<()> {
740 self.config.mcp.allowlist = allowlist;
741 self.save_config()
742 }
743
744 pub fn get_policy(&self, tool_name: &str) -> ToolPolicy {
746 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
747 return self.get_mcp_tool_policy(&provider, &tool);
748 }
749
750 self.config
751 .policies
752 .get(tool_name)
753 .cloned()
754 .unwrap_or(ToolPolicy::Prompt)
755 }
756
757 pub fn get_constraints(&self, tool_name: &str) -> Option<&ToolConstraints> {
759 self.config.constraints.get(tool_name)
760 }
761
762 pub fn should_execute_tool(&mut self, tool_name: &str) -> Result<bool> {
764 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
765 return match self.get_mcp_tool_policy(&provider, &tool) {
766 ToolPolicy::Allow => Ok(true),
767 ToolPolicy::Deny => Ok(false),
768 ToolPolicy::Prompt => {
769 if ToolPolicyManager::is_auto_allow_tool(tool_name) {
770 self.set_mcp_tool_policy(&provider, &tool, ToolPolicy::Allow)?;
771 Ok(true)
772 } else {
773 self.prompt_user_for_tool(tool_name)
774 }
775 }
776 };
777 }
778
779 match self.get_policy(tool_name) {
780 ToolPolicy::Allow => Ok(true),
781 ToolPolicy::Deny => Ok(false),
782 ToolPolicy::Prompt => {
783 if AUTO_ALLOW_TOOLS.contains(&tool_name) {
784 self.set_policy(tool_name, ToolPolicy::Allow)?;
785 return Ok(true);
786 }
787 let should_execute = self.prompt_user_for_tool(tool_name)?;
788 Ok(should_execute)
789 }
790 }
791 }
792
793 pub fn is_auto_allow_tool(tool_name: &str) -> bool {
794 AUTO_ALLOW_TOOLS.contains(&tool_name)
795 }
796
797 fn prompt_user_for_tool(&mut self, tool_name: &str) -> Result<bool> {
799 let interactive = std::io::stdin().is_terminal() && std::io::stdout().is_terminal();
800 let mut renderer = AnsiRenderer::stdout();
801 let banner_style = theme::banner_style();
802
803 if !interactive {
804 let message = format!(
805 "Non-interactive environment detected. Auto-approving '{}' tool.",
806 tool_name
807 );
808 renderer.line_with_style(banner_style, &message)?;
809 return Ok(true);
810 }
811
812 let header = format!("Tool Permission Request: {}", tool_name);
813 renderer.line_with_style(banner_style, &header)?;
814 renderer.line_with_style(
815 banner_style,
816 &format!("The agent wants to use the '{}' tool.", tool_name),
817 )?;
818 renderer.line_with_style(banner_style, "")?;
819 renderer.line_with_style(
820 banner_style,
821 "This decision applies to the current request only.",
822 )?;
823 renderer.line_with_style(
824 banner_style,
825 "Update the policy file or use CLI flags to change the default.",
826 )?;
827 renderer.line_with_style(banner_style, "")?;
828
829 if AUTO_ALLOW_TOOLS.contains(&tool_name) {
830 renderer.line_with_style(
831 banner_style,
832 &format!(
833 "Auto-approving '{}' tool (default trusted tool).",
834 tool_name
835 ),
836 )?;
837 return Ok(true);
838 }
839
840 let rgb = theme::banner_color();
841 let to_ansi_256 = |value: u8| -> u8 {
842 if value < 48 {
843 0
844 } else if value < 114 {
845 1
846 } else {
847 ((value - 35) / 40).min(5)
848 }
849 };
850 let rgb_to_index = |r: u8, g: u8, b: u8| -> u8 {
851 let r_idx = to_ansi_256(r);
852 let g_idx = to_ansi_256(g);
853 let b_idx = to_ansi_256(b);
854 16 + 36 * r_idx + 6 * g_idx + b_idx
855 };
856 let color_index = rgb_to_index(rgb.0, rgb.1, rgb.2);
857 let dialog_color = ConsoleColor::Color256(color_index);
858 let tinted_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
859
860 let mut dialog_theme = ColorfulTheme::default();
861 dialog_theme.prompt_style = tinted_style;
862 dialog_theme.prompt_prefix = style("—".to_string()).for_stderr().fg(dialog_color);
863 dialog_theme.prompt_suffix = style("—".to_string()).for_stderr().fg(dialog_color);
864 dialog_theme.hint_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
865 dialog_theme.defaults_style = dialog_theme.hint_style.clone();
866 dialog_theme.success_prefix = style("✓".to_string()).for_stderr().fg(dialog_color);
867 dialog_theme.success_suffix = style("·".to_string()).for_stderr().fg(dialog_color);
868 dialog_theme.error_prefix = style("✗".to_string()).for_stderr().fg(dialog_color);
869 dialog_theme.error_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
870 dialog_theme.values_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
871
872 let prompt_text = format!("Allow the agent to use '{}'?", tool_name);
873
874 match Confirm::with_theme(&dialog_theme)
875 .with_prompt(prompt_text)
876 .default(false)
877 .interact()
878 {
879 Ok(confirmed) => {
880 let message = if confirmed {
881 format!("✓ Approved: '{}' tool will run now", tool_name)
882 } else {
883 format!("✗ Denied: '{}' tool will not run", tool_name)
884 };
885 let style = if confirmed {
886 MessageStyle::Tool
887 } else {
888 MessageStyle::Error
889 };
890 renderer.line(style, &message)?;
891 Ok(confirmed)
892 }
893 Err(e) => {
894 renderer.line(
895 MessageStyle::Error,
896 &format!("Failed to read confirmation: {}", e),
897 )?;
898 Ok(false)
899 }
900 }
901 }
902
903 pub fn set_policy(&mut self, tool_name: &str, policy: ToolPolicy) -> Result<()> {
905 if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
906 return self.set_mcp_tool_policy(&provider, &tool, policy);
907 }
908
909 self.config.policies.insert(tool_name.to_string(), policy);
910 self.save_config()
911 }
912
913 pub fn reset_all_to_prompt(&mut self) -> Result<()> {
915 for policy in self.config.policies.values_mut() {
916 *policy = ToolPolicy::Prompt;
917 }
918 for provider in self.config.mcp.providers.values_mut() {
919 for policy in provider.tools.values_mut() {
920 *policy = ToolPolicy::Prompt;
921 }
922 }
923 self.save_config()
924 }
925
926 pub fn allow_all_tools(&mut self) -> Result<()> {
928 for policy in self.config.policies.values_mut() {
929 *policy = ToolPolicy::Allow;
930 }
931 for provider in self.config.mcp.providers.values_mut() {
932 for policy in provider.tools.values_mut() {
933 *policy = ToolPolicy::Allow;
934 }
935 }
936 self.save_config()
937 }
938
939 pub fn deny_all_tools(&mut self) -> Result<()> {
941 for policy in self.config.policies.values_mut() {
942 *policy = ToolPolicy::Deny;
943 }
944 for provider in self.config.mcp.providers.values_mut() {
945 for policy in provider.tools.values_mut() {
946 *policy = ToolPolicy::Deny;
947 }
948 }
949 self.save_config()
950 }
951
952 pub fn get_policy_summary(&self) -> IndexMap<String, ToolPolicy> {
954 let mut summary = self.config.policies.clone();
955 for (provider, policy) in &self.config.mcp.providers {
956 for (tool, status) in &policy.tools {
957 summary.insert(format!("mcp::{}::{}", provider, tool), status.clone());
958 }
959 }
960 summary
961 }
962
963 fn save_config(&self) -> Result<()> {
965 Self::write_config(&self.config_path, &self.config)
966 }
967
968 pub fn print_status(&self) {
970 println!("{}", style("Tool Policy Status").cyan().bold());
971 println!("Config file: {}", self.config_path.display());
972 println!();
973
974 let summary = self.get_policy_summary();
975
976 if summary.is_empty() {
977 println!("No tools configured yet.");
978 return;
979 }
980
981 let mut allow_count = 0;
982 let mut prompt_count = 0;
983 let mut deny_count = 0;
984
985 for (tool, policy) in &summary {
986 let (status, color_name) = match policy {
987 ToolPolicy::Allow => {
988 allow_count += 1;
989 ("ALLOW", "green")
990 }
991 ToolPolicy::Prompt => {
992 prompt_count += 1;
993 ("PROMPT", "yellow")
994 }
995 ToolPolicy::Deny => {
996 deny_count += 1;
997 ("DENY", "red")
998 }
999 };
1000
1001 let status_styled = match color_name {
1002 "green" => style(status).green(),
1003 "yellow" => style(status).yellow(),
1004 "red" => style(status).red(),
1005 _ => style(status),
1006 };
1007
1008 println!(
1009 " {} {}",
1010 style(format!("{:15}", tool)).cyan(),
1011 status_styled
1012 );
1013 }
1014
1015 println!();
1016 println!(
1017 "Summary: {} allowed, {} prompt, {} denied",
1018 style(allow_count).green(),
1019 style(prompt_count).yellow(),
1020 style(deny_count).red()
1021 );
1022 }
1023
1024 pub fn config_path(&self) -> &Path {
1026 &self.config_path
1027 }
1028}
1029
1030#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1032pub struct ToolConstraints {
1033 #[serde(default)]
1035 pub allowed_modes: Option<Vec<String>>,
1036 #[serde(default)]
1038 pub max_results_per_call: Option<usize>,
1039 #[serde(default)]
1041 pub max_items_per_call: Option<usize>,
1042 #[serde(default)]
1044 pub default_response_format: Option<String>,
1045 #[serde(default)]
1047 pub max_bytes_per_read: Option<usize>,
1048 #[serde(default)]
1050 pub max_response_bytes: Option<usize>,
1051 #[serde(default)]
1053 pub allowed_url_schemes: Option<Vec<String>>,
1054 #[serde(default)]
1056 pub denied_url_hosts: Option<Vec<String>>,
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061 use super::*;
1062 use crate::config::constants::tools;
1063 use tempfile::tempdir;
1064
1065 #[test]
1066 fn test_tool_policy_config_serialization() {
1067 let mut config = ToolPolicyConfig::default();
1068 config.available_tools = vec![tools::READ_FILE.to_string(), tools::WRITE_FILE.to_string()];
1069 config
1070 .policies
1071 .insert(tools::READ_FILE.to_string(), ToolPolicy::Allow);
1072 config
1073 .policies
1074 .insert(tools::WRITE_FILE.to_string(), ToolPolicy::Prompt);
1075
1076 let json = serde_json::to_string_pretty(&config).unwrap();
1077 let deserialized: ToolPolicyConfig = serde_json::from_str(&json).unwrap();
1078
1079 assert_eq!(config.available_tools, deserialized.available_tools);
1080 assert_eq!(config.policies, deserialized.policies);
1081 }
1082
1083 #[test]
1084 fn test_policy_updates() {
1085 let dir = tempdir().unwrap();
1086 let config_path = dir.path().join("tool-policy.json");
1087
1088 let mut config = ToolPolicyConfig::default();
1089 config.available_tools = vec!["tool1".to_string()];
1090 config
1091 .policies
1092 .insert("tool1".to_string(), ToolPolicy::Prompt);
1093
1094 let content = serde_json::to_string_pretty(&config).unwrap();
1096 fs::write(&config_path, content).unwrap();
1097
1098 let mut loaded_config = ToolPolicyManager::load_or_create_config(&config_path).unwrap();
1100
1101 let new_tools = vec!["tool1".to_string(), "tool2".to_string()];
1103 let current_tools: std::collections::HashSet<_> =
1104 loaded_config.available_tools.iter().collect();
1105
1106 for tool in &new_tools {
1107 if !current_tools.contains(tool) {
1108 loaded_config
1109 .policies
1110 .insert(tool.clone(), ToolPolicy::Prompt);
1111 }
1112 }
1113
1114 loaded_config.available_tools = new_tools;
1115
1116 assert_eq!(loaded_config.policies.len(), 2);
1117 assert_eq!(
1118 loaded_config.policies.get("tool2"),
1119 Some(&ToolPolicy::Prompt)
1120 );
1121 }
1122}