1use crate::config::LlmConfigProvider;
2use crate::tools::error::Result;
3use std::sync::Arc;
4use tokio_util::sync::CancellationToken;
5use tracing::{Span, debug, error, instrument};
6
7use crate::app::validation::{ValidationContext, ValidatorRegistry};
8use crate::tools::{BackendRegistry, ExecutionContext};
9use crate::workspace::Workspace;
10use steer_tools::{ToolCall, ToolSchema, result::ToolResult};
11
12#[derive(Clone)]
14pub struct ToolExecutor {
15 pub(crate) workspace: Arc<dyn Workspace>,
17 pub(crate) backend_registry: Arc<BackendRegistry>,
19 pub(crate) validators: Arc<ValidatorRegistry>,
21 pub(crate) llm_config_provider: Option<LlmConfigProvider>,
23}
24
25impl ToolExecutor {
26 pub fn with_workspace(workspace: Arc<dyn Workspace>) -> Self {
28 Self {
29 workspace,
30 backend_registry: Arc::new(BackendRegistry::new()),
31 validators: Arc::new(ValidatorRegistry::new()),
32 llm_config_provider: None,
33 }
34 }
35
36 pub fn with_components(
38 workspace: Arc<dyn Workspace>,
39 backend_registry: Arc<BackendRegistry>,
40 validators: Arc<ValidatorRegistry>,
41 ) -> Self {
42 Self {
43 workspace,
44 backend_registry,
45 validators,
46 llm_config_provider: None,
47 }
48 }
49
50 pub fn with_all_components(
52 workspace: Arc<dyn Workspace>,
53 backend_registry: Arc<BackendRegistry>,
54 validators: Arc<ValidatorRegistry>,
55 llm_config_provider: LlmConfigProvider,
56 ) -> Self {
57 Self {
58 workspace,
59 backend_registry,
60 validators,
61 llm_config_provider: Some(llm_config_provider),
62 }
63 }
64
65 pub async fn requires_approval(&self, tool_name: &str) -> Result<bool> {
66 let workspace_tools = self.workspace.available_tools().await;
68 if workspace_tools.iter().any(|t| t.name == tool_name) {
69 return Ok(self.workspace.requires_approval(tool_name).await?);
70 }
71
72 match self.backend_registry.get_backend_for_tool(tool_name) {
74 Some(backend) => Ok(backend.requires_approval(tool_name).await?),
75 None => Err(steer_tools::ToolError::UnknownTool(tool_name.to_string()).into()),
76 }
77 }
78
79 pub async fn get_tool_schemas(&self) -> Vec<ToolSchema> {
80 let mut schemas = Vec::new();
81
82 schemas.extend(self.workspace.available_tools().await);
84
85 schemas.extend(self.backend_registry.get_tool_schemas().await);
87
88 schemas
89 }
90
91 pub async fn supported_tools(&self) -> Vec<String> {
93 let schemas = self.get_tool_schemas().await;
94 schemas.into_iter().map(|s| s.name).collect()
95 }
96
97 pub fn backend_registry(&self) -> &Arc<BackendRegistry> {
99 &self.backend_registry
100 }
101
102 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
103 pub async fn execute_tool_with_cancellation(
104 &self,
105 tool_call: &ToolCall,
106 token: CancellationToken,
107 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
108 let tool_name = &tool_call.name;
109 let tool_id = &tool_call.id;
110
111 Span::current().record("tool.name", tool_name);
112 Span::current().record("tool.id", tool_id);
113
114 if let Some(validator) = self.validators.get_validator(tool_name) {
116 if let Some(ref llm_config_provider) = self.llm_config_provider {
118 let validation_context = ValidationContext {
119 cancellation_token: token.clone(),
120 llm_config_provider: llm_config_provider.clone(),
121 };
122
123 let validation_result = validator
124 .validate(tool_call, &validation_context)
125 .await
126 .map_err(|e| {
127 steer_tools::ToolError::InternalError(format!("Validation failed: {e}"))
128 })?;
129
130 if !validation_result.allowed {
131 return Err(steer_tools::ToolError::InternalError(
132 validation_result
133 .reason
134 .unwrap_or_else(|| "Tool execution was denied".to_string()),
135 ));
136 }
137 }
138 }
140
141 let mut builder = ExecutionContext::builder(
143 "default".to_string(), "default".to_string(), tool_call.id.clone(),
146 token,
147 );
148
149 if let Some(provider) = &self.llm_config_provider {
151 builder = builder.llm_config_provider(provider.clone());
152 }
153
154 let context = builder.build();
155
156 let workspace_tools = self.workspace.available_tools().await;
158 if workspace_tools.iter().any(|t| &t.name == tool_name) {
159 debug!(
160 target: "app.tool_executor.execute_tool_with_cancellation",
161 "Executing workspace tool {} ({}) with cancellation",
162 tool_name,
163 tool_id
164 );
165
166 return self
167 .execute_workspace_tool(&self.workspace, tool_call, &context)
168 .await;
169 }
170
171 let backend = self
173 .backend_registry
174 .get_backend_for_tool(tool_name)
175 .cloned()
176 .ok_or_else(|| {
177 error!(
178 target: "app.tool_executor.execute_tool_with_cancellation",
179 "No backend configured for tool: {} ({})",
180 tool_name,
181 tool_id
182 );
183 steer_tools::ToolError::UnknownTool(tool_name.clone())
184 })?;
185
186 debug!(
187 target: "app.tool_executor.execute_tool_with_cancellation",
188 "Executing external tool {} ({}) via backend with cancellation",
189 tool_name,
190 tool_id
191 );
192
193 backend.execute(tool_call, &context).await
194 }
195
196 #[instrument(skip(self, tool_call, token), fields(tool.name = %tool_call.name, tool.id = %tool_call.id))]
198 pub async fn execute_tool_direct(
199 &self,
200 tool_call: &ToolCall,
201 token: CancellationToken,
202 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
203 let tool_name = &tool_call.name;
204 let tool_id = &tool_call.id;
205
206 Span::current().record("tool.name", tool_name);
207 Span::current().record("tool.id", tool_id);
208
209 let mut builder = ExecutionContext::builder(
211 "direct".to_string(), "direct".to_string(),
213 tool_call.id.clone(),
214 token,
215 );
216
217 if let Some(provider) = &self.llm_config_provider {
219 builder = builder.llm_config_provider(provider.clone());
220 }
221
222 let context = builder.build();
223
224 let workspace_tools = self.workspace.available_tools().await;
226 if workspace_tools.iter().any(|t| &t.name == tool_name) {
227 debug!(
228 target: "app.tool_executor.execute_tool_direct",
229 "Executing workspace tool {} ({}) directly (no validation)",
230 tool_name,
231 tool_id
232 );
233
234 return self
235 .execute_workspace_tool(&self.workspace, tool_call, &context)
236 .await;
237 }
238
239 let backend = self
241 .backend_registry
242 .get_backend_for_tool(tool_name)
243 .cloned()
244 .ok_or_else(|| {
245 error!(
246 target: "app.tool_executor.execute_tool_direct",
247 "No backend configured for tool: {} ({})",
248 tool_name,
249 tool_id
250 );
251 steer_tools::ToolError::UnknownTool(tool_name.clone())
252 })?;
253
254 debug!(
255 target: "app.tool_executor.execute_tool_direct",
256 "Executing external tool {} ({}) directly via backend (no validation)",
257 tool_name,
258 tool_id
259 );
260
261 backend.execute(tool_call, &context).await
262 }
263
264 async fn execute_workspace_tool(
266 &self,
267 workspace: &Arc<dyn Workspace>,
268 tool_call: &ToolCall,
269 context: &ExecutionContext,
270 ) -> std::result::Result<ToolResult, steer_tools::ToolError> {
271 let tools_context = steer_tools::ExecutionContext::new(context.tool_call_id.clone())
273 .with_cancellation_token(context.cancellation_token.clone());
274
275 workspace
276 .execute_tool(tool_call, tools_context)
277 .await
278 .map_err(|e| {
279 steer_tools::ToolError::InternalError(format!("Workspace execution failed: {e}"))
280 })
281 }
282}