steer_core/tools/
backend.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::tools::ExecutionContext;
6use steer_tools::{ToolCall, ToolError, ToolSchema, result::ToolResult};
7
8/// Metadata about a tool backend for debugging and monitoring
9#[derive(Debug, Clone)]
10pub struct BackendMetadata {
11    pub name: String,
12    pub backend_type: String,
13    pub location: Option<String>,
14    pub additional_info: HashMap<String, String>,
15}
16
17impl BackendMetadata {
18    pub fn new(name: String, backend_type: String) -> Self {
19        Self {
20            name,
21            backend_type,
22            location: None,
23            additional_info: HashMap::new(),
24        }
25    }
26
27    pub fn with_location(mut self, location: String) -> Self {
28        self.location = Some(location);
29        self
30    }
31
32    pub fn with_info(mut self, key: String, value: String) -> Self {
33        self.additional_info.insert(key, value);
34        self
35    }
36}
37
38/// Simple trait for tool execution backends
39///
40/// This trait abstracts different execution environments for tools,
41/// allowing tools to run locally, on remote machines, in containers,
42/// or through proxy services.
43#[async_trait]
44pub trait ToolBackend: Send + Sync {
45    /// Execute a tool call in this backend's environment
46    ///
47    /// # Arguments
48    /// * `tool_call` - The tool call containing name, parameters, and ID
49    /// * `context` - Execution context with session info, cancellation, etc.
50    ///
51    /// # Returns
52    /// The typed tool result on success, or a ToolError on failure
53    async fn execute(
54        &self,
55        tool_call: &ToolCall,
56        context: &ExecutionContext,
57    ) -> Result<ToolResult, ToolError>;
58
59    /// List the tools this backend can handle
60    ///
61    /// Returns a vector of tool names that this backend supports.
62    /// The backend registry uses this to map tools to backends.
63    async fn supported_tools(&self) -> Vec<String>;
64
65    /// Get API tool descriptions for this backend
66    ///
67    /// Returns a vector of ToolSchema objects containing name, description,
68    /// and input schema for each tool this backend supports.
69    async fn get_tool_schemas(&self) -> Vec<ToolSchema>;
70
71    /// Backend metadata for debugging and monitoring
72    ///
73    /// Override this method to provide additional information about
74    /// the backend for observability and troubleshooting.
75    fn metadata(&self) -> BackendMetadata;
76
77    /// Check if the backend is healthy and ready to execute tools
78    ///
79    /// This method can be used for health checks and load balancing.
80    /// Default implementation returns true.
81    async fn health_check(&self) -> bool {
82        true
83    }
84
85    /// Check if a tool requires approval before execution
86    ///
87    /// Returns true if the tool requires user approval before it can be executed.
88    /// Default implementation returns true for safety.
89    ///
90    /// # Arguments
91    /// * `tool_name` - The name of the tool to check
92    ///
93    /// # Returns
94    /// Ok(true) if approval is required, Ok(false) if not, or an error if the tool is unknown
95    async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
96        // Default conservative implementation - require approval
97        // Backends should override this to provide accurate information
98        Ok(true)
99    }
100}
101
102/// Registry that maps tool names to their backends
103///
104/// When a backend is registered, the registry queries its supported_tools()
105/// and creates mappings for each tool. This allows for efficient lookup
106/// of the appropriate backend for a given tool name.
107pub struct BackendRegistry {
108    backends: Vec<(String, Arc<dyn ToolBackend>)>,
109    tool_mapping: HashMap<String, Arc<dyn ToolBackend>>,
110}
111
112impl BackendRegistry {
113    /// Create a new empty backend registry
114    pub fn new() -> Self {
115        Self {
116            backends: Vec::new(),
117            tool_mapping: HashMap::new(),
118        }
119    }
120
121    /// Register a backend with the given name
122    ///
123    /// This method queries the backend's supported_tools() and creates
124    /// mappings for each tool. If a tool is already mapped to another
125    /// backend, it will be overwritten.
126    ///
127    /// # Arguments
128    /// * `name` - A unique name for this backend instance
129    /// * `backend` - The backend implementation
130    pub async fn register(&mut self, name: String, backend: Arc<dyn ToolBackend>) {
131        // Map each tool this backend supports
132        for tool_name in backend.supported_tools().await {
133            self.tool_mapping
134                .insert(tool_name.to_string(), backend.clone());
135        }
136        self.backends.push((name, backend));
137    }
138
139    /// Get the backend for a specific tool
140    ///
141    /// Returns the backend that can handle the given tool name,
142    /// or None if no backend supports that tool.
143    ///
144    /// # Arguments
145    /// * `tool_name` - The name of the tool to look up
146    pub fn get_backend_for_tool(&self, tool_name: &str) -> Option<&Arc<dyn ToolBackend>> {
147        self.tool_mapping.get(tool_name)
148    }
149
150    /// Get all registered backends
151    ///
152    /// Returns a vector of (name, backend) pairs for all registered backends.
153    pub fn backends(&self) -> &Vec<(String, Arc<dyn ToolBackend>)> {
154        &self.backends
155    }
156
157    /// Get all tool mappings
158    ///
159    /// Returns a reference to the tool name -> backend mapping.
160    pub fn tool_mappings(&self) -> &HashMap<String, Arc<dyn ToolBackend>> {
161        &self.tool_mapping
162    }
163
164    /// Check which tools are supported
165    ///
166    /// Returns a vector of all tool names that have registered backends.
167    pub async fn supported_tools(&self) -> Vec<String> {
168        self.tool_mapping.keys().cloned().collect()
169    }
170
171    /// Remove a backend by name
172    ///
173    /// This removes the backend and all its tool mappings.
174    /// Returns true if a backend was removed, false if the name wasn't found.
175    pub fn unregister(&mut self, name: &str) -> bool {
176        if let Some(pos) = self.backends.iter().position(|(n, _)| n == name) {
177            let (_, backend) = self.backends.remove(pos);
178
179            // Remove all tool mappings for this backend
180            self.tool_mapping
181                .retain(|_tool, mapped_backend| !Arc::ptr_eq(mapped_backend, &backend));
182
183            true
184        } else {
185            false
186        }
187    }
188
189    /// Clear all backends and mappings
190    pub fn clear(&mut self) {
191        self.backends.clear();
192        self.tool_mapping.clear();
193    }
194
195    /// Get API tools from all registered backends
196    ///
197    /// Collects and returns all API tool descriptions from all registered backends.
198    /// This provides a unified view of all available tools across all backends.
199    pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
200        let futures = self
201            .backends
202            .iter()
203            .map(|(_, backend)| backend.get_tool_schemas());
204        let all_schemas = futures::future::join_all(futures).await;
205
206        let mut all_tools = Vec::new();
207        for schemas in all_schemas {
208            all_tools.extend(schemas);
209        }
210
211        all_tools
212    }
213}
214
215impl Default for BackendRegistry {
216    fn default() -> Self {
217        Self::new()
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use serde_json::json;
225    use tokio_util::sync::CancellationToken;
226
227    struct MockBackend {
228        name: String,
229        tools: Vec<&'static str>,
230    }
231
232    #[async_trait]
233    impl ToolBackend for MockBackend {
234        async fn execute(
235            &self,
236            tool_call: &ToolCall,
237            _context: &ExecutionContext,
238        ) -> Result<ToolResult, ToolError> {
239            Ok(ToolResult::External(steer_tools::result::ExternalResult {
240                tool_name: self.name.clone(),
241                payload: format!("Mock execution of {} by {}", tool_call.name, self.name),
242            }))
243        }
244
245        async fn supported_tools(&self) -> Vec<String> {
246            self.tools.iter().map(|&s| s.to_string()).collect()
247        }
248
249        async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
250            Vec::new()
251        }
252
253        fn metadata(&self) -> BackendMetadata {
254            BackendMetadata::new(self.name.clone(), "Mock".to_string())
255        }
256    }
257
258    #[tokio::test]
259    async fn test_backend_registry() {
260        let mut registry = BackendRegistry::new();
261
262        let backend1 = Arc::new(MockBackend {
263            name: "backend1".to_string(),
264            tools: vec!["tool1", "tool2"],
265        });
266
267        let backend2 = Arc::new(MockBackend {
268            name: "backend2".to_string(),
269            tools: vec!["tool3", "tool4"],
270        });
271
272        registry
273            .register("backend1".to_string(), backend1.clone())
274            .await;
275        registry
276            .register("backend2".to_string(), backend2.clone())
277            .await;
278
279        // Test tool mappings
280        assert!(registry.get_backend_for_tool("tool1").is_some());
281        assert!(registry.get_backend_for_tool("tool3").is_some());
282        assert!(registry.get_backend_for_tool("unknown_tool").is_none());
283
284        // Test supported tools
285        let supported = registry.supported_tools().await;
286        assert_eq!(supported.len(), 4);
287        assert!(supported.contains(&"tool1".to_string()));
288        assert!(supported.contains(&"tool4".to_string()));
289
290        // Test backend removal
291        assert!(registry.unregister("backend1"));
292        assert!(!registry.unregister("nonexistent"));
293
294        // tool1 and tool2 should no longer be mapped
295        assert!(registry.get_backend_for_tool("tool1").is_none());
296        assert!(registry.get_backend_for_tool("tool3").is_some());
297    }
298
299    #[tokio::test]
300    async fn test_mock_backend_execution() {
301        let backend = MockBackend {
302            name: "test".to_string(),
303            tools: vec!["test_tool"],
304        };
305
306        let tool_call = ToolCall {
307            name: "test_tool".to_string(),
308            parameters: json!({}),
309            id: "test_id".to_string(),
310        };
311
312        let context = ExecutionContext::new(
313            "session".to_string(),
314            "operation".to_string(),
315            "tool_call".to_string(),
316            CancellationToken::new(),
317        );
318
319        let result = backend.execute(&tool_call, &context).await.unwrap();
320        match result {
321            ToolResult::External(external) => {
322                assert!(external.payload.contains("Mock execution"));
323                assert!(external.payload.contains("test_tool"));
324                assert!(external.payload.contains("test"));
325            }
326            _ => unreachable!("External result"),
327        }
328    }
329}