1use crate::contracts::runtime::behavior::AgentBehavior;
4use crate::contracts::runtime::tool_call::ToolCallContext;
5use crate::contracts::runtime::tool_call::{Tool, ToolExecutionEffect, ToolResult};
6pub use crate::contracts::runtime::ToolExecution;
7use crate::contracts::thread::ToolCall;
8use crate::contracts::{reduce_state_actions, AnyStateAction, ScopeContext};
9use futures::future::join_all;
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13use tirea_contract::RunConfig;
14use tirea_state::{apply_patch, DocCell, Patch, TrackedPatch};
15
16const DIRECT_STATE_WRITE_DENIED_ERROR_CODE: &str = "tool_context_state_write_not_allowed";
17
18pub(crate) fn merge_context_patch_into_effect(
19 call: &ToolCall,
20 _effect: &mut ToolExecutionEffect,
21 context_patch: TrackedPatch,
22) -> Result<(), Box<ToolResult>> {
23 if context_patch.patch().is_empty() {
24 return Ok(());
25 }
26
27 Err(Box::new(ToolResult::error_with_code(
29 &call.name,
30 DIRECT_STATE_WRITE_DENIED_ERROR_CODE,
31 "direct ToolCallContext state writes are disabled; emit ToolExecutionEffect actions instead",
32 )))
33}
34
35pub async fn execute_single_tool(
48 tool: Option<&dyn Tool>,
49 call: &ToolCall,
50 state: &Value,
51) -> ToolExecution {
52 execute_single_tool_with_scope_and_behavior(tool, call, state, None, None).await
53}
54
55pub async fn execute_single_tool_with_scope(
57 tool: Option<&dyn Tool>,
58 call: &ToolCall,
59 state: &Value,
60 scope: Option<&RunConfig>,
61) -> ToolExecution {
62 execute_single_tool_with_scope_and_behavior(tool, call, state, scope, None).await
63}
64
65pub async fn execute_single_tool_with_scope_and_behavior(
67 tool: Option<&dyn Tool>,
68 call: &ToolCall,
69 state: &Value,
70 scope: Option<&RunConfig>,
71 _behavior: Option<&dyn AgentBehavior>,
72) -> ToolExecution {
73 let Some(tool) = tool else {
74 return ToolExecution {
75 call: call.clone(),
76 result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
77 patch: None,
78 };
79 };
80
81 let doc = DocCell::new(state.clone());
83 let ops = Mutex::new(Vec::new());
84 let default_scope = RunConfig::default();
85 let scope = scope.unwrap_or(&default_scope);
86 let pending_messages = Mutex::new(Vec::new());
87 let ctx = ToolCallContext::new(
88 &doc,
89 &ops,
90 &call.id,
91 format!("tool:{}", call.name),
92 scope,
93 &pending_messages,
94 tirea_contract::runtime::activity::NoOpActivityManager::arc(),
95 );
96
97 if let Err(e) = tool.validate_args(&call.arguments) {
99 return ToolExecution {
100 call: call.clone(),
101 result: ToolResult::error(&call.name, e.to_string()),
102 patch: None,
103 };
104 }
105
106 let mut effect = match tool.execute_effect(call.arguments.clone(), &ctx).await {
108 Ok(effect) => effect,
109 Err(e) => ToolExecutionEffect::from(ToolResult::error(&call.name, e.to_string())),
110 };
111
112 let context_patch = ctx.take_patch();
113 if let Err(result) = merge_context_patch_into_effect(call, &mut effect, context_patch) {
114 return ToolExecution {
115 call: call.clone(),
116 result: *result,
117 patch: None,
118 };
119 }
120 let (result, actions) = effect.into_parts();
121 let state_actions: Vec<AnyStateAction> = actions
122 .into_iter()
123 .filter_map(|a| {
124 if a.is_state_action() {
125 a.into_state_action()
126 } else {
127 None
128 }
129 })
130 .collect();
131
132 let tool_scope_ctx = ScopeContext::for_call(&call.id);
133 let action_patches = match reduce_state_actions(
134 state_actions,
135 state,
136 &format!("tool:{}", call.name),
137 &tool_scope_ctx,
138 ) {
139 Ok(patches) => patches,
140 Err(err) => {
141 return ToolExecution {
142 call: call.clone(),
143 result: ToolResult::error(
144 &call.name,
145 format!("tool state action reduce failed: {err}"),
146 ),
147 patch: None,
148 };
149 }
150 };
151
152 let mut merged_patch = Patch::new();
153 for tracked in action_patches {
154 merged_patch.extend(tracked.patch().clone());
155 }
156
157 let patch = if merged_patch.is_empty() {
158 None
159 } else {
160 Some(TrackedPatch::new(merged_patch).with_source(format!("tool:{}", call.name)))
161 };
162
163 ToolExecution {
164 call: call.clone(),
165 result,
166 patch,
167 }
168}
169
170pub async fn execute_tools_parallel(
172 tools: &HashMap<String, Arc<dyn Tool>>,
173 calls: &[ToolCall],
174 state: &Value,
175) -> Vec<ToolExecution> {
176 let tasks = calls.iter().map(|call| {
177 let tool = tools.get(&call.name).cloned();
178 let state = state.clone();
179 async move { execute_single_tool(tool.as_deref(), call, &state).await }
180 });
181 join_all(tasks).await
182}
183
184pub async fn execute_tools_sequential(
186 tools: &HashMap<String, Arc<dyn Tool>>,
187 calls: &[ToolCall],
188 state: &Value,
189) -> (Value, Vec<ToolExecution>) {
190 let mut state = state.clone();
191 let mut executions = Vec::with_capacity(calls.len());
192
193 for call in calls {
194 let exec = execute_single_tool(tools.get(&call.name).map(Arc::as_ref), call, &state).await;
195 if let Some(patch) = exec.patch.as_ref() {
196 if let Ok(next) = apply_patch(&state, patch.patch()) {
197 state = next;
198 }
199 }
200 executions.push(exec);
201 }
202
203 (state, executions)
204}
205
206pub fn collect_patches(executions: &[ToolExecution]) -> Vec<TrackedPatch> {
208 executions.iter().filter_map(|e| e.patch.clone()).collect()
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::contracts::runtime::state::AnyStateAction;
215 use crate::contracts::runtime::state::StateSpec;
216 use crate::contracts::runtime::tool_call::{ToolDescriptor, ToolError};
217 use crate::contracts::ToolCallContext;
218 use async_trait::async_trait;
219 use serde::{Deserialize, Serialize};
220 use serde_json::json;
221 use tirea_contract::testing::TestFixtureState;
222 use tirea_state::{PatchSink, Path as TPath, State, TireaResult};
223
224 struct EchoTool;
225
226 #[async_trait]
227 impl Tool for EchoTool {
228 fn descriptor(&self) -> ToolDescriptor {
229 ToolDescriptor::new("echo", "Echo", "Echo the input")
230 }
231
232 async fn execute(
233 &self,
234 args: Value,
235 _ctx: &ToolCallContext<'_>,
236 ) -> Result<ToolResult, ToolError> {
237 Ok(ToolResult::success("echo", args))
238 }
239 }
240
241 #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
242 struct EffectCounterState {
243 value: i64,
244 }
245
246 struct EffectCounterRef;
247
248 impl State for EffectCounterState {
249 type Ref<'a> = EffectCounterRef;
250 const PATH: &'static str = "counter";
251
252 fn state_ref<'a>(_: &'a DocCell, _: TPath, _: PatchSink<'a>) -> Self::Ref<'a> {
253 EffectCounterRef
254 }
255
256 fn from_value(value: &Value) -> TireaResult<Self> {
257 if value.is_null() {
258 return Ok(Self::default());
259 }
260 serde_json::from_value(value.clone()).map_err(tirea_state::TireaError::Serialization)
261 }
262
263 fn to_value(&self) -> TireaResult<Value> {
264 serde_json::to_value(self).map_err(tirea_state::TireaError::Serialization)
265 }
266 }
267
268 impl StateSpec for EffectCounterState {
269 type Action = i64;
270
271 fn reduce(&mut self, action: Self::Action) {
272 self.value += action;
273 }
274 }
275
276 struct EffectTool;
277
278 #[async_trait]
279 impl Tool for EffectTool {
280 fn descriptor(&self) -> ToolDescriptor {
281 ToolDescriptor::new("effect", "Effect", "Tool returning state actions")
282 }
283
284 async fn execute(
285 &self,
286 _args: Value,
287 _ctx: &ToolCallContext<'_>,
288 ) -> Result<ToolResult, ToolError> {
289 Ok(ToolResult::success("effect", json!({})))
290 }
291
292 async fn execute_effect(
293 &self,
294 _args: Value,
295 _ctx: &ToolCallContext<'_>,
296 ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
297 Ok(
298 crate::contracts::runtime::ToolExecutionEffect::new(ToolResult::success(
299 "effect",
300 json!({}),
301 ))
302 .with_action(AnyStateAction::new::<EffectCounterState>(2)),
303 )
304 }
305 }
306
307 struct DirectWriteEffectTool;
308
309 #[async_trait]
310 impl Tool for DirectWriteEffectTool {
311 fn descriptor(&self) -> ToolDescriptor {
312 ToolDescriptor::new(
313 "direct_write_effect",
314 "DirectWriteEffect",
315 "writes state directly in execute_effect",
316 )
317 }
318
319 async fn execute(
320 &self,
321 _args: Value,
322 _ctx: &ToolCallContext<'_>,
323 ) -> Result<ToolResult, ToolError> {
324 Ok(ToolResult::success(
325 "direct_write_effect",
326 json!({"ok": true}),
327 ))
328 }
329
330 async fn execute_effect(
331 &self,
332 _args: Value,
333 ctx: &ToolCallContext<'_>,
334 ) -> Result<crate::contracts::runtime::ToolExecutionEffect, ToolError> {
335 let state = ctx.state_of::<TestFixtureState>();
336 state
337 .set_label(Some("direct_write".to_string()))
338 .expect("failed to set label");
339 Ok(crate::contracts::runtime::ToolExecutionEffect::new(
340 ToolResult::success("direct_write_effect", json!({"ok": true})),
341 ))
342 }
343 }
344
345 #[tokio::test]
346 async fn test_execute_single_tool_not_found() {
347 let call = ToolCall::new("call_1", "nonexistent", json!({}));
348 let state = json!({});
349
350 let exec = execute_single_tool(None, &call, &state).await;
351
352 assert!(exec.result.is_error());
353 assert!(exec.patch.is_none());
354 }
355
356 #[tokio::test]
357 async fn test_execute_single_tool_success() {
358 let tool = EchoTool;
359 let call = ToolCall::new("call_1", "echo", json!({"msg": "hello"}));
360 let state = json!({});
361
362 let exec = execute_single_tool(Some(&tool), &call, &state).await;
363
364 assert!(exec.result.is_success());
365 assert_eq!(exec.result.data["msg"], "hello");
366 }
367
368 #[tokio::test]
369 async fn test_execute_single_tool_applies_state_actions_from_effect() {
370 let tool = EffectTool;
371 let call = ToolCall::new("call_1", "effect", json!({}));
372 let state = json!({"counter": {"value": 1}});
373
374 let exec = execute_single_tool(Some(&tool), &call, &state).await;
375 let patch = exec.patch.expect("patch should be emitted");
376 let next = apply_patch(&state, patch.patch()).expect("patch should apply");
377
378 assert_eq!(next["counter"]["value"], 3);
379 }
380
381 #[tokio::test]
382 async fn test_execute_single_tool_rejects_direct_context_writes_in_strict_mode() {
383 let tool = DirectWriteEffectTool;
384 let call = ToolCall::new("call_1", "direct_write_effect", json!({}));
385 let state = json!({});
386 let scope = RunConfig::default();
387
388 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
389 assert!(exec.result.is_error());
390 assert_eq!(
391 exec.result.data["error"]["code"],
392 json!("tool_context_state_write_not_allowed")
393 );
394 assert!(exec.patch.is_none());
395 }
396
397 #[tokio::test]
398 async fn test_collect_patches() {
399 use tirea_state::{path, Op, Patch};
400
401 let executions = vec![
402 ToolExecution {
403 call: ToolCall::new("1", "a", json!({})),
404 result: ToolResult::success("a", json!({})),
405 patch: Some(TrackedPatch::new(
406 Patch::new().with_op(Op::set(path!("a"), json!(1))),
407 )),
408 },
409 ToolExecution {
410 call: ToolCall::new("2", "b", json!({})),
411 result: ToolResult::success("b", json!({})),
412 patch: None,
413 },
414 ToolExecution {
415 call: ToolCall::new("3", "c", json!({})),
416 result: ToolResult::success("c", json!({})),
417 patch: Some(TrackedPatch::new(
418 Patch::new().with_op(Op::set(path!("c"), json!(3))),
419 )),
420 },
421 ];
422
423 let patches = collect_patches(&executions);
424 assert_eq!(patches.len(), 2);
425 }
426
427 #[tokio::test]
428 async fn test_tool_execution_error() {
429 struct FailingTool;
430
431 #[async_trait]
432 impl Tool for FailingTool {
433 fn descriptor(&self) -> ToolDescriptor {
434 ToolDescriptor::new("failing", "Failing", "Always fails")
435 }
436
437 async fn execute(
438 &self,
439 _args: Value,
440 _ctx: &ToolCallContext<'_>,
441 ) -> Result<ToolResult, ToolError> {
442 Err(ToolError::ExecutionFailed(
443 "Intentional failure".to_string(),
444 ))
445 }
446 }
447
448 let tool = FailingTool;
449 let call = ToolCall::new("call_1", "failing", json!({}));
450 let state = json!({});
451
452 let exec = execute_single_tool(Some(&tool), &call, &state).await;
453
454 assert!(exec.result.is_error());
455 assert!(exec
456 .result
457 .message
458 .as_ref()
459 .unwrap()
460 .contains("Intentional failure"));
461 }
462
463 #[tokio::test]
464 async fn test_execute_single_tool_with_scope_reads() {
465 struct ScopeReaderTool;
467
468 #[async_trait]
469 impl Tool for ScopeReaderTool {
470 fn descriptor(&self) -> ToolDescriptor {
471 ToolDescriptor::new("scope_reader", "ScopeReader", "Reads scope values")
472 }
473
474 async fn execute(
475 &self,
476 _args: Value,
477 ctx: &ToolCallContext<'_>,
478 ) -> Result<ToolResult, ToolError> {
479 let user_id = ctx
480 .config_value("user_id")
481 .and_then(|v| v.as_str())
482 .unwrap_or("unknown");
483 Ok(ToolResult::success(
484 "scope_reader",
485 json!({"user_id": user_id}),
486 ))
487 }
488 }
489
490 let mut scope = RunConfig::new();
491 scope.set("user_id", "u-42").unwrap();
492
493 let tool = ScopeReaderTool;
494 let call = ToolCall::new("call_1", "scope_reader", json!({}));
495 let state = json!({});
496
497 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
498
499 assert!(exec.result.is_success());
500 assert_eq!(exec.result.data["user_id"], "u-42");
501 }
502
503 #[tokio::test]
504 async fn test_execute_single_tool_with_scope_none() {
505 struct ScopeCheckerTool;
507
508 #[async_trait]
509 impl Tool for ScopeCheckerTool {
510 fn descriptor(&self) -> ToolDescriptor {
511 ToolDescriptor::new("scope_checker", "ScopeChecker", "Checks scope presence")
512 }
513
514 async fn execute(
515 &self,
516 _args: Value,
517 ctx: &ToolCallContext<'_>,
518 ) -> Result<ToolResult, ToolError> {
519 let has_user_id = ctx.config_value("user_id").is_some();
522 Ok(ToolResult::success(
523 "scope_checker",
524 json!({"has_scope": true, "has_user_id": has_user_id}),
525 ))
526 }
527 }
528
529 let tool = ScopeCheckerTool;
530 let call = ToolCall::new("call_1", "scope_checker", json!({}));
531 let state = json!({});
532
533 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, None).await;
535 assert_eq!(exec.result.data["has_scope"], true);
536 assert_eq!(exec.result.data["has_user_id"], false);
537
538 let scope = RunConfig::new();
540 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
541 assert_eq!(exec.result.data["has_scope"], true);
542 assert_eq!(exec.result.data["has_user_id"], false);
543 }
544
545 #[tokio::test]
546 async fn test_execute_with_scope_sensitive_key() {
547 struct SensitiveReaderTool;
549
550 #[async_trait]
551 impl Tool for SensitiveReaderTool {
552 fn descriptor(&self) -> ToolDescriptor {
553 ToolDescriptor::new("sensitive", "Sensitive", "Reads sensitive key")
554 }
555
556 async fn execute(
557 &self,
558 _args: Value,
559 ctx: &ToolCallContext<'_>,
560 ) -> Result<ToolResult, ToolError> {
561 let scope = ctx.run_config();
562 let token = scope.value("token").and_then(|v| v.as_str()).unwrap();
563 let is_sensitive = scope.is_sensitive("token");
564 Ok(ToolResult::success(
565 "sensitive",
566 json!({"token_len": token.len(), "is_sensitive": is_sensitive}),
567 ))
568 }
569 }
570
571 let mut scope = RunConfig::new();
572 scope.set_sensitive("token", "super-secret-token").unwrap();
573
574 let tool = SensitiveReaderTool;
575 let call = ToolCall::new("call_1", "sensitive", json!({}));
576 let state = json!({});
577
578 let exec = execute_single_tool_with_scope(Some(&tool), &call, &state, Some(&scope)).await;
579
580 assert!(exec.result.is_success());
581 assert_eq!(exec.result.data["token_len"], 18);
582 assert_eq!(exec.result.data["is_sensitive"], true);
583 }
584
585 struct StrictSchemaTool {
591 executed: std::sync::atomic::AtomicBool,
592 }
593
594 #[async_trait]
595 impl Tool for StrictSchemaTool {
596 fn descriptor(&self) -> ToolDescriptor {
597 ToolDescriptor::new("strict", "Strict", "Requires a string 'name'").with_parameters(
598 json!({
599 "type": "object",
600 "properties": {
601 "name": { "type": "string" }
602 },
603 "required": ["name"]
604 }),
605 )
606 }
607
608 async fn execute(
609 &self,
610 args: Value,
611 _ctx: &ToolCallContext<'_>,
612 ) -> Result<ToolResult, ToolError> {
613 self.executed
614 .store(true, std::sync::atomic::Ordering::SeqCst);
615 Ok(ToolResult::success("strict", args))
616 }
617 }
618
619 #[tokio::test]
620 async fn test_validate_args_blocks_invalid_before_execute() {
621 let tool = StrictSchemaTool {
622 executed: std::sync::atomic::AtomicBool::new(false),
623 };
624 let call = ToolCall::new("call_1", "strict", json!({}));
626 let state = json!({});
627
628 let exec = execute_single_tool(Some(&tool), &call, &state).await;
629
630 assert!(exec.result.is_error());
631 assert!(
632 exec.result.message.as_ref().unwrap().contains("name"),
633 "error should mention the missing field"
634 );
635 assert!(
636 !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
637 "execute() must NOT be called when validate_args fails"
638 );
639 }
640
641 #[tokio::test]
642 async fn test_validate_args_passes_valid_to_execute() {
643 let tool = StrictSchemaTool {
644 executed: std::sync::atomic::AtomicBool::new(false),
645 };
646 let call = ToolCall::new("call_1", "strict", json!({"name": "Alice"}));
647 let state = json!({});
648
649 let exec = execute_single_tool(Some(&tool), &call, &state).await;
650
651 assert!(exec.result.is_success());
652 assert!(
653 tool.executed.load(std::sync::atomic::Ordering::SeqCst),
654 "execute() should be called for valid args"
655 );
656 }
657
658 #[tokio::test]
659 async fn test_validate_args_wrong_type_blocks_execute() {
660 let tool = StrictSchemaTool {
661 executed: std::sync::atomic::AtomicBool::new(false),
662 };
663 let call = ToolCall::new("call_1", "strict", json!({"name": 42}));
665 let state = json!({});
666
667 let exec = execute_single_tool(Some(&tool), &call, &state).await;
668
669 assert!(exec.result.is_error());
670 assert!(
671 !tool.executed.load(std::sync::atomic::Ordering::SeqCst),
672 "execute() must NOT be called when validate_args fails"
673 );
674 }
675}