1use crate::contracts::runtime::behavior::AgentBehavior;
4use crate::contracts::runtime::tool_call::ToolCallContext;
5use crate::contracts::runtime::tool_call::{Tool, ToolExecutionEffect, ToolResult};
6pub use crate::contracts::runtime::ToolExecution;
7use crate::contracts::thread::ToolCall;
8use crate::contracts::{reduce_state_actions, AnyStateAction, ScopeContext};
9use futures::future::join_all;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use tirea_contract::RunConfig;
14use tirea_state::{apply_patch, DocCell, Patch, TrackedPatch};
15
16const DIRECT_STATE_WRITE_DENIED_ERROR_CODE: &str = "tool_context_state_write_not_allowed";
17
18pub(crate) fn merge_context_patch_into_effect(
19 call: &ToolCall,
20 _effect: &mut ToolExecutionEffect,
21 context_patch: TrackedPatch,
22) -> Result<(), Box<ToolResult>> {
23 if context_patch.patch().is_empty() {
24 return Ok(());
25 }
26
27 Err(Box::new(ToolResult::error_with_code(
29 &call.name,
30 DIRECT_STATE_WRITE_DENIED_ERROR_CODE,
31 "direct ToolCallContext state writes are disabled; emit ToolExecutionEffect actions instead",
32 )))
33}
34
35pub async fn execute_single_tool(
48 tool: Option<&dyn Tool>,
49 call: &ToolCall,
50 state: &Value,
51) -> ToolExecution {
52 execute_single_tool_with_scope_and_behavior(tool, call, state, None, None).await
53}
54
55pub async fn execute_single_tool_with_scope(
57 tool: Option<&dyn Tool>,
58 call: &ToolCall,
59 state: &Value,
60 scope: Option<&RunConfig>,
61) -> ToolExecution {
62 execute_single_tool_with_scope_and_behavior(tool, call, state, scope, None).await
63}
64
65pub async fn execute_single_tool_with_scope_and_behavior(
67 tool: Option<&dyn Tool>,
68 call: &ToolCall,
69 state: &Value,
70 scope: Option<&RunConfig>,
71 _behavior: Option<&dyn AgentBehavior>,
72) -> ToolExecution {
73 let Some(tool) = tool else {
74 return ToolExecution {
75 call: call.clone(),
76 result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
77 patch: None,
78 };
79 };
80
81 let doc = DocCell::new(state.clone());
83 let ops = Mutex::new(Vec::new());
84 let default_scope = RunConfig::default();
85 let scope = scope.unwrap_or(&default_scope);
86 let pending_messages = Mutex::new(Vec::new());
87 let ctx = ToolCallContext::new(
88 &doc,
89 &ops,
90 &call.id,
91 format!("tool:{}", call.name),
92 scope,
93 &pending_messages,
94 tirea_contract::runtime::activity::NoOpActivityManager::arc(),
95 );
96
97 if let Err(e) = tool.validate_args(&call.arguments) {
99 return ToolExecution {
100 call: call.clone(),
101 result: ToolResult::error(&call.name, e.to_string()),
102 patch: None,
103 };
104 }
105
106 let mut effect = match tool.execute_effect(call.arguments.clone(), &ctx).await {
108 Ok(effect) => effect,
109 Err(e) => ToolExecutionEffect::from(ToolResult::error(&call.name, e.to_string())),
110 };
111
112 let context_patch = ctx.take_patch();
113 if let Err(result) = merge_context_patch_into_effect(call, &mut effect, context_patch) {
114 return ToolExecution {
115 call: call.clone(),
116 result: *result,
117 patch: None,
118 };
119 }
120 let (result, actions) = effect.into_parts();
121 let state_actions: Vec<AnyStateAction> = actions
122 .into_iter()
123 .filter_map(|a| {
124 if a.is_state_action() {
125 a.into_state_action()
126 } else {
127 None
128 }
129 })
130 .collect();
131 let tool_scope_ctx = ScopeContext::for_call(&call.id);
132 let action_patches = match reduce_state_actions(
133 state_actions,
134 state,
135 &format!("tool:{}", call.name),
136 &tool_scope_ctx,
137 ) {
138 Ok(patches) => patches,
139 Err(err) => {
140 return ToolExecution {
141 call: call.clone(),
142 result: ToolResult::error(
143 &call.name,
144 format!("tool state action reduce failed: {err}"),
145 ),
146 patch: None,
147 };
148 }
149 };
150
151 let mut merged_patch = Patch::new();
152 for tracked in action_patches {
153 merged_patch.extend(tracked.patch().clone());
154 }
155
156 let patch = if merged_patch.is_empty() {
157 None
158 } else {
159 Some(TrackedPatch::new(merged_patch).with_source(format!("tool:{}", call.name)))
160 };
161
162 ToolExecution {
163 call: call.clone(),
164 result,
165 patch,
166 }
167}
168
169pub async fn execute_tools_parallel(
171 tools: &HashMap<String, Arc<dyn Tool>>,
172 calls: &[ToolCall],
173 state: &Value,
174) -> Vec<ToolExecution> {
175 let tasks = calls.iter().map(|call| {
176 let tool = tools.get(&call.name).cloned();
177 let state = state.clone();
178 async move { execute_single_tool(tool.as_deref(), call, &state).await }
179 });
180 join_all(tasks).await
181}
182
183pub async fn execute_tools_sequential(
185 tools: &HashMap<String, Arc<dyn Tool>>,
186 calls: &[ToolCall],
187 state: &Value,
188) -> (Value, Vec<ToolExecution>) {
189 let mut state = state.clone();
190 let mut executions = Vec::with_capacity(calls.len());
191
192 for call in calls {
193 let exec = execute_single_tool(tools.get(&call.name).map(Arc::as_ref), call, &state).await;
194 if let Some(patch) = exec.patch.as_ref() {
195 if let Ok(next) = apply_patch(&state, patch.patch()) {
196 state = next;
197 }
198 }
199 executions.push(exec);
200 }
201
202 (state, executions)
203}
204
205pub fn collect_patches(executions: &[ToolExecution]) -> Vec<TrackedPatch> {
207 executions.iter().filter_map(|e| e.patch.clone()).collect()
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::contracts::runtime::state::AnyStateAction;
214 use crate::contracts::runtime::state::StateSpec;
215 use crate::contracts::runtime::tool_call::{ToolDescriptor, ToolError};
216 use crate::contracts::ToolCallContext;
217 use async_trait::async_trait;
218 use serde::{Deserialize, Serialize};
219 use serde_json::json;
220 use tirea_contract::testing::TestFixtureState;
221 use tirea_state::{PatchSink, Path as TPath, State, TireaResult};
222
223 struct EchoTool;
224
225 #[async_trait]
226 impl Tool for EchoTool {
227 fn descriptor(&self) -> ToolDescriptor {
228 ToolDescriptor::new("echo", "Echo", "Echo the input")
229 }
230
231 async fn execute(
232 &self,
233 args: Value,
234 _ctx: &ToolCallContext<'_>,
235 ) -> Result<ToolResult, ToolError> {
236 Ok(ToolResult::success("echo", args))
237 }
238 }
239
240 #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
241 struct EffectCounterState {
242 value: i64,
243 }
244
245 struct EffectCounterRef;
246
247 impl State for EffectCounterState {
248 type Ref<'a> = EffectCounterRef;
249 const PATH: &'static str = "counter";
250
251 fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
252 EffectCounterRef
253 }
254
255 fn from_value(value: &Value) -> TireaResult<Self> {
256 if value.is_null() {
257 return Ok(Self::default());
258 }
259 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
260 }
261
262 fn to_value(&self) -> TireaResult<Value> {
263 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
264 }
265 }
266
267 impl StateSpec for EffectCounterState {
268 type Action = i64;
269
270 fn reduce(&mut self, action: Self::Action) {
271 self.value += action;
272 }
273 }
274
275 struct EffectTool;
276
277 #[async_trait]
278 impl Tool for EffectTool {
279 fn descriptor(&self) -> ToolDescriptor {
280 ToolDescriptor::new("effect", "Effect", "Tool returning state actions")
281 }
282
283 async fn execute(
284 &self,
285 _args: Value,
286 _ctx: &ToolCallContext<'_>,
287 ) -> Result<ToolResult, ToolError> {
288 Ok(ToolResult::success("effect", json!({})))
289 }
290
291 async fn execute_effect(
292 &self,
293 _args: Value,
294 _ctx: &ToolCallContext<'_>,
295 ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
296 Ok(
297 crate::contracts::runtime::ToolExecutionEffect::new(ToolResult::success(
298 "effect",
299 json!({}),
300 ))
301 .with_action(AnyStateAction::new::<EffectCounterState>(2)),
302 )
303 }
304 }
305
306 struct DirectWriteEffectTool;
307
308 #[async_trait]
309 impl Tool for DirectWriteEffectTool {
310 fn descriptor(&self) -> ToolDescriptor {
311 ToolDescriptor::new(
312 "direct_write_effect",
313 "DirectWriteEffect",
314 "writes state directly in execute_effect",
315 )
316 }
317
318 async fn execute(
319 &self,
320 _args: Value,
321 _ctx: &ToolCallContext<'_>,
322 ) -> Result<ToolResult, ToolError> {
323 Ok(ToolResult::success(
324 "direct_write_effect",
325 json!({"ok": true}),
326 ))
327 }
328
329 async fn execute_effect(
330 &self,
331 _args: Value,
332 ctx: &ToolCallContext<'_>,
333 ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
334 let state = ctx.state_of::<TestFixtureState>();
335 state
336 .set_label(Some("direct_write".to_string()))
337 .expect("failed to set label");
338 Ok(crate::contracts::runtime::ToolExecutionEffect::new(
339 ToolResult::success("direct_write_effect", json!({"ok": true})),
340 ))
341 }
342 }
343
344 #[tokio::test]
345 async fn test_execute_single_tool_not_found() {
346 let call = ToolCall::new("call_1", "nonexistent", json!({}));
347 let state = json!({});
348
349 let exec = execute_single_tool(None, &call, &state).await;
350
351 assert!(exec.result.is_error());
352 assert!(exec.patch.is_none());
353 }
354
355 #[tokio::test]
356 async fn test_execute_single_tool_success() {
357 let tool = EchoTool;
358 let call = ToolCall::new("call_1", "echo", json!({"msg": "hello"}));
359 let state = json!({});
360
361 let exec = execute_single_tool(Some(&tool), &call, &state).await;
362
363 assert!(exec.result.is_success());
364 assert_eq!(exec.result.data["msg"], "hello");
365 }
366
367 #[tokio::test]
368 async fn test_execute_single_tool_applies_state_actions_from_effect() {
369 let tool = EffectTool;
370 let call = ToolCall::new("call_1", "effect", json!({}));
371 let state = json!({"counter": {"value": 1}});
372
373 let exec = execute_single_tool(Some(&tool), &call, &state).await;
374 let patch = exec.patch.expect("patch should be emitted");
375 let next = apply_patch(&state, patch.patch()).expect("patch should apply");
376
377 assert_eq!(next["counter"]["value"], 3);
378 }
379
380 #[tokio::test]
381 async fn test_execute_single_tool_rejects_direct_context_writes_in_strict_mode() {
382 let tool = DirectWriteEffectTool;
383 let call = ToolCall::new("call_1", "direct_write_effect", json!({}));
384 let state = json!({});
385 let scope = RunConfig::default();
386
387 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
388 assert!(exec.result.is_error());
389 assert_eq!(
390 exec.result.data["error"]["code"],
391 json!("tool_context_state_write_not_allowed")
392 );
393 assert!(exec.patch.is_none());
394 }
395
396 #[tokio::test]
397 async fn test_collect_patches() {
398 use tirea_state::{path, Op, Patch};
399
400 let executions = vec![
401 ToolExecution {
402 call: ToolCall::new("1", "a", json!({})),
403 result: ToolResult::success("a", json!({})),
404 patch: Some(TrackedPatch::new(
405 Patch::new().with_op(Op::set(path!("a"), json!(1))),
406 )),
407 },
408 ToolExecution {
409 call: ToolCall::new("2", "b", json!({})),
410 result: ToolResult::success("b", json!({})),
411 patch: None,
412 },
413 ToolExecution {
414 call: ToolCall::new("3", "c", json!({})),
415 result: ToolResult::success("c", json!({})),
416 patch: Some(TrackedPatch::new(
417 Patch::new().with_op(Op::set(path!("c"), json!(3))),
418 )),
419 },
420 ];
421
422 let patches = collect_patches(&executions);
423 assert_eq!(patches.len(), 2);
424 }
425
426 #[tokio::test]
427 async fn test_tool_execution_error() {
428 struct FailingTool;
429
430 #[async_trait]
431 impl Tool for FailingTool {
432 fn descriptor(&self) -> ToolDescriptor {
433 ToolDescriptor::new("failing", "Failing", "Always fails")
434 }
435
436 async fn execute(
437 &self,
438 _args: Value,
439 _ctx: &ToolCallContext<'_>,
440 ) -> Result<ToolResult, ToolError> {
441 Err(ToolError::ExecutionFailed(
442 "Intentional failure".to_string(),
443 ))
444 }
445 }
446
447 let tool = FailingTool;
448 let call = ToolCall::new("call_1", "failing", json!({}));
449 let state = json!({});
450
451 let exec = execute_single_tool(Some(&tool), &call, &state).await;
452
453 assert!(exec.result.is_error());
454 assert!(exec
455 .result
456 .message
457 .as_ref()
458 .unwrap()
459 .contains("Intentional failure"));
460 }
461
462 #[tokio::test]
463 async fn test_execute_single_tool_with_scope_reads() {
464 struct ScopeReaderTool;
466
467 #[async_trait]
468 impl Tool for ScopeReaderTool {
469 fn descriptor(&self) -> ToolDescriptor {
470 ToolDescriptor::new("scope_reader", "ScopeReader", "Reads scope values")
471 }
472
473 async fn execute(
474 &self,
475 _args: Value,
476 ctx: &ToolCallContext<'_>,
477 ) -> Result<ToolResult, ToolError> {
478 let user_id = ctx
479 .config_value("user_id")
480 .and_then(|v| v.as_str())
481 .unwrap_or("unknown");
482 Ok(ToolResult::success(
483 "scope_reader",
484 json!({"user_id": user_id}),
485 ))
486 }
487 }
488
489 let mut scope = RunConfig::new();
490 scope.set("user_id", "u-42").unwrap();
491
492 let tool = ScopeReaderTool;
493 let call = ToolCall::new("call_1", "scope_reader", json!({}));
494 let state = json!({});
495
496 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
497
498 assert!(exec.result.is_success());
499 assert_eq!(exec.result.data["user_id"], "u-42");
500 }
501
502 #[tokio::test]
503 async fn test_execute_single_tool_with_scope_none() {
504 struct ScopeCheckerTool;
506
507 #[async_trait]
508 impl Tool for ScopeCheckerTool {
509 fn descriptor(&self) -> ToolDescriptor {
510 ToolDescriptor::new("scope_checker", "ScopeChecker", "Checks scope presence")
511 }
512
513 async fn execute(
514 &self,
515 _args: Value,
516 ctx: &ToolCallContext<'_>,
517 ) -> Result<ToolResult, ToolError> {
518 let has_user_id = ctx.config_value("user_id").is_some();
521 Ok(ToolResult::success(
522 "scope_checker",
523 json!({"has_scope": true, "has_user_id": has_user_id}),
524 ))
525 }
526 }
527
528 let tool = ScopeCheckerTool;
529 let call = ToolCall::new("call_1", "scope_checker", json!({}));
530 let state = json!({});
531
532 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, None).await;
534 assert_eq!(exec.result.data["has_scope"], true);
535 assert_eq!(exec.result.data["has_user_id"], false);
536
537 let scope = RunConfig::new();
539 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
540 assert_eq!(exec.result.data["has_scope"], true);
541 assert_eq!(exec.result.data["has_user_id"], false);
542 }
543
544 #[tokio::test]
545 async fn test_execute_with_scope_sensitive_key() {
546 struct SensitiveReaderTool;
548
549 #[async_trait]
550 impl Tool for SensitiveReaderTool {
551 fn descriptor(&self) -> ToolDescriptor {
552 ToolDescriptor::new("sensitive", "Sensitive", "Reads sensitive key")
553 }
554
555 async fn execute(
556 &self,
557 _args: Value,
558 ctx: &ToolCallContext<'_>,
559 ) -> Result<ToolResult, ToolError> {
560 let scope = ctx.run_config();
561 let token = scope.value("token").and_then(|v| v.as_str()).unwrap();
562 let is_sensitive = scope.is_sensitive("token");
563 Ok(ToolResult::success(
564 "sensitive",
565 json!({"token_len": token.len(), "is_sensitive": is_sensitive}),
566 ))
567 }
568 }
569
570 let mut scope = RunConfig::new();
571 scope.set_sensitive("token", "super-secret-token").unwrap();
572
573 let tool = SensitiveReaderTool;
574 let call = ToolCall::new("call_1", "sensitive", json!({}));
575 let state = json!({});
576
577 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
578
579 assert!(exec.result.is_success());
580 assert_eq!(exec.result.data["token_len"], 18);
581 assert_eq!(exec.result.data["is_sensitive"], true);
582 }
583
584 struct StrictSchemaTool {
590 executed: std::sync::atomic::AtomicBool,
591 }
592
593 #[async_trait]
594 impl Tool for StrictSchemaTool {
595 fn descriptor(&self) -> ToolDescriptor {
596 ToolDescriptor::new("strict", "Strict", "Requires a string 'name'").with_parameters(
597 json!({
598 "type": "object",
599 "properties": {
600 "name": { "type": "string" }
601 },
602 "required": ["name"]
603 }),
604 )
605 }
606
607 async fn execute(
608 &self,
609 args: Value,
610 _ctx: &ToolCallContext<'_>,
611 ) -> Result<ToolResult, ToolError> {
612 self.executed
613 .store(true, std::sync::atomic::Ordering::SeqCst);
614 Ok(ToolResult::success("strict", args))
615 }
616 }
617
618 #[tokio::test]
619 async fn test_validate_args_blocks_invalid_before_execute() {
620 let tool = StrictSchemaTool {
621 executed: std::sync::atomic::AtomicBool::new(false),
622 };
623 let call = ToolCall::new("call_1", "strict", json!({}));
625 let state = json!({});
626
627 let exec = execute_single_tool(Some(&tool), &call, &state).await;
628
629 assert!(exec.result.is_error());
630 assert!(
631 exec.result.message.as_ref().unwrap().contains("name"),
632 "error should mention the missing field"
633 );
634 assert!(
635 !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
636 "execute() must NOT be called when validate_args fails"
637 );
638 }
639
640 #[tokio::test]
641 async fn test_validate_args_passes_valid_to_execute() {
642 let tool = StrictSchemaTool {
643 executed: std::sync::atomic::AtomicBool::new(false),
644 };
645 let call = ToolCall::new("call_1", "strict", json!({"name": "Alice"}));
646 let state = json!({});
647
648 let exec = execute_single_tool(Some(&tool), &call, &state).await;
649
650 assert!(exec.result.is_success());
651 assert!(
652 tool.executed.load(std::sync::atomic::Ordering::SeqCst),
653 "execute() should be called for valid args"
654 );
655 }
656
657 #[tokio::test]
658 async fn test_validate_args_wrong_type_blocks_execute() {
659 let tool = StrictSchemaTool {
660 executed: std::sync::atomic::AtomicBool::new(false),
661 };
662 let call = ToolCall::new("call_1", "strict", json!({"name": 42}));
664 let state = json!({});
665
666 let exec = execute_single_tool(Some(&tool), &call, &state).await;
667
668 assert!(exec.result.is_error());
669 assert!(
670 !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
671 "execute() must NOT be called when validate_args fails"
672 );
673 }
674}