Skip to main content

tirea_agent_loop/engine/
tool_execution.rs

1//! Tool execution utilities.
2
3use crate::contracts::runtime::behavior::AgentBehavior;
4use crate::contracts::runtime::tool_call::ToolCallContext;
5use crate::contracts::runtime::tool_call::{Tool, ToolExecutionEffect, ToolResult};
6pub use crate::contracts::runtime::ToolExecution;
7use crate::contracts::thread::ToolCall;
8use crate::contracts::{reduce_state_actions, AnyStateAction, ScopeContext};
9use futures::future::join_all;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use tirea_contract::RunConfig;
14use tirea_state::{apply_patch, DocCell, Patch, TrackedPatch};
15
16const DIRECT_STATE_WRITE_DENIED_ERROR_CODE: &str = "tool_context_state_write_not_allowed";
17
18pub(crate) fn merge_context_patch_into_effect(
19    call: &ToolCall,
20    _effect: &mut ToolExecutionEffect,
21    context_patch: TrackedPatch,
22) -> Result<(), Box<ToolResult>> {
23    if context_patch.patch().is_empty() {
24        return Ok(());
25    }
26
27    // No compatibility mode: tool-side direct state writes are always rejected.
28    Err(Box::new(ToolResult::error_with_code(
29        &call.name,
30        DIRECT_STATE_WRITE_DENIED_ERROR_CODE,
31        "direct ToolCallContext state writes are disabled; emit ToolExecutionEffect actions instead",
32    )))
33}
34
35/// Execute a single tool call.
36///
37/// This function:
38/// 1. Creates a Context from the state snapshot
39/// 2. Executes the tool
40/// 3. Extracts any state changes as a TrackedPatch
41///
42/// # Arguments
43///
44/// * `tool` - The tool to execute (or None if not found)
45/// * `call` - The tool call with id, name, and arguments
46/// * `state` - The current state snapshot (read-only)
47pub async fn execute_single_tool(
48    tool: Option<&dyn Tool>,
49    call: &ToolCall,
50    state: &Value,
51) -> ToolExecution {
52    execute_single_tool_with_scope_and_behavior(tool, call, state, None, None).await
53}
54
55/// Execute a single tool call with an optional scope context.
56pub async fn execute_single_tool_with_scope(
57    tool: Option<&dyn Tool>,
58    call: &ToolCall,
59    state: &Value,
60    scope: Option<&RunConfig>,
61) -> ToolExecution {
62    execute_single_tool_with_scope_and_behavior(tool, call, state, scope, None).await
63}
64
65/// Execute a single tool call with optional scope and behavior router.
66pub async fn execute_single_tool_with_scope_and_behavior(
67    tool: Option<&dyn Tool>,
68    call: &ToolCall,
69    state: &Value,
70    scope: Option<&RunConfig>,
71    _behavior: Option<&dyn AgentBehavior>,
72) -> ToolExecution {
73    let Some(tool) = tool else {
74        return ToolExecution {
75            call: call.clone(),
76            result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
77            patch: None,
78        };
79    };
80
81    // Create context for this tool call
82    let doc = DocCell::new(state.clone());
83    let ops = Mutex::new(Vec::new());
84    let default_scope = RunConfig::default();
85    let scope = scope.unwrap_or(&default_scope);
86    let pending_messages = Mutex::new(Vec::new());
87    let ctx = ToolCallContext::new(
88        &doc,
89        &ops,
90        &call.id,
91        format!("tool:{}", call.name),
92        scope,
93        &pending_messages,
94        tirea_contract::runtime::activity::NoOpActivityManager::arc(),
95    );
96
97    // Validate arguments against the tool's JSON Schema
98    if let Err(e) = tool.validate_args(&call.arguments) {
99        return ToolExecution {
100            call: call.clone(),
101            result: ToolResult::error(&call.name, e.to_string()),
102            patch: None,
103        };
104    }
105
106    // Execute the tool
107    let mut effect = match tool.execute_effect(call.arguments.clone(), &ctx).await {
108        Ok(effect) => effect,
109        Err(e) => ToolExecutionEffect::from(ToolResult::error(&call.name, e.to_string())),
110    };
111
112    let context_patch = ctx.take_patch();
113    if let Err(result) = merge_context_patch_into_effect(call, &mut effect, context_patch) {
114        return ToolExecution {
115            call: call.clone(),
116            result: *result,
117            patch: None,
118        };
119    }
120    let (result, actions) = effect.into_parts();
121    let state_actions: Vec<AnyStateAction> = actions
122        .into_iter()
123        .filter_map(|a| {
124            if a.is_state_action() {
125                a.into_state_action()
126            } else {
127                None
128            }
129        })
130        .collect();
131    let tool_scope_ctx = ScopeContext::for_call(&call.id);
132    let action_patches = match reduce_state_actions(
133        state_actions,
134        state,
135        &format!("tool:{}", call.name),
136        &tool_scope_ctx,
137    ) {
138        Ok(patches) => patches,
139        Err(err) => {
140            return ToolExecution {
141                call: call.clone(),
142                result: ToolResult::error(
143                    &call.name,
144                    format!("tool state action reduce failed: {err}"),
145                ),
146                patch: None,
147            };
148        }
149    };
150
151    let mut merged_patch = Patch::new();
152    for tracked in action_patches {
153        merged_patch.extend(tracked.patch().clone());
154    }
155
156    let patch = if merged_patch.is_empty() {
157        None
158    } else {
159        Some(TrackedPatch::new(merged_patch).with_source(format!("tool:{}", call.name)))
160    };
161
162    ToolExecution {
163        call: call.clone(),
164        result,
165        patch,
166    }
167}
168
169/// Execute tool calls in parallel using the same state snapshot for every call.
170pub async fn execute_tools_parallel(
171    tools: &HashMap<String, Arc<dyn Tool>>,
172    calls: &[ToolCall],
173    state: &Value,
174) -> Vec<ToolExecution> {
175    let tasks = calls.iter().map(|call| {
176        let tool = tools.get(&call.name).cloned();
177        let state = state.clone();
178        async move { execute_single_tool(tool.as_deref(), call, &state).await }
179    });
180    join_all(tasks).await
181}
182
183/// Execute tool calls sequentially, applying each resulting patch before the next call.
184pub async fn execute_tools_sequential(
185    tools: &HashMap<String, Arc<dyn Tool>>,
186    calls: &[ToolCall],
187    state: &Value,
188) -> (Value, Vec<ToolExecution>) {
189    let mut state = state.clone();
190    let mut executions = Vec::with_capacity(calls.len());
191
192    for call in calls {
193        let exec = execute_single_tool(tools.get(&call.name).map(Arc::as_ref), call, &state).await;
194        if let Some(patch) = exec.patch.as_ref() {
195            if let Ok(next) = apply_patch(&state, patch.patch()) {
196                state = next;
197            }
198        }
199        executions.push(exec);
200    }
201
202    (state, executions)
203}
204
205/// Collect patches from executions.
206pub fn collect_patches(executions: &[ToolExecution]) -> Vec<TrackedPatch> {
207    executions.iter().filter_map(|e| e.patch.clone()).collect()
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use crate::contracts::runtime::state::AnyStateAction;
214    use crate::contracts::runtime::state::StateSpec;
215    use crate::contracts::runtime::tool_call::{ToolDescriptor, ToolError};
216    use crate::contracts::ToolCallContext;
217    use async_trait::async_trait;
218    use serde::{Deserialize, Serialize};
219    use serde_json::json;
220    use tirea_contract::testing::TestFixtureState;
221    use tirea_state::{PatchSink, Path as TPath, State, TireaResult};
222
223    struct EchoTool;
224
225    #[async_trait]
226    impl Tool for EchoTool {
227        fn descriptor(&self) -> ToolDescriptor {
228            ToolDescriptor::new("echo", "Echo", "Echo the input")
229        }
230
231        async fn execute(
232            &self,
233            args: Value,
234            _ctx: &ToolCallContext<'_>,
235        ) -> Result<ToolResult, ToolError> {
236            Ok(ToolResult::success("echo", args))
237        }
238    }
239
240    #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
241    struct EffectCounterState {
242        value: i64,
243    }
244
245    struct EffectCounterRef;
246
247    impl State for EffectCounterState {
248        type Ref<'a> = EffectCounterRef;
249        const PATH: &'static str = "counter";
250
251        fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
252            EffectCounterRef
253        }
254
255        fn from_value(value: &Value) -> TireaResult<Self> {
256            if value.is_null() {
257                return Ok(Self::default());
258            }
259            serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
260        }
261
262        fn to_value(&self) -> TireaResult<Value> {
263            serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
264        }
265    }
266
267    impl StateSpec for EffectCounterState {
268        type Action = i64;
269
270        fn reduce(&mut self, action: Self::Action) {
271            self.value += action;
272        }
273    }
274
275    struct EffectTool;
276
277    #[async_trait]
278    impl Tool for EffectTool {
279        fn descriptor(&self) -> ToolDescriptor {
280            ToolDescriptor::new("effect", "Effect", "Tool returning state actions")
281        }
282
283        async fn execute(
284            &self,
285            _args: Value,
286            _ctx: &ToolCallContext<'_>,
287        ) -> Result<ToolResult, ToolError> {
288            Ok(ToolResult::success("effect", json!({})))
289        }
290
291        async fn execute_effect(
292            &self,
293            _args: Value,
294            _ctx: &ToolCallContext<'_>,
295        ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
296            Ok(
297                crate::contracts::runtime::ToolExecutionEffect::new(ToolResult::success(
298                    "effect",
299                    json!({}),
300                ))
301                .with_action(AnyStateAction::new::<EffectCounterState>(2)),
302            )
303        }
304    }
305
306    struct DirectWriteEffectTool;
307
308    #[async_trait]
309    impl Tool for DirectWriteEffectTool {
310        fn descriptor(&self) -> ToolDescriptor {
311            ToolDescriptor::new(
312                "direct_write_effect",
313                "DirectWriteEffect",
314                "writes state directly in execute_effect",
315            )
316        }
317
318        async fn execute(
319            &self,
320            _args: Value,
321            _ctx: &ToolCallContext<'_>,
322        ) -> Result<ToolResult, ToolError> {
323            Ok(ToolResult::success(
324                "direct_write_effect",
325                json!({"ok": true}),
326            ))
327        }
328
329        async fn execute_effect(
330            &self,
331            _args: Value,
332            ctx: &ToolCallContext<'_>,
333        ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
334            let state = ctx.state_of::<TestFixtureState>();
335            state
336                .set_label(Some("direct_write".to_string()))
337                .expect("failed to set label");
338            Ok(crate::contracts::runtime::ToolExecutionEffect::new(
339                ToolResult::success("direct_write_effect", json!({"ok": true})),
340            ))
341        }
342    }
343
344    #[tokio::test]
345    async fn test_execute_single_tool_not_found() {
346        let call = ToolCall::new("call_1", "nonexistent", json!({}));
347        let state = json!({});
348
349        let exec = execute_single_tool(None, &call, &state).await;
350
351        assert!(exec.result.is_error());
352        assert!(exec.patch.is_none());
353    }
354
355    #[tokio::test]
356    async fn test_execute_single_tool_success() {
357        let tool = EchoTool;
358        let call = ToolCall::new("call_1", "echo", json!({"msg": "hello"}));
359        let state = json!({});
360
361        let exec = execute_single_tool(Some(&tool), &call, &state).await;
362
363        assert!(exec.result.is_success());
364        assert_eq!(exec.result.data["msg"], "hello");
365    }
366
367    #[tokio::test]
368    async fn test_execute_single_tool_applies_state_actions_from_effect() {
369        let tool = EffectTool;
370        let call = ToolCall::new("call_1", "effect", json!({}));
371        let state = json!({"counter": {"value": 1}});
372
373        let exec = execute_single_tool(Some(&tool), &call, &state).await;
374        let patch = exec.patch.expect("patch should be emitted");
375        let next = apply_patch(&state, patch.patch()).expect("patch should apply");
376
377        assert_eq!(next["counter"]["value"], 3);
378    }
379
380    #[tokio::test]
381    async fn test_execute_single_tool_rejects_direct_context_writes_in_strict_mode() {
382        let tool = DirectWriteEffectTool;
383        let call = ToolCall::new("call_1", "direct_write_effect", json!({}));
384        let state = json!({});
385        let scope = RunConfig::default();
386
387        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
388        assert!(exec.result.is_error());
389        assert_eq!(
390            exec.result.data["error"]["code"],
391            json!("tool_context_state_write_not_allowed")
392        );
393        assert!(exec.patch.is_none());
394    }
395
396    #[tokio::test]
397    async fn test_collect_patches() {
398        use tirea_state::{path, Op, Patch};
399
400        let executions = vec![
401            ToolExecution {
402                call: ToolCall::new("1", "a", json!({})),
403                result: ToolResult::success("a", json!({})),
404                patch: Some(TrackedPatch::new(
405                    Patch::new().with_op(Op::set(path!("a"), json!(1))),
406                )),
407            },
408            ToolExecution {
409                call: ToolCall::new("2", "b", json!({})),
410                result: ToolResult::success("b", json!({})),
411                patch: None,
412            },
413            ToolExecution {
414                call: ToolCall::new("3", "c", json!({})),
415                result: ToolResult::success("c", json!({})),
416                patch: Some(TrackedPatch::new(
417                    Patch::new().with_op(Op::set(path!("c"), json!(3))),
418                )),
419            },
420        ];
421
422        let patches = collect_patches(&executions);
423        assert_eq!(patches.len(), 2);
424    }
425
426    #[tokio::test]
427    async fn test_tool_execution_error() {
428        struct FailingTool;
429
430        #[async_trait]
431        impl Tool for FailingTool {
432            fn descriptor(&self) -> ToolDescriptor {
433                ToolDescriptor::new("failing", "Failing", "Always fails")
434            }
435
436            async fn execute(
437                &self,
438                _args: Value,
439                _ctx: &ToolCallContext<'_>,
440            ) -> Result<ToolResult, ToolError> {
441                Err(ToolError::ExecutionFailed(
442                    "Intentional failure".to_string(),
443                ))
444            }
445        }
446
447        let tool = FailingTool;
448        let call = ToolCall::new("call_1", "failing", json!({}));
449        let state = json!({});
450
451        let exec = execute_single_tool(Some(&tool), &call, &state).await;
452
453        assert!(exec.result.is_error());
454        assert!(exec
455            .result
456            .message
457            .as_ref()
458            .unwrap()
459            .contains("Intentional failure"));
460    }
461
462    #[tokio::test]
463    async fn test_execute_single_tool_with_scope_reads() {
464        /// Tool that reads user_id from scope and returns it.
465        struct ScopeReaderTool;
466
467        #[async_trait]
468        impl Tool for ScopeReaderTool {
469            fn descriptor(&self) -> ToolDescriptor {
470                ToolDescriptor::new("scope_reader", "ScopeReader", "Reads scope values")
471            }
472
473            async fn execute(
474                &self,
475                _args: Value,
476                ctx: &ToolCallContext<'_>,
477            ) -> Result<ToolResult, ToolError> {
478                let user_id = ctx
479                    .config_value("user_id")
480                    .and_then(|v| v.as_str())
481                    .unwrap_or("unknown");
482                Ok(ToolResult::success(
483                    "scope_reader",
484                    json!({"user_id": user_id}),
485                ))
486            }
487        }
488
489        let mut scope = RunConfig::new();
490        scope.set("user_id", "u-42").unwrap();
491
492        let tool = ScopeReaderTool;
493        let call = ToolCall::new("call_1", "scope_reader", json!({}));
494        let state = json!({});
495
496        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
497
498        assert!(exec.result.is_success());
499        assert_eq!(exec.result.data["user_id"], "u-42");
500    }
501
502    #[tokio::test]
503    async fn test_execute_single_tool_with_scope_none() {
504        /// Tool that checks scope_ref is None.
505        struct ScopeCheckerTool;
506
507        #[async_trait]
508        impl Tool for ScopeCheckerTool {
509            fn descriptor(&self) -> ToolDescriptor {
510                ToolDescriptor::new("scope_checker", "ScopeChecker", "Checks scope presence")
511            }
512
513            async fn execute(
514                &self,
515                _args: Value,
516                ctx: &ToolCallContext<'_>,
517            ) -> Result<ToolResult, ToolError> {
518                // ToolCallContext always provides a scope reference (never None).
519                // We verify scope access works by probing for a known key.
520                let has_user_id = ctx.config_value("user_id").is_some();
521                Ok(ToolResult::success(
522                    "scope_checker",
523                    json!({"has_scope": true, "has_user_id": has_user_id}),
524                ))
525            }
526        }
527
528        let tool = ScopeCheckerTool;
529        let call = ToolCall::new("call_1", "scope_checker", json!({}));
530        let state = json!({});
531
532        // Without scope — ToolCallContext still provides a (default-empty) scope
533        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, None).await;
534        assert_eq!(exec.result.data["has_scope"], true);
535        assert_eq!(exec.result.data["has_user_id"], false);
536
537        // With scope (empty)
538        let scope = RunConfig::new();
539        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
540        assert_eq!(exec.result.data["has_scope"], true);
541        assert_eq!(exec.result.data["has_user_id"], false);
542    }
543
544    #[tokio::test]
545    async fn test_execute_with_scope_sensitive_key() {
546        /// Tool that reads a sensitive key from scope.
547        struct SensitiveReaderTool;
548
549        #[async_trait]
550        impl Tool for SensitiveReaderTool {
551            fn descriptor(&self) -> ToolDescriptor {
552                ToolDescriptor::new("sensitive", "Sensitive", "Reads sensitive key")
553            }
554
555            async fn execute(
556                &self,
557                _args: Value,
558                ctx: &ToolCallContext<'_>,
559            ) -> Result<ToolResult, ToolError> {
560                let scope = ctx.run_config();
561                let token = scope.value("token").and_then(|v| v.as_str()).unwrap();
562                let is_sensitive = scope.is_sensitive("token");
563                Ok(ToolResult::success(
564                    "sensitive",
565                    json!({"token_len": token.len(), "is_sensitive": is_sensitive}),
566                ))
567            }
568        }
569
570        let mut scope = RunConfig::new();
571        scope.set_sensitive("token", "super-secret-token").unwrap();
572
573        let tool = SensitiveReaderTool;
574        let call = ToolCall::new("call_1", "sensitive", json!({}));
575        let state = json!({});
576
577        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
578
579        assert!(exec.result.is_success());
580        assert_eq!(exec.result.data["token_len"], 18);
581        assert_eq!(exec.result.data["is_sensitive"], true);
582    }
583
584    // =========================================================================
585    // validate_args integration: strict schema blocks invalid args at exec path
586    // =========================================================================
587
588    /// Tool with a strict schema — execute should never be reached on invalid args.
589    struct StrictSchemaTool {
590        executed: std::sync::atomic::AtomicBool,
591    }
592
593    #[async_trait]
594    impl Tool for StrictSchemaTool {
595        fn descriptor(&self) -> ToolDescriptor {
596            ToolDescriptor::new("strict", "Strict", "Requires a string 'name'").with_parameters(
597                json!({
598                    "type": "object",
599                    "properties": {
600                        "name": { "type": "string" }
601                    },
602                    "required": ["name"]
603                }),
604            )
605        }
606
607        async fn execute(
608            &self,
609            args: Value,
610            _ctx: &ToolCallContext<'_>,
611        ) -> Result<ToolResult, ToolError> {
612            self.executed
613                .store(true, std::sync::atomic::Ordering::SeqCst);
614            Ok(ToolResult::success("strict", args))
615        }
616    }
617
618    #[tokio::test]
619    async fn test_validate_args_blocks_invalid_before_execute() {
620        let tool = StrictSchemaTool {
621            executed: std::sync::atomic::AtomicBool::new(false),
622        };
623        // Missing required "name" field
624        let call = ToolCall::new("call_1", "strict", json!({}));
625        let state = json!({});
626
627        let exec = execute_single_tool(Some(&tool), &call, &state).await;
628
629        assert!(exec.result.is_error());
630        assert!(
631            exec.result.message.as_ref().unwrap().contains("name"),
632            "error should mention the missing field"
633        );
634        assert!(
635            !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
636            "execute() must NOT be called when validate_args fails"
637        );
638    }
639
640    #[tokio::test]
641    async fn test_validate_args_passes_valid_to_execute() {
642        let tool = StrictSchemaTool {
643            executed: std::sync::atomic::AtomicBool::new(false),
644        };
645        let call = ToolCall::new("call_1", "strict", json!({"name": "Alice"}));
646        let state = json!({});
647
648        let exec = execute_single_tool(Some(&tool), &call, &state).await;
649
650        assert!(exec.result.is_success());
651        assert!(
652            tool.executed.load(std::sync::atomic::Ordering::SeqCst),
653            "execute() should be called for valid args"
654        );
655    }
656
657    #[tokio::test]
658    async fn test_validate_args_wrong_type_blocks_execute() {
659        let tool = StrictSchemaTool {
660            executed: std::sync::atomic::AtomicBool::new(false),
661        };
662        // "name" should be string, not integer
663        let call = ToolCall::new("call_1", "strict", json!({"name": 42}));
664        let state = json!({});
665
666        let exec = execute_single_tool(Some(&tool), &call, &state).await;
667
668        assert!(exec.result.is_error());
669        assert!(
670            !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
671            "execute() must NOT be called when validate_args fails"
672        );
673    }
674}