Skip to main content

tirea_agent_loop/engine/
tool_execution.rs

1//! Tool execution utilities.
2
3pub use crate::contracts::runtime::ToolExecution;
4use crate::contracts::thread::ToolCall;
5use crate::contracts::tool::context::ToolCallContext;
6use crate::contracts::tool::{Tool, ToolResult};
7use futures::future::join_all;
8use serde_json::Value;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use tirea_contract::RunConfig;
12use tirea_state::{apply_patch, DocCell, TrackedPatch};
13
14/// Execute a single tool call.
15///
16/// This function:
17/// 1. Creates a Context from the state snapshot
18/// 2. Executes the tool
19/// 3. Extracts any state changes as a TrackedPatch
20///
21/// # Arguments
22///
23/// * `tool` - The tool to execute (or None if not found)
24/// * `call` - The tool call with id, name, and arguments
25/// * `state` - The current state snapshot (read-only)
26pub async fn execute_single_tool(
27    tool: Option<&dyn Tool>,
28    call: &ToolCall,
29    state: &Value,
30) -> ToolExecution {
31    execute_single_tool_with_scope(tool, call, state, None).await
32}
33
34/// Execute a single tool call with an optional scope context.
35pub async fn execute_single_tool_with_scope(
36    tool: Option<&dyn Tool>,
37    call: &ToolCall,
38    state: &Value,
39    scope: Option<&RunConfig>,
40) -> ToolExecution {
41    let Some(tool) = tool else {
42        return ToolExecution {
43            call: call.clone(),
44            result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
45            patch: None,
46        };
47    };
48
49    // Create context for this tool call
50    let doc = DocCell::new(state.clone());
51    let ops = Mutex::new(Vec::new());
52    let default_scope = RunConfig::default();
53    let scope = scope.unwrap_or(&default_scope);
54    let pending_messages = Mutex::new(Vec::new());
55    let ctx = ToolCallContext::new(
56        &doc,
57        &ops,
58        &call.id,
59        format!("tool:{}", call.name),
60        scope,
61        &pending_messages,
62        None,
63    );
64
65    // Validate arguments against the tool's JSON Schema
66    if let Err(e) = tool.validate_args(&call.arguments) {
67        return ToolExecution {
68            call: call.clone(),
69            result: ToolResult::error(&call.name, e.to_string()),
70            patch: None,
71        };
72    }
73
74    // Execute the tool
75    let result = match tool.execute(call.arguments.clone(), &ctx).await {
76        Ok(r) => r,
77        Err(e) => ToolResult::error(&call.name, e.to_string()),
78    };
79
80    // Extract any state changes
81    let patch = ctx.take_patch();
82    let patch = if patch.patch().is_empty() {
83        None
84    } else {
85        Some(patch)
86    };
87
88    ToolExecution {
89        call: call.clone(),
90        result,
91        patch,
92    }
93}
94
95/// Execute tool calls in parallel using the same state snapshot for every call.
96pub async fn execute_tools_parallel(
97    tools: &HashMap<String, Arc<dyn Tool>>,
98    calls: &[ToolCall],
99    state: &Value,
100) -> Vec<ToolExecution> {
101    let tasks = calls.iter().map(|call| {
102        let tool = tools.get(&call.name).cloned();
103        let state = state.clone();
104        async move { execute_single_tool(tool.as_deref(), call, &state).await }
105    });
106    join_all(tasks).await
107}
108
109/// Execute tool calls sequentially, applying each resulting patch before the next call.
110pub async fn execute_tools_sequential(
111    tools: &HashMap<String, Arc<dyn Tool>>,
112    calls: &[ToolCall],
113    state: &Value,
114) -> (Value, Vec<ToolExecution>) {
115    let mut state = state.clone();
116    let mut executions = Vec::with_capacity(calls.len());
117
118    for call in calls {
119        let exec = execute_single_tool(tools.get(&call.name).map(Arc::as_ref), call, &state).await;
120        if let Some(patch) = exec.patch.as_ref() {
121            if let Ok(next) = apply_patch(&state, patch.patch()) {
122                state = next;
123            }
124        }
125        executions.push(exec);
126    }
127
128    (state, executions)
129}
130
131/// Collect patches from executions.
132pub fn collect_patches(executions: &[ToolExecution]) -> Vec<TrackedPatch> {
133    executions.iter().filter_map(|e| e.patch.clone()).collect()
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::contracts::tool::{ToolDescriptor, ToolError};
140    use crate::contracts::ToolCallContext;
141    use async_trait::async_trait;
142    use serde_json::json;
143
144    struct EchoTool;
145
146    #[async_trait]
147    impl Tool for EchoTool {
148        fn descriptor(&self) -> ToolDescriptor {
149            ToolDescriptor::new("echo", "Echo", "Echo the input")
150        }
151
152        async fn execute(
153            &self,
154            args: Value,
155            _ctx: &ToolCallContext<'_>,
156        ) -> Result<ToolResult, ToolError> {
157            Ok(ToolResult::success("echo", args))
158        }
159    }
160
161    #[tokio::test]
162    async fn test_execute_single_tool_not_found() {
163        let call = ToolCall::new("call_1", "nonexistent", json!({}));
164        let state = json!({});
165
166        let exec = execute_single_tool(None, &call, &state).await;
167
168        assert!(exec.result.is_error());
169        assert!(exec.patch.is_none());
170    }
171
172    #[tokio::test]
173    async fn test_execute_single_tool_success() {
174        let tool = EchoTool;
175        let call = ToolCall::new("call_1", "echo", json!({"msg": "hello"}));
176        let state = json!({});
177
178        let exec = execute_single_tool(Some(&tool), &call, &state).await;
179
180        assert!(exec.result.is_success());
181        assert_eq!(exec.result.data["msg"], "hello");
182    }
183
184    #[tokio::test]
185    async fn test_collect_patches() {
186        use tirea_state::{path, Op, Patch};
187
188        let executions = vec![
189            ToolExecution {
190                call: ToolCall::new("1", "a", json!({})),
191                result: ToolResult::success("a", json!({})),
192                patch: Some(TrackedPatch::new(
193                    Patch::new().with_op(Op::set(path!("a"), json!(1))),
194                )),
195            },
196            ToolExecution {
197                call: ToolCall::new("2", "b", json!({})),
198                result: ToolResult::success("b", json!({})),
199                patch: None,
200            },
201            ToolExecution {
202                call: ToolCall::new("3", "c", json!({})),
203                result: ToolResult::success("c", json!({})),
204                patch: Some(TrackedPatch::new(
205                    Patch::new().with_op(Op::set(path!("c"), json!(3))),
206                )),
207            },
208        ];
209
210        let patches = collect_patches(&executions);
211        assert_eq!(patches.len(), 2);
212    }
213
214    #[tokio::test]
215    async fn test_tool_execution_error() {
216        struct FailingTool;
217
218        #[async_trait]
219        impl Tool for FailingTool {
220            fn descriptor(&self) -> ToolDescriptor {
221                ToolDescriptor::new("failing", "Failing", "Always fails")
222            }
223
224            async fn execute(
225                &self,
226                _args: Value,
227                _ctx: &ToolCallContext<'_>,
228            ) -> Result<ToolResult, ToolError> {
229                Err(ToolError::ExecutionFailed(
230                    "Intentional failure".to_string(),
231                ))
232            }
233        }
234
235        let tool = FailingTool;
236        let call = ToolCall::new("call_1", "failing", json!({}));
237        let state = json!({});
238
239        let exec = execute_single_tool(Some(&tool), &call, &state).await;
240
241        assert!(exec.result.is_error());
242        assert!(exec
243            .result
244            .message
245            .as_ref()
246            .unwrap()
247            .contains("Intentional failure"));
248    }
249
250    #[tokio::test]
251    async fn test_execute_single_tool_with_scope_reads() {
252        /// Tool that reads user_id from scope and returns it.
253        struct ScopeReaderTool;
254
255        #[async_trait]
256        impl Tool for ScopeReaderTool {
257            fn descriptor(&self) -> ToolDescriptor {
258                ToolDescriptor::new("scope_reader", "ScopeReader", "Reads scope values")
259            }
260
261            async fn execute(
262                &self,
263                _args: Value,
264                ctx: &ToolCallContext<'_>,
265            ) -> Result<ToolResult, ToolError> {
266                let user_id = ctx
267                    .config_value("user_id")
268                    .and_then(|v| v.as_str())
269                    .unwrap_or("unknown");
270                Ok(ToolResult::success(
271                    "scope_reader",
272                    json!({"user_id": user_id}),
273                ))
274            }
275        }
276
277        let mut scope = RunConfig::new();
278        scope.set("user_id", "u-42").unwrap();
279
280        let tool = ScopeReaderTool;
281        let call = ToolCall::new("call_1", "scope_reader", json!({}));
282        let state = json!({});
283
284        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
285
286        assert!(exec.result.is_success());
287        assert_eq!(exec.result.data["user_id"], "u-42");
288    }
289
290    #[tokio::test]
291    async fn test_execute_single_tool_with_scope_none() {
292        /// Tool that checks scope_ref is None.
293        struct ScopeCheckerTool;
294
295        #[async_trait]
296        impl Tool for ScopeCheckerTool {
297            fn descriptor(&self) -> ToolDescriptor {
298                ToolDescriptor::new("scope_checker", "ScopeChecker", "Checks scope presence")
299            }
300
301            async fn execute(
302                &self,
303                _args: Value,
304                ctx: &ToolCallContext<'_>,
305            ) -> Result<ToolResult, ToolError> {
306                // ToolCallContext always provides a scope reference (never None).
307                // We verify scope access works by probing for a known key.
308                let has_user_id = ctx.config_value("user_id").is_some();
309                Ok(ToolResult::success(
310                    "scope_checker",
311                    json!({"has_scope": true, "has_user_id": has_user_id}),
312                ))
313            }
314        }
315
316        let tool = ScopeCheckerTool;
317        let call = ToolCall::new("call_1", "scope_checker", json!({}));
318        let state = json!({});
319
320        // Without scope — ToolCallContext still provides a (default-empty) scope
321        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, None).await;
322        assert_eq!(exec.result.data["has_scope"], true);
323        assert_eq!(exec.result.data["has_user_id"], false);
324
325        // With scope (empty)
326        let scope = RunConfig::new();
327        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
328        assert_eq!(exec.result.data["has_scope"], true);
329        assert_eq!(exec.result.data["has_user_id"], false);
330    }
331
332    #[tokio::test]
333    async fn test_execute_with_scope_sensitive_key() {
334        /// Tool that reads a sensitive key from scope.
335        struct SensitiveReaderTool;
336
337        #[async_trait]
338        impl Tool for SensitiveReaderTool {
339            fn descriptor(&self) -> ToolDescriptor {
340                ToolDescriptor::new("sensitive", "Sensitive", "Reads sensitive key")
341            }
342
343            async fn execute(
344                &self,
345                _args: Value,
346                ctx: &ToolCallContext<'_>,
347            ) -> Result<ToolResult, ToolError> {
348                let scope = ctx.run_config();
349                let token = scope.value("token").and_then(|v| v.as_str()).unwrap();
350                let is_sensitive = scope.is_sensitive("token");
351                Ok(ToolResult::success(
352                    "sensitive",
353                    json!({"token_len": token.len(), "is_sensitive": is_sensitive}),
354                ))
355            }
356        }
357
358        let mut scope = RunConfig::new();
359        scope.set_sensitive("token", "super-secret-token").unwrap();
360
361        let tool = SensitiveReaderTool;
362        let call = ToolCall::new("call_1", "sensitive", json!({}));
363        let state = json!({});
364
365        let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
366
367        assert!(exec.result.is_success());
368        assert_eq!(exec.result.data["token_len"], 18);
369        assert_eq!(exec.result.data["is_sensitive"], true);
370    }
371
372    // =========================================================================
373    // validate_args integration: strict schema blocks invalid args at exec path
374    // =========================================================================
375
376    /// Tool with a strict schema — execute should never be reached on invalid args.
377    struct StrictSchemaTool {
378        executed: std::sync::atomic::AtomicBool,
379    }
380
381    #[async_trait]
382    impl Tool for StrictSchemaTool {
383        fn descriptor(&self) -> ToolDescriptor {
384            ToolDescriptor::new("strict", "Strict", "Requires a string 'name'").with_parameters(
385                json!({
386                    "type": "object",
387                    "properties": {
388                        "name": { "type": "string" }
389                    },
390                    "required": ["name"]
391                }),
392            )
393        }
394
395        async fn execute(
396            &self,
397            args: Value,
398            _ctx: &ToolCallContext<'_>,
399        ) -> Result<ToolResult, ToolError> {
400            self.executed
401                .store(true, std::sync::atomic::Ordering::SeqCst);
402            Ok(ToolResult::success("strict", args))
403        }
404    }
405
406    #[tokio::test]
407    async fn test_validate_args_blocks_invalid_before_execute() {
408        let tool = StrictSchemaTool {
409            executed: std::sync::atomic::AtomicBool::new(false),
410        };
411        // Missing required "name" field
412        let call = ToolCall::new("call_1", "strict", json!({}));
413        let state = json!({});
414
415        let exec = execute_single_tool(Some(&tool), &call, &state).await;
416
417        assert!(exec.result.is_error());
418        assert!(
419            exec.result.message.as_ref().unwrap().contains("name"),
420            "error should mention the missing field"
421        );
422        assert!(
423            !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
424            "execute() must NOT be called when validate_args fails"
425        );
426    }
427
428    #[tokio::test]
429    async fn test_validate_args_passes_valid_to_execute() {
430        let tool = StrictSchemaTool {
431            executed: std::sync::atomic::AtomicBool::new(false),
432        };
433        let call = ToolCall::new("call_1", "strict", json!({"name": "Alice"}));
434        let state = json!({});
435
436        let exec = execute_single_tool(Some(&tool), &call, &state).await;
437
438        assert!(exec.result.is_success());
439        assert!(
440            tool.executed.load(std::sync::atomic::Ordering::SeqCst),
441            "execute() should be called for valid args"
442        );
443    }
444
445    #[tokio::test]
446    async fn test_validate_args_wrong_type_blocks_execute() {
447        let tool = StrictSchemaTool {
448            executed: std::sync::atomic::AtomicBool::new(false),
449        };
450        // "name" should be string, not integer
451        let call = ToolCall::new("call_1", "strict", json!({"name": 42}));
452        let state = json!({});
453
454        let exec = execute_single_tool(Some(&tool), &call, &state).await;
455
456        assert!(exec.result.is_error());
457        assert!(
458            !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
459            "execute() must NOT be called when validate_args fails"
460        );
461    }
462}