Skip to main content

steer_core/tools/
backend.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::tools::ExecutionContext;
6use steer_tools::{ToolCall, ToolError, ToolSchema, result::ToolResult};
7
8/// Metadata about a tool backend for debugging and monitoring
9#[derive(Debug, Clone)]
10pub struct BackendMetadata {
11    pub name: String,
12    pub backend_type: String,
13    pub location: Option<String>,
14    pub additional_info: HashMap<String, String>,
15}
16
17impl BackendMetadata {
18    pub fn new(name: String, backend_type: String) -> Self {
19        Self {
20            name,
21            backend_type,
22            location: None,
23            additional_info: HashMap::new(),
24        }
25    }
26
27    pub fn with_location(mut self, location: String) -> Self {
28        self.location = Some(location);
29        self
30    }
31
32    pub fn with_info(mut self, key: String, value: String) -> Self {
33        self.additional_info.insert(key, value);
34        self
35    }
36}
37
38/// Simple trait for tool execution backends
39///
40/// This trait abstracts different execution environments for tools,
41/// allowing tools to run locally, on remote machines, in containers,
42/// or through proxy services.
43#[async_trait]
44pub trait ToolBackend: Send + Sync {
45    /// Execute a tool call in this backend's environment
46    ///
47    /// # Arguments
48    /// * `tool_call` - The tool call containing name, parameters, and ID
49    /// * `context` - Execution context with session info, cancellation, etc.
50    ///
51    /// # Returns
52    /// The typed tool result on success, or a ToolError on failure
53    async fn execute(
54        &self,
55        tool_call: &ToolCall,
56        context: &ExecutionContext,
57    ) -> Result<ToolResult, ToolError>;
58
59    /// List the tools this backend can handle
60    ///
61    /// Returns a vector of tool names that this backend supports.
62    /// The backend registry uses this to map tools to backends.
63    async fn supported_tools(&self) -> Vec<String>;
64
65    /// Get API tool descriptions for this backend
66    ///
67    /// Returns a vector of ToolSchema objects containing name, description,
68    /// and input schema for each tool this backend supports.
69    async fn get_tool_schemas(&self) -> Vec<ToolSchema>;
70
71    /// Backend metadata for debugging and monitoring
72    ///
73    /// Override this method to provide additional information about
74    /// the backend for observability and troubleshooting.
75    fn metadata(&self) -> BackendMetadata;
76
77    /// Check if the backend is healthy and ready to execute tools
78    ///
79    /// This method can be used for health checks and load balancing.
80    /// Default implementation returns true.
81    async fn health_check(&self) -> bool {
82        true
83    }
84
85    /// Check if a tool requires approval before execution
86    ///
87    /// Returns true if the tool requires user approval before it can be executed.
88    /// Default implementation returns true for safety.
89    ///
90    /// # Arguments
91    /// * `tool_name` - The name of the tool to check
92    ///
93    /// # Returns
94    /// Ok(true) if approval is required, Ok(false) if not, or an error if the tool is unknown
95    async fn requires_approval(&self, _tool_name: &str) -> Result<bool, ToolError> {
96        // Default conservative implementation - require approval
97        // Backends should override this to provide accurate information
98        Ok(true)
99    }
100}
101
102/// Registry that maps tool names to their backends
103///
104/// When a backend is registered, the registry queries its supported_tools()
105/// and creates mappings for each tool. This allows for efficient lookup
106/// of the appropriate backend for a given tool name.
107pub struct BackendRegistry {
108    backends: Vec<(String, Arc<dyn ToolBackend>)>,
109    tool_mapping: HashMap<String, Arc<dyn ToolBackend>>,
110}
111
112impl BackendRegistry {
113    /// Create a new empty backend registry
114    pub fn new() -> Self {
115        Self {
116            backends: Vec::new(),
117            tool_mapping: HashMap::new(),
118        }
119    }
120
121    /// Register a backend with the given name
122    ///
123    /// This method queries the backend's supported_tools() and creates
124    /// mappings for each tool. If a tool is already mapped to another
125    /// backend, it will be overwritten.
126    ///
127    /// # Arguments
128    /// * `name` - A unique name for this backend instance
129    /// * `backend` - The backend implementation
130    pub async fn register(&mut self, name: String, backend: Arc<dyn ToolBackend>) {
131        // Map each tool this backend supports
132        for tool_name in backend.supported_tools().await {
133            self.tool_mapping.insert(tool_name.clone(), backend.clone());
134        }
135        self.backends.push((name, backend));
136    }
137
138    /// Get the backend for a specific tool
139    ///
140    /// Returns the backend that can handle the given tool name,
141    /// or None if no backend supports that tool.
142    ///
143    /// # Arguments
144    /// * `tool_name` - The name of the tool to look up
145    pub fn get_backend_for_tool(&self, tool_name: &str) -> Option<&Arc<dyn ToolBackend>> {
146        self.tool_mapping.get(tool_name)
147    }
148
149    /// Get all registered backends
150    ///
151    /// Returns a vector of (name, backend) pairs for all registered backends.
152    pub fn backends(&self) -> &Vec<(String, Arc<dyn ToolBackend>)> {
153        &self.backends
154    }
155
156    /// Get all tool mappings
157    ///
158    /// Returns a reference to the tool name -> backend mapping.
159    pub fn tool_mappings(&self) -> &HashMap<String, Arc<dyn ToolBackend>> {
160        &self.tool_mapping
161    }
162
163    /// Check which tools are supported
164    ///
165    /// Returns a vector of all tool names that have registered backends.
166    pub async fn supported_tools(&self) -> Vec<String> {
167        self.tool_mapping.keys().cloned().collect()
168    }
169
170    /// Remove a backend by name
171    ///
172    /// This removes the backend and all its tool mappings.
173    /// Returns true if a backend was removed, false if the name wasn't found.
174    pub fn unregister(&mut self, name: &str) -> bool {
175        if let Some(pos) = self.backends.iter().position(|(n, _)| n == name) {
176            let (_, backend) = self.backends.remove(pos);
177
178            // Remove all tool mappings for this backend
179            self.tool_mapping
180                .retain(|_tool, mapped_backend| !Arc::ptr_eq(mapped_backend, &backend));
181
182            true
183        } else {
184            false
185        }
186    }
187
188    /// Clear all backends and mappings
189    pub fn clear(&mut self) {
190        self.backends.clear();
191        self.tool_mapping.clear();
192    }
193
194    /// Get API tools from all registered backends
195    ///
196    /// Collects and returns all API tool descriptions from all registered backends.
197    /// This provides a unified view of all available tools across all backends.
198    pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
199        let futures = self
200            .backends
201            .iter()
202            .map(|(_, backend)| backend.get_tool_schemas());
203        let all_schemas = futures::future::join_all(futures).await;
204
205        let mut all_tools = Vec::new();
206        for schemas in all_schemas {
207            all_tools.extend(schemas);
208        }
209
210        all_tools
211    }
212}
213
214impl Default for BackendRegistry {
215    fn default() -> Self {
216        Self::new()
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use serde_json::json;
224    use tokio_util::sync::CancellationToken;
225
226    struct MockBackend {
227        name: String,
228        tools: Vec<&'static str>,
229    }
230
231    #[async_trait]
232    impl ToolBackend for MockBackend {
233        async fn execute(
234            &self,
235            tool_call: &ToolCall,
236            _context: &ExecutionContext,
237        ) -> Result<ToolResult, ToolError> {
238            Ok(ToolResult::External(steer_tools::result::ExternalResult {
239                tool_name: self.name.clone(),
240                payload: format!("Mock execution of {} by {}", tool_call.name, self.name),
241            }))
242        }
243
244        async fn supported_tools(&self) -> Vec<String> {
245            self.tools.iter().map(|&s| s.to_string()).collect()
246        }
247
248        async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
249            Vec::new()
250        }
251
252        fn metadata(&self) -> BackendMetadata {
253            BackendMetadata::new(self.name.clone(), "Mock".to_string())
254        }
255    }
256
257    #[tokio::test]
258    async fn test_backend_registry() {
259        let mut registry = BackendRegistry::new();
260
261        let backend1 = Arc::new(MockBackend {
262            name: "backend1".to_string(),
263            tools: vec!["tool1", "tool2"],
264        });
265
266        let backend2 = Arc::new(MockBackend {
267            name: "backend2".to_string(),
268            tools: vec!["tool3", "tool4"],
269        });
270
271        registry
272            .register("backend1".to_string(), backend1.clone())
273            .await;
274        registry
275            .register("backend2".to_string(), backend2.clone())
276            .await;
277
278        // Test tool mappings
279        assert!(registry.get_backend_for_tool("tool1").is_some());
280        assert!(registry.get_backend_for_tool("tool3").is_some());
281        assert!(registry.get_backend_for_tool("unknown_tool").is_none());
282
283        // Test supported tools
284        let supported = registry.supported_tools().await;
285        assert_eq!(supported.len(), 4);
286        assert!(supported.contains(&"tool1".to_string()));
287        assert!(supported.contains(&"tool4".to_string()));
288
289        // Test backend removal
290        assert!(registry.unregister("backend1"));
291        assert!(!registry.unregister("nonexistent"));
292
293        // tool1 and tool2 should no longer be mapped
294        assert!(registry.get_backend_for_tool("tool1").is_none());
295        assert!(registry.get_backend_for_tool("tool3").is_some());
296    }
297
298    #[tokio::test]
299    async fn test_mock_backend_execution() {
300        let backend = MockBackend {
301            name: "test".to_string(),
302            tools: vec!["test_tool"],
303        };
304
305        let tool_call = ToolCall {
306            name: "test_tool".to_string(),
307            parameters: json!({}),
308            id: "test_id".to_string(),
309        };
310
311        let context = ExecutionContext::new(
312            "session".to_string(),
313            "operation".to_string(),
314            "tool_call".to_string(),
315            CancellationToken::new(),
316        );
317
318        let result = backend.execute(&tool_call, &context).await.unwrap();
319        match result {
320            ToolResult::External(external) => {
321                assert!(external.payload.contains("Mock execution"));
322                assert!(external.payload.contains("test_tool"));
323                assert!(external.payload.contains("test"));
324            }
325            _ => unreachable!("External result"),
326        }
327    }
328}