steer_core/tools/
local_backend.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::api::ToolCall;
6use crate::config::LlmConfigProvider;
7use crate::tools::{BackendMetadata, ExecutionContext, ToolBackend};
8use crate::tools::{DispatchAgentTool, FetchTool};
9use steer_tools::tools::{read_only_workspace_tools, workspace_tools};
10use steer_tools::{
11    ExecutionContext as SteerExecutionContext, Tool, ToolError, ToolSchema, result::ToolResult,
12    traits::ExecutableTool,
13};
14
15// Tool wrappers for server-side tools
16struct FetchToolWrapper(FetchTool);
17struct DispatchAgentToolWrapper(DispatchAgentTool);
18
19#[async_trait]
20impl Tool for FetchToolWrapper {
21    type Output = ToolResult;
22
23    fn name(&self) -> &'static str {
24        self.0.name()
25    }
26
27    fn description(&self) -> String {
28        self.0.description()
29    }
30
31    fn input_schema(&self) -> &'static steer_tools::InputSchema {
32        self.0.input_schema()
33    }
34
35    async fn execute(
36        &self,
37        params: serde_json::Value,
38        ctx: &SteerExecutionContext,
39    ) -> Result<Self::Output, ToolError> {
40        let result = self.0.execute(params, ctx).await?;
41        Ok(ToolResult::Fetch(result))
42    }
43
44    fn requires_approval(&self) -> bool {
45        self.0.requires_approval()
46    }
47}
48
49#[async_trait]
50impl Tool for DispatchAgentToolWrapper {
51    type Output = ToolResult;
52
53    fn name(&self) -> &'static str {
54        self.0.name()
55    }
56
57    fn description(&self) -> String {
58        self.0.description()
59    }
60
61    fn input_schema(&self) -> &'static steer_tools::InputSchema {
62        self.0.input_schema()
63    }
64
65    async fn execute(
66        &self,
67        params: serde_json::Value,
68        ctx: &SteerExecutionContext,
69    ) -> Result<Self::Output, ToolError> {
70        let result = self.0.execute(params, ctx).await?;
71        Ok(ToolResult::Agent(result))
72    }
73
74    fn requires_approval(&self) -> bool {
75        self.0.requires_approval()
76    }
77}
78
79/// Local backend that executes tools in the current process
80///
81/// This backend uses the steer-tools implementations directly.
82pub struct LocalBackend {
83    /// The tool registry containing all available tools
84    registry: HashMap<String, Box<dyn ExecutableTool>>,
85}
86
87impl Default for LocalBackend {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92
93impl LocalBackend {
94    /// Create a new empty LocalBackend
95    pub fn new() -> Self {
96        Self {
97            registry: HashMap::new(),
98        }
99    }
100
101    /// Create a backend from a collection of tool instances
102    pub fn from_tools(tools: Vec<Box<dyn ExecutableTool>>) -> Self {
103        let mut registry = HashMap::new();
104        tools.into_iter().for_each(|tool| {
105            registry.insert(tool.name().to_string(), tool);
106        });
107        Self { registry }
108    }
109
110    /// Create a backend with only specific tools enabled by name
111    ///
112    /// This method takes a list of tool names and creates a backend
113    /// containing only those tools from the full set of available tools.
114    pub fn with_tools(
115        tool_names: Vec<String>,
116        llm_config_provider: Arc<LlmConfigProvider>,
117        workspace: Arc<dyn crate::workspace::Workspace>,
118    ) -> Self {
119        let mut all_tools = workspace_tools();
120        all_tools.push(Box::new(FetchToolWrapper(FetchTool {
121            llm_config_provider: llm_config_provider.clone(),
122        })));
123        all_tools.push(Box::new(DispatchAgentToolWrapper(DispatchAgentTool {
124            llm_config_provider: llm_config_provider.clone(),
125            workspace,
126        })));
127
128        let filtered_tools: Vec<Box<dyn ExecutableTool>> = all_tools
129            .into_iter()
130            .filter(|tool| tool_names.contains(&tool.name().to_string()))
131            .collect();
132
133        Self::from_tools(filtered_tools)
134    }
135
136    /// Create a backend excluding specific tools by name
137    ///
138    /// This method takes a list of tool names to exclude and creates a backend
139    /// containing all other tools from the full set of available tools.
140    pub fn without_tools(
141        excluded_tools: Vec<String>,
142        llm_config_provider: Arc<LlmConfigProvider>,
143        workspace: Arc<dyn crate::workspace::Workspace>,
144    ) -> Self {
145        let mut all_tools = workspace_tools();
146        all_tools.push(Box::new(FetchToolWrapper(FetchTool {
147            llm_config_provider: llm_config_provider.clone(),
148        })));
149        all_tools.push(Box::new(DispatchAgentToolWrapper(DispatchAgentTool {
150            llm_config_provider: llm_config_provider.clone(),
151            workspace,
152        })));
153
154        let filtered_tools: Vec<Box<dyn ExecutableTool>> = all_tools
155            .into_iter()
156            .filter(|tool| !excluded_tools.contains(&tool.name().to_string()))
157            .collect();
158
159        Self::from_tools(filtered_tools)
160    }
161
162    /// Create a new LocalBackend with all tools (workspace + server tools)
163    pub fn full(
164        llm_config_provider: Arc<LlmConfigProvider>,
165        workspace: Arc<dyn crate::workspace::Workspace>,
166    ) -> Self {
167        let mut tools = workspace_tools();
168        // Add server-side tools
169        tools.push(Box::new(FetchToolWrapper(FetchTool {
170            llm_config_provider: llm_config_provider.clone(),
171        })));
172        tools.push(Box::new(DispatchAgentToolWrapper(DispatchAgentTool {
173            llm_config_provider: llm_config_provider.clone(),
174            workspace,
175        })));
176        Self::from_tools(tools)
177    }
178
179    /// Create a LocalBackend with only server-side tools
180    pub fn server_only(
181        llm_config_provider: Arc<LlmConfigProvider>,
182        workspace: Arc<dyn crate::workspace::Workspace>,
183    ) -> Self {
184        Self::from_tools(vec![
185            Box::new(FetchToolWrapper(FetchTool {
186                llm_config_provider: llm_config_provider.clone(),
187            })),
188            Box::new(DispatchAgentToolWrapper(DispatchAgentTool {
189                llm_config_provider: llm_config_provider.clone(),
190                workspace,
191            })),
192        ])
193    }
194
195    /// Create a LocalBackend with read-only tools
196    ///
197    /// This creates a backend with only read-only tools, useful for
198    /// sandboxed or restricted execution environments.
199    pub fn read_only(llm_config_provider: Arc<LlmConfigProvider>) -> Self {
200        let mut tools = read_only_workspace_tools();
201        // Add server-side tools (they're read-only too)
202        tools.push(Box::new(FetchToolWrapper(FetchTool {
203            llm_config_provider: llm_config_provider.clone(),
204        })));
205        Self::from_tools(tools)
206    }
207
208    /// Check if a tool is available in this backend
209    pub fn has_tool(&self, tool_name: &str) -> bool {
210        self.registry.contains_key(tool_name)
211    }
212}
213
214#[async_trait]
215impl ToolBackend for LocalBackend {
216    async fn execute(
217        &self,
218        tool_call: &ToolCall,
219        context: &ExecutionContext,
220    ) -> Result<ToolResult, ToolError> {
221        // Get the tool from the registry
222        let tool = self
223            .registry
224            .get(&tool_call.name)
225            .ok_or_else(|| ToolError::UnknownTool(tool_call.name.clone()))?;
226
227        // Create execution context for steer-tools
228        let steer_context = SteerExecutionContext::new(tool_call.id.clone())
229            .with_cancellation_token(context.cancellation_token.clone());
230
231        // Execute the tool and get the result
232        tool.run(tool_call.parameters.clone(), &steer_context).await
233    }
234
235    async fn supported_tools(&self) -> Vec<String> {
236        // Return the tools we currently have in the registry
237        self.registry.keys().cloned().collect()
238    }
239
240    async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
241        self.registry
242            .iter()
243            .map(|(name, tool)| ToolSchema {
244                name: name.clone(),
245                description: tool.description().to_string(),
246                input_schema: tool.input_schema().clone(),
247            })
248            .collect()
249    }
250
251    fn metadata(&self) -> BackendMetadata {
252        BackendMetadata::new("Local".to_string(), "Local".to_string())
253            .with_location("localhost".to_string())
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[tokio::test]
262    async fn test_local_backend_creation() {
263        let backend = LocalBackend::new();
264        assert_eq!(backend.registry.len(), 0);
265    }
266
267    #[tokio::test]
268    async fn test_local_backend_metadata() {
269        let backend = LocalBackend::new();
270        let metadata = backend.metadata();
271        assert_eq!(metadata.name, "Local");
272        assert_eq!(metadata.backend_type, "Local");
273        assert_eq!(metadata.location, Some("localhost".to_string()));
274    }
275}