1use super::core::{
2 pending_approval_placeholder_message, transition_tool_call_state, ToolCallStateSeed,
3 ToolCallStateTransition,
4};
5use super::parallel_state_merge::merge_parallel_state_patches;
6use super::plugin_runtime::emit_tool_phase;
7use super::{Agent, AgentLoopError, BaseAgent, RunCancellationToken};
8use crate::contracts::runtime::action::Action;
9use crate::contracts::runtime::behavior::AgentBehavior;
10use crate::contracts::runtime::phase::{Phase, StepContext};
11use crate::contracts::runtime::state::{reduce_state_actions, AnyStateAction, ScopeContext};
12use crate::contracts::runtime::tool_call::{CallerContext, ToolGate};
13use crate::contracts::runtime::tool_call::{Tool, ToolDescriptor, ToolResult};
14use crate::contracts::runtime::{
15 ActivityManager, PendingToolCall, SuspendTicket, SuspendedCall, ToolCallResumeMode,
16};
17use crate::contracts::runtime::{
18 DecisionReplayPolicy, StreamResult, ToolCallOutcome, ToolCallStatus, ToolExecution,
19 ToolExecutionEffect, ToolExecutionRequest, ToolExecutionResult, ToolExecutor,
20 ToolExecutorError,
21};
22use crate::contracts::thread::Thread;
23use crate::contracts::thread::{Message, MessageMetadata, ToolCall};
24use crate::contracts::{RunContext, Suspension};
25use crate::engine::convert::tool_response;
26use crate::engine::tool_execution::merge_context_patch_into_effect;
27use crate::runtime::run_context::{await_or_cancel, is_cancelled, CancelAware};
28use async_trait::async_trait;
29use serde_json::Value;
30use std::collections::HashMap;
31use std::sync::Arc;
32use tirea_state::{apply_patch, Patch, TrackedPatch};
33
34#[derive(Debug)]
39pub enum ExecuteToolsOutcome {
40 Completed(Thread),
42 Suspended {
44 thread: Thread,
45 suspended_call: Box<SuspendedCall>,
46 },
47}
48
49impl ExecuteToolsOutcome {
50 pub fn into_thread(self) -> Thread {
52 match self {
53 Self::Completed(t) | Self::Suspended { thread: t, .. } => t,
54 }
55 }
56
57 pub fn is_suspended(&self) -> bool {
59 matches!(self, Self::Suspended { .. })
60 }
61}
62
63pub(super) struct AppliedToolResults {
64 pub(super) suspended_calls: Vec<SuspendedCall>,
65 pub(super) state_snapshot: Option<Value>,
66}
67
68#[derive(Clone)]
69pub(super) struct ToolPhaseContext<'a> {
70 pub(super) tool_descriptors: &'a [ToolDescriptor],
71 pub(super) agent_behavior: Option<&'a dyn AgentBehavior>,
72 pub(super) activity_manager: Arc<dyn ActivityManager>,
73 pub(super) run_policy: &'a tirea_contract::RunPolicy,
74 pub(super) run_identity: tirea_contract::runtime::RunIdentity,
75 pub(super) caller_context: CallerContext,
76 pub(super) thread_id: &'a str,
77 pub(super) thread_messages: &'a [Arc<Message>],
78 pub(super) cancellation_token: Option<&'a RunCancellationToken>,
79}
80
81impl<'a> ToolPhaseContext<'a> {
82 pub(super) fn from_request(request: &'a ToolExecutionRequest<'a>) -> Self {
83 Self {
84 tool_descriptors: request.tool_descriptors,
85 agent_behavior: request.agent_behavior,
86 activity_manager: request.activity_manager.clone(),
87 run_policy: request.run_policy,
88 run_identity: request.run_identity.clone(),
89 caller_context: request.caller_context.clone(),
90 thread_id: request.thread_id,
91 thread_messages: request.thread_messages,
92 cancellation_token: request.cancellation_token,
93 }
94 }
95}
96
97fn now_unix_millis() -> u64 {
98 std::time::SystemTime::now()
99 .duration_since(std::time::UNIX_EPOCH)
100 .map_or(0, |d| d.as_millis().min(u128::from(u64::MAX)) as u64)
101}
102
103fn suspended_call_from_tool_result(call: &ToolCall, result: &ToolResult) -> SuspendedCall {
104 if let Some(mut explicit) = result.suspension() {
105 if explicit.pending.id.trim().is_empty() || explicit.pending.name.trim().is_empty() {
106 explicit.pending =
107 PendingToolCall::new(call.id.clone(), call.name.clone(), call.arguments.clone());
108 }
109 return SuspendedCall::new(call, explicit);
110 }
111
112 let mut suspension = Suspension::new(&call.id, format!("tool:{}", call.name))
113 .with_parameters(call.arguments.clone());
114 if let Some(message) = result.message.as_ref() {
115 suspension = suspension.with_message(message.clone());
116 }
117
118 SuspendedCall::new(
119 call,
120 SuspendTicket::new(
121 suspension,
122 PendingToolCall::new(call.id.clone(), call.name.clone(), call.arguments.clone()),
123 ToolCallResumeMode::ReplayToolCall,
124 ),
125 )
126}
127
128fn persist_tool_call_status(
129 step: &StepContext<'_>,
130 call: &ToolCall,
131 status: ToolCallStatus,
132 suspended_call: Option<&SuspendedCall>,
133) -> Result<crate::contracts::runtime::ToolCallState, AgentLoopError> {
134 let current_state = step.ctx().tool_call_state_for(&call.id).map_err(|e| {
135 AgentLoopError::StateError(format!(
136 "failed to read tool call state for '{}' before setting {:?}: {e}",
137 call.id, status
138 ))
139 })?;
140 let previous_status = current_state
141 .as_ref()
142 .map(|state| state.status)
143 .unwrap_or(ToolCallStatus::New);
144 let current_resume_token = current_state
145 .as_ref()
146 .and_then(|state| state.resume_token.clone());
147 let current_resume = current_state
148 .as_ref()
149 .and_then(|state| state.resume.clone());
150
151 let (next_resume_token, next_resume) = match status {
152 ToolCallStatus::Running => {
153 if matches!(previous_status, ToolCallStatus::Resuming) {
154 (current_resume_token.clone(), current_resume.clone())
155 } else {
156 (None, None)
157 }
158 }
159 ToolCallStatus::Suspended => (
160 suspended_call
161 .map(|entry| entry.ticket.pending.id.clone())
162 .or(current_resume_token.clone()),
163 None,
164 ),
165 ToolCallStatus::Succeeded
166 | ToolCallStatus::Failed
167 | ToolCallStatus::Cancelled
168 | ToolCallStatus::New
169 | ToolCallStatus::Resuming => (current_resume_token, current_resume),
170 };
171
172 let Some(runtime_state) = transition_tool_call_state(
173 current_state,
174 ToolCallStateSeed {
175 call_id: &call.id,
176 tool_name: &call.name,
177 arguments: &call.arguments,
178 status: ToolCallStatus::New,
179 resume_token: None,
180 },
181 ToolCallStateTransition {
182 status,
183 resume_token: next_resume_token,
184 resume: next_resume,
185 updated_at: now_unix_millis(),
186 },
187 ) else {
188 return Err(AgentLoopError::StateError(format!(
189 "invalid tool call status transition for '{}': {:?} -> {:?}",
190 call.id, previous_status, status
191 )));
192 };
193
194 step.ctx()
195 .set_tool_call_state_for(&call.id, runtime_state.clone())
196 .map_err(|e| {
197 AgentLoopError::StateError(format!(
198 "failed to persist tool call state for '{}' as {:?}: {e}",
199 call.id, status
200 ))
201 })?;
202
203 Ok(runtime_state)
204}
205
206fn map_tool_executor_error(err: AgentLoopError, thread_id: &str) -> ToolExecutorError {
207 match err {
208 AgentLoopError::Cancelled => ToolExecutorError::Cancelled {
209 thread_id: thread_id.to_string(),
210 },
211 other => ToolExecutorError::Failed {
212 message: other.to_string(),
213 },
214 }
215}
216
217#[derive(Debug, Clone, Copy, PartialEq, Eq)]
219pub enum ParallelToolExecutionMode {
220 BatchApproval,
221 Streaming,
222}
223
224#[derive(Debug, Clone, Copy, PartialEq, Eq)]
226pub struct ParallelToolExecutor {
227 mode: ParallelToolExecutionMode,
228}
229
230impl ParallelToolExecutor {
231 pub const fn batch_approval() -> Self {
232 Self {
233 mode: ParallelToolExecutionMode::BatchApproval,
234 }
235 }
236
237 pub const fn streaming() -> Self {
238 Self {
239 mode: ParallelToolExecutionMode::Streaming,
240 }
241 }
242
243 fn mode_name(self) -> &'static str {
244 match self.mode {
245 ParallelToolExecutionMode::BatchApproval => "parallel_batch_approval",
246 ParallelToolExecutionMode::Streaming => "parallel_streaming",
247 }
248 }
249}
250
251impl Default for ParallelToolExecutor {
252 fn default() -> Self {
253 Self::streaming()
254 }
255}
256
257#[async_trait]
258impl ToolExecutor for ParallelToolExecutor {
259 async fn execute(
260 &self,
261 request: ToolExecutionRequest<'_>,
262 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
263 let thread_id = request.thread_id;
264 let phase_ctx = ToolPhaseContext::from_request(&request);
265 execute_tools_parallel_with_phases(request.tools, request.calls, request.state, phase_ctx)
266 .await
267 .map_err(|e| map_tool_executor_error(e, thread_id))
268 }
269
270 fn name(&self) -> &'static str {
271 self.mode_name()
272 }
273
274 fn requires_parallel_patch_conflict_check(&self) -> bool {
275 true
276 }
277
278 fn decision_replay_policy(&self) -> DecisionReplayPolicy {
279 match self.mode {
280 ParallelToolExecutionMode::BatchApproval => DecisionReplayPolicy::BatchAllSuspended,
281 ParallelToolExecutionMode::Streaming => DecisionReplayPolicy::Immediate,
282 }
283 }
284}
285
286#[derive(Debug, Clone, Copy, Default)]
288pub struct SequentialToolExecutor;
289
290#[async_trait]
291impl ToolExecutor for SequentialToolExecutor {
292 async fn execute(
293 &self,
294 request: ToolExecutionRequest<'_>,
295 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
296 let thread_id = request.thread_id;
297 let phase_ctx = ToolPhaseContext::from_request(&request);
298 execute_tools_sequential_with_phases(request.tools, request.calls, request.state, phase_ctx)
299 .await
300 .map_err(|e| map_tool_executor_error(e, thread_id))
301 }
302
303 fn name(&self) -> &'static str {
304 "sequential"
305 }
306}
307
308pub(super) fn apply_tool_results_to_session(
309 run_ctx: &mut RunContext,
310 results: &[ToolExecutionResult],
311 metadata: Option<MessageMetadata>,
312 check_parallel_patch_conflicts: bool,
313) -> Result<AppliedToolResults, AgentLoopError> {
314 apply_tool_results_impl(
315 run_ctx,
316 results,
317 metadata,
318 check_parallel_patch_conflicts,
319 None,
320 )
321}
322
323pub(super) fn apply_tool_results_impl(
324 run_ctx: &mut RunContext,
325 results: &[ToolExecutionResult],
326 metadata: Option<MessageMetadata>,
327 check_parallel_patch_conflicts: bool,
328 tool_msg_ids: Option<&HashMap<String, String>>,
329) -> Result<AppliedToolResults, AgentLoopError> {
330 let suspended: Vec<SuspendedCall> = results
332 .iter()
333 .filter_map(|r| {
334 if matches!(r.outcome, ToolCallOutcome::Suspended) {
335 r.suspended_call.clone()
336 } else {
337 None
338 }
339 })
340 .collect();
341
342 let all_serialized_state_actions: Vec<tirea_contract::SerializedStateAction> = results
344 .iter()
345 .flat_map(|r| r.serialized_state_actions.iter().cloned())
346 .collect();
347 if !all_serialized_state_actions.is_empty() {
348 run_ctx.add_serialized_state_actions(all_serialized_state_actions);
349 }
350
351 let base_snapshot = run_ctx
352 .snapshot()
353 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
354 let patches = merge_parallel_state_patches(
355 &base_snapshot,
356 results,
357 check_parallel_patch_conflicts,
358 run_ctx.lattice_registry(),
359 )?;
360 let mut state_changed = !patches.is_empty();
361 run_ctx.add_thread_patches(patches);
362
363 let tool_messages: Vec<Arc<Message>> = results
365 .iter()
366 .flat_map(|r| {
367 let is_suspended = matches!(r.outcome, ToolCallOutcome::Suspended);
368 let mut msgs = if is_suspended {
369 vec![Message::tool(
370 &r.execution.call.id,
371 pending_approval_placeholder_message(&r.execution.call.name),
372 )]
373 } else {
374 let mut tool_msg = tool_response(&r.execution.call.id, &r.execution.result);
375 if let Some(id) = tool_msg_ids.and_then(|ids| ids.get(&r.execution.call.id)) {
376 tool_msg = tool_msg.with_id(id.clone());
377 }
378 vec![tool_msg]
379 };
380 for reminder in &r.reminders {
381 msgs.push(Message::internal_system(format!(
382 "<system-reminder>{}</system-reminder>",
383 reminder
384 )));
385 }
386 if let Some(ref meta) = metadata {
387 for msg in &mut msgs {
388 msg.metadata = Some(meta.clone());
389 }
390 }
391 msgs.into_iter().map(Arc::new).collect::<Vec<_>>()
392 })
393 .collect();
394
395 run_ctx.add_messages(tool_messages);
396
397 let user_messages: Vec<Arc<Message>> = results
399 .iter()
400 .flat_map(|r| {
401 r.user_messages
402 .iter()
403 .map(|s| s.trim())
404 .filter(|s| !s.is_empty())
405 .map(|text| {
406 let mut msg = Message::user(text.to_string());
407 if let Some(ref meta) = metadata {
408 msg.metadata = Some(meta.clone());
409 }
410 Arc::new(msg)
411 })
412 .collect::<Vec<_>>()
413 })
414 .collect();
415 if !user_messages.is_empty() {
416 run_ctx.add_messages(user_messages);
417 }
418 if !suspended.is_empty() {
419 let state = run_ctx
420 .snapshot()
421 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
422 let actions: Vec<AnyStateAction> = suspended
423 .iter()
424 .map(|call| call.clone().into_state_action())
425 .collect();
426 let patches = reduce_state_actions(actions, &state, "agent_loop", &ScopeContext::run())
427 .map_err(|e| {
428 AgentLoopError::StateError(format!("failed to reduce suspended call actions: {e}"))
429 })?;
430 for patch in patches {
431 if !patch.patch().is_empty() {
432 state_changed = true;
433 run_ctx.add_thread_patch(patch);
434 }
435 }
436 let state_snapshot = if state_changed {
437 Some(
438 run_ctx
439 .snapshot()
440 .map_err(|e| AgentLoopError::StateError(e.to_string()))?,
441 )
442 } else {
443 None
444 };
445 return Ok(AppliedToolResults {
446 suspended_calls: suspended,
447 state_snapshot,
448 });
449 }
450
451 let state_snapshot = if state_changed {
458 Some(
459 run_ctx
460 .snapshot()
461 .map_err(|e| AgentLoopError::StateError(e.to_string()))?,
462 )
463 } else {
464 None
465 };
466
467 Ok(AppliedToolResults {
468 suspended_calls: Vec::new(),
469 state_snapshot,
470 })
471}
472
473fn tool_result_metadata_from_run_ctx(
474 run_ctx: &RunContext,
475 run_id: Option<&str>,
476) -> Option<MessageMetadata> {
477 let run_id = run_id.map(|id| id.to_string()).or_else(|| {
478 run_ctx.messages().iter().rev().find_map(|m| {
479 m.metadata
480 .as_ref()
481 .and_then(|meta| meta.run_id.as_ref().cloned())
482 })
483 });
484
485 let step_index = run_ctx
486 .messages()
487 .iter()
488 .rev()
489 .find_map(|m| m.metadata.as_ref().and_then(|meta| meta.step_index));
490
491 if run_id.is_none() && step_index.is_none() {
492 None
493 } else {
494 Some(MessageMetadata { run_id, step_index })
495 }
496}
497
498#[allow(dead_code)]
499pub(super) fn next_step_index(run_ctx: &RunContext) -> u32 {
500 run_ctx
501 .messages()
502 .iter()
503 .filter_map(|m| m.metadata.as_ref().and_then(|meta| meta.step_index))
504 .max()
505 .map(|v| v.saturating_add(1))
506 .unwrap_or(0)
507}
508
509pub(super) fn step_metadata(run_id: Option<String>, step_index: u32) -> MessageMetadata {
510 MessageMetadata {
511 run_id,
512 step_index: Some(step_index),
513 }
514}
515
516pub async fn execute_tools(
520 thread: Thread,
521 result: &StreamResult,
522 tools: &HashMap<String, Arc<dyn Tool>>,
523 parallel: bool,
524) -> Result<ExecuteToolsOutcome, AgentLoopError> {
525 let parallel_executor = ParallelToolExecutor::streaming();
526 let sequential_executor = SequentialToolExecutor;
527 let executor: &dyn ToolExecutor = if parallel {
528 ¶llel_executor
529 } else {
530 &sequential_executor
531 };
532 execute_tools_with_agent_and_executor(thread, result, tools, executor, None).await
533}
534
535pub async fn execute_tools_with_config(
537 thread: Thread,
538 result: &StreamResult,
539 tools: &HashMap<String, Arc<dyn Tool>>,
540 agent: &dyn Agent,
541) -> Result<ExecuteToolsOutcome, AgentLoopError> {
542 execute_tools_with_agent_and_executor(
543 thread,
544 result,
545 tools,
546 agent.tool_executor().as_ref(),
547 Some(agent.behavior()),
548 )
549 .await
550}
551
552pub(super) fn caller_context_for_tool_execution(
553 run_ctx: &RunContext,
554 _state: &Value,
555) -> CallerContext {
556 CallerContext::new(
557 Some(run_ctx.thread_id().to_string()),
558 run_ctx.run_identity().run_id_opt().map(ToOwned::to_owned),
559 run_ctx.run_identity().agent_id_opt().map(ToOwned::to_owned),
560 run_ctx.messages().to_vec(),
561 )
562}
563
564pub async fn execute_tools_with_behaviors(
566 thread: Thread,
567 result: &StreamResult,
568 tools: &HashMap<String, Arc<dyn Tool>>,
569 parallel: bool,
570 behavior: Arc<dyn AgentBehavior>,
571) -> Result<ExecuteToolsOutcome, AgentLoopError> {
572 let executor: Arc<dyn ToolExecutor> = if parallel {
573 Arc::new(ParallelToolExecutor::streaming())
574 } else {
575 Arc::new(SequentialToolExecutor)
576 };
577 let agent = BaseAgent::default()
578 .with_behavior(behavior)
579 .with_tool_executor(executor);
580 execute_tools_with_config(thread, result, tools, &agent).await
581}
582
583async fn execute_tools_with_agent_and_executor(
584 thread: Thread,
585 result: &StreamResult,
586 tools: &HashMap<String, Arc<dyn Tool>>,
587 executor: &dyn ToolExecutor,
588 behavior: Option<&dyn AgentBehavior>,
589) -> Result<ExecuteToolsOutcome, AgentLoopError> {
590 let rebuilt_state = thread
592 .rebuild_state()
593 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
594 let mut run_ctx = RunContext::new(
595 &thread.id,
596 rebuilt_state.clone(),
597 thread.messages.clone(),
598 tirea_contract::RunPolicy::default(),
599 );
600
601 let tool_descriptors: Vec<ToolDescriptor> =
602 tools.values().map(|t| t.descriptor().clone()).collect();
603 if let Some(behavior) = behavior {
605 let run_start_patches = super::plugin_runtime::behavior_run_phase_block(
606 &run_ctx,
607 &tool_descriptors,
608 behavior,
609 &[Phase::RunStart],
610 |_| {},
611 |_| (),
612 )
613 .await?
614 .1;
615 if !run_start_patches.is_empty() {
616 run_ctx.add_thread_patches(run_start_patches);
617 }
618 }
619
620 let replay_executor: Arc<dyn ToolExecutor> = match executor.decision_replay_policy() {
621 DecisionReplayPolicy::BatchAllSuspended => Arc::new(ParallelToolExecutor::batch_approval()),
622 DecisionReplayPolicy::Immediate => Arc::new(ParallelToolExecutor::streaming()),
623 };
624 let replay_config = BaseAgent::default().with_tool_executor(replay_executor);
625 let replay = super::drain_resuming_tool_calls_and_replay(
626 &mut run_ctx,
627 tools,
628 &replay_config,
629 &tool_descriptors,
630 )
631 .await?;
632
633 if replay.replayed {
634 let suspended = run_ctx.suspended_calls().values().next().cloned();
635 let delta = run_ctx.take_delta();
636 let mut out_thread = thread;
637 for msg in delta.messages {
638 out_thread = out_thread.with_message((*msg).clone());
639 }
640 out_thread = out_thread.with_patches(delta.patches);
641 return if let Some(first) = suspended {
642 Ok(ExecuteToolsOutcome::Suspended {
643 thread: out_thread,
644 suspended_call: Box::new(first),
645 })
646 } else {
647 Ok(ExecuteToolsOutcome::Completed(out_thread))
648 };
649 }
650
651 if result.tool_calls.is_empty() {
652 let delta = run_ctx.take_delta();
653 let mut out_thread = thread;
654 for msg in delta.messages {
655 out_thread = out_thread.with_message((*msg).clone());
656 }
657 out_thread = out_thread.with_patches(delta.patches);
658 return Ok(ExecuteToolsOutcome::Completed(out_thread));
659 }
660
661 let current_state = run_ctx
662 .snapshot()
663 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
664 let caller_context = caller_context_for_tool_execution(&run_ctx, ¤t_state);
665 let results = executor
666 .execute(ToolExecutionRequest {
667 tools,
668 calls: &result.tool_calls,
669 state: ¤t_state,
670 tool_descriptors: &tool_descriptors,
671 agent_behavior: behavior,
672 activity_manager: tirea_contract::runtime::activity::NoOpActivityManager::arc(),
673 run_policy: run_ctx.run_policy(),
674 run_identity: run_ctx.run_identity().clone(),
675 caller_context,
676 thread_id: run_ctx.thread_id(),
677 thread_messages: run_ctx.messages(),
678 state_version: run_ctx.version(),
679 cancellation_token: None,
680 })
681 .await?;
682
683 let metadata = tool_result_metadata_from_run_ctx(&run_ctx, None);
684 let applied = apply_tool_results_to_session(
685 &mut run_ctx,
686 &results,
687 metadata,
688 executor.requires_parallel_patch_conflict_check(),
689 )?;
690 let suspended = applied.suspended_calls.into_iter().next();
691
692 let delta = run_ctx.take_delta();
694 let mut out_thread = thread;
695 for msg in delta.messages {
696 out_thread = out_thread.with_message((*msg).clone());
697 }
698 out_thread = out_thread.with_patches(delta.patches);
699
700 if let Some(first) = suspended {
701 Ok(ExecuteToolsOutcome::Suspended {
702 thread: out_thread,
703 suspended_call: Box::new(first),
704 })
705 } else {
706 Ok(ExecuteToolsOutcome::Completed(out_thread))
707 }
708}
709
710pub(super) async fn execute_tools_parallel_with_phases(
712 tools: &HashMap<String, Arc<dyn Tool>>,
713 calls: &[crate::contracts::thread::ToolCall],
714 state: &Value,
715 phase_ctx: ToolPhaseContext<'_>,
716) -> Result<Vec<ToolExecutionResult>, AgentLoopError> {
717 use futures::future::join_all;
718
719 if is_cancelled(phase_ctx.cancellation_token) {
720 return Err(cancelled_error(phase_ctx.thread_id));
721 }
722
723 let run_policy_owned = phase_ctx.run_policy.clone();
725 let thread_id = phase_ctx.thread_id.to_string();
726 let thread_messages = Arc::new(phase_ctx.thread_messages.to_vec());
727 let tool_descriptors = phase_ctx.tool_descriptors.to_vec();
728 let agent = phase_ctx.agent_behavior;
729
730 let futures = calls.iter().map(|call| {
731 let tool = tools.get(&call.name).cloned();
732 let state = state.clone();
733 let call = call.clone();
734 let tool_descriptors = tool_descriptors.clone();
735 let activity_manager = phase_ctx.activity_manager.clone();
736 let rt = run_policy_owned.clone();
737 let run_identity = phase_ctx.run_identity.clone();
738 let caller_context = phase_ctx.caller_context.clone();
739 let sid = thread_id.clone();
740 let thread_messages = thread_messages.clone();
741
742 async move {
743 execute_single_tool_with_phases_impl(
744 tool.as_deref(),
745 &call,
746 &state,
747 &ToolPhaseContext {
748 tool_descriptors: &tool_descriptors,
749 agent_behavior: agent,
750 activity_manager,
751 run_policy: &rt,
752 run_identity,
753 caller_context,
754 thread_id: &sid,
755 thread_messages: thread_messages.as_slice(),
756 cancellation_token: None,
757 },
758 )
759 .await
760 }
761 });
762
763 let join_future = join_all(futures);
764 let results = match await_or_cancel(phase_ctx.cancellation_token, join_future).await {
765 CancelAware::Cancelled => return Err(cancelled_error(&thread_id)),
766 CancelAware::Value(results) => results,
767 };
768 let results: Vec<ToolExecutionResult> = results.into_iter().collect::<Result<_, _>>()?;
769 Ok(results)
770}
771
772pub(super) async fn execute_tools_sequential_with_phases(
774 tools: &HashMap<String, Arc<dyn Tool>>,
775 calls: &[crate::contracts::thread::ToolCall],
776 initial_state: &Value,
777 phase_ctx: ToolPhaseContext<'_>,
778) -> Result<Vec<ToolExecutionResult>, AgentLoopError> {
779 use tirea_state::apply_patch;
780
781 if is_cancelled(phase_ctx.cancellation_token) {
782 return Err(cancelled_error(phase_ctx.thread_id));
783 }
784
785 let mut state = initial_state.clone();
786 let mut results = Vec::with_capacity(calls.len());
787
788 for call in calls {
789 let tool = tools.get(&call.name).cloned();
790 let call_phase_ctx = ToolPhaseContext {
791 tool_descriptors: phase_ctx.tool_descriptors,
792 agent_behavior: phase_ctx.agent_behavior,
793 activity_manager: phase_ctx.activity_manager.clone(),
794 run_policy: phase_ctx.run_policy,
795 run_identity: phase_ctx.run_identity.clone(),
796 caller_context: phase_ctx.caller_context.clone(),
797 thread_id: phase_ctx.thread_id,
798 thread_messages: phase_ctx.thread_messages,
799 cancellation_token: None,
800 };
801 let result = match await_or_cancel(
802 phase_ctx.cancellation_token,
803 execute_single_tool_with_phases_impl(tool.as_deref(), call, &state, &call_phase_ctx),
804 )
805 .await
806 {
807 CancelAware::Cancelled => return Err(cancelled_error(phase_ctx.thread_id)),
808 CancelAware::Value(result) => result?,
809 };
810
811 if let Some(ref patch) = result.execution.patch {
813 state = apply_patch(&state, patch.patch()).map_err(|e| {
814 AgentLoopError::StateError(format!(
815 "failed to apply tool patch for call '{}': {}",
816 result.execution.call.id, e
817 ))
818 })?;
819 }
820 for pp in &result.pending_patches {
822 state = apply_patch(&state, pp.patch()).map_err(|e| {
823 AgentLoopError::StateError(format!(
824 "failed to apply plugin patch for call '{}': {}",
825 result.execution.call.id, e
826 ))
827 })?;
828 }
829
830 results.push(result);
831
832 if results
833 .last()
834 .is_some_and(|r| matches!(r.outcome, ToolCallOutcome::Suspended))
835 {
836 break;
837 }
838 }
839
840 Ok(results)
841}
842
843#[cfg(test)]
845pub(super) async fn execute_single_tool_with_phases(
846 tool: Option<&dyn Tool>,
847 call: &crate::contracts::thread::ToolCall,
848 state: &Value,
849 phase_ctx: &ToolPhaseContext<'_>,
850) -> Result<ToolExecutionResult, AgentLoopError> {
851 execute_single_tool_with_phases_impl(tool, call, state, phase_ctx).await
852}
853
854pub(super) async fn execute_single_tool_with_phases_deferred(
855 tool: Option<&dyn Tool>,
856 call: &crate::contracts::thread::ToolCall,
857 state: &Value,
858 phase_ctx: &ToolPhaseContext<'_>,
859) -> Result<ToolExecutionResult, AgentLoopError> {
860 execute_single_tool_with_phases_impl(tool, call, state, phase_ctx).await
861}
862
863async fn execute_single_tool_with_phases_impl(
864 tool: Option<&dyn Tool>,
865 call: &crate::contracts::thread::ToolCall,
866 state: &Value,
867 phase_ctx: &ToolPhaseContext<'_>,
868) -> Result<ToolExecutionResult, AgentLoopError> {
869 let doc = tirea_state::DocCell::new(state.clone());
871 let ops = std::sync::Mutex::new(Vec::new());
872 let pending_messages = std::sync::Mutex::new(Vec::new());
873 let plugin_scope = phase_ctx.run_policy;
874 let mut plugin_tool_call_ctx = crate::contracts::ToolCallContext::new(
875 &doc,
876 &ops,
877 "plugin_phase",
878 "plugin:tool_phase",
879 plugin_scope,
880 &pending_messages,
881 tirea_contract::runtime::activity::NoOpActivityManager::arc(),
882 )
883 .with_run_identity(phase_ctx.run_identity.clone())
884 .with_caller_context(phase_ctx.caller_context.clone());
885 if let Some(token) = phase_ctx.cancellation_token {
886 plugin_tool_call_ctx = plugin_tool_call_ctx.with_cancellation_token(token);
887 }
888
889 let mut step = StepContext::new(
891 plugin_tool_call_ctx,
892 phase_ctx.thread_id,
893 phase_ctx.thread_messages,
894 phase_ctx.tool_descriptors.to_vec(),
895 );
896 step.gate = Some(ToolGate::from_tool_call(call));
897 emit_tool_phase(
899 Phase::BeforeToolExecute,
900 &mut step,
901 phase_ctx.agent_behavior,
902 &doc,
903 )
904 .await?;
905
906 let (mut execution, outcome, suspended_call, tool_actions) = if step.tool_blocked() {
908 let reason = step
909 .gate
910 .as_ref()
911 .and_then(|g| g.block_reason.clone())
912 .unwrap_or_else(|| "Blocked by plugin".to_string());
913 (
914 ToolExecution {
915 call: call.clone(),
916 result: ToolResult::error(&call.name, reason),
917 patch: None,
918 },
919 ToolCallOutcome::Failed,
920 None,
921 Vec::<Box<dyn Action>>::new(),
922 )
923 } else if let Some(plugin_result) = step.tool_result().cloned() {
924 let outcome = ToolCallOutcome::from_tool_result(&plugin_result);
925 (
926 ToolExecution {
927 call: call.clone(),
928 result: plugin_result,
929 patch: None,
930 },
931 outcome,
932 None,
933 Vec::<Box<dyn Action>>::new(),
934 )
935 } else {
936 match tool {
937 None => (
938 ToolExecution {
939 call: call.clone(),
940 result: ToolResult::error(
941 &call.name,
942 format!("Tool '{}' not found", call.name),
943 ),
944 patch: None,
945 },
946 ToolCallOutcome::Failed,
947 None,
948 Vec::<Box<dyn Action>>::new(),
949 ),
950 Some(tool) => {
951 if let Err(e) = tool.validate_args(&call.arguments) {
952 (
953 ToolExecution {
954 call: call.clone(),
955 result: ToolResult::error(&call.name, e.to_string()),
956 patch: None,
957 },
958 ToolCallOutcome::Failed,
959 None,
960 Vec::<Box<dyn Action>>::new(),
961 )
962 } else if step.tool_pending() {
963 let Some(suspend_ticket) =
964 step.gate.as_ref().and_then(|g| g.suspend_ticket.clone())
965 else {
966 return Err(AgentLoopError::StateError(
967 "tool is pending but suspend ticket is missing".to_string(),
968 ));
969 };
970 (
971 ToolExecution {
972 call: call.clone(),
973 result: ToolResult::suspended(
974 &call.name,
975 "Execution suspended; awaiting external decision",
976 ),
977 patch: None,
978 },
979 ToolCallOutcome::Suspended,
980 Some(SuspendedCall::new(call, suspend_ticket)),
981 Vec::<Box<dyn Action>>::new(),
982 )
983 } else {
984 persist_tool_call_status(&step, call, ToolCallStatus::Running, None)?;
985 let tool_doc = tirea_state::DocCell::new(state.clone());
987 let tool_ops = std::sync::Mutex::new(Vec::new());
988 let tool_pending_msgs = std::sync::Mutex::new(Vec::new());
989 let mut tool_ctx = crate::contracts::ToolCallContext::new(
990 &tool_doc,
991 &tool_ops,
992 &call.id,
993 format!("tool:{}", call.name),
994 plugin_scope,
995 &tool_pending_msgs,
996 phase_ctx.activity_manager.clone(),
997 )
998 .with_run_identity(phase_ctx.run_identity.clone())
999 .with_caller_context(phase_ctx.caller_context.clone());
1000 if let Some(token) = phase_ctx.cancellation_token {
1001 tool_ctx = tool_ctx.with_cancellation_token(token);
1002 }
1003 let mut effect =
1004 match tool.execute_effect(call.arguments.clone(), &tool_ctx).await {
1005 Ok(effect) => effect,
1006 Err(e) => ToolExecutionEffect::from(ToolResult::error(
1007 &call.name,
1008 e.to_string(),
1009 )),
1010 };
1011
1012 let context_patch = tool_ctx.take_patch();
1013 if let Err(result) =
1014 merge_context_patch_into_effect(call, &mut effect, context_patch)
1015 {
1016 effect = ToolExecutionEffect::from(*result);
1017 }
1018 let (result, actions) = effect.into_parts();
1019 let outcome = ToolCallOutcome::from_tool_result(&result);
1020
1021 let suspended_call = if matches!(outcome, ToolCallOutcome::Suspended) {
1022 Some(suspended_call_from_tool_result(call, &result))
1023 } else {
1024 None
1025 };
1026
1027 (
1028 ToolExecution {
1029 call: call.clone(),
1030 result,
1031 patch: None,
1032 },
1033 outcome,
1034 suspended_call,
1035 actions,
1036 )
1037 }
1038 }
1039 }
1040 };
1041
1042 if let Some(gate) = step.gate.as_mut() {
1044 gate.result = Some(execution.result.clone());
1045 }
1046
1047 let mut tool_state_actions = Vec::<AnyStateAction>::new();
1050 let mut other_actions = Vec::<Box<dyn Action>>::new();
1051 for action in tool_actions {
1052 if action.is_state_action() {
1053 if let Some(sa) = action.into_state_action() {
1054 tool_state_actions.push(sa);
1055 }
1056 } else {
1057 other_actions.push(action);
1058 }
1059 }
1060 for action in &other_actions {
1062 action
1063 .validate(Phase::AfterToolExecute)
1064 .map_err(AgentLoopError::StateError)?;
1065 }
1066 for action in other_actions {
1067 action.apply(&mut step);
1068 }
1069
1070 emit_tool_phase(
1072 Phase::AfterToolExecute,
1073 &mut step,
1074 phase_ctx.agent_behavior,
1075 &doc,
1076 )
1077 .await?;
1078
1079 let terminal_tool_call_state = match outcome {
1080 ToolCallOutcome::Suspended => Some(persist_tool_call_status(
1081 &step,
1082 call,
1083 ToolCallStatus::Suspended,
1084 suspended_call.as_ref(),
1085 )?),
1086 ToolCallOutcome::Succeeded => Some(persist_tool_call_status(
1087 &step,
1088 call,
1089 ToolCallStatus::Succeeded,
1090 None,
1091 )?),
1092 ToolCallOutcome::Failed => Some(persist_tool_call_status(
1093 &step,
1094 call,
1095 ToolCallStatus::Failed,
1096 None,
1097 )?),
1098 };
1099
1100 if let Some(tool_call_state) = terminal_tool_call_state {
1101 tool_state_actions.push(tool_call_state.into_state_action());
1102 }
1103
1104 if !matches!(outcome, ToolCallOutcome::Suspended) {
1107 let cleanup_path = format!("__tool_call_scope.{}.suspended_call", call.id);
1108 let cleanup_patch = Patch::with_ops(vec![tirea_state::Op::delete(
1109 tirea_state::parse_path(&cleanup_path),
1110 )]);
1111 let tracked = TrackedPatch::new(cleanup_patch).with_source("framework:scope_cleanup");
1112 step.emit_patch(tracked);
1113 }
1114
1115 let mut serialized_state_actions: Vec<tirea_contract::SerializedStateAction> =
1117 tool_state_actions
1118 .iter()
1119 .map(|a| a.to_serialized_state_action())
1120 .collect();
1121
1122 let tool_scope_ctx = ScopeContext::for_call(&call.id);
1123 let execution_patch_parts = reduce_tool_state_actions(
1124 state,
1125 tool_state_actions,
1126 &format!("tool:{}", call.name),
1127 &tool_scope_ctx,
1128 )?;
1129 execution.patch = merge_tracked_patches(&execution_patch_parts, &format!("tool:{}", call.name));
1130
1131 let phase_base_state = if let Some(tool_patch) = execution.patch.as_ref() {
1132 tirea_state::apply_patch(state, tool_patch.patch()).map_err(|e| {
1133 AgentLoopError::StateError(format!(
1134 "failed to apply tool patch for call '{}': {}",
1135 call.id, e
1136 ))
1137 })?
1138 } else {
1139 state.clone()
1140 };
1141 let pending_patches = apply_tracked_patches_checked(
1142 &phase_base_state,
1143 std::mem::take(&mut step.pending_patches),
1144 &call.id,
1145 )?;
1146
1147 let reminders = step.messaging.reminders.clone();
1148 let user_messages = std::mem::take(&mut step.messaging.user_messages);
1149
1150 serialized_state_actions.extend(step.take_pending_serialized_state_actions());
1152
1153 Ok(ToolExecutionResult {
1154 execution,
1155 outcome,
1156 suspended_call,
1157 reminders,
1158 user_messages,
1159 pending_patches,
1160 serialized_state_actions,
1161 })
1162}
1163
1164fn reduce_tool_state_actions(
1165 base_state: &Value,
1166 actions: Vec<AnyStateAction>,
1167 source: &str,
1168 scope_ctx: &ScopeContext,
1169) -> Result<Vec<TrackedPatch>, AgentLoopError> {
1170 reduce_state_actions(actions, base_state, source, scope_ctx).map_err(|e| {
1171 AgentLoopError::StateError(format!("failed to reduce tool state actions: {e}"))
1172 })
1173}
1174
1175fn merge_tracked_patches(patches: &[TrackedPatch], source: &str) -> Option<TrackedPatch> {
1176 let mut merged = Patch::new();
1177 for tracked in patches {
1178 merged.extend(tracked.patch().clone());
1179 }
1180 if merged.is_empty() {
1181 None
1182 } else {
1183 Some(TrackedPatch::new(merged).with_source(source.to_string()))
1184 }
1185}
1186
1187fn apply_tracked_patches_checked(
1188 base_state: &Value,
1189 patches: Vec<TrackedPatch>,
1190 call_id: &str,
1191) -> Result<Vec<TrackedPatch>, AgentLoopError> {
1192 let mut rolling = base_state.clone();
1193 let mut validated = Vec::with_capacity(patches.len());
1194 for tracked in patches {
1195 if tracked.patch().is_empty() {
1196 continue;
1197 }
1198 rolling = apply_patch(&rolling, tracked.patch()).map_err(|e| {
1199 AgentLoopError::StateError(format!(
1200 "failed to apply pending state patch for call '{}': {}",
1201 call_id, e
1202 ))
1203 })?;
1204 validated.push(tracked);
1205 }
1206 Ok(validated)
1207}
1208
1209fn cancelled_error(_thread_id: &str) -> AgentLoopError {
1210 AgentLoopError::Cancelled
1211}
1212
1213#[cfg(test)]
1214mod tests {
1215 use super::*;
1216 use serde_json::json;
1217 use tirea_state::Op;
1218
1219 #[test]
1220 fn apply_tracked_patches_checked_keeps_valid_sequence() {
1221 let patches = vec![
1222 TrackedPatch::new(Patch::new().with_op(Op::set(tirea_state::path!("alpha"), json!(1))))
1223 .with_source("test:first"),
1224 TrackedPatch::new(Patch::new().with_op(Op::set(tirea_state::path!("beta"), json!(2))))
1225 .with_source("test:second"),
1226 ];
1227
1228 let validated =
1229 apply_tracked_patches_checked(&json!({}), patches, "call_1").expect("patches valid");
1230
1231 assert_eq!(validated.len(), 2);
1232 assert_eq!(validated[0].patch().ops().len(), 1);
1233 assert_eq!(validated[1].patch().ops().len(), 1);
1234 }
1235
1236 #[test]
1237 fn apply_tracked_patches_checked_reports_invalid_sequence() {
1238 let patches = vec![TrackedPatch::new(
1239 Patch::new().with_op(Op::increment(tirea_state::path!("counter"), 1_i64)),
1240 )
1241 .with_source("test:broken")];
1242
1243 let error = apply_tracked_patches_checked(&json!({}), patches, "call_1")
1244 .expect_err("increment against missing path should fail");
1245
1246 assert!(matches!(error, AgentLoopError::StateError(message)
1247 if message.contains("failed to apply pending state patch for call 'call_1'")));
1248 }
1249}