Skip to main content

steer_core/tools/
executor.rs

1use crate::app::domain::types::{SessionId, ToolCallId};
2use crate::config::LlmConfigProvider;
3use crate::tools::error::Result;
4use std::sync::Arc;
5use tokio_util::sync::CancellationToken;
6use tracing::{Span, debug, error, instrument};
7
8use crate::app::validation::{ValidationContext, ValidatorRegistry};
9use crate::tools::registry::ToolRegistry;
10use crate::tools::resolver::BackendResolver;
11use crate::tools::services::ToolServices;
12use crate::tools::static_tool::{StaticToolContext, StaticToolError};
13use crate::tools::{BackendRegistry, ExecutionContext};
14use steer_tools::{ToolCall, ToolSchema, result::ToolResult};
15
16#[derive(Clone)]
17pub struct ToolExecutor {
18    pub(crate) backend_registry: Arc<BackendRegistry>,
19    pub(crate) validators: Arc<ValidatorRegistry>,
20    pub(crate) llm_config_provider: Option<LlmConfigProvider>,
21    pub(crate) tool_registry: Option<Arc<ToolRegistry>>,
22    pub(crate) tool_services: Option<Arc<ToolServices>>,
23}
24
25impl ToolExecutor {
26    pub fn with_components(
27        backend_registry: Arc<BackendRegistry>,
28        validators: Arc<ValidatorRegistry>,
29    ) -> Self {
30        Self {
31            backend_registry,
32            validators,
33            llm_config_provider: None,
34            tool_registry: None,
35            tool_services: None,
36        }
37    }
38
39    pub fn with_all_components(
40        backend_registry: Arc<BackendRegistry>,
41        validators: Arc<ValidatorRegistry>,
42        llm_config_provider: LlmConfigProvider,
43    ) -> Self {
44        Self {
45            backend_registry,
46            validators,
47            llm_config_provider: Some(llm_config_provider),
48            tool_registry: None,
49            tool_services: None,
50        }
51    }
52
53    pub fn with_static_tools(
54        mut self,
55        registry: Arc<ToolRegistry>,
56        services: Arc<ToolServices>,
57    ) -> Self {
58        self.tool_registry = Some(registry);
59        self.tool_services = Some(services);
60        self
61    }
62
63    pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
64        if let Some(registry) = &self.tool_registry
65            && registry.is_static_tool(tool_name)
66        {
67            return Ok(registry.requires_approval(tool_name));
68        }
69
70        match self.backend_registry.get_backend_for_tool(tool_name) {
71            Some(backend) => Ok(backend.requires_approval(tool_name).await?),
72            None => Err(steer_tools::ToolError::UnknownTool(tool_name.to_string()).into()),
73        }
74    }
75
76    pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
77        self.get_tool_schemas_with_capabilities(super::Capabilities::all())
78            .await
79    }
80
81    pub async fn get_tool_schemas_with_resolver(
82        &self,
83        session_resolver: Option<&dyn BackendResolver>,
84    ) -> Vec<ToolSchema> {
85        self.get_tool_schemas_with_capabilities_and_resolver(
86            super::Capabilities::all(),
87            session_resolver,
88        )
89        .await
90    }
91
92    pub async fn get_tool_schemas_with_capabilities(
93        &self,
94        capabilities: super::Capabilities,
95    ) -> Vec<ToolSchema> {
96        self.get_tool_schemas_with_capabilities_and_resolver(capabilities, None)
97            .await
98    }
99
100    pub async fn get_tool_schemas_with_capabilities_and_resolver(
101        &self,
102        capabilities: super::Capabilities,
103        session_resolver: Option<&dyn BackendResolver>,
104    ) -> Vec<ToolSchema> {
105        let mut schemas = Vec::new();
106        let mut static_tool_names = std::collections::HashSet::new();
107
108        if let Some(registry) = &self.tool_registry {
109            let static_schemas = registry.available_schemas(capabilities).await;
110            for schema in &static_schemas {
111                static_tool_names.insert(schema.name.clone());
112            }
113            schemas.extend(static_schemas);
114        }
115
116        if let Some(resolver) = session_resolver {
117            for schema in resolver.get_tool_schemas().await {
118                if !static_tool_names.contains(&schema.name) {
119                    schemas.push(schema);
120                }
121            }
122        }
123
124        for schema in self.backend_registry.get_tool_schemas().await {
125            if !static_tool_names.contains(&schema.name) {
126                schemas.push(schema);
127            }
128        }
129
130        schemas
131    }
132
133    pub fn is_static_tool(&self, tool_name: &str) -> bool {
134        self.tool_registry
135            .as_ref()
136            .is_some_and(|r| r.is_static_tool(tool_name))
137    }
138
139    /// Get the list of supported tools
140    pub async fn supported_tools(&self) -> Vec<String> {
141        let schemas = self.get_tool_schemas().await;
142        schemas.into_iter().map(|s| s.name).collect()
143    }
144
145    /// Get the backend registry
146    pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
147        &self.backend_registry
148    }
149
150    pub fn workspace(&self) -> Option<Arc<dyn crate::workspace::Workspace>> {
151        self.tool_services
152            .as_ref()
153            .map(|services| services.workspace.clone())
154    }
155
156    #[instrument(skip(self, tool_call, session_id, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
157    pub async fn execute_tool_with_session(
158        &self,
159        tool_call: &ToolCall,
160        session_id: SessionId,
161        token: CancellationToken,
162    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
163        self.execute_tool_with_session_resolver(tool_call, session_id, token, None)
164            .await
165    }
166
167    #[instrument(skip(self, tool_call, session_id, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
168    pub async fn execute_tool_with_session_resolver(
169        &self,
170        tool_call: &ToolCall,
171        session_id: SessionId,
172        token: CancellationToken,
173        session_resolver: Option<&dyn BackendResolver>,
174    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
175        let tool_name = &tool_call.name;
176
177        if let Some((registry, services)) =
178            self.tool_registry.as_ref().zip(self.tool_services.as_ref())
179            && let Some(tool) = registry.static_tool(tool_name)
180        {
181            debug!(target: "tool_executor", "Executing static tool: {}", tool_name);
182            return self
183                .execute_static_tool(tool, tool_call, session_id, services, token)
184                .await;
185        }
186
187        self.execute_tool_with_resolver(tool_call, token, session_resolver)
188            .await
189    }
190
191    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
192    pub async fn execute_tool_with_cancellation(
193        &self,
194        tool_call: &ToolCall,
195        token: CancellationToken,
196    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
197        self.execute_tool_with_resolver(tool_call, token, None)
198            .await
199    }
200
201    #[instrument(skip(self, tool_call, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
202    pub async fn execute_tool_with_resolver(
203        &self,
204        tool_call: &ToolCall,
205        token: CancellationToken,
206        session_resolver: Option<&dyn BackendResolver>,
207    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
208        let tool_name = &tool_call.name;
209        let tool_id = &tool_call.id;
210
211        Span::current().record("tool.name", tool_name);
212        Span::current().record("tool.id", tool_id);
213
214        if let Some(validator) = self.validators.get_validator(tool_name)
215            && let Some(ref llm_config_provider) = self.llm_config_provider
216        {
217            let validation_context = ValidationContext {
218                cancellation_token: token.clone(),
219                llm_config_provider: llm_config_provider.clone(),
220            };
221
222            let validation_result = validator
223                .validate(tool_call, &validation_context)
224                .await
225                .map_err(|e| {
226                    steer_tools::ToolError::InternalError(format!("Validation failed: {e}"))
227                })?;
228
229            if !validation_result.allowed {
230                return Err(steer_tools::ToolError::InternalError(
231                    validation_result
232                        .reason
233                        .unwrap_or_else(|| "Tool execution was denied".to_string()),
234                ));
235            }
236        }
237
238        let mut builder = ExecutionContext::builder(
239            "default".to_string(),
240            "default".to_string(),
241            tool_call.id.clone(),
242            token,
243        );
244
245        if let Some(provider) = &self.llm_config_provider {
246            builder = builder.llm_config_provider(provider.clone());
247        }
248
249        let context = builder.build();
250
251        if let Some(resolver) = session_resolver
252            && let Some(backend) = resolver.resolve(tool_name).await
253        {
254            debug!(target: "tool_executor", "Executing session MCP tool: {} ({})", tool_name, tool_id);
255            return backend.execute(tool_call, &context).await;
256        }
257
258        let backend = self
259            .backend_registry
260            .get_backend_for_tool(tool_name)
261            .cloned()
262            .ok_or_else(|| {
263                error!(target: "tool_executor", "No backend for tool: {} ({})", tool_name, tool_id);
264                steer_tools::ToolError::UnknownTool(tool_name.clone())
265            })?;
266
267        debug!(target: "tool_executor", "Executing external tool: {} ({})", tool_name, tool_id);
268        backend.execute(tool_call, &context).await
269    }
270
271    async fn execute_static_tool(
272        &self,
273        tool: &dyn super::static_tool::StaticToolErased,
274        tool_call: &ToolCall,
275        session_id: SessionId,
276        services: &Arc<ToolServices>,
277        token: CancellationToken,
278    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
279        let ctx = StaticToolContext {
280            tool_call_id: ToolCallId(tool_call.id.clone()),
281            session_id,
282            cancellation_token: token,
283            services: services.clone(),
284        };
285
286        let output = tool
287            .execute_erased(tool_call.parameters.clone(), &ctx)
288            .await
289            .map_err(|e| match e {
290                StaticToolError::InvalidParams(msg) => steer_tools::ToolError::InvalidParams {
291                    tool_name: tool_call.name.clone(),
292                    message: msg,
293                },
294                StaticToolError::Execution(err) => steer_tools::ToolError::Execution(err),
295                StaticToolError::MissingCapability(cap) => {
296                    steer_tools::ToolError::InternalError(format!("Missing capability: {cap}"))
297                }
298                StaticToolError::Cancelled => {
299                    steer_tools::ToolError::Cancelled(tool_call.name.clone())
300                }
301                StaticToolError::Timeout => steer_tools::ToolError::Timeout(tool_call.name.clone()),
302            })?;
303
304        Ok(output)
305    }
306
307    /// Execute a tool directly without validation - for user-initiated bash commands
308    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
309    pub async fn execute_tool_direct(
310        &self,
311        tool_call: &ToolCall,
312        token: CancellationToken,
313    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
314        let tool_name = &tool_call.name;
315        let tool_id = &tool_call.id;
316
317        Span::current().record("tool.name", tool_name);
318        Span::current().record("tool.id", tool_id);
319
320        if let Some((registry, services)) =
321            self.tool_registry.as_ref().zip(self.tool_services.as_ref())
322            && let Some(tool) = registry.static_tool(tool_name)
323        {
324            debug!(
325                target: "app.tool_executor.execute_tool_direct",
326                "Executing static tool {} ({}) directly (no validation)",
327                tool_name,
328                tool_id
329            );
330            return self
331                .execute_static_tool(tool, tool_call, SessionId::new(), services, token)
332                .await;
333        }
334
335        // Create execution context
336        let mut builder = ExecutionContext::builder(
337            "direct".to_string(), // Mark as direct execution
338            "direct".to_string(),
339            tool_call.id.clone(),
340            token,
341        );
342
343        // Add LLM config provider if available
344        if let Some(provider) = &self.llm_config_provider {
345            builder = builder.llm_config_provider(provider.clone());
346        }
347
348        let context = builder.build();
349
350        // Otherwise check external backends
351        let backend = self
352            .backend_registry
353            .get_backend_for_tool(tool_name)
354            .cloned()
355            .ok_or_else(|| {
356                error!(
357                    target: "app.tool_executor.execute_tool_direct",
358                    "No backend configured for tool: {} ({})",
359                    tool_name,
360                    tool_id
361                );
362                steer_tools::ToolError::UnknownTool(tool_name.clone())
363            })?;
364
365        debug!(
366            target: "app.tool_executor.execute_tool_direct",
367            "Executing external tool {} ({}) directly via backend (no validation)",
368            tool_name,
369            tool_id
370        );
371
372        backend.execute(tool_call, &context).await
373    }
374}