steer_core/tools/
registry.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use steer_tools::ToolSchema;
5
6use super::backend::ToolBackend;
7use super::capability::Capabilities;
8use super::mcp::McpBackend;
9use super::static_tool::StaticToolErased;
10
11pub struct ToolRegistry {
12 static_tools: HashMap<String, Box<dyn StaticToolErased>>,
13 mcp_backends: Vec<Arc<McpBackend>>,
14}
15
16impl ToolRegistry {
17 pub fn new() -> Self {
18 Self {
19 static_tools: HashMap::new(),
20 mcp_backends: Vec::new(),
21 }
22 }
23
24 pub fn register_static<T: StaticToolErased + 'static>(&mut self, tool: T) {
25 self.static_tools
26 .insert(tool.name().to_string(), Box::new(tool));
27 }
28
29 pub fn register_mcp(&mut self, backend: Arc<McpBackend>) {
30 self.mcp_backends.push(backend);
31 }
32
33 pub async fn available_schemas(&self, available_caps: Capabilities) -> Vec<ToolSchema> {
34 let mut schemas = Vec::new();
35
36 for tool in self.static_tools.values() {
37 if available_caps.satisfies(tool.required_capabilities()) {
38 schemas.push(tool.schema());
39 }
40 }
41
42 for backend in &self.mcp_backends {
43 schemas.extend(backend.get_tool_schemas().await);
44 }
45
46 schemas
47 }
48
49 pub fn static_tool(&self, name: &str) -> Option<&dyn StaticToolErased> {
50 self.static_tools.get(name).map(|b| b.as_ref())
51 }
52
53 pub fn find_mcp_backend(&self, tool_name: &str) -> Option<&Arc<McpBackend>> {
54 self.mcp_backends
55 .iter()
56 .find(|&backend| backend.has_tool(tool_name))
57 }
58
59 pub fn is_static_tool(&self, name: &str) -> bool {
60 self.static_tools.contains_key(name)
61 }
62
63 pub fn static_tool_names(&self) -> Vec<&str> {
64 self.static_tools.keys().map(|s| s.as_str()).collect()
65 }
66
67 pub fn requires_approval(&self, tool_name: &str) -> bool {
68 if let Some(tool) = self.static_tools.get(tool_name) {
69 return tool.requires_approval();
70 }
71 true
72 }
73
74 pub fn required_capabilities(&self, tool_name: &str) -> Option<Capabilities> {
75 self.static_tools
76 .get(tool_name)
77 .map(|t| t.required_capabilities())
78 }
79}
80
81impl Default for ToolRegistry {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use crate::tools::capability::Capabilities;
91 use crate::tools::static_tool::{StaticToolContext, StaticToolError};
92 use async_trait::async_trait;
93 use schemars::JsonSchema;
94 use serde::Deserialize;
95 use steer_tools::ToolSpec;
96 use steer_tools::error::ToolExecutionError;
97
98 #[derive(Debug, Deserialize, JsonSchema)]
99 struct TestParams {
100 value: String,
101 }
102
103 #[derive(Debug)]
104 struct TestOutput {
105 result: String,
106 }
107
108 impl From<TestOutput> for steer_tools::result::ToolResult {
109 fn from(output: TestOutput) -> Self {
110 steer_tools::result::ToolResult::External(steer_tools::result::ExternalResult {
111 tool_name: "test_tool".to_string(),
112 payload: output.result,
113 })
114 }
115 }
116
117 struct TestTool;
118
119 #[derive(Debug, Clone, thiserror::Error)]
120 #[error("test tool error: {message}")]
121 struct TestToolError {
122 message: String,
123 }
124
125 struct TestToolSpec;
126
127 impl ToolSpec for TestToolSpec {
128 type Params = TestParams;
129 type Result = TestOutput;
130 type Error = TestToolError;
131
132 const NAME: &'static str = "test_tool";
133 const DISPLAY_NAME: &'static str = "Test Tool";
134
135 fn execution_error(error: Self::Error) -> ToolExecutionError {
136 ToolExecutionError::External {
137 tool_name: Self::NAME.to_string(),
138 message: error.to_string(),
139 }
140 }
141 }
142
143 #[async_trait]
144 impl super::super::static_tool::StaticTool for TestTool {
145 type Params = TestParams;
146 type Output = TestOutput;
147 type Spec = TestToolSpec;
148
149 const DESCRIPTION: &'static str = "A test tool";
150 const REQUIRES_APPROVAL: bool = false;
151 const REQUIRED_CAPABILITIES: Capabilities = Capabilities::WORKSPACE;
152
153 async fn execute(
154 &self,
155 params: Self::Params,
156 _ctx: &StaticToolContext,
157 ) -> Result<Self::Output, StaticToolError<TestToolError>> {
158 Ok(TestOutput {
159 result: params.value,
160 })
161 }
162 }
163
164 struct AgentTool;
165
166 struct AgentToolSpec;
167
168 impl ToolSpec for AgentToolSpec {
169 type Params = TestParams;
170 type Result = TestOutput;
171 type Error = TestToolError;
172
173 const NAME: &'static str = "agent_tool";
174 const DISPLAY_NAME: &'static str = "Agent Tool";
175
176 fn execution_error(error: Self::Error) -> ToolExecutionError {
177 ToolExecutionError::External {
178 tool_name: Self::NAME.to_string(),
179 message: error.to_string(),
180 }
181 }
182 }
183
184 #[async_trait]
185 impl super::super::static_tool::StaticTool for AgentTool {
186 type Params = TestParams;
187 type Output = TestOutput;
188 type Spec = AgentToolSpec;
189
190 const DESCRIPTION: &'static str = "Needs agent spawner";
191 const REQUIRES_APPROVAL: bool = false;
192 const REQUIRED_CAPABILITIES: Capabilities = Capabilities::AGENT;
193
194 async fn execute(
195 &self,
196 params: Self::Params,
197 _ctx: &StaticToolContext,
198 ) -> Result<Self::Output, StaticToolError<TestToolError>> {
199 Ok(TestOutput {
200 result: params.value,
201 })
202 }
203 }
204
205 #[tokio::test]
206 async fn test_capability_filtering() {
207 let mut registry = ToolRegistry::new();
208 registry.register_static(TestTool);
209 registry.register_static(AgentTool);
210
211 let schemas = registry.available_schemas(Capabilities::WORKSPACE).await;
212 assert_eq!(schemas.len(), 1);
213 assert_eq!(schemas[0].name, "test_tool");
214
215 let schemas = registry.available_schemas(Capabilities::AGENT).await;
216 assert_eq!(schemas.len(), 2);
217 }
218
219 #[test]
220 fn test_requires_approval() {
221 let mut registry = ToolRegistry::new();
222 registry.register_static(TestTool);
223
224 assert!(!registry.requires_approval("test_tool"));
225 assert!(registry.requires_approval("unknown_tool"));
226 }
227}