1use crate::config::LlmConfigProvider;
2use crate::tools::error::Result;
3use std::sync::Arc;
4use tokio_util::sync::CancellationToken;
5use tracing::{Span, debug, error, instrument};
6
7use crate::api::ToolCall;
8use crate::app::validation::{ValidationContext, ValidatorRegistry};
9use crate::tools::{BackendRegistry, ExecutionContext};
10use crate::workspace::Workspace;
11use steer_tools::ToolSchema;
12use steer_tools::result::ToolResult;
13
14#[derive(Clone)]
16pub struct ToolExecutor {
17 pub(crate) workspace: Arc<dyn Workspace>,
19 pub(crate) backend_registry: Arc<BackendRegistry>,
21 pub(crate) validators: Arc<ValidatorRegistry>,
23 pub(crate) llm_config_provider: Option<LlmConfigProvider>,
25}
26
27impl ToolExecutor {
28 pub fn with_workspace(workspace: Arc<dyn Workspace>) -> Self {
30 Self {
31 workspace,
32 backend_registry: Arc::new(BackendRegistry::new()),
33 validators: Arc::new(ValidatorRegistry::new()),
34 llm_config_provider: None,
35 }
36 }
37
38 pub fn with_components(
40 workspace: Arc<dyn Workspace>,
41 backend_registry: Arc<BackendRegistry>,
42 validators: Arc<ValidatorRegistry>,
43 ) -> Self {
44 Self {
45 workspace,
46 backend_registry,
47 validators,
48 llm_config_provider: None,
49 }
50 }
51
52 pub fn with_all_components(
54 workspace: Arc<dyn Workspace>,
55 backend_registry: Arc<BackendRegistry>,
56 validators: Arc<ValidatorRegistry>,
57 llm_config_provider: LlmConfigProvider,
58 ) -> Self {
59 Self {
60 workspace,
61 backend_registry,
62 validators,
63 llm_config_provider: Some(llm_config_provider),
64 }
65 }
66
67 pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
68 let workspace_tools = self.workspace.available_tools().await;
70 if workspace_tools.iter().any(|t| t.name == tool_name) {
71 return Ok(self.workspace.requires_approval(tool_name).await?);
72 }
73
74 match self.backend_registry.get_backend_for_tool(tool_name) {
76 Some(backend) => Ok(backend.requires_approval(tool_name).await?),
77 None => Err(steer_tools::ToolError::UnknownTool(tool_name.to_string()).into()),
78 }
79 }
80
81 pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
82 let mut schemas = Vec::new();
83
84 schemas.extend(self.workspace.available_tools().await);
86
87 schemas.extend(self.backend_registry.get_tool_schemas().await);
89
90 schemas
91 }
92
93 pub async fn supported_tools(&self) -> Vec<String> {
95 let schemas = self.get_tool_schemas().await;
96 schemas.into_iter().map(|s| s.name).collect()
97 }
98
99 pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
101 &self.backend_registry
102 }
103
104 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
105 pub async fn execute_tool_with_cancellation(
106 &self,
107 tool_call: &ToolCall,
108 token: CancellationToken,
109 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
110 let tool_name = &tool_call.name;
111 let tool_id = &tool_call.id;
112
113 Span::current().record("tool.name", tool_name);
114 Span::current().record("tool.id", tool_id);
115
116 if let Some(validator) = self.validators.get_validator(tool_name) {
118 if let Some(ref llm_config_provider) = self.llm_config_provider {
120 let validation_context = ValidationContext {
121 cancellation_token: token.clone(),
122 llm_config_provider: llm_config_provider.clone(),
123 };
124
125 let validation_result = validator
126 .validate(tool_call, &validation_context)
127 .await
128 .map_err(|e| {
129 steer_tools::ToolError::InternalError(format!("Validation failed: {e}"))
130 })?;
131
132 if !validation_result.allowed {
133 return Err(steer_tools::ToolError::InternalError(
134 validation_result
135 .reason
136 .unwrap_or_else(|| "Tool execution was denied".to_string()),
137 ));
138 }
139 }
140 }
142
143 let mut builder = ExecutionContext::builder(
145 "default".to_string(), "default".to_string(), tool_call.id.clone(),
148 token,
149 );
150
151 if let Some(provider) = &self.llm_config_provider {
153 builder = builder.llm_config_provider(provider.clone());
154 }
155
156 let context = builder.build();
157
158 let workspace_tools = self.workspace.available_tools().await;
160 if workspace_tools.iter().any(|t| &t.name == tool_name) {
161 debug!(
162 target: "app.tool_executor.execute_tool_with_cancellation",
163 "Executing workspace tool {} ({}) with cancellation",
164 tool_name,
165 tool_id
166 );
167
168 return self
169 .execute_workspace_tool(&self.workspace, tool_call, &context)
170 .await;
171 }
172
173 let backend = self
175 .backend_registry
176 .get_backend_for_tool(tool_name)
177 .cloned()
178 .ok_or_else(|| {
179 error!(
180 target: "app.tool_executor.execute_tool_with_cancellation",
181 "No backend configured for tool: {} ({})",
182 tool_name,
183 tool_id
184 );
185 steer_tools::ToolError::UnknownTool(tool_name.clone())
186 })?;
187
188 debug!(
189 target: "app.tool_executor.execute_tool_with_cancellation",
190 "Executing external tool {} ({}) via backend with cancellation",
191 tool_name,
192 tool_id
193 );
194
195 backend.execute(tool_call, &context).await
196 }
197
198 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
200 pub async fn execute_tool_direct(
201 &self,
202 tool_call: &ToolCall,
203 token: CancellationToken,
204 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
205 let tool_name = &tool_call.name;
206 let tool_id = &tool_call.id;
207
208 Span::current().record("tool.name", tool_name);
209 Span::current().record("tool.id", tool_id);
210
211 let mut builder = ExecutionContext::builder(
213 "direct".to_string(), "direct".to_string(),
215 tool_call.id.clone(),
216 token,
217 );
218
219 if let Some(provider) = &self.llm_config_provider {
221 builder = builder.llm_config_provider(provider.clone());
222 }
223
224 let context = builder.build();
225
226 let workspace_tools = self.workspace.available_tools().await;
228 if workspace_tools.iter().any(|t| &t.name == tool_name) {
229 debug!(
230 target: "app.tool_executor.execute_tool_direct",
231 "Executing workspace tool {} ({}) directly (no validation)",
232 tool_name,
233 tool_id
234 );
235
236 return self
237 .execute_workspace_tool(&self.workspace, tool_call, &context)
238 .await;
239 }
240
241 let backend = self
243 .backend_registry
244 .get_backend_for_tool(tool_name)
245 .cloned()
246 .ok_or_else(|| {
247 error!(
248 target: "app.tool_executor.execute_tool_direct",
249 "No backend configured for tool: {} ({})",
250 tool_name,
251 tool_id
252 );
253 steer_tools::ToolError::UnknownTool(tool_name.clone())
254 })?;
255
256 debug!(
257 target: "app.tool_executor.execute_tool_direct",
258 "Executing external tool {} ({}) directly via backend (no validation)",
259 tool_name,
260 tool_id
261 );
262
263 backend.execute(tool_call, &context).await
264 }
265
266 async fn execute_workspace_tool(
268 &self,
269 workspace: &Arc<dyn Workspace>,
270 tool_call: &ToolCall,
271 context: &ExecutionContext,
272 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
273 let tools_context = steer_tools::ExecutionContext::new(context.tool_call_id.clone())
275 .with_cancellation_token(context.cancellation_token.clone());
276
277 workspace
278 .execute_tool(tool_call, tools_context)
279 .await
280 .map_err(|e| {
281 steer_tools::ToolError::InternalError(format!("Workspace execution failed: {e}"))
282 })
283 }
284}