vtcode_core/
tool_policy.rs

1//! Tool policy management system
2//!
3//! This module manages user preferences for tool usage, storing choices in
4//! ~/.vtcode/tool-policy.json to minimize repeated prompts while maintaining
5//! user control overwhich tools the agent can use.
6
7use anyhow::{Context, Result};
8use console::{Color as ConsoleColor, Style as ConsoleStyle, style};
9use dialoguer::{Confirm, theme::ColorfulTheme};
10use indexmap::IndexMap;
11use is_terminal::IsTerminal;
12use serde::{Deserialize, Serialize};
13use std::collections::{BTreeMap, HashMap, HashSet};
14use std::fs;
15use std::path::{Path, PathBuf};
16
17use crate::ui::theme;
18use crate::utils::ansi::{AnsiRenderer, MessageStyle};
19
20use crate::config::constants::tools;
21use crate::config::core::tools::{ToolPolicy as ConfigToolPolicy, ToolsConfig};
22use crate::config::mcp::{McpAllowListConfig, McpAllowListRules};
23
24const AUTO_ALLOW_TOOLS: &[&str] = &["run_terminal_cmd", "bash"];
25const DEFAULT_CURL_MAX_RESPONSE_BYTES: usize = 64 * 1024;
26
27/// Tool execution policy
28#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
29#[serde(rename_all = "lowercase")]
30pub enum ToolPolicy {
31    /// Allow tool execution without prompting
32    Allow,
33    /// Prompt user for confirmation each time
34    Prompt,
35    /// Never allow tool execution
36    Deny,
37}
38
39impl Default for ToolPolicy {
40    fn default() -> Self {
41        ToolPolicy::Prompt
42    }
43}
44
45/// Tool policy configuration stored in ~/.vtcode/tool-policy.json
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolPolicyConfig {
48    /// Configuration version for future compatibility
49    pub version: u32,
50    /// Available tools at time of last update
51    pub available_tools: Vec<String>,
52    /// Policy for each tool
53    pub policies: IndexMap<String, ToolPolicy>,
54    /// Optional per-tool constraints to scope permissions and enforce safety
55    #[serde(default)]
56    pub constraints: IndexMap<String, ToolConstraints>,
57    /// MCP-specific policy configuration
58    #[serde(default)]
59    pub mcp: McpPolicyStore,
60}
61
62impl Default for ToolPolicyConfig {
63    fn default() -> Self {
64        Self {
65            version: 1,
66            available_tools: Vec::new(),
67            policies: IndexMap::new(),
68            constraints: IndexMap::new(),
69            mcp: McpPolicyStore::default(),
70        }
71    }
72}
73
74/// Stored MCP policy state, persisted alongside standard tool policies
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct McpPolicyStore {
77    /// Active MCP allow list configuration
78    #[serde(default = "default_secure_mcp_allowlist")]
79    pub allowlist: McpAllowListConfig,
80    /// Provider-specific tool policies (allow/prompt/deny)
81    #[serde(default)]
82    pub providers: IndexMap<String, McpProviderPolicy>,
83}
84
85impl Default for McpPolicyStore {
86    fn default() -> Self {
87        Self {
88            allowlist: default_secure_mcp_allowlist(),
89            providers: IndexMap::new(),
90        }
91    }
92}
93
94/// MCP provider policy entry containing per-tool permissions
95#[derive(Debug, Clone, Serialize, Deserialize, Default)]
96pub struct McpProviderPolicy {
97    #[serde(default)]
98    pub tools: IndexMap<String, ToolPolicy>,
99}
100
101fn default_secure_mcp_allowlist() -> McpAllowListConfig {
102    let mut allowlist = McpAllowListConfig::default();
103    allowlist.enforce = true;
104
105    allowlist.default.logging = Some(vec![
106        "mcp.provider_initialized".to_string(),
107        "mcp.provider_initialization_failed".to_string(),
108        "mcp.tool_filtered".to_string(),
109        "mcp.tool_execution".to_string(),
110        "mcp.tool_failed".to_string(),
111        "mcp.tool_denied".to_string(),
112    ]);
113
114    allowlist.default.configuration = Some(BTreeMap::from([
115        (
116            "client".to_string(),
117            vec![
118                "max_concurrent_connections".to_string(),
119                "request_timeout_seconds".to_string(),
120                "retry_attempts".to_string(),
121            ],
122        ),
123        (
124            "ui".to_string(),
125            vec![
126                "mode".to_string(),
127                "max_events".to_string(),
128                "show_provider_names".to_string(),
129            ],
130        ),
131        (
132            "server".to_string(),
133            vec![
134                "enabled".to_string(),
135                "bind_address".to_string(),
136                "port".to_string(),
137                "transport".to_string(),
138                "name".to_string(),
139                "version".to_string(),
140            ],
141        ),
142    ]));
143
144    let mut time_rules = McpAllowListRules::default();
145    time_rules.tools = Some(vec![
146        "get_*".to_string(),
147        "list_*".to_string(),
148        "convert_timezone".to_string(),
149        "describe_timezone".to_string(),
150        "time_*".to_string(),
151    ]);
152    time_rules.resources = Some(vec!["timezone:*".to_string(), "location:*".to_string()]);
153    time_rules.logging = Some(vec![
154        "mcp.tool_execution".to_string(),
155        "mcp.tool_failed".to_string(),
156        "mcp.tool_denied".to_string(),
157        "mcp.tool_filtered".to_string(),
158        "mcp.provider_initialized".to_string(),
159    ]);
160    time_rules.configuration = Some(BTreeMap::from([
161        (
162            "provider".to_string(),
163            vec!["max_concurrent_requests".to_string()],
164        ),
165        (
166            "time".to_string(),
167            vec!["local_timezone_override".to_string()],
168        ),
169    ]));
170    allowlist.providers.insert("time".to_string(), time_rules);
171
172    let mut context_rules = McpAllowListRules::default();
173    context_rules.tools = Some(vec![
174        "search_*".to_string(),
175        "fetch_*".to_string(),
176        "list_*".to_string(),
177        "context7_*".to_string(),
178        "get_*".to_string(),
179    ]);
180    context_rules.resources = Some(vec![
181        "docs::*".to_string(),
182        "snippets::*".to_string(),
183        "repositories::*".to_string(),
184        "context7::*".to_string(),
185    ]);
186    context_rules.prompts = Some(vec![
187        "context7::*".to_string(),
188        "support::*".to_string(),
189        "docs::*".to_string(),
190    ]);
191    context_rules.logging = Some(vec![
192        "mcp.tool_execution".to_string(),
193        "mcp.tool_failed".to_string(),
194        "mcp.tool_denied".to_string(),
195        "mcp.tool_filtered".to_string(),
196        "mcp.provider_initialized".to_string(),
197    ]);
198    context_rules.configuration = Some(BTreeMap::from([
199        (
200            "provider".to_string(),
201            vec!["max_concurrent_requests".to_string()],
202        ),
203        (
204            "context7".to_string(),
205            vec![
206                "workspace".to_string(),
207                "search_scope".to_string(),
208                "max_results".to_string(),
209            ],
210        ),
211    ]));
212    allowlist
213        .providers
214        .insert("context7".to_string(), context_rules);
215
216    let mut seq_rules = McpAllowListRules::default();
217    seq_rules.tools = Some(vec![
218        "plan".to_string(),
219        "critique".to_string(),
220        "reflect".to_string(),
221        "decompose".to_string(),
222        "sequential_*".to_string(),
223    ]);
224    seq_rules.prompts = Some(vec![
225        "sequential-thinking::*".to_string(),
226        "plan".to_string(),
227        "reflect".to_string(),
228        "critique".to_string(),
229    ]);
230    seq_rules.logging = Some(vec![
231        "mcp.tool_execution".to_string(),
232        "mcp.tool_failed".to_string(),
233        "mcp.tool_denied".to_string(),
234        "mcp.tool_filtered".to_string(),
235        "mcp.provider_initialized".to_string(),
236    ]);
237    seq_rules.configuration = Some(BTreeMap::from([
238        (
239            "provider".to_string(),
240            vec!["max_concurrent_requests".to_string()],
241        ),
242        (
243            "sequencing".to_string(),
244            vec!["max_depth".to_string(), "max_branches".to_string()],
245        ),
246    ]));
247    allowlist
248        .providers
249        .insert("sequential-thinking".to_string(), seq_rules);
250
251    allowlist
252}
253
254fn parse_mcp_policy_key(tool_name: &str) -> Option<(String, String)> {
255    let mut parts = tool_name.splitn(3, "::");
256    match (parts.next()?, parts.next(), parts.next()) {
257        ("mcp", Some(provider), Some(tool)) if !provider.is_empty() && !tool.is_empty() => {
258            Some((provider.to_string(), tool.to_string()))
259        }
260        _ => None,
261    }
262}
263
264/// Alternative tool policy configuration format (user's format)
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct AlternativeToolPolicyConfig {
267    /// Configuration version for future compatibility
268    pub version: u32,
269    /// Default policy settings
270    pub default: AlternativeDefaultPolicy,
271    /// Tool-specific policies
272    pub tools: IndexMap<String, AlternativeToolPolicy>,
273    /// Optional per-tool constraints (ignored if absent)
274    #[serde(default)]
275    pub constraints: IndexMap<String, ToolConstraints>,
276}
277
278/// Default policy in alternative format
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct AlternativeDefaultPolicy {
281    /// Whether to allow by default
282    pub allow: bool,
283    /// Rate limit per run
284    pub rate_limit_per_run: u32,
285    /// Max concurrent executions
286    pub max_concurrent: u32,
287    /// Allow filesystem writes
288    pub fs_write: bool,
289    /// Allow network access
290    pub network: bool,
291}
292
293/// Tool policy in alternative format
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct AlternativeToolPolicy {
296    /// Whether to allow this tool
297    pub allow: bool,
298    /// Allow filesystem writes (optional)
299    #[serde(default)]
300    pub fs_write: bool,
301    /// Allow network access (optional)
302    #[serde(default)]
303    pub network: bool,
304    /// Arguments policy (optional)
305    #[serde(default)]
306    pub args_policy: Option<AlternativeArgsPolicy>,
307}
308
309/// Arguments policy in alternative format
310#[derive(Debug, Clone, Serialize, Deserialize)]
311pub struct AlternativeArgsPolicy {
312    /// Substrings to deny
313    pub deny_substrings: Vec<String>,
314}
315
316/// Tool policy manager
317#[derive(Clone)]
318pub struct ToolPolicyManager {
319    config_path: PathBuf,
320    config: ToolPolicyConfig,
321}
322
323impl ToolPolicyManager {
324    /// Create a new tool policy manager
325    pub fn new() -> Result<Self> {
326        let config_path = Self::get_config_path()?;
327        let config = Self::load_or_create_config(&config_path)?;
328
329        Ok(Self {
330            config_path,
331            config,
332        })
333    }
334
335    /// Create a new tool policy manager with workspace-specific config
336    pub fn new_with_workspace(workspace_root: &PathBuf) -> Result<Self> {
337        let config_path = Self::get_workspace_config_path(workspace_root)?;
338        let config = Self::load_or_create_config(&config_path)?;
339
340        Ok(Self {
341            config_path,
342            config,
343        })
344    }
345
346    /// Get the path to the tool policy configuration file
347    fn get_config_path() -> Result<PathBuf> {
348        let home_dir = dirs::home_dir().context("Could not determine home directory")?;
349
350        let vtcode_dir = home_dir.join(".vtcode");
351        if !vtcode_dir.exists() {
352            fs::create_dir_all(&vtcode_dir).context("Failed to create ~/.vtcode directory")?;
353        }
354
355        Ok(vtcode_dir.join("tool-policy.json"))
356    }
357
358    /// Get the path to the workspace-specific tool policy configuration file
359    fn get_workspace_config_path(workspace_root: &PathBuf) -> Result<PathBuf> {
360        let workspace_vtcode_dir = workspace_root.join(".vtcode");
361
362        if !workspace_vtcode_dir.exists() {
363            fs::create_dir_all(&workspace_vtcode_dir).with_context(|| {
364                format!(
365                    "Failed to create workspace policy directory at {}",
366                    workspace_vtcode_dir.display()
367                )
368            })?;
369        }
370
371        Ok(workspace_vtcode_dir.join("tool-policy.json"))
372    }
373
374    /// Load existing config or create new one with all tools as "prompt"
375    fn load_or_create_config(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
376        if config_path.exists() {
377            let content =
378                fs::read_to_string(config_path).context("Failed to read tool policy config")?;
379
380            // Try to parse as alternative format first
381            if let Ok(alt_config) = serde_json::from_str::<AlternativeToolPolicyConfig>(&content) {
382                // Convert alternative format to standard format
383                return Ok(Self::convert_from_alternative(alt_config));
384            }
385
386            // Fall back to standard format with graceful recovery on parse errors
387            match serde_json::from_str(&content) {
388                Ok(mut config) => {
389                    Self::apply_auto_allow_defaults(&mut config);
390                    Self::ensure_network_constraints(&mut config);
391                    Ok(config)
392                }
393                Err(parse_err) => {
394                    eprintln!(
395                        "Warning: Invalid tool policy config at {} ({}). Resetting to defaults.",
396                        config_path.display(),
397                        parse_err
398                    );
399                    Self::reset_to_default(config_path)
400                }
401            }
402        } else {
403            // Create new config with empty tools list
404            let mut config = ToolPolicyConfig::default();
405            Self::apply_auto_allow_defaults(&mut config);
406            Self::ensure_network_constraints(&mut config);
407            Ok(config)
408        }
409    }
410
411    fn apply_auto_allow_defaults(config: &mut ToolPolicyConfig) {
412        for tool in AUTO_ALLOW_TOOLS {
413            config
414                .policies
415                .entry((*tool).to_string())
416                .and_modify(|policy| *policy = ToolPolicy::Allow)
417                .or_insert(ToolPolicy::Allow);
418            if !config.available_tools.contains(&tool.to_string()) {
419                config.available_tools.push(tool.to_string());
420            }
421        }
422        Self::ensure_network_constraints(config);
423    }
424
425    fn ensure_network_constraints(config: &mut ToolPolicyConfig) {
426        let entry = config
427            .constraints
428            .entry(tools::CURL.to_string())
429            .or_insert_with(ToolConstraints::default);
430
431        if entry.max_response_bytes.is_none() {
432            entry.max_response_bytes = Some(DEFAULT_CURL_MAX_RESPONSE_BYTES);
433        }
434        if entry.allowed_url_schemes.is_none() {
435            entry.allowed_url_schemes = Some(vec!["https".to_string()]);
436        }
437        if entry.denied_url_hosts.is_none() {
438            entry.denied_url_hosts = Some(vec![
439                "localhost".to_string(),
440                "127.0.0.1".to_string(),
441                "0.0.0.0".to_string(),
442                "::1".to_string(),
443                ".localhost".to_string(),
444                ".local".to_string(),
445                ".internal".to_string(),
446                ".lan".to_string(),
447            ]);
448        }
449    }
450
451    fn reset_to_default(config_path: &PathBuf) -> Result<ToolPolicyConfig> {
452        let backup_path = config_path.with_extension("json.bak");
453
454        if let Err(err) = fs::rename(config_path, &backup_path) {
455            eprintln!(
456                "Warning: Unable to back up invalid tool policy config ({}). {}",
457                config_path.display(),
458                err
459            );
460        } else {
461            eprintln!(
462                "Backed up invalid tool policy config to {}",
463                backup_path.display()
464            );
465        }
466
467        let default_config = ToolPolicyConfig::default();
468        Self::write_config(config_path.as_path(), &default_config)?;
469        Ok(default_config)
470    }
471
472    fn write_config(path: &Path, config: &ToolPolicyConfig) -> Result<()> {
473        if let Some(parent) = path.parent() {
474            if !parent.exists() {
475                fs::create_dir_all(parent).with_context(|| {
476                    format!(
477                        "Failed to create directory for tool policy config at {}",
478                        parent.display()
479                    )
480                })?;
481            }
482        }
483
484        let serialized = serde_json::to_string_pretty(config)
485            .context("Failed to serialize tool policy config")?;
486
487        fs::write(path, serialized)
488            .with_context(|| format!("Failed to write tool policy config: {}", path.display()))
489    }
490
491    /// Convert alternative format to standard format
492    fn convert_from_alternative(alt_config: AlternativeToolPolicyConfig) -> ToolPolicyConfig {
493        let mut policies = IndexMap::new();
494
495        // Convert tool policies
496        for (tool_name, alt_policy) in alt_config.tools {
497            let policy = if alt_policy.allow {
498                ToolPolicy::Allow
499            } else {
500                ToolPolicy::Deny
501            };
502            policies.insert(tool_name, policy);
503        }
504
505        let mut config = ToolPolicyConfig {
506            version: alt_config.version,
507            available_tools: policies.keys().cloned().collect(),
508            policies,
509            constraints: alt_config.constraints,
510            mcp: McpPolicyStore::default(),
511        };
512        Self::apply_auto_allow_defaults(&mut config);
513        config
514    }
515
516    fn apply_config_policy(&mut self, tool_name: &str, policy: ConfigToolPolicy) {
517        let runtime_policy = match policy {
518            ConfigToolPolicy::Allow => ToolPolicy::Allow,
519            ConfigToolPolicy::Prompt => ToolPolicy::Prompt,
520            ConfigToolPolicy::Deny => ToolPolicy::Deny,
521        };
522
523        self.config
524            .policies
525            .insert(tool_name.to_string(), runtime_policy);
526    }
527
528    fn resolve_config_policy(tools_config: &ToolsConfig, tool_name: &str) -> ConfigToolPolicy {
529        if let Some(policy) = tools_config.policies.get(tool_name) {
530            return policy.clone();
531        }
532
533        match tool_name {
534            tools::LIST_FILES => tools_config
535                .policies
536                .get("list_dir")
537                .or_else(|| tools_config.policies.get("list_directory"))
538                .cloned(),
539            _ => None,
540        }
541        .unwrap_or_else(|| tools_config.default_policy.clone())
542    }
543
544    /// Apply policies defined in vtcode.toml to the runtime policy manager
545    pub fn apply_tools_config(&mut self, tools_config: &ToolsConfig) -> Result<()> {
546        if self.config.available_tools.is_empty() {
547            return Ok(());
548        }
549
550        for tool in self.config.available_tools.clone() {
551            let config_policy = Self::resolve_config_policy(tools_config, &tool);
552            self.apply_config_policy(&tool, config_policy);
553        }
554
555        Self::apply_auto_allow_defaults(&mut self.config);
556        self.save_config()
557    }
558
559    /// Update the tool list and save configuration
560    pub fn update_available_tools(&mut self, tools: Vec<String>) -> Result<()> {
561        let current_tools: HashSet<_> = self.config.policies.keys().cloned().collect();
562        let new_tools: HashSet<_> = tools
563            .iter()
564            .filter(|name| !name.starts_with("mcp::"))
565            .cloned()
566            .collect();
567
568        let mut has_changes = false;
569
570        // Add new tools with appropriate defaults
571        for tool in tools
572            .iter()
573            .filter(|tool| !tool.starts_with("mcp::") && !current_tools.contains(*tool))
574        {
575            let default_policy = if AUTO_ALLOW_TOOLS.contains(&tool.as_str()) {
576                ToolPolicy::Allow
577            } else {
578                ToolPolicy::Prompt
579            };
580            self.config.policies.insert(tool.clone(), default_policy);
581            has_changes = true;
582        }
583
584        // Remove deleted tools - use itertools to find tools to remove
585        let tools_to_remove: Vec<_> = self
586            .config
587            .policies
588            .keys()
589            .filter(|tool| !new_tools.contains(*tool))
590            .cloned()
591            .collect();
592
593        for tool in tools_to_remove {
594            self.config.policies.shift_remove(&tool);
595            has_changes = true;
596        }
597
598        // Check if available tools list has actually changed
599        if self.config.available_tools != tools {
600            // Update available tools list
601            self.config.available_tools = tools;
602            has_changes = true;
603        }
604
605        Self::ensure_network_constraints(&mut self.config);
606
607        if has_changes {
608            self.save_config()
609        } else {
610            Ok(())
611        }
612    }
613
614    /// Synchronize MCP provider tool lists with persisted policies
615    pub fn update_mcp_tools(
616        &mut self,
617        provider_tools: &HashMap<String, Vec<String>>,
618    ) -> Result<()> {
619        let stored_providers: HashSet<String> = self.config.mcp.providers.keys().cloned().collect();
620        let mut has_changes = false;
621
622        // Update or insert provider entries
623        for (provider, tools) in provider_tools {
624            let entry = self
625                .config
626                .mcp
627                .providers
628                .entry(provider.clone())
629                .or_insert_with(McpProviderPolicy::default);
630
631            let existing_tools: HashSet<String> = entry.tools.keys().cloned().collect();
632            let advertised: HashSet<String> = tools.iter().cloned().collect();
633
634            // Add new tools with default Prompt policy
635            for tool in tools {
636                if !existing_tools.contains(tool) {
637                    entry.tools.insert(tool.clone(), ToolPolicy::Prompt);
638                    has_changes = true;
639                }
640            }
641
642            // Remove tools no longer advertised
643            for stale in existing_tools.difference(&advertised) {
644                entry.tools.shift_remove(stale.as_str());
645                has_changes = true;
646            }
647        }
648
649        // Remove providers that are no longer present
650        let advertised_providers: HashSet<String> = provider_tools.keys().cloned().collect();
651        for provider in stored_providers
652            .difference(&advertised_providers)
653            .cloned()
654            .collect::<Vec<_>>()
655        {
656            self.config.mcp.providers.shift_remove(provider.as_str());
657            has_changes = true;
658        }
659
660        // Remove any stale MCP keys from the primary policy map
661        let stale_runtime_keys: Vec<_> = self
662            .config
663            .policies
664            .keys()
665            .filter(|name| name.starts_with("mcp::"))
666            .cloned()
667            .collect();
668
669        for key in stale_runtime_keys {
670            self.config.policies.shift_remove(&key);
671            has_changes = true;
672        }
673
674        // Refresh available tools list with MCP entries included
675        let mut available: Vec<String> = self
676            .config
677            .available_tools
678            .iter()
679            .filter(|name| !name.starts_with("mcp::"))
680            .cloned()
681            .collect();
682
683        for (provider, policy) in &self.config.mcp.providers {
684            for tool in policy.tools.keys() {
685                available.push(format!("mcp::{}::{}", provider, tool));
686            }
687        }
688
689        available.sort();
690        available.dedup();
691
692        // Check if the available tools list has actually changed
693        if self.config.available_tools != available {
694            self.config.available_tools = available;
695            has_changes = true;
696        }
697
698        if has_changes {
699            self.save_config()
700        } else {
701            Ok(())
702        }
703    }
704
705    /// Retrieve policy for a specific MCP tool
706    pub fn get_mcp_tool_policy(&self, provider: &str, tool: &str) -> ToolPolicy {
707        self.config
708            .mcp
709            .providers
710            .get(provider)
711            .and_then(|policy| policy.tools.get(tool))
712            .cloned()
713            .unwrap_or(ToolPolicy::Prompt)
714    }
715
716    /// Update policy for a specific MCP tool
717    pub fn set_mcp_tool_policy(
718        &mut self,
719        provider: &str,
720        tool: &str,
721        policy: ToolPolicy,
722    ) -> Result<()> {
723        let entry = self
724            .config
725            .mcp
726            .providers
727            .entry(provider.to_string())
728            .or_insert_with(McpProviderPolicy::default);
729        entry.tools.insert(tool.to_string(), policy);
730        self.save_config()
731    }
732
733    /// Access the persisted MCP allow list configuration
734    pub fn mcp_allowlist(&self) -> &McpAllowListConfig {
735        &self.config.mcp.allowlist
736    }
737
738    /// Replace the persisted MCP allow list configuration
739    pub fn set_mcp_allowlist(&mut self, allowlist: McpAllowListConfig) -> Result<()> {
740        self.config.mcp.allowlist = allowlist;
741        self.save_config()
742    }
743
744    /// Get policy for a specific tool
745    pub fn get_policy(&self, tool_name: &str) -> ToolPolicy {
746        if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
747            return self.get_mcp_tool_policy(&provider, &tool);
748        }
749
750        self.config
751            .policies
752            .get(tool_name)
753            .cloned()
754            .unwrap_or(ToolPolicy::Prompt)
755    }
756
757    /// Get optional constraints for a specific tool
758    pub fn get_constraints(&self, tool_name: &str) -> Option<&ToolConstraints> {
759        self.config.constraints.get(tool_name)
760    }
761
762    /// Check if tool should be executed based on policy
763    pub fn should_execute_tool(&mut self, tool_name: &str) -> Result<bool> {
764        if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
765            return match self.get_mcp_tool_policy(&provider, &tool) {
766                ToolPolicy::Allow => Ok(true),
767                ToolPolicy::Deny => Ok(false),
768                ToolPolicy::Prompt => {
769                    if ToolPolicyManager::is_auto_allow_tool(tool_name) {
770                        self.set_mcp_tool_policy(&provider, &tool, ToolPolicy::Allow)?;
771                        Ok(true)
772                    } else {
773                        self.prompt_user_for_tool(tool_name)
774                    }
775                }
776            };
777        }
778
779        match self.get_policy(tool_name) {
780            ToolPolicy::Allow => Ok(true),
781            ToolPolicy::Deny => Ok(false),
782            ToolPolicy::Prompt => {
783                if AUTO_ALLOW_TOOLS.contains(&tool_name) {
784                    self.set_policy(tool_name, ToolPolicy::Allow)?;
785                    return Ok(true);
786                }
787                let should_execute = self.prompt_user_for_tool(tool_name)?;
788                Ok(should_execute)
789            }
790        }
791    }
792
793    pub fn is_auto_allow_tool(tool_name: &str) -> bool {
794        AUTO_ALLOW_TOOLS.contains(&tool_name)
795    }
796
797    /// Prompt user for tool execution permission
798    fn prompt_user_for_tool(&mut self, tool_name: &str) -> Result<bool> {
799        let interactive = std::io::stdin().is_terminal() && std::io::stdout().is_terminal();
800        let mut renderer = AnsiRenderer::stdout();
801        let banner_style = theme::banner_style();
802
803        if !interactive {
804            let message = format!(
805                "Non-interactive environment detected. Auto-approving '{}' tool.",
806                tool_name
807            );
808            renderer.line_with_style(banner_style, &message)?;
809            return Ok(true);
810        }
811
812        let header = format!("Tool Permission Request: {}", tool_name);
813        renderer.line_with_style(banner_style, &header)?;
814        renderer.line_with_style(
815            banner_style,
816            &format!("The agent wants to use the '{}' tool.", tool_name),
817        )?;
818        renderer.line_with_style(banner_style, "")?;
819        renderer.line_with_style(
820            banner_style,
821            "This decision applies to the current request only.",
822        )?;
823        renderer.line_with_style(
824            banner_style,
825            "Update the policy file or use CLI flags to change the default.",
826        )?;
827        renderer.line_with_style(banner_style, "")?;
828
829        if AUTO_ALLOW_TOOLS.contains(&tool_name) {
830            renderer.line_with_style(
831                banner_style,
832                &format!(
833                    "Auto-approving '{}' tool (default trusted tool).",
834                    tool_name
835                ),
836            )?;
837            return Ok(true);
838        }
839
840        let rgb = theme::banner_color();
841        let to_ansi_256 = |value: u8| -> u8 {
842            if value < 48 {
843                0
844            } else if value < 114 {
845                1
846            } else {
847                ((value - 35) / 40).min(5)
848            }
849        };
850        let rgb_to_index = |r: u8, g: u8, b: u8| -> u8 {
851            let r_idx = to_ansi_256(r);
852            let g_idx = to_ansi_256(g);
853            let b_idx = to_ansi_256(b);
854            16 + 36 * r_idx + 6 * g_idx + b_idx
855        };
856        let color_index = rgb_to_index(rgb.0, rgb.1, rgb.2);
857        let dialog_color = ConsoleColor::Color256(color_index);
858        let tinted_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
859
860        let mut dialog_theme = ColorfulTheme::default();
861        dialog_theme.prompt_style = tinted_style;
862        dialog_theme.prompt_prefix = style("—".to_string()).for_stderr().fg(dialog_color);
863        dialog_theme.prompt_suffix = style("—".to_string()).for_stderr().fg(dialog_color);
864        dialog_theme.hint_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
865        dialog_theme.defaults_style = dialog_theme.hint_style.clone();
866        dialog_theme.success_prefix = style("✓".to_string()).for_stderr().fg(dialog_color);
867        dialog_theme.success_suffix = style("·".to_string()).for_stderr().fg(dialog_color);
868        dialog_theme.error_prefix = style("✗".to_string()).for_stderr().fg(dialog_color);
869        dialog_theme.error_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
870        dialog_theme.values_style = ConsoleStyle::new().for_stderr().fg(dialog_color);
871
872        let prompt_text = format!("Allow the agent to use '{}'?", tool_name);
873
874        match Confirm::with_theme(&dialog_theme)
875            .with_prompt(prompt_text)
876            .default(false)
877            .interact()
878        {
879            Ok(confirmed) => {
880                let message = if confirmed {
881                    format!("✓ Approved: '{}' tool will run now", tool_name)
882                } else {
883                    format!("✗ Denied: '{}' tool will not run", tool_name)
884                };
885                let style = if confirmed {
886                    MessageStyle::Tool
887                } else {
888                    MessageStyle::Error
889                };
890                renderer.line(style, &message)?;
891                Ok(confirmed)
892            }
893            Err(e) => {
894                renderer.line(
895                    MessageStyle::Error,
896                    &format!("Failed to read confirmation: {}", e),
897                )?;
898                Ok(false)
899            }
900        }
901    }
902
903    /// Set policy for a specific tool
904    pub fn set_policy(&mut self, tool_name: &str, policy: ToolPolicy) -> Result<()> {
905        if let Some((provider, tool)) = parse_mcp_policy_key(tool_name) {
906            return self.set_mcp_tool_policy(&provider, &tool, policy);
907        }
908
909        self.config.policies.insert(tool_name.to_string(), policy);
910        self.save_config()
911    }
912
913    /// Reset all tools to prompt
914    pub fn reset_all_to_prompt(&mut self) -> Result<()> {
915        for policy in self.config.policies.values_mut() {
916            *policy = ToolPolicy::Prompt;
917        }
918        for provider in self.config.mcp.providers.values_mut() {
919            for policy in provider.tools.values_mut() {
920                *policy = ToolPolicy::Prompt;
921            }
922        }
923        self.save_config()
924    }
925
926    /// Allow all tools
927    pub fn allow_all_tools(&mut self) -> Result<()> {
928        for policy in self.config.policies.values_mut() {
929            *policy = ToolPolicy::Allow;
930        }
931        for provider in self.config.mcp.providers.values_mut() {
932            for policy in provider.tools.values_mut() {
933                *policy = ToolPolicy::Allow;
934            }
935        }
936        self.save_config()
937    }
938
939    /// Deny all tools
940    pub fn deny_all_tools(&mut self) -> Result<()> {
941        for policy in self.config.policies.values_mut() {
942            *policy = ToolPolicy::Deny;
943        }
944        for provider in self.config.mcp.providers.values_mut() {
945            for policy in provider.tools.values_mut() {
946                *policy = ToolPolicy::Deny;
947            }
948        }
949        self.save_config()
950    }
951
952    /// Get summary of current policies
953    pub fn get_policy_summary(&self) -> IndexMap<String, ToolPolicy> {
954        let mut summary = self.config.policies.clone();
955        for (provider, policy) in &self.config.mcp.providers {
956            for (tool, status) in &policy.tools {
957                summary.insert(format!("mcp::{}::{}", provider, tool), status.clone());
958            }
959        }
960        summary
961    }
962
963    /// Save configuration to file
964    fn save_config(&self) -> Result<()> {
965        Self::write_config(&self.config_path, &self.config)
966    }
967
968    /// Print current policy status
969    pub fn print_status(&self) {
970        println!("{}", style("Tool Policy Status").cyan().bold());
971        println!("Config file: {}", self.config_path.display());
972        println!();
973
974        let summary = self.get_policy_summary();
975
976        if summary.is_empty() {
977            println!("No tools configured yet.");
978            return;
979        }
980
981        let mut allow_count = 0;
982        let mut prompt_count = 0;
983        let mut deny_count = 0;
984
985        for (tool, policy) in &summary {
986            let (status, color_name) = match policy {
987                ToolPolicy::Allow => {
988                    allow_count += 1;
989                    ("ALLOW", "green")
990                }
991                ToolPolicy::Prompt => {
992                    prompt_count += 1;
993                    ("PROMPT", "yellow")
994                }
995                ToolPolicy::Deny => {
996                    deny_count += 1;
997                    ("DENY", "red")
998                }
999            };
1000
1001            let status_styled = match color_name {
1002                "green" => style(status).green(),
1003                "yellow" => style(status).yellow(),
1004                "red" => style(status).red(),
1005                _ => style(status),
1006            };
1007
1008            println!(
1009                "  {} {}",
1010                style(format!("{:15}", tool)).cyan(),
1011                status_styled
1012            );
1013        }
1014
1015        println!();
1016        println!(
1017            "Summary: {} allowed, {} prompt, {} denied",
1018            style(allow_count).green(),
1019            style(prompt_count).yellow(),
1020            style(deny_count).red()
1021        );
1022    }
1023
1024    /// Expose path of the underlying policy configuration file
1025    pub fn config_path(&self) -> &Path {
1026        &self.config_path
1027    }
1028}
1029
1030/// Scoped, optional constraints for a tool to align with safe defaults
1031#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1032pub struct ToolConstraints {
1033    /// Whitelisted modes for tools that support modes (e.g., 'terminal')
1034    #[serde(default)]
1035    pub allowed_modes: Option<Vec<String>>,
1036    /// Cap on results for list/search-like tools
1037    #[serde(default)]
1038    pub max_results_per_call: Option<usize>,
1039    /// Cap on items scanned for file listing
1040    #[serde(default)]
1041    pub max_items_per_call: Option<usize>,
1042    /// Default response format if unspecified by caller
1043    #[serde(default)]
1044    pub default_response_format: Option<String>,
1045    /// Cap maximum bytes when reading files
1046    #[serde(default)]
1047    pub max_bytes_per_read: Option<usize>,
1048    /// Cap maximum bytes when fetching over the network
1049    #[serde(default)]
1050    pub max_response_bytes: Option<usize>,
1051    /// Allowed URL schemes for network tools
1052    #[serde(default)]
1053    pub allowed_url_schemes: Option<Vec<String>>,
1054    /// Denied URL hosts or suffixes for network tools
1055    #[serde(default)]
1056    pub denied_url_hosts: Option<Vec<String>>,
1057}
1058
1059#[cfg(test)]
1060mod tests {
1061    use super::*;
1062    use crate::config::constants::tools;
1063    use tempfile::tempdir;
1064
1065    #[test]
1066    fn test_tool_policy_config_serialization() {
1067        let mut config = ToolPolicyConfig::default();
1068        config.available_tools = vec![tools::READ_FILE.to_string(), tools::WRITE_FILE.to_string()];
1069        config
1070            .policies
1071            .insert(tools::READ_FILE.to_string(), ToolPolicy::Allow);
1072        config
1073            .policies
1074            .insert(tools::WRITE_FILE.to_string(), ToolPolicy::Prompt);
1075
1076        let json = serde_json::to_string_pretty(&config).unwrap();
1077        let deserialized: ToolPolicyConfig = serde_json::from_str(&json).unwrap();
1078
1079        assert_eq!(config.available_tools, deserialized.available_tools);
1080        assert_eq!(config.policies, deserialized.policies);
1081    }
1082
1083    #[test]
1084    fn test_policy_updates() {
1085        let dir = tempdir().unwrap();
1086        let config_path = dir.path().join("tool-policy.json");
1087
1088        let mut config = ToolPolicyConfig::default();
1089        config.available_tools = vec!["tool1".to_string()];
1090        config
1091            .policies
1092            .insert("tool1".to_string(), ToolPolicy::Prompt);
1093
1094        // Save initial config
1095        let content = serde_json::to_string_pretty(&config).unwrap();
1096        fs::write(&config_path, content).unwrap();
1097
1098        // Load and update
1099        let mut loaded_config = ToolPolicyManager::load_or_create_config(&config_path).unwrap();
1100
1101        // Add new tool
1102        let new_tools = vec!["tool1".to_string(), "tool2".to_string()];
1103        let current_tools: std::collections::HashSet<_> =
1104            loaded_config.available_tools.iter().collect();
1105
1106        for tool in &new_tools {
1107            if !current_tools.contains(tool) {
1108                loaded_config
1109                    .policies
1110                    .insert(tool.clone(), ToolPolicy::Prompt);
1111            }
1112        }
1113
1114        loaded_config.available_tools = new_tools;
1115
1116        assert_eq!(loaded_config.policies.len(), 2);
1117        assert_eq!(
1118            loaded_config.policies.get("tool2"),
1119            Some(&ToolPolicy::Prompt)
1120        );
1121    }
1122}