Skip to main content

steer_core/tools/
registry.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use steer_tools::ToolSchema;
5
6use super::backend::ToolBackend;
7use super::capability::Capabilities;
8use super::mcp::McpBackend;
9use super::static_tool::StaticToolErased;
10
11pub struct ToolRegistry {
12    static_tools: HashMap<String, Box<dyn StaticToolErased>>,
13    mcp_backends: Vec<Arc<McpBackend>>,
14}
15
16impl ToolRegistry {
17    pub fn new() -> Self {
18        Self {
19            static_tools: HashMap::new(),
20            mcp_backends: Vec::new(),
21        }
22    }
23
24    pub fn register_static<T: StaticToolErased + 'static>(&mut self, tool: T) {
25        self.static_tools
26            .insert(tool.name().to_string(), Box::new(tool));
27    }
28
29    pub fn register_mcp(&mut self, backend: Arc<McpBackend>) {
30        self.mcp_backends.push(backend);
31    }
32
33    pub async fn available_schemas(&self, available_caps: Capabilities) -> Vec<ToolSchema> {
34        let mut schemas = Vec::new();
35
36        for tool in self.static_tools.values() {
37            if available_caps.satisfies(tool.required_capabilities()) {
38                schemas.push(tool.schema());
39            }
40        }
41
42        for backend in &self.mcp_backends {
43            schemas.extend(backend.get_tool_schemas().await);
44        }
45
46        schemas
47    }
48
49    pub fn static_tool(&self, name: &str) -> Option<&dyn StaticToolErased> {
50        self.static_tools.get(name).map(|b| b.as_ref())
51    }
52
53    pub fn find_mcp_backend(&self, tool_name: &str) -> Option<&Arc<McpBackend>> {
54        self.mcp_backends
55            .iter()
56            .find(|&backend| backend.has_tool(tool_name))
57    }
58
59    pub fn is_static_tool(&self, name: &str) -> bool {
60        self.static_tools.contains_key(name)
61    }
62
63    pub fn static_tool_names(&self) -> Vec<&str> {
64        self.static_tools.keys().map(|s| s.as_str()).collect()
65    }
66
67    pub fn requires_approval(&self, tool_name: &str) -> bool {
68        if let Some(tool) = self.static_tools.get(tool_name) {
69            return tool.requires_approval();
70        }
71        true
72    }
73
74    pub fn required_capabilities(&self, tool_name: &str) -> Option<Capabilities> {
75        self.static_tools
76            .get(tool_name)
77            .map(|t| t.required_capabilities())
78    }
79}
80
81impl Default for ToolRegistry {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use crate::tools::capability::Capabilities;
91    use crate::tools::static_tool::{StaticToolContext, StaticToolError};
92    use async_trait::async_trait;
93    use schemars::JsonSchema;
94    use serde::Deserialize;
95    use steer_tools::ToolSpec;
96    use steer_tools::error::ToolExecutionError;
97
98    #[derive(Debug, Deserialize, JsonSchema)]
99    struct TestParams {
100        value: String,
101    }
102
103    #[derive(Debug)]
104    struct TestOutput {
105        result: String,
106    }
107
108    impl From<TestOutput> for steer_tools::result::ToolResult {
109        fn from(output: TestOutput) -> Self {
110            steer_tools::result::ToolResult::External(steer_tools::result::ExternalResult {
111                tool_name: "test_tool".to_string(),
112                payload: output.result,
113            })
114        }
115    }
116
117    struct TestTool;
118
119    #[derive(Debug, Clone, thiserror::Error)]
120    #[error("test tool error: {message}")]
121    struct TestToolError {
122        message: String,
123    }
124
125    struct TestToolSpec;
126
127    impl ToolSpec for TestToolSpec {
128        type Params = TestParams;
129        type Result = TestOutput;
130        type Error = TestToolError;
131
132        const NAME: &'static str = "test_tool";
133        const DISPLAY_NAME: &'static str = "Test Tool";
134
135        fn execution_error(error: Self::Error) -> ToolExecutionError {
136            ToolExecutionError::External {
137                tool_name: Self::NAME.to_string(),
138                message: error.to_string(),
139            }
140        }
141    }
142
143    #[async_trait]
144    impl super::super::static_tool::StaticTool for TestTool {
145        type Params = TestParams;
146        type Output = TestOutput;
147        type Spec = TestToolSpec;
148
149        const DESCRIPTION: &'static str = "A test tool";
150        const REQUIRES_APPROVAL: bool = false;
151        const REQUIRED_CAPABILITIES: Capabilities = Capabilities::WORKSPACE;
152
153        async fn execute(
154            &self,
155            params: Self::Params,
156            _ctx: &StaticToolContext,
157        ) -> Result<Self::Output, StaticToolError<TestToolError>> {
158            Ok(TestOutput {
159                result: params.value,
160            })
161        }
162    }
163
164    struct AgentTool;
165
166    struct AgentToolSpec;
167
168    impl ToolSpec for AgentToolSpec {
169        type Params = TestParams;
170        type Result = TestOutput;
171        type Error = TestToolError;
172
173        const NAME: &'static str = "agent_tool";
174        const DISPLAY_NAME: &'static str = "Agent Tool";
175
176        fn execution_error(error: Self::Error) -> ToolExecutionError {
177            ToolExecutionError::External {
178                tool_name: Self::NAME.to_string(),
179                message: error.to_string(),
180            }
181        }
182    }
183
184    #[async_trait]
185    impl super::super::static_tool::StaticTool for AgentTool {
186        type Params = TestParams;
187        type Output = TestOutput;
188        type Spec = AgentToolSpec;
189
190        const DESCRIPTION: &'static str = "Needs agent spawner";
191        const REQUIRES_APPROVAL: bool = false;
192        const REQUIRED_CAPABILITIES: Capabilities = Capabilities::AGENT;
193
194        async fn execute(
195            &self,
196            params: Self::Params,
197            _ctx: &StaticToolContext,
198        ) -> Result<Self::Output, StaticToolError<TestToolError>> {
199            Ok(TestOutput {
200                result: params.value,
201            })
202        }
203    }
204
205    #[tokio::test]
206    async fn test_capability_filtering() {
207        let mut registry = ToolRegistry::new();
208        registry.register_static(TestTool);
209        registry.register_static(AgentTool);
210
211        let schemas = registry.available_schemas(Capabilities::WORKSPACE).await;
212        assert_eq!(schemas.len(), 1);
213        assert_eq!(schemas[0].name, "test_tool");
214
215        let schemas = registry.available_schemas(Capabilities::AGENT).await;
216        assert_eq!(schemas.len(), 2);
217    }
218
219    #[test]
220    fn test_requires_approval() {
221        let mut registry = ToolRegistry::new();
222        registry.register_static(TestTool);
223
224        assert!(!registry.requires_approval("test_tool"));
225        assert!(registry.requires_approval("unknown_tool"));
226    }
227}