vtcode_core/tools/registry/
mod.rs

1//! Tool registry and function declarations
2
3mod astgrep;
4mod builtins;
5mod cache;
6mod declarations;
7mod error;
8mod executors;
9mod legacy;
10mod policy;
11mod pty;
12mod registration;
13mod utils;
14
15pub use declarations::{build_function_declarations, build_function_declarations_for_level};
16pub use error::{ToolErrorType, ToolExecutionError, classify_error};
17pub use registration::{ToolExecutorFn, ToolHandler, ToolRegistration};
18
19use builtins::register_builtin_tools;
20use utils::normalize_tool_output;
21
22use crate::config::PtyConfig;
23use crate::config::ToolsConfig;
24use crate::config::constants::tools;
25use crate::tool_policy::{ToolPolicy, ToolPolicyManager};
26use crate::tools::ast_grep::AstGrepEngine;
27use crate::tools::grep_search::GrepSearchManager;
28use anyhow::{Result, anyhow};
29use serde_json::Value;
30use std::collections::{HashMap, HashSet};
31use std::path::PathBuf;
32use std::sync::Arc;
33use std::sync::atomic::AtomicUsize;
34
35use super::bash_tool::BashTool;
36use super::command::CommandTool;
37use super::curl_tool::CurlTool;
38use super::file_ops::FileOpsTool;
39use super::plan::PlanManager;
40use super::search::SearchTool;
41use super::simple_search::SimpleSearchTool;
42use super::srgn::SrgnTool;
43
44#[cfg(test)]
45use super::traits::Tool;
46#[cfg(test)]
47use crate::config::types::CapabilityLevel;
48
49#[derive(Clone)]
50pub struct ToolRegistry {
51    workspace_root: PathBuf,
52    search_tool: SearchTool,
53    simple_search_tool: SimpleSearchTool,
54    bash_tool: BashTool,
55    file_ops_tool: FileOpsTool,
56    command_tool: CommandTool,
57    curl_tool: CurlTool,
58    grep_search: Arc<GrepSearchManager>,
59    ast_grep_engine: Option<Arc<AstGrepEngine>>,
60    tool_policy: Option<ToolPolicyManager>,
61    pty_config: PtyConfig,
62    active_pty_sessions: Arc<AtomicUsize>,
63    srgn_tool: SrgnTool,
64    plan_manager: PlanManager,
65    tool_registrations: Vec<ToolRegistration>,
66    tool_lookup: HashMap<&'static str, usize>,
67    preapproved_tools: HashSet<String>,
68    full_auto_allowlist: Option<HashSet<String>>,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum ToolPermissionDecision {
73    Allow,
74    Deny,
75    Prompt,
76}
77
78impl ToolRegistry {
79    pub fn new(workspace_root: PathBuf) -> Self {
80        Self::new_with_config(workspace_root, PtyConfig::default())
81    }
82
83    pub fn new_with_config(workspace_root: PathBuf, pty_config: PtyConfig) -> Self {
84        let grep_search = Arc::new(GrepSearchManager::new(workspace_root.clone()));
85
86        let search_tool = SearchTool::new(workspace_root.clone(), grep_search.clone());
87        let simple_search_tool = SimpleSearchTool::new(workspace_root.clone());
88        let bash_tool = BashTool::new(workspace_root.clone());
89        let file_ops_tool = FileOpsTool::new(workspace_root.clone(), grep_search.clone());
90        let command_tool = CommandTool::new(workspace_root.clone());
91        let curl_tool = CurlTool::new();
92        let srgn_tool = SrgnTool::new(workspace_root.clone());
93        let plan_manager = PlanManager::new();
94
95        let ast_grep_engine = match AstGrepEngine::new() {
96            Ok(engine) => Some(Arc::new(engine)),
97            Err(err) => {
98                eprintln!("Warning: Failed to initialize AST-grep engine: {}", err);
99                None
100            }
101        };
102
103        let policy_manager = match ToolPolicyManager::new_with_workspace(&workspace_root) {
104            Ok(manager) => Some(manager),
105            Err(err) => {
106                eprintln!("Warning: Failed to initialize tool policy manager: {}", err);
107                None
108            }
109        };
110
111        let mut registry = Self {
112            workspace_root,
113            search_tool,
114            simple_search_tool,
115            bash_tool,
116            file_ops_tool,
117            command_tool,
118            curl_tool,
119            grep_search,
120            ast_grep_engine,
121            tool_policy: policy_manager,
122            pty_config,
123            active_pty_sessions: Arc::new(AtomicUsize::new(0)),
124            srgn_tool,
125            plan_manager,
126            tool_registrations: Vec::new(),
127            tool_lookup: HashMap::new(),
128            preapproved_tools: HashSet::new(),
129            full_auto_allowlist: None,
130        };
131
132        register_builtin_tools(&mut registry);
133        registry
134    }
135
136    pub fn register_tool(&mut self, registration: ToolRegistration) -> Result<()> {
137        if self.tool_lookup.contains_key(registration.name()) {
138            return Err(anyhow!(format!(
139                "Tool '{}' is already registered",
140                registration.name()
141            )));
142        }
143
144        let index = self.tool_registrations.len();
145        self.tool_lookup.insert(registration.name(), index);
146        self.tool_registrations.push(registration);
147        Ok(())
148    }
149
150    pub fn available_tools(&self) -> Vec<String> {
151        self.tool_registrations
152            .iter()
153            .map(|registration| registration.name().to_string())
154            .collect()
155    }
156
157    pub fn enable_full_auto_mode(&mut self, allowed_tools: &[String]) {
158        let mut normalized: HashSet<String> = HashSet::new();
159        if allowed_tools
160            .iter()
161            .any(|tool| tool.trim() == tools::WILDCARD_ALL)
162        {
163            for tool in self.available_tools() {
164                normalized.insert(tool);
165            }
166        } else {
167            for tool in allowed_tools {
168                let trimmed = tool.trim();
169                if !trimmed.is_empty() {
170                    normalized.insert(trimmed.to_string());
171                }
172            }
173        }
174
175        self.full_auto_allowlist = Some(normalized);
176    }
177
178    pub fn current_full_auto_allowlist(&self) -> Option<Vec<String>> {
179        self.full_auto_allowlist.as_ref().map(|set| {
180            let mut items: Vec<String> = set.iter().cloned().collect();
181            items.sort();
182            items
183        })
184    }
185
186    pub fn has_tool(&self, name: &str) -> bool {
187        self.tool_lookup.contains_key(name)
188    }
189
190    pub fn with_ast_grep(mut self, engine: Arc<AstGrepEngine>) -> Self {
191        self.ast_grep_engine = Some(engine);
192        self
193    }
194
195    pub fn workspace_root(&self) -> &PathBuf {
196        &self.workspace_root
197    }
198
199    pub fn plan_manager(&self) -> PlanManager {
200        self.plan_manager.clone()
201    }
202
203    pub fn current_plan(&self) -> crate::tools::TaskPlan {
204        self.plan_manager.snapshot()
205    }
206
207    pub async fn initialize_async(&mut self) -> Result<()> {
208        Ok(())
209    }
210
211    pub fn apply_config_policies(&mut self, tools_config: &ToolsConfig) -> Result<()> {
212        if let Ok(policy_manager) = self.policy_manager_mut() {
213            policy_manager.apply_tools_config(tools_config)?;
214        }
215
216        Ok(())
217    }
218
219    pub async fn execute_tool(&mut self, name: &str, args: Value) -> Result<Value> {
220        if let Some(allowlist) = &self.full_auto_allowlist {
221            if !allowlist.contains(name) {
222                let error = ToolExecutionError::new(
223                    name.to_string(),
224                    ToolErrorType::PolicyViolation,
225                    format!(
226                        "Tool '{}' is not permitted while full-auto mode is active",
227                        name
228                    ),
229                );
230                return Ok(error.to_json_value());
231            }
232        }
233
234        let skip_policy_prompt = self.preapproved_tools.remove(name);
235
236        if !skip_policy_prompt {
237            if let Ok(policy_manager) = self.policy_manager_mut() {
238                if !policy_manager.should_execute_tool(name)? {
239                    let error = ToolExecutionError::new(
240                        name.to_string(),
241                        ToolErrorType::PolicyViolation,
242                        format!("Tool '{}' execution denied by policy", name),
243                    );
244                    return Ok(error.to_json_value());
245                }
246            }
247        }
248
249        let args = match self.apply_policy_constraints(name, args) {
250            Ok(args) => args,
251            Err(err) => {
252                let error = ToolExecutionError::with_original_error(
253                    name.to_string(),
254                    ToolErrorType::InvalidParameters,
255                    "Failed to apply policy constraints".to_string(),
256                    err.to_string(),
257                );
258                return Ok(error.to_json_value());
259            }
260        };
261
262        let registration = match self
263            .tool_lookup
264            .get(name)
265            .and_then(|index| self.tool_registrations.get(*index))
266        {
267            Some(registration) => registration,
268            None => {
269                let error = ToolExecutionError::new(
270                    name.to_string(),
271                    ToolErrorType::ToolNotFound,
272                    format!("Unknown tool: {}", name),
273                );
274                return Ok(error.to_json_value());
275            }
276        };
277
278        let uses_pty = registration.uses_pty();
279        if uses_pty {
280            if let Err(err) = self.start_pty_session() {
281                let error = ToolExecutionError::with_original_error(
282                    name.to_string(),
283                    ToolErrorType::ExecutionError,
284                    "Failed to start PTY session".to_string(),
285                    err.to_string(),
286                );
287                return Ok(error.to_json_value());
288            }
289        }
290
291        let handler = registration.handler();
292        let result = match handler {
293            ToolHandler::RegistryFn(executor) => executor(self, args).await,
294            ToolHandler::TraitObject(tool) => tool.execute(args).await,
295        };
296
297        if uses_pty {
298            self.end_pty_session();
299        }
300
301        match result {
302            Ok(value) => Ok(normalize_tool_output(value)),
303            Err(err) => {
304                let error_type = classify_error(&err);
305                let error = ToolExecutionError::with_original_error(
306                    name.to_string(),
307                    error_type,
308                    format!("Tool execution failed: {}", err),
309                    err.to_string(),
310                );
311                Ok(error.to_json_value())
312            }
313        }
314    }
315}
316
317impl ToolRegistry {
318    /// Prompt for permission before starting long-running tool executions to avoid spinner conflicts
319    pub fn preflight_tool_permission(&mut self, name: &str) -> Result<bool> {
320        match self.evaluate_tool_policy(name)? {
321            ToolPermissionDecision::Allow => Ok(true),
322            ToolPermissionDecision::Deny => Ok(false),
323            ToolPermissionDecision::Prompt => Ok(true),
324        }
325    }
326
327    pub fn evaluate_tool_policy(&mut self, name: &str) -> Result<ToolPermissionDecision> {
328        if let Some(allowlist) = self.full_auto_allowlist.as_ref() {
329            if !allowlist.contains(name) {
330                return Ok(ToolPermissionDecision::Deny);
331            }
332
333            if let Some(policy_manager) = self.tool_policy.as_mut() {
334                match policy_manager.get_policy(name) {
335                    ToolPolicy::Deny => return Ok(ToolPermissionDecision::Deny),
336                    ToolPolicy::Allow | ToolPolicy::Prompt => {
337                        self.preapproved_tools.insert(name.to_string());
338                        return Ok(ToolPermissionDecision::Allow);
339                    }
340                }
341            }
342
343            self.preapproved_tools.insert(name.to_string());
344            return Ok(ToolPermissionDecision::Allow);
345        }
346
347        if let Some(policy_manager) = self.tool_policy.as_mut() {
348            match policy_manager.get_policy(name) {
349                ToolPolicy::Allow => {
350                    self.preapproved_tools.insert(name.to_string());
351                    Ok(ToolPermissionDecision::Allow)
352                }
353                ToolPolicy::Deny => Ok(ToolPermissionDecision::Deny),
354                ToolPolicy::Prompt => {
355                    if ToolPolicyManager::is_auto_allow_tool(name) {
356                        policy_manager.set_policy(name, ToolPolicy::Allow)?;
357                        self.preapproved_tools.insert(name.to_string());
358                        Ok(ToolPermissionDecision::Allow)
359                    } else {
360                        Ok(ToolPermissionDecision::Prompt)
361                    }
362                }
363            }
364        } else {
365            self.preapproved_tools.insert(name.to_string());
366            Ok(ToolPermissionDecision::Allow)
367        }
368    }
369
370    pub fn mark_tool_preapproved(&mut self, name: &str) {
371        self.preapproved_tools.insert(name.to_string());
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    use async_trait::async_trait;
379    use serde_json::json;
380    use tempfile::TempDir;
381
382    const CUSTOM_TOOL_NAME: &str = "custom_test_tool";
383
384    struct CustomEchoTool;
385
386    #[async_trait]
387    impl Tool for CustomEchoTool {
388        async fn execute(&self, args: Value) -> Result<Value> {
389            Ok(json!({
390                "success": true,
391                "args": args,
392            }))
393        }
394
395        fn name(&self) -> &'static str {
396            CUSTOM_TOOL_NAME
397        }
398
399        fn description(&self) -> &'static str {
400            "Custom echo tool for testing"
401        }
402    }
403
404    #[tokio::test]
405    async fn registers_builtin_tools() -> Result<()> {
406        let temp_dir = TempDir::new()?;
407        let registry = ToolRegistry::new(temp_dir.path().to_path_buf());
408        let available = registry.available_tools();
409
410        assert!(available.contains(&tools::READ_FILE.to_string()));
411        assert!(available.contains(&tools::RUN_TERMINAL_CMD.to_string()));
412        assert!(available.contains(&tools::CURL.to_string()));
413        Ok(())
414    }
415
416    #[tokio::test]
417    async fn allows_registering_custom_tools() -> Result<()> {
418        let temp_dir = TempDir::new()?;
419        let mut registry = ToolRegistry::new(temp_dir.path().to_path_buf());
420
421        registry.register_tool(ToolRegistration::from_tool_instance(
422            CUSTOM_TOOL_NAME,
423            CapabilityLevel::CodeSearch,
424            CustomEchoTool,
425        ))?;
426
427        registry.sync_policy_available_tools();
428
429        registry.allow_all_tools().ok();
430
431        let available = registry.available_tools();
432        assert!(available.contains(&CUSTOM_TOOL_NAME.to_string()));
433
434        let response = registry
435            .execute_tool(CUSTOM_TOOL_NAME, json!({"input": "value"}))
436            .await?;
437        assert!(response["success"].as_bool().unwrap_or(false));
438        Ok(())
439    }
440
441    #[tokio::test]
442    async fn full_auto_allowlist_enforced() -> Result<()> {
443        let temp_dir = TempDir::new()?;
444        let mut registry = ToolRegistry::new(temp_dir.path().to_path_buf());
445
446        registry.enable_full_auto_mode(&vec![tools::READ_FILE.to_string()]);
447
448        assert!(registry.preflight_tool_permission(tools::READ_FILE)?);
449        assert!(!registry.preflight_tool_permission(tools::RUN_TERMINAL_CMD)?);
450
451        Ok(())
452    }
453}