vtcode_core/
tool_policy.rs

1//! Tool policy management system
2//!
3//! This module manages user preferences for tool usage, storing choices in
4//! ~/.vtcode/tool-policy.json to minimize repeated prompts while maintaining
5//! user control overwhich tools the agent can use.
6
7use anyhow::{Context, Result};
8use dialoguer::{
9    Confirm,
10    console::{Color as ConsoleColor, Style as ConsoleStyle, style},
11    theme::ColorfulTheme,
12};
13use indexmap::IndexMap;
14use is_terminal::IsTerminal;
15use serde::{Deserialize, Serialize};
16use std::collections::{BTreeMap, HashMap, HashSet};
17use std::fs;
18use std::path::{Path, PathBuf};
19
20use crate::ui::theme;
21use crate::utils::ansi::{AnsiRenderer, MessageStyle};
22
23use crate::config::constants::tools;
24use crate::config::core::tools::{ToolPolicy as ConfigToolPolicy, ToolsConfig};
25use crate::config::mcp::{McpAllowListConfig, McpAllowListRules};
26
27const AUTO_ALLOW_TOOLS: &[&str] = &[
28    tools::GREP_SEARCH,
29    tools::LIST_FILES,
30    tools::UPDATE_PLAN,
31    tools::RUN_TERMINAL_CMD,
32    tools::READ_FILE,
33    tools::EDIT_FILE,
34    tools::AST_GREP_SEARCH,
35    tools::SIMPLE_SEARCH,
36    tools::BASH,
37];
38const DEFAULT_CURL_MAX_RESPONSE_BYTES: usize = 64 * 1024;
39
40/// Tool execution policy
41#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
42#[serde(rename_all = "lowercase")]
43pub enum ToolPolicy {
44    /// Allow tool execution without prompting
45    Allow,
46    /// Prompt user for confirmation each time
47    Prompt,
48    /// Never allow tool execution
49    Deny,
50}
51
52impl Default for ToolPolicy {
53    fn default() -> Self {
54        ToolPolicy::Prompt
55    }
56}
57
58/// Tool policy configuration stored in ~/.vtcode/tool-policy.json
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ToolPolicyConfig {
61    /// Configuration version for future compatibility
62    pub version: u32,
63    /// Available tools at time of last update
64    pub available_tools: Vec<String>,
65    /// Policy for each tool
66    pub policies: IndexMap<String, ToolPolicy>,
67    /// Optional per-tool constraints to scope permissions and enforce safety
68    #[serde(default)]
69    pub constraints: IndexMap<String, ToolConstraints>,
70    /// MCP-specific policy configuration
71    #[serde(default)]
72    pub mcp: McpPolicyStore,
73}
74
75impl Default for ToolPolicyConfig {
76    fn default() -> Self {
77        Self {
78            version: 1,
79            available_tools: Vec::new(),
80            policies: IndexMap::new(),
81            constraints: IndexMap::new(),
82            mcp: McpPolicyStore::default(),
83        }
84    }
85}
86
87/// Stored MCP policy state, persisted alongside standard tool policies
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct McpPolicyStore {
90    /// Active MCP allow list configuration
91    #[serde(default = "default_secure_mcp_allowlist")]
92    pub allowlist: McpAllowListConfig,
93    /// Provider-specific tool policies (allow/prompt/deny)
94    #[serde(default)]
95    pub providers: IndexMap<String, McpProviderPolicy>,
96}
97
98impl Default for McpPolicyStore {
99    fn default() -> Self {
100        Self {
101            allowlist: default_secure_mcp_allowlist(),
102            providers: IndexMap::new(),
103        }
104    }
105}
106
107/// MCP provider policy entry containing per-tool permissions
108#[derive(Debug, Clone, Serialize, Deserialize, Default)]
109pub struct McpProviderPolicy {
110    #[serde(default)]
111    pub tools: IndexMap<String, ToolPolicy>,
112}
113
114fn default_secure_mcp_allowlist() -> McpAllowListConfig {
115    let mut allowlist = McpAllowListConfig::default();
116    allowlist.enforce = true;
117
118    allowlist.default.logging = Some(vec![
119        "mcp.provider_initialized".to_string(),
120        "mcp.provider_initialization_failed".to_string(),
121        "mcp.tool_filtered".to_string(),
122        "mcp.tool_execution".to_string(),
123        "mcp.tool_failed".to_string(),
124        "mcp.tool_denied".to_string(),
125    ]);
126
127    allowlist.default.configuration = Some(BTreeMap::from([
128        (
129            "client".to_string(),
130            vec![
131                "max_concurrent_connections".to_string(),
132                "request_timeout_seconds".to_string(),
133                "retry_attempts".to_string(),
134            ],
135        ),
136        (
137            "ui".to_string(),
138            vec![
139                "mode".to_string(),
140                "max_events".to_string(),
141                "show_provider_names".to_string(),
142            ],
143        ),
144        (
145            "server".to_string(),
146            vec![
147                "enabled".to_string(),
148                "bind_address".to_string(),
149                "port".to_string(),
150                "transport".to_string(),
151                "name".to_string(),
152                "version".to_string(),
153            ],
154        ),
155    ]));
156
157    let mut time_rules = McpAllowListRules::default();
158    time_rules.tools = Some(vec![
159        "get_*".to_string(),
160        "list_*".to_string(),
161        "convert_timezone".to_string(),
162        "describe_timezone".to_string(),
163        "time_*".to_string(),
164    ]);
165    time_rules.resources = Some(vec!["timezone:*".to_string(), "location:*".to_string()]);
166    time_rules.logging = Some(vec![
167        "mcp.tool_execution".to_string(),
168        "mcp.tool_failed".to_string(),
169        "mcp.tool_denied".to_string(),
170        "mcp.tool_filtered".to_string(),
171        "mcp.provider_initialized".to_string(),
172    ]);
173    time_rules.configuration = Some(BTreeMap::from([
174        (
175            "provider".to_string(),
176            vec!["max_concurrent_requests".to_string()],
177        ),
178        (
179            "time".to_string(),
180            vec!["local_timezone_override".to_string()],
181        ),
182    ]));
183    allowlist.providers.insert("time".to_string(), time_rules);
184
185    let mut context_rules = McpAllowListRules::default();
186    context_rules.tools = Some(vec![
187        "search_*".to_string(),
188        "fetch_*".to_string(),
189        "list_*".to_string(),
190        "context7_*".to_string(),
191        "get_*".to_string(),
192    ]);
193    context_rules.resources = Some(vec![
194        "docs::*".to_string(),
195        "snippets::*".to_string(),
196        "repositories::*".to_string(),
197        "context7::*".to_string(),
198    ]);
199    context_rules.prompts = Some(vec![
200        "context7::*".to_string(),
201        "support::*".to_string(),
202        "docs::*".to_string(),
203    ]);
204    context_rules.logging = Some(vec![
205        "mcp.tool_execution".to_string(),
206        "mcp.tool_failed".to_string(),
207        "mcp.tool_denied".to_string(),
208        "mcp.tool_filtered".to_string(),
209        "mcp.provider_initialized".to_string(),
210    ]);
211    context_rules.configuration = Some(BTreeMap::from([
212        (
213            "provider".to_string(),
214            vec!["max_concurrent_requests".to_string()],
215        ),
216        (
217            "context7".to_string(),
218            vec![
219                "workspace".to_string(),
220                "search_scope".to_string(),
221                "max_results".to_string(),
222            ],
223        ),
224    ]));
225    allowlist
226        .providers
227        .insert("context7".to_string(), context_rules);
228
229    let mut seq_rules = McpAllowListRules::default();
230    seq_rules.tools = Some(vec![
231        "plan".to_string(),
232        "critique".to_string(),
233        "reflect".to_string(),
234        "decompose".to_string(),
235        "sequential_*".to_string(),
236    ]);
237    seq_rules.prompts = Some(vec![
238        "sequential-thinking::*".to_string(),
239        "plan".to_string(),
240        "reflect".to_string(),
241        "critique".to_string(),
242    ]);
243    seq_rules.logging = Some(vec![
244        "mcp.tool_execution".to_string(),
245        "mcp.tool_failed".to_string(),
246        "mcp.tool_denied".to_string(),
247        "mcp.tool_filtered".to_string(),
248        "mcp.provider_initialized".to_string(),
249    ]);
250    seq_rules.configuration = Some(BTreeMap::from([
251        (
252            "provider".to_string(),
253            vec!["max_concurrent_requests".to_string()],
254        ),
255        (
256            "sequencing".to_string(),
257            vec!["max_depth".to_string(), "max_branches".to_string()],
258        ),
259    ]));
260    allowlist
261        .providers
262        .insert("sequential-thinking".to_string(), seq_rules);
263
264    allowlist
265}
266
267fn parse_mcp_policy_key(tool_name: &str) -> Option<(String, String)> {
268    let mut parts = tool_name.splitn(3, "::");
269    match (parts.next()?, parts.next(), parts.next()) {
270        ("mcp", Some(provider), Some(tool)) if !provider.is_empty() && !tool.is_empty() => {
271            Some((provider.to_string(), tool.to_string()))
272        }
273        _ => None,
274    }
275}
276
277/// Alternative tool policy configuration format (user's format)
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct AlternativeToolPolicyConfig {
280    /// Configuration version for future compatibility
281    pub version: u32,
282    /// Default policy settings
283    pub default: AlternativeDefaultPolicy,
284    /// Tool-specific policies
285    pub tools: IndexMap<String, AlternativeToolPolicy>,
286    /// Optional per-tool constraints (ignored if absent)
287    #[serde(default)]
288    pub constraints: IndexMap<String, ToolConstraints>,
289}
290
291/// Default policy in alternative format
292#[derive(Debug, Clone, Serialize, Deserialize)]
293pub struct AlternativeDefaultPolicy {
294    /// Whether to allow by default
295    pub allow: bool,
296    /// Rate limit per run
297    pub rate_limit_per_run: u32,
298    /// Max concurrent executions
299    pub max_concurrent: u32,
300    /// Allow filesystem writes
301    pub fs_write: bool,
302    /// Allow network access
303    pub network: bool,
304}
305
306/// Tool policy in alternative format
307#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct AlternativeToolPolicy {
309    /// Whether to allow this tool
310    pub allow: bool,
311    /// Allow filesystem writes (optional)
312    #[serde(default)]
313    pub fs_write: bool,
314    /// Allow network access (optional)
315    #[serde(default)]
316    pub network: bool,
317    /// Arguments policy (optional)
318    #[serde(default)]
319    pub args_policy: Option<AlternativeArgsPolicy>,
320}
321
322/// Arguments policy in alternative format
323#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct AlternativeArgsPolicy {
325    /// Substrings to deny
326    pub deny_substrings: Vec<String>,
327}
328
329/// Tool policy manager
330#[derive(Clone)]
331pub struct ToolPolicyManager {
332    config_path: PathBuf,
333    config: ToolPolicyConfig,
334}
335
336impl ToolPolicyManager {
337    /// Create a new tool policy manager
338    pub fn new() -> Result<Self> {
339        let config_path = Self::get_config_path()?;
340        let config = Self::load_or_create_config(&config_path)?;
341
342        Ok(Self {
343            config_path,
344            config,
345        })
346    }
347
348    /// Create a new tool policy manager with workspace-specific config
349    pub fn new_with_workspace(workspace_root: &PathBuf) -> Result<Self> {
350        let config_path = Self::get_workspace_config_path(workspace_root)?;
351        let config = Self::load_or_create_config(&config_path)?;
352
353        Ok(Self {
354            config_path,
355            config,
356        })
357    }
358
359    /// Get the path to the tool policy configuration file
360    fn get_config_path() -> Result<PathBuf> {
361        let home_dir = dirs::home_dir().context("Could not determine home directory")?;
362
363        let vtcode_dir = home_dir.join(".vtcode");
364        if !vtcode_dir.exists() {
365            fs::create_dir_all(&vtcode_dir).context("Failed to create ~/.vtcode directory")?;
366        }
367
368        Ok(vtcode_dir.join("tool-policy.json"))
369    }
370
371    /// Get the path to the workspace-specific tool policy configuration file
372    fn get_workspace_config_path(workspace_root: &PathBuf) -> Result<PathBuf> {
373        let workspace_vtcode_dir = workspace_root.join(".vtcode");
374
375        if !workspace_vtcode_dir.exists() {
376            fs::create_dir_all(&workspace_vtcode_dir).with_context(|| {
377                format!(
378                    "Failed to create workspace policy directory at {}",
379                    workspace_vtcode_dir.display()
380                )
381            })?;
382        }
383
384        Ok(workspace_vtcode_dir.join("tool-policy.json"))
385    }
386
387    /// Load existing config or create new one with all tools as "prompt"
388    fn load_or_create_config(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
389        if config_path.exists() {
390            let content =
391                fs::read_to_string(config_path).context("Failed to read tool policy config")?;
392
393            // Try to parse as alternative format first
394            if let Ok(alt_config) = serde_json::from_str::<AlternativeToolPolicyConfig>(&content) {
395                // Convert alternative format to standard format
396                return Ok(Self::convert_from_alternative(alt_config));
397            }
398
399            // Fall back to standard format with graceful recovery on parse errors
400            match serde_json::from_str(&content) {
401                Ok(mut config) => {
402                    Self::apply_auto_allow_defaults(&mut config);
403                    Self::ensure_network_constraints(&mut config);
404                    Ok(config)
405                }
406                Err(parse_err) => {
407                    eprintln!(
408                        "Warning: Invalid tool policy config at {} ({}). Resetting to defaults.",
409                        config_path.display(),
410                        parse_err
411                    );
412                    Self::reset_to_default(config_path)
413                }
414            }
415        } else {
416            // Create new config with empty tools list
417            let mut config = ToolPolicyConfig::default();
418            Self::apply_auto_allow_defaults(&mut config);
419            Self::ensure_network_constraints(&mut config);
420            Ok(config)
421        }
422    }
423
424    fn apply_auto_allow_defaults(config: &mut ToolPolicyConfig) {
425        for tool in AUTO_ALLOW_TOOLS {
426            config
427                .policies
428                .entry((*tool).to_string())
429                .and_modify(|policy| *policy = ToolPolicy::Allow)
430                .or_insert(ToolPolicy::Allow);
431            if !config.available_tools.contains(&tool.to_string()) {
432                config.available_tools.push(tool.to_string());
433            }
434        }
435        Self::ensure_network_constraints(config);
436    }
437
438    fn ensure_network_constraints(config: &mut ToolPolicyConfig) {
439        let entry = config
440            .constraints
441            .entry(tools::CURL.to_string())
442            .or_insert_with(ToolConstraints::default);
443
444        if entry.max_response_bytes.is_none() {
445            entry.max_response_bytes = Some(DEFAULT_CURL_MAX_RESPONSE_BYTES);
446        }
447        if entry.allowed_url_schemes.is_none() {
448            entry.allowed_url_schemes = Some(vec!["https".to_string()]);
449        }
450        if entry.denied_url_hosts.is_none() {
451            entry.denied_url_hosts = Some(vec![
452                "localhost".to_string(),
453                "127.0.0.1".to_string(),
454                "0.0.0.0".to_string(),
455                "::1".to_string(),
456                ".localhost".to_string(),
457                ".local".to_string(),
458                ".internal".to_string(),
459                ".lan".to_string(),
460            ]);
461        }
462    }
463
464    fn reset_to_default(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
465        let backup_path = config_path.with_extension("json.bak");
466
467        if let Err(err) = fs::rename(config_path, &backup_path) {
468            eprintln!(
469                "Warning: Unable to back up invalid tool policy config ({}). {}",
470                config_path.display(),
471                err
472            );
473        } else {
474            eprintln!(
475                "Backed up invalid tool policy config to {}",
476                backup_path.display()
477            );
478        }
479
480        let default_config = ToolPolicyConfig::default();
481        Self::write_config(config_path.as_path(), &default_config)?;
482        Ok(default_config)
483    }
484
485    fn write_config(path: &Path, config: &ToolPolicyConfig) -> Result<()> {
486        if let Some(parent) = path.parent() {
487            if !parent.exists() {
488                fs::create_dir_all(parent).with_context(|| {
489                    format!(
490                        "Failed to create directory for tool policy config at {}",
491                        parent.display()
492                    )
493                })?;
494            }
495        }
496
497        let serialized = serde_json::to_string_pretty(config)
498            .context("Failed to serialize tool policy config")?;
499
500        fs::write(path, serialized)
501            .with_context(|| format!("Failed to write tool policy config: {}", path.display()))
502    }
503
504    /// Convert alternative format to standard format
505    fn convert_from_alternative(alt_config: AlternativeToolPolicyConfig) -> ToolPolicyConfig {
506        let mut policies = IndexMap::new();
507
508        // Convert tool policies
509        for (tool_name, alt_policy) in alt_config.tools {
510            let policy = if alt_policy.allow {
511                ToolPolicy::Allow
512            } else {
513                ToolPolicy::Deny
514            };
515            policies.insert(tool_name, policy);
516        }
517
518        let mut config = ToolPolicyConfig {
519            version: alt_config.version,
520            available_tools: policies.keys().cloned().collect(),
521            policies,
522            constraints: alt_config.constraints,
523            mcp: McpPolicyStore::default(),
524        };
525        Self::apply_auto_allow_defaults(&mut config);
526        config
527    }
528
529    fn apply_config_policy(&mut self, tool_name: &str, policy: ConfigToolPolicy) {
530        let runtime_policy = match policy {
531            ConfigToolPolicy::Allow => ToolPolicy::Allow,
532            ConfigToolPolicy::Prompt => ToolPolicy::Prompt,
533            ConfigToolPolicy::Deny => ToolPolicy::Deny,
534        };
535
536        self.config
537            .policies
538            .insert(tool_name.to_string(), runtime_policy);
539    }
540
541    fn resolve_config_policy(tools_config: &ToolsConfig, tool_name: &str) -> ConfigToolPolicy {
542        if let Some(policy) = tools_config.policies.get(tool_name) {
543            return policy.clone();
544        }
545
546        match tool_name {
547            tools::LIST_FILES => tools_config
548                .policies
549                .get("list_dir")
550                .or_else(|| tools_config.policies.get("list_directory"))
551                .cloned(),
552            _ => None,
553        }
554        .unwrap_or_else(|| tools_config.default_policy.clone())
555    }
556
557    /// Apply policies defined in vtcode.toml to the runtime policy manager
558    pub fn apply_tools_config(&mut self, tools_config: &ToolsConfig) -> Result<()> {
559        if self.config.available_tools.is_empty() {
560            return Ok(());
561        }
562
563        for tool in self.config.available_tools.clone() {
564            let config_policy = Self::resolve_config_policy(tools_config, &tool);
565            self.apply_config_policy(&tool, config_policy);
566        }
567
568        Self::apply_auto_allow_defaults(&mut self.config);
569        self.save_config()
570    }
571
572    /// Update the tool list and save configuration
573    pub fn update_available_tools(&mut self, tools: Vec<String>) -> Result<()> {
574        let current_tools: HashSet<_> = self.config.policies.keys().cloned().collect();
575        let new_tools: HashSet<_> = tools
576            .iter()
577            .filter(|name| !name.starts_with("mcp::"))
578            .cloned()
579            .collect();
580
581        let mut has_changes = false;
582
583        // Add new tools with appropriate defaults
584        for tool in tools
585            .iter()
586            .filter(|tool| !tool.starts_with("mcp::") && !current_tools.contains(*tool))
587        {
588            let default_policy = if AUTO_ALLOW_TOOLS.contains(&tool.as_str()) {
589                ToolPolicy::Allow
590            } else {
591                ToolPolicy::Prompt
592            };
593            self.config.policies.insert(tool.clone(), default_policy);
594            has_changes = true;
595        }
596
597        // Remove deleted tools - use itertools to find tools to remove
598        let tools_to_remove: Vec<_> = self
599            .config
600            .policies
601            .keys()
602            .filter(|tool| !new_tools.contains(*tool))
603            .cloned()
604            .collect();
605
606        for tool in tools_to_remove {
607            self.config.policies.shift_remove(&tool);
608            has_changes = true;
609        }
610
611        // Check if available tools list has actually changed
612        if self.config.available_tools != tools {
613            // Update available tools list
614            self.config.available_tools = tools;
615            has_changes = true;
616        }
617
618        Self::ensure_network_constraints(&mut self.config);
619
620        if has_changes {
621            self.save_config()
622        } else {
623            Ok(())
624        }
625    }
626
627    /// Synchronize MCP provider tool lists with persisted policies
628    pub fn update_mcp_tools(
629        &mut self,
630        provider_tools: &HashMap<String, Vec<String>>,
631    ) -> Result<()> {
632        let stored_providers: HashSet<String> = self.config.mcp.providers.keys().cloned().collect();
633        let mut has_changes = false;
634
635        // Update or insert provider entries
636        for (provider, tools) in provider_tools {
637            let entry = self
638                .config
639                .mcp
640                .providers
641                .entry(provider.clone())
642                .or_insert_with(McpProviderPolicy::default);
643
644            let existing_tools: HashSet<String> = entry.tools.keys().cloned().collect();
645            let advertised: HashSet<String> = tools.iter().cloned().collect();
646
647            // Add new tools with default Prompt policy
648            for tool in tools {
649                if !existing_tools.contains(tool) {
650                    entry.tools.insert(tool.clone(), ToolPolicy::Prompt);
651                    has_changes = true;
652                }
653            }
654
655            // Remove tools no longer advertised
656            for stale in existing_tools.difference(&advertised) {
657                entry.tools.shift_remove(stale.as_str());
658                has_changes = true;
659            }
660        }
661
662        // Remove providers that are no longer present
663        let advertised_providers: HashSet<String> = provider_tools.keys().cloned().collect();
664        for provider in stored_providers
665            .difference(&advertised_providers)
666            .cloned()
667            .collect::<Vec<_>>()
668        {
669            self.config.mcp.providers.shift_remove(provider.as_str());
670            has_changes = true;
671        }
672
673        // Remove any stale MCP keys from the primary policy map
674        let stale_runtime_keys: Vec<_> = self
675            .config
676            .policies
677            .keys()
678            .filter(|name| name.starts_with("mcp::"))
679            .cloned()
680            .collect();
681
682        for key in stale_runtime_keys {
683            self.config.policies.shift_remove(&key);
684            has_changes = true;
685        }
686
687        // Refresh available tools list with MCP entries included
688        let mut available: Vec<String> = self
689            .config
690            .available_tools
691            .iter()
692            .filter(|name| !name.starts_with("mcp::"))
693            .cloned()
694            .collect();
695
696        for (provider, policy) in &self.config.mcp.providers {
697            for tool in policy.tools.keys() {
698                available.push(format!("mcp::{}::{}", provider, tool));
699            }
700        }
701
702        available.sort();
703        available.dedup();
704
705        // Check if the available tools list has actually changed
706        if self.config.available_tools != available {
707            self.config.available_tools = available;
708            has_changes = true;
709        }
710
711        if has_changes {
712            self.save_config()
713        } else {
714            Ok(())
715        }
716    }
717
718    /// Retrieve policy for a specific MCP tool
719    pub fn get_mcp_tool_policy(&self, provider: &str, tool: &str) -> ToolPolicy {
720        self.config
721            .mcp
722            .providers
723            .get(provider)
724            .and_then(|policy| policy.tools.get(tool))
725            .cloned()
726            .unwrap_or(ToolPolicy::Prompt)
727    }
728
729    /// Update policy for a specific MCP tool
730    pub fn set_mcp_tool_policy(
731        &mut self,
732        provider: &str,
733        tool: &str,
734        policy: ToolPolicy,
735    ) -> Result<()> {
736        let entry = self
737            .config
738            .mcp
739            .providers
740            .entry(provider.to_string())
741            .or_insert_with(McpProviderPolicy::default);
742        entry.tools.insert(tool.to_string(), policy);
743        self.save_config()
744    }
745
746    /// Access the persisted MCP allow list configuration
747    pub fn mcp_allowlist(&self) -> &McpAllowListConfig {
748        &self.config.mcp.allowlist
749    }
750
751    /// Replace the persisted MCP allow list configuration
752    pub fn set_mcp_allowlist(&mut self, allowlist: McpAllowListConfig) -> Result<()> {
753        self.config.mcp.allowlist = allowlist;
754        self.save_config()
755    }
756
757    /// Get policy for a specific tool
758    pub fn get_policy(&self, tool_name: &str) -> ToolPolicy {
759        if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
760            return self.get_mcp_tool_policy(&provider, &tool);
761        }
762
763        self.config
764            .policies
765            .get(tool_name)
766            .cloned()
767            .unwrap_or(ToolPolicy::Prompt)
768    }
769
770    /// Get optional constraints for a specific tool
771    pub fn get_constraints(&self, tool_name: &str) -> Option<&ToolConstraints> {
772        self.config.constraints.get(tool_name)
773    }
774
775    /// Check if tool should be executed based on policy
776    pub fn should_execute_tool(&mut self, tool_name: &str) -> Result<bool> {
777        if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
778            return match self.get_mcp_tool_policy(&provider, &tool) {
779                ToolPolicy::Allow => Ok(true),
780                ToolPolicy::Deny => Ok(false),
781                ToolPolicy::Prompt => {
782                    if ToolPolicyManager::is_auto_allow_tool(tool_name) {
783                        self.set_mcp_tool_policy(&provider, &tool, ToolPolicy::Allow)?;
784                        Ok(true)
785                    } else {
786                        self.prompt_user_for_tool(tool_name)
787                    }
788                }
789            };
790        }
791
792        match self.get_policy(tool_name) {
793            ToolPolicy::Allow => Ok(true),
794            ToolPolicy::Deny => Ok(false),
795            ToolPolicy::Prompt => {
796                if AUTO_ALLOW_TOOLS.contains(&tool_name) {
797                    self.set_policy(tool_name, ToolPolicy::Allow)?;
798                    return Ok(true);
799                }
800                let should_execute = self.prompt_user_for_tool(tool_name)?;
801                Ok(should_execute)
802            }
803        }
804    }
805
806    pub fn is_auto_allow_tool(tool_name: &str) -> bool {
807        AUTO_ALLOW_TOOLS.contains(&tool_name)
808    }
809
810    /// Prompt user for tool execution permission
811    fn prompt_user_for_tool(&mut self, tool_name: &str) -> Result<bool> {
812        let interactive = std::io::stdin().is_terminal() && std::io::stdout().is_terminal();
813        let mut renderer = AnsiRenderer::stdout();
814        let banner_style = theme::banner_style();
815
816        if !interactive {
817            let message = format!(
818                "Non-interactive environment detected. Auto-approving '{}' tool.",
819                tool_name
820            );
821            renderer.line_with_style(banner_style, &message)?;
822            self.set_policy(tool_name, ToolPolicy::Allow)?;
823            return Ok(true);
824        }
825
826        let header = format!("Tool Permission Request: {}", tool_name);
827        renderer.line_with_style(banner_style, &header)?;
828        renderer.line_with_style(
829            banner_style,
830            &format!("The agent wants to use the '{}' tool.", tool_name),
831        )?;
832        renderer.line_with_style(banner_style, "")?;
833        renderer.line_with_style(
834            banner_style,
835            "This decision applies to the current request only.",
836        )?;
837        renderer.line_with_style(
838            banner_style,
839            "Update the policy file or use CLI flags to change the default.",
840        )?;
841        renderer.line_with_style(banner_style, "")?;
842
843        if AUTO_ALLOW_TOOLS.contains(&tool_name) {
844            renderer.line_with_style(
845                banner_style,
846                &format!(
847                    "Auto-approving '{}' tool (default trusted tool).",
848                    tool_name
849                ),
850            )?;
851            return Ok(true);
852        }
853
854        let rgb = theme::banner_color();
855        let to_ansi_256 = |value: u8| -> u8 {
856            if value < 48 {
857                0
858            } else if value < 114 {
859                1
860            } else {
861                ((value - 35) / 40).min(5)
862            }
863        };
864        let rgb_to_index = |r: u8, g: u8, b: u8| -> u8 {
865            let r_idx = to_ansi_256(r);
866            let g_idx = to_ansi_256(g);
867            let b_idx = to_ansi_256(b);
868            16 + 36 * r_idx + 6 * g_idx + b_idx
869        };
870        let color_index = rgb_to_index(rgb.0, rgb.1, rgb.2);
871        let dialog_color = ConsoleColor::Color256(color_index);
872        let tinted_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
873
874        let mut dialog_theme = ColorfulTheme::default();
875        dialog_theme.prompt_style = tinted_style;
876        dialog_theme.prompt_prefix = style("—".to_string()).for_stderr().fg(dialog_color);
877        dialog_theme.prompt_suffix = style("—".to_string()).for_stderr().fg(dialog_color);
878        dialog_theme.hint_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
879        dialog_theme.defaults_style = dialog_theme.hint_style.clone();
880        dialog_theme.success_prefix = style("✓".to_string()).for_stderr().fg(dialog_color);
881        dialog_theme.success_suffix = style("·".to_string()).for_stderr().fg(dialog_color);
882        dialog_theme.error_prefix = style("✗".to_string()).for_stderr().fg(dialog_color);
883        dialog_theme.error_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
884        dialog_theme.values_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
885
886        let prompt_text = format!("Allow the agent to use '{}'?", tool_name);
887
888        match Confirm::with_theme(&dialog_theme)
889            .with_prompt(prompt_text)
890            .default(false)
891            .interact()
892        {
893            Ok(confirmed) => {
894                let message = if confirmed {
895                    format!("✓ Approved: '{}' tool will run now", tool_name)
896                } else {
897                    format!("✗ Denied: '{}' tool will not run", tool_name)
898                };
899                let style = if confirmed {
900                    MessageStyle::Tool
901                } else {
902                    MessageStyle::Error
903                };
904                renderer.line(style, &message)?;
905                Ok(confirmed)
906            }
907            Err(e) => {
908                renderer.line(
909                    MessageStyle::Error,
910                    &format!("Failed to read confirmation: {}", e),
911                )?;
912                Ok(false)
913            }
914        }
915    }
916
917    /// Set policy for a specific tool
918    pub fn set_policy(&mut self, tool_name: &str, policy: ToolPolicy) -> Result<()> {
919        if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
920            return self.set_mcp_tool_policy(&provider, &tool, policy);
921        }
922
923        self.config.policies.insert(tool_name.to_string(), policy);
924        self.save_config()
925    }
926
927    /// Reset all tools to prompt
928    pub fn reset_all_to_prompt(&mut self) -> Result<()> {
929        for policy in self.config.policies.values_mut() {
930            *policy = ToolPolicy::Prompt;
931        }
932        for provider in self.config.mcp.providers.values_mut() {
933            for policy in provider.tools.values_mut() {
934                *policy = ToolPolicy::Prompt;
935            }
936        }
937        self.save_config()
938    }
939
940    /// Allow all tools
941    pub fn allow_all_tools(&mut self) -> Result<()> {
942        for policy in self.config.policies.values_mut() {
943            *policy = ToolPolicy::Allow;
944        }
945        for provider in self.config.mcp.providers.values_mut() {
946            for policy in provider.tools.values_mut() {
947                *policy = ToolPolicy::Allow;
948            }
949        }
950        self.save_config()
951    }
952
953    /// Deny all tools
954    pub fn deny_all_tools(&mut self) -> Result<()> {
955        for policy in self.config.policies.values_mut() {
956            *policy = ToolPolicy::Deny;
957        }
958        for provider in self.config.mcp.providers.values_mut() {
959            for policy in provider.tools.values_mut() {
960                *policy = ToolPolicy::Deny;
961            }
962        }
963        self.save_config()
964    }
965
966    /// Get summary of current policies
967    pub fn get_policy_summary(&self) -> IndexMap<String, ToolPolicy> {
968        let mut summary = self.config.policies.clone();
969        for (provider, policy) in &self.config.mcp.providers {
970            for (tool, status) in &policy.tools {
971                summary.insert(format!("mcp::{}::{}", provider, tool), status.clone());
972            }
973        }
974        summary
975    }
976
977    /// Save configuration to file
978    fn save_config(&self) -> Result<()> {
979        Self::write_config(&self.config_path, &self.config)
980    }
981
982    /// Print current policy status
983    pub fn print_status(&self) {
984        println!("{}", style("Tool Policy Status").cyan().bold());
985        println!("Config file: {}", self.config_path.display());
986        println!();
987
988        let summary = self.get_policy_summary();
989
990        if summary.is_empty() {
991            println!("No tools configured yet.");
992            return;
993        }
994
995        let mut allow_count = 0;
996        let mut prompt_count = 0;
997        let mut deny_count = 0;
998
999        for (tool, policy) in &summary {
1000            let (status, color_name) = match policy {
1001                ToolPolicy::Allow => {
1002                    allow_count += 1;
1003                    ("ALLOW", "green")
1004                }
1005                ToolPolicy::Prompt => {
1006                    prompt_count += 1;
1007                    ("PROMPT", "yellow")
1008                }
1009                ToolPolicy::Deny => {
1010                    deny_count += 1;
1011                    ("DENY", "red")
1012                }
1013            };
1014
1015            let status_styled = match color_name {
1016                "green" => style(status).green(),
1017                "yellow" => style(status).yellow(),
1018                "red" => style(status).red(),
1019                _ => style(status),
1020            };
1021
1022            println!(
1023                "  {} {}",
1024                style(format!("{:15}", tool)).cyan(),
1025                status_styled
1026            );
1027        }
1028
1029        println!();
1030        println!(
1031            "Summary: {} allowed, {} prompt, {} denied",
1032            style(allow_count).green(),
1033            style(prompt_count).yellow(),
1034            style(deny_count).red()
1035        );
1036    }
1037
1038    /// Expose path of the underlying policy configuration file
1039    pub fn config_path(&self) -> &Path {
1040        &self.config_path
1041    }
1042}
1043
1044/// Scoped, optional constraints for a tool to align with safe defaults
1045#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1046pub struct ToolConstraints {
1047    /// Whitelisted modes for tools that support modes (e.g., 'terminal')
1048    #[serde(default)]
1049    pub allowed_modes: Option<Vec<String>>,
1050    /// Cap on results for list/search-like tools
1051    #[serde(default)]
1052    pub max_results_per_call: Option<usize>,
1053    /// Cap on items scanned for file listing
1054    #[serde(default)]
1055    pub max_items_per_call: Option<usize>,
1056    /// Default response format if unspecified by caller
1057    #[serde(default)]
1058    pub default_response_format: Option<String>,
1059    /// Cap maximum bytes when reading files
1060    #[serde(default)]
1061    pub max_bytes_per_read: Option<usize>,
1062    /// Cap maximum bytes when fetching over the network
1063    #[serde(default)]
1064    pub max_response_bytes: Option<usize>,
1065    /// Allowed URL schemes for network tools
1066    #[serde(default)]
1067    pub allowed_url_schemes: Option<Vec<String>>,
1068    /// Denied URL hosts or suffixes for network tools
1069    #[serde(default)]
1070    pub denied_url_hosts: Option<Vec<String>>,
1071}
1072
1073#[cfg(test)]
1074mod tests {
1075    use super::*;
1076    use crate::config::constants::tools;
1077    use tempfile::tempdir;
1078
1079    #[test]
1080    fn test_tool_policy_config_serialization() {
1081        let mut config = ToolPolicyConfig::default();
1082        config.available_tools = vec![tools::READ_FILE.to_string(), tools::WRITE_FILE.to_string()];
1083        config
1084            .policies
1085            .insert(tools::READ_FILE.to_string(), ToolPolicy::Allow);
1086        config
1087            .policies
1088            .insert(tools::WRITE_FILE.to_string(), ToolPolicy::Prompt);
1089
1090        let json = serde_json::to_string_pretty(&config).unwrap();
1091        let deserialized: ToolPolicyConfig = serde_json::from_str(&json).unwrap();
1092
1093        assert_eq!(config.available_tools, deserialized.available_tools);
1094        assert_eq!(config.policies, deserialized.policies);
1095    }
1096
1097    #[test]
1098    fn test_policy_updates() {
1099        let dir = tempdir().unwrap();
1100        let config_path = dir.path().join("tool-policy.json");
1101
1102        let mut config = ToolPolicyConfig::default();
1103        config.available_tools = vec!["tool1".to_string()];
1104        config
1105            .policies
1106            .insert("tool1".to_string(), ToolPolicy::Prompt);
1107
1108        // Save initial config
1109        let content = serde_json::to_string_pretty(&config).unwrap();
1110        fs::write(&config_path, content).unwrap();
1111
1112        // Load and update
1113        let mut loaded_config = ToolPolicyManager::load_or_create_config(&config_path).unwrap();
1114
1115        // Add new tool
1116        let new_tools = vec!["tool1".to_string(), "tool2".to_string()];
1117        let current_tools: std::collections::HashSet<_> =
1118            loaded_config.available_tools.iter().collect();
1119
1120        for tool in &new_tools {
1121            if !current_tools.contains(tool) {
1122                loaded_config
1123                    .policies
1124                    .insert(tool.clone(), ToolPolicy::Prompt);
1125            }
1126        }
1127
1128        loaded_config.available_tools = new_tools;
1129
1130        assert_eq!(loaded_config.policies.len(), 2);
1131        assert_eq!(
1132            loaded_config.policies.get("tool2"),
1133            Some(&ToolPolicy::Prompt)
1134        );
1135    }
1136}