1use crate::app::domain::types::{SessionId, ToolCallId};
2use crate::config::LlmConfigProvider;
3use crate::tools::error::Result;
4use std::sync::Arc;
5use tokio_util::sync::CancellationToken;
6use tracing::{Span, debug, error, instrument};
7
8use crate::app::validation::{ValidationContext, ValidatorRegistry};
9use crate::tools::registry::ToolRegistry;
10use crate::tools::resolver::BackendResolver;
11use crate::tools::services::ToolServices;
12use crate::tools::static_tool::{StaticToolContext, StaticToolError};
13use crate::tools::{BackendRegistry, ExecutionContext};
14use steer_tools::{ToolCall, ToolSchema, result::ToolResult};
15
16#[derive(Clone)]
17pub struct ToolExecutor {
18 pub(crate) backend_registry: Arc<BackendRegistry>,
19 pub(crate) validators: Arc<ValidatorRegistry>,
20 pub(crate) llm_config_provider: Option<LlmConfigProvider>,
21 pub(crate) tool_registry: Option<Arc<ToolRegistry>>,
22 pub(crate) tool_services: Option<Arc<ToolServices>>,
23}
24
25impl ToolExecutor {
26 pub fn with_components(
27 backend_registry: Arc<BackendRegistry>,
28 validators: Arc<ValidatorRegistry>,
29 ) -> Self {
30 Self {
31 backend_registry,
32 validators,
33 llm_config_provider: None,
34 tool_registry: None,
35 tool_services: None,
36 }
37 }
38
39 pub fn with_all_components(
40 backend_registry: Arc<BackendRegistry>,
41 validators: Arc<ValidatorRegistry>,
42 llm_config_provider: LlmConfigProvider,
43 ) -> Self {
44 Self {
45 backend_registry,
46 validators,
47 llm_config_provider: Some(llm_config_provider),
48 tool_registry: None,
49 tool_services: None,
50 }
51 }
52
53 pub fn with_static_tools(
54 mut self,
55 registry: Arc<ToolRegistry>,
56 services: Arc<ToolServices>,
57 ) -> Self {
58 self.tool_registry = Some(registry);
59 self.tool_services = Some(services);
60 self
61 }
62
63 pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
64 if let Some(registry) = &self.tool_registry
65 && registry.is_static_tool(tool_name)
66 {
67 return Ok(registry.requires_approval(tool_name));
68 }
69
70 match self.backend_registry.get_backend_for_tool(tool_name) {
71 Some(backend) => Ok(backend.requires_approval(tool_name).await?),
72 None => Err(steer_tools::ToolError::UnknownTool(tool_name.to_string()).into()),
73 }
74 }
75
76 pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
77 self.get_tool_schemas_with_capabilities(super::Capabilities::all())
78 .await
79 }
80
81 pub async fn get_tool_schemas_with_resolver(
82 &self,
83 session_resolver: Option<&dyn BackendResolver>,
84 ) -> Vec<ToolSchema> {
85 self.get_tool_schemas_with_capabilities_and_resolver(
86 super::Capabilities::all(),
87 session_resolver,
88 )
89 .await
90 }
91
92 pub async fn get_tool_schemas_with_capabilities(
93 &self,
94 capabilities: super::Capabilities,
95 ) -> Vec<ToolSchema> {
96 self.get_tool_schemas_with_capabilities_and_resolver(capabilities, None)
97 .await
98 }
99
100 pub async fn get_tool_schemas_with_capabilities_and_resolver(
101 &self,
102 capabilities: super::Capabilities,
103 session_resolver: Option<&dyn BackendResolver>,
104 ) -> Vec<ToolSchema> {
105 let mut schemas = Vec::new();
106 let mut static_tool_names = std::collections::HashSet::new();
107
108 if let Some(registry) = &self.tool_registry {
109 let static_schemas = registry.available_schemas(capabilities).await;
110 for schema in &static_schemas {
111 static_tool_names.insert(schema.name.clone());
112 }
113 schemas.extend(static_schemas);
114 }
115
116 if let Some(resolver) = session_resolver {
117 for schema in resolver.get_tool_schemas().await {
118 if !static_tool_names.contains(&schema.name) {
119 schemas.push(schema);
120 }
121 }
122 }
123
124 for schema in self.backend_registry.get_tool_schemas().await {
125 if !static_tool_names.contains(&schema.name) {
126 schemas.push(schema);
127 }
128 }
129
130 schemas
131 }
132
133 pub fn is_static_tool(&self, tool_name: &str) -> bool {
134 self.tool_registry
135 .as_ref()
136 .is_some_and(|r| r.is_static_tool(tool_name))
137 }
138
139 pub async fn supported_tools(&self) -> Vec<String> {
141 let schemas = self.get_tool_schemas().await;
142 schemas.into_iter().map(|s| s.name).collect()
143 }
144
145 pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
147 &self.backend_registry
148 }
149
150 pub fn workspace(&self) -> Option<Arc<dyn crate::workspace::Workspace>> {
151 self.tool_services
152 .as_ref()
153 .map(|services| services.workspace.clone())
154 }
155
156 #[instrument(skip(self, tool_call, session_id, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
157 pub async fn execute_tool_with_session(
158 &self,
159 tool_call: &ToolCall,
160 session_id: SessionId,
161 token: CancellationToken,
162 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
163 self.execute_tool_with_session_resolver(tool_call, session_id, token, None)
164 .await
165 }
166
167 #[instrument(skip(self, tool_call, session_id, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
168 pub async fn execute_tool_with_session_resolver(
169 &self,
170 tool_call: &ToolCall,
171 session_id: SessionId,
172 token: CancellationToken,
173 session_resolver: Option<&dyn BackendResolver>,
174 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
175 let tool_name = &tool_call.name;
176
177 if let Some((registry, services)) =
178 self.tool_registry.as_ref().zip(self.tool_services.as_ref())
179 && let Some(tool) = registry.static_tool(tool_name)
180 {
181 debug!(target: "tool_executor", "Executing static tool: {}", tool_name);
182 return self
183 .execute_static_tool(tool, tool_call, session_id, services, token)
184 .await;
185 }
186
187 self.execute_tool_with_resolver(tool_call, token, session_resolver)
188 .await
189 }
190
191 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
192 pub async fn execute_tool_with_cancellation(
193 &self,
194 tool_call: &ToolCall,
195 token: CancellationToken,
196 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
197 self.execute_tool_with_resolver(tool_call, token, None)
198 .await
199 }
200
201 #[instrument(skip(self, tool_call, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
202 pub async fn execute_tool_with_resolver(
203 &self,
204 tool_call: &ToolCall,
205 token: CancellationToken,
206 session_resolver: Option<&dyn BackendResolver>,
207 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
208 let tool_name = &tool_call.name;
209 let tool_id = &tool_call.id;
210
211 Span::current().record("tool.name", tool_name);
212 Span::current().record("tool.id", tool_id);
213
214 if let Some(validator) = self.validators.get_validator(tool_name)
215 && let Some(ref llm_config_provider) = self.llm_config_provider
216 {
217 let validation_context = ValidationContext {
218 cancellation_token: token.clone(),
219 llm_config_provider: llm_config_provider.clone(),
220 };
221
222 let validation_result = validator
223 .validate(tool_call, &validation_context)
224 .await
225 .map_err(|e| {
226 steer_tools::ToolError::InternalError(format!("Validation failed: {e}"))
227 })?;
228
229 if !validation_result.allowed {
230 return Err(steer_tools::ToolError::InternalError(
231 validation_result
232 .reason
233 .unwrap_or_else(|| "Tool execution was denied".to_string()),
234 ));
235 }
236 }
237
238 let mut builder = ExecutionContext::builder(
239 "default".to_string(),
240 "default".to_string(),
241 tool_call.id.clone(),
242 token,
243 );
244
245 if let Some(provider) = &self.llm_config_provider {
246 builder = builder.llm_config_provider(provider.clone());
247 }
248
249 let context = builder.build();
250
251 if let Some(resolver) = session_resolver
252 && let Some(backend) = resolver.resolve(tool_name).await
253 {
254 debug!(target: "tool_executor", "Executing session MCP tool: {} ({})", tool_name, tool_id);
255 return backend.execute(tool_call, &context).await;
256 }
257
258 let backend = self
259 .backend_registry
260 .get_backend_for_tool(tool_name)
261 .cloned()
262 .ok_or_else(|| {
263 error!(target: "tool_executor", "No backend for tool: {} ({})", tool_name, tool_id);
264 steer_tools::ToolError::UnknownTool(tool_name.clone())
265 })?;
266
267 debug!(target: "tool_executor", "Executing external tool: {} ({})", tool_name, tool_id);
268 backend.execute(tool_call, &context).await
269 }
270
271 async fn execute_static_tool(
272 &self,
273 tool: &dyn super::static_tool::StaticToolErased,
274 tool_call: &ToolCall,
275 session_id: SessionId,
276 services: &Arc<ToolServices>,
277 token: CancellationToken,
278 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
279 let ctx = StaticToolContext {
280 tool_call_id: ToolCallId(tool_call.id.clone()),
281 session_id,
282 cancellation_token: token,
283 services: services.clone(),
284 };
285
286 let output = tool
287 .execute_erased(tool_call.parameters.clone(), &ctx)
288 .await
289 .map_err(|e| match e {
290 StaticToolError::InvalidParams(msg) => steer_tools::ToolError::InvalidParams {
291 tool_name: tool_call.name.clone(),
292 message: msg,
293 },
294 StaticToolError::Execution(err) => steer_tools::ToolError::Execution(err),
295 StaticToolError::MissingCapability(cap) => {
296 steer_tools::ToolError::InternalError(format!("Missing capability: {cap}"))
297 }
298 StaticToolError::Cancelled => {
299 steer_tools::ToolError::Cancelled(tool_call.name.clone())
300 }
301 StaticToolError::Timeout => steer_tools::ToolError::Timeout(tool_call.name.clone()),
302 })?;
303
304 Ok(output)
305 }
306
307 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
309 pub async fn execute_tool_direct(
310 &self,
311 tool_call: &ToolCall,
312 token: CancellationToken,
313 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
314 let tool_name = &tool_call.name;
315 let tool_id = &tool_call.id;
316
317 Span::current().record("tool.name", tool_name);
318 Span::current().record("tool.id", tool_id);
319
320 if let Some((registry, services)) =
321 self.tool_registry.as_ref().zip(self.tool_services.as_ref())
322 && let Some(tool) = registry.static_tool(tool_name)
323 {
324 debug!(
325 target: "app.tool_executor.execute_tool_direct",
326 "Executing static tool {} ({}) directly (no validation)",
327 tool_name,
328 tool_id
329 );
330 return self
331 .execute_static_tool(tool, tool_call, SessionId::new(), services, token)
332 .await;
333 }
334
335 let mut builder = ExecutionContext::builder(
337 "direct".to_string(), "direct".to_string(),
339 tool_call.id.clone(),
340 token,
341 );
342
343 if let Some(provider) = &self.llm_config_provider {
345 builder = builder.llm_config_provider(provider.clone());
346 }
347
348 let context = builder.build();
349
350 let backend = self
352 .backend_registry
353 .get_backend_for_tool(tool_name)
354 .cloned()
355 .ok_or_else(|| {
356 error!(
357 target: "app.tool_executor.execute_tool_direct",
358 "No backend configured for tool: {} ({})",
359 tool_name,
360 tool_id
361 );
362 steer_tools::ToolError::UnknownTool(tool_name.clone())
363 })?;
364
365 debug!(
366 target: "app.tool_executor.execute_tool_direct",
367 "Executing external tool {} ({}) directly via backend (no validation)",
368 tool_name,
369 tool_id
370 );
371
372 backend.execute(tool_call, &context).await
373 }
374}