steer_core/tools/
executor.rs

1use crate::config::LlmConfigProvider;
2use crate::tools::error::Result;
3use std::sync::Arc;
4use tokio_util::sync::CancellationToken;
5use tracing::{Span, debug, error, instrument};
6
7use crate::api::ToolCall;
8use crate::app::validation::{ValidationContext, ValidatorRegistry};
9use crate::tools::{BackendRegistry, ExecutionContext};
10use crate::workspace::Workspace;
11use steer_tools::ToolSchema;
12use steer_tools::result::ToolResult;
13
14/// Manages the execution of tools called by the AI model
15#[derive(Clone)]
16pub struct ToolExecutor {
17    /// Optional workspace for executing workspace tools
18    pub(crate) workspace: Arc<dyn Workspace>,
19    /// Registry for external tool backends (MCP servers, etc.)
20    pub(crate) backend_registry: Arc<BackendRegistry>,
21    /// Validators for tool execution
22    pub(crate) validators: Arc<ValidatorRegistry>,
23    /// Provider for LLM configuration
24    pub(crate) llm_config_provider: Option<LlmConfigProvider>,
25}
26
27impl ToolExecutor {
28    /// Create a ToolExecutor with a workspace for workspace tools
29    pub fn with_workspace(workspace: Arc<dyn Workspace>) -> Self {
30        Self {
31            workspace,
32            backend_registry: Arc::new(BackendRegistry::new()),
33            validators: Arc::new(ValidatorRegistry::new()),
34            llm_config_provider: None,
35        }
36    }
37
38    /// Create a ToolExecutor with custom components
39    pub fn with_components(
40        workspace: Arc<dyn Workspace>,
41        backend_registry: Arc<BackendRegistry>,
42        validators: Arc<ValidatorRegistry>,
43    ) -> Self {
44        Self {
45            workspace,
46            backend_registry,
47            validators,
48            llm_config_provider: None,
49        }
50    }
51
52    /// Create a ToolExecutor with all components including LLM config provider
53    pub fn with_all_components(
54        workspace: Arc<dyn Workspace>,
55        backend_registry: Arc<BackendRegistry>,
56        validators: Arc<ValidatorRegistry>,
57        llm_config_provider: LlmConfigProvider,
58    ) -> Self {
59        Self {
60            workspace,
61            backend_registry,
62            validators,
63            llm_config_provider: Some(llm_config_provider),
64        }
65    }
66
67    pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
68        // First check if it's a workspace tool
69        let workspace_tools = self.workspace.available_tools().await;
70        if workspace_tools.iter().any(|t| t.name == tool_name) {
71            return Ok(self.workspace.requires_approval(tool_name).await?);
72        }
73
74        // Otherwise check external backends
75        match self.backend_registry.get_backend_for_tool(tool_name) {
76            Some(backend) => Ok(backend.requires_approval(tool_name).await?),
77            None => Err(steer_tools::ToolError::UnknownTool(tool_name.to_string()).into()),
78        }
79    }
80
81    pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
82        let mut schemas = Vec::new();
83
84        // Add workspace tools if available
85        schemas.extend(self.workspace.available_tools().await);
86
87        // Add external backend tools
88        schemas.extend(self.backend_registry.get_tool_schemas().await);
89
90        schemas
91    }
92
93    /// Get the list of supported tools
94    pub async fn supported_tools(&self) -> Vec<String> {
95        let schemas = self.get_tool_schemas().await;
96        schemas.into_iter().map(|s| s.name).collect()
97    }
98
99    /// Get the backend registry
100    pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
101        &self.backend_registry
102    }
103
104    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
105    pub async fn execute_tool_with_cancellation(
106        &self,
107        tool_call: &ToolCall,
108        token: CancellationToken,
109    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
110        let tool_name = &tool_call.name;
111        let tool_id = &tool_call.id;
112
113        Span::current().record("tool.name", tool_name);
114        Span::current().record("tool.id", tool_id);
115
116        // Pre-execution validation
117        if let Some(validator) = self.validators.get_validator(tool_name) {
118            // Only validate if we have an LLM config provider
119            if let Some(ref llm_config_provider) = self.llm_config_provider {
120                let validation_context = ValidationContext {
121                    cancellation_token: token.clone(),
122                    llm_config_provider: llm_config_provider.clone(),
123                };
124
125                let validation_result = validator
126                    .validate(tool_call, &validation_context)
127                    .await
128                    .map_err(|e| {
129                        steer_tools::ToolError::InternalError(format!("Validation failed: {e}"))
130                    })?;
131
132                if !validation_result.allowed {
133                    return Err(steer_tools::ToolError::InternalError(
134                        validation_result
135                            .reason
136                            .unwrap_or_else(|| "Tool execution was denied".to_string()),
137                    ));
138                }
139            }
140            // If no LLM config provider, skip validation (allow execution)
141        }
142
143        // Create execution context
144        let mut builder = ExecutionContext::builder(
145            "default".to_string(), // TODO: Get real session ID
146            "default".to_string(), // TODO: Get real operation ID
147            tool_call.id.clone(),
148            token,
149        );
150
151        // Add LLM config provider if available
152        if let Some(provider) = &self.llm_config_provider {
153            builder = builder.llm_config_provider(provider.clone());
154        }
155
156        let context = builder.build();
157
158        // First check if it's a workspace tool
159        let workspace_tools = self.workspace.available_tools().await;
160        if workspace_tools.iter().any(|t| &t.name == tool_name) {
161            debug!(
162                target: "app.tool_executor.execute_tool_with_cancellation",
163                "Executing workspace tool {} ({}) with cancellation",
164                tool_name,
165                tool_id
166            );
167
168            return self
169                .execute_workspace_tool(&self.workspace, tool_call, &context)
170                .await;
171        }
172
173        // Otherwise check external backends
174        let backend = self
175            .backend_registry
176            .get_backend_for_tool(tool_name)
177            .cloned()
178            .ok_or_else(|| {
179                error!(
180                    target: "app.tool_executor.execute_tool_with_cancellation",
181                    "No backend configured for tool: {} ({})",
182                    tool_name,
183                    tool_id
184                );
185                steer_tools::ToolError::UnknownTool(tool_name.clone())
186            })?;
187
188        debug!(
189            target: "app.tool_executor.execute_tool_with_cancellation",
190            "Executing external tool {} ({}) via backend with cancellation",
191            tool_name,
192            tool_id
193        );
194
195        backend.execute(tool_call, &context).await
196    }
197
198    /// Execute a tool directly without validation - for user-initiated bash commands
199    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
200    pub async fn execute_tool_direct(
201        &self,
202        tool_call: &ToolCall,
203        token: CancellationToken,
204    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
205        let tool_name = &tool_call.name;
206        let tool_id = &tool_call.id;
207
208        Span::current().record("tool.name", tool_name);
209        Span::current().record("tool.id", tool_id);
210
211        // Create execution context
212        let mut builder = ExecutionContext::builder(
213            "direct".to_string(), // Mark as direct execution
214            "direct".to_string(),
215            tool_call.id.clone(),
216            token,
217        );
218
219        // Add LLM config provider if available
220        if let Some(provider) = &self.llm_config_provider {
221            builder = builder.llm_config_provider(provider.clone());
222        }
223
224        let context = builder.build();
225
226        // First check if it's a workspace tool (no validation for direct execution)
227        let workspace_tools = self.workspace.available_tools().await;
228        if workspace_tools.iter().any(|t| &t.name == tool_name) {
229            debug!(
230                target: "app.tool_executor.execute_tool_direct",
231                "Executing workspace tool {} ({}) directly (no validation)",
232                tool_name,
233                tool_id
234            );
235
236            return self
237                .execute_workspace_tool(&self.workspace, tool_call, &context)
238                .await;
239        }
240
241        // Otherwise check external backends
242        let backend = self
243            .backend_registry
244            .get_backend_for_tool(tool_name)
245            .cloned()
246            .ok_or_else(|| {
247                error!(
248                    target: "app.tool_executor.execute_tool_direct",
249                    "No backend configured for tool: {} ({})",
250                    tool_name,
251                    tool_id
252                );
253                steer_tools::ToolError::UnknownTool(tool_name.clone())
254            })?;
255
256        debug!(
257            target: "app.tool_executor.execute_tool_direct",
258            "Executing external tool {} ({}) directly via backend (no validation)",
259            tool_name,
260            tool_id
261        );
262
263        backend.execute(tool_call, &context).await
264    }
265
266    /// Helper method to execute a workspace tool
267    async fn execute_workspace_tool(
268        &self,
269        workspace: &Arc<dyn Workspace>,
270        tool_call: &ToolCall,
271        context: &ExecutionContext,
272    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
273        // Convert ExecutionContext to steer-tools ExecutionContext
274        let tools_context = steer_tools::ExecutionContext::new(context.tool_call_id.clone())
275            .with_cancellation_token(context.cancellation_token.clone());
276
277        workspace
278            .execute_tool(tool_call, tools_context)
279            .await
280            .map_err(|e| {
281                steer_tools::ToolError::InternalError(format!("Workspace execution failed: {e}"))
282            })
283    }
284}