1use super::core::{
2 drain_agent_append_user_messages, set_agent_suspended_calls, transition_tool_call_state,
3 ToolCallStateSeed, ToolCallStateTransition,
4};
5use super::plugin_runtime::emit_phase_checked;
6use super::{
7 AgentConfig, AgentLoopError, RunCancellationToken, TOOL_SCOPE_CALLER_MESSAGES_KEY,
8 TOOL_SCOPE_CALLER_STATE_KEY, TOOL_SCOPE_CALLER_THREAD_ID_KEY,
9};
10use crate::contracts::runtime::plugin::phase::{Phase, StateEffect, StepContext, ToolContext};
11use crate::contracts::runtime::plugin::AgentPlugin;
12use crate::contracts::runtime::{
13 ActivityManager, PendingToolCall, SuspendTicket, SuspendedCall, ToolCallResumeMode,
14};
15use crate::contracts::runtime::{
16 DecisionReplayPolicy, StreamResult, ToolCallOutcome, ToolCallStatus, ToolExecution,
17 ToolExecutionRequest, ToolExecutionResult, ToolExecutor, ToolExecutorError,
18};
19use crate::contracts::thread::Thread;
20use crate::contracts::thread::{Message, MessageMetadata, ToolCall};
21use crate::contracts::runtime::tool_call::{Tool, ToolDescriptor, ToolResult};
22use crate::contracts::{RunContext, Suspension};
23use crate::engine::convert::tool_response;
24use crate::engine::tool_execution::collect_patches;
25use crate::runtime::run_context::{await_or_cancel, is_cancelled, CancelAware};
26use async_trait::async_trait;
27use serde_json::Value;
28use std::collections::HashMap;
29use std::sync::Arc;
30use tirea_state::{Patch, PatchExt, TrackedPatch};
31
32#[derive(Debug)]
37pub enum ExecuteToolsOutcome {
38 Completed(Thread),
40 Suspended {
42 thread: Thread,
43 suspended_call: SuspendedCall,
44 },
45}
46
47impl ExecuteToolsOutcome {
48 pub fn into_thread(self) -> Thread {
50 match self {
51 Self::Completed(t) | Self::Suspended { thread: t, .. } => t,
52 }
53 }
54
55 pub fn is_suspended(&self) -> bool {
57 matches!(self, Self::Suspended { .. })
58 }
59}
60
61pub(super) struct AppliedToolResults {
62 pub(super) suspended_calls: Vec<SuspendedCall>,
63 pub(super) state_snapshot: Option<Value>,
64}
65
66#[derive(Clone)]
67pub(super) struct ToolPhaseContext<'a> {
68 pub(super) tool_descriptors: &'a [ToolDescriptor],
69 pub(super) plugins: &'a [Arc<dyn AgentPlugin>],
70 pub(super) activity_manager: Arc<dyn ActivityManager>,
71 pub(super) run_config: &'a tirea_contract::RunConfig,
72 pub(super) thread_id: &'a str,
73 pub(super) thread_messages: &'a [Arc<Message>],
74 pub(super) cancellation_token: Option<&'a RunCancellationToken>,
75}
76
77impl<'a> ToolPhaseContext<'a> {
78 pub(super) fn from_request(request: &'a ToolExecutionRequest<'a>) -> Self {
79 Self {
80 tool_descriptors: request.tool_descriptors,
81 plugins: request.plugins,
82 activity_manager: request.activity_manager.clone(),
83 run_config: request.run_config,
84 thread_id: request.thread_id,
85 thread_messages: request.thread_messages,
86 cancellation_token: request.cancellation_token,
87 }
88 }
89}
90
91fn now_unix_millis() -> u64 {
92 std::time::SystemTime::now()
93 .duration_since(std::time::UNIX_EPOCH)
94 .map_or(0, |d| d.as_millis().min(u128::from(u64::MAX)) as u64)
95}
96
97fn suspended_call_from_tool_result(call: &ToolCall, result: &ToolResult) -> SuspendedCall {
98 if let Some(mut explicit) = result.suspension() {
99 if explicit.pending.id.trim().is_empty() || explicit.pending.name.trim().is_empty() {
100 explicit.pending =
101 PendingToolCall::new(call.id.clone(), call.name.clone(), call.arguments.clone());
102 }
103 return SuspendedCall::new(call, explicit);
104 }
105
106 let mut suspension = Suspension::new(&call.id, format!("tool:{}", call.name))
107 .with_parameters(call.arguments.clone());
108 if let Some(message) = result.message.as_ref() {
109 suspension = suspension.with_message(message.clone());
110 }
111
112 SuspendedCall::new(
113 call,
114 SuspendTicket::new(
115 suspension,
116 PendingToolCall::new(call.id.clone(), call.name.clone(), call.arguments.clone()),
117 ToolCallResumeMode::ReplayToolCall,
118 ),
119 )
120}
121
122fn persist_tool_call_status(
123 step: &StepContext<'_>,
124 call: &ToolCall,
125 status: ToolCallStatus,
126 suspended_call: Option<&SuspendedCall>,
127) -> Result<(), AgentLoopError> {
128 let current_state = step.ctx().tool_call_state_for(&call.id).map_err(|e| {
129 AgentLoopError::StateError(format!(
130 "failed to read tool call state for '{}' before setting {:?}: {e}",
131 call.id, status
132 ))
133 })?;
134 let previous_status = current_state
135 .as_ref()
136 .map(|state| state.status)
137 .unwrap_or(ToolCallStatus::New);
138 let current_resume_token = current_state
139 .as_ref()
140 .and_then(|state| state.resume_token.clone());
141 let current_resume = current_state
142 .as_ref()
143 .and_then(|state| state.resume.clone());
144
145 let (next_resume_token, next_resume) = match status {
146 ToolCallStatus::Running => {
147 if matches!(previous_status, ToolCallStatus::Resuming) {
148 (current_resume_token.clone(), current_resume.clone())
149 } else {
150 (None, None)
151 }
152 }
153 ToolCallStatus::Suspended => (
154 suspended_call
155 .map(|entry| entry.ticket.pending.id.clone())
156 .or(current_resume_token.clone()),
157 None,
158 ),
159 ToolCallStatus::Succeeded
160 | ToolCallStatus::Failed
161 | ToolCallStatus::Cancelled
162 | ToolCallStatus::New
163 | ToolCallStatus::Resuming => (current_resume_token, current_resume),
164 };
165
166 let Some(runtime_state) = transition_tool_call_state(
167 current_state,
168 ToolCallStateSeed {
169 call_id: &call.id,
170 tool_name: &call.name,
171 arguments: &call.arguments,
172 status: ToolCallStatus::New,
173 resume_token: None,
174 },
175 ToolCallStateTransition {
176 status,
177 resume_token: next_resume_token,
178 resume: next_resume,
179 updated_at: now_unix_millis(),
180 },
181 ) else {
182 return Err(AgentLoopError::StateError(format!(
183 "invalid tool call status transition for '{}': {:?} -> {:?}",
184 call.id, previous_status, status
185 )));
186 };
187
188 step.ctx()
189 .set_tool_call_state_for(&call.id, runtime_state)
190 .map_err(|e| {
191 AgentLoopError::StateError(format!(
192 "failed to persist tool call state for '{}' as {:?}: {e}",
193 call.id, status
194 ))
195 })
196}
197
198fn map_tool_executor_error(err: AgentLoopError, thread_id: &str) -> ToolExecutorError {
199 match err {
200 AgentLoopError::Cancelled => ToolExecutorError::Cancelled {
201 thread_id: thread_id.to_string(),
202 },
203 other => ToolExecutorError::Failed {
204 message: other.to_string(),
205 },
206 }
207}
208
209#[derive(Debug, Clone, Copy, PartialEq, Eq)]
211pub enum ParallelToolExecutionMode {
212 BatchApproval,
213 Streaming,
214}
215
216#[derive(Debug, Clone, Copy, PartialEq, Eq)]
218pub struct ParallelToolExecutor {
219 mode: ParallelToolExecutionMode,
220}
221
222impl ParallelToolExecutor {
223 pub const fn batch_approval() -> Self {
224 Self {
225 mode: ParallelToolExecutionMode::BatchApproval,
226 }
227 }
228
229 pub const fn streaming() -> Self {
230 Self {
231 mode: ParallelToolExecutionMode::Streaming,
232 }
233 }
234
235 fn mode_name(self) -> &'static str {
236 match self.mode {
237 ParallelToolExecutionMode::BatchApproval => "parallel_batch_approval",
238 ParallelToolExecutionMode::Streaming => "parallel_streaming",
239 }
240 }
241}
242
243impl Default for ParallelToolExecutor {
244 fn default() -> Self {
245 Self::streaming()
246 }
247}
248
249#[async_trait]
250impl ToolExecutor for ParallelToolExecutor {
251 async fn execute(
252 &self,
253 request: ToolExecutionRequest<'_>,
254 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
255 let thread_id = request.thread_id;
256 let phase_ctx = ToolPhaseContext::from_request(&request);
257 execute_tools_parallel_with_phases(request.tools, request.calls, request.state, phase_ctx)
258 .await
259 .map_err(|e| map_tool_executor_error(e, thread_id))
260 }
261
262 fn name(&self) -> &'static str {
263 self.mode_name()
264 }
265
266 fn requires_parallel_patch_conflict_check(&self) -> bool {
267 true
268 }
269
270 fn decision_replay_policy(&self) -> DecisionReplayPolicy {
271 match self.mode {
272 ParallelToolExecutionMode::BatchApproval => DecisionReplayPolicy::BatchAllSuspended,
273 ParallelToolExecutionMode::Streaming => DecisionReplayPolicy::Immediate,
274 }
275 }
276}
277
278#[derive(Debug, Clone, Copy, Default)]
280pub struct SequentialToolExecutor;
281
282#[async_trait]
283impl ToolExecutor for SequentialToolExecutor {
284 async fn execute(
285 &self,
286 request: ToolExecutionRequest<'_>,
287 ) -> Result<Vec<ToolExecutionResult>, ToolExecutorError> {
288 let thread_id = request.thread_id;
289 let phase_ctx = ToolPhaseContext::from_request(&request);
290 execute_tools_sequential_with_phases(request.tools, request.calls, request.state, phase_ctx)
291 .await
292 .map_err(|e| map_tool_executor_error(e, thread_id))
293 }
294
295 fn name(&self) -> &'static str {
296 "sequential"
297 }
298}
299
300fn validate_parallel_state_patch_conflicts(
301 results: &[ToolExecutionResult],
302) -> Result<(), AgentLoopError> {
303 for (left_idx, left) in results.iter().enumerate() {
304 let mut left_patches: Vec<&TrackedPatch> = Vec::new();
305 if let Some(ref patch) = left.execution.patch {
306 left_patches.push(patch);
307 }
308 left_patches.extend(left.pending_patches.iter());
309
310 if left_patches.is_empty() {
311 continue;
312 }
313
314 for right in results.iter().skip(left_idx + 1) {
315 let mut right_patches: Vec<&TrackedPatch> = Vec::new();
316 if let Some(ref patch) = right.execution.patch {
317 right_patches.push(patch);
318 }
319 right_patches.extend(right.pending_patches.iter());
320
321 if right_patches.is_empty() {
322 continue;
323 }
324
325 for left_patch in &left_patches {
326 for right_patch in &right_patches {
327 let conflicts = left_patch.patch().conflicts_with(right_patch.patch());
328 if let Some(conflict) = conflicts.first() {
329 return Err(AgentLoopError::StateError(format!(
330 "conflicting parallel state patches between '{}' and '{}' at {}",
331 left.execution.call.id, right.execution.call.id, conflict.path
332 )));
333 }
334 }
335 }
336 }
337 }
338
339 Ok(())
340}
341
342pub(super) fn apply_tool_results_to_session(
343 run_ctx: &mut RunContext,
344 results: &[ToolExecutionResult],
345 metadata: Option<MessageMetadata>,
346 check_parallel_patch_conflicts: bool,
347) -> Result<AppliedToolResults, AgentLoopError> {
348 apply_tool_results_impl(
349 run_ctx,
350 results,
351 metadata,
352 check_parallel_patch_conflicts,
353 None,
354 )
355}
356
357pub(super) fn apply_tool_results_impl(
358 run_ctx: &mut RunContext,
359 results: &[ToolExecutionResult],
360 metadata: Option<MessageMetadata>,
361 check_parallel_patch_conflicts: bool,
362 tool_msg_ids: Option<&HashMap<String, String>>,
363) -> Result<AppliedToolResults, AgentLoopError> {
364 let suspended: Vec<SuspendedCall> = results
366 .iter()
367 .filter_map(|r| {
368 if matches!(r.outcome, ToolCallOutcome::Suspended) {
369 r.suspended_call.clone()
370 } else {
371 None
372 }
373 })
374 .collect();
375
376 if check_parallel_patch_conflicts {
377 validate_parallel_state_patch_conflicts(results)?;
378 }
379
380 let mut patches: Vec<TrackedPatch> = collect_patches(
382 &results
383 .iter()
384 .map(|r| r.execution.clone())
385 .collect::<Vec<_>>(),
386 );
387 let mut merged_pending_patch = Patch::new();
388 for r in results {
389 for pending in &r.pending_patches {
390 merged_pending_patch.extend(pending.patch().clone());
391 }
392 }
393 if !merged_pending_patch.is_empty() {
394 patches.push(TrackedPatch::new(merged_pending_patch).with_source("agent_loop"));
395 }
396 let mut state_changed = !patches.is_empty();
397 run_ctx.add_thread_patches(patches);
398
399 let tool_messages: Vec<Arc<Message>> = results
401 .iter()
402 .flat_map(|r| {
403 let is_suspended = matches!(r.outcome, ToolCallOutcome::Suspended);
404 let mut msgs = if is_suspended {
405 vec![Message::tool(
406 &r.execution.call.id,
407 format!(
408 "Tool '{}' is awaiting approval. Execution paused.",
409 r.execution.call.name
410 ),
411 )]
412 } else {
413 let mut tool_msg = tool_response(&r.execution.call.id, &r.execution.result);
414 if let Some(id) = tool_msg_ids.and_then(|ids| ids.get(&r.execution.call.id)) {
415 tool_msg = tool_msg.with_id(id.clone());
416 }
417 vec![tool_msg]
418 };
419 for reminder in &r.reminders {
420 msgs.push(Message::internal_system(format!(
421 "<system-reminder>{}</system-reminder>",
422 reminder
423 )));
424 }
425 if let Some(ref meta) = metadata {
426 for msg in &mut msgs {
427 msg.metadata = Some(meta.clone());
428 }
429 }
430 msgs.into_iter().map(Arc::new).collect::<Vec<_>>()
431 })
432 .collect();
433
434 run_ctx.add_messages(tool_messages);
435 let appended_count = drain_agent_append_user_messages(run_ctx, results, metadata.as_ref())?;
436 if appended_count > 0 {
437 state_changed = true;
438 }
439 let existing_suspended = run_ctx.suspended_calls();
440
441 if !suspended.is_empty() {
442 let mut merged_suspended = existing_suspended;
443 for call in &suspended {
444 merged_suspended.insert(call.call_id.clone(), call.clone());
445 }
446 let state = run_ctx
447 .snapshot()
448 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
449 let patch =
450 set_agent_suspended_calls(&state, merged_suspended.into_values().collect::<Vec<_>>())?;
451 if !patch.patch().is_empty() {
452 state_changed = true;
453 run_ctx.add_thread_patch(patch);
454 }
455 let state_snapshot = if state_changed {
456 Some(
457 run_ctx
458 .snapshot()
459 .map_err(|e| AgentLoopError::StateError(e.to_string()))?,
460 )
461 } else {
462 None
463 };
464 return Ok(AppliedToolResults {
465 suspended_calls: suspended,
466 state_snapshot,
467 });
468 }
469
470 let state_snapshot = if state_changed {
477 Some(
478 run_ctx
479 .snapshot()
480 .map_err(|e| AgentLoopError::StateError(e.to_string()))?,
481 )
482 } else {
483 None
484 };
485
486 Ok(AppliedToolResults {
487 suspended_calls: Vec::new(),
488 state_snapshot,
489 })
490}
491
492fn tool_result_metadata_from_run_ctx(run_ctx: &RunContext) -> Option<MessageMetadata> {
493 let run_id = run_ctx
494 .run_config
495 .value("run_id")
496 .and_then(|v| v.as_str().map(String::from))
497 .or_else(|| {
498 run_ctx.messages().iter().rev().find_map(|m| {
499 m.metadata
500 .as_ref()
501 .and_then(|meta| meta.run_id.as_ref().cloned())
502 })
503 });
504
505 let step_index = run_ctx
506 .messages()
507 .iter()
508 .rev()
509 .find_map(|m| m.metadata.as_ref().and_then(|meta| meta.step_index));
510
511 if run_id.is_none() && step_index.is_none() {
512 None
513 } else {
514 Some(MessageMetadata { run_id, step_index })
515 }
516}
517
518#[allow(dead_code)]
519pub(super) fn next_step_index(run_ctx: &RunContext) -> u32 {
520 run_ctx
521 .messages()
522 .iter()
523 .filter_map(|m| m.metadata.as_ref().and_then(|meta| meta.step_index))
524 .max()
525 .map(|v| v.saturating_add(1))
526 .unwrap_or(0)
527}
528
529pub(super) fn step_metadata(run_id: Option<String>, step_index: u32) -> MessageMetadata {
530 MessageMetadata {
531 run_id,
532 step_index: Some(step_index),
533 }
534}
535
536pub async fn execute_tools(
540 thread: Thread,
541 result: &StreamResult,
542 tools: &HashMap<String, Arc<dyn Tool>>,
543 parallel: bool,
544) -> Result<ExecuteToolsOutcome, AgentLoopError> {
545 execute_tools_with_plugins(thread, result, tools, parallel, &[]).await
546}
547
548pub async fn execute_tools_with_config(
550 thread: Thread,
551 result: &StreamResult,
552 tools: &HashMap<String, Arc<dyn Tool>>,
553 config: &AgentConfig,
554) -> Result<ExecuteToolsOutcome, AgentLoopError> {
555 execute_tools_with_plugins_and_executor(
556 thread,
557 result,
558 tools,
559 config.tool_executor.as_ref(),
560 &config.plugins,
561 )
562 .await
563}
564
565pub(super) fn scope_with_tool_caller_context(
566 run_ctx: &RunContext,
567 state: &Value,
568 _config: Option<&AgentConfig>,
569) -> Result<tirea_contract::RunConfig, AgentLoopError> {
570 let mut rt = run_ctx.run_config.clone();
571 if rt.value(TOOL_SCOPE_CALLER_THREAD_ID_KEY).is_none() {
572 rt.set(
573 TOOL_SCOPE_CALLER_THREAD_ID_KEY,
574 run_ctx.thread_id().to_string(),
575 )
576 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
577 }
578 if rt.value(TOOL_SCOPE_CALLER_STATE_KEY).is_none() {
579 rt.set(TOOL_SCOPE_CALLER_STATE_KEY, state.clone())
580 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
581 }
582 if rt.value(TOOL_SCOPE_CALLER_MESSAGES_KEY).is_none() {
583 rt.set(TOOL_SCOPE_CALLER_MESSAGES_KEY, run_ctx.messages().to_vec())
584 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
585 }
586 Ok(rt)
587}
588
589pub async fn execute_tools_with_plugins(
591 thread: Thread,
592 result: &StreamResult,
593 tools: &HashMap<String, Arc<dyn Tool>>,
594 parallel: bool,
595 plugins: &[Arc<dyn AgentPlugin>],
596) -> Result<ExecuteToolsOutcome, AgentLoopError> {
597 let parallel_executor = ParallelToolExecutor::streaming();
598 let sequential_executor = SequentialToolExecutor;
599 let executor: &dyn ToolExecutor = if parallel {
600 ¶llel_executor
601 } else {
602 &sequential_executor
603 };
604 execute_tools_with_plugins_and_executor(thread, result, tools, executor, plugins).await
605}
606
607pub async fn execute_tools_with_plugins_and_executor(
608 thread: Thread,
609 result: &StreamResult,
610 tools: &HashMap<String, Arc<dyn Tool>>,
611 executor: &dyn ToolExecutor,
612 plugins: &[Arc<dyn AgentPlugin>],
613) -> Result<ExecuteToolsOutcome, AgentLoopError> {
614 let rebuilt_state = thread
616 .rebuild_state()
617 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
618 let mut run_ctx = RunContext::new(
619 &thread.id,
620 rebuilt_state.clone(),
621 thread.messages.clone(),
622 tirea_contract::RunConfig::default(),
623 );
624
625 let tool_descriptors: Vec<ToolDescriptor> =
626 tools.values().map(|t| t.descriptor().clone()).collect();
627 let (_, run_start_patches) = super::run_phase_block(
628 &run_ctx,
629 &tool_descriptors,
630 plugins,
631 &[Phase::RunStart],
632 |_| {},
633 |_| (),
634 )
635 .await?;
636 if !run_start_patches.is_empty() {
637 run_ctx.add_thread_patches(run_start_patches);
638 }
639
640 let replay_executor: Arc<dyn ToolExecutor> = match executor.decision_replay_policy() {
641 DecisionReplayPolicy::BatchAllSuspended => Arc::new(ParallelToolExecutor::batch_approval()),
642 DecisionReplayPolicy::Immediate => Arc::new(ParallelToolExecutor::streaming()),
643 };
644 let replay_config = AgentConfig {
645 plugins: plugins.to_vec(),
646 tool_executor: replay_executor,
647 ..AgentConfig::default()
648 };
649 let replay = super::drain_resuming_tool_calls_and_replay(
650 &mut run_ctx,
651 tools,
652 &replay_config,
653 &tool_descriptors,
654 )
655 .await?;
656
657 if replay.replayed {
658 let suspended = run_ctx.suspended_calls().values().next().cloned();
659 let delta = run_ctx.take_delta();
660 let mut out_thread = thread;
661 for msg in delta.messages {
662 out_thread = out_thread.with_message((*msg).clone());
663 }
664 out_thread = out_thread.with_patches(delta.patches);
665 return if let Some(first) = suspended {
666 Ok(ExecuteToolsOutcome::Suspended {
667 thread: out_thread,
668 suspended_call: first,
669 })
670 } else {
671 Ok(ExecuteToolsOutcome::Completed(out_thread))
672 };
673 }
674
675 if result.tool_calls.is_empty() {
676 let delta = run_ctx.take_delta();
677 let mut out_thread = thread;
678 for msg in delta.messages {
679 out_thread = out_thread.with_message((*msg).clone());
680 }
681 out_thread = out_thread.with_patches(delta.patches);
682 return Ok(ExecuteToolsOutcome::Completed(out_thread));
683 }
684
685 let current_state = run_ctx
686 .snapshot()
687 .map_err(|e| AgentLoopError::StateError(e.to_string()))?;
688 let rt_for_tools = scope_with_tool_caller_context(&run_ctx, ¤t_state, None)?;
689 let results = executor
690 .execute(ToolExecutionRequest {
691 tools,
692 calls: &result.tool_calls,
693 state: ¤t_state,
694 tool_descriptors: &tool_descriptors,
695 plugins,
696 activity_manager: tirea_contract::runtime::activity::NoOpActivityManager::arc(),
697 run_config: &rt_for_tools,
698 thread_id: run_ctx.thread_id(),
699 thread_messages: run_ctx.messages(),
700 state_version: run_ctx.version(),
701 cancellation_token: None,
702 })
703 .await?;
704
705 let metadata = tool_result_metadata_from_run_ctx(&run_ctx);
706 let applied = apply_tool_results_to_session(
707 &mut run_ctx,
708 &results,
709 metadata,
710 executor.requires_parallel_patch_conflict_check(),
711 )?;
712 let suspended = applied.suspended_calls.into_iter().next();
713
714 let delta = run_ctx.take_delta();
716 let mut out_thread = thread;
717 for msg in delta.messages {
718 out_thread = out_thread.with_message((*msg).clone());
719 }
720 out_thread = out_thread.with_patches(delta.patches);
721
722 if let Some(first) = suspended {
723 Ok(ExecuteToolsOutcome::Suspended {
724 thread: out_thread,
725 suspended_call: first,
726 })
727 } else {
728 Ok(ExecuteToolsOutcome::Completed(out_thread))
729 }
730}
731
732pub(super) async fn execute_tools_parallel_with_phases(
734 tools: &HashMap<String, Arc<dyn Tool>>,
735 calls: &[crate::contracts::thread::ToolCall],
736 state: &Value,
737 phase_ctx: ToolPhaseContext<'_>,
738) -> Result<Vec<ToolExecutionResult>, AgentLoopError> {
739 use futures::future::join_all;
740
741 if is_cancelled(phase_ctx.cancellation_token) {
742 return Err(cancelled_error(phase_ctx.thread_id));
743 }
744
745 let run_config_owned = phase_ctx.run_config.clone();
747 let thread_id = phase_ctx.thread_id.to_string();
748 let thread_messages = Arc::new(phase_ctx.thread_messages.to_vec());
749 let tool_descriptors = phase_ctx.tool_descriptors.to_vec();
750 let plugins = phase_ctx.plugins.to_vec();
751
752 let futures = calls.iter().map(|call| {
753 let tool = tools.get(&call.name).cloned();
754 let state = state.clone();
755 let call = call.clone();
756 let plugins = plugins.clone();
757 let tool_descriptors = tool_descriptors.clone();
758 let activity_manager = phase_ctx.activity_manager.clone();
759 let rt = run_config_owned.clone();
760 let sid = thread_id.clone();
761 let thread_messages = thread_messages.clone();
762
763 async move {
764 execute_single_tool_with_phases(
765 tool.as_deref(),
766 &call,
767 &state,
768 &ToolPhaseContext {
769 tool_descriptors: &tool_descriptors,
770 plugins: &plugins,
771 activity_manager,
772 run_config: &rt,
773 thread_id: &sid,
774 thread_messages: thread_messages.as_slice(),
775 cancellation_token: None,
776 },
777 )
778 .await
779 }
780 });
781
782 let join_future = join_all(futures);
783 let results = match await_or_cancel(phase_ctx.cancellation_token, join_future).await {
784 CancelAware::Cancelled => return Err(cancelled_error(&thread_id)),
785 CancelAware::Value(results) => results,
786 };
787 let results: Vec<ToolExecutionResult> = results.into_iter().collect::<Result<_, _>>()?;
788 Ok(results)
789}
790
791pub(super) async fn execute_tools_sequential_with_phases(
793 tools: &HashMap<String, Arc<dyn Tool>>,
794 calls: &[crate::contracts::thread::ToolCall],
795 initial_state: &Value,
796 phase_ctx: ToolPhaseContext<'_>,
797) -> Result<Vec<ToolExecutionResult>, AgentLoopError> {
798 use tirea_state::apply_patch;
799
800 if is_cancelled(phase_ctx.cancellation_token) {
801 return Err(cancelled_error(phase_ctx.thread_id));
802 }
803
804 let mut state = initial_state.clone();
805 let mut results = Vec::with_capacity(calls.len());
806
807 for call in calls {
808 let tool = tools.get(&call.name).cloned();
809 let call_phase_ctx = ToolPhaseContext {
810 tool_descriptors: phase_ctx.tool_descriptors,
811 plugins: phase_ctx.plugins,
812 activity_manager: phase_ctx.activity_manager.clone(),
813 run_config: phase_ctx.run_config,
814 thread_id: phase_ctx.thread_id,
815 thread_messages: phase_ctx.thread_messages,
816 cancellation_token: None,
817 };
818 let result = match await_or_cancel(
819 phase_ctx.cancellation_token,
820 execute_single_tool_with_phases(tool.as_deref(), call, &state, &call_phase_ctx),
821 )
822 .await
823 {
824 CancelAware::Cancelled => return Err(cancelled_error(phase_ctx.thread_id)),
825 CancelAware::Value(result) => result?,
826 };
827
828 if let Some(ref patch) = result.execution.patch {
830 state = apply_patch(&state, patch.patch()).map_err(|e| {
831 AgentLoopError::StateError(format!(
832 "failed to apply tool patch for call '{}': {}",
833 result.execution.call.id, e
834 ))
835 })?;
836 }
837 for pp in &result.pending_patches {
839 state = apply_patch(&state, pp.patch()).map_err(|e| {
840 AgentLoopError::StateError(format!(
841 "failed to apply plugin patch for call '{}': {}",
842 result.execution.call.id, e
843 ))
844 })?;
845 }
846
847 results.push(result);
848
849 if results
850 .last()
851 .is_some_and(|r| matches!(r.outcome, ToolCallOutcome::Suspended))
852 {
853 break;
854 }
855 }
856
857 Ok(results)
858}
859
860pub(super) async fn execute_single_tool_with_phases(
862 tool: Option<&dyn Tool>,
863 call: &crate::contracts::thread::ToolCall,
864 state: &Value,
865 phase_ctx: &ToolPhaseContext<'_>,
866) -> Result<ToolExecutionResult, AgentLoopError> {
867 let doc = tirea_state::DocCell::new(state.clone());
869 let ops = std::sync::Mutex::new(Vec::new());
870 let pending_messages = std::sync::Mutex::new(Vec::new());
871 let plugin_scope = phase_ctx.run_config;
872 let mut plugin_tool_call_ctx = crate::contracts::ToolCallContext::new(
873 &doc,
874 &ops,
875 "plugin_phase",
876 "plugin:tool_phase",
877 plugin_scope,
878 &pending_messages,
879 tirea_contract::runtime::activity::NoOpActivityManager::arc(),
880 );
881 if let Some(token) = phase_ctx.cancellation_token {
882 plugin_tool_call_ctx = plugin_tool_call_ctx.with_cancellation_token(token);
883 }
884
885 let mut step = StepContext::new(
887 plugin_tool_call_ctx,
888 phase_ctx.thread_id,
889 phase_ctx.thread_messages,
890 phase_ctx.tool_descriptors.to_vec(),
891 );
892 step.tool = Some(ToolContext::new(call));
893 emit_phase_checked(Phase::BeforeToolExecute, &mut step, phase_ctx.plugins).await?;
895
896 let (execution, outcome, suspended_call) = if step.tool_blocked() {
898 let reason = step
899 .tool
900 .as_ref()
901 .and_then(|t| t.block_reason.clone())
902 .unwrap_or_else(|| "Blocked by plugin".to_string());
903 (
904 ToolExecution {
905 call: call.clone(),
906 result: ToolResult::error(&call.name, reason),
907 patch: None,
908 },
909 ToolCallOutcome::Failed,
910 None,
911 )
912 } else if let Some(plugin_result) = step.tool_result().cloned() {
913 let outcome = ToolCallOutcome::from_tool_result(&plugin_result);
914 (
915 ToolExecution {
916 call: call.clone(),
917 result: plugin_result,
918 patch: None,
919 },
920 outcome,
921 None,
922 )
923 } else if tool.is_none() {
924 (
925 ToolExecution {
926 call: call.clone(),
927 result: ToolResult::error(&call.name, format!("Tool '{}' not found", call.name)),
928 patch: None,
929 },
930 ToolCallOutcome::Failed,
931 None,
932 )
933 } else if let Err(e) = tool.unwrap().validate_args(&call.arguments) {
934 (
936 ToolExecution {
937 call: call.clone(),
938 result: ToolResult::error(&call.name, e.to_string()),
939 patch: None,
940 },
941 ToolCallOutcome::Failed,
942 None,
943 )
944 } else if step.tool_pending() {
945 let Some(suspend_ticket) = step.tool.as_ref().and_then(|t| t.suspend_ticket.clone()) else {
946 return Err(AgentLoopError::StateError(
947 "tool is pending but suspend ticket is missing".to_string(),
948 ));
949 };
950 (
951 ToolExecution {
952 call: call.clone(),
953 result: ToolResult::suspended(
954 &call.name,
955 "Execution suspended; awaiting external decision",
956 ),
957 patch: None,
958 },
959 ToolCallOutcome::Suspended,
960 Some(SuspendedCall::new(call, suspend_ticket)),
961 )
962 } else {
963 persist_tool_call_status(&step, call, ToolCallStatus::Running, None)?;
964 let tool_doc = tirea_state::DocCell::new(state.clone());
966 let tool_ops = std::sync::Mutex::new(Vec::new());
967 let tool_pending_msgs = std::sync::Mutex::new(Vec::new());
968 let mut tool_ctx = crate::contracts::ToolCallContext::new(
969 &tool_doc,
970 &tool_ops,
971 &call.id,
972 format!("tool:{}", call.name),
973 plugin_scope,
974 &tool_pending_msgs,
975 phase_ctx.activity_manager.clone(),
976 );
977 if let Some(token) = phase_ctx.cancellation_token {
978 tool_ctx = tool_ctx.with_cancellation_token(token);
979 }
980 let result = match tool
981 .unwrap()
982 .execute(call.arguments.clone(), &tool_ctx)
983 .await
984 {
985 Ok(r) => r,
986 Err(e) => ToolResult::error(&call.name, e.to_string()),
987 };
988
989 let patch = tool_ctx.take_patch();
990 let patch = if patch.patch().is_empty() {
991 None
992 } else {
993 Some(patch)
994 };
995 let outcome = ToolCallOutcome::from_tool_result(&result);
996
997 let suspended_call = if matches!(outcome, ToolCallOutcome::Suspended) {
998 Some(suspended_call_from_tool_result(call, &result))
999 } else {
1000 None
1001 };
1002
1003 (
1004 ToolExecution {
1005 call: call.clone(),
1006 result,
1007 patch,
1008 },
1009 outcome,
1010 suspended_call,
1011 )
1012 };
1013
1014 step.set_tool_result(execution.result.clone());
1016
1017 emit_phase_checked(Phase::AfterToolExecute, &mut step, phase_ctx.plugins).await?;
1019
1020 match outcome {
1021 ToolCallOutcome::Suspended => {
1022 persist_tool_call_status(
1023 &step,
1024 call,
1025 ToolCallStatus::Suspended,
1026 suspended_call.as_ref(),
1027 )?;
1028 }
1029 ToolCallOutcome::Succeeded => {
1030 persist_tool_call_status(&step, call, ToolCallStatus::Succeeded, None)?;
1031 }
1032 ToolCallOutcome::Failed => {
1033 persist_tool_call_status(&step, call, ToolCallStatus::Failed, None)?;
1034 }
1035 }
1036
1037 let plugin_patch = step.ctx().take_patch();
1039 if !plugin_patch.patch().is_empty() {
1040 step.emit_patch(plugin_patch);
1041 }
1042
1043 let mut pending_patches = std::mem::take(&mut step.pending_patches);
1044 for effect in std::mem::take(&mut step.state_effects) {
1045 match effect {
1046 StateEffect::Patch(patch) => pending_patches.push(patch),
1047 }
1048 }
1049
1050 Ok(ToolExecutionResult {
1051 execution,
1052 outcome,
1053 suspended_call,
1054 reminders: step.system_reminders.clone(),
1055 pending_patches,
1056 })
1057}
1058
1059fn cancelled_error(_thread_id: &str) -> AgentLoopError {
1060 AgentLoopError::Cancelled
1061}