1pub use crate::contracts::runtime::ToolExecution;
4use crate::contracts::thread::ToolCall;
5use crate::contracts::tool::context::ToolCallContext;
6use crate::contracts::tool::{Tool, ToolResult};
7use futures::future::join_all;
8use serde_json::Value;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use tirea_contract::RunConfig;
12use tirea_state::{apply_patch, DocCell, TrackedPatch};
13
14pub async fn execute_single_tool(
27 tool: Option<&dyn Tool>,
28 call: &ToolCall,
29 state: &Value,
30) -> ToolExecution {
31 execute_single_tool_with_scope(tool, call, state, None).await
32}
33
34pub async fn execute_single_tool_with_scope(
36 tool: Option<&dyn Tool>,
37 call: &ToolCall,
38 state: &Value,
39 scope: Option<&RunConfig>,
40) -> ToolExecution {
41 let Some(tool) = tool else {
42 return ToolExecution {
43 call: call.clone(),
44 result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
45 patch: None,
46 };
47 };
48
49 let doc = DocCell::new(state.clone());
51 let ops = Mutex::new(Vec::new());
52 let default_scope = RunConfig::default();
53 let scope = scope.unwrap_or(&default_scope);
54 let pending_messages = Mutex::new(Vec::new());
55 let ctx = ToolCallContext::new(
56 &doc,
57 &ops,
58 &call.id,
59 format!("tool:{}", call.name),
60 scope,
61 &pending_messages,
62 None,
63 );
64
65 if let Err(e) = tool.validate_args(&call.arguments) {
67 return ToolExecution {
68 call: call.clone(),
69 result: ToolResult::error(&call.name, e.to_string()),
70 patch: None,
71 };
72 }
73
74 let result = match tool.execute(call.arguments.clone(), &ctx).await {
76 Ok(r) => r,
77 Err(e) => ToolResult::error(&call.name, e.to_string()),
78 };
79
80 let patch = ctx.take_patch();
82 let patch = if patch.patch().is_empty() {
83 None
84 } else {
85 Some(patch)
86 };
87
88 ToolExecution {
89 call: call.clone(),
90 result,
91 patch,
92 }
93}
94
95pub async fn execute_tools_parallel(
97 tools: &HashMap<String, Arc<dyn Tool>>,
98 calls: &[ToolCall],
99 state: &Value,
100) -> Vec<ToolExecution> {
101 let tasks = calls.iter().map(|call| {
102 let tool = tools.get(&call.name).cloned();
103 let state = state.clone();
104 async move { execute_single_tool(tool.as_deref(), call, &state).await }
105 });
106 join_all(tasks).await
107}
108
109pub async fn execute_tools_sequential(
111 tools: &HashMap<String, Arc<dyn Tool>>,
112 calls: &[ToolCall],
113 state: &Value,
114) -> (Value, Vec<ToolExecution>) {
115 let mut state = state.clone();
116 let mut executions = Vec::with_capacity(calls.len());
117
118 for call in calls {
119 let exec = execute_single_tool(tools.get(&call.name).map(Arc::as_ref), call, &state).await;
120 if let Some(patch) = exec.patch.as_ref() {
121 if let Ok(next) = apply_patch(&state, patch.patch()) {
122 state = next;
123 }
124 }
125 executions.push(exec);
126 }
127
128 (state, executions)
129}
130
131pub fn collect_patches(executions: &[ToolExecution]) -> Vec<TrackedPatch> {
133 executions.iter().filter_map(|e| e.patch.clone()).collect()
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::contracts::tool::{ToolDescriptor, ToolError};
140 use crate::contracts::ToolCallContext;
141 use async_trait::async_trait;
142 use serde_json::json;
143
144 struct EchoTool;
145
146 #[async_trait]
147 impl Tool for EchoTool {
148 fn descriptor(&self) -> ToolDescriptor {
149 ToolDescriptor::new("echo", "Echo", "Echo the input")
150 }
151
152 async fn execute(
153 &self,
154 args: Value,
155 _ctx: &ToolCallContext<'_>,
156 ) -> Result<ToolResult, ToolError> {
157 Ok(ToolResult::success("echo", args))
158 }
159 }
160
161 #[tokio::test]
162 async fn test_execute_single_tool_not_found() {
163 let call = ToolCall::new("call_1", "nonexistent", json!({}));
164 let state = json!({});
165
166 let exec = execute_single_tool(None, &call, &state).await;
167
168 assert!(exec.result.is_error());
169 assert!(exec.patch.is_none());
170 }
171
172 #[tokio::test]
173 async fn test_execute_single_tool_success() {
174 let tool = EchoTool;
175 let call = ToolCall::new("call_1", "echo", json!({"msg": "hello"}));
176 let state = json!({});
177
178 let exec = execute_single_tool(Some(&tool), &call, &state).await;
179
180 assert!(exec.result.is_success());
181 assert_eq!(exec.result.data["msg"], "hello");
182 }
183
184 #[tokio::test]
185 async fn test_collect_patches() {
186 use tirea_state::{path, Op, Patch};
187
188 let executions = vec![
189 ToolExecution {
190 call: ToolCall::new("1", "a", json!({})),
191 result: ToolResult::success("a", json!({})),
192 patch: Some(TrackedPatch::new(
193 Patch::new().with_op(Op::set(path!("a"), json!(1))),
194 )),
195 },
196 ToolExecution {
197 call: ToolCall::new("2", "b", json!({})),
198 result: ToolResult::success("b", json!({})),
199 patch: None,
200 },
201 ToolExecution {
202 call: ToolCall::new("3", "c", json!({})),
203 result: ToolResult::success("c", json!({})),
204 patch: Some(TrackedPatch::new(
205 Patch::new().with_op(Op::set(path!("c"), json!(3))),
206 )),
207 },
208 ];
209
210 let patches = collect_patches(&executions);
211 assert_eq!(patches.len(), 2);
212 }
213
214 #[tokio::test]
215 async fn test_tool_execution_error() {
216 struct FailingTool;
217
218 #[async_trait]
219 impl Tool for FailingTool {
220 fn descriptor(&self) -> ToolDescriptor {
221 ToolDescriptor::new("failing", "Failing", "Always fails")
222 }
223
224 async fn execute(
225 &self,
226 _args: Value,
227 _ctx: &ToolCallContext<'_>,
228 ) -> Result<ToolResult, ToolError> {
229 Err(ToolError::ExecutionFailed(
230 "Intentional failure".to_string(),
231 ))
232 }
233 }
234
235 let tool = FailingTool;
236 let call = ToolCall::new("call_1", "failing", json!({}));
237 let state = json!({});
238
239 let exec = execute_single_tool(Some(&tool), &call, &state).await;
240
241 assert!(exec.result.is_error());
242 assert!(exec
243 .result
244 .message
245 .as_ref()
246 .unwrap()
247 .contains("Intentional failure"));
248 }
249
250 #[tokio::test]
251 async fn test_execute_single_tool_with_scope_reads() {
252 struct ScopeReaderTool;
254
255 #[async_trait]
256 impl Tool for ScopeReaderTool {
257 fn descriptor(&self) -> ToolDescriptor {
258 ToolDescriptor::new("scope_reader", "ScopeReader", "Reads scope values")
259 }
260
261 async fn execute(
262 &self,
263 _args: Value,
264 ctx: &ToolCallContext<'_>,
265 ) -> Result<ToolResult, ToolError> {
266 let user_id = ctx
267 .config_value("user_id")
268 .and_then(|v| v.as_str())
269 .unwrap_or("unknown");
270 Ok(ToolResult::success(
271 "scope_reader",
272 json!({"user_id": user_id}),
273 ))
274 }
275 }
276
277 let mut scope = RunConfig::new();
278 scope.set("user_id", "u-42").unwrap();
279
280 let tool = ScopeReaderTool;
281 let call = ToolCall::new("call_1", "scope_reader", json!({}));
282 let state = json!({});
283
284 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
285
286 assert!(exec.result.is_success());
287 assert_eq!(exec.result.data["user_id"], "u-42");
288 }
289
290 #[tokio::test]
291 async fn test_execute_single_tool_with_scope_none() {
292 struct ScopeCheckerTool;
294
295 #[async_trait]
296 impl Tool for ScopeCheckerTool {
297 fn descriptor(&self) -> ToolDescriptor {
298 ToolDescriptor::new("scope_checker", "ScopeChecker", "Checks scope presence")
299 }
300
301 async fn execute(
302 &self,
303 _args: Value,
304 ctx: &ToolCallContext<'_>,
305 ) -> Result<ToolResult, ToolError> {
306 let has_user_id = ctx.config_value("user_id").is_some();
309 Ok(ToolResult::success(
310 "scope_checker",
311 json!({"has_scope": true, "has_user_id": has_user_id}),
312 ))
313 }
314 }
315
316 let tool = ScopeCheckerTool;
317 let call = ToolCall::new("call_1", "scope_checker", json!({}));
318 let state = json!({});
319
320 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, None).await;
322 assert_eq!(exec.result.data["has_scope"], true);
323 assert_eq!(exec.result.data["has_user_id"], false);
324
325 let scope = RunConfig::new();
327 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
328 assert_eq!(exec.result.data["has_scope"], true);
329 assert_eq!(exec.result.data["has_user_id"], false);
330 }
331
332 #[tokio::test]
333 async fn test_execute_with_scope_sensitive_key() {
334 struct SensitiveReaderTool;
336
337 #[async_trait]
338 impl Tool for SensitiveReaderTool {
339 fn descriptor(&self) -> ToolDescriptor {
340 ToolDescriptor::new("sensitive", "Sensitive", "Reads sensitive key")
341 }
342
343 async fn execute(
344 &self,
345 _args: Value,
346 ctx: &ToolCallContext<'_>,
347 ) -> Result<ToolResult, ToolError> {
348 let scope = ctx.run_config();
349 let token = scope.value("token").and_then(|v| v.as_str()).unwrap();
350 let is_sensitive = scope.is_sensitive("token");
351 Ok(ToolResult::success(
352 "sensitive",
353 json!({"token_len": token.len(), "is_sensitive": is_sensitive}),
354 ))
355 }
356 }
357
358 let mut scope = RunConfig::new();
359 scope.set_sensitive("token", "super-secret-token").unwrap();
360
361 let tool = SensitiveReaderTool;
362 let call = ToolCall::new("call_1", "sensitive", json!({}));
363 let state = json!({});
364
365 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
366
367 assert!(exec.result.is_success());
368 assert_eq!(exec.result.data["token_len"], 18);
369 assert_eq!(exec.result.data["is_sensitive"], true);
370 }
371
372 struct StrictSchemaTool {
378 executed: std::sync::atomic::AtomicBool,
379 }
380
381 #[async_trait]
382 impl Tool for StrictSchemaTool {
383 fn descriptor(&self) -> ToolDescriptor {
384 ToolDescriptor::new("strict", "Strict", "Requires a string 'name'").with_parameters(
385 json!({
386 "type": "object",
387 "properties": {
388 "name": { "type": "string" }
389 },
390 "required": ["name"]
391 }),
392 )
393 }
394
395 async fn execute(
396 &self,
397 args: Value,
398 _ctx: &ToolCallContext<'_>,
399 ) -> Result<ToolResult, ToolError> {
400 self.executed
401 .store(true, std::sync::atomic::Ordering::SeqCst);
402 Ok(ToolResult::success("strict", args))
403 }
404 }
405
406 #[tokio::test]
407 async fn test_validate_args_blocks_invalid_before_execute() {
408 let tool = StrictSchemaTool {
409 executed: std::sync::atomic::AtomicBool::new(false),
410 };
411 let call = ToolCall::new("call_1", "strict", json!({}));
413 let state = json!({});
414
415 let exec = execute_single_tool(Some(&tool), &call, &state).await;
416
417 assert!(exec.result.is_error());
418 assert!(
419 exec.result.message.as_ref().unwrap().contains("name"),
420 "error should mention the missing field"
421 );
422 assert!(
423 !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
424 "execute() must NOT be called when validate_args fails"
425 );
426 }
427
428 #[tokio::test]
429 async fn test_validate_args_passes_valid_to_execute() {
430 let tool = StrictSchemaTool {
431 executed: std::sync::atomic::AtomicBool::new(false),
432 };
433 let call = ToolCall::new("call_1", "strict", json!({"name": "Alice"}));
434 let state = json!({});
435
436 let exec = execute_single_tool(Some(&tool), &call, &state).await;
437
438 assert!(exec.result.is_success());
439 assert!(
440 tool.executed.load(std::sync::atomic::Ordering::SeqCst),
441 "execute() should be called for valid args"
442 );
443 }
444
445 #[tokio::test]
446 async fn test_validate_args_wrong_type_blocks_execute() {
447 let tool = StrictSchemaTool {
448 executed: std::sync::atomic::AtomicBool::new(false),
449 };
450 let call = ToolCall::new("call_1", "strict", json!({"name": 42}));
452 let state = json!({});
453
454 let exec = execute_single_tool(Some(&tool), &call, &state).await;
455
456 assert!(exec.result.is_error());
457 assert!(
458 !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
459 "execute() must NOT be called when validate_args fails"
460 );
461 }
462}