vtcode_core/tools/registry/
mod.rs

1//! Tool registry and function declarations
2
3mod astgrep;
4mod builtins;
5mod cache;
6mod declarations;
7mod error;
8mod executors;
9mod legacy;
10mod policy;
11mod pty;
12mod registration;
13mod utils;
14
15pub use declarations::{build_function_declarations, build_function_declarations_for_level};
16pub use error::{ToolErrorType, ToolExecutionError, classify_error};
17pub use registration::{ToolExecutorFn, ToolHandler, ToolRegistration};
18
19use builtins::register_builtin_tools;
20use utils::normalize_tool_output;
21
22use crate::config::PtyConfig;
23use crate::config::ToolsConfig;
24use crate::config::constants::tools;
25use crate::tool_policy::{ToolPolicy, ToolPolicyManager};
26use crate::tools::ast_grep::AstGrepEngine;
27use crate::tools::grep_search::GrepSearchManager;
28use anyhow::{Result, anyhow};
29use serde_json::Value;
30use std::collections::{HashMap, HashSet};
31use std::path::PathBuf;
32use std::sync::Arc;
33use std::sync::atomic::AtomicUsize;
34
35use super::bash_tool::BashTool;
36use super::command::CommandTool;
37use super::file_ops::FileOpsTool;
38use super::search::SearchTool;
39use super::simple_search::SimpleSearchTool;
40use super::srgn::SrgnTool;
41
42#[cfg(test)]
43use super::traits::Tool;
44#[cfg(test)]
45use crate::config::constants::tools;
46#[cfg(test)]
47use crate::config::types::CapabilityLevel;
48
49#[derive(Clone)]
50pub struct ToolRegistry {
51    workspace_root: PathBuf,
52    search_tool: SearchTool,
53    simple_search_tool: SimpleSearchTool,
54    bash_tool: BashTool,
55    file_ops_tool: FileOpsTool,
56    command_tool: CommandTool,
57    grep_search: Arc<GrepSearchManager>,
58    ast_grep_engine: Option<Arc<AstGrepEngine>>,
59    tool_policy: Option<ToolPolicyManager>,
60    pty_config: PtyConfig,
61    active_pty_sessions: Arc<AtomicUsize>,
62    srgn_tool: SrgnTool,
63    tool_registrations: Vec<ToolRegistration>,
64    tool_lookup: HashMap<&'static str, usize>,
65    preapproved_tools: HashSet<String>,
66    full_auto_allowlist: Option<HashSet<String>>,
67}
68
69impl ToolRegistry {
70    pub fn new(workspace_root: PathBuf) -> Self {
71        Self::new_with_config(workspace_root, PtyConfig::default())
72    }
73
74    pub fn new_with_config(workspace_root: PathBuf, pty_config: PtyConfig) -> Self {
75        let grep_search = Arc::new(GrepSearchManager::new(workspace_root.clone()));
76
77        let search_tool = SearchTool::new(workspace_root.clone(), grep_search.clone());
78        let simple_search_tool = SimpleSearchTool::new(workspace_root.clone());
79        let bash_tool = BashTool::new(workspace_root.clone());
80        let file_ops_tool = FileOpsTool::new(workspace_root.clone(), grep_search.clone());
81        let command_tool = CommandTool::new(workspace_root.clone());
82        let srgn_tool = SrgnTool::new(workspace_root.clone());
83
84        let ast_grep_engine = match AstGrepEngine::new() {
85            Ok(engine) => Some(Arc::new(engine)),
86            Err(err) => {
87                eprintln!("Warning: Failed to initialize AST-grep engine: {}", err);
88                None
89            }
90        };
91
92        let policy_manager = match ToolPolicyManager::new_with_workspace(&workspace_root) {
93            Ok(manager) => Some(manager),
94            Err(err) => {
95                eprintln!("Warning: Failed to initialize tool policy manager: {}", err);
96                None
97            }
98        };
99
100        let mut registry = Self {
101            workspace_root,
102            search_tool,
103            simple_search_tool,
104            bash_tool,
105            file_ops_tool,
106            command_tool,
107            grep_search,
108            ast_grep_engine,
109            tool_policy: policy_manager,
110            pty_config,
111            active_pty_sessions: Arc::new(AtomicUsize::new(0)),
112            srgn_tool,
113            tool_registrations: Vec::new(),
114            tool_lookup: HashMap::new(),
115            preapproved_tools: HashSet::new(),
116            full_auto_allowlist: None,
117        };
118
119        register_builtin_tools(&mut registry);
120        registry
121    }
122
123    pub fn register_tool(&mut self, registration: ToolRegistration) -> Result<()> {
124        if self.tool_lookup.contains_key(registration.name()) {
125            return Err(anyhow!(format!(
126                "Tool '{}' is already registered",
127                registration.name()
128            )));
129        }
130
131        let index = self.tool_registrations.len();
132        self.tool_lookup.insert(registration.name(), index);
133        self.tool_registrations.push(registration);
134        Ok(())
135    }
136
137    pub fn available_tools(&self) -> Vec<String> {
138        self.tool_registrations
139            .iter()
140            .map(|registration| registration.name().to_string())
141            .collect()
142    }
143
144    pub fn enable_full_auto_mode(&mut self, allowed_tools: &[String]) {
145        let mut normalized: HashSet<String> = HashSet::new();
146        if allowed_tools
147            .iter()
148            .any(|tool| tool.trim() == tools::WILDCARD_ALL)
149        {
150            for tool in self.available_tools() {
151                normalized.insert(tool);
152            }
153        } else {
154            for tool in allowed_tools {
155                let trimmed = tool.trim();
156                if !trimmed.is_empty() {
157                    normalized.insert(trimmed.to_string());
158                }
159            }
160        }
161
162        self.full_auto_allowlist = Some(normalized);
163    }
164
165    pub fn current_full_auto_allowlist(&self) -> Option<Vec<String>> {
166        self.full_auto_allowlist.as_ref().map(|set| {
167            let mut items: Vec<String> = set.iter().cloned().collect();
168            items.sort();
169            items
170        })
171    }
172
173    pub fn has_tool(&self, name: &str) -> bool {
174        self.tool_lookup.contains_key(name)
175    }
176
177    pub fn with_ast_grep(mut self, engine: Arc<AstGrepEngine>) -> Self {
178        self.ast_grep_engine = Some(engine);
179        self
180    }
181
182    pub fn workspace_root(&self) -> &PathBuf {
183        &self.workspace_root
184    }
185
186    pub async fn initialize_async(&mut self) -> Result<()> {
187        Ok(())
188    }
189
190    pub fn apply_config_policies(&mut self, tools_config: &ToolsConfig) -> Result<()> {
191        if let Ok(policy_manager) = self.policy_manager_mut() {
192            policy_manager.apply_tools_config(tools_config)?;
193        }
194
195        Ok(())
196    }
197
198    pub async fn execute_tool(&mut self, name: &str, args: Value) -> Result<Value> {
199        if let Some(allowlist) = &self.full_auto_allowlist {
200            if !allowlist.contains(name) {
201                let error = ToolExecutionError::new(
202                    name.to_string(),
203                    ToolErrorType::PolicyViolation,
204                    format!(
205                        "Tool '{}' is not permitted while full-auto mode is active",
206                        name
207                    ),
208                );
209                return Ok(error.to_json_value());
210            }
211        }
212
213        let skip_policy_prompt = self.preapproved_tools.remove(name);
214
215        if !skip_policy_prompt {
216            if let Ok(policy_manager) = self.policy_manager_mut() {
217                if !policy_manager.should_execute_tool(name)? {
218                    let error = ToolExecutionError::new(
219                        name.to_string(),
220                        ToolErrorType::PolicyViolation,
221                        format!("Tool '{}' execution denied by policy", name),
222                    );
223                    return Ok(error.to_json_value());
224                }
225            }
226        }
227
228        let args = match self.apply_policy_constraints(name, args) {
229            Ok(args) => args,
230            Err(err) => {
231                let error = ToolExecutionError::with_original_error(
232                    name.to_string(),
233                    ToolErrorType::InvalidParameters,
234                    "Failed to apply policy constraints".to_string(),
235                    err.to_string(),
236                );
237                return Ok(error.to_json_value());
238            }
239        };
240
241        let registration = match self
242            .tool_lookup
243            .get(name)
244            .and_then(|index| self.tool_registrations.get(*index))
245        {
246            Some(registration) => registration,
247            None => {
248                let error = ToolExecutionError::new(
249                    name.to_string(),
250                    ToolErrorType::ToolNotFound,
251                    format!("Unknown tool: {}", name),
252                );
253                return Ok(error.to_json_value());
254            }
255        };
256
257        let uses_pty = registration.uses_pty();
258        if uses_pty {
259            if let Err(err) = self.start_pty_session() {
260                let error = ToolExecutionError::with_original_error(
261                    name.to_string(),
262                    ToolErrorType::ExecutionError,
263                    "Failed to start PTY session".to_string(),
264                    err.to_string(),
265                );
266                return Ok(error.to_json_value());
267            }
268        }
269
270        let handler = registration.handler();
271        let result = match handler {
272            ToolHandler::RegistryFn(executor) => executor(self, args).await,
273            ToolHandler::TraitObject(tool) => tool.execute(args).await,
274        };
275
276        if uses_pty {
277            self.end_pty_session();
278        }
279
280        match result {
281            Ok(value) => Ok(normalize_tool_output(value)),
282            Err(err) => {
283                let error_type = classify_error(&err);
284                let error = ToolExecutionError::with_original_error(
285                    name.to_string(),
286                    error_type,
287                    format!("Tool execution failed: {}", err),
288                    err.to_string(),
289                );
290                Ok(error.to_json_value())
291            }
292        }
293    }
294}
295
296impl ToolRegistry {
297    /// Prompt for permission before starting long-running tool executions to avoid spinner conflicts
298    pub fn preflight_tool_permission(&mut self, name: &str) -> Result<bool> {
299        if let Some(allowlist) = self.full_auto_allowlist.as_ref() {
300            if !allowlist.contains(name) {
301                return Ok(false);
302            }
303
304            if let Some(policy_manager) = self.tool_policy.as_mut() {
305                match policy_manager.get_policy(name) {
306                    ToolPolicy::Deny => return Ok(false),
307                    ToolPolicy::Allow | ToolPolicy::Prompt => {
308                        self.preapproved_tools.insert(name.to_string());
309                        return Ok(true);
310                    }
311                }
312            }
313
314            self.preapproved_tools.insert(name.to_string());
315            return Ok(true);
316        }
317
318        if let Ok(policy_manager) = self.policy_manager_mut() {
319            let allowed = policy_manager.should_execute_tool(name)?;
320            if allowed {
321                self.preapproved_tools.insert(name.to_string());
322            }
323            return Ok(allowed);
324        }
325
326        Ok(true)
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use async_trait::async_trait;
334    use serde_json::json;
335    use tempfile::TempDir;
336
337    const CUSTOM_TOOL_NAME: &str = "custom_test_tool";
338
339    struct CustomEchoTool;
340
341    #[async_trait]
342    impl Tool for CustomEchoTool {
343        async fn execute(&self, args: Value) -> Result<Value> {
344            Ok(json!({
345                "success": true,
346                "args": args,
347            }))
348        }
349
350        fn name(&self) -> &'static str {
351            CUSTOM_TOOL_NAME
352        }
353
354        fn description(&self) -> &'static str {
355            "Custom echo tool for testing"
356        }
357    }
358
359    #[tokio::test]
360    async fn registers_builtin_tools() -> Result<()> {
361        let temp_dir = TempDir::new()?;
362        let registry = ToolRegistry::new(temp_dir.path().to_path_buf());
363        let available = registry.available_tools();
364
365        assert!(available.contains(&tools::READ_FILE.to_string()));
366        assert!(available.contains(&tools::RUN_TERMINAL_CMD.to_string()));
367        Ok(())
368    }
369
370    #[tokio::test]
371    async fn allows_registering_custom_tools() -> Result<()> {
372        let temp_dir = TempDir::new()?;
373        let mut registry = ToolRegistry::new(temp_dir.path().to_path_buf());
374
375        registry.register_tool(ToolRegistration::from_tool_instance(
376            CUSTOM_TOOL_NAME,
377            CapabilityLevel::CodeSearch,
378            CustomEchoTool,
379        ))?;
380
381        registry.sync_policy_available_tools();
382
383        registry.allow_all_tools().ok();
384
385        let available = registry.available_tools();
386        assert!(available.contains(&CUSTOM_TOOL_NAME.to_string()));
387
388        let response = registry
389            .execute_tool(CUSTOM_TOOL_NAME, json!({"input": "value"}))
390            .await?;
391        assert!(response["success"].as_bool().unwrap_or(false));
392        Ok(())
393    }
394
395    #[tokio::test]
396    async fn full_auto_allowlist_enforced() -> Result<()> {
397        let temp_dir = TempDir::new()?;
398        let mut registry = ToolRegistry::new(temp_dir.path().to_path_buf());
399
400        registry.enable_full_auto_mode(&vec![tools::READ_FILE.to_string()]);
401
402        assert!(registry.preflight_tool_permission(tools::READ_FILE)?);
403        assert!(!registry.preflight_tool_permission(tools::RUN_TERMINAL_CMD)?);
404
405        Ok(())
406    }
407}