vtcode_core/tools/registry/
policy.rs

1use anyhow::{Result, anyhow};
2use reqwest::Url;
3use serde_json::{Value, json};
4
5use crate::config::constants::tools;
6use crate::tool_policy::{ToolPolicy, ToolPolicyManager};
7
8use super::ToolRegistry;
9
10impl ToolRegistry {
11    pub(super) fn sync_policy_available_tools(&mut self) {
12        let mut available = self.available_tools();
13        available.extend(self.mcp_policy_keys());
14        if let Some(ref mut pm) = self.tool_policy
15            && let Err(err) = pm.update_available_tools(available)
16        {
17            eprintln!("Warning: Failed to update tool policies: {}", err);
18        }
19    }
20
21    pub(super) fn apply_policy_constraints(&self, name: &str, mut args: Value) -> Result<Value> {
22        if let Some(constraints) = self
23            .tool_policy
24            .as_ref()
25            .and_then(|tp| tp.get_constraints(name))
26            .cloned()
27        {
28            let obj = args
29                .as_object_mut()
30                .ok_or_else(|| anyhow!("Error: tool arguments must be an object"))?;
31
32            if let Some(fmt) = constraints.default_response_format {
33                obj.entry("response_format").or_insert(json!(fmt));
34            }
35
36            if let Some(allowed) = constraints.allowed_modes
37                && let Some(mode) = obj.get("mode").and_then(|v| v.as_str())
38                && !allowed.iter().any(|m| m == mode)
39            {
40                return Err(anyhow!(format!(
41                    "Mode '{}' not allowed by policy for '{}'. Allowed: {}",
42                    mode,
43                    name,
44                    allowed.join(", ")
45                )));
46            }
47
48            match name {
49                n if n == tools::LIST_FILES => {
50                    if let Some(cap) = constraints.max_items_per_call {
51                        let requested = obj
52                            .get("max_items")
53                            .and_then(|v| v.as_u64())
54                            .unwrap_or(cap as u64) as usize;
55                        if requested > cap {
56                            obj.insert("max_items".to_string(), json!(cap));
57                            obj.insert(
58                                "_policy_note".to_string(),
59                                json!(format!("Capped max_items to {} by policy", cap)),
60                            );
61                        }
62                    }
63                }
64                n if n == tools::GREP_SEARCH => {
65                    if let Some(cap) = constraints.max_results_per_call {
66                        let requested = obj
67                            .get("max_results")
68                            .and_then(|v| v.as_u64())
69                            .unwrap_or(cap as u64) as usize;
70                        if requested > cap {
71                            obj.insert("max_results".to_string(), json!(cap));
72                            obj.insert(
73                                "_policy_note".to_string(),
74                                json!(format!("Capped max_results to {} by policy", cap)),
75                            );
76                        }
77                    }
78                }
79                n if n == tools::READ_FILE => {
80                    if let Some(cap) = constraints.max_bytes_per_read {
81                        let requested = obj
82                            .get("max_bytes")
83                            .and_then(|v| v.as_u64())
84                            .unwrap_or(cap as u64) as usize;
85                        if requested > cap {
86                            obj.insert("max_bytes".to_string(), json!(cap));
87                            obj.insert(
88                                "_policy_note".to_string(),
89                                json!(format!("Capped max_bytes to {} by policy", cap)),
90                            );
91                        }
92                    }
93                }
94                n if n == tools::CURL => {
95                    if let Some(cap) = constraints.max_response_bytes {
96                        let requested = obj
97                            .get("max_bytes")
98                            .and_then(|v| v.as_u64())
99                            .unwrap_or(cap as u64) as usize;
100                        if requested > cap {
101                            obj.insert("max_bytes".to_string(), json!(cap));
102                            obj.insert(
103                                "_policy_note".to_string(),
104                                json!(format!("Capped max_bytes to {} bytes by policy", cap)),
105                            );
106                        }
107                    }
108
109                    let url_value = obj
110                        .get("url")
111                        .and_then(|v| v.as_str())
112                        .ok_or_else(|| anyhow!("curl tool requires a 'url' parameter"))?;
113
114                    let parsed = Url::parse(url_value)
115                        .map_err(|err| anyhow!(format!("Invalid URL '{}': {}", url_value, err)))?;
116
117                    if let Some(allowed) = constraints.allowed_url_schemes.as_ref() {
118                        let scheme = parsed.scheme();
119                        if !allowed
120                            .iter()
121                            .any(|candidate| candidate.eq_ignore_ascii_case(scheme))
122                        {
123                            return Err(anyhow!(format!(
124                                "Scheme '{}' is not allowed for curl tool. Allowed schemes: {}",
125                                scheme,
126                                allowed.join(", ")
127                            )));
128                        }
129                    }
130
131                    if let Some(denied_hosts) = constraints.denied_url_hosts.as_ref()
132                        && let Some(host_str) = parsed.host_str()
133                    {
134                        let lowered = host_str.to_lowercase();
135                        let blocked = denied_hosts.iter().any(|pattern| {
136                            let normalized = pattern.to_lowercase();
137                            if normalized.starts_with('.') {
138                                lowered.ends_with(&normalized)
139                            } else {
140                                lowered == normalized
141                                    || lowered.ends_with(&format!(".{}", normalized))
142                            }
143                        });
144                        if blocked {
145                            return Err(anyhow!(format!(
146                                "URL host '{}' is blocked by policy",
147                                host_str
148                            )));
149                        }
150                    }
151                }
152                _ => {}
153            }
154        }
155        Ok(args)
156    }
157
158    pub fn policy_manager_mut(&mut self) -> Result<&mut ToolPolicyManager> {
159        self.tool_policy
160            .as_mut()
161            .ok_or_else(|| anyhow!("Tool policy manager not available"))
162    }
163
164    pub fn policy_manager(&self) -> Result<&ToolPolicyManager> {
165        self.tool_policy
166            .as_ref()
167            .ok_or_else(|| anyhow!("Tool policy manager not available"))
168    }
169
170    pub fn set_policy_manager(&mut self, manager: ToolPolicyManager) {
171        self.tool_policy = Some(manager);
172        self.sync_policy_available_tools();
173    }
174
175    pub fn set_tool_policy(&mut self, tool_name: &str, policy: ToolPolicy) -> Result<()> {
176        self.tool_policy
177            .as_mut()
178            .expect("Tool policy manager not initialized")
179            .set_policy(tool_name, policy)
180    }
181
182    pub fn get_tool_policy(&self, tool_name: &str) -> ToolPolicy {
183        self.tool_policy
184            .as_ref()
185            .map(|tp| tp.get_policy(tool_name))
186            .unwrap_or(ToolPolicy::Allow)
187    }
188
189    pub fn reset_tool_policies(&mut self) -> Result<()> {
190        if let Some(tp) = self.tool_policy.as_mut() {
191            tp.reset_all_to_prompt()
192        } else {
193            Err(anyhow!("Tool policy manager not available"))
194        }
195    }
196
197    pub fn allow_all_tools(&mut self) -> Result<()> {
198        if let Some(tp) = self.tool_policy.as_mut() {
199            tp.allow_all_tools()
200        } else {
201            Err(anyhow!("Tool policy manager not available"))
202        }
203    }
204
205    pub fn deny_all_tools(&mut self) -> Result<()> {
206        if let Some(tp) = self.tool_policy.as_mut() {
207            tp.deny_all_tools()
208        } else {
209            Err(anyhow!("Tool policy manager not available"))
210        }
211    }
212
213    pub fn print_policy_status(&self) {
214        if let Some(tp) = self.tool_policy.as_ref() {
215            tp.print_status();
216        } else {
217            eprintln!("Tool policy manager not available");
218        }
219    }
220}