steer_core/app/
tool_executor.rs

1use crate::config::LlmConfigProvider;
2use crate::error::{Error, Result};
3use std::sync::Arc;
4use tokio_util::sync::CancellationToken;
5use tracing::{Span, debug, error, instrument};
6
7use crate::api::ToolCall;
8use crate::app::validation::{ValidationContext, ValidatorRegistry};
9use crate::tools::{BackendRegistry, ExecutionContext};
10use crate::workspace::Workspace;
11use steer_tools::ToolSchema;
12use steer_tools::{ToolError, result::ToolResult};
13
14/// Manages the execution of tools called by the AI model
15#[derive(Clone)]
16pub struct ToolExecutor {
17    /// Optional workspace for executing workspace tools
18    pub(crate) workspace: Arc<dyn Workspace>,
19    /// Registry for external tool backends (MCP servers, etc.)
20    pub(crate) backend_registry: Arc<BackendRegistry>,
21    /// Validators for tool execution
22    pub(crate) validators: Arc<ValidatorRegistry>,
23    /// Provider for LLM configuration
24    pub(crate) llm_config_provider: Option<LlmConfigProvider>,
25}
26
27impl ToolExecutor {
28    /// Create a ToolExecutor with a workspace for workspace tools
29    pub fn with_workspace(workspace: Arc<dyn Workspace>) -> Self {
30        Self {
31            workspace,
32            backend_registry: Arc::new(BackendRegistry::new()),
33            validators: Arc::new(ValidatorRegistry::new()),
34            llm_config_provider: None,
35        }
36    }
37
38    /// Create a ToolExecutor with custom components
39    pub fn with_components(
40        workspace: Arc<dyn Workspace>,
41        backend_registry: Arc<BackendRegistry>,
42        validators: Arc<ValidatorRegistry>,
43    ) -> Self {
44        Self {
45            workspace,
46            backend_registry,
47            validators,
48            llm_config_provider: None,
49        }
50    }
51
52    /// Create a ToolExecutor with all components including LLM config provider
53    pub fn with_all_components(
54        workspace: Arc<dyn Workspace>,
55        backend_registry: Arc<BackendRegistry>,
56        validators: Arc<ValidatorRegistry>,
57        llm_config_provider: LlmConfigProvider,
58    ) -> Self {
59        Self {
60            workspace,
61            backend_registry,
62            validators,
63            llm_config_provider: Some(llm_config_provider),
64        }
65    }
66
67    pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
68        // First check if it's a workspace tool
69        let workspace_tools = self.workspace.available_tools().await;
70        if workspace_tools.iter().any(|t| t.name == tool_name) {
71            return self
72                .workspace
73                .requires_approval(tool_name)
74                .await
75                .map_err(|e| {
76                    Error::Tool(steer_tools::ToolError::InternalError(format!(
77                        "Failed to check approval requirement: {e}"
78                    )))
79                });
80        }
81
82        // Otherwise check external backends
83        match self.backend_registry.get_backend_for_tool(tool_name) {
84            Some(backend) => backend.requires_approval(tool_name).await.map_err(|e| {
85                Error::Tool(steer_tools::ToolError::InternalError(format!(
86                    "Failed to check approval requirement: {e}"
87                )))
88            }),
89            None => Err(Error::Tool(steer_tools::ToolError::UnknownTool(
90                tool_name.to_string(),
91            ))),
92        }
93    }
94
95    pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
96        let mut schemas = Vec::new();
97
98        // Add workspace tools if available
99        schemas.extend(self.workspace.available_tools().await);
100
101        // Add external backend tools
102        schemas.extend(self.backend_registry.get_tool_schemas().await);
103
104        schemas
105    }
106
107    /// Get the list of supported tools
108    pub async fn supported_tools(&self) -> Vec<String> {
109        let schemas = self.get_tool_schemas().await;
110        schemas.into_iter().map(|s| s.name).collect()
111    }
112
113    /// Get the backend registry
114    pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
115        &self.backend_registry
116    }
117
118    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
119    pub async fn execute_tool_with_cancellation(
120        &self,
121        tool_call: &ToolCall,
122        token: CancellationToken,
123    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
124        let tool_name = &tool_call.name;
125        let tool_id = &tool_call.id;
126
127        Span::current().record("tool.name", tool_name);
128        Span::current().record("tool.id", tool_id);
129
130        // Pre-execution validation
131        if let Some(validator) = self.validators.get_validator(tool_name) {
132            // Only validate if we have an LLM config provider
133            if let Some(ref llm_config_provider) = self.llm_config_provider {
134                let validation_context = ValidationContext {
135                    cancellation_token: token.clone(),
136                    llm_config_provider: llm_config_provider.clone(),
137                };
138
139                let validation_result = validator
140                    .validate(tool_call, &validation_context)
141                    .await
142                    .map_err(|e| ToolError::InternalError(format!("Validation failed: {e}")))?;
143
144                if !validation_result.allowed {
145                    return Err(ToolError::InternalError(
146                        validation_result
147                            .reason
148                            .unwrap_or_else(|| "Tool execution was denied".to_string()),
149                    ));
150                }
151            }
152            // If no LLM config provider, skip validation (allow execution)
153        }
154
155        // Create execution context
156        let mut builder = ExecutionContext::builder(
157            "default".to_string(), // TODO: Get real session ID
158            "default".to_string(), // TODO: Get real operation ID
159            tool_call.id.clone(),
160            token,
161        );
162
163        // Add LLM config provider if available
164        if let Some(provider) = &self.llm_config_provider {
165            builder = builder.llm_config_provider(provider.clone());
166        }
167
168        let context = builder.build();
169
170        // First check if it's a workspace tool
171        let workspace_tools = self.workspace.available_tools().await;
172        if workspace_tools.iter().any(|t| &t.name == tool_name) {
173            debug!(
174                target: "app.tool_executor.execute_tool_with_cancellation",
175                "Executing workspace tool {} ({}) with cancellation",
176                tool_name,
177                tool_id
178            );
179
180            return self
181                .execute_workspace_tool(&self.workspace, tool_call, &context)
182                .await;
183        }
184
185        // Otherwise check external backends
186        let backend = self
187            .backend_registry
188            .get_backend_for_tool(tool_name)
189            .cloned()
190            .ok_or_else(|| {
191                error!(
192                    target: "app.tool_executor.execute_tool_with_cancellation",
193                    "No backend configured for tool: {} ({})",
194                    tool_name,
195                    tool_id
196                );
197                ToolError::UnknownTool(tool_name.clone())
198            })?;
199
200        debug!(
201            target: "app.tool_executor.execute_tool_with_cancellation",
202            "Executing external tool {} ({}) via backend with cancellation",
203            tool_name,
204            tool_id
205        );
206
207        backend.execute(tool_call, &context).await
208    }
209
210    /// Execute a tool directly without validation - for user-initiated bash commands
211    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
212    pub async fn execute_tool_direct(
213        &self,
214        tool_call: &ToolCall,
215        token: CancellationToken,
216    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
217        let tool_name = &tool_call.name;
218        let tool_id = &tool_call.id;
219
220        Span::current().record("tool.name", tool_name);
221        Span::current().record("tool.id", tool_id);
222
223        // Create execution context
224        let mut builder = ExecutionContext::builder(
225            "direct".to_string(), // Mark as direct execution
226            "direct".to_string(),
227            tool_call.id.clone(),
228            token,
229        );
230
231        // Add LLM config provider if available
232        if let Some(provider) = &self.llm_config_provider {
233            builder = builder.llm_config_provider(provider.clone());
234        }
235
236        let context = builder.build();
237
238        // First check if it's a workspace tool (no validation for direct execution)
239        let workspace_tools = self.workspace.available_tools().await;
240        if workspace_tools.iter().any(|t| &t.name == tool_name) {
241            debug!(
242                target: "app.tool_executor.execute_tool_direct",
243                "Executing workspace tool {} ({}) directly (no validation)",
244                tool_name,
245                tool_id
246            );
247
248            return self
249                .execute_workspace_tool(&self.workspace, tool_call, &context)
250                .await;
251        }
252
253        // Otherwise check external backends
254        let backend = self
255            .backend_registry
256            .get_backend_for_tool(tool_name)
257            .cloned()
258            .ok_or_else(|| {
259                error!(
260                    target: "app.tool_executor.execute_tool_direct",
261                    "No backend configured for tool: {} ({})",
262                    tool_name,
263                    tool_id
264                );
265                ToolError::UnknownTool(tool_name.clone())
266            })?;
267
268        debug!(
269            target: "app.tool_executor.execute_tool_direct",
270            "Executing external tool {} ({}) directly via backend (no validation)",
271            tool_name,
272            tool_id
273        );
274
275        backend.execute(tool_call, &context).await
276    }
277
278    /// Helper method to execute a workspace tool
279    async fn execute_workspace_tool(
280        &self,
281        workspace: &Arc<dyn Workspace>,
282        tool_call: &ToolCall,
283        context: &ExecutionContext,
284    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
285        // Convert ExecutionContext to steer-tools ExecutionContext
286        let tools_context = steer_tools::ExecutionContext::new(context.tool_call_id.clone())
287            .with_cancellation_token(context.cancellation_token.clone());
288
289        workspace
290            .execute_tool(tool_call, tools_context)
291            .await
292            .map_err(|e| ToolError::InternalError(format!("Workspace execution failed: {e}")))
293    }
294}