steer_core/tools/
backend.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::tools::ExecutionContext;
6use steer_tools::{ToolCall, ToolError, ToolSchema, result::ToolResult};
7
8#[derive(Debug, Clone)]
10pub struct BackendMetadata {
11 pub name: String,
12 pub backend_type: String,
13 pub location: Option<String>,
14 pub additional_info: HashMap<String, String>,
15}
16
17impl BackendMetadata {
18 pub fn new(name: String, backend_type: String) -> Self {
19 Self {
20 name,
21 backend_type,
22 location: None,
23 additional_info: HashMap::new(),
24 }
25 }
26
27 pub fn with_location(mut self, location: String) -> Self {
28 self.location = Some(location);
29 self
30 }
31
32 pub fn with_info(mut self, key: String, value: String) -> Self {
33 self.additional_info.insert(key, value);
34 self
35 }
36}
37
38#[async_trait]
44pub trait ToolBackend: Send + Sync {
45 async fn execute(
54 &self,
55 tool_call: &ToolCall,
56 context: &ExecutionContext,
57 ) -> Result<ToolResult, ToolError>;
58
59 async fn supported_tools(&self) -> Vec<String>;
64
65 async fn get_tool_schemas(&self) -> Vec<ToolSchema>;
70
71 fn metadata(&self) -> BackendMetadata;
76
77 async fn health_check(&self) -> bool {
82 true
83 }
84
85 async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
96 Ok(true)
99 }
100}
101
102pub struct BackendRegistry {
108 backends: Vec<(String, Arc<dyn ToolBackend>)>,
109 tool_mapping: HashMap<String, Arc<dyn ToolBackend>>,
110}
111
112impl BackendRegistry {
113 pub fn new() -> Self {
115 Self {
116 backends: Vec::new(),
117 tool_mapping: HashMap::new(),
118 }
119 }
120
121 pub async fn register(&mut self, name: String, backend: Arc<dyn ToolBackend>) {
131 for tool_name in backend.supported_tools().await {
133 self.tool_mapping
134 .insert(tool_name.to_string(), backend.clone());
135 }
136 self.backends.push((name, backend));
137 }
138
139 pub fn get_backend_for_tool(&self, tool_name: &str) -> Option<&Arc<dyn ToolBackend>> {
147 self.tool_mapping.get(tool_name)
148 }
149
150 pub fn backends(&self) -> &Vec<(String, Arc<dyn ToolBackend>)> {
154 &self.backends
155 }
156
157 pub fn tool_mappings(&self) -> &HashMap<String, Arc<dyn ToolBackend>> {
161 &self.tool_mapping
162 }
163
164 pub async fn supported_tools(&self) -> Vec<String> {
168 self.tool_mapping.keys().cloned().collect()
169 }
170
171 pub fn unregister(&mut self, name: &str) -> bool {
176 if let Some(pos) = self.backends.iter().position(|(n, _)| n == name) {
177 let (_, backend) = self.backends.remove(pos);
178
179 self.tool_mapping
181 .retain(|_tool, mapped_backend| !Arc::ptr_eq(mapped_backend, &backend));
182
183 true
184 } else {
185 false
186 }
187 }
188
189 pub fn clear(&mut self) {
191 self.backends.clear();
192 self.tool_mapping.clear();
193 }
194
195 pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
200 let futures = self
201 .backends
202 .iter()
203 .map(|(_, backend)| backend.get_tool_schemas());
204 let all_schemas = futures::future::join_all(futures).await;
205
206 let mut all_tools = Vec::new();
207 for schemas in all_schemas {
208 all_tools.extend(schemas);
209 }
210
211 all_tools
212 }
213}
214
215impl Default for BackendRegistry {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use serde_json::json;
225 use tokio_util::sync::CancellationToken;
226
227 struct MockBackend {
228 name: String,
229 tools: Vec<&'static str>,
230 }
231
232 #[async_trait]
233 impl ToolBackend for MockBackend {
234 async fn execute(
235 &self,
236 tool_call: &ToolCall,
237 _context: &ExecutionContext,
238 ) -> Result<ToolResult, ToolError> {
239 Ok(ToolResult::External(steer_tools::result::ExternalResult {
240 tool_name: self.name.clone(),
241 payload: format!("Mock execution of {} by {}", tool_call.name, self.name),
242 }))
243 }
244
245 async fn supported_tools(&self) -> Vec<String> {
246 self.tools.iter().map(|&s| s.to_string()).collect()
247 }
248
249 async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
250 Vec::new()
251 }
252
253 fn metadata(&self) -> BackendMetadata {
254 BackendMetadata::new(self.name.clone(), "Mock".to_string())
255 }
256 }
257
258 #[tokio::test]
259 async fn test_backend_registry() {
260 let mut registry = BackendRegistry::new();
261
262 let backend1 = Arc::new(MockBackend {
263 name: "backend1".to_string(),
264 tools: vec!["tool1", "tool2"],
265 });
266
267 let backend2 = Arc::new(MockBackend {
268 name: "backend2".to_string(),
269 tools: vec!["tool3", "tool4"],
270 });
271
272 registry
273 .register("backend1".to_string(), backend1.clone())
274 .await;
275 registry
276 .register("backend2".to_string(), backend2.clone())
277 .await;
278
279 assert!(registry.get_backend_for_tool("tool1").is_some());
281 assert!(registry.get_backend_for_tool("tool3").is_some());
282 assert!(registry.get_backend_for_tool("unknown_tool").is_none());
283
284 let supported = registry.supported_tools().await;
286 assert_eq!(supported.len(), 4);
287 assert!(supported.contains(&"tool1".to_string()));
288 assert!(supported.contains(&"tool4".to_string()));
289
290 assert!(registry.unregister("backend1"));
292 assert!(!registry.unregister("nonexistent"));
293
294 assert!(registry.get_backend_for_tool("tool1").is_none());
296 assert!(registry.get_backend_for_tool("tool3").is_some());
297 }
298
299 #[tokio::test]
300 async fn test_mock_backend_execution() {
301 let backend = MockBackend {
302 name: "test".to_string(),
303 tools: vec!["test_tool"],
304 };
305
306 let tool_call = ToolCall {
307 name: "test_tool".to_string(),
308 parameters: json!({}),
309 id: "test_id".to_string(),
310 };
311
312 let context = ExecutionContext::new(
313 "session".to_string(),
314 "operation".to_string(),
315 "tool_call".to_string(),
316 CancellationToken::new(),
317 );
318
319 let result = backend.execute(&tool_call, &context).await.unwrap();
320 match result {
321 ToolResult::External(external) => {
322 assert!(external.payload.contains("Mock execution"));
323 assert!(external.payload.contains("test_tool"));
324 assert!(external.payload.contains("test"));
325 }
326 _ => unreachable!("External result"),
327 }
328 }
329}