1use crate::app::domain::types::{SessionId, ToolCallId};
2use crate::config::LlmConfigProvider;
3use crate::config::model::ModelId;
4use crate::tools::error::Result;
5use std::sync::Arc;
6use tokio_util::sync::CancellationToken;
7use tracing::{Span, debug, error, instrument};
8
9use crate::app::validation::{ValidationContext, ValidatorRegistry};
10use crate::tools::registry::ToolRegistry;
11use crate::tools::resolver::BackendResolver;
12use crate::tools::services::ToolServices;
13use crate::tools::static_tool::{StaticToolContext, StaticToolError};
14use crate::tools::{BackendRegistry, ExecutionContext};
15use steer_tools::{ToolCall, ToolSchema, result::ToolResult};
16
17#[derive(Clone)]
18pub struct ToolExecutor {
19 pub(crate) backend_registry: Arc<BackendRegistry>,
20 pub(crate) validators: Arc<ValidatorRegistry>,
21 pub(crate) llm_config_provider: Option<LlmConfigProvider>,
22 pub(crate) tool_registry: Option<Arc<ToolRegistry>>,
23 pub(crate) tool_services: Option<Arc<ToolServices>>,
24}
25
26impl ToolExecutor {
27 pub fn with_components(
28 backend_registry: Arc<BackendRegistry>,
29 validators: Arc<ValidatorRegistry>,
30 ) -> Self {
31 Self {
32 backend_registry,
33 validators,
34 llm_config_provider: None,
35 tool_registry: None,
36 tool_services: None,
37 }
38 }
39
40 pub fn with_all_components(
41 backend_registry: Arc<BackendRegistry>,
42 validators: Arc<ValidatorRegistry>,
43 llm_config_provider: LlmConfigProvider,
44 ) -> Self {
45 Self {
46 backend_registry,
47 validators,
48 llm_config_provider: Some(llm_config_provider),
49 tool_registry: None,
50 tool_services: None,
51 }
52 }
53
54 pub fn with_static_tools(
55 mut self,
56 registry: Arc<ToolRegistry>,
57 services: Arc<ToolServices>,
58 ) -> Self {
59 self.tool_registry = Some(registry);
60 self.tool_services = Some(services);
61 self
62 }
63
64 pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
65 if let Some(registry) = &self.tool_registry
66 && registry.is_static_tool(tool_name)
67 {
68 return Ok(registry.requires_approval(tool_name));
69 }
70
71 match self.backend_registry.get_backend_for_tool(tool_name) {
72 Some(backend) => Ok(backend.requires_approval(tool_name).await?),
73 None => Err(steer_tools::ToolError::UnknownTool(tool_name.to_string()).into()),
74 }
75 }
76
77 pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
78 self.get_tool_schemas_with_capabilities(super::Capabilities::all())
79 .await
80 }
81
82 pub async fn get_tool_schemas_with_resolver(
83 &self,
84 session_resolver: Option<&dyn BackendResolver>,
85 ) -> Vec<ToolSchema> {
86 self.get_tool_schemas_with_capabilities_and_resolver(
87 super::Capabilities::all(),
88 session_resolver,
89 )
90 .await
91 }
92
93 pub async fn get_tool_schemas_with_capabilities(
94 &self,
95 capabilities: super::Capabilities,
96 ) -> Vec<ToolSchema> {
97 self.get_tool_schemas_with_capabilities_and_resolver(capabilities, None)
98 .await
99 }
100
101 pub async fn get_tool_schemas_with_capabilities_and_resolver(
102 &self,
103 capabilities: super::Capabilities,
104 session_resolver: Option<&dyn BackendResolver>,
105 ) -> Vec<ToolSchema> {
106 let mut schemas = Vec::new();
107 let mut static_tool_names = std::collections::HashSet::new();
108
109 if let Some(registry) = &self.tool_registry {
110 let static_schemas = registry.available_schemas(capabilities).await;
111 for schema in &static_schemas {
112 static_tool_names.insert(schema.name.clone());
113 }
114 schemas.extend(static_schemas);
115 }
116
117 if let Some(resolver) = session_resolver {
118 for schema in resolver.get_tool_schemas().await {
119 if !static_tool_names.contains(&schema.name) {
120 schemas.push(schema);
121 }
122 }
123 }
124
125 for schema in self.backend_registry.get_tool_schemas().await {
126 if !static_tool_names.contains(&schema.name) {
127 schemas.push(schema);
128 }
129 }
130
131 schemas
132 }
133
134 pub fn is_static_tool(&self, tool_name: &str) -> bool {
135 self.tool_registry
136 .as_ref()
137 .is_some_and(|r| r.is_static_tool(tool_name))
138 }
139
140 pub async fn supported_tools(&self) -> Vec<String> {
142 let schemas = self.get_tool_schemas().await;
143 schemas.into_iter().map(|s| s.name).collect()
144 }
145
146 pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
148 &self.backend_registry
149 }
150
151 pub fn workspace(&self) -> Option<Arc<dyn crate::workspace::Workspace>> {
152 self.tool_services
153 .as_ref()
154 .map(|services| services.workspace.clone())
155 }
156
157 #[instrument(skip(self, tool_call, session_id, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
158 pub async fn execute_tool_with_session(
159 &self,
160 tool_call: &ToolCall,
161 session_id: SessionId,
162 token: CancellationToken,
163 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
164 self.execute_tool_with_session_resolver(tool_call, session_id, None, token, None)
165 .await
166 }
167
168 #[instrument(skip(self, tool_call, session_id, invoking_model, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
169 pub async fn execute_tool_with_session_resolver(
170 &self,
171 tool_call: &ToolCall,
172 session_id: SessionId,
173 invoking_model: Option<ModelId>,
174 token: CancellationToken,
175 session_resolver: Option<&dyn BackendResolver>,
176 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
177 let tool_name = &tool_call.name;
178
179 if let Some((registry, services)) =
180 self.tool_registry.as_ref().zip(self.tool_services.as_ref())
181 && let Some(tool) = registry.static_tool(tool_name)
182 {
183 debug!(target: "tool_executor", "Executing static tool: {}", tool_name);
184 return self
185 .execute_static_tool(tool, tool_call, session_id, invoking_model, services, token)
186 .await;
187 }
188
189 self.execute_tool_with_resolver(tool_call, token, session_resolver)
190 .await
191 }
192
193 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
194 pub async fn execute_tool_with_cancellation(
195 &self,
196 tool_call: &ToolCall,
197 token: CancellationToken,
198 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
199 self.execute_tool_with_resolver(tool_call, token, None)
200 .await
201 }
202
203 #[instrument(skip(self, tool_call, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
204 pub async fn execute_tool_with_resolver(
205 &self,
206 tool_call: &ToolCall,
207 token: CancellationToken,
208 session_resolver: Option<&dyn BackendResolver>,
209 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
210 let tool_name = &tool_call.name;
211 let tool_id = &tool_call.id;
212
213 Span::current().record("tool.name", tool_name);
214 Span::current().record("tool.id", tool_id);
215
216 if let Some(validator) = self.validators.get_validator(tool_name)
217 && let Some(ref llm_config_provider) = self.llm_config_provider
218 {
219 let validation_context = ValidationContext {
220 cancellation_token: token.clone(),
221 llm_config_provider: llm_config_provider.clone(),
222 };
223
224 let validation_result = validator
225 .validate(tool_call, &validation_context)
226 .await
227 .map_err(|e| {
228 steer_tools::ToolError::InternalError(format!("Validation failed: {e}"))
229 })?;
230
231 if !validation_result.allowed {
232 return Err(steer_tools::ToolError::InternalError(
233 validation_result
234 .reason
235 .unwrap_or_else(|| "Tool execution was denied".to_string()),
236 ));
237 }
238 }
239
240 let mut builder = ExecutionContext::builder(
241 "default".to_string(),
242 "default".to_string(),
243 tool_call.id.clone(),
244 token,
245 );
246
247 if let Some(provider) = &self.llm_config_provider {
248 builder = builder.llm_config_provider(provider.clone());
249 }
250
251 let context = builder.build();
252
253 if let Some(resolver) = session_resolver
254 && let Some(backend) = resolver.resolve(tool_name).await
255 {
256 debug!(target: "tool_executor", "Executing session MCP tool: {} ({})", tool_name, tool_id);
257 return backend.execute(tool_call, &context).await;
258 }
259
260 let backend = self
261 .backend_registry
262 .get_backend_for_tool(tool_name)
263 .cloned()
264 .ok_or_else(|| {
265 error!(target: "tool_executor", "No backend for tool: {} ({})", tool_name, tool_id);
266 steer_tools::ToolError::UnknownTool(tool_name.clone())
267 })?;
268
269 debug!(target: "tool_executor", "Executing external tool: {} ({})", tool_name, tool_id);
270 backend.execute(tool_call, &context).await
271 }
272
273 async fn execute_static_tool(
274 &self,
275 tool: &dyn super::static_tool::StaticToolErased,
276 tool_call: &ToolCall,
277 session_id: SessionId,
278 invoking_model: Option<ModelId>,
279 services: &Arc<ToolServices>,
280 token: CancellationToken,
281 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
282 let ctx = StaticToolContext {
283 tool_call_id: ToolCallId(tool_call.id.clone()),
284 session_id,
285 invoking_model,
286 cancellation_token: token,
287 services: services.clone(),
288 };
289
290 let output = tool
291 .execute_erased(tool_call.parameters.clone(), &ctx)
292 .await
293 .map_err(|e| match e {
294 StaticToolError::InvalidParams(msg) => steer_tools::ToolError::InvalidParams {
295 tool_name: tool_call.name.clone(),
296 message: msg,
297 },
298 StaticToolError::Execution(err) => steer_tools::ToolError::Execution(err),
299 StaticToolError::MissingCapability(cap) => {
300 steer_tools::ToolError::InternalError(format!("Missing capability: {cap}"))
301 }
302 StaticToolError::Cancelled => {
303 steer_tools::ToolError::Cancelled(tool_call.name.clone())
304 }
305 StaticToolError::Timeout => steer_tools::ToolError::Timeout(tool_call.name.clone()),
306 })?;
307
308 Ok(output)
309 }
310
311 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
313 pub async fn execute_tool_direct(
314 &self,
315 tool_call: &ToolCall,
316 token: CancellationToken,
317 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
318 let tool_name = &tool_call.name;
319 let tool_id = &tool_call.id;
320
321 Span::current().record("tool.name", tool_name);
322 Span::current().record("tool.id", tool_id);
323
324 if let Some((registry, services)) =
325 self.tool_registry.as_ref().zip(self.tool_services.as_ref())
326 && let Some(tool) = registry.static_tool(tool_name)
327 {
328 debug!(
329 target: "app.tool_executor.execute_tool_direct",
330 "Executing static tool {} ({}) directly (no validation)",
331 tool_name,
332 tool_id
333 );
334 return self
335 .execute_static_tool(tool, tool_call, SessionId::new(), None, services, token)
336 .await;
337 }
338
339 let mut builder = ExecutionContext::builder(
341 "direct".to_string(), "direct".to_string(),
343 tool_call.id.clone(),
344 token,
345 );
346
347 if let Some(provider) = &self.llm_config_provider {
349 builder = builder.llm_config_provider(provider.clone());
350 }
351
352 let context = builder.build();
353
354 let backend = self
356 .backend_registry
357 .get_backend_for_tool(tool_name)
358 .cloned()
359 .ok_or_else(|| {
360 error!(
361 target: "app.tool_executor.execute_tool_direct",
362 "No backend configured for tool: {} ({})",
363 tool_name,
364 tool_id
365 );
366 steer_tools::ToolError::UnknownTool(tool_name.clone())
367 })?;
368
369 debug!(
370 target: "app.tool_executor.execute_tool_direct",
371 "Executing external tool {} ({}) directly via backend (no validation)",
372 tool_name,
373 tool_id
374 );
375
376 backend.execute(tool_call, &context).await
377 }
378}