1use crate::contracts::runtime::behavior::AgentBehavior;
4use crate::contracts::runtime::tool_call::ToolCallContext;
5use crate::contracts::runtime::tool_call::{Tool, ToolExecutionEffect, ToolResult};
6pub use crate::contracts::runtime::ToolExecution;
7use crate::contracts::thread::ToolCall;
8use crate::contracts::{reduce_state_actions, AnyStateAction, ScopeContext};
9use futures::future::join_all;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use tirea_contract::RunPolicy;
14use tirea_state::{apply_patch, DocCell, Patch, TrackedPatch};
15
16const DIRECT_STATE_WRITE_DENIED_ERROR_CODE: &str = "tool_context_state_write_not_allowed";
17
18pub(crate) fn merge_context_patch_into_effect(
19 call: &ToolCall,
20 _effect: &mut ToolExecutionEffect,
21 context_patch: TrackedPatch,
22) -> Result<(), Box<ToolResult>> {
23 if context_patch.patch().is_empty() {
24 return Ok(());
25 }
26
27 Err(Box::new(ToolResult::error_with_code(
29 &call.name,
30 DIRECT_STATE_WRITE_DENIED_ERROR_CODE,
31 "direct ToolCallContext state writes are disabled; emit ToolExecutionEffect actions instead",
32 )))
33}
34
35pub async fn execute_single_tool(
48 tool: Option<&dyn Tool>,
49 call: &ToolCall,
50 state: &Value,
51) -> ToolExecution {
52 execute_single_tool_with_run_policy_and_behavior(tool, call, state, None, None).await
53}
54
55pub async fn execute_single_tool_with_run_policy(
57 tool: Option<&dyn Tool>,
58 call: &ToolCall,
59 state: &Value,
60 run_policy: Option<&RunPolicy>,
61) -> ToolExecution {
62 execute_single_tool_with_run_policy_and_behavior(tool, call, state, run_policy, None).await
63}
64
65pub async fn execute_single_tool_with_run_policy_and_behavior(
67 tool: Option<&dyn Tool>,
68 call: &ToolCall,
69 state: &Value,
70 run_policy: Option<&RunPolicy>,
71 _behavior: Option<&dyn AgentBehavior>,
72) -> ToolExecution {
73 let Some(tool) = tool else {
74 return ToolExecution {
75 call: call.clone(),
76 result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
77 patch: None,
78 };
79 };
80
81 let doc = DocCell::new(state.clone());
83 let ops = Mutex::new(Vec::new());
84 let default_run_policy = RunPolicy::default();
85 let run_policy = run_policy.unwrap_or(&default_run_policy);
86 let pending_messages = Mutex::new(Vec::new());
87 let ctx = ToolCallContext::new(
88 &doc,
89 &ops,
90 &call.id,
91 format!("tool:{}", call.name),
92 run_policy,
93 &pending_messages,
94 tirea_contract::runtime::activity::NoOpActivityManager::arc(),
95 );
96
97 if let Err(e) = tool.validate_args(&call.arguments) {
99 return ToolExecution {
100 call: call.clone(),
101 result: ToolResult::error(&call.name, e.to_string()),
102 patch: None,
103 };
104 }
105
106 let mut effect = match tool.execute_effect(call.arguments.clone(), &ctx).await {
108 Ok(effect) => effect,
109 Err(e) => ToolExecutionEffect::from(ToolResult::error(&call.name, e.to_string())),
110 };
111
112 let context_patch = ctx.take_patch();
113 if let Err(result) = merge_context_patch_into_effect(call, &mut effect, context_patch) {
114 return ToolExecution {
115 call: call.clone(),
116 result: *result,
117 patch: None,
118 };
119 }
120 let (result, actions) = effect.into_parts();
121 let state_actions: Vec<AnyStateAction> = actions
122 .into_iter()
123 .filter_map(|a| {
124 if a.is_state_action() {
125 a.into_state_action()
126 } else {
127 None
128 }
129 })
130 .collect();
131 let tool_scope_ctx = ScopeContext::for_call(&call.id);
132 let action_patches = match reduce_state_actions(
133 state_actions,
134 state,
135 &format!("tool:{}", call.name),
136 &tool_scope_ctx,
137 ) {
138 Ok(patches) => patches,
139 Err(err) => {
140 return ToolExecution {
141 call: call.clone(),
142 result: ToolResult::error(
143 &call.name,
144 format!("tool state action reduce failed: {err}"),
145 ),
146 patch: None,
147 };
148 }
149 };
150
151 let mut merged_patch = Patch::new();
152 for tracked in action_patches {
153 merged_patch.extend(tracked.patch().clone());
154 }
155
156 let patch = if merged_patch.is_empty() {
157 None
158 } else {
159 Some(TrackedPatch::new(merged_patch).with_source(format!("tool:{}", call.name)))
160 };
161
162 ToolExecution {
163 call: call.clone(),
164 result,
165 patch,
166 }
167}
168
169pub async fn execute_tools_parallel(
171 tools: &HashMap<String, Arc<dyn Tool>>,
172 calls: &[ToolCall],
173 state: &Value,
174) -> Vec<ToolExecution> {
175 let tasks = calls.iter().map(|call| {
176 let tool = tools.get(&call.name).cloned();
177 let state = state.clone();
178 async move { execute_single_tool(tool.as_deref(), call, &state).await }
179 });
180 join_all(tasks).await
181}
182
183pub async fn execute_tools_sequential(
185 tools: &HashMap<String, Arc<dyn Tool>>,
186 calls: &[ToolCall],
187 state: &Value,
188) -> (Value, Vec<ToolExecution>) {
189 let mut state = state.clone();
190 let mut executions = Vec::with_capacity(calls.len());
191
192 for call in calls {
193 let exec = execute_single_tool(tools.get(&call.name).map(Arc::as_ref), call, &state).await;
194 if let Some(patch) = exec.patch.as_ref() {
195 if let Ok(next) = apply_patch(&state, patch.patch()) {
196 state = next;
197 }
198 }
199 executions.push(exec);
200 }
201
202 (state, executions)
203}
204
205pub fn collect_patches(executions: &[ToolExecution]) -> Vec<TrackedPatch> {
207 executions.iter().filter_map(|e| e.patch.clone()).collect()
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::contracts::runtime::state::AnyStateAction;
214 use crate::contracts::runtime::state::StateSpec;
215 use crate::contracts::runtime::tool_call::{ToolDescriptor, ToolError};
216 use crate::contracts::ToolCallContext;
217 use async_trait::async_trait;
218 use serde::{Deserialize, Serialize};
219 use serde_json::json;
220 use tirea_contract::testing::TestFixtureState;
221 use tirea_state::{PatchSink, Path as TPath, State, TireaResult};
222
223 struct EchoTool;
224
225 #[async_trait]
226 impl Tool for EchoTool {
227 fn descriptor(&self) -> ToolDescriptor {
228 ToolDescriptor::new("echo", "Echo", "Echo the input")
229 }
230
231 async fn execute(
232 &self,
233 args: Value,
234 _ctx: &ToolCallContext<'_>,
235 ) -> Result<ToolResult, ToolError> {
236 Ok(ToolResult::success("echo", args))
237 }
238 }
239
240 #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
241 struct EffectCounterState {
242 value: i64,
243 }
244
245 struct EffectCounterRef;
246
247 impl State for EffectCounterState {
248 type Ref<'a> = EffectCounterRef;
249 const PATH: &'static str = "counter";
250
251 fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
252 EffectCounterRef
253 }
254
255 fn from_value(value: &Value) -> TireaResult<Self> {
256 if value.is_null() {
257 return Ok(Self::default());
258 }
259 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
260 }
261
262 fn to_value(&self) -> TireaResult<Value> {
263 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
264 }
265 }
266
267 impl StateSpec for EffectCounterState {
268 type Action = i64;
269
270 fn reduce(&mut self, action: Self::Action) {
271 self.value += action;
272 }
273 }
274
275 struct EffectTool;
276
277 #[async_trait]
278 impl Tool for EffectTool {
279 fn descriptor(&self) -> ToolDescriptor {
280 ToolDescriptor::new("effect", "Effect", "Tool returning state actions")
281 }
282
283 async fn execute(
284 &self,
285 _args: Value,
286 _ctx: &ToolCallContext<'_>,
287 ) -> Result<ToolResult, ToolError> {
288 Ok(ToolResult::success("effect", json!({})))
289 }
290
291 async fn execute_effect(
292 &self,
293 _args: Value,
294 _ctx: &ToolCallContext<'_>,
295 ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
296 Ok(
297 crate::contracts::runtime::ToolExecutionEffect::new(ToolResult::success(
298 "effect",
299 json!({}),
300 ))
301 .with_action(AnyStateAction::new::<EffectCounterState>(2)),
302 )
303 }
304 }
305
306 struct DirectWriteEffectTool;
307
308 #[async_trait]
309 impl Tool for DirectWriteEffectTool {
310 fn descriptor(&self) -> ToolDescriptor {
311 ToolDescriptor::new(
312 "direct_write_effect",
313 "DirectWriteEffect",
314 "writes state directly in execute_effect",
315 )
316 }
317
318 async fn execute(
319 &self,
320 _args: Value,
321 _ctx: &ToolCallContext<'_>,
322 ) -> Result<ToolResult, ToolError> {
323 Ok(ToolResult::success(
324 "direct_write_effect",
325 json!({"ok": true}),
326 ))
327 }
328
329 async fn execute_effect(
330 &self,
331 _args: Value,
332 ctx: &ToolCallContext<'_>,
333 ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
334 let state = ctx.state_of::<TestFixtureState>();
335 state
336 .set_label(Some("direct_write".to_string()))
337 .expect("failed to set label");
338 Ok(crate::contracts::runtime::ToolExecutionEffect::new(
339 ToolResult::success("direct_write_effect", json!({"ok": true})),
340 ))
341 }
342 }
343
344 #[tokio::test]
345 async fn test_execute_single_tool_not_found() {
346 let call = ToolCall::new("call_1", "nonexistent", json!({}));
347 let state = json!({});
348
349 let exec = execute_single_tool(None, &call, &state).await;
350
351 assert!(exec.result.is_error());
352 assert!(exec.patch.is_none());
353 }
354
355 #[tokio::test]
356 async fn test_execute_single_tool_success() {
357 let tool = EchoTool;
358 let call = ToolCall::new("call_1", "echo", json!({"msg": "hello"}));
359 let state = json!({});
360
361 let exec = execute_single_tool(Some(&tool), &call, &state).await;
362
363 assert!(exec.result.is_success());
364 assert_eq!(exec.result.data["msg"], "hello");
365 }
366
367 #[tokio::test]
368 async fn test_execute_single_tool_applies_state_actions_from_effect() {
369 let tool = EffectTool;
370 let call = ToolCall::new("call_1", "effect", json!({}));
371 let state = json!({"counter": {"value": 1}});
372
373 let exec = execute_single_tool(Some(&tool), &call, &state).await;
374 let patch = exec.patch.expect("patch should be emitted");
375 let next = apply_patch(&state, patch.patch()).expect("patch should apply");
376
377 assert_eq!(next["counter"]["value"], 3);
378 }
379
380 #[tokio::test]
381 async fn test_execute_single_tool_rejects_direct_context_writes_in_strict_mode() {
382 let tool = DirectWriteEffectTool;
383 let call = ToolCall::new("call_1", "direct_write_effect", json!({}));
384 let state = json!({});
385 let scope = RunPolicy::default();
386
387 let exec =
388 execute_single_tool_with_run_policy(Some(&tool), &call, &state, Some(&scope)).await;
389 assert!(exec.result.is_error());
390 assert_eq!(
391 exec.result.data["error"]["code"],
392 json!("tool_context_state_write_not_allowed")
393 );
394 assert!(exec.patch.is_none());
395 }
396
397 #[tokio::test]
398 async fn test_collect_patches() {
399 use tirea_state::{path, Op, Patch};
400
401 let executions = vec![
402 ToolExecution {
403 call: ToolCall::new("1", "a", json!({})),
404 result: ToolResult::success("a", json!({})),
405 patch: Some(TrackedPatch::new(
406 Patch::new().with_op(Op::set(path!("a"), json!(1))),
407 )),
408 },
409 ToolExecution {
410 call: ToolCall::new("2", "b", json!({})),
411 result: ToolResult::success("b", json!({})),
412 patch: None,
413 },
414 ToolExecution {
415 call: ToolCall::new("3", "c", json!({})),
416 result: ToolResult::success("c", json!({})),
417 patch: Some(TrackedPatch::new(
418 Patch::new().with_op(Op::set(path!("c"), json!(3))),
419 )),
420 },
421 ];
422
423 let patches = collect_patches(&executions);
424 assert_eq!(patches.len(), 2);
425 }
426
427 #[tokio::test]
428 async fn test_tool_execution_error() {
429 struct FailingTool;
430
431 #[async_trait]
432 impl Tool for FailingTool {
433 fn descriptor(&self) -> ToolDescriptor {
434 ToolDescriptor::new("failing", "Failing", "Always fails")
435 }
436
437 async fn execute(
438 &self,
439 _args: Value,
440 _ctx: &ToolCallContext<'_>,
441 ) -> Result<ToolResult, ToolError> {
442 Err(ToolError::ExecutionFailed(
443 "Intentional failure".to_string(),
444 ))
445 }
446 }
447
448 let tool = FailingTool;
449 let call = ToolCall::new("call_1", "failing", json!({}));
450 let state = json!({});
451
452 let exec = execute_single_tool(Some(&tool), &call, &state).await;
453
454 assert!(exec.result.is_error());
455 assert!(exec
456 .result
457 .message
458 .as_ref()
459 .unwrap()
460 .contains("Intentional failure"));
461 }
462
463 #[tokio::test]
464 async fn test_execute_single_tool_with_default_run_identity_has_no_parent_tool_call() {
465 struct RunIdentityReaderTool;
467
468 #[async_trait]
469 impl Tool for RunIdentityReaderTool {
470 fn descriptor(&self) -> ToolDescriptor {
471 ToolDescriptor::new(
472 "run_identity_reader",
473 "RunIdentityReader",
474 "Reads run identity",
475 )
476 }
477
478 async fn execute(
479 &self,
480 _args: Value,
481 ctx: &ToolCallContext<'_>,
482 ) -> Result<ToolResult, ToolError> {
483 let parent_tool_call_id = ctx
484 .run_identity()
485 .parent_tool_call_id_opt()
486 .unwrap_or("none");
487 Ok(ToolResult::success(
488 "run_identity_reader",
489 json!({"parent_tool_call_id": parent_tool_call_id}),
490 ))
491 }
492 }
493
494 let tool = RunIdentityReaderTool;
495 let call = ToolCall::new("call_1", "run_identity_reader", json!({}));
496 let state = json!({});
497
498 let exec = execute_single_tool_with_run_policy(Some(&tool), &call, &state, None).await;
499
500 assert!(exec.result.is_success());
501 assert_eq!(exec.result.data["parent_tool_call_id"], "none");
502 }
503
504 #[tokio::test]
505 async fn test_execute_single_tool_with_run_policy_none() {
506 struct RunPolicyCheckerTool;
508
509 #[async_trait]
510 impl Tool for RunPolicyCheckerTool {
511 fn descriptor(&self) -> ToolDescriptor {
512 ToolDescriptor::new(
513 "run_policy_checker",
514 "RunPolicyChecker",
515 "Checks runtime option presence",
516 )
517 }
518
519 async fn execute(
520 &self,
521 _args: Value,
522 ctx: &ToolCallContext<'_>,
523 ) -> Result<ToolResult, ToolError> {
524 Ok(ToolResult::success(
525 "run_policy_checker",
526 json!({
527 "has_run_policy": true,
528 "has_parent_tool_call_id": ctx.run_identity().parent_tool_call_id_opt().is_some()
529 }),
530 ))
531 }
532 }
533
534 let tool = RunPolicyCheckerTool;
535 let call = ToolCall::new("call_1", "run_policy_checker", json!({}));
536 let state = json!({});
537
538 let exec = execute_single_tool_with_run_policy(Some(&tool), &call, &state, None).await;
540 assert_eq!(exec.result.data["has_run_policy"], true);
541 assert_eq!(exec.result.data["has_parent_tool_call_id"], false);
542
543 let run_policy = RunPolicy::new();
545 let exec =
546 execute_single_tool_with_run_policy(Some(&tool), &call, &state, Some(&run_policy))
547 .await;
548 assert_eq!(exec.result.data["has_run_policy"], true);
549 assert_eq!(exec.result.data["has_parent_tool_call_id"], false);
550 }
551
552 #[tokio::test]
553 async fn test_execute_with_run_policy() {
554 struct SensitiveReaderTool;
556
557 #[async_trait]
558 impl Tool for SensitiveReaderTool {
559 fn descriptor(&self) -> ToolDescriptor {
560 ToolDescriptor::new("sensitive", "Sensitive", "Reads sensitive key")
561 }
562
563 async fn execute(
564 &self,
565 _args: Value,
566 ctx: &ToolCallContext<'_>,
567 ) -> Result<ToolResult, ToolError> {
568 let allowed_tools = ctx
569 .run_policy()
570 .allowed_tools()
571 .map(|items| items.to_vec())
572 .unwrap_or_default();
573 Ok(ToolResult::success(
574 "sensitive",
575 json!({"allowed_tools": allowed_tools}),
576 ))
577 }
578 }
579
580 let mut run_policy = RunPolicy::new();
581 run_policy
582 .set_allowed_tools_if_absent(Some(&["sensitive".to_string(), "echo".to_string()]));
583
584 let tool = SensitiveReaderTool;
585 let call = ToolCall::new("call_1", "sensitive", json!({}));
586 let state = json!({});
587
588 let exec =
589 execute_single_tool_with_run_policy(Some(&tool), &call, &state, Some(&run_policy))
590 .await;
591
592 assert!(exec.result.is_success());
593 assert_eq!(
594 exec.result.data["allowed_tools"],
595 json!(["sensitive", "echo"])
596 );
597 }
598
599 struct StrictSchemaTool {
605 executed: std::sync::atomic::AtomicBool,
606 }
607
608 #[async_trait]
609 impl Tool for StrictSchemaTool {
610 fn descriptor(&self) -> ToolDescriptor {
611 ToolDescriptor::new("strict", "Strict", "Requires a string 'name'").with_parameters(
612 json!({
613 "type": "object",
614 "properties": {
615 "name": { "type": "string" }
616 },
617 "required": ["name"]
618 }),
619 )
620 }
621
622 async fn execute(
623 &self,
624 args: Value,
625 _ctx: &ToolCallContext<'_>,
626 ) -> Result<ToolResult, ToolError> {
627 self.executed
628 .store(true, std::sync::atomic::Ordering::SeqCst);
629 Ok(ToolResult::success("strict", args))
630 }
631 }
632
633 #[tokio::test]
634 async fn test_validate_args_blocks_invalid_before_execute() {
635 let tool = StrictSchemaTool {
636 executed: std::sync::atomic::AtomicBool::new(false),
637 };
638 let call = ToolCall::new("call_1", "strict", json!({}));
640 let state = json!({});
641
642 let exec = execute_single_tool(Some(&tool), &call, &state).await;
643
644 assert!(exec.result.is_error());
645 assert!(
646 exec.result.message.as_ref().unwrap().contains("name"),
647 "error should mention the missing field"
648 );
649 assert!(
650 !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
651 "execute() must NOT be called when validate_args fails"
652 );
653 }
654
655 #[tokio::test]
656 async fn test_validate_args_passes_valid_to_execute() {
657 let tool = StrictSchemaTool {
658 executed: std::sync::atomic::AtomicBool::new(false),
659 };
660 let call = ToolCall::new("call_1", "strict", json!({"name": "Alice"}));
661 let state = json!({});
662
663 let exec = execute_single_tool(Some(&tool), &call, &state).await;
664
665 assert!(exec.result.is_success());
666 assert!(
667 tool.executed.load(std::sync::atomic::Ordering::SeqCst),
668 "execute() should be called for valid args"
669 );
670 }
671
672 #[tokio::test]
673 async fn test_validate_args_wrong_type_blocks_execute() {
674 let tool = StrictSchemaTool {
675 executed: std::sync::atomic::AtomicBool::new(false),
676 };
677 let call = ToolCall::new("call_1", "strict", json!({"name": 42}));
679 let state = json!({});
680
681 let exec = execute_single_tool(Some(&tool), &call, &state).await;
682
683 assert!(exec.result.is_error());
684 assert!(
685 !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
686 "execute() must NOT be called when validate_args fails"
687 );
688 }
689}