steer_core/tools/
backend.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::api::ToolCall;
6use crate::tools::ExecutionContext;
7use steer_tools::ToolSchema;
8use steer_tools::{ToolError, result::ToolResult};
9
10#[derive(Debug, Clone)]
12pub struct BackendMetadata {
13 pub name: String,
14 pub backend_type: String,
15 pub location: Option<String>,
16 pub additional_info: HashMap<String, String>,
17}
18
19impl BackendMetadata {
20 pub fn new(name: String, backend_type: String) -> Self {
21 Self {
22 name,
23 backend_type,
24 location: None,
25 additional_info: HashMap::new(),
26 }
27 }
28
29 pub fn with_location(mut self, location: String) -> Self {
30 self.location = Some(location);
31 self
32 }
33
34 pub fn with_info(mut self, key: String, value: String) -> Self {
35 self.additional_info.insert(key, value);
36 self
37 }
38}
39
40#[async_trait]
46pub trait ToolBackend: Send + Sync {
47 async fn execute(
56 &self,
57 tool_call: &ToolCall,
58 context: &ExecutionContext,
59 ) -> Result<ToolResult, ToolError>;
60
61 async fn supported_tools(&self) -> Vec<String>;
66
67 async fn get_tool_schemas(&self) -> Vec<ToolSchema>;
72
73 fn metadata(&self) -> BackendMetadata;
78
79 async fn health_check(&self) -> bool {
84 true
85 }
86
87 async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
98 Ok(true)
101 }
102}
103
104pub struct BackendRegistry {
110 backends: Vec<(String, Arc<dyn ToolBackend>)>,
111 tool_mapping: HashMap<String, Arc<dyn ToolBackend>>,
112}
113
114impl BackendRegistry {
115 pub fn new() -> Self {
117 Self {
118 backends: Vec::new(),
119 tool_mapping: HashMap::new(),
120 }
121 }
122
123 pub async fn register(&mut self, name: String, backend: Arc<dyn ToolBackend>) {
133 for tool_name in backend.supported_tools().await {
135 self.tool_mapping
136 .insert(tool_name.to_string(), backend.clone());
137 }
138 self.backends.push((name, backend));
139 }
140
141 pub fn get_backend_for_tool(&self, tool_name: &str) -> Option<&Arc<dyn ToolBackend>> {
149 self.tool_mapping.get(tool_name)
150 }
151
152 pub fn backends(&self) -> &Vec<(String, Arc<dyn ToolBackend>)> {
156 &self.backends
157 }
158
159 pub fn tool_mappings(&self) -> &HashMap<String, Arc<dyn ToolBackend>> {
163 &self.tool_mapping
164 }
165
166 pub async fn supported_tools(&self) -> Vec<String> {
170 self.tool_mapping.keys().cloned().collect()
171 }
172
173 pub fn unregister(&mut self, name: &str) -> bool {
178 if let Some(pos) = self.backends.iter().position(|(n, _)| n == name) {
179 let (_, backend) = self.backends.remove(pos);
180
181 self.tool_mapping
183 .retain(|_tool, mapped_backend| !Arc::ptr_eq(mapped_backend, &backend));
184
185 true
186 } else {
187 false
188 }
189 }
190
191 pub fn clear(&mut self) {
193 self.backends.clear();
194 self.tool_mapping.clear();
195 }
196
197 pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
202 let futures = self
203 .backends
204 .iter()
205 .map(|(_, backend)| backend.get_tool_schemas());
206 let all_schemas = futures::future::join_all(futures).await;
207
208 let mut all_tools = Vec::new();
209 for schemas in all_schemas {
210 all_tools.extend(schemas);
211 }
212
213 all_tools
214 }
215}
216
217impl Default for BackendRegistry {
218 fn default() -> Self {
219 Self::new()
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use crate::api::ToolCall;
227 use serde_json::json;
228 use tokio_util::sync::CancellationToken;
229
230 struct MockBackend {
231 name: String,
232 tools: Vec<&'static str>,
233 }
234
235 #[async_trait]
236 impl ToolBackend for MockBackend {
237 async fn execute(
238 &self,
239 tool_call: &ToolCall,
240 _context: &ExecutionContext,
241 ) -> Result<ToolResult, ToolError> {
242 Ok(ToolResult::External(steer_tools::result::ExternalResult {
243 tool_name: self.name.clone(),
244 payload: format!("Mock execution of {} by {}", tool_call.name, self.name),
245 }))
246 }
247
248 async fn supported_tools(&self) -> Vec<String> {
249 self.tools.iter().map(|&s| s.to_string()).collect()
250 }
251
252 async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
253 Vec::new()
254 }
255
256 fn metadata(&self) -> BackendMetadata {
257 BackendMetadata::new(self.name.clone(), "Mock".to_string())
258 }
259 }
260
261 #[tokio::test]
262 async fn test_backend_registry() {
263 let mut registry = BackendRegistry::new();
264
265 let backend1 = Arc::new(MockBackend {
266 name: "backend1".to_string(),
267 tools: vec!["tool1", "tool2"],
268 });
269
270 let backend2 = Arc::new(MockBackend {
271 name: "backend2".to_string(),
272 tools: vec!["tool3", "tool4"],
273 });
274
275 registry
276 .register("backend1".to_string(), backend1.clone())
277 .await;
278 registry
279 .register("backend2".to_string(), backend2.clone())
280 .await;
281
282 assert!(registry.get_backend_for_tool("tool1").is_some());
284 assert!(registry.get_backend_for_tool("tool3").is_some());
285 assert!(registry.get_backend_for_tool("unknown_tool").is_none());
286
287 let supported = registry.supported_tools().await;
289 assert_eq!(supported.len(), 4);
290 assert!(supported.contains(&"tool1".to_string()));
291 assert!(supported.contains(&"tool4".to_string()));
292
293 assert!(registry.unregister("backend1"));
295 assert!(!registry.unregister("nonexistent"));
296
297 assert!(registry.get_backend_for_tool("tool1").is_none());
299 assert!(registry.get_backend_for_tool("tool3").is_some());
300 }
301
302 #[tokio::test]
303 async fn test_mock_backend_execution() {
304 let backend = MockBackend {
305 name: "test".to_string(),
306 tools: vec!["test_tool"],
307 };
308
309 let tool_call = ToolCall {
310 name: "test_tool".to_string(),
311 parameters: json!({}),
312 id: "test_id".to_string(),
313 };
314
315 let context = ExecutionContext::new(
316 "session".to_string(),
317 "operation".to_string(),
318 "tool_call".to_string(),
319 CancellationToken::new(),
320 );
321
322 let result = backend.execute(&tool_call, &context).await.unwrap();
323 match result {
324 ToolResult::External(external) => {
325 assert!(external.payload.contains("Mock execution"));
326 assert!(external.payload.contains("test_tool"));
327 assert!(external.payload.contains("test"));
328 }
329 _ => unreachable!("External result"),
330 }
331 }
332}