steer_core/tools/
executor.rs

1use crate::config::LlmConfigProvider;
2use crate::tools::error::Result;
3use std::sync::Arc;
4use tokio_util::sync::CancellationToken;
5use tracing::{Span, debug, error, instrument};
6
7use crate::app::validation::{ValidationContext, ValidatorRegistry};
8use crate::tools::{BackendRegistry, ExecutionContext};
9use crate::workspace::Workspace;
10use steer_tools::{ToolCall, ToolSchema, result::ToolResult};
11
12/// Manages the execution of tools called by the AI model
13#[derive(Clone)]
14pub struct ToolExecutor {
15    /// Optional workspace for executing workspace tools
16    pub(crate) workspace: Arc<dyn Workspace>,
17    /// Registry for external tool backends (MCP servers, etc.)
18    pub(crate) backend_registry: Arc<BackendRegistry>,
19    /// Validators for tool execution
20    pub(crate) validators: Arc<ValidatorRegistry>,
21    /// Provider for LLM configuration
22    pub(crate) llm_config_provider: Option<LlmConfigProvider>,
23}
24
25impl ToolExecutor {
26    /// Create a ToolExecutor with a workspace for workspace tools
27    pub fn with_workspace(workspace: Arc<dyn Workspace>) -> Self {
28        Self {
29            workspace,
30            backend_registry: Arc::new(BackendRegistry::new()),
31            validators: Arc::new(ValidatorRegistry::new()),
32            llm_config_provider: None,
33        }
34    }
35
36    /// Create a ToolExecutor with custom components
37    pub fn with_components(
38        workspace: Arc<dyn Workspace>,
39        backend_registry: Arc<BackendRegistry>,
40        validators: Arc<ValidatorRegistry>,
41    ) -> Self {
42        Self {
43            workspace,
44            backend_registry,
45            validators,
46            llm_config_provider: None,
47        }
48    }
49
50    /// Create a ToolExecutor with all components including LLM config provider
51    pub fn with_all_components(
52        workspace: Arc<dyn Workspace>,
53        backend_registry: Arc<BackendRegistry>,
54        validators: Arc<ValidatorRegistry>,
55        llm_config_provider: LlmConfigProvider,
56    ) -> Self {
57        Self {
58            workspace,
59            backend_registry,
60            validators,
61            llm_config_provider: Some(llm_config_provider),
62        }
63    }
64
65    pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
66        // First check if it's a workspace tool
67        let workspace_tools = self.workspace.available_tools().await;
68        if workspace_tools.iter().any(|t| t.name == tool_name) {
69            return Ok(self.workspace.requires_approval(tool_name).await?);
70        }
71
72        // Otherwise check external backends
73        match self.backend_registry.get_backend_for_tool(tool_name) {
74            Some(backend) => Ok(backend.requires_approval(tool_name).await?),
75            None => Err(steer_tools::ToolError::UnknownTool(tool_name.to_string()).into()),
76        }
77    }
78
79    pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
80        let mut schemas = Vec::new();
81
82        // Add workspace tools if available
83        schemas.extend(self.workspace.available_tools().await);
84
85        // Add external backend tools
86        schemas.extend(self.backend_registry.get_tool_schemas().await);
87
88        schemas
89    }
90
91    /// Get the list of supported tools
92    pub async fn supported_tools(&self) -> Vec<String> {
93        let schemas = self.get_tool_schemas().await;
94        schemas.into_iter().map(|s| s.name).collect()
95    }
96
97    /// Get the backend registry
98    pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
99        &self.backend_registry
100    }
101
102    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
103    pub async fn execute_tool_with_cancellation(
104        &self,
105        tool_call: &ToolCall,
106        token: CancellationToken,
107    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
108        let tool_name = &tool_call.name;
109        let tool_id = &tool_call.id;
110
111        Span::current().record("tool.name", tool_name);
112        Span::current().record("tool.id", tool_id);
113
114        // Pre-execution validation
115        if let Some(validator) = self.validators.get_validator(tool_name) {
116            // Only validate if we have an LLM config provider
117            if let Some(ref llm_config_provider) = self.llm_config_provider {
118                let validation_context = ValidationContext {
119                    cancellation_token: token.clone(),
120                    llm_config_provider: llm_config_provider.clone(),
121                };
122
123                let validation_result = validator
124                    .validate(tool_call, &validation_context)
125                    .await
126                    .map_err(|e| {
127                        steer_tools::ToolError::InternalError(format!("Validation failed: {e}"))
128                    })?;
129
130                if !validation_result.allowed {
131                    return Err(steer_tools::ToolError::InternalError(
132                        validation_result
133                            .reason
134                            .unwrap_or_else(|| "Tool execution was denied".to_string()),
135                    ));
136                }
137            }
138            // If no LLM config provider, skip validation (allow execution)
139        }
140
141        // Create execution context
142        let mut builder = ExecutionContext::builder(
143            "default".to_string(), // TODO: Get real session ID
144            "default".to_string(), // TODO: Get real operation ID
145            tool_call.id.clone(),
146            token,
147        );
148
149        // Add LLM config provider if available
150        if let Some(provider) = &self.llm_config_provider {
151            builder = builder.llm_config_provider(provider.clone());
152        }
153
154        let context = builder.build();
155
156        // First check if it's a workspace tool
157        let workspace_tools = self.workspace.available_tools().await;
158        if workspace_tools.iter().any(|t| &t.name == tool_name) {
159            debug!(
160                target: "app.tool_executor.execute_tool_with_cancellation",
161                "Executing workspace tool {} ({}) with cancellation",
162                tool_name,
163                tool_id
164            );
165
166            return self
167                .execute_workspace_tool(&self.workspace, tool_call, &context)
168                .await;
169        }
170
171        // Otherwise check external backends
172        let backend = self
173            .backend_registry
174            .get_backend_for_tool(tool_name)
175            .cloned()
176            .ok_or_else(|| {
177                error!(
178                    target: "app.tool_executor.execute_tool_with_cancellation",
179                    "No backend configured for tool: {} ({})",
180                    tool_name,
181                    tool_id
182                );
183                steer_tools::ToolError::UnknownTool(tool_name.clone())
184            })?;
185
186        debug!(
187            target: "app.tool_executor.execute_tool_with_cancellation",
188            "Executing external tool {} ({}) via backend with cancellation",
189            tool_name,
190            tool_id
191        );
192
193        backend.execute(tool_call, &context).await
194    }
195
196    /// Execute a tool directly without validation - for user-initiated bash commands
197    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
198    pub async fn execute_tool_direct(
199        &self,
200        tool_call: &ToolCall,
201        token: CancellationToken,
202    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
203        let tool_name = &tool_call.name;
204        let tool_id = &tool_call.id;
205
206        Span::current().record("tool.name", tool_name);
207        Span::current().record("tool.id", tool_id);
208
209        // Create execution context
210        let mut builder = ExecutionContext::builder(
211            "direct".to_string(), // Mark as direct execution
212            "direct".to_string(),
213            tool_call.id.clone(),
214            token,
215        );
216
217        // Add LLM config provider if available
218        if let Some(provider) = &self.llm_config_provider {
219            builder = builder.llm_config_provider(provider.clone());
220        }
221
222        let context = builder.build();
223
224        // First check if it's a workspace tool (no validation for direct execution)
225        let workspace_tools = self.workspace.available_tools().await;
226        if workspace_tools.iter().any(|t| &t.name == tool_name) {
227            debug!(
228                target: "app.tool_executor.execute_tool_direct",
229                "Executing workspace tool {} ({}) directly (no validation)",
230                tool_name,
231                tool_id
232            );
233
234            return self
235                .execute_workspace_tool(&self.workspace, tool_call, &context)
236                .await;
237        }
238
239        // Otherwise check external backends
240        let backend = self
241            .backend_registry
242            .get_backend_for_tool(tool_name)
243            .cloned()
244            .ok_or_else(|| {
245                error!(
246                    target: "app.tool_executor.execute_tool_direct",
247                    "No backend configured for tool: {} ({})",
248                    tool_name,
249                    tool_id
250                );
251                steer_tools::ToolError::UnknownTool(tool_name.clone())
252            })?;
253
254        debug!(
255            target: "app.tool_executor.execute_tool_direct",
256            "Executing external tool {} ({}) directly via backend (no validation)",
257            tool_name,
258            tool_id
259        );
260
261        backend.execute(tool_call, &context).await
262    }
263
264    /// Helper method to execute a workspace tool
265    async fn execute_workspace_tool(
266        &self,
267        workspace: &Arc<dyn Workspace>,
268        tool_call: &ToolCall,
269        context: &ExecutionContext,
270    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
271        // Convert ExecutionContext to steer-tools ExecutionContext
272        let tools_context = steer_tools::ExecutionContext::new(context.tool_call_id.clone())
273            .with_cancellation_token(context.cancellation_token.clone());
274
275        workspace
276            .execute_tool(tool_call, tools_context)
277            .await
278            .map_err(|e| {
279                // Map WorkspaceError variants to structured ToolError
280                use steer_workspace::WorkspaceError;
281                match e {
282                    WorkspaceError::ToolExecution(msg) => steer_tools::ToolError::Execution {
283                        tool_name: tool_call.name.clone(),
284                        message: msg,
285                    },
286                    WorkspaceError::Io(msg) => steer_tools::ToolError::Io {
287                        tool_name: tool_call.name.clone(),
288                        message: msg,
289                    },
290                    _ => steer_tools::ToolError::InternalError(e.to_string()),
291                }
292            })
293    }
294}