1use super::core::{
2 pending_approval_placeholder_message, transition_tool_call_state, ToolCallStateSeed,
3 ToolCallStateTransition,
4};
5use super::parallel_state_merge::merge_parallel_state_patches;
6use super::plugin_runtime::emit_tool_phase;
7use super::{
8 Agent, AgentLoopError, BaseAgent, RunCancellationToken, TOOL_SCOPE_CALLER_MESSAGES_KEY,
9 TOOL_SCOPE_CALLER_STATE_KEY, TOOL_SCOPE_CALLER_THREAD_ID_KEY,
10};
11use crate::contracts::runtime::action::Action;
12use crate::contracts::runtime::behavior::AgentBehavior;
13use crate::contracts::runtime::phase::{Phase, StepContext};
14use crate::contracts::runtime::state::{reduce_state_actions, AnyStateAction, ScopeContext};
15use crate::contracts::runtime::tool_call::ToolGate;
16use crate::contracts::runtime::tool_call::{Tool, ToolDescriptor, ToolResult};
17use crate::contracts::runtime::{
18 ActivityManager, PendingToolCall, SuspendTicket, SuspendedCall, ToolCallResumeMode,
19};
20use crate::contracts::runtime::{
21 DecisionReplayPolicy, StreamResult, ToolCallOutcome, ToolCallStatus, ToolExecution,
22 ToolExecutionEffect, ToolExecutionRequest, ToolExecutionResult, ToolExecutor,
23 ToolExecutorError,
24};
25use crate::contracts::thread::Thread;
26use crate::contracts::thread::{Message, MessageMetadata, ToolCall};
27use crate::contracts::{RunContext, Suspension};
28use crate::engine::convert::tool_response;
29use crate::engine::tool_execution::merge_context_patch_into_effect;
30use crate::runtime::run_context::{await_or_cancel, is_cancelled, CancelAware};
31use async_trait::async_trait;
32use serde_json::Value;
33use std::collections::HashMap;
34use std::sync::Arc;
35use tirea_state::{apply_patch, Patch, TrackedPatch};
36
37#[derive(Debug)]
42pub enum ExecuteToolsOutcome {
43 Completed(Thread),
45 Suspended {
47 thread: Thread,
48 suspended_call: Box<SuspendedCall>,
49 },
50}
51
52impl ExecuteToolsOutcome {
53 pub fn into_thread(self) -> Thread {
55 match self {
56 Self::Completed(t) | Self::Suspended { thread: t, .. } => t,
57 }
58 }
59
60 pub fn is_suspended(&self) -> bool {
62 matches!(self, Self::Suspended { .. })
63 }
64}
65
66pub(super) struct AppliedToolResults {
67 pub(super) suspended_calls: Vec<SuspendedCall>,
68 pub(super) state_snapshot: Option<Value>,
69}
70
71#[derive(Clone)]
72pub(super) struct ToolPhaseContext<'a> {
73 pub(super) tool_descriptors: &'a [ToolDescriptor],
74 pub(super) agent_behavior: Option<&'a dyn AgentBehavior>,
75 pub(super) activity_manager: Arc<dyn ActivityManager>,
76 pub(super) run_config: &'a tirea_contract::RunConfig,
77 pub(super) thread_id: &'a str,
78 pub(super) thread_messages: &'a [Arc<Message>],
79 pub(super) cancellation_token: Option<&'a RunCancellationToken>,
80}
81
82impl<'a> ToolPhaseContext<'a> {
83 pub(super) fn from_request(request: &'a ToolExecutionRequest<'a>) -> Self {
84 Self {
85 tool_descriptors: request.tool_descriptors,
86 agent_behavior: request.agent_behavior,
87 activity_manager: request.activity_manager.clone(),
88 run_config: request.run_config,
89 thread_id: request.thread_id,
90 thread_messages: request.thread_messages,
91 cancellation_token: request.cancellation_token,
92 }
93 }
94}
95
96fn now_unix_millis() -> u64 {
97 std::time::SystemTime::now()
98 .duration_since(std::time::UNIX_EPOCH)
99 .map_or(0, |d| d.as_millis().min(u128::from(u64::MAX)) as u64)
100}
101
102fn suspended_call_from_tool_result(call: &ToolCall, result: &ToolResult) -> SuspendedCall {
103 if let Some(mut explicit) = result.suspension() {
104 if explicit.pending.id.trim().is_empty() || explicit.pending.name.trim().is_empty() {
105 explicit.pending =
106 PendingToolCall::new(call.id.clone(), call.name.clone(), call.arguments.clone());
107 }
108 return SuspendedCall::new(call, explicit);
109 }
110
111 let mut suspension = Suspension::new(&call.id, format!("tool:{}", call.name))
112 .with_parameters(call.arguments.clone());
113 if let Some(message) = result.message.as_ref() {
114 suspension = suspension.with_message(message.clone());
115 }
116
117 SuspendedCall::new(
118 call,
119 SuspendTicket::new(
120 suspension,
121 PendingToolCall::new(call.id.clone(), call.name.clone(), call.arguments.clone()),
122 ToolCallResumeMode::ReplayToolCall,
123 ),
124 )
125}
126
127fn persist_tool_call_status(
128 step: &StepContext<'_>,
129 call: &ToolCall,
130 status: ToolCallStatus,
131 suspended_call: Option<&SuspendedCall>,
132) -> Result<crate::contracts::runtime::ToolCallState, AgentLoopError> {
133 let current_state = step.ctx().tool_call_state_for(&call.id).map_err(|e| {
134 AgentLoopError::StateError(format!(
135 "failed to read tool call state for '{}' before setting {:?}: {e}",
136 call.id, status
137 ))
138 })?;
139 let previous_status = current_state
140 .as_ref()
141 .map(|state| state.status)
142 .unwrap_or(ToolCallStatus::New);
143 let current_resume_token = current_state
144 .as_ref()
145 .and_then(|state| state.resume_token.clone());
146 let current_resume = current_state
147 .as_ref()
148 .and_then(|state| state.resume.clone());
149
150 let (next_resume_token, next_resume) = match status {
151 ToolCallStatus::Running => {
152 if matches!(previous_status, ToolCallStatus::Resuming) {
153 (current_resume_token.clone(), current_resume.clone())
154 } else {
155 (None, None)
156 }
157 }
158 ToolCallStatus::Suspended => (
159 suspended_call
160 .map(|entry| entry.ticket.pending.id.clone())
161 .or(current_resume_token.clone()),
162 None,
163 ),
164 ToolCallStatus::Succeeded
165 | ToolCallStatus::Failed
166 | ToolCallStatus::Cancelled
167 | ToolCallStatus::New
168 | ToolCallStatus::Resuming => (current_resume_token, current_resume),
169 };
170
171 let Some(runtime_state) = transition_tool_call_state(
172 current_state,
173 ToolCallStateSeed {
174 call_id: &call.id,
175 tool_name: &call.name,
176 arguments: &call.arguments,
177 status: ToolCallStatus::New,
178 resume_token: None,
179 },
180 ToolCallStateTransition {
181 status,
182 resume_token: next_resume_token,
183 resume: next_resume,
184 updated_at: now_unix_millis(),
185 },
186 ) else {
187 return Err(AgentLoopError::StateError(format!(
188 "invalid tool call status transition for '{}': {:?} -> {:?}",
189 call.id, previous_status, status
190 )));
191 };
192
193 step.ctx()
194 .set_tool_call_state_for(&call.id, runtime_state.clone())
195 .map_err(|e| {
196 AgentLoopError::StateError(format!(
197 "failed to persist tool call state for '{}' as {:?}: {e}",
198 call.id, status
199 ))
200 })?;
201
202 Ok(runtime_state)
203}
204
205fn map_tool_executor_error(err: AgentLoopError, thread_id: &str) -> ToolExecutorError {
206 match err {
207 AgentLoopError::Cancelled => ToolExecutorError::Cancelled {
208 thread_id: thread_id.to_string(),
209 },
210 other => ToolExecutorError::Failed {
211 message: other.to_string(),
212 },
213 }
214}
215
216#[derive(Debug, Clone, Copy, PartialEq, Eq)]
218pub enum ParallelToolExecutionMode {
219 BatchApproval,
220 Streaming,
221}
222
223#[derive(Debug, Clone, Copy, PartialEq, Eq)]
225pub struct ParallelToolExecutor {
226 mode: ParallelToolExecutionMode,
227}
228
229impl ParallelToolExecutor {
230 pub const fn batch_approval() -> Self {
231 Self {
232 mode: ParallelToolExecutionMode::BatchApproval,
233 }
234 }
235
236 pub const fn streaming() -> Self {
237 Self {
238 mode: ParallelToolExecutionMode::Streaming,
239 }
240 }
241
242 fn mode_name(self) -> &'static str {
243 match self.mode {
244 ParallelToolExecutionMode::BatchApproval => "parallel_batch_approval",
245 ParallelToolExecutionMode::Streaming => "parallel_streaming",
246 }
247 }
248}
249
250impl Default for ParallelToolExecutor {
251 fn default() -> Self {
252 Self::streaming()
253 }
254}
255
256#[async_trait]
257impl ToolExecutor for ParallelToolExecutor {
258 async fn execute(
259 &self,
260 request: ToolExecutionRequest<'_>,
261 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
262 let thread_id = request.thread_id;
263 let phase_ctx = ToolPhaseContext::from_request(&request);
264 execute_tools_parallel_with_phases(request.tools, request.calls, request.state, phase_ctx)
265 .await
266 .map_err(|e| map_tool_executor_error(e, thread_id))
267 }
268
269 fn name(&self) -> &'static str {
270 self.mode_name()
271 }
272
273 fn requires_parallel_patch_conflict_check(&self) -> bool {
274 true
275 }
276
277 fn decision_replay_policy(&self) -> DecisionReplayPolicy {
278 match self.mode {
279 ParallelToolExecutionMode::BatchApproval => DecisionReplayPolicy::BatchAllSuspended,
280 ParallelToolExecutionMode::Streaming => DecisionReplayPolicy::Immediate,
281 }
282 }
283}
284
285#[derive(Debug, Clone, Copy, Default)]
287pub struct SequentialToolExecutor;
288
289#[async_trait]
290impl ToolExecutor for SequentialToolExecutor {
291 async fn execute(
292 &self,
293 request: ToolExecutionRequest<'_>,
294 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
295 let thread_id = request.thread_id;
296 let phase_ctx = ToolPhaseContext::from_request(&request);
297 execute_tools_sequential_with_phases(request.tools, request.calls, request.state, phase_ctx)
298 .await
299 .map_err(|e| map_tool_executor_error(e, thread_id))
300 }
301
302 fn name(&self) -> &'static str {
303 "sequential"
304 }
305}
306
307pub(super) fn apply_tool_results_to_session(
308 run_ctx: &mut RunContext,
309 results: &[ToolExecutionResult],
310 metadata: Option<MessageMetadata>,
311 check_parallel_patch_conflicts: bool,
312) -> Result<AppliedToolResults, AgentLoopError> {
313 apply_tool_results_impl(
314 run_ctx,
315 results,
316 metadata,
317 check_parallel_patch_conflicts,
318 None,
319 )
320}
321
322pub(super) fn apply_tool_results_impl(
323 run_ctx: &mut RunContext,
324 results: &[ToolExecutionResult],
325 metadata: Option<MessageMetadata>,
326 check_parallel_patch_conflicts: bool,
327 tool_msg_ids: Option<&HashMap<String, String>>,
328) -> Result<AppliedToolResults, AgentLoopError> {
329 let suspended: Vec<SuspendedCall> = results
331 .iter()
332 .filter_map(|r| {
333 if matches!(r.outcome, ToolCallOutcome::Suspended) {
334 r.suspended_call.clone()
335 } else {
336 None
337 }
338 })
339 .collect();
340
341 let all_serialized_state_actions: Vec<tirea_contract::SerializedStateAction> = results
343 .iter()
344 .flat_map(|r| r.serialized_state_actions.iter().cloned())
345 .collect();
346 if !all_serialized_state_actions.is_empty() {
347 run_ctx.add_serialized_state_actions(all_serialized_state_actions);
348 }
349
350 let base_snapshot = run_ctx
351 .snapshot()
352 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
353 let patches = merge_parallel_state_patches(
354 &base_snapshot,
355 results,
356 check_parallel_patch_conflicts,
357 run_ctx.lattice_registry(),
358 )?;
359 let mut state_changed = !patches.is_empty();
360 run_ctx.add_thread_patches(patches);
361
362 let tool_messages: Vec<Arc<Message>> = results
364 .iter()
365 .flat_map(|r| {
366 let is_suspended = matches!(r.outcome, ToolCallOutcome::Suspended);
367 let mut msgs = if is_suspended {
368 vec![Message::tool(
369 &r.execution.call.id,
370 pending_approval_placeholder_message(&r.execution.call.name),
371 )]
372 } else {
373 let mut tool_msg = tool_response(&r.execution.call.id, &r.execution.result);
374 if let Some(id) = tool_msg_ids.and_then(|ids| ids.get(&r.execution.call.id)) {
375 tool_msg = tool_msg.with_id(id.clone());
376 }
377 vec![tool_msg]
378 };
379 for reminder in &r.reminders {
380 msgs.push(Message::internal_system(format!(
381 "<system-reminder>{}</system-reminder>",
382 reminder
383 )));
384 }
385 if let Some(ref meta) = metadata {
386 for msg in &mut msgs {
387 msg.metadata = Some(meta.clone());
388 }
389 }
390 msgs.into_iter().map(Arc::new).collect::<Vec<_>>()
391 })
392 .collect();
393
394 run_ctx.add_messages(tool_messages);
395
396 let user_messages: Vec<Arc<Message>> = results
398 .iter()
399 .flat_map(|r| {
400 r.user_messages
401 .iter()
402 .map(|s| s.trim())
403 .filter(|s| !s.is_empty())
404 .map(|text| {
405 let mut msg = Message::user(text.to_string());
406 if let Some(ref meta) = metadata {
407 msg.metadata = Some(meta.clone());
408 }
409 Arc::new(msg)
410 })
411 .collect::<Vec<_>>()
412 })
413 .collect();
414 if !user_messages.is_empty() {
415 run_ctx.add_messages(user_messages);
416 }
417 if !suspended.is_empty() {
418 let state = run_ctx
419 .snapshot()
420 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
421 let actions: Vec<AnyStateAction> = suspended
422 .iter()
423 .map(|call| call.clone().into_state_action())
424 .collect();
425 let patches = reduce_state_actions(actions, &state, "agent_loop", &ScopeContext::run())
426 .map_err(|e| {
427 AgentLoopError::StateError(format!("failed to reduce suspended call actions: {e}"))
428 })?;
429 for patch in patches {
430 if !patch.patch().is_empty() {
431 state_changed = true;
432 run_ctx.add_thread_patch(patch);
433 }
434 }
435 let state_snapshot = if state_changed {
436 Some(
437 run_ctx
438 .snapshot()
439 .map_err(|e| AgentLoopError::StateError(e.to_string()))?,
440 )
441 } else {
442 None
443 };
444 return Ok(AppliedToolResults {
445 suspended_calls: suspended,
446 state_snapshot,
447 });
448 }
449
450 let state_snapshot = if state_changed {
457 Some(
458 run_ctx
459 .snapshot()
460 .map_err(|e| AgentLoopError::StateError(e.to_string()))?,
461 )
462 } else {
463 None
464 };
465
466 Ok(AppliedToolResults {
467 suspended_calls: Vec::new(),
468 state_snapshot,
469 })
470}
471
472fn tool_result_metadata_from_run_ctx(
473 run_ctx: &RunContext,
474 run_id: Option<&str>,
475) -> Option<MessageMetadata> {
476 let run_id = run_id.map(|id| id.to_string()).or_else(|| {
477 run_ctx.messages().iter().rev().find_map(|m| {
478 m.metadata
479 .as_ref()
480 .and_then(|meta| meta.run_id.as_ref().cloned())
481 })
482 });
483
484 let step_index = run_ctx
485 .messages()
486 .iter()
487 .rev()
488 .find_map(|m| m.metadata.as_ref().and_then(|meta| meta.step_index));
489
490 if run_id.is_none() && step_index.is_none() {
491 None
492 } else {
493 Some(MessageMetadata { run_id, step_index })
494 }
495}
496
497#[allow(dead_code)]
498pub(super) fn next_step_index(run_ctx: &RunContext) -> u32 {
499 run_ctx
500 .messages()
501 .iter()
502 .filter_map(|m| m.metadata.as_ref().and_then(|meta| meta.step_index))
503 .max()
504 .map(|v| v.saturating_add(1))
505 .unwrap_or(0)
506}
507
508pub(super) fn step_metadata(run_id: Option<String>, step_index: u32) -> MessageMetadata {
509 MessageMetadata {
510 run_id,
511 step_index: Some(step_index),
512 }
513}
514
515pub async fn execute_tools(
519 thread: Thread,
520 result: &StreamResult,
521 tools: &HashMap<String, Arc<dyn Tool>>,
522 parallel: bool,
523) -> Result<ExecuteToolsOutcome, AgentLoopError> {
524 let parallel_executor = ParallelToolExecutor::streaming();
525 let sequential_executor = SequentialToolExecutor;
526 let executor: &dyn ToolExecutor = if parallel {
527 ¶llel_executor
528 } else {
529 &sequential_executor
530 };
531 execute_tools_with_agent_and_executor(thread, result, tools, executor, None).await
532}
533
534pub async fn execute_tools_with_config(
536 thread: Thread,
537 result: &StreamResult,
538 tools: &HashMap<String, Arc<dyn Tool>>,
539 agent: &dyn Agent,
540) -> Result<ExecuteToolsOutcome, AgentLoopError> {
541 execute_tools_with_agent_and_executor(
542 thread,
543 result,
544 tools,
545 agent.tool_executor().as_ref(),
546 Some(agent.behavior()),
547 )
548 .await
549}
550
551pub(super) fn scope_with_tool_caller_context(
552 run_ctx: &RunContext,
553 state: &Value,
554) -> Result<tirea_contract::RunConfig, AgentLoopError> {
555 let mut rt = run_ctx.run_config.clone();
556 if rt.value(TOOL_SCOPE_CALLER_THREAD_ID_KEY).is_none() {
557 rt.set(
558 TOOL_SCOPE_CALLER_THREAD_ID_KEY,
559 run_ctx.thread_id().to_string(),
560 )
561 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
562 }
563 if rt.value(TOOL_SCOPE_CALLER_STATE_KEY).is_none() {
564 rt.set(TOOL_SCOPE_CALLER_STATE_KEY, state.clone())
565 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
566 }
567 if rt.value(TOOL_SCOPE_CALLER_MESSAGES_KEY).is_none() {
568 rt.set(TOOL_SCOPE_CALLER_MESSAGES_KEY, run_ctx.messages().to_vec())
569 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
570 }
571 Ok(rt)
572}
573
574pub async fn execute_tools_with_behaviors(
576 thread: Thread,
577 result: &StreamResult,
578 tools: &HashMap<String, Arc<dyn Tool>>,
579 parallel: bool,
580 behavior: Arc<dyn AgentBehavior>,
581) -> Result<ExecuteToolsOutcome, AgentLoopError> {
582 let executor: Arc<dyn ToolExecutor> = if parallel {
583 Arc::new(ParallelToolExecutor::streaming())
584 } else {
585 Arc::new(SequentialToolExecutor)
586 };
587 let agent = BaseAgent::default()
588 .with_behavior(behavior)
589 .with_tool_executor(executor);
590 execute_tools_with_config(thread, result, tools, &agent).await
591}
592
593async fn execute_tools_with_agent_and_executor(
594 thread: Thread,
595 result: &StreamResult,
596 tools: &HashMap<String, Arc<dyn Tool>>,
597 executor: &dyn ToolExecutor,
598 behavior: Option<&dyn AgentBehavior>,
599) -> Result<ExecuteToolsOutcome, AgentLoopError> {
600 let rebuilt_state = thread
602 .rebuild_state()
603 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
604 let mut run_ctx = RunContext::new(
605 &thread.id,
606 rebuilt_state.clone(),
607 thread.messages.clone(),
608 tirea_contract::RunConfig::default(),
609 );
610
611 let tool_descriptors: Vec<ToolDescriptor> =
612 tools.values().map(|t| t.descriptor().clone()).collect();
613 if let Some(behavior) = behavior {
615 let run_start_patches = super::plugin_runtime::behavior_run_phase_block(
616 &run_ctx,
617 &tool_descriptors,
618 behavior,
619 &[Phase::RunStart],
620 |_| {},
621 |_| (),
622 )
623 .await?
624 .1;
625 if !run_start_patches.is_empty() {
626 run_ctx.add_thread_patches(run_start_patches);
627 }
628 }
629
630 let replay_executor: Arc<dyn ToolExecutor> = match executor.decision_replay_policy() {
631 DecisionReplayPolicy::BatchAllSuspended => Arc::new(ParallelToolExecutor::batch_approval()),
632 DecisionReplayPolicy::Immediate => Arc::new(ParallelToolExecutor::streaming()),
633 };
634 let replay_config = BaseAgent::default().with_tool_executor(replay_executor);
635 let replay = super::drain_resuming_tool_calls_and_replay(
636 &mut run_ctx,
637 tools,
638 &replay_config,
639 &tool_descriptors,
640 )
641 .await?;
642
643 if replay.replayed {
644 let suspended = run_ctx.suspended_calls().values().next().cloned();
645 let delta = run_ctx.take_delta();
646 let mut out_thread = thread;
647 for msg in delta.messages {
648 out_thread = out_thread.with_message((*msg).clone());
649 }
650 out_thread = out_thread.with_patches(delta.patches);
651 return if let Some(first) = suspended {
652 Ok(ExecuteToolsOutcome::Suspended {
653 thread: out_thread,
654 suspended_call: Box::new(first),
655 })
656 } else {
657 Ok(ExecuteToolsOutcome::Completed(out_thread))
658 };
659 }
660
661 if result.tool_calls.is_empty() {
662 let delta = run_ctx.take_delta();
663 let mut out_thread = thread;
664 for msg in delta.messages {
665 out_thread = out_thread.with_message((*msg).clone());
666 }
667 out_thread = out_thread.with_patches(delta.patches);
668 return Ok(ExecuteToolsOutcome::Completed(out_thread));
669 }
670
671 let current_state = run_ctx
672 .snapshot()
673 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
674 let rt_for_tools = scope_with_tool_caller_context(&run_ctx, ¤t_state)?;
675 let results = executor
676 .execute(ToolExecutionRequest {
677 tools,
678 calls: &result.tool_calls,
679 state: ¤t_state,
680 tool_descriptors: &tool_descriptors,
681 agent_behavior: behavior,
682 activity_manager: tirea_contract::runtime::activity::NoOpActivityManager::arc(),
683 run_config: &rt_for_tools,
684 thread_id: run_ctx.thread_id(),
685 thread_messages: run_ctx.messages(),
686 state_version: run_ctx.version(),
687 cancellation_token: None,
688 })
689 .await?;
690
691 let metadata = tool_result_metadata_from_run_ctx(&run_ctx, None);
692 let applied = apply_tool_results_to_session(
693 &mut run_ctx,
694 &results,
695 metadata,
696 executor.requires_parallel_patch_conflict_check(),
697 )?;
698 let suspended = applied.suspended_calls.into_iter().next();
699
700 let delta = run_ctx.take_delta();
702 let mut out_thread = thread;
703 for msg in delta.messages {
704 out_thread = out_thread.with_message((*msg).clone());
705 }
706 out_thread = out_thread.with_patches(delta.patches);
707
708 if let Some(first) = suspended {
709 Ok(ExecuteToolsOutcome::Suspended {
710 thread: out_thread,
711 suspended_call: Box::new(first),
712 })
713 } else {
714 Ok(ExecuteToolsOutcome::Completed(out_thread))
715 }
716}
717
718pub(super) async fn execute_tools_parallel_with_phases(
720 tools: &HashMap<String, Arc<dyn Tool>>,
721 calls: &[crate::contracts::thread::ToolCall],
722 state: &Value,
723 phase_ctx: ToolPhaseContext<'_>,
724) -> Result<Vec<ToolExecutionResult>, AgentLoopError> {
725 use futures::future::join_all;
726
727 if is_cancelled(phase_ctx.cancellation_token) {
728 return Err(cancelled_error(phase_ctx.thread_id));
729 }
730
731 let run_config_owned = phase_ctx.run_config.clone();
733 let thread_id = phase_ctx.thread_id.to_string();
734 let thread_messages = Arc::new(phase_ctx.thread_messages.to_vec());
735 let tool_descriptors = phase_ctx.tool_descriptors.to_vec();
736 let agent = phase_ctx.agent_behavior;
737
738 let futures = calls.iter().map(|call| {
739 let tool = tools.get(&call.name).cloned();
740 let state = state.clone();
741 let call = call.clone();
742 let tool_descriptors = tool_descriptors.clone();
743 let activity_manager = phase_ctx.activity_manager.clone();
744 let rt = run_config_owned.clone();
745 let sid = thread_id.clone();
746 let thread_messages = thread_messages.clone();
747
748 async move {
749 execute_single_tool_with_phases_impl(
750 tool.as_deref(),
751 &call,
752 &state,
753 &ToolPhaseContext {
754 tool_descriptors: &tool_descriptors,
755 agent_behavior: agent,
756 activity_manager,
757 run_config: &rt,
758 thread_id: &sid,
759 thread_messages: thread_messages.as_slice(),
760 cancellation_token: None,
761 },
762 )
763 .await
764 }
765 });
766
767 let join_future = join_all(futures);
768 let results = match await_or_cancel(phase_ctx.cancellation_token, join_future).await {
769 CancelAware::Cancelled => return Err(cancelled_error(&thread_id)),
770 CancelAware::Value(results) => results,
771 };
772 let results: Vec<ToolExecutionResult> = results.into_iter().collect::<Result<_, _>>()?;
773 Ok(results)
774}
775
776pub(super) async fn execute_tools_sequential_with_phases(
778 tools: &HashMap<String, Arc<dyn Tool>>,
779 calls: &[crate::contracts::thread::ToolCall],
780 initial_state: &Value,
781 phase_ctx: ToolPhaseContext<'_>,
782) -> Result<Vec<ToolExecutionResult>, AgentLoopError> {
783 use tirea_state::apply_patch;
784
785 if is_cancelled(phase_ctx.cancellation_token) {
786 return Err(cancelled_error(phase_ctx.thread_id));
787 }
788
789 let mut state = initial_state.clone();
790 let mut results = Vec::with_capacity(calls.len());
791
792 for call in calls {
793 let tool = tools.get(&call.name).cloned();
794 let call_phase_ctx = ToolPhaseContext {
795 tool_descriptors: phase_ctx.tool_descriptors,
796 agent_behavior: phase_ctx.agent_behavior,
797 activity_manager: phase_ctx.activity_manager.clone(),
798 run_config: phase_ctx.run_config,
799 thread_id: phase_ctx.thread_id,
800 thread_messages: phase_ctx.thread_messages,
801 cancellation_token: None,
802 };
803 let result = match await_or_cancel(
804 phase_ctx.cancellation_token,
805 execute_single_tool_with_phases_impl(tool.as_deref(), call, &state, &call_phase_ctx),
806 )
807 .await
808 {
809 CancelAware::Cancelled => return Err(cancelled_error(phase_ctx.thread_id)),
810 CancelAware::Value(result) => result?,
811 };
812
813 if let Some(ref patch) = result.execution.patch {
815 state = apply_patch(&state, patch.patch()).map_err(|e| {
816 AgentLoopError::StateError(format!(
817 "failed to apply tool patch for call '{}': {}",
818 result.execution.call.id, e
819 ))
820 })?;
821 }
822 for pp in &result.pending_patches {
824 state = apply_patch(&state, pp.patch()).map_err(|e| {
825 AgentLoopError::StateError(format!(
826 "failed to apply plugin patch for call '{}': {}",
827 result.execution.call.id, e
828 ))
829 })?;
830 }
831
832 results.push(result);
833
834 if results
835 .last()
836 .is_some_and(|r| matches!(r.outcome, ToolCallOutcome::Suspended))
837 {
838 break;
839 }
840 }
841
842 Ok(results)
843}
844
845#[cfg(test)]
847pub(super) async fn execute_single_tool_with_phases(
848 tool: Option<&dyn Tool>,
849 call: &crate::contracts::thread::ToolCall,
850 state: &Value,
851 phase_ctx: &ToolPhaseContext<'_>,
852) -> Result<ToolExecutionResult, AgentLoopError> {
853 execute_single_tool_with_phases_impl(tool, call, state, phase_ctx).await
854}
855
856pub(super) async fn execute_single_tool_with_phases_deferred(
857 tool: Option<&dyn Tool>,
858 call: &crate::contracts::thread::ToolCall,
859 state: &Value,
860 phase_ctx: &ToolPhaseContext<'_>,
861) -> Result<ToolExecutionResult, AgentLoopError> {
862 execute_single_tool_with_phases_impl(tool, call, state, phase_ctx).await
863}
864
865async fn execute_single_tool_with_phases_impl(
866 tool: Option<&dyn Tool>,
867 call: &crate::contracts::thread::ToolCall,
868 state: &Value,
869 phase_ctx: &ToolPhaseContext<'_>,
870) -> Result<ToolExecutionResult, AgentLoopError> {
871 let doc = tirea_state::DocCell::new(state.clone());
873 let ops = std::sync::Mutex::new(Vec::new());
874 let pending_messages = std::sync::Mutex::new(Vec::new());
875 let plugin_scope = phase_ctx.run_config;
876 let mut plugin_tool_call_ctx = crate::contracts::ToolCallContext::new(
877 &doc,
878 &ops,
879 "plugin_phase",
880 "plugin:tool_phase",
881 plugin_scope,
882 &pending_messages,
883 tirea_contract::runtime::activity::NoOpActivityManager::arc(),
884 );
885 if let Some(token) = phase_ctx.cancellation_token {
886 plugin_tool_call_ctx = plugin_tool_call_ctx.with_cancellation_token(token);
887 }
888
889 let mut step = StepContext::new(
891 plugin_tool_call_ctx,
892 phase_ctx.thread_id,
893 phase_ctx.thread_messages,
894 phase_ctx.tool_descriptors.to_vec(),
895 );
896 step.gate = Some(ToolGate::from_tool_call(call));
897 emit_tool_phase(
899 Phase::BeforeToolExecute,
900 &mut step,
901 phase_ctx.agent_behavior,
902 &doc,
903 )
904 .await?;
905
906 let (mut execution, outcome, suspended_call, tool_actions) = if step.tool_blocked() {
908 let reason = step
909 .gate
910 .as_ref()
911 .and_then(|g| g.block_reason.clone())
912 .unwrap_or_else(|| "Blocked by plugin".to_string());
913 (
914 ToolExecution {
915 call: call.clone(),
916 result: ToolResult::error(&call.name, reason),
917 patch: None,
918 },
919 ToolCallOutcome::Failed,
920 None,
921 Vec::<Box<dyn Action>>::new(),
922 )
923 } else if let Some(plugin_result) = step.tool_result().cloned() {
924 let outcome = ToolCallOutcome::from_tool_result(&plugin_result);
925 (
926 ToolExecution {
927 call: call.clone(),
928 result: plugin_result,
929 patch: None,
930 },
931 outcome,
932 None,
933 Vec::<Box<dyn Action>>::new(),
934 )
935 } else {
936 match tool {
937 None => (
938 ToolExecution {
939 call: call.clone(),
940 result: ToolResult::error(
941 &call.name,
942 format!("Tool '{}' not found", call.name),
943 ),
944 patch: None,
945 },
946 ToolCallOutcome::Failed,
947 None,
948 Vec::<Box<dyn Action>>::new(),
949 ),
950 Some(tool) => {
951 if let Err(e) = tool.validate_args(&call.arguments) {
952 (
953 ToolExecution {
954 call: call.clone(),
955 result: ToolResult::error(&call.name, e.to_string()),
956 patch: None,
957 },
958 ToolCallOutcome::Failed,
959 None,
960 Vec::<Box<dyn Action>>::new(),
961 )
962 } else if step.tool_pending() {
963 let Some(suspend_ticket) =
964 step.gate.as_ref().and_then(|g| g.suspend_ticket.clone())
965 else {
966 return Err(AgentLoopError::StateError(
967 "tool is pending but suspend ticket is missing".to_string(),
968 ));
969 };
970 (
971 ToolExecution {
972 call: call.clone(),
973 result: ToolResult::suspended(
974 &call.name,
975 "Execution suspended; awaiting external decision",
976 ),
977 patch: None,
978 },
979 ToolCallOutcome::Suspended,
980 Some(SuspendedCall::new(call, suspend_ticket)),
981 Vec::<Box<dyn Action>>::new(),
982 )
983 } else {
984 persist_tool_call_status(&step, call, ToolCallStatus::Running, None)?;
985 let tool_doc = tirea_state::DocCell::new(state.clone());
987 let tool_ops = std::sync::Mutex::new(Vec::new());
988 let tool_pending_msgs = std::sync::Mutex::new(Vec::new());
989 let mut tool_ctx = crate::contracts::ToolCallContext::new(
990 &tool_doc,
991 &tool_ops,
992 &call.id,
993 format!("tool:{}", call.name),
994 plugin_scope,
995 &tool_pending_msgs,
996 phase_ctx.activity_manager.clone(),
997 );
998 if let Some(token) = phase_ctx.cancellation_token {
999 tool_ctx = tool_ctx.with_cancellation_token(token);
1000 }
1001 let mut effect =
1002 match tool.execute_effect(call.arguments.clone(), &tool_ctx).await {
1003 Ok(effect) => effect,
1004 Err(e) => ToolExecutionEffect::from(ToolResult::error(
1005 &call.name,
1006 e.to_string(),
1007 )),
1008 };
1009
1010 let context_patch = tool_ctx.take_patch();
1011 if let Err(result) =
1012 merge_context_patch_into_effect(call, &mut effect, context_patch)
1013 {
1014 effect = ToolExecutionEffect::from(*result);
1015 }
1016 let (result, actions) = effect.into_parts();
1017 let outcome = ToolCallOutcome::from_tool_result(&result);
1018
1019 let suspended_call = if matches!(outcome, ToolCallOutcome::Suspended) {
1020 Some(suspended_call_from_tool_result(call, &result))
1021 } else {
1022 None
1023 };
1024
1025 (
1026 ToolExecution {
1027 call: call.clone(),
1028 result,
1029 patch: None,
1030 },
1031 outcome,
1032 suspended_call,
1033 actions,
1034 )
1035 }
1036 }
1037 }
1038 };
1039
1040 if let Some(gate) = step.gate.as_mut() {
1042 gate.result = Some(execution.result.clone());
1043 }
1044
1045 let mut tool_state_actions = Vec::<AnyStateAction>::new();
1048 let mut other_actions = Vec::<Box<dyn Action>>::new();
1049 for action in tool_actions {
1050 if action.is_state_action() {
1051 if let Some(sa) = action.into_state_action() {
1052 tool_state_actions.push(sa);
1053 }
1054 } else {
1055 other_actions.push(action);
1056 }
1057 }
1058 for action in &other_actions {
1060 action
1061 .validate(Phase::AfterToolExecute)
1062 .map_err(AgentLoopError::StateError)?;
1063 }
1064 for action in other_actions {
1065 action.apply(&mut step);
1066 }
1067
1068 emit_tool_phase(
1070 Phase::AfterToolExecute,
1071 &mut step,
1072 phase_ctx.agent_behavior,
1073 &doc,
1074 )
1075 .await?;
1076
1077 let terminal_tool_call_state = match outcome {
1078 ToolCallOutcome::Suspended => Some(persist_tool_call_status(
1079 &step,
1080 call,
1081 ToolCallStatus::Suspended,
1082 suspended_call.as_ref(),
1083 )?),
1084 ToolCallOutcome::Succeeded => Some(persist_tool_call_status(
1085 &step,
1086 call,
1087 ToolCallStatus::Succeeded,
1088 None,
1089 )?),
1090 ToolCallOutcome::Failed => Some(persist_tool_call_status(
1091 &step,
1092 call,
1093 ToolCallStatus::Failed,
1094 None,
1095 )?),
1096 };
1097
1098 if let Some(tool_call_state) = terminal_tool_call_state {
1099 tool_state_actions.push(tool_call_state.into_state_action());
1100 }
1101
1102 if !matches!(outcome, ToolCallOutcome::Suspended) {
1105 let cleanup_path = format!("__tool_call_scope.{}.suspended_call", call.id);
1106 let cleanup_patch = Patch::with_ops(vec![tirea_state::Op::delete(
1107 tirea_state::parse_path(&cleanup_path),
1108 )]);
1109 let tracked = TrackedPatch::new(cleanup_patch).with_source("framework:scope_cleanup");
1110 step.emit_patch(tracked);
1111 }
1112
1113 let mut serialized_state_actions: Vec<tirea_contract::SerializedStateAction> =
1115 tool_state_actions
1116 .iter()
1117 .map(|a| a.to_serialized_state_action())
1118 .collect();
1119
1120 let tool_scope_ctx = ScopeContext::for_call(&call.id);
1121 let execution_patch_parts = reduce_tool_state_actions(
1122 state,
1123 tool_state_actions,
1124 &format!("tool:{}", call.name),
1125 &tool_scope_ctx,
1126 )?;
1127 execution.patch = merge_tracked_patches(&execution_patch_parts, &format!("tool:{}", call.name));
1128
1129 let phase_base_state = if let Some(tool_patch) = execution.patch.as_ref() {
1130 tirea_state::apply_patch(state, tool_patch.patch()).map_err(|e| {
1131 AgentLoopError::StateError(format!(
1132 "failed to apply tool patch for call '{}': {}",
1133 call.id, e
1134 ))
1135 })?
1136 } else {
1137 state.clone()
1138 };
1139 let pending_patches = apply_tracked_patches_checked(
1140 &phase_base_state,
1141 std::mem::take(&mut step.pending_patches),
1142 &call.id,
1143 )?;
1144
1145 let reminders = step.messaging.reminders.clone();
1146 let user_messages = std::mem::take(&mut step.messaging.user_messages);
1147
1148 serialized_state_actions.extend(step.take_pending_serialized_state_actions());
1150
1151 Ok(ToolExecutionResult {
1152 execution,
1153 outcome,
1154 suspended_call,
1155 reminders,
1156 user_messages,
1157 pending_patches,
1158 serialized_state_actions,
1159 })
1160}
1161
1162fn reduce_tool_state_actions(
1163 base_state: &Value,
1164 actions: Vec<AnyStateAction>,
1165 source: &str,
1166 scope_ctx: &ScopeContext,
1167) -> Result<Vec<TrackedPatch>, AgentLoopError> {
1168 reduce_state_actions(actions, base_state, source, scope_ctx).map_err(|e| {
1169 AgentLoopError::StateError(format!("failed to reduce tool state actions: {e}"))
1170 })
1171}
1172
1173fn merge_tracked_patches(patches: &[TrackedPatch], source: &str) -> Option<TrackedPatch> {
1174 let mut merged = Patch::new();
1175 for tracked in patches {
1176 merged.extend(tracked.patch().clone());
1177 }
1178 if merged.is_empty() {
1179 None
1180 } else {
1181 Some(TrackedPatch::new(merged).with_source(source.to_string()))
1182 }
1183}
1184
1185fn apply_tracked_patches_checked(
1186 base_state: &Value,
1187 patches: Vec<TrackedPatch>,
1188 call_id: &str,
1189) -> Result<Vec<TrackedPatch>, AgentLoopError> {
1190 let mut rolling = base_state.clone();
1191 let mut validated = Vec::with_capacity(patches.len());
1192 for tracked in patches {
1193 if tracked.patch().is_empty() {
1194 continue;
1195 }
1196 rolling = apply_patch(&rolling, tracked.patch()).map_err(|e| {
1197 AgentLoopError::StateError(format!(
1198 "failed to apply pending state patch for call '{}': {}",
1199 call_id, e
1200 ))
1201 })?;
1202 validated.push(tracked);
1203 }
1204 Ok(validated)
1205}
1206
1207fn cancelled_error(_thread_id: &str) -> AgentLoopError {
1208 AgentLoopError::Cancelled
1209}
1210
1211#[cfg(test)]
1212mod tests {
1213 use super::*;
1214 use serde_json::json;
1215 use tirea_state::Op;
1216
1217 #[test]
1218 fn apply_tracked_patches_checked_keeps_valid_sequence() {
1219 let patches = vec![
1220 TrackedPatch::new(Patch::new().with_op(Op::set(tirea_state::path!("alpha"), json!(1))))
1221 .with_source("test:first"),
1222 TrackedPatch::new(Patch::new().with_op(Op::set(tirea_state::path!("beta"), json!(2))))
1223 .with_source("test:second"),
1224 ];
1225
1226 let validated =
1227 apply_tracked_patches_checked(&json!({}), patches, "call_1").expect("patches valid");
1228
1229 assert_eq!(validated.len(), 2);
1230 assert_eq!(validated[0].patch().ops().len(), 1);
1231 assert_eq!(validated[1].patch().ops().len(), 1);
1232 }
1233
1234 #[test]
1235 fn apply_tracked_patches_checked_reports_invalid_sequence() {
1236 let patches = vec![TrackedPatch::new(
1237 Patch::new().with_op(Op::increment(tirea_state::path!("counter"), 1_i64)),
1238 )
1239 .with_source("test:broken")];
1240
1241 let error = apply_tracked_patches_checked(&json!({}), patches, "call_1")
1242 .expect_err("increment against missing path should fail");
1243
1244 assert!(matches!(error, AgentLoopError::StateError(message)
1245 if message.contains("failed to apply pending state patch for call 'call_1'")));
1246 }
1247}