steer_core/tools/
local_backend.rs1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use crate::api::ToolCall;
6use crate::config::LlmConfigProvider;
7use crate::tools::{BackendMetadata, ExecutionContext, ToolBackend};
8use crate::tools::{DispatchAgentTool, FetchTool};
9use steer_tools::tools::{read_only_workspace_tools, workspace_tools};
10use steer_tools::{
11 ExecutionContext as SteerExecutionContext, Tool, ToolError, ToolSchema, result::ToolResult,
12 traits::ExecutableTool,
13};
14
15struct FetchToolWrapper(FetchTool);
17struct DispatchAgentToolWrapper(DispatchAgentTool);
18
19#[async_trait]
20impl Tool for FetchToolWrapper {
21 type Output = ToolResult;
22
23 fn name(&self) -> &'static str {
24 self.0.name()
25 }
26
27 fn description(&self) -> String {
28 self.0.description()
29 }
30
31 fn input_schema(&self) -> &'static steer_tools::InputSchema {
32 self.0.input_schema()
33 }
34
35 async fn execute(
36 &self,
37 params: serde_json::Value,
38 ctx: &SteerExecutionContext,
39 ) -> Result<Self::Output, ToolError> {
40 let result = self.0.execute(params, ctx).await?;
41 Ok(ToolResult::Fetch(result))
42 }
43
44 fn requires_approval(&self) -> bool {
45 self.0.requires_approval()
46 }
47}
48
49#[async_trait]
50impl Tool for DispatchAgentToolWrapper {
51 type Output = ToolResult;
52
53 fn name(&self) -> &'static str {
54 self.0.name()
55 }
56
57 fn description(&self) -> String {
58 self.0.description()
59 }
60
61 fn input_schema(&self) -> &'static steer_tools::InputSchema {
62 self.0.input_schema()
63 }
64
65 async fn execute(
66 &self,
67 params: serde_json::Value,
68 ctx: &SteerExecutionContext,
69 ) -> Result<Self::Output, ToolError> {
70 let result = self.0.execute(params, ctx).await?;
71 Ok(ToolResult::Agent(result))
72 }
73
74 fn requires_approval(&self) -> bool {
75 self.0.requires_approval()
76 }
77}
78
79pub struct LocalBackend {
83 registry: HashMap<String, Box<dyn ExecutableTool>>,
85}
86
87impl Default for LocalBackend {
88 fn default() -> Self {
89 Self::new()
90 }
91}
92
93impl LocalBackend {
94 pub fn new() -> Self {
96 Self {
97 registry: HashMap::new(),
98 }
99 }
100
101 pub fn from_tools(tools: Vec<Box<dyn ExecutableTool>>) -> Self {
103 let mut registry = HashMap::new();
104 tools.into_iter().for_each(|tool| {
105 registry.insert(tool.name().to_string(), tool);
106 });
107 Self { registry }
108 }
109
110 pub fn with_tools(
115 tool_names: Vec<String>,
116 llm_config_provider: Arc<LlmConfigProvider>,
117 workspace: Arc<dyn crate::workspace::Workspace>,
118 ) -> Self {
119 let mut all_tools = workspace_tools();
120 all_tools.push(Box::new(FetchToolWrapper(FetchTool {
121 llm_config_provider: llm_config_provider.clone(),
122 })));
123 all_tools.push(Box::new(DispatchAgentToolWrapper(DispatchAgentTool {
124 llm_config_provider: llm_config_provider.clone(),
125 workspace,
126 })));
127
128 let filtered_tools: Vec<Box<dyn ExecutableTool>> = all_tools
129 .into_iter()
130 .filter(|tool| tool_names.contains(&tool.name().to_string()))
131 .collect();
132
133 Self::from_tools(filtered_tools)
134 }
135
136 pub fn without_tools(
141 excluded_tools: Vec<String>,
142 llm_config_provider: Arc<LlmConfigProvider>,
143 workspace: Arc<dyn crate::workspace::Workspace>,
144 ) -> Self {
145 let mut all_tools = workspace_tools();
146 all_tools.push(Box::new(FetchToolWrapper(FetchTool {
147 llm_config_provider: llm_config_provider.clone(),
148 })));
149 all_tools.push(Box::new(DispatchAgentToolWrapper(DispatchAgentTool {
150 llm_config_provider: llm_config_provider.clone(),
151 workspace,
152 })));
153
154 let filtered_tools: Vec<Box<dyn ExecutableTool>> = all_tools
155 .into_iter()
156 .filter(|tool| !excluded_tools.contains(&tool.name().to_string()))
157 .collect();
158
159 Self::from_tools(filtered_tools)
160 }
161
162 pub fn full(
164 llm_config_provider: Arc<LlmConfigProvider>,
165 workspace: Arc<dyn crate::workspace::Workspace>,
166 ) -> Self {
167 let mut tools = workspace_tools();
168 tools.push(Box::new(FetchToolWrapper(FetchTool {
170 llm_config_provider: llm_config_provider.clone(),
171 })));
172 tools.push(Box::new(DispatchAgentToolWrapper(DispatchAgentTool {
173 llm_config_provider: llm_config_provider.clone(),
174 workspace,
175 })));
176 Self::from_tools(tools)
177 }
178
179 pub fn server_only(
181 llm_config_provider: Arc<LlmConfigProvider>,
182 workspace: Arc<dyn crate::workspace::Workspace>,
183 ) -> Self {
184 Self::from_tools(vec![
185 Box::new(FetchToolWrapper(FetchTool {
186 llm_config_provider: llm_config_provider.clone(),
187 })),
188 Box::new(DispatchAgentToolWrapper(DispatchAgentTool {
189 llm_config_provider: llm_config_provider.clone(),
190 workspace,
191 })),
192 ])
193 }
194
195 pub fn read_only(llm_config_provider: Arc<LlmConfigProvider>) -> Self {
200 let mut tools = read_only_workspace_tools();
201 tools.push(Box::new(FetchToolWrapper(FetchTool {
203 llm_config_provider: llm_config_provider.clone(),
204 })));
205 Self::from_tools(tools)
206 }
207
208 pub fn has_tool(&self, tool_name: &str) -> bool {
210 self.registry.contains_key(tool_name)
211 }
212}
213
214#[async_trait]
215impl ToolBackend for LocalBackend {
216 async fn execute(
217 &self,
218 tool_call: &ToolCall,
219 context: &ExecutionContext,
220 ) -> Result<ToolResult, ToolError> {
221 let tool = self
223 .registry
224 .get(&tool_call.name)
225 .ok_or_else(|| ToolError::UnknownTool(tool_call.name.clone()))?;
226
227 let steer_context = SteerExecutionContext::new(tool_call.id.clone())
229 .with_cancellation_token(context.cancellation_token.clone());
230
231 tool.run(tool_call.parameters.clone(), &steer_context).await
233 }
234
235 async fn supported_tools(&self) -> Vec<String> {
236 self.registry.keys().cloned().collect()
238 }
239
240 async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
241 self.registry
242 .iter()
243 .map(|(name, tool)| ToolSchema {
244 name: name.clone(),
245 description: tool.description().to_string(),
246 input_schema: tool.input_schema().clone(),
247 })
248 .collect()
249 }
250
251 fn metadata(&self) -> BackendMetadata {
252 BackendMetadata::new("Local".to_string(), "Local".to_string())
253 .with_location("localhost".to_string())
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[tokio::test]
262 async fn test_local_backend_creation() {
263 let backend = LocalBackend::new();
264 assert_eq!(backend.registry.len(), 0);
265 }
266
267 #[tokio::test]
268 async fn test_local_backend_metadata() {
269 let backend = LocalBackend::new();
270 let metadata = backend.metadata();
271 assert_eq!(metadata.name, "Local");
272 assert_eq!(metadata.backend_type, "Local");
273 assert_eq!(metadata.location, Some("localhost".to_string()));
274 }
275}