Skip to main content

tirea_agent_loop/engine/
tool_execution.rs

1//! Tool execution utilities.
2
3use crate::contracts::runtime::behavior::AgentBehavior;
4use crate::contracts::runtime::tool_call::ToolCallContext;
5use crate::contracts::runtime::tool_call::{Tool, ToolExecutionEffect, ToolResult};
6pub use crate::contracts::runtime::ToolExecution;
7use crate::contracts::thread::ToolCall;
8use crate::contracts::{reduce_state_actions, AnyStateAction, ScopeContext};
9use futures::future::join_all;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use tirea_contract::RunConfig;
14use tirea_state::{apply_patch, DocCell, Patch, TrackedPatch};
15
16const DIRECT_STATE_WRITE_DENIED_ERROR_CODE: &str = "tool_context_state_write_not_allowed";
17
18pub(crate) fn merge_context_patch_into_effect(
19    call: &ToolCall,
20    _effect: &mut ToolExecutionEffect,
21    context_patch: TrackedPatch,
22) -> Result<(), Box<ToolResult>> {
23    if context_patch.patch().is_empty() {
24        return Ok(());
25    }
26
27    // No compatibility mode: tool-side direct state writes are always rejected.
28    Err(Box::new(ToolResult::error_with_code(
29        &call.name,
30        DIRECT_STATE_WRITE_DENIED_ERROR_CODE,
31        "direct ToolCallContext state writes are disabled; emit ToolExecutionEffect actions instead",
32    )))
33}
34
35/// Execute a single tool call.
36///
37/// This function:
38/// 1. Creates a Context from the state snapshot
39/// 2. Executes the tool
40/// 3. Extracts any state changes as a TrackedPatch
41///
42/// # Arguments
43///
44/// * `tool` - The tool to execute (or None if not found)
45/// * `call` - The tool call with id, name, and arguments
46/// * `state` - The current state snapshot (read-only)
47pub async fn execute_single_tool(
48    tool: Option<&dyn Tool>,
49    call: &ToolCall,
50    state: &Value,
51) -> ToolExecution {
52    execute_single_tool_with_scope_and_behavior(tool, call, state, None, None).await
53}
54
55/// Execute a single tool call with an optional scope context.
56pub async fn execute_single_tool_with_scope(
57    tool: Option<&dyn Tool>,
58    call: &ToolCall,
59    state: &Value,
60    scope: Option<&RunConfig>,
61) -> ToolExecution {
62    execute_single_tool_with_scope_and_behavior(tool, call, state, scope, None).await
63}
64
65/// Execute a single tool call with optional scope and behavior router.
66pub async fn execute_single_tool_with_scope_and_behavior(
67    tool: Option<&dyn Tool>,
68    call: &ToolCall,
69    state: &Value,
70    scope: Option<&RunConfig>,
71    _behavior: Option<&dyn AgentBehavior>,
72) -> ToolExecution {
73    let Some(tool) = tool else {
74        return ToolExecution {
75            call: call.clone(),
76            result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
77            patch: None,
78        };
79    };
80
81    // Create context for this tool call
82    let doc = DocCell::new(state.clone());
83    let ops = Mutex::new(Vec::new());
84    let default_scope = RunConfig::default();
85    let scope = scope.unwrap_or(&default_scope);
86    let pending_messages = Mutex::new(Vec::new());
87    let ctx = ToolCallContext::new(
88        &doc,
89        &ops,
90        &call.id,
91        format!("tool:{}", call.name),
92        scope,
93        &pending_messages,
94        tirea_contract::runtime::activity::NoOpActivityManager::arc(),
95    );
96
97    // Validate arguments against the tool's JSON Schema
98    if let Err(e) = tool.validate_args(&call.arguments) {
99        return ToolExecution {
100            call: call.clone(),
101            result: ToolResult::error(&call.name, e.to_string()),
102            patch: None,
103        };
104    }
105
106    // Execute the tool
107    let mut effect = match tool.execute_effect(call.arguments.clone(), &ctx).await {
108        Ok(effect) => effect,
109        Err(e) => ToolExecutionEffect::from(ToolResult::error(&call.name, e.to_string())),
110    };
111
112    let context_patch = ctx.take_patch();
113    if let Err(result) = merge_context_patch_into_effect(call, &mut effect, context_patch) {
114        return ToolExecution {
115            call: call.clone(),
116            result: *result,
117            patch: None,
118        };
119    }
120    let (result, actions) = effect.into_parts();
121    let state_actions: Vec<AnyStateAction> = actions
122        .into_iter()
123        .filter_map(|a| {
124            if a.is_state_action() {
125                a.into_state_action()
126            } else {
127                None
128            }
129        })
130        .collect();
131
132    let tool_scope_ctx = ScopeContext::for_call(&call.id);
133    let action_patches = match reduce_state_actions(
134        state_actions,
135        state,
136        &format!("tool:{}", call.name),
137        &tool_scope_ctx,
138    ) {
139        Ok(patches) => patches,
140        Err(err) => {
141            return ToolExecution {
142                call: call.clone(),
143                result: ToolResult::error(
144                    &call.name,
145                    format!("tool state action reduce failed: {err}"),
146                ),
147                patch: None,
148            };
149        }
150    };
151
152    let mut merged_patch = Patch::new();
153    for tracked in action_patches {
154        merged_patch.extend(tracked.patch().clone());
155    }
156
157    let patch = if merged_patch.is_empty() {
158        None
159    } else {
160        Some(TrackedPatch::new(merged_patch).with_source(format!("tool:{}", call.name)))
161    };
162
163    ToolExecution {
164        call: call.clone(),
165        result,
166        patch,
167    }
168}
169
170/// Execute tool calls in parallel using the same state snapshot for every call.
171pub async fn execute_tools_parallel(
172    tools: &HashMap<String, Arc<dyn Tool>>,
173    calls: &[ToolCall],
174    state: &Value,
175) -> Vec<ToolExecution> {
176    let tasks = calls.iter().map(|call| {
177        let tool = tools.get(&call.name).cloned();
178        let state = state.clone();
179        async move { execute_single_tool(tool.as_deref(), call, &state).await }
180    });
181    join_all(tasks).await
182}
183
184/// Execute tool calls sequentially, applying each resulting patch before the next call.
185pub async fn execute_tools_sequential(
186    tools: &HashMap<String, Arc<dyn Tool>>,
187    calls: &[ToolCall],
188    state: &Value,
189) -> (Value, Vec<ToolExecution>) {
190    let mut state = state.clone();
191    let mut executions = Vec::with_capacity(calls.len());
192
193    for call in calls {
194        let exec = execute_single_tool(tools.get(&call.name).map(Arc::as_ref), call, &state).await;
195        if let Some(patch) = exec.patch.as_ref() {
196            if let Ok(next) = apply_patch(&state, patch.patch()) {
197                state = next;
198            }
199        }
200        executions.push(exec);
201    }
202
203    (state, executions)
204}
205
206/// Collect patches from executions.
207pub fn collect_patches(executions: &[ToolExecution]) -> Vec<TrackedPatch> {
208    executions.iter().filter_map(|e| e.patch.clone()).collect()
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::contracts::runtime::state::AnyStateAction;
215    use crate::contracts::runtime::state::StateSpec;
216    use crate::contracts::runtime::tool_call::{ToolDescriptor, ToolError};
217    use crate::contracts::ToolCallContext;
218    use async_trait::async_trait;
219    use serde::{Deserialize, Serialize};
220    use serde_json::json;
221    use tirea_contract::testing::TestFixtureState;
222    use tirea_state::{PatchSink, Path as TPath, State, TireaResult};
223
224    struct EchoTool;
225
226    #[async_trait]
227    impl Tool for EchoTool {
228        fn descriptor(&self) -> ToolDescriptor {
229            ToolDescriptor::new("echo", "Echo", "Echo the input")
230        }
231
232        async fn execute(
233            &self,
234            args: Value,
235            _ctx: &ToolCallContext<'_>,
236        ) -> Result<ToolResult, ToolError> {
237            Ok(ToolResult::success("echo", args))
238        }
239    }
240
241    #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
242    struct EffectCounterState {
243        value: i64,
244    }
245
246    struct EffectCounterRef;
247
248    impl State for EffectCounterState {
249        type Ref<'a> = EffectCounterRef;
250        const PATH: &'static str = "counter";
251
252        fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
253            EffectCounterRef
254        }
255
256        fn from_value(value: &Value) -> TireaResult<Self> {
257            if value.is_null() {
258                return Ok(Self::default());
259            }
260            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
261        }
262
263        fn to_value(&self) -> TireaResult<Value> {
264            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
265        }
266    }
267
268    impl StateSpec for EffectCounterState {
269        type Action = i64;
270
271        fn reduce(&mut self, action: Self::Action) {
272            self.value += action;
273        }
274    }
275
276    struct EffectTool;
277
278    #[async_trait]
279    impl Tool for EffectTool {
280        fn descriptor(&self) -> ToolDescriptor {
281            ToolDescriptor::new("effect", "Effect", "Tool returning state actions")
282        }
283
284        async fn execute(
285            &self,
286            _args: Value,
287            _ctx: &ToolCallContext<'_>,
288        ) -> Result<ToolResult, ToolError> {
289            Ok(ToolResult::success("effect", json!({})))
290        }
291
292        async fn execute_effect(
293            &self,
294            _args: Value,
295            _ctx: &ToolCallContext<'_>,
296        ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
297            Ok(
298                crate::contracts::runtime::ToolExecutionEffect::new(ToolResult::success(
299                    "effect",
300                    json!({}),
301                ))
302                .with_action(AnyStateAction::new::<EffectCounterState>(2)),
303            )
304        }
305    }
306
307    struct DirectWriteEffectTool;
308
309    #[async_trait]
310    impl Tool for DirectWriteEffectTool {
311        fn descriptor(&self) -> ToolDescriptor {
312            ToolDescriptor::new(
313                "direct_write_effect",
314                "DirectWriteEffect",
315                "writes state directly in execute_effect",
316            )
317        }
318
319        async fn execute(
320            &self,
321            _args: Value,
322            _ctx: &ToolCallContext<'_>,
323        ) -> Result<ToolResult, ToolError> {
324            Ok(ToolResult::success(
325                "direct_write_effect",
326                json!({"ok": true}),
327            ))
328        }
329
330        async fn execute_effect(
331            &self,
332            _args: Value,
333            ctx: &ToolCallContext<'_>,
334        ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
335            let state = ctx.state_of::<TestFixtureState>();
336            state
337                .set_label(Some("direct_write".to_string()))
338                .expect("failed to set label");
339            Ok(crate::contracts::runtime::ToolExecutionEffect::new(
340                ToolResult::success("direct_write_effect", json!({"ok": true})),
341            ))
342        }
343    }
344
345    #[tokio::test]
346    async fn test_execute_single_tool_not_found() {
347        let call = ToolCall::new("call_1", "nonexistent", json!({}));
348        let state = json!({});
349
350        let exec = execute_single_tool(None, &call, &state).await;
351
352        assert!(exec.result.is_error());
353        assert!(exec.patch.is_none());
354    }
355
356    #[tokio::test]
357    async fn test_execute_single_tool_success() {
358        let tool = EchoTool;
359        let call = ToolCall::new("call_1", "echo", json!({"msg": "hello"}));
360        let state = json!({});
361
362        let exec = execute_single_tool(Some(&tool), &call, &state).await;
363
364        assert!(exec.result.is_success());
365        assert_eq!(exec.result.data["msg"], "hello");
366    }
367
368    #[tokio::test]
369    async fn test_execute_single_tool_applies_state_actions_from_effect() {
370        let tool = EffectTool;
371        let call = ToolCall::new("call_1", "effect", json!({}));
372        let state = json!({"counter": {"value": 1}});
373
374        let exec = execute_single_tool(Some(&tool), &call, &state).await;
375        let patch = exec.patch.expect("patch should be emitted");
376        let next = apply_patch(&state, patch.patch()).expect("patch should apply");
377
378        assert_eq!(next["counter"]["value"], 3);
379    }
380
381    #[tokio::test]
382    async fn test_execute_single_tool_rejects_direct_context_writes_in_strict_mode() {
383        let tool = DirectWriteEffectTool;
384        let call = ToolCall::new("call_1", "direct_write_effect", json!({}));
385        let state = json!({});
386        let scope = RunConfig::default();
387
388        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
389        assert!(exec.result.is_error());
390        assert_eq!(
391            exec.result.data["error"]["code"],
392            json!("tool_context_state_write_not_allowed")
393        );
394        assert!(exec.patch.is_none());
395    }
396
397    #[tokio::test]
398    async fn test_collect_patches() {
399        use tirea_state::{path, Op, Patch};
400
401        let executions = vec![
402            ToolExecution {
403                call: ToolCall::new("1", "a", json!({})),
404                result: ToolResult::success("a", json!({})),
405                patch: Some(TrackedPatch::new(
406                    Patch::new().with_op(Op::set(path!("a"), json!(1))),
407                )),
408            },
409            ToolExecution {
410                call: ToolCall::new("2", "b", json!({})),
411                result: ToolResult::success("b", json!({})),
412                patch: None,
413            },
414            ToolExecution {
415                call: ToolCall::new("3", "c", json!({})),
416                result: ToolResult::success("c", json!({})),
417                patch: Some(TrackedPatch::new(
418                    Patch::new().with_op(Op::set(path!("c"), json!(3))),
419                )),
420            },
421        ];
422
423        let patches = collect_patches(&executions);
424        assert_eq!(patches.len(), 2);
425    }
426
427    #[tokio::test]
428    async fn test_tool_execution_error() {
429        struct FailingTool;
430
431        #[async_trait]
432        impl Tool for FailingTool {
433            fn descriptor(&self) -> ToolDescriptor {
434                ToolDescriptor::new("failing", "Failing", "Always fails")
435            }
436
437            async fn execute(
438                &self,
439                _args: Value,
440                _ctx: &ToolCallContext<'_>,
441            ) -> Result<ToolResult, ToolError> {
442                Err(ToolError::ExecutionFailed(
443                    "Intentional failure".to_string(),
444                ))
445            }
446        }
447
448        let tool = FailingTool;
449        let call = ToolCall::new("call_1", "failing", json!({}));
450        let state = json!({});
451
452        let exec = execute_single_tool(Some(&tool), &call, &state).await;
453
454        assert!(exec.result.is_error());
455        assert!(exec
456            .result
457            .message
458            .as_ref()
459            .unwrap()
460            .contains("Intentional failure"));
461    }
462
463    #[tokio::test]
464    async fn test_execute_single_tool_with_scope_reads() {
465        /// Tool that reads user_id from scope and returns it.
466        struct ScopeReaderTool;
467
468        #[async_trait]
469        impl Tool for ScopeReaderTool {
470            fn descriptor(&self) -> ToolDescriptor {
471                ToolDescriptor::new("scope_reader", "ScopeReader", "Reads scope values")
472            }
473
474            async fn execute(
475                &self,
476                _args: Value,
477                ctx: &ToolCallContext<'_>,
478            ) -> Result<ToolResult, ToolError> {
479                let user_id = ctx
480                    .config_value("user_id")
481                    .and_then(|v| v.as_str())
482                    .unwrap_or("unknown");
483                Ok(ToolResult::success(
484                    "scope_reader",
485                    json!({"user_id": user_id}),
486                ))
487            }
488        }
489
490        let mut scope = RunConfig::new();
491        scope.set("user_id", "u-42").unwrap();
492
493        let tool = ScopeReaderTool;
494        let call = ToolCall::new("call_1", "scope_reader", json!({}));
495        let state = json!({});
496
497        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
498
499        assert!(exec.result.is_success());
500        assert_eq!(exec.result.data["user_id"], "u-42");
501    }
502
503    #[tokio::test]
504    async fn test_execute_single_tool_with_scope_none() {
505        /// Tool that checks scope_ref is None.
506        struct ScopeCheckerTool;
507
508        #[async_trait]
509        impl Tool for ScopeCheckerTool {
510            fn descriptor(&self) -> ToolDescriptor {
511                ToolDescriptor::new("scope_checker", "ScopeChecker", "Checks scope presence")
512            }
513
514            async fn execute(
515                &self,
516                _args: Value,
517                ctx: &ToolCallContext<'_>,
518            ) -> Result<ToolResult, ToolError> {
519                // ToolCallContext always provides a scope reference (never None).
520                // We verify scope access works by probing for a known key.
521                let has_user_id = ctx.config_value("user_id").is_some();
522                Ok(ToolResult::success(
523                    "scope_checker",
524                    json!({"has_scope": true, "has_user_id": has_user_id}),
525                ))
526            }
527        }
528
529        let tool = ScopeCheckerTool;
530        let call = ToolCall::new("call_1", "scope_checker", json!({}));
531        let state = json!({});
532
533        // Without scope — ToolCallContext still provides a (default-empty) scope
534        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, None).await;
535        assert_eq!(exec.result.data["has_scope"], true);
536        assert_eq!(exec.result.data["has_user_id"], false);
537
538        // With scope (empty)
539        let scope = RunConfig::new();
540        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
541        assert_eq!(exec.result.data["has_scope"], true);
542        assert_eq!(exec.result.data["has_user_id"], false);
543    }
544
545    #[tokio::test]
546    async fn test_execute_with_scope_sensitive_key() {
547        /// Tool that reads a sensitive key from scope.
548        struct SensitiveReaderTool;
549
550        #[async_trait]
551        impl Tool for SensitiveReaderTool {
552            fn descriptor(&self) -> ToolDescriptor {
553                ToolDescriptor::new("sensitive", "Sensitive", "Reads sensitive key")
554            }
555
556            async fn execute(
557                &self,
558                _args: Value,
559                ctx: &ToolCallContext<'_>,
560            ) -> Result<ToolResult, ToolError> {
561                let scope = ctx.run_config();
562                let token = scope.value("token").and_then(|v| v.as_str()).unwrap();
563                let is_sensitive = scope.is_sensitive("token");
564                Ok(ToolResult::success(
565                    "sensitive",
566                    json!({"token_len": token.len(), "is_sensitive": is_sensitive}),
567                ))
568            }
569        }
570
571        let mut scope = RunConfig::new();
572        scope.set_sensitive("token", "super-secret-token").unwrap();
573
574        let tool = SensitiveReaderTool;
575        let call = ToolCall::new("call_1", "sensitive", json!({}));
576        let state = json!({});
577
578        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
579
580        assert!(exec.result.is_success());
581        assert_eq!(exec.result.data["token_len"], 18);
582        assert_eq!(exec.result.data["is_sensitive"], true);
583    }
584
585    // =========================================================================
586    // validate_args integration: strict schema blocks invalid args at exec path
587    // =========================================================================
588
589    /// Tool with a strict schema — execute should never be reached on invalid args.
590    struct StrictSchemaTool {
591        executed: std::sync::atomic::AtomicBool,
592    }
593
594    #[async_trait]
595    impl Tool for StrictSchemaTool {
596        fn descriptor(&self) -> ToolDescriptor {
597            ToolDescriptor::new("strict", "Strict", "Requires a string 'name'").with_parameters(
598                json!({
599                    "type": "object",
600                    "properties": {
601                        "name": { "type": "string" }
602                    },
603                    "required": ["name"]
604                }),
605            )
606        }
607
608        async fn execute(
609            &self,
610            args: Value,
611            _ctx: &ToolCallContext<'_>,
612        ) -> Result<ToolResult, ToolError> {
613            self.executed
614                .store(true, std::sync::atomic::Ordering::SeqCst);
615            Ok(ToolResult::success("strict", args))
616        }
617    }
618
619    #[tokio::test]
620    async fn test_validate_args_blocks_invalid_before_execute() {
621        let tool = StrictSchemaTool {
622            executed: std::sync::atomic::AtomicBool::new(false),
623        };
624        // Missing required "name" field
625        let call = ToolCall::new("call_1", "strict", json!({}));
626        let state = json!({});
627
628        let exec = execute_single_tool(Some(&tool), &call, &state).await;
629
630        assert!(exec.result.is_error());
631        assert!(
632            exec.result.message.as_ref().unwrap().contains("name"),
633            "error should mention the missing field"
634        );
635        assert!(
636            !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
637            "execute() must NOT be called when validate_args fails"
638        );
639    }
640
641    #[tokio::test]
642    async fn test_validate_args_passes_valid_to_execute() {
643        let tool = StrictSchemaTool {
644            executed: std::sync::atomic::AtomicBool::new(false),
645        };
646        let call = ToolCall::new("call_1", "strict", json!({"name": "Alice"}));
647        let state = json!({});
648
649        let exec = execute_single_tool(Some(&tool), &call, &state).await;
650
651        assert!(exec.result.is_success());
652        assert!(
653            tool.executed.load(std::sync::atomic::Ordering::SeqCst),
654            "execute() should be called for valid args"
655        );
656    }
657
658    #[tokio::test]
659    async fn test_validate_args_wrong_type_blocks_execute() {
660        let tool = StrictSchemaTool {
661            executed: std::sync::atomic::AtomicBool::new(false),
662        };
663        // "name" should be string, not integer
664        let call = ToolCall::new("call_1", "strict", json!({"name": 42}));
665        let state = json!({});
666
667        let exec = execute_single_tool(Some(&tool), &call, &state).await;
668
669        assert!(exec.result.is_error());
670        assert!(
671            !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
672            "execute() must NOT be called when validate_args fails"
673        );
674    }
675}