vtcode_core/tools/registry/
mod.rs

1//! Tool registry and function declarations
2
3mod astgrep;
4mod builtins;
5mod cache;
6mod declarations;
7mod error;
8mod executors;
9mod legacy;
10mod policy;
11mod pty;
12mod registration;
13mod utils;
14
15pub use declarations::{build_function_declarations, build_function_declarations_for_level};
16pub use error::{ToolErrorType, ToolExecutionError, classify_error};
17pub use registration::{ToolExecutorFn, ToolHandler, ToolRegistration};
18
19use builtins::register_builtin_tools;
20use utils::normalize_tool_output;
21
22use crate::config::PtyConfig;
23use crate::config::ToolsConfig;
24use crate::config::constants::tools;
25use crate::tool_policy::{ToolPolicy, ToolPolicyManager};
26use crate::tools::ast_grep::AstGrepEngine;
27use crate::tools::grep_search::GrepSearchManager;
28use anyhow::{Result, anyhow};
29use serde_json::Value;
30use std::collections::{HashMap, HashSet};
31use std::path::PathBuf;
32use std::sync::Arc;
33use std::sync::atomic::AtomicUsize;
34
35use super::bash_tool::BashTool;
36use super::command::CommandTool;
37use super::file_ops::FileOpsTool;
38use super::search::SearchTool;
39use super::simple_search::SimpleSearchTool;
40use super::srgn::SrgnTool;
41
42#[cfg(test)]
43use super::traits::Tool;
44#[cfg(test)]
45use crate::config::types::CapabilityLevel;
46
47#[derive(Clone)]
48pub struct ToolRegistry {
49    workspace_root: PathBuf,
50    search_tool: SearchTool,
51    simple_search_tool: SimpleSearchTool,
52    bash_tool: BashTool,
53    file_ops_tool: FileOpsTool,
54    command_tool: CommandTool,
55    grep_search: Arc<GrepSearchManager>,
56    ast_grep_engine: Option<Arc<AstGrepEngine>>,
57    tool_policy: Option<ToolPolicyManager>,
58    pty_config: PtyConfig,
59    active_pty_sessions: Arc<AtomicUsize>,
60    srgn_tool: SrgnTool,
61    tool_registrations: Vec<ToolRegistration>,
62    tool_lookup: HashMap<&'static str, usize>,
63    preapproved_tools: HashSet<String>,
64    full_auto_allowlist: Option<HashSet<String>>,
65}
66
67impl ToolRegistry {
68    pub fn new(workspace_root: PathBuf) -> Self {
69        Self::new_with_config(workspace_root, PtyConfig::default())
70    }
71
72    pub fn new_with_config(workspace_root: PathBuf, pty_config: PtyConfig) -> Self {
73        let grep_search = Arc::new(GrepSearchManager::new(workspace_root.clone()));
74
75        let search_tool = SearchTool::new(workspace_root.clone(), grep_search.clone());
76        let simple_search_tool = SimpleSearchTool::new(workspace_root.clone());
77        let bash_tool = BashTool::new(workspace_root.clone());
78        let file_ops_tool = FileOpsTool::new(workspace_root.clone(), grep_search.clone());
79        let command_tool = CommandTool::new(workspace_root.clone());
80        let srgn_tool = SrgnTool::new(workspace_root.clone());
81
82        let ast_grep_engine = match AstGrepEngine::new() {
83            Ok(engine) => Some(Arc::new(engine)),
84            Err(err) => {
85                eprintln!("Warning: Failed to initialize AST-grep engine: {}", err);
86                None
87            }
88        };
89
90        let policy_manager = match ToolPolicyManager::new_with_workspace(&workspace_root) {
91            Ok(manager) => Some(manager),
92            Err(err) => {
93                eprintln!("Warning: Failed to initialize tool policy manager: {}", err);
94                None
95            }
96        };
97
98        let mut registry = Self {
99            workspace_root,
100            search_tool,
101            simple_search_tool,
102            bash_tool,
103            file_ops_tool,
104            command_tool,
105            grep_search,
106            ast_grep_engine,
107            tool_policy: policy_manager,
108            pty_config,
109            active_pty_sessions: Arc::new(AtomicUsize::new(0)),
110            srgn_tool,
111            tool_registrations: Vec::new(),
112            tool_lookup: HashMap::new(),
113            preapproved_tools: HashSet::new(),
114            full_auto_allowlist: None,
115        };
116
117        register_builtin_tools(&mut registry);
118        registry
119    }
120
121    pub fn register_tool(&mut self, registration: ToolRegistration) -> Result<()> {
122        if self.tool_lookup.contains_key(registration.name()) {
123            return Err(anyhow!(format!(
124                "Tool '{}' is already registered",
125                registration.name()
126            )));
127        }
128
129        let index = self.tool_registrations.len();
130        self.tool_lookup.insert(registration.name(), index);
131        self.tool_registrations.push(registration);
132        Ok(())
133    }
134
135    pub fn available_tools(&self) -> Vec<String> {
136        self.tool_registrations
137            .iter()
138            .map(|registration| registration.name().to_string())
139            .collect()
140    }
141
142    pub fn enable_full_auto_mode(&mut self, allowed_tools: &[String]) {
143        let mut normalized: HashSet<String> = HashSet::new();
144        if allowed_tools
145            .iter()
146            .any(|tool| tool.trim() == tools::WILDCARD_ALL)
147        {
148            for tool in self.available_tools() {
149                normalized.insert(tool);
150            }
151        } else {
152            for tool in allowed_tools {
153                let trimmed = tool.trim();
154                if !trimmed.is_empty() {
155                    normalized.insert(trimmed.to_string());
156                }
157            }
158        }
159
160        self.full_auto_allowlist = Some(normalized);
161    }
162
163    pub fn current_full_auto_allowlist(&self) -> Option<Vec<String>> {
164        self.full_auto_allowlist.as_ref().map(|set| {
165            let mut items: Vec<String> = set.iter().cloned().collect();
166            items.sort();
167            items
168        })
169    }
170
171    pub fn has_tool(&self, name: &str) -> bool {
172        self.tool_lookup.contains_key(name)
173    }
174
175    pub fn with_ast_grep(mut self, engine: Arc<AstGrepEngine>) -> Self {
176        self.ast_grep_engine = Some(engine);
177        self
178    }
179
180    pub fn workspace_root(&self) -> &PathBuf {
181        &self.workspace_root
182    }
183
184    pub async fn initialize_async(&mut self) -> Result<()> {
185        Ok(())
186    }
187
188    pub fn apply_config_policies(&mut self, tools_config: &ToolsConfig) -> Result<()> {
189        if let Ok(policy_manager) = self.policy_manager_mut() {
190            policy_manager.apply_tools_config(tools_config)?;
191        }
192
193        Ok(())
194    }
195
196    pub async fn execute_tool(&mut self, name: &str, args: Value) -> Result<Value> {
197        if let Some(allowlist) = &self.full_auto_allowlist {
198            if !allowlist.contains(name) {
199                let error = ToolExecutionError::new(
200                    name.to_string(),
201                    ToolErrorType::PolicyViolation,
202                    format!(
203                        "Tool '{}' is not permitted while full-auto mode is active",
204                        name
205                    ),
206                );
207                return Ok(error.to_json_value());
208            }
209        }
210
211        let skip_policy_prompt = self.preapproved_tools.remove(name);
212
213        if !skip_policy_prompt {
214            if let Ok(policy_manager) = self.policy_manager_mut() {
215                if !policy_manager.should_execute_tool(name)? {
216                    let error = ToolExecutionError::new(
217                        name.to_string(),
218                        ToolErrorType::PolicyViolation,
219                        format!("Tool '{}' execution denied by policy", name),
220                    );
221                    return Ok(error.to_json_value());
222                }
223            }
224        }
225
226        let args = match self.apply_policy_constraints(name, args) {
227            Ok(args) => args,
228            Err(err) => {
229                let error = ToolExecutionError::with_original_error(
230                    name.to_string(),
231                    ToolErrorType::InvalidParameters,
232                    "Failed to apply policy constraints".to_string(),
233                    err.to_string(),
234                );
235                return Ok(error.to_json_value());
236            }
237        };
238
239        let registration = match self
240            .tool_lookup
241            .get(name)
242            .and_then(|index| self.tool_registrations.get(*index))
243        {
244            Some(registration) => registration,
245            None => {
246                let error = ToolExecutionError::new(
247                    name.to_string(),
248                    ToolErrorType::ToolNotFound,
249                    format!("Unknown tool: {}", name),
250                );
251                return Ok(error.to_json_value());
252            }
253        };
254
255        let uses_pty = registration.uses_pty();
256        if uses_pty {
257            if let Err(err) = self.start_pty_session() {
258                let error = ToolExecutionError::with_original_error(
259                    name.to_string(),
260                    ToolErrorType::ExecutionError,
261                    "Failed to start PTY session".to_string(),
262                    err.to_string(),
263                );
264                return Ok(error.to_json_value());
265            }
266        }
267
268        let handler = registration.handler();
269        let result = match handler {
270            ToolHandler::RegistryFn(executor) => executor(self, args).await,
271            ToolHandler::TraitObject(tool) => tool.execute(args).await,
272        };
273
274        if uses_pty {
275            self.end_pty_session();
276        }
277
278        match result {
279            Ok(value) => Ok(normalize_tool_output(value)),
280            Err(err) => {
281                let error_type = classify_error(&err);
282                let error = ToolExecutionError::with_original_error(
283                    name.to_string(),
284                    error_type,
285                    format!("Tool execution failed: {}", err),
286                    err.to_string(),
287                );
288                Ok(error.to_json_value())
289            }
290        }
291    }
292}
293
294impl ToolRegistry {
295    /// Prompt for permission before starting long-running tool executions to avoid spinner conflicts
296    pub fn preflight_tool_permission(&mut self, name: &str) -> Result<bool> {
297        if let Some(allowlist) = self.full_auto_allowlist.as_ref() {
298            if !allowlist.contains(name) {
299                return Ok(false);
300            }
301
302            if let Some(policy_manager) = self.tool_policy.as_mut() {
303                match policy_manager.get_policy(name) {
304                    ToolPolicy::Deny => return Ok(false),
305                    ToolPolicy::Allow | ToolPolicy::Prompt => {
306                        self.preapproved_tools.insert(name.to_string());
307                        return Ok(true);
308                    }
309                }
310            }
311
312            self.preapproved_tools.insert(name.to_string());
313            return Ok(true);
314        }
315
316        if let Ok(policy_manager) = self.policy_manager_mut() {
317            let allowed = policy_manager.should_execute_tool(name)?;
318            if allowed {
319                self.preapproved_tools.insert(name.to_string());
320            }
321            return Ok(allowed);
322        }
323
324        Ok(true)
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use async_trait::async_trait;
332    use serde_json::json;
333    use tempfile::TempDir;
334
335    const CUSTOM_TOOL_NAME: &str = "custom_test_tool";
336
337    struct CustomEchoTool;
338
339    #[async_trait]
340    impl Tool for CustomEchoTool {
341        async fn execute(&self, args: Value) -> Result<Value> {
342            Ok(json!({
343                "success": true,
344                "args": args,
345            }))
346        }
347
348        fn name(&self) -> &'static str {
349            CUSTOM_TOOL_NAME
350        }
351
352        fn description(&self) -> &'static str {
353            "Custom echo tool for testing"
354        }
355    }
356
357    #[tokio::test]
358    async fn registers_builtin_tools() -> Result<()> {
359        let temp_dir = TempDir::new()?;
360        let registry = ToolRegistry::new(temp_dir.path().to_path_buf());
361        let available = registry.available_tools();
362
363        assert!(available.contains(&tools::READ_FILE.to_string()));
364        assert!(available.contains(&tools::RUN_TERMINAL_CMD.to_string()));
365        Ok(())
366    }
367
368    #[tokio::test]
369    async fn allows_registering_custom_tools() -> Result<()> {
370        let temp_dir = TempDir::new()?;
371        let mut registry = ToolRegistry::new(temp_dir.path().to_path_buf());
372
373        registry.register_tool(ToolRegistration::from_tool_instance(
374            CUSTOM_TOOL_NAME,
375            CapabilityLevel::CodeSearch,
376            CustomEchoTool,
377        ))?;
378
379        registry.sync_policy_available_tools();
380
381        registry.allow_all_tools().ok();
382
383        let available = registry.available_tools();
384        assert!(available.contains(&CUSTOM_TOOL_NAME.to_string()));
385
386        let response = registry
387            .execute_tool(CUSTOM_TOOL_NAME, json!({"input": "value"}))
388            .await?;
389        assert!(response["success"].as_bool().unwrap_or(false));
390        Ok(())
391    }
392
393    #[tokio::test]
394    async fn full_auto_allowlist_enforced() -> Result<()> {
395        let temp_dir = TempDir::new()?;
396        let mut registry = ToolRegistry::new(temp_dir.path().to_path_buf());
397
398        registry.enable_full_auto_mode(&vec![tools::READ_FILE.to_string()]);
399
400        assert!(registry.preflight_tool_permission(tools::READ_FILE)?);
401        assert!(!registry.preflight_tool_permission(tools::RUN_TERMINAL_CMD)?);
402
403        Ok(())
404    }
405}