steer_core/tools/
backend.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::tools::ExecutionContext;
6use steer_tools::{ToolCall, ToolError, ToolSchema, result::ToolResult};
7
8#[derive(Debug, Clone)]
10pub struct BackendMetadata {
11 pub name: String,
12 pub backend_type: String,
13 pub location: Option<String>,
14 pub additional_info: HashMap<String, String>,
15}
16
17impl BackendMetadata {
18 pub fn new(name: String, backend_type: String) -> Self {
19 Self {
20 name,
21 backend_type,
22 location: None,
23 additional_info: HashMap::new(),
24 }
25 }
26
27 pub fn with_location(mut self, location: String) -> Self {
28 self.location = Some(location);
29 self
30 }
31
32 pub fn with_info(mut self, key: String, value: String) -> Self {
33 self.additional_info.insert(key, value);
34 self
35 }
36}
37
38#[async_trait]
44pub trait ToolBackend: Send + Sync {
45 async fn execute(
54 &self,
55 tool_call: &ToolCall,
56 context: &ExecutionContext,
57 ) -> Result<ToolResult, ToolError>;
58
59 async fn supported_tools(&self) -> Vec<String>;
64
65 async fn get_tool_schemas(&self) -> Vec<ToolSchema>;
70
71 fn metadata(&self) -> BackendMetadata;
76
77 async fn health_check(&self) -> bool {
82 true
83 }
84
85 async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
96 Ok(true)
99 }
100}
101
102pub struct BackendRegistry {
108 backends: Vec<(String, Arc<dyn ToolBackend>)>,
109 tool_mapping: HashMap<String, Arc<dyn ToolBackend>>,
110}
111
112impl BackendRegistry {
113 pub fn new() -> Self {
115 Self {
116 backends: Vec::new(),
117 tool_mapping: HashMap::new(),
118 }
119 }
120
121 pub async fn register(&mut self, name: String, backend: Arc<dyn ToolBackend>) {
131 for tool_name in backend.supported_tools().await {
133 self.tool_mapping.insert(tool_name.clone(), backend.clone());
134 }
135 self.backends.push((name, backend));
136 }
137
138 pub fn get_backend_for_tool(&self, tool_name: &str) -> Option<&Arc<dyn ToolBackend>> {
146 self.tool_mapping.get(tool_name)
147 }
148
149 pub fn backends(&self) -> &Vec<(String, Arc<dyn ToolBackend>)> {
153 &self.backends
154 }
155
156 pub fn tool_mappings(&self) -> &HashMap<String, Arc<dyn ToolBackend>> {
160 &self.tool_mapping
161 }
162
163 pub async fn supported_tools(&self) -> Vec<String> {
167 self.tool_mapping.keys().cloned().collect()
168 }
169
170 pub fn unregister(&mut self, name: &str) -> bool {
175 if let Some(pos) = self.backends.iter().position(|(n, _)| n == name) {
176 let (_, backend) = self.backends.remove(pos);
177
178 self.tool_mapping
180 .retain(|_tool, mapped_backend| !Arc::ptr_eq(mapped_backend, &backend));
181
182 true
183 } else {
184 false
185 }
186 }
187
188 pub fn clear(&mut self) {
190 self.backends.clear();
191 self.tool_mapping.clear();
192 }
193
194 pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
199 let futures = self
200 .backends
201 .iter()
202 .map(|(_, backend)| backend.get_tool_schemas());
203 let all_schemas = futures::future::join_all(futures).await;
204
205 let mut all_tools = Vec::new();
206 for schemas in all_schemas {
207 all_tools.extend(schemas);
208 }
209
210 all_tools
211 }
212}
213
214impl Default for BackendRegistry {
215 fn default() -> Self {
216 Self::new()
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use serde_json::json;
224 use tokio_util::sync::CancellationToken;
225
226 struct MockBackend {
227 name: String,
228 tools: Vec<&'static str>,
229 }
230
231 #[async_trait]
232 impl ToolBackend for MockBackend {
233 async fn execute(
234 &self,
235 tool_call: &ToolCall,
236 _context: &ExecutionContext,
237 ) -> Result<ToolResult, ToolError> {
238 Ok(ToolResult::External(steer_tools::result::ExternalResult {
239 tool_name: self.name.clone(),
240 payload: format!("Mock execution of {} by {}", tool_call.name, self.name),
241 }))
242 }
243
244 async fn supported_tools(&self) -> Vec<String> {
245 self.tools.iter().map(|&s| s.to_string()).collect()
246 }
247
248 async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
249 Vec::new()
250 }
251
252 fn metadata(&self) -> BackendMetadata {
253 BackendMetadata::new(self.name.clone(), "Mock".to_string())
254 }
255 }
256
257 #[tokio::test]
258 async fn test_backend_registry() {
259 let mut registry = BackendRegistry::new();
260
261 let backend1 = Arc::new(MockBackend {
262 name: "backend1".to_string(),
263 tools: vec!["tool1", "tool2"],
264 });
265
266 let backend2 = Arc::new(MockBackend {
267 name: "backend2".to_string(),
268 tools: vec!["tool3", "tool4"],
269 });
270
271 registry
272 .register("backend1".to_string(), backend1.clone())
273 .await;
274 registry
275 .register("backend2".to_string(), backend2.clone())
276 .await;
277
278 assert!(registry.get_backend_for_tool("tool1").is_some());
280 assert!(registry.get_backend_for_tool("tool3").is_some());
281 assert!(registry.get_backend_for_tool("unknown_tool").is_none());
282
283 let supported = registry.supported_tools().await;
285 assert_eq!(supported.len(), 4);
286 assert!(supported.contains(&"tool1".to_string()));
287 assert!(supported.contains(&"tool4".to_string()));
288
289 assert!(registry.unregister("backend1"));
291 assert!(!registry.unregister("nonexistent"));
292
293 assert!(registry.get_backend_for_tool("tool1").is_none());
295 assert!(registry.get_backend_for_tool("tool3").is_some());
296 }
297
298 #[tokio::test]
299 async fn test_mock_backend_execution() {
300 let backend = MockBackend {
301 name: "test".to_string(),
302 tools: vec!["test_tool"],
303 };
304
305 let tool_call = ToolCall {
306 name: "test_tool".to_string(),
307 parameters: json!({}),
308 id: "test_id".to_string(),
309 };
310
311 let context = ExecutionContext::new(
312 "session".to_string(),
313 "operation".to_string(),
314 "tool_call".to_string(),
315 CancellationToken::new(),
316 );
317
318 let result = backend.execute(&tool_call, &context).await.unwrap();
319 match result {
320 ToolResult::External(external) => {
321 assert!(external.payload.contains("Mock execution"));
322 assert!(external.payload.contains("test_tool"));
323 assert!(external.payload.contains("test"));
324 }
325 _ => unreachable!("External result"),
326 }
327 }
328}