1use crate::config::LlmConfigProvider;
2use crate::error::{Error, Result};
3use std::sync::Arc;
4use tokio_util::sync::CancellationToken;
5use tracing::{Span, debug, error, instrument};
6
7use crate::api::ToolCall;
8use crate::app::validation::{ValidationContext, ValidatorRegistry};
9use crate::tools::{BackendRegistry, ExecutionContext};
10use crate::workspace::Workspace;
11use steer_tools::ToolSchema;
12use steer_tools::{ToolError, result::ToolResult};
13
14#[derive(Clone)]
16pub struct ToolExecutor {
17 pub(crate) workspace: Arc<dyn Workspace>,
19 pub(crate) backend_registry: Arc<BackendRegistry>,
21 pub(crate) validators: Arc<ValidatorRegistry>,
23 pub(crate) llm_config_provider: Option<LlmConfigProvider>,
25}
26
27impl ToolExecutor {
28 pub fn with_workspace(workspace: Arc<dyn Workspace>) -> Self {
30 Self {
31 workspace,
32 backend_registry: Arc::new(BackendRegistry::new()),
33 validators: Arc::new(ValidatorRegistry::new()),
34 llm_config_provider: None,
35 }
36 }
37
38 pub fn with_components(
40 workspace: Arc<dyn Workspace>,
41 backend_registry: Arc<BackendRegistry>,
42 validators: Arc<ValidatorRegistry>,
43 ) -> Self {
44 Self {
45 workspace,
46 backend_registry,
47 validators,
48 llm_config_provider: None,
49 }
50 }
51
52 pub fn with_all_components(
54 workspace: Arc<dyn Workspace>,
55 backend_registry: Arc<BackendRegistry>,
56 validators: Arc<ValidatorRegistry>,
57 llm_config_provider: LlmConfigProvider,
58 ) -> Self {
59 Self {
60 workspace,
61 backend_registry,
62 validators,
63 llm_config_provider: Some(llm_config_provider),
64 }
65 }
66
67 pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
68 let workspace_tools = self.workspace.available_tools().await;
70 if workspace_tools.iter().any(|t| t.name == tool_name) {
71 return self
72 .workspace
73 .requires_approval(tool_name)
74 .await
75 .map_err(|e| {
76 Error::Tool(steer_tools::ToolError::InternalError(format!(
77 "Failed to check approval requirement: {e}"
78 )))
79 });
80 }
81
82 match self.backend_registry.get_backend_for_tool(tool_name) {
84 Some(backend) => backend.requires_approval(tool_name).await.map_err(|e| {
85 Error::Tool(steer_tools::ToolError::InternalError(format!(
86 "Failed to check approval requirement: {e}"
87 )))
88 }),
89 None => Err(Error::Tool(steer_tools::ToolError::UnknownTool(
90 tool_name.to_string(),
91 ))),
92 }
93 }
94
95 pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
96 let mut schemas = Vec::new();
97
98 schemas.extend(self.workspace.available_tools().await);
100
101 schemas.extend(self.backend_registry.get_tool_schemas().await);
103
104 schemas
105 }
106
107 pub async fn supported_tools(&self) -> Vec<String> {
109 let schemas = self.get_tool_schemas().await;
110 schemas.into_iter().map(|s| s.name).collect()
111 }
112
113 pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
115 &self.backend_registry
116 }
117
118 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
119 pub async fn execute_tool_with_cancellation(
120 &self,
121 tool_call: &ToolCall,
122 token: CancellationToken,
123 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
124 let tool_name = &tool_call.name;
125 let tool_id = &tool_call.id;
126
127 Span::current().record("tool.name", tool_name);
128 Span::current().record("tool.id", tool_id);
129
130 if let Some(validator) = self.validators.get_validator(tool_name) {
132 if let Some(ref llm_config_provider) = self.llm_config_provider {
134 let validation_context = ValidationContext {
135 cancellation_token: token.clone(),
136 llm_config_provider: llm_config_provider.clone(),
137 };
138
139 let validation_result = validator
140 .validate(tool_call, &validation_context)
141 .await
142 .map_err(|e| ToolError::InternalError(format!("Validation failed: {e}")))?;
143
144 if !validation_result.allowed {
145 return Err(ToolError::InternalError(
146 validation_result
147 .reason
148 .unwrap_or_else(|| "Tool execution was denied".to_string()),
149 ));
150 }
151 }
152 }
154
155 let mut builder = ExecutionContext::builder(
157 "default".to_string(), "default".to_string(), tool_call.id.clone(),
160 token,
161 );
162
163 if let Some(provider) = &self.llm_config_provider {
165 builder = builder.llm_config_provider(provider.clone());
166 }
167
168 let context = builder.build();
169
170 let workspace_tools = self.workspace.available_tools().await;
172 if workspace_tools.iter().any(|t| &t.name == tool_name) {
173 debug!(
174 target: "app.tool_executor.execute_tool_with_cancellation",
175 "Executing workspace tool {} ({}) with cancellation",
176 tool_name,
177 tool_id
178 );
179
180 return self
181 .execute_workspace_tool(&self.workspace, tool_call, &context)
182 .await;
183 }
184
185 let backend = self
187 .backend_registry
188 .get_backend_for_tool(tool_name)
189 .cloned()
190 .ok_or_else(|| {
191 error!(
192 target: "app.tool_executor.execute_tool_with_cancellation",
193 "No backend configured for tool: {} ({})",
194 tool_name,
195 tool_id
196 );
197 ToolError::UnknownTool(tool_name.clone())
198 })?;
199
200 debug!(
201 target: "app.tool_executor.execute_tool_with_cancellation",
202 "Executing external tool {} ({}) via backend with cancellation",
203 tool_name,
204 tool_id
205 );
206
207 backend.execute(tool_call, &context).await
208 }
209
210 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
212 pub async fn execute_tool_direct(
213 &self,
214 tool_call: &ToolCall,
215 token: CancellationToken,
216 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
217 let tool_name = &tool_call.name;
218 let tool_id = &tool_call.id;
219
220 Span::current().record("tool.name", tool_name);
221 Span::current().record("tool.id", tool_id);
222
223 let mut builder = ExecutionContext::builder(
225 "direct".to_string(), "direct".to_string(),
227 tool_call.id.clone(),
228 token,
229 );
230
231 if let Some(provider) = &self.llm_config_provider {
233 builder = builder.llm_config_provider(provider.clone());
234 }
235
236 let context = builder.build();
237
238 let workspace_tools = self.workspace.available_tools().await;
240 if workspace_tools.iter().any(|t| &t.name == tool_name) {
241 debug!(
242 target: "app.tool_executor.execute_tool_direct",
243 "Executing workspace tool {} ({}) directly (no validation)",
244 tool_name,
245 tool_id
246 );
247
248 return self
249 .execute_workspace_tool(&self.workspace, tool_call, &context)
250 .await;
251 }
252
253 let backend = self
255 .backend_registry
256 .get_backend_for_tool(tool_name)
257 .cloned()
258 .ok_or_else(|| {
259 error!(
260 target: "app.tool_executor.execute_tool_direct",
261 "No backend configured for tool: {} ({})",
262 tool_name,
263 tool_id
264 );
265 ToolError::UnknownTool(tool_name.clone())
266 })?;
267
268 debug!(
269 target: "app.tool_executor.execute_tool_direct",
270 "Executing external tool {} ({}) directly via backend (no validation)",
271 tool_name,
272 tool_id
273 );
274
275 backend.execute(tool_call, &context).await
276 }
277
278 async fn execute_workspace_tool(
280 &self,
281 workspace: &Arc<dyn Workspace>,
282 tool_call: &ToolCall,
283 context: &ExecutionContext,
284 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
285 let tools_context = steer_tools::ExecutionContext::new(context.tool_call_id.clone())
287 .with_cancellation_token(context.cancellation_token.clone());
288
289 workspace
290 .execute_tool(tool_call, tools_context)
291 .await
292 .map_err(|e| ToolError::InternalError(format!("Workspace execution failed: {e}")))
293 }
294}