1use crate::app::domain::types::{SessionId, ToolCallId};
2use crate::config::LlmConfigProvider;
3use crate::config::model::ModelId;
4use crate::tools::error::Result;
5use std::sync::Arc;
6use tokio_util::sync::CancellationToken;
7use tracing::{Span, debug, error, instrument};
8
9use crate::app::validation::{ValidationContext, ValidatorRegistry};
10use crate::tools::builtin_tool::{BuiltinToolContext, BuiltinToolError};
11use crate::tools::registry::ToolRegistry;
12use crate::tools::resolver::BackendResolver;
13use crate::tools::services::ToolServices;
14use crate::tools::{BackendRegistry, ExecutionContext};
15use steer_tools::{ToolCall, ToolSchema, result::ToolResult};
16
17#[derive(Clone)]
18pub struct ToolExecutor {
19 pub(crate) backend_registry: Arc<BackendRegistry>,
20 pub(crate) validators: Arc<ValidatorRegistry>,
21 pub(crate) llm_config_provider: Option<LlmConfigProvider>,
22 pub(crate) tool_registry: Option<Arc<ToolRegistry>>,
23 pub(crate) tool_services: Option<Arc<ToolServices>>,
24}
25
26impl ToolExecutor {
27 pub fn with_components(
28 backend_registry: Arc<BackendRegistry>,
29 validators: Arc<ValidatorRegistry>,
30 ) -> Self {
31 Self {
32 backend_registry,
33 validators,
34 llm_config_provider: None,
35 tool_registry: None,
36 tool_services: None,
37 }
38 }
39
40 pub fn with_all_components(
41 backend_registry: Arc<BackendRegistry>,
42 validators: Arc<ValidatorRegistry>,
43 llm_config_provider: LlmConfigProvider,
44 ) -> Self {
45 Self {
46 backend_registry,
47 validators,
48 llm_config_provider: Some(llm_config_provider),
49 tool_registry: None,
50 tool_services: None,
51 }
52 }
53
54 pub fn with_builtin_tools(
55 mut self,
56 registry: Arc<ToolRegistry>,
57 services: Arc<ToolServices>,
58 ) -> Self {
59 self.tool_registry = Some(registry);
60 self.tool_services = Some(services);
61 self
62 }
63
64 pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
65 if let Some(registry) = &self.tool_registry
66 && registry.is_builtin_tool(tool_name)
67 {
68 return Ok(registry.requires_approval(tool_name));
69 }
70
71 match self.backend_registry.get_backend_for_tool(tool_name) {
72 Some(backend) => Ok(backend.requires_approval(tool_name).await?),
73 None => Err(steer_tools::ToolError::UnknownTool(tool_name.to_string()).into()),
74 }
75 }
76
77 pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
78 self.get_tool_schemas_with_capabilities(super::Capabilities::all())
79 .await
80 }
81
82 pub async fn get_tool_schemas_with_resolver(
83 &self,
84 session_resolver: Option<&dyn BackendResolver>,
85 ) -> Vec<ToolSchema> {
86 self.get_tool_schemas_with_capabilities_and_resolver(
87 super::Capabilities::all(),
88 session_resolver,
89 )
90 .await
91 }
92
93 pub async fn get_tool_schemas_with_capabilities(
94 &self,
95 capabilities: super::Capabilities,
96 ) -> Vec<ToolSchema> {
97 self.get_tool_schemas_with_capabilities_and_resolver(capabilities, None)
98 .await
99 }
100
101 pub async fn get_tool_schemas_with_capabilities_and_resolver(
102 &self,
103 capabilities: super::Capabilities,
104 session_resolver: Option<&dyn BackendResolver>,
105 ) -> Vec<ToolSchema> {
106 let mut schemas = Vec::new();
107 let mut builtin_tool_names = std::collections::HashSet::new();
108
109 if let Some(registry) = &self.tool_registry {
110 let builtin_schemas = registry.available_schemas(capabilities).await;
111 for schema in &builtin_schemas {
112 builtin_tool_names.insert(schema.name.clone());
113 }
114 schemas.extend(builtin_schemas);
115 }
116
117 if let Some(resolver) = session_resolver {
118 for schema in resolver.get_tool_schemas().await {
119 if !builtin_tool_names.contains(&schema.name) {
120 schemas.push(schema);
121 }
122 }
123 }
124
125 for schema in self.backend_registry.get_tool_schemas().await {
126 if !builtin_tool_names.contains(&schema.name) {
127 schemas.push(schema);
128 }
129 }
130
131 schemas
132 }
133
134 pub fn is_builtin_tool(&self, tool_name: &str) -> bool {
135 self.tool_registry
136 .as_ref()
137 .is_some_and(|r| r.is_builtin_tool(tool_name))
138 }
139
140 pub async fn supported_tools(&self) -> Vec<String> {
142 let schemas = self.get_tool_schemas().await;
143 schemas.into_iter().map(|s| s.name).collect()
144 }
145
146 pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
148 &self.backend_registry
149 }
150
151 pub fn workspace(&self) -> Option<Arc<dyn crate::workspace::Workspace>> {
152 self.tool_services
153 .as_ref()
154 .map(|services| services.workspace.clone())
155 }
156
157 #[instrument(skip(self, tool_call, session_id, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
158 pub async fn execute_tool_with_session(
159 &self,
160 tool_call: &ToolCall,
161 session_id: SessionId,
162 token: CancellationToken,
163 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
164 self.execute_tool_with_session_resolver(tool_call, session_id, None, token, None)
165 .await
166 }
167
168 #[instrument(skip(self, tool_call, session_id, invoking_model, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
169 pub async fn execute_tool_with_session_resolver(
170 &self,
171 tool_call: &ToolCall,
172 session_id: SessionId,
173 invoking_model: Option<ModelId>,
174 token: CancellationToken,
175 session_resolver: Option<&dyn BackendResolver>,
176 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
177 let tool_name = &tool_call.name;
178
179 if let Some((registry, services)) =
180 self.tool_registry.as_ref().zip(self.tool_services.as_ref())
181 && let Some(tool) = registry.builtin_tool(tool_name)
182 {
183 debug!(target: "tool_executor", "Executing builtin tool: {}", tool_name);
184 return self
185 .execute_builtin_tool(tool, tool_call, session_id, invoking_model, services, token)
186 .await;
187 }
188
189 self.execute_tool_with_resolver(tool_call, token, session_resolver)
190 .await
191 }
192
193 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
194 pub async fn execute_tool_with_cancellation(
195 &self,
196 tool_call: &ToolCall,
197 token: CancellationToken,
198 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
199 self.execute_tool_with_resolver(tool_call, token, None)
200 .await
201 }
202
203 #[instrument(skip(self, tool_call, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
204 pub async fn execute_tool_with_resolver(
205 &self,
206 tool_call: &ToolCall,
207 token: CancellationToken,
208 session_resolver: Option<&dyn BackendResolver>,
209 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
210 let tool_name = &tool_call.name;
211 let tool_id = &tool_call.id;
212
213 Span::current().record("tool.name", tool_name);
214 Span::current().record("tool.id", tool_id);
215
216 if let Some(validator) = self.validators.get_validator(tool_name)
217 && let Some(ref llm_config_provider) = self.llm_config_provider
218 {
219 let validation_context = ValidationContext {
220 cancellation_token: token.clone(),
221 llm_config_provider: llm_config_provider.clone(),
222 };
223
224 let validation_result = validator
225 .validate(tool_call, &validation_context)
226 .await
227 .map_err(|e| {
228 steer_tools::ToolError::InternalError(format!("Validation failed: {e}"))
229 })?;
230
231 if !validation_result.allowed {
232 return Err(steer_tools::ToolError::InternalError(
233 validation_result
234 .reason
235 .unwrap_or_else(|| "Tool execution was denied".to_string()),
236 ));
237 }
238 }
239
240 let mut builder = ExecutionContext::builder(
241 "default".to_string(),
242 "default".to_string(),
243 tool_call.id.clone(),
244 token,
245 );
246
247 if let Some(provider) = &self.llm_config_provider {
248 builder = builder.llm_config_provider(provider.clone());
249 }
250
251 let context = builder.build();
252
253 if let Some(resolver) = session_resolver
254 && let Some(backend) = resolver.resolve(tool_name).await
255 {
256 debug!(target: "tool_executor", "Executing session MCP tool: {} ({})", tool_name, tool_id);
257 return backend.execute(tool_call, &context).await;
258 }
259
260 let backend = self
261 .backend_registry
262 .get_backend_for_tool(tool_name)
263 .cloned()
264 .ok_or_else(|| {
265 error!(target: "tool_executor", "No backend for tool: {} ({})", tool_name, tool_id);
266 steer_tools::ToolError::UnknownTool(tool_name.clone())
267 })?;
268
269 debug!(target: "tool_executor", "Executing external tool: {} ({})", tool_name, tool_id);
270 backend.execute(tool_call, &context).await
271 }
272
273 async fn execute_builtin_tool(
274 &self,
275 tool: &dyn super::builtin_tool::BuiltinToolErased,
276 tool_call: &ToolCall,
277 session_id: SessionId,
278 invoking_model: Option<ModelId>,
279 services: &Arc<ToolServices>,
280 token: CancellationToken,
281 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
282 let ctx = BuiltinToolContext {
283 tool_call_id: ToolCallId(tool_call.id.clone()),
284 session_id,
285 invoking_model,
286 cancellation_token: token,
287 services: services.clone(),
288 };
289
290 let output = tool
291 .execute_erased(tool_call.parameters.clone(), &ctx)
292 .await
293 .map_err(|e| match e {
294 BuiltinToolError::InvalidParams(msg) => steer_tools::ToolError::InvalidParams {
295 tool_name: tool_call.name.clone(),
296 message: msg,
297 },
298 BuiltinToolError::Execution(err) => steer_tools::ToolError::Execution(err),
299 BuiltinToolError::MissingCapability(cap) => {
300 steer_tools::ToolError::InternalError(format!("Missing capability: {cap}"))
301 }
302 BuiltinToolError::Cancelled => {
303 steer_tools::ToolError::Cancelled(tool_call.name.clone())
304 }
305 BuiltinToolError::Timeout => {
306 steer_tools::ToolError::Timeout(tool_call.name.clone())
307 }
308 })?;
309
310 Ok(output)
311 }
312
313 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
315 pub async fn execute_tool_direct(
316 &self,
317 tool_call: &ToolCall,
318 token: CancellationToken,
319 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
320 let tool_name = &tool_call.name;
321 let tool_id = &tool_call.id;
322
323 Span::current().record("tool.name", tool_name);
324 Span::current().record("tool.id", tool_id);
325
326 if let Some((registry, services)) =
327 self.tool_registry.as_ref().zip(self.tool_services.as_ref())
328 && let Some(tool) = registry.builtin_tool(tool_name)
329 {
330 debug!(
331 target: "app.tool_executor.execute_tool_direct",
332 "Executing builtin tool {} ({}) directly (no validation)",
333 tool_name,
334 tool_id
335 );
336 return self
337 .execute_builtin_tool(tool, tool_call, SessionId::new(), None, services, token)
338 .await;
339 }
340
341 let mut builder = ExecutionContext::builder(
343 "direct".to_string(), "direct".to_string(),
345 tool_call.id.clone(),
346 token,
347 );
348
349 if let Some(provider) = &self.llm_config_provider {
351 builder = builder.llm_config_provider(provider.clone());
352 }
353
354 let context = builder.build();
355
356 let backend = self
358 .backend_registry
359 .get_backend_for_tool(tool_name)
360 .cloned()
361 .ok_or_else(|| {
362 error!(
363 target: "app.tool_executor.execute_tool_direct",
364 "No backend configured for tool: {} ({})",
365 tool_name,
366 tool_id
367 );
368 steer_tools::ToolError::UnknownTool(tool_name.clone())
369 })?;
370
371 debug!(
372 target: "app.tool_executor.execute_tool_direct",
373 "Executing external tool {} ({}) directly via backend (no validation)",
374 tool_name,
375 tool_id
376 );
377
378 backend.execute(tool_call, &context).await
379 }
380}