Skip to main content

steer_core/tools/
executor.rs

1use crate::app::domain::types::{SessionId, ToolCallId};
2use crate::config::LlmConfigProvider;
3use crate::config::model::ModelId;
4use crate::tools::error::Result;
5use std::sync::Arc;
6use tokio_util::sync::CancellationToken;
7use tracing::{Span, debug, error, instrument};
8
9use crate::app::validation::{ValidationContext, ValidatorRegistry};
10use crate::tools::registry::ToolRegistry;
11use crate::tools::resolver::BackendResolver;
12use crate::tools::services::ToolServices;
13use crate::tools::static_tool::{StaticToolContext, StaticToolError};
14use crate::tools::{BackendRegistry, ExecutionContext};
15use steer_tools::{ToolCall, ToolSchema, result::ToolResult};
16
17#[derive(Clone)]
18pub struct ToolExecutor {
19    pub(crate) backend_registry: Arc<BackendRegistry>,
20    pub(crate) validators: Arc<ValidatorRegistry>,
21    pub(crate) llm_config_provider: Option<LlmConfigProvider>,
22    pub(crate) tool_registry: Option<Arc<ToolRegistry>>,
23    pub(crate) tool_services: Option<Arc<ToolServices>>,
24}
25
26impl ToolExecutor {
27    pub fn with_components(
28        backend_registry: Arc<BackendRegistry>,
29        validators: Arc<ValidatorRegistry>,
30    ) -> Self {
31        Self {
32            backend_registry,
33            validators,
34            llm_config_provider: None,
35            tool_registry: None,
36            tool_services: None,
37        }
38    }
39
40    pub fn with_all_components(
41        backend_registry: Arc<BackendRegistry>,
42        validators: Arc<ValidatorRegistry>,
43        llm_config_provider: LlmConfigProvider,
44    ) -> Self {
45        Self {
46            backend_registry,
47            validators,
48            llm_config_provider: Some(llm_config_provider),
49            tool_registry: None,
50            tool_services: None,
51        }
52    }
53
54    pub fn with_static_tools(
55        mut self,
56        registry: Arc<ToolRegistry>,
57        services: Arc<ToolServices>,
58    ) -> Self {
59        self.tool_registry = Some(registry);
60        self.tool_services = Some(services);
61        self
62    }
63
64    pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
65        if let Some(registry) = &self.tool_registry
66            && registry.is_static_tool(tool_name)
67        {
68            return Ok(registry.requires_approval(tool_name));
69        }
70
71        match self.backend_registry.get_backend_for_tool(tool_name) {
72            Some(backend) => Ok(backend.requires_approval(tool_name).await?),
73            None => Err(steer_tools::ToolError::UnknownTool(tool_name.to_string()).into()),
74        }
75    }
76
77    pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
78        self.get_tool_schemas_with_capabilities(super::Capabilities::all())
79            .await
80    }
81
82    pub async fn get_tool_schemas_with_resolver(
83        &self,
84        session_resolver: Option<&dyn BackendResolver>,
85    ) -> Vec<ToolSchema> {
86        self.get_tool_schemas_with_capabilities_and_resolver(
87            super::Capabilities::all(),
88            session_resolver,
89        )
90        .await
91    }
92
93    pub async fn get_tool_schemas_with_capabilities(
94        &self,
95        capabilities: super::Capabilities,
96    ) -> Vec<ToolSchema> {
97        self.get_tool_schemas_with_capabilities_and_resolver(capabilities, None)
98            .await
99    }
100
101    pub async fn get_tool_schemas_with_capabilities_and_resolver(
102        &self,
103        capabilities: super::Capabilities,
104        session_resolver: Option<&dyn BackendResolver>,
105    ) -> Vec<ToolSchema> {
106        let mut schemas = Vec::new();
107        let mut static_tool_names = std::collections::HashSet::new();
108
109        if let Some(registry) = &self.tool_registry {
110            let static_schemas = registry.available_schemas(capabilities).await;
111            for schema in &static_schemas {
112                static_tool_names.insert(schema.name.clone());
113            }
114            schemas.extend(static_schemas);
115        }
116
117        if let Some(resolver) = session_resolver {
118            for schema in resolver.get_tool_schemas().await {
119                if !static_tool_names.contains(&schema.name) {
120                    schemas.push(schema);
121                }
122            }
123        }
124
125        for schema in self.backend_registry.get_tool_schemas().await {
126            if !static_tool_names.contains(&schema.name) {
127                schemas.push(schema);
128            }
129        }
130
131        schemas
132    }
133
134    pub fn is_static_tool(&self, tool_name: &str) -> bool {
135        self.tool_registry
136            .as_ref()
137            .is_some_and(|r| r.is_static_tool(tool_name))
138    }
139
140    /// Get the list of supported tools
141    pub async fn supported_tools(&self) -> Vec<String> {
142        let schemas = self.get_tool_schemas().await;
143        schemas.into_iter().map(|s| s.name).collect()
144    }
145
146    /// Get the backend registry
147    pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
148        &self.backend_registry
149    }
150
151    pub fn workspace(&self) -> Option<Arc<dyn crate::workspace::Workspace>> {
152        self.tool_services
153            .as_ref()
154            .map(|services| services.workspace.clone())
155    }
156
157    #[instrument(skip(self, tool_call, session_id, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
158    pub async fn execute_tool_with_session(
159        &self,
160        tool_call: &ToolCall,
161        session_id: SessionId,
162        token: CancellationToken,
163    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
164        self.execute_tool_with_session_resolver(tool_call, session_id, None, token, None)
165            .await
166    }
167
168    #[instrument(skip(self, tool_call, session_id, invoking_model, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
169    pub async fn execute_tool_with_session_resolver(
170        &self,
171        tool_call: &ToolCall,
172        session_id: SessionId,
173        invoking_model: Option<ModelId>,
174        token: CancellationToken,
175        session_resolver: Option<&dyn BackendResolver>,
176    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
177        let tool_name = &tool_call.name;
178
179        if let Some((registry, services)) =
180            self.tool_registry.as_ref().zip(self.tool_services.as_ref())
181            && let Some(tool) = registry.static_tool(tool_name)
182        {
183            debug!(target: "tool_executor", "Executing static tool: {}", tool_name);
184            return self
185                .execute_static_tool(tool, tool_call, session_id, invoking_model, services, token)
186                .await;
187        }
188
189        self.execute_tool_with_resolver(tool_call, token, session_resolver)
190            .await
191    }
192
193    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
194    pub async fn execute_tool_with_cancellation(
195        &self,
196        tool_call: &ToolCall,
197        token: CancellationToken,
198    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
199        self.execute_tool_with_resolver(tool_call, token, None)
200            .await
201    }
202
203    #[instrument(skip(self, tool_call, token, session_resolver), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
204    pub async fn execute_tool_with_resolver(
205        &self,
206        tool_call: &ToolCall,
207        token: CancellationToken,
208        session_resolver: Option<&dyn BackendResolver>,
209    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
210        let tool_name = &tool_call.name;
211        let tool_id = &tool_call.id;
212
213        Span::current().record("tool.name", tool_name);
214        Span::current().record("tool.id", tool_id);
215
216        if let Some(validator) = self.validators.get_validator(tool_name)
217            && let Some(ref llm_config_provider) = self.llm_config_provider
218        {
219            let validation_context = ValidationContext {
220                cancellation_token: token.clone(),
221                llm_config_provider: llm_config_provider.clone(),
222            };
223
224            let validation_result = validator
225                .validate(tool_call, &validation_context)
226                .await
227                .map_err(|e| {
228                    steer_tools::ToolError::InternalError(format!("Validation failed: {e}"))
229                })?;
230
231            if !validation_result.allowed {
232                return Err(steer_tools::ToolError::InternalError(
233                    validation_result
234                        .reason
235                        .unwrap_or_else(|| "Tool execution was denied".to_string()),
236                ));
237            }
238        }
239
240        let mut builder = ExecutionContext::builder(
241            "default".to_string(),
242            "default".to_string(),
243            tool_call.id.clone(),
244            token,
245        );
246
247        if let Some(provider) = &self.llm_config_provider {
248            builder = builder.llm_config_provider(provider.clone());
249        }
250
251        let context = builder.build();
252
253        if let Some(resolver) = session_resolver
254            && let Some(backend) = resolver.resolve(tool_name).await
255        {
256            debug!(target: "tool_executor", "Executing session MCP tool: {} ({})", tool_name, tool_id);
257            return backend.execute(tool_call, &context).await;
258        }
259
260        let backend = self
261            .backend_registry
262            .get_backend_for_tool(tool_name)
263            .cloned()
264            .ok_or_else(|| {
265                error!(target: "tool_executor", "No backend for tool: {} ({})", tool_name, tool_id);
266                steer_tools::ToolError::UnknownTool(tool_name.clone())
267            })?;
268
269        debug!(target: "tool_executor", "Executing external tool: {} ({})", tool_name, tool_id);
270        backend.execute(tool_call, &context).await
271    }
272
273    async fn execute_static_tool(
274        &self,
275        tool: &dyn super::static_tool::StaticToolErased,
276        tool_call: &ToolCall,
277        session_id: SessionId,
278        invoking_model: Option<ModelId>,
279        services: &Arc<ToolServices>,
280        token: CancellationToken,
281    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
282        let ctx = StaticToolContext {
283            tool_call_id: ToolCallId(tool_call.id.clone()),
284            session_id,
285            invoking_model,
286            cancellation_token: token,
287            services: services.clone(),
288        };
289
290        let output = tool
291            .execute_erased(tool_call.parameters.clone(), &ctx)
292            .await
293            .map_err(|e| match e {
294                StaticToolError::InvalidParams(msg) => steer_tools::ToolError::InvalidParams {
295                    tool_name: tool_call.name.clone(),
296                    message: msg,
297                },
298                StaticToolError::Execution(err) => steer_tools::ToolError::Execution(err),
299                StaticToolError::MissingCapability(cap) => {
300                    steer_tools::ToolError::InternalError(format!("Missing capability: {cap}"))
301                }
302                StaticToolError::Cancelled => {
303                    steer_tools::ToolError::Cancelled(tool_call.name.clone())
304                }
305                StaticToolError::Timeout => steer_tools::ToolError::Timeout(tool_call.name.clone()),
306            })?;
307
308        Ok(output)
309    }
310
311    /// Execute a tool directly without validation - for user-initiated bash commands
312    #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
313    pub async fn execute_tool_direct(
314        &self,
315        tool_call: &ToolCall,
316        token: CancellationToken,
317    ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
318        let tool_name = &tool_call.name;
319        let tool_id = &tool_call.id;
320
321        Span::current().record("tool.name", tool_name);
322        Span::current().record("tool.id", tool_id);
323
324        if let Some((registry, services)) =
325            self.tool_registry.as_ref().zip(self.tool_services.as_ref())
326            && let Some(tool) = registry.static_tool(tool_name)
327        {
328            debug!(
329                target: "app.tool_executor.execute_tool_direct",
330                "Executing static tool {} ({}) directly (no validation)",
331                tool_name,
332                tool_id
333            );
334            return self
335                .execute_static_tool(tool, tool_call, SessionId::new(), None, services, token)
336                .await;
337        }
338
339        // Create execution context
340        let mut builder = ExecutionContext::builder(
341            "direct".to_string(), // Mark as direct execution
342            "direct".to_string(),
343            tool_call.id.clone(),
344            token,
345        );
346
347        // Add LLM config provider if available
348        if let Some(provider) = &self.llm_config_provider {
349            builder = builder.llm_config_provider(provider.clone());
350        }
351
352        let context = builder.build();
353
354        // Otherwise check external backends
355        let backend = self
356            .backend_registry
357            .get_backend_for_tool(tool_name)
358            .cloned()
359            .ok_or_else(|| {
360                error!(
361                    target: "app.tool_executor.execute_tool_direct",
362                    "No backend configured for tool: {} ({})",
363                    tool_name,
364                    tool_id
365                );
366                steer_tools::ToolError::UnknownTool(tool_name.clone())
367            })?;
368
369        debug!(
370            target: "app.tool_executor.execute_tool_direct",
371            "Executing external tool {} ({}) directly via backend (no validation)",
372            tool_name,
373            tool_id
374        );
375
376        backend.execute(tool_call, &context).await
377    }
378}