Skip to main content

steer_core/tools/
executor.rs

1use crate::app::domain::types::{SessionId, ToolCallId};
2use crate::config::LlmConfigProvider;
3use crate::config::model::ModelId;
4use crate::tools::error::Result;
5use std::sync::Arc;
6use tokio_util::sync::CancellationToken;
7use tracing::{Span, debug, error, instrument};
8
9use crate::app::validation::{ValidationContext, ValidatorRegistry};
10use crate::tools::builtin_tool::{BuiltinToolContext, BuiltinToolError};
11use crate::tools::registry::ToolRegistry;
12use crate::tools::resolver::BackendResolver;
13use crate::tools::services::ToolServices;
14use crate::tools::{BackendRegistry, ExecutionContext};
15use steer_tools::{ToolCall, ToolSchema, result::ToolResult};
16
17#[derive(Clone)]
18pub struct ToolExecutor {
19    pub(crate) backend_registry: Arc<BackendRegistry>,
20    pub(crate) validators: Arc<ValidatorRegistry>,
21    pub(crate) llm_config_provider: Option<LlmConfigProvider>,
22    pub(crate) tool_registry: Option<Arc<ToolRegistry>>,
23    pub(crate) tool_services: Option<Arc<ToolServices>>,
24}
25
26impl ToolExecutor {
27    pub fn with_components(
28        backend_registry: Arc<BackendRegistry>,
29        validators: Arc<ValidatorRegistry>,
30    ) -> Self {
31        Self {
32            backend_registry,
33            validators,
34            llm_config_provider: None,
35            tool_registry: None,
36            tool_services: None,
37        }
38    }
39
40    pub fn with_all_components(
41        backend_registry: Arc<BackendRegistry>,
42        validators: Arc<ValidatorRegistry>,
43        llm_config_provider: LlmConfigProvider,
44    ) -> Self {
45        Self {
46            backend_registry,
47            validators,
48            llm_config_provider: Some(llm_config_provider),
49            tool_registry: None,
50            tool_services: None,
51        }
52    }
53
54    pub fn with_builtin_tools(
55        mut self,
56        registry: Arc<ToolRegistry>,
57        services: Arc<ToolServices>,
58    ) -> Self {
59        self.tool_registry = Some(registry);
60        self.tool_services = Some(services);
61        self
62    }
63
64    pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
65        if let Some(registry) = &self.tool_registry
66            && registry.is_builtin_tool(tool_name)
67        {
68            return Ok(registry.requires_approval(tool_name));
69        }
70
71        match self.backend_registry.get_backend_for_tool(tool_name) {
72            Some(backend) => Ok(backend.requires_approval(tool_name).await?),
73            None => Err(steer_tools::ToolError::UnknownTool(tool_name.to_string()).into()),
74        }
75    }
76
77    pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
78        self.get_tool_schemas_with_capabilities(super::Capabilities::all())
79            .await
80    }
81
82    pub async fn get_tool_schemas_with_resolver(
83        &self,
84        session_resolver: Option<&dyn BackendResolver>,
85    ) -> Vec<ToolSchema> {
86        self.get_tool_schemas_with_capabilities_and_resolver(
87            super::Capabilities::all(),
88            session_resolver,
89        )
90        .await
91    }
92
93    pub async fn get_tool_schemas_with_capabilities(
94        &self,
95        capabilities: super::Capabilities,
96    ) -> Vec<ToolSchema> {
97        self.get_tool_schemas_with_capabilities_and_resolver(capabilities, None)
98            .await
99    }
100
101    pub async fn get_tool_schemas_with_capabilities_and_resolver(
102        &self,
103        capabilities: super::Capabilities,
104        session_resolver: Option<&dyn BackendResolver>,
105    ) -> Vec<ToolSchema> {
106        let mut schemas = Vec::new();
107        let mut builtin_tool_names = std::collections::HashSet::new();
108
109        if let Some(registry) = &self.tool_registry {
110            let builtin_schemas = registry.available_schemas(capabilities).await;
111            for schema in &builtin_schemas {
112                builtin_tool_names.insert(schema.name.clone());
113            }
114            schemas.extend(builtin_schemas);
115        }
116
117        if let Some(resolver) = session_resolver {
118            for schema in resolver.get_tool_schemas().await {
119                if !builtin_tool_names.contains(&schema.name) {
120                    schemas.push(schema);
121                }
122            }
123        }
124
125        for schema in self.backend_registry.get_tool_schemas().await {
126            if !builtin_tool_names.contains(&schema.name) {
127                schemas.push(schema);
128            }
129        }
130
131        schemas
132    }
133
134    pub fn is_builtin_tool(&self, tool_name: &str) -> bool {
135        self.tool_registry
136            .as_ref()
137            .is_some_and(|r| r.is_builtin_tool(tool_name))
138    }
139
140    /// Get the list of supported tools
141    pub async fn supported_tools(&self) -> Vec<String> {
142        let schemas = self.get_tool_schemas().await;
143        schemas.into_iter().map(|s| s.name).collect()
144    }
145
146    /// Get the backend registry
147    pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
148        &self.backend_registry
149    }
150
151    pub fn workspace(&self) -> Option<Arc<dyn crate::workspace::Workspace>> {
152        self.tool_services
153            .as_ref()
154            .map(|services| services.workspace.clone())
155    }
156
157    #[instrument(skip(self, tool_call, session_id, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
158    pub async fn execute_tool_with_session(
159        &self,
160        tool_call: &ToolCall,
161        session_id: SessionId,
162        token: CancellationToken,
163    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
164        self.execute_tool_with_session_resolver(tool_call, session_id, None, token, None)
165            .await
166    }
167
168    #[instrument(skip(self, tool_call, session_id, invoking_model, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
169    pub async fn execute_tool_with_session_resolver(
170        &self,
171        tool_call: &ToolCall,
172        session_id: SessionId,
173        invoking_model: Option<ModelId>,
174        token: CancellationToken,
175        session_resolver: Option<&dyn BackendResolver>,
176    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
177        let tool_name = &tool_call.name;
178
179        if let Some((registry, services)) =
180            self.tool_registry.as_ref().zip(self.tool_services.as_ref())
181            && let Some(tool) = registry.builtin_tool(tool_name)
182        {
183            debug!(target: "tool_executor", "Executing builtin tool: {}", tool_name);
184            return self
185                .execute_builtin_tool(tool, tool_call, session_id, invoking_model, services, token)
186                .await;
187        }
188
189        self.execute_tool_with_resolver(tool_call, token, session_resolver)
190            .await
191    }
192
193    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
194    pub async fn execute_tool_with_cancellation(
195        &self,
196        tool_call: &ToolCall,
197        token: CancellationToken,
198    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
199        self.execute_tool_with_resolver(tool_call, token, None)
200            .await
201    }
202
203    #[instrument(skip(self, tool_call, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
204    pub async fn execute_tool_with_resolver(
205        &self,
206        tool_call: &ToolCall,
207        token: CancellationToken,
208        session_resolver: Option<&dyn BackendResolver>,
209    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
210        let tool_name = &tool_call.name;
211        let tool_id = &tool_call.id;
212
213        Span::current().record("tool.name", tool_name);
214        Span::current().record("tool.id", tool_id);
215
216        if let Some(validator) = self.validators.get_validator(tool_name)
217            && let Some(ref llm_config_provider) = self.llm_config_provider
218        {
219            let validation_context = ValidationContext {
220                cancellation_token: token.clone(),
221                llm_config_provider: llm_config_provider.clone(),
222            };
223
224            let validation_result = validator
225                .validate(tool_call, &validation_context)
226                .await
227                .map_err(|e| {
228                    steer_tools::ToolError::InternalError(format!("Validation failed: {e}"))
229                })?;
230
231            if !validation_result.allowed {
232                return Err(steer_tools::ToolError::InternalError(
233                    validation_result
234                        .reason
235                        .unwrap_or_else(|| "Tool execution was denied".to_string()),
236                ));
237            }
238        }
239
240        let mut builder = ExecutionContext::builder(
241            "default".to_string(),
242            "default".to_string(),
243            tool_call.id.clone(),
244            token,
245        );
246
247        if let Some(provider) = &self.llm_config_provider {
248            builder = builder.llm_config_provider(provider.clone());
249        }
250
251        let context = builder.build();
252
253        if let Some(resolver) = session_resolver
254            && let Some(backend) = resolver.resolve(tool_name).await
255        {
256            debug!(target: "tool_executor", "Executing session MCP tool: {} ({})", tool_name, tool_id);
257            return backend.execute(tool_call, &context).await;
258        }
259
260        let backend = self
261            .backend_registry
262            .get_backend_for_tool(tool_name)
263            .cloned()
264            .ok_or_else(|| {
265                error!(target: "tool_executor", "No backend for tool: {} ({})", tool_name, tool_id);
266                steer_tools::ToolError::UnknownTool(tool_name.clone())
267            })?;
268
269        debug!(target: "tool_executor", "Executing external tool: {} ({})", tool_name, tool_id);
270        backend.execute(tool_call, &context).await
271    }
272
273    async fn execute_builtin_tool(
274        &self,
275        tool: &dyn super::builtin_tool::BuiltinToolErased,
276        tool_call: &ToolCall,
277        session_id: SessionId,
278        invoking_model: Option<ModelId>,
279        services: &Arc<ToolServices>,
280        token: CancellationToken,
281    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
282        let ctx = BuiltinToolContext {
283            tool_call_id: ToolCallId(tool_call.id.clone()),
284            session_id,
285            invoking_model,
286            cancellation_token: token,
287            services: services.clone(),
288        };
289
290        let output = tool
291            .execute_erased(tool_call.parameters.clone(), &ctx)
292            .await
293            .map_err(|e| match e {
294                BuiltinToolError::InvalidParams(msg) => steer_tools::ToolError::InvalidParams {
295                    tool_name: tool_call.name.clone(),
296                    message: msg,
297                },
298                BuiltinToolError::Execution(err) => steer_tools::ToolError::Execution(err),
299                BuiltinToolError::MissingCapability(cap) => {
300                    steer_tools::ToolError::InternalError(format!("Missing capability: {cap}"))
301                }
302                BuiltinToolError::Cancelled => {
303                    steer_tools::ToolError::Cancelled(tool_call.name.clone())
304                }
305                BuiltinToolError::Timeout => {
306                    steer_tools::ToolError::Timeout(tool_call.name.clone())
307                }
308            })?;
309
310        Ok(output)
311    }
312
313    /// Execute a tool directly without validation - for user-initiated bash commands
314    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
315    pub async fn execute_tool_direct(
316        &self,
317        tool_call: &ToolCall,
318        token: CancellationToken,
319    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
320        let tool_name = &tool_call.name;
321        let tool_id = &tool_call.id;
322
323        Span::current().record("tool.name", tool_name);
324        Span::current().record("tool.id", tool_id);
325
326        if let Some((registry, services)) =
327            self.tool_registry.as_ref().zip(self.tool_services.as_ref())
328            && let Some(tool) = registry.builtin_tool(tool_name)
329        {
330            debug!(
331                target: "app.tool_executor.execute_tool_direct",
332                "Executing builtin tool {} ({}) directly (no validation)",
333                tool_name,
334                tool_id
335            );
336            return self
337                .execute_builtin_tool(tool, tool_call, SessionId::new(), None, services, token)
338                .await;
339        }
340
341        // Create execution context
342        let mut builder = ExecutionContext::builder(
343            "direct".to_string(), // Mark as direct execution
344            "direct".to_string(),
345            tool_call.id.clone(),
346            token,
347        );
348
349        // Add LLM config provider if available
350        if let Some(provider) = &self.llm_config_provider {
351            builder = builder.llm_config_provider(provider.clone());
352        }
353
354        let context = builder.build();
355
356        // Otherwise check external backends
357        let backend = self
358            .backend_registry
359            .get_backend_for_tool(tool_name)
360            .cloned()
361            .ok_or_else(|| {
362                error!(
363                    target: "app.tool_executor.execute_tool_direct",
364                    "No backend configured for tool: {} ({})",
365                    tool_name,
366                    tool_id
367                );
368                steer_tools::ToolError::UnknownTool(tool_name.clone())
369            })?;
370
371        debug!(
372            target: "app.tool_executor.execute_tool_direct",
373            "Executing external tool {} ({}) directly via backend (no validation)",
374            tool_name,
375            tool_id
376        );
377
378        backend.execute(tool_call, &context).await
379    }
380}