steer_core/tools/
backend.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::api::ToolCall;
6use crate::tools::ExecutionContext;
7use steer_tools::ToolSchema;
8use steer_tools::{ToolError, result::ToolResult};
9
10/// Metadata about a tool backend for debugging and monitoring
11#[derive(Debug, Clone)]
12pub struct BackendMetadata {
13    pub name: String,
14    pub backend_type: String,
15    pub location: Option<String>,
16    pub additional_info: HashMap<String, String>,
17}
18
19impl BackendMetadata {
20    pub fn new(name: String, backend_type: String) -> Self {
21        Self {
22            name,
23            backend_type,
24            location: None,
25            additional_info: HashMap::new(),
26        }
27    }
28
29    pub fn with_location(mut self, location: String) -> Self {
30        self.location = Some(location);
31        self
32    }
33
34    pub fn with_info(mut self, key: String, value: String) -> Self {
35        self.additional_info.insert(key, value);
36        self
37    }
38}
39
40/// Simple trait for tool execution backends
41///
42/// This trait abstracts different execution environments for tools,
43/// allowing tools to run locally, on remote machines, in containers,
44/// or through proxy services.
45#[async_trait]
46pub trait ToolBackend: Send + Sync {
47    /// Execute a tool call in this backend's environment
48    ///
49    /// # Arguments
50    /// * `tool_call` - The tool call containing name, parameters, and ID
51    /// * `context` - Execution context with session info, cancellation, etc.
52    ///
53    /// # Returns
54    /// The typed tool result on success, or a ToolError on failure
55    async fn execute(
56        &self,
57        tool_call: &ToolCall,
58        context: &ExecutionContext,
59    ) -> Result<ToolResult, ToolError>;
60
61    /// List the tools this backend can handle
62    ///
63    /// Returns a vector of tool names that this backend supports.
64    /// The backend registry uses this to map tools to backends.
65    async fn supported_tools(&self) -> Vec<String>;
66
67    /// Get API tool descriptions for this backend
68    ///
69    /// Returns a vector of ToolSchema objects containing name, description,
70    /// and input schema for each tool this backend supports.
71    async fn get_tool_schemas(&self) -> Vec<ToolSchema>;
72
73    /// Backend metadata for debugging and monitoring
74    ///
75    /// Override this method to provide additional information about
76    /// the backend for observability and troubleshooting.
77    fn metadata(&self) -> BackendMetadata;
78
79    /// Check if the backend is healthy and ready to execute tools
80    ///
81    /// This method can be used for health checks and load balancing.
82    /// Default implementation returns true.
83    async fn health_check(&self) -> bool {
84        true
85    }
86
87    /// Check if a tool requires approval before execution
88    ///
89    /// Returns true if the tool requires user approval before it can be executed.
90    /// Default implementation returns true for safety.
91    ///
92    /// # Arguments
93    /// * `tool_name` - The name of the tool to check
94    ///
95    /// # Returns
96    /// Ok(true) if approval is required, Ok(false) if not, or an error if the tool is unknown
97    async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
98        // Default conservative implementation - require approval
99        // Backends should override this to provide accurate information
100        Ok(true)
101    }
102}
103
104/// Registry that maps tool names to their backends
105///
106/// When a backend is registered, the registry queries its supported_tools()
107/// and creates mappings for each tool. This allows for efficient lookup
108/// of the appropriate backend for a given tool name.
109pub struct BackendRegistry {
110    backends: Vec<(String, Arc<dyn ToolBackend>)>,
111    tool_mapping: HashMap<String, Arc<dyn ToolBackend>>,
112}
113
114impl BackendRegistry {
115    /// Create a new empty backend registry
116    pub fn new() -> Self {
117        Self {
118            backends: Vec::new(),
119            tool_mapping: HashMap::new(),
120        }
121    }
122
123    /// Register a backend with the given name
124    ///
125    /// This method queries the backend's supported_tools() and creates
126    /// mappings for each tool. If a tool is already mapped to another
127    /// backend, it will be overwritten.
128    ///
129    /// # Arguments
130    /// * `name` - A unique name for this backend instance
131    /// * `backend` - The backend implementation
132    pub async fn register(&mut self, name: String, backend: Arc<dyn ToolBackend>) {
133        // Map each tool this backend supports
134        for tool_name in backend.supported_tools().await {
135            self.tool_mapping
136                .insert(tool_name.to_string(), backend.clone());
137        }
138        self.backends.push((name, backend));
139    }
140
141    /// Get the backend for a specific tool
142    ///
143    /// Returns the backend that can handle the given tool name,
144    /// or None if no backend supports that tool.
145    ///
146    /// # Arguments
147    /// * `tool_name` - The name of the tool to look up
148    pub fn get_backend_for_tool(&self, tool_name: &str) -> Option<&Arc<dyn ToolBackend>> {
149        self.tool_mapping.get(tool_name)
150    }
151
152    /// Get all registered backends
153    ///
154    /// Returns a vector of (name, backend) pairs for all registered backends.
155    pub fn backends(&self) -> &Vec<(String, Arc<dyn ToolBackend>)> {
156        &self.backends
157    }
158
159    /// Get all tool mappings
160    ///
161    /// Returns a reference to the tool name -> backend mapping.
162    pub fn tool_mappings(&self) -> &HashMap<String, Arc<dyn ToolBackend>> {
163        &self.tool_mapping
164    }
165
166    /// Check which tools are supported
167    ///
168    /// Returns a vector of all tool names that have registered backends.
169    pub async fn supported_tools(&self) -> Vec<String> {
170        self.tool_mapping.keys().cloned().collect()
171    }
172
173    /// Remove a backend by name
174    ///
175    /// This removes the backend and all its tool mappings.
176    /// Returns true if a backend was removed, false if the name wasn't found.
177    pub fn unregister(&mut self, name: &str) -> bool {
178        if let Some(pos) = self.backends.iter().position(|(n, _)| n == name) {
179            let (_, backend) = self.backends.remove(pos);
180
181            // Remove all tool mappings for this backend
182            self.tool_mapping
183                .retain(|_tool, mapped_backend| !Arc::ptr_eq(mapped_backend, &backend));
184
185            true
186        } else {
187            false
188        }
189    }
190
191    /// Clear all backends and mappings
192    pub fn clear(&mut self) {
193        self.backends.clear();
194        self.tool_mapping.clear();
195    }
196
197    /// Get API tools from all registered backends
198    ///
199    /// Collects and returns all API tool descriptions from all registered backends.
200    /// This provides a unified view of all available tools across all backends.
201    pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
202        let futures = self
203            .backends
204            .iter()
205            .map(|(_, backend)| backend.get_tool_schemas());
206        let all_schemas = futures::future::join_all(futures).await;
207
208        let mut all_tools = Vec::new();
209        for schemas in all_schemas {
210            all_tools.extend(schemas);
211        }
212
213        all_tools
214    }
215}
216
217impl Default for BackendRegistry {
218    fn default() -> Self {
219        Self::new()
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use crate::api::ToolCall;
227    use serde_json::json;
228    use tokio_util::sync::CancellationToken;
229
230    struct MockBackend {
231        name: String,
232        tools: Vec<&'static str>,
233    }
234
235    #[async_trait]
236    impl ToolBackend for MockBackend {
237        async fn execute(
238            &self,
239            tool_call: &ToolCall,
240            _context: &ExecutionContext,
241        ) -> Result<ToolResult, ToolError> {
242            Ok(ToolResult::External(steer_tools::result::ExternalResult {
243                tool_name: self.name.clone(),
244                payload: format!("Mock execution of {} by {}", tool_call.name, self.name),
245            }))
246        }
247
248        async fn supported_tools(&self) -> Vec<String> {
249            self.tools.iter().map(|&s| s.to_string()).collect()
250        }
251
252        async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
253            Vec::new()
254        }
255
256        fn metadata(&self) -> BackendMetadata {
257            BackendMetadata::new(self.name.clone(), "Mock".to_string())
258        }
259    }
260
261    #[tokio::test]
262    async fn test_backend_registry() {
263        let mut registry = BackendRegistry::new();
264
265        let backend1 = Arc::new(MockBackend {
266            name: "backend1".to_string(),
267            tools: vec!["tool1", "tool2"],
268        });
269
270        let backend2 = Arc::new(MockBackend {
271            name: "backend2".to_string(),
272            tools: vec!["tool3", "tool4"],
273        });
274
275        registry
276            .register("backend1".to_string(), backend1.clone())
277            .await;
278        registry
279            .register("backend2".to_string(), backend2.clone())
280            .await;
281
282        // Test tool mappings
283        assert!(registry.get_backend_for_tool("tool1").is_some());
284        assert!(registry.get_backend_for_tool("tool3").is_some());
285        assert!(registry.get_backend_for_tool("unknown_tool").is_none());
286
287        // Test supported tools
288        let supported = registry.supported_tools().await;
289        assert_eq!(supported.len(), 4);
290        assert!(supported.contains(&"tool1".to_string()));
291        assert!(supported.contains(&"tool4".to_string()));
292
293        // Test backend removal
294        assert!(registry.unregister("backend1"));
295        assert!(!registry.unregister("nonexistent"));
296
297        // tool1 and tool2 should no longer be mapped
298        assert!(registry.get_backend_for_tool("tool1").is_none());
299        assert!(registry.get_backend_for_tool("tool3").is_some());
300    }
301
302    #[tokio::test]
303    async fn test_mock_backend_execution() {
304        let backend = MockBackend {
305            name: "test".to_string(),
306            tools: vec!["test_tool"],
307        };
308
309        let tool_call = ToolCall {
310            name: "test_tool".to_string(),
311            parameters: json!({}),
312            id: "test_id".to_string(),
313        };
314
315        let context = ExecutionContext::new(
316            "session".to_string(),
317            "operation".to_string(),
318            "tool_call".to_string(),
319            CancellationToken::new(),
320        );
321
322        let result = backend.execute(&tool_call, &context).await.unwrap();
323        match result {
324            ToolResult::External(external) => {
325                assert!(external.payload.contains("Mock execution"));
326                assert!(external.payload.contains("test_tool"));
327                assert!(external.payload.contains("test"));
328            }
329            _ => unreachable!("External result"),
330        }
331    }
332}