1mod astgrep;
4mod builtins;
5mod cache;
6mod declarations;
7mod error;
8mod executors;
9mod legacy;
10mod policy;
11mod pty;
12mod registration;
13mod utils;
14
15pub use declarations::{build_function_declarations, build_function_declarations_for_level};
16pub use error::{ToolErrorType, ToolExecutionError, classify_error};
17pub use registration::{ToolExecutorFn, ToolHandler, ToolRegistration};
18
19use builtins::register_builtin_tools;
20use utils::normalize_tool_output;
21
22use crate::config::PtyConfig;
23use crate::config::ToolsConfig;
24use crate::config::constants::tools;
25use crate::tool_policy::{ToolPolicy, ToolPolicyManager};
26use crate::tools::ast_grep::AstGrepEngine;
27use crate::tools::grep_search::GrepSearchManager;
28use anyhow::{Result, anyhow};
29use serde_json::Value;
30use std::collections::{HashMap, HashSet};
31use std::path::PathBuf;
32use std::sync::Arc;
33use std::sync::atomic::AtomicUsize;
34use tracing::{debug, warn};
35
36use super::bash_tool::BashTool;
37use super::command::CommandTool;
38use super::curl_tool::CurlTool;
39use super::file_ops::FileOpsTool;
40use super::plan::PlanManager;
41use super::search::SearchTool;
42use super::simple_search::SimpleSearchTool;
43use super::srgn::SrgnTool;
44use crate::mcp_client::{McpClient, McpToolExecutor, McpToolInfo};
45
46#[cfg(test)]
47use super::traits::Tool;
48#[cfg(test)]
49use crate::config::types::CapabilityLevel;
50
51#[derive(Clone)]
52pub struct ToolRegistry {
53 workspace_root: PathBuf,
54 search_tool: SearchTool,
55 simple_search_tool: SimpleSearchTool,
56 bash_tool: BashTool,
57 file_ops_tool: FileOpsTool,
58 command_tool: CommandTool,
59 curl_tool: CurlTool,
60 grep_search: Arc<GrepSearchManager>,
61 ast_grep_engine: Option<Arc<AstGrepEngine>>,
62 tool_policy: Option<ToolPolicyManager>,
63 pty_config: PtyConfig,
64 active_pty_sessions: Arc<AtomicUsize>,
65 srgn_tool: SrgnTool,
66 plan_manager: PlanManager,
67 mcp_client: Option<Arc<McpClient>>,
68 mcp_tool_index: HashMap<String, Vec<String>>,
69 tool_registrations: Vec<ToolRegistration>,
70 tool_lookup: HashMap<&'static str, usize>,
71 preapproved_tools: HashSet<String>,
72 full_auto_allowlist: Option<HashSet<String>>,
73}
74
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum ToolPermissionDecision {
77 Allow,
78 Deny,
79 Prompt,
80}
81
82impl ToolRegistry {
83 pub fn new(workspace_root: PathBuf) -> Self {
84 Self::new_with_config(workspace_root, PtyConfig::default())
85 }
86
87 pub fn new_with_config(workspace_root: PathBuf, pty_config: PtyConfig) -> Self {
88 let grep_search = Arc::new(GrepSearchManager::new(workspace_root.clone()));
89
90 let search_tool = SearchTool::new(workspace_root.clone(), grep_search.clone());
91 let simple_search_tool = SimpleSearchTool::new(workspace_root.clone());
92 let bash_tool = BashTool::new(workspace_root.clone());
93 let file_ops_tool = FileOpsTool::new(workspace_root.clone(), grep_search.clone());
94 let command_tool = CommandTool::new(workspace_root.clone());
95 let curl_tool = CurlTool::new();
96 let srgn_tool = SrgnTool::new(workspace_root.clone());
97 let plan_manager = PlanManager::new();
98
99 let ast_grep_engine = match AstGrepEngine::new() {
100 Ok(engine) => Some(Arc::new(engine)),
101 Err(err) => {
102 eprintln!("Warning: Failed to initialize AST-grep engine: {}", err);
103 None
104 }
105 };
106
107 let policy_manager = match ToolPolicyManager::new_with_workspace(&workspace_root) {
108 Ok(manager) => Some(manager),
109 Err(err) => {
110 eprintln!("Warning: Failed to initialize tool policy manager: {}", err);
111 None
112 }
113 };
114
115 let mut registry = Self {
116 workspace_root,
117 search_tool,
118 simple_search_tool,
119 bash_tool,
120 file_ops_tool,
121 command_tool,
122 curl_tool,
123 grep_search,
124 ast_grep_engine,
125 tool_policy: policy_manager,
126 pty_config,
127 active_pty_sessions: Arc::new(AtomicUsize::new(0)),
128 srgn_tool,
129 plan_manager,
130 mcp_client: None,
131 mcp_tool_index: HashMap::new(),
132 tool_registrations: Vec::new(),
133 tool_lookup: HashMap::new(),
134 preapproved_tools: HashSet::new(),
135 full_auto_allowlist: None,
136 };
137
138 register_builtin_tools(&mut registry);
139 registry
140 }
141
142 pub fn register_tool(&mut self, registration: ToolRegistration) -> Result<()> {
143 if self.tool_lookup.contains_key(registration.name()) {
144 return Err(anyhow!(format!(
145 "Tool '{}' is already registered",
146 registration.name()
147 )));
148 }
149
150 let index = self.tool_registrations.len();
151 self.tool_lookup.insert(registration.name(), index);
152 self.tool_registrations.push(registration);
153 Ok(())
154 }
155
156 pub fn available_tools(&self) -> Vec<String> {
157 self.tool_registrations
158 .iter()
159 .map(|registration| registration.name().to_string())
160 .collect()
161 }
162
163 fn mcp_policy_keys(&self) -> Vec<String> {
164 let mut keys = Vec::new();
165 for (provider, tools) in &self.mcp_tool_index {
166 for tool in tools {
167 keys.push(format!("mcp::{}::{}", provider, tool));
168 }
169 }
170 keys
171 }
172
173 fn find_mcp_provider(&self, tool_name: &str) -> Option<String> {
174 for (provider, tools) in &self.mcp_tool_index {
175 if tools.iter().any(|candidate| candidate == tool_name) {
176 return Some(provider.clone());
177 }
178 }
179 None
180 }
181
182 pub fn enable_full_auto_mode(&mut self, allowed_tools: &[String]) {
183 let mut normalized: HashSet<String> = HashSet::new();
184 if allowed_tools
185 .iter()
186 .any(|tool| tool.trim() == tools::WILDCARD_ALL)
187 {
188 for tool in self.available_tools() {
189 normalized.insert(tool);
190 }
191 } else {
192 for tool in allowed_tools {
193 let trimmed = tool.trim();
194 if !trimmed.is_empty() {
195 normalized.insert(trimmed.to_string());
196 }
197 }
198 }
199
200 self.full_auto_allowlist = Some(normalized);
201 }
202
203 pub fn current_full_auto_allowlist(&self) -> Option<Vec<String>> {
204 self.full_auto_allowlist.as_ref().map(|set| {
205 let mut items: Vec<String> = set.iter().cloned().collect();
206 items.sort();
207 items
208 })
209 }
210
211 pub fn has_tool(&self, name: &str) -> bool {
212 self.tool_lookup.contains_key(name)
213 }
214
215 pub fn with_ast_grep(mut self, engine: Arc<AstGrepEngine>) -> Self {
216 self.ast_grep_engine = Some(engine);
217 self
218 }
219
220 pub fn workspace_root(&self) -> &PathBuf {
221 &self.workspace_root
222 }
223
224 pub fn plan_manager(&self) -> PlanManager {
225 self.plan_manager.clone()
226 }
227
228 pub fn current_plan(&self) -> crate::tools::TaskPlan {
229 self.plan_manager.snapshot()
230 }
231
232 pub async fn initialize_async(&mut self) -> Result<()> {
233 Ok(())
234 }
235
236 pub fn apply_config_policies(&mut self, tools_config: &ToolsConfig) -> Result<()> {
237 if let Ok(policy_manager) = self.policy_manager_mut() {
238 policy_manager.apply_tools_config(tools_config)?;
239 }
240
241 Ok(())
242 }
243
244 pub async fn execute_tool(&mut self, name: &str, args: Value) -> Result<Value> {
245 if let Some(allowlist) = &self.full_auto_allowlist
246 && !allowlist.contains(name)
247 {
248 let error = ToolExecutionError::new(
249 name.to_string(),
250 ToolErrorType::PolicyViolation,
251 format!(
252 "Tool '{}' is not permitted while full-auto mode is active",
253 name
254 ),
255 );
256 return Ok(error.to_json_value());
257 }
258
259 let skip_policy_prompt = self.preapproved_tools.remove(name);
260
261 if !skip_policy_prompt
262 && let Ok(policy_manager) = self.policy_manager_mut()
263 && !policy_manager.should_execute_tool(name)?
264 {
265 let error = ToolExecutionError::new(
266 name.to_string(),
267 ToolErrorType::PolicyViolation,
268 format!("Tool '{}' execution denied by policy", name),
269 );
270 return Ok(error.to_json_value());
271 }
272
273 let args = match self.apply_policy_constraints(name, args) {
274 Ok(args) => args,
275 Err(err) => {
276 let error = ToolExecutionError::with_original_error(
277 name.to_string(),
278 ToolErrorType::InvalidParameters,
279 "Failed to apply policy constraints".to_string(),
280 err.to_string(),
281 );
282 return Ok(error.to_json_value());
283 }
284 };
285
286 let registration = match self
287 .tool_lookup
288 .get(name)
289 .and_then(|index| self.tool_registrations.get(*index))
290 {
291 Some(registration) => registration,
292 None => {
293 if let Some(mcp_client) = &self.mcp_client {
295 if name.starts_with("mcp_") {
297 let actual_tool_name = &name[4..]; match mcp_client.has_mcp_tool(actual_tool_name).await {
299 Ok(true) => {
300 debug!(
301 "MCP tool '{}' found, executing via MCP client",
302 actual_tool_name
303 );
304 return self.execute_mcp_tool(actual_tool_name, args).await;
305 }
306 Ok(false) => {
307 if let Some(resolved_name) =
308 self.resolve_mcp_tool_alias(actual_tool_name).await
309 {
310 if resolved_name != actual_tool_name {
311 debug!(
312 "Resolved MCP tool alias '{}' to '{}'",
313 actual_tool_name, resolved_name
314 );
315 return self.execute_mcp_tool(&resolved_name, args).await;
316 }
317 }
318
319 let error = ToolExecutionError::new(
321 name.to_string(),
322 ToolErrorType::ToolNotFound,
323 format!("Unknown MCP tool: {}", actual_tool_name),
324 );
325 return Ok(error.to_json_value());
326 }
327 Err(e) => {
328 warn!(
329 "Error checking MCP tool availability for '{}': {}",
330 actual_tool_name, e
331 );
332 let error = ToolExecutionError::with_original_error(
333 name.to_string(),
334 ToolErrorType::ExecutionError,
335 format!(
336 "Failed to verify MCP tool '{}' due to provider errors",
337 actual_tool_name
338 ),
339 e.to_string(),
340 );
341 return Ok(error.to_json_value());
342 }
343 }
344 } else {
345 match mcp_client.has_mcp_tool(name).await {
347 Ok(true) => {
348 debug!(
349 "Tool '{}' not found in registry, delegating to MCP client",
350 name
351 );
352 return self.execute_mcp_tool(name, args).await;
353 }
354 Ok(false) => {
355 let error = ToolExecutionError::new(
357 name.to_string(),
358 ToolErrorType::ToolNotFound,
359 format!("Unknown tool: {}", name),
360 );
361 return Ok(error.to_json_value());
362 }
363 Err(e) => {
364 warn!("Error checking MCP tool availability for '{}': {}", name, e);
365 let error = ToolExecutionError::with_original_error(
366 name.to_string(),
367 ToolErrorType::ExecutionError,
368 format!(
369 "Failed to verify MCP tool '{}' due to provider errors",
370 name
371 ),
372 e.to_string(),
373 );
374 return Ok(error.to_json_value());
375 }
376 }
377 }
378 } else {
379 let error = ToolExecutionError::new(
381 name.to_string(),
382 ToolErrorType::ToolNotFound,
383 format!("Unknown tool: {}", name),
384 );
385 return Ok(error.to_json_value());
386 }
387 }
388 };
389
390 let uses_pty = registration.uses_pty();
391 if uses_pty && let Err(err) = self.start_pty_session() {
392 let error = ToolExecutionError::with_original_error(
393 name.to_string(),
394 ToolErrorType::ExecutionError,
395 "Failed to start PTY session".to_string(),
396 err.to_string(),
397 );
398 return Ok(error.to_json_value());
399 }
400
401 let handler = registration.handler();
402 let result = match handler {
403 ToolHandler::RegistryFn(executor) => executor(self, args).await,
404 ToolHandler::TraitObject(tool) => tool.execute(args).await,
405 };
406
407 if uses_pty {
408 self.end_pty_session();
409 }
410
411 match result {
412 Ok(value) => Ok(normalize_tool_output(value)),
413 Err(err) => {
414 let error_type = classify_error(&err);
415 let error = ToolExecutionError::with_original_error(
416 name.to_string(),
417 error_type,
418 format!("Tool execution failed: {}", err),
419 err.to_string(),
420 );
421 Ok(error.to_json_value())
422 }
423 }
424 }
425
426 pub fn with_mcp_client(mut self, mcp_client: Arc<McpClient>) -> Self {
428 self.mcp_client = Some(mcp_client);
429 self
430 }
431
432 pub fn mcp_client(&self) -> Option<&Arc<McpClient>> {
434 self.mcp_client.as_ref()
435 }
436
437 pub async fn list_mcp_tools(&self) -> Result<Vec<McpToolInfo>> {
439 if let Some(mcp_client) = &self.mcp_client {
440 mcp_client.list_mcp_tools().await
441 } else {
442 Ok(Vec::new())
443 }
444 }
445
446 pub async fn has_mcp_tool(&self, tool_name: &str) -> bool {
448 if let Some(mcp_client) = &self.mcp_client {
449 match mcp_client.has_mcp_tool(tool_name).await {
450 Ok(true) => true,
451 Ok(false) => false,
452 Err(_) => {
453 false
455 }
456 }
457 } else {
458 false
459 }
460 }
461
462 pub async fn execute_mcp_tool(&self, tool_name: &str, args: Value) -> Result<Value> {
464 if let Some(mcp_client) = &self.mcp_client {
465 mcp_client.execute_mcp_tool(tool_name, args).await
466 } else {
467 Err(anyhow::anyhow!("MCP client not available"))
468 }
469 }
470
471 async fn resolve_mcp_tool_alias(&self, tool_name: &str) -> Option<String> {
472 let Some(mcp_client) = &self.mcp_client else {
473 return None;
474 };
475
476 let normalized = normalize_mcp_tool_identifier(tool_name);
477 if normalized.is_empty() {
478 return None;
479 }
480
481 let tools = match mcp_client.list_mcp_tools().await {
482 Ok(list) => list,
483 Err(err) => {
484 warn!(
485 "Failed to list MCP tools while resolving alias '{}': {}",
486 tool_name, err
487 );
488 return None;
489 }
490 };
491
492 for tool in tools {
493 if normalize_mcp_tool_identifier(&tool.name) == normalized {
494 return Some(tool.name);
495 }
496 }
497
498 None
499 }
500
501 pub async fn refresh_mcp_tools(&mut self) -> Result<()> {
503 if let Some(mcp_client) = &self.mcp_client {
504 debug!(
505 "Refreshing MCP tools for {} providers",
506 mcp_client.get_status().provider_count
507 );
508
509 let tools = mcp_client.list_mcp_tools().await?;
510 let mut provider_map: HashMap<String, Vec<String>> = HashMap::new();
511
512 for tool in tools {
513 provider_map
514 .entry(tool.provider.clone())
515 .or_default()
516 .push(tool.name.clone());
517 }
518
519 for tools in provider_map.values_mut() {
520 tools.sort();
521 tools.dedup();
522 }
523
524 self.mcp_tool_index = provider_map;
525
526 if let Some(policy_manager) = self.tool_policy.as_mut() {
527 policy_manager.update_mcp_tools(&self.mcp_tool_index)?;
528 let allowlist = policy_manager.mcp_allowlist().clone();
529 mcp_client.update_allowlist(allowlist);
530 }
531
532 self.sync_policy_available_tools();
533 Ok(())
534 } else {
535 debug!("No MCP client configured, nothing to refresh");
536 Ok(())
537 }
538 }
539}
540
541impl ToolRegistry {
542 pub fn preflight_tool_permission(&mut self, name: &str) -> Result<bool> {
544 match self.evaluate_tool_policy(name)? {
545 ToolPermissionDecision::Allow => Ok(true),
546 ToolPermissionDecision::Deny => Ok(false),
547 ToolPermissionDecision::Prompt => Ok(true),
548 }
549 }
550
551 pub fn evaluate_tool_policy(&mut self, name: &str) -> Result<ToolPermissionDecision> {
552 if let Some(tool_name) = name.strip_prefix("mcp_") {
553 return self.evaluate_mcp_tool_policy(name, tool_name);
554 }
555
556 if let Some(allowlist) = self.full_auto_allowlist.as_ref() {
557 if !allowlist.contains(name) {
558 return Ok(ToolPermissionDecision::Deny);
559 }
560
561 if let Some(policy_manager) = self.tool_policy.as_mut() {
562 match policy_manager.get_policy(name) {
563 ToolPolicy::Deny => return Ok(ToolPermissionDecision::Deny),
564 ToolPolicy::Allow | ToolPolicy::Prompt => {
565 self.preapproved_tools.insert(name.to_string());
566 return Ok(ToolPermissionDecision::Allow);
567 }
568 }
569 }
570
571 self.preapproved_tools.insert(name.to_string());
572 return Ok(ToolPermissionDecision::Allow);
573 }
574
575 if let Some(policy_manager) = self.tool_policy.as_mut() {
576 match policy_manager.get_policy(name) {
577 ToolPolicy::Allow => {
578 self.preapproved_tools.insert(name.to_string());
579 Ok(ToolPermissionDecision::Allow)
580 }
581 ToolPolicy::Deny => Ok(ToolPermissionDecision::Deny),
582 ToolPolicy::Prompt => {
583 if ToolPolicyManager::is_auto_allow_tool(name) {
584 policy_manager.set_policy(name, ToolPolicy::Allow)?;
585 self.preapproved_tools.insert(name.to_string());
586 Ok(ToolPermissionDecision::Allow)
587 } else {
588 Ok(ToolPermissionDecision::Prompt)
589 }
590 }
591 }
592 } else {
593 self.preapproved_tools.insert(name.to_string());
594 Ok(ToolPermissionDecision::Allow)
595 }
596 }
597
598 fn evaluate_mcp_tool_policy(
599 &mut self,
600 full_name: &str,
601 tool_name: &str,
602 ) -> Result<ToolPermissionDecision> {
603 let provider = match self.find_mcp_provider(tool_name) {
604 Some(provider) => provider,
605 None => {
606 return Ok(ToolPermissionDecision::Prompt);
608 }
609 };
610
611 if let Some(allowlist) = self.full_auto_allowlist.as_ref() {
612 if !allowlist.contains(full_name) {
613 return Ok(ToolPermissionDecision::Deny);
614 }
615
616 if let Some(policy_manager) = self.tool_policy.as_mut() {
617 match policy_manager.get_mcp_tool_policy(&provider, tool_name) {
618 ToolPolicy::Deny => return Ok(ToolPermissionDecision::Deny),
619 ToolPolicy::Allow | ToolPolicy::Prompt => {
620 self.preapproved_tools.insert(full_name.to_string());
621 return Ok(ToolPermissionDecision::Allow);
622 }
623 }
624 }
625
626 self.preapproved_tools.insert(full_name.to_string());
627 return Ok(ToolPermissionDecision::Allow);
628 }
629
630 if let Some(policy_manager) = self.tool_policy.as_mut() {
631 match policy_manager.get_mcp_tool_policy(&provider, tool_name) {
632 ToolPolicy::Allow => {
633 self.preapproved_tools.insert(full_name.to_string());
634 Ok(ToolPermissionDecision::Allow)
635 }
636 ToolPolicy::Deny => Ok(ToolPermissionDecision::Deny),
637 ToolPolicy::Prompt => Ok(ToolPermissionDecision::Prompt),
638 }
639 } else {
640 self.preapproved_tools.insert(full_name.to_string());
641 Ok(ToolPermissionDecision::Allow)
642 }
643 }
644
645 pub fn mark_tool_preapproved(&mut self, name: &str) {
646 self.preapproved_tools.insert(name.to_string());
647 }
648
649 pub fn persist_mcp_tool_policy(&mut self, name: &str, policy: ToolPolicy) -> Result<()> {
650 if !name.starts_with("mcp_") {
651 return Ok(());
652 }
653
654 let Some(tool_name) = name.strip_prefix("mcp_") else {
655 return Ok(());
656 };
657
658 let Some(provider) = self.find_mcp_provider(tool_name) else {
659 return Ok(());
660 };
661
662 if let Some(manager) = self.tool_policy.as_mut() {
663 manager.set_mcp_tool_policy(&provider, tool_name, policy)?;
664 }
665
666 Ok(())
667 }
668}
669
670fn normalize_mcp_tool_identifier(value: &str) -> String {
671 let mut normalized = String::new();
672 for ch in value.chars() {
673 if ch.is_ascii_alphanumeric() {
674 normalized.push(ch.to_ascii_lowercase());
675 }
676 }
677 normalized
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683 use async_trait::async_trait;
684 use serde_json::json;
685 use tempfile::TempDir;
686
687 const CUSTOM_TOOL_NAME: &str = "custom_test_tool";
688
689 struct CustomEchoTool;
690
691 #[async_trait]
692 impl Tool for CustomEchoTool {
693 async fn execute(&self, args: Value) -> Result<Value> {
694 Ok(json!({
695 "success": true,
696 "args": args,
697 }))
698 }
699
700 fn name(&self) -> &'static str {
701 CUSTOM_TOOL_NAME
702 }
703
704 fn description(&self) -> &'static str {
705 "Custom echo tool for testing"
706 }
707 }
708
709 #[tokio::test]
710 async fn registers_builtin_tools() -> Result<()> {
711 let temp_dir = TempDir::new()?;
712 let registry = ToolRegistry::new(temp_dir.path().to_path_buf());
713 let available = registry.available_tools();
714
715 assert!(available.contains(&tools::READ_FILE.to_string()));
716 assert!(available.contains(&tools::RUN_TERMINAL_CMD.to_string()));
717 assert!(available.contains(&tools::CURL.to_string()));
718 Ok(())
719 }
720
721 #[tokio::test]
722 async fn allows_registering_custom_tools() -> Result<()> {
723 let temp_dir = TempDir::new()?;
724 let mut registry = ToolRegistry::new(temp_dir.path().to_path_buf());
725
726 registry.register_tool(ToolRegistration::from_tool_instance(
727 CUSTOM_TOOL_NAME,
728 CapabilityLevel::CodeSearch,
729 CustomEchoTool,
730 ))?;
731
732 registry.sync_policy_available_tools();
733
734 registry.allow_all_tools().ok();
735
736 let available = registry.available_tools();
737 assert!(available.contains(&CUSTOM_TOOL_NAME.to_string()));
738
739 let response = registry
740 .execute_tool(CUSTOM_TOOL_NAME, json!({"input": "value"}))
741 .await?;
742 assert!(response["success"].as_bool().unwrap_or(false));
743 Ok(())
744 }
745
746 #[tokio::test]
747 async fn full_auto_allowlist_enforced() -> Result<()> {
748 let temp_dir = TempDir::new()?;
749 let mut registry = ToolRegistry::new(temp_dir.path().to_path_buf());
750
751 registry.enable_full_auto_mode(&vec![tools::READ_FILE.to_string()]);
752
753 assert!(registry.preflight_tool_permission(tools::READ_FILE)?);
754 assert!(!registry.preflight_tool_permission(tools::RUN_TERMINAL_CMD)?);
755
756 Ok(())
757 }
758
759 #[test]
760 fn normalizes_mcp_tool_identifiers() {
761 assert_eq!(
762 normalize_mcp_tool_identifier("sequential-thinking"),
763 "sequentialthinking"
764 );
765 assert_eq!(
766 normalize_mcp_tool_identifier("Context7.Lookup"),
767 "context7lookup"
768 );
769 assert_eq!(normalize_mcp_tool_identifier("alpha_beta"), "alphabeta");
770 }
771}