Skip to main content

tirea_agent_loop/engine/
tool_execution.rs

1//! Tool execution utilities.
2
3use crate::contracts::runtime::behavior::AgentBehavior;
4use crate::contracts::runtime::tool_call::ToolCallContext;
5use crate::contracts::runtime::tool_call::{Tool, ToolExecutionEffect, ToolResult};
6pub use crate::contracts::runtime::ToolExecution;
7use crate::contracts::thread::ToolCall;
8use crate::contracts::{reduce_state_actions, AnyStateAction, ScopeContext};
9use futures::future::join_all;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use tirea_contract::RunPolicy;
14use tirea_state::{apply_patch, DocCell, Patch, TrackedPatch};
15
16const DIRECT_STATE_WRITE_DENIED_ERROR_CODE: &str = "tool_context_state_write_not_allowed";
17
18pub(crate) fn merge_context_patch_into_effect(
19    call: &ToolCall,
20    _effect: &mut ToolExecutionEffect,
21    context_patch: TrackedPatch,
22) -> Result<(), Box<ToolResult>> {
23    if context_patch.patch().is_empty() {
24        return Ok(());
25    }
26
27    // No compatibility mode: tool-side direct state writes are always rejected.
28    Err(Box::new(ToolResult::error_with_code(
29        &call.name,
30        DIRECT_STATE_WRITE_DENIED_ERROR_CODE,
31        "direct ToolCallContext state writes are disabled; emit ToolExecutionEffect actions instead",
32    )))
33}
34
35/// Execute a single tool call.
36///
37/// This function:
38/// 1. Creates a Context from the state snapshot
39/// 2. Executes the tool
40/// 3. Extracts any state changes as a TrackedPatch
41///
42/// # Arguments
43///
44/// * `tool` - The tool to execute (or None if not found)
45/// * `call` - The tool call with id, name, and arguments
46/// * `state` - The current state snapshot (read-only)
47pub async fn execute_single_tool(
48    tool: Option<&dyn Tool>,
49    call: &ToolCall,
50    state: &Value,
51) -> ToolExecution {
52    execute_single_tool_with_run_policy_and_behavior(tool, call, state, None, None).await
53}
54
55/// Execute a single tool call with optional run policy.
56pub async fn execute_single_tool_with_run_policy(
57    tool: Option<&dyn Tool>,
58    call: &ToolCall,
59    state: &Value,
60    run_policy: Option<&RunPolicy>,
61) -> ToolExecution {
62    execute_single_tool_with_run_policy_and_behavior(tool, call, state, run_policy, None).await
63}
64
65/// Execute a single tool call with optional run policy and behavior router.
66pub async fn execute_single_tool_with_run_policy_and_behavior(
67    tool: Option<&dyn Tool>,
68    call: &ToolCall,
69    state: &Value,
70    run_policy: Option<&RunPolicy>,
71    _behavior: Option<&dyn AgentBehavior>,
72) -> ToolExecution {
73    let Some(tool) = tool else {
74        return ToolExecution {
75            call: call.clone(),
76            result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
77            patch: None,
78        };
79    };
80
81    // Create context for this tool call
82    let doc = DocCell::new(state.clone());
83    let ops = Mutex::new(Vec::new());
84    let default_run_policy = RunPolicy::default();
85    let run_policy = run_policy.unwrap_or(&default_run_policy);
86    let pending_messages = Mutex::new(Vec::new());
87    let ctx = ToolCallContext::new(
88        &doc,
89        &ops,
90        &call.id,
91        format!("tool:{}", call.name),
92        run_policy,
93        &pending_messages,
94        tirea_contract::runtime::activity::NoOpActivityManager::arc(),
95    );
96
97    // Validate arguments against the tool's JSON Schema
98    if let Err(e) = tool.validate_args(&call.arguments) {
99        return ToolExecution {
100            call: call.clone(),
101            result: ToolResult::error(&call.name, e.to_string()),
102            patch: None,
103        };
104    }
105
106    // Execute the tool
107    let mut effect = match tool.execute_effect(call.arguments.clone(), &ctx).await {
108        Ok(effect) => effect,
109        Err(e) => ToolExecutionEffect::from(ToolResult::error(&call.name, e.to_string())),
110    };
111
112    let context_patch = ctx.take_patch();
113    if let Err(result) = merge_context_patch_into_effect(call, &mut effect, context_patch) {
114        return ToolExecution {
115            call: call.clone(),
116            result: *result,
117            patch: None,
118        };
119    }
120    let (result, actions) = effect.into_parts();
121    let state_actions: Vec<AnyStateAction> = actions
122        .into_iter()
123        .filter_map(|a| {
124            if a.is_state_action() {
125                a.into_state_action()
126            } else {
127                None
128            }
129        })
130        .collect();
131    let tool_scope_ctx = ScopeContext::for_call(&call.id);
132    let action_patches = match reduce_state_actions(
133        state_actions,
134        state,
135        &format!("tool:{}", call.name),
136        &tool_scope_ctx,
137    ) {
138        Ok(patches) => patches,
139        Err(err) => {
140            return ToolExecution {
141                call: call.clone(),
142                result: ToolResult::error(
143                    &call.name,
144                    format!("tool state action reduce failed: {err}"),
145                ),
146                patch: None,
147            };
148        }
149    };
150
151    let mut merged_patch = Patch::new();
152    for tracked in action_patches {
153        merged_patch.extend(tracked.patch().clone());
154    }
155
156    let patch = if merged_patch.is_empty() {
157        None
158    } else {
159        Some(TrackedPatch::new(merged_patch).with_source(format!("tool:{}", call.name)))
160    };
161
162    ToolExecution {
163        call: call.clone(),
164        result,
165        patch,
166    }
167}
168
169/// Execute tool calls in parallel using the same state snapshot for every call.
170pub async fn execute_tools_parallel(
171    tools: &HashMap<String, Arc<dyn Tool>>,
172    calls: &[ToolCall],
173    state: &Value,
174) -> Vec<ToolExecution> {
175    let tasks = calls.iter().map(|call| {
176        let tool = tools.get(&call.name).cloned();
177        let state = state.clone();
178        async move { execute_single_tool(tool.as_deref(), call, &state).await }
179    });
180    join_all(tasks).await
181}
182
183/// Execute tool calls sequentially, applying each resulting patch before the next call.
184pub async fn execute_tools_sequential(
185    tools: &HashMap<String, Arc<dyn Tool>>,
186    calls: &[ToolCall],
187    state: &Value,
188) -> (Value, Vec<ToolExecution>) {
189    let mut state = state.clone();
190    let mut executions = Vec::with_capacity(calls.len());
191
192    for call in calls {
193        let exec = execute_single_tool(tools.get(&call.name).map(Arc::as_ref), call, &state).await;
194        if let Some(patch) = exec.patch.as_ref() {
195            if let Ok(next) = apply_patch(&state, patch.patch()) {
196                state = next;
197            }
198        }
199        executions.push(exec);
200    }
201
202    (state, executions)
203}
204
205/// Collect patches from executions.
206pub fn collect_patches(executions: &[ToolExecution]) -> Vec<TrackedPatch> {
207    executions.iter().filter_map(|e| e.patch.clone()).collect()
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::contracts::runtime::state::AnyStateAction;
214    use crate::contracts::runtime::state::StateSpec;
215    use crate::contracts::runtime::tool_call::{ToolDescriptor, ToolError};
216    use crate::contracts::ToolCallContext;
217    use async_trait::async_trait;
218    use serde::{Deserialize, Serialize};
219    use serde_json::json;
220    use tirea_contract::testing::TestFixtureState;
221    use tirea_state::{PatchSink, Path as TPath, State, TireaResult};
222
223    struct EchoTool;
224
225    #[async_trait]
226    impl Tool for EchoTool {
227        fn descriptor(&self) -> ToolDescriptor {
228            ToolDescriptor::new("echo", "Echo", "Echo the input")
229        }
230
231        async fn execute(
232            &self,
233            args: Value,
234            _ctx: &ToolCallContext<'_>,
235        ) -> Result<ToolResult, ToolError> {
236            Ok(ToolResult::success("echo", args))
237        }
238    }
239
240    #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
241    struct EffectCounterState {
242        value: i64,
243    }
244
245    struct EffectCounterRef;
246
247    impl State for EffectCounterState {
248        type Ref<'a> = EffectCounterRef;
249        const PATH: &'static str = "counter";
250
251        fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
252            EffectCounterRef
253        }
254
255        fn from_value(value: &Value) -> TireaResult<Self> {
256            if value.is_null() {
257                return Ok(Self::default());
258            }
259            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
260        }
261
262        fn to_value(&self) -> TireaResult<Value> {
263            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
264        }
265    }
266
267    impl StateSpec for EffectCounterState {
268        type Action = i64;
269
270        fn reduce(&mut self, action: Self::Action) {
271            self.value += action;
272        }
273    }
274
275    struct EffectTool;
276
277    #[async_trait]
278    impl Tool for EffectTool {
279        fn descriptor(&self) -> ToolDescriptor {
280            ToolDescriptor::new("effect", "Effect", "Tool returning state actions")
281        }
282
283        async fn execute(
284            &self,
285            _args: Value,
286            _ctx: &ToolCallContext<'_>,
287        ) -> Result<ToolResult, ToolError> {
288            Ok(ToolResult::success("effect", json!({})))
289        }
290
291        async fn execute_effect(
292            &self,
293            _args: Value,
294            _ctx: &ToolCallContext<'_>,
295        ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
296            Ok(
297                crate::contracts::runtime::ToolExecutionEffect::new(ToolResult::success(
298                    "effect",
299                    json!({}),
300                ))
301                .with_action(AnyStateAction::new::<EffectCounterState>(2)),
302            )
303        }
304    }
305
306    struct DirectWriteEffectTool;
307
308    #[async_trait]
309    impl Tool for DirectWriteEffectTool {
310        fn descriptor(&self) -> ToolDescriptor {
311            ToolDescriptor::new(
312                "direct_write_effect",
313                "DirectWriteEffect",
314                "writes state directly in execute_effect",
315            )
316        }
317
318        async fn execute(
319            &self,
320            _args: Value,
321            _ctx: &ToolCallContext<'_>,
322        ) -> Result<ToolResult, ToolError> {
323            Ok(ToolResult::success(
324                "direct_write_effect",
325                json!({"ok": true}),
326            ))
327        }
328
329        async fn execute_effect(
330            &self,
331            _args: Value,
332            ctx: &ToolCallContext<'_>,
333        ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
334            let state = ctx.state_of::<TestFixtureState>();
335            state
336                .set_label(Some("direct_write".to_string()))
337                .expect("failed to set label");
338            Ok(crate::contracts::runtime::ToolExecutionEffect::new(
339                ToolResult::success("direct_write_effect", json!({"ok": true})),
340            ))
341        }
342    }
343
344    #[tokio::test]
345    async fn test_execute_single_tool_not_found() {
346        let call = ToolCall::new("call_1", "nonexistent", json!({}));
347        let state = json!({});
348
349        let exec = execute_single_tool(None, &call, &state).await;
350
351        assert!(exec.result.is_error());
352        assert!(exec.patch.is_none());
353    }
354
355    #[tokio::test]
356    async fn test_execute_single_tool_success() {
357        let tool = EchoTool;
358        let call = ToolCall::new("call_1", "echo", json!({"msg": "hello"}));
359        let state = json!({});
360
361        let exec = execute_single_tool(Some(&tool), &call, &state).await;
362
363        assert!(exec.result.is_success());
364        assert_eq!(exec.result.data["msg"], "hello");
365    }
366
367    #[tokio::test]
368    async fn test_execute_single_tool_applies_state_actions_from_effect() {
369        let tool = EffectTool;
370        let call = ToolCall::new("call_1", "effect", json!({}));
371        let state = json!({"counter": {"value": 1}});
372
373        let exec = execute_single_tool(Some(&tool), &call, &state).await;
374        let patch = exec.patch.expect("patch should be emitted");
375        let next = apply_patch(&state, patch.patch()).expect("patch should apply");
376
377        assert_eq!(next["counter"]["value"], 3);
378    }
379
380    #[tokio::test]
381    async fn test_execute_single_tool_rejects_direct_context_writes_in_strict_mode() {
382        let tool = DirectWriteEffectTool;
383        let call = ToolCall::new("call_1", "direct_write_effect", json!({}));
384        let state = json!({});
385        let scope = RunPolicy::default();
386
387        let exec =
388            execute_single_tool_with_run_policy(Some(&tool), &call, &state, Some(&scope)).await;
389        assert!(exec.result.is_error());
390        assert_eq!(
391            exec.result.data["error"]["code"],
392            json!("tool_context_state_write_not_allowed")
393        );
394        assert!(exec.patch.is_none());
395    }
396
397    #[tokio::test]
398    async fn test_collect_patches() {
399        use tirea_state::{path, Op, Patch};
400
401        let executions = vec![
402            ToolExecution {
403                call: ToolCall::new("1", "a", json!({})),
404                result: ToolResult::success("a", json!({})),
405                patch: Some(TrackedPatch::new(
406                    Patch::new().with_op(Op::set(path!("a"), json!(1))),
407                )),
408            },
409            ToolExecution {
410                call: ToolCall::new("2", "b", json!({})),
411                result: ToolResult::success("b", json!({})),
412                patch: None,
413            },
414            ToolExecution {
415                call: ToolCall::new("3", "c", json!({})),
416                result: ToolResult::success("c", json!({})),
417                patch: Some(TrackedPatch::new(
418                    Patch::new().with_op(Op::set(path!("c"), json!(3))),
419                )),
420            },
421        ];
422
423        let patches = collect_patches(&executions);
424        assert_eq!(patches.len(), 2);
425    }
426
427    #[tokio::test]
428    async fn test_tool_execution_error() {
429        struct FailingTool;
430
431        #[async_trait]
432        impl Tool for FailingTool {
433            fn descriptor(&self) -> ToolDescriptor {
434                ToolDescriptor::new("failing", "Failing", "Always fails")
435            }
436
437            async fn execute(
438                &self,
439                _args: Value,
440                _ctx: &ToolCallContext<'_>,
441            ) -> Result<ToolResult, ToolError> {
442                Err(ToolError::ExecutionFailed(
443                    "Intentional failure".to_string(),
444                ))
445            }
446        }
447
448        let tool = FailingTool;
449        let call = ToolCall::new("call_1", "failing", json!({}));
450        let state = json!({});
451
452        let exec = execute_single_tool(Some(&tool), &call, &state).await;
453
454        assert!(exec.result.is_error());
455        assert!(exec
456            .result
457            .message
458            .as_ref()
459            .unwrap()
460            .contains("Intentional failure"));
461    }
462
463    #[tokio::test]
464    async fn test_execute_single_tool_with_default_run_identity_has_no_parent_tool_call() {
465        /// Tool that reads the default run identity and returns parent lineage.
466        struct RunIdentityReaderTool;
467
468        #[async_trait]
469        impl Tool for RunIdentityReaderTool {
470            fn descriptor(&self) -> ToolDescriptor {
471                ToolDescriptor::new(
472                    "run_identity_reader",
473                    "RunIdentityReader",
474                    "Reads run identity",
475                )
476            }
477
478            async fn execute(
479                &self,
480                _args: Value,
481                ctx: &ToolCallContext<'_>,
482            ) -> Result<ToolResult, ToolError> {
483                let parent_tool_call_id = ctx
484                    .run_identity()
485                    .parent_tool_call_id_opt()
486                    .unwrap_or("none");
487                Ok(ToolResult::success(
488                    "run_identity_reader",
489                    json!({"parent_tool_call_id": parent_tool_call_id}),
490                ))
491            }
492        }
493
494        let tool = RunIdentityReaderTool;
495        let call = ToolCall::new("call_1", "run_identity_reader", json!({}));
496        let state = json!({});
497
498        let exec = execute_single_tool_with_run_policy(Some(&tool), &call, &state, None).await;
499
500        assert!(exec.result.is_success());
501        assert_eq!(exec.result.data["parent_tool_call_id"], "none");
502    }
503
504    #[tokio::test]
505    async fn test_execute_single_tool_with_run_policy_none() {
506        /// Tool that checks typed run-policy defaults when none are supplied.
507        struct RunPolicyCheckerTool;
508
509        #[async_trait]
510        impl Tool for RunPolicyCheckerTool {
511            fn descriptor(&self) -> ToolDescriptor {
512                ToolDescriptor::new(
513                    "run_policy_checker",
514                    "RunPolicyChecker",
515                    "Checks runtime option presence",
516                )
517            }
518
519            async fn execute(
520                &self,
521                _args: Value,
522                ctx: &ToolCallContext<'_>,
523            ) -> Result<ToolResult, ToolError> {
524                Ok(ToolResult::success(
525                    "run_policy_checker",
526                    json!({
527                        "has_run_policy": true,
528                        "has_parent_tool_call_id": ctx.run_identity().parent_tool_call_id_opt().is_some()
529                    }),
530                ))
531            }
532        }
533
534        let tool = RunPolicyCheckerTool;
535        let call = ToolCall::new("call_1", "run_policy_checker", json!({}));
536        let state = json!({});
537
538        // Without explicit run policy, ToolCallContext still provides defaults.
539        let exec = execute_single_tool_with_run_policy(Some(&tool), &call, &state, None).await;
540        assert_eq!(exec.result.data["has_run_policy"], true);
541        assert_eq!(exec.result.data["has_parent_tool_call_id"], false);
542
543        // With explicit empty run policy.
544        let run_policy = RunPolicy::new();
545        let exec =
546            execute_single_tool_with_run_policy(Some(&tool), &call, &state, Some(&run_policy))
547                .await;
548        assert_eq!(exec.result.data["has_run_policy"], true);
549        assert_eq!(exec.result.data["has_parent_tool_call_id"], false);
550    }
551
552    #[tokio::test]
553    async fn test_execute_with_run_policy() {
554        /// Tool that reads typed policy values from the run policy.
555        struct SensitiveReaderTool;
556
557        #[async_trait]
558        impl Tool for SensitiveReaderTool {
559            fn descriptor(&self) -> ToolDescriptor {
560                ToolDescriptor::new("sensitive", "Sensitive", "Reads sensitive key")
561            }
562
563            async fn execute(
564                &self,
565                _args: Value,
566                ctx: &ToolCallContext<'_>,
567            ) -> Result<ToolResult, ToolError> {
568                let allowed_tools = ctx
569                    .run_policy()
570                    .allowed_tools()
571                    .map(|items| items.to_vec())
572                    .unwrap_or_default();
573                Ok(ToolResult::success(
574                    "sensitive",
575                    json!({"allowed_tools": allowed_tools}),
576                ))
577            }
578        }
579
580        let mut run_policy = RunPolicy::new();
581        run_policy
582            .set_allowed_tools_if_absent(Some(&["sensitive".to_string(), "echo".to_string()]));
583
584        let tool = SensitiveReaderTool;
585        let call = ToolCall::new("call_1", "sensitive", json!({}));
586        let state = json!({});
587
588        let exec =
589            execute_single_tool_with_run_policy(Some(&tool), &call, &state, Some(&run_policy))
590                .await;
591
592        assert!(exec.result.is_success());
593        assert_eq!(
594            exec.result.data["allowed_tools"],
595            json!(["sensitive", "echo"])
596        );
597    }
598
599    // =========================================================================
600    // validate_args integration: strict schema blocks invalid args at exec path
601    // =========================================================================
602
603    /// Tool with a strict schema — execute should never be reached on invalid args.
604    struct StrictSchemaTool {
605        executed: std::sync::atomic::AtomicBool,
606    }
607
608    #[async_trait]
609    impl Tool for StrictSchemaTool {
610        fn descriptor(&self) -> ToolDescriptor {
611            ToolDescriptor::new("strict", "Strict", "Requires a string 'name'").with_parameters(
612                json!({
613                    "type": "object",
614                    "properties": {
615                        "name": { "type": "string" }
616                    },
617                    "required": ["name"]
618                }),
619            )
620        }
621
622        async fn execute(
623            &self,
624            args: Value,
625            _ctx: &ToolCallContext<'_>,
626        ) -> Result<ToolResult, ToolError> {
627            self.executed
628                .store(true, std::sync::atomic::Ordering::SeqCst);
629            Ok(ToolResult::success("strict", args))
630        }
631    }
632
633    #[tokio::test]
634    async fn test_validate_args_blocks_invalid_before_execute() {
635        let tool = StrictSchemaTool {
636            executed: std::sync::atomic::AtomicBool::new(false),
637        };
638        // Missing required "name" field
639        let call = ToolCall::new("call_1", "strict", json!({}));
640        let state = json!({});
641
642        let exec = execute_single_tool(Some(&tool), &call, &state).await;
643
644        assert!(exec.result.is_error());
645        assert!(
646            exec.result.message.as_ref().unwrap().contains("name"),
647            "error should mention the missing field"
648        );
649        assert!(
650            !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
651            "execute() must NOT be called when validate_args fails"
652        );
653    }
654
655    #[tokio::test]
656    async fn test_validate_args_passes_valid_to_execute() {
657        let tool = StrictSchemaTool {
658            executed: std::sync::atomic::AtomicBool::new(false),
659        };
660        let call = ToolCall::new("call_1", "strict", json!({"name": "Alice"}));
661        let state = json!({});
662
663        let exec = execute_single_tool(Some(&tool), &call, &state).await;
664
665        assert!(exec.result.is_success());
666        assert!(
667            tool.executed.load(std::sync::atomic::Ordering::SeqCst),
668            "execute() should be called for valid args"
669        );
670    }
671
672    #[tokio::test]
673    async fn test_validate_args_wrong_type_blocks_execute() {
674        let tool = StrictSchemaTool {
675            executed: std::sync::atomic::AtomicBool::new(false),
676        };
677        // "name" should be string, not integer
678        let call = ToolCall::new("call_1", "strict", json!({"name": 42}));
679        let state = json!({});
680
681        let exec = execute_single_tool(Some(&tool), &call, &state).await;
682
683        assert!(exec.result.is_error());
684        assert!(
685            !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
686            "execute() must NOT be called when validate_args fails"
687        );
688    }
689}