tirea_agent_loop/runtime/loop_runner/
state_commit.rs1use super::{AgentLoopError, RunIdentity, StateCommitError, StateCommitter};
2use crate::contracts::storage::{RunOrigin, VersionPrecondition};
3use crate::contracts::thread::CheckpointReason;
4use crate::contracts::{RunContext, RunMeta, TerminationReason, ThreadChangeSet};
5use async_trait::async_trait;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9#[derive(Clone)]
10pub struct ChannelStateCommitter {
11 tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>,
12 version: Arc<AtomicU64>,
13}
14
15impl ChannelStateCommitter {
16 pub fn new(tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>) -> Self {
17 Self {
18 tx,
19 version: Arc::new(AtomicU64::new(0)),
20 }
21 }
22}
23
24#[async_trait]
25impl StateCommitter for ChannelStateCommitter {
26 async fn commit(
27 &self,
28 _thread_id: &str,
29 changeset: ThreadChangeSet,
30 _precondition: VersionPrecondition,
31 ) -> Result<u64, StateCommitError> {
32 let next_version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
33 self.tx
34 .send(changeset)
35 .map_err(|e| StateCommitError::new(format!("channel state commit failed: {e}")))?;
36 Ok(next_version)
37 }
38}
39
40pub(super) async fn commit_pending_delta(
41 run_ctx: &mut RunContext,
42 reason: CheckpointReason,
43 force: bool,
44 state_committer: Option<&Arc<dyn StateCommitter>>,
45 run_identity: &RunIdentity,
46 termination: Option<&TerminationReason>,
47) -> Result<(), AgentLoopError> {
48 let Some(committer) = state_committer else {
49 return Ok(());
50 };
51
52 let delta = run_ctx.take_delta();
53 if !force && delta.is_empty() {
54 return Ok(());
55 }
56
57 let snapshot = if reason == CheckpointReason::RunFinished {
60 match run_ctx.snapshot() {
61 Ok(state) => Some(state),
62 Err(e) => {
63 tracing::warn!(error = %e, "failed to compute RunFinished snapshot; continuing without snapshot");
64 None
65 }
66 }
67 } else {
68 None
69 };
70
71 let mut changeset = ThreadChangeSet::from_parts(
72 run_identity.run_id.clone(),
73 run_identity.parent_run_id.clone(),
74 reason,
75 delta.messages,
76 delta.patches,
77 delta.state_actions,
78 snapshot,
79 );
80
81 if let Some(termination) = termination {
85 let agent_id = run_identity.agent_id.clone();
86 let origin: RunOrigin = run_identity.origin;
87 let parent_thread_id = None; let (status, termination_code, termination_detail) = map_termination(termination);
89 changeset.run_meta = Some(RunMeta {
90 agent_id,
91 origin,
92 status,
93 parent_thread_id,
94 termination_code,
95 termination_detail,
96 });
97 }
98
99 let precondition = VersionPrecondition::Exact(run_ctx.version());
100 let committed_version = committer
101 .commit(run_ctx.thread_id(), changeset, precondition)
102 .await
103 .map_err(|e| AgentLoopError::StateError(format!("state commit failed: {e}")))?;
104 run_ctx.set_version(committed_version, Some(super::current_unix_millis()));
105 Ok(())
106}
107
108fn map_termination(
109 termination: &TerminationReason,
110) -> (
111 crate::contracts::storage::RunStatus,
112 Option<String>,
113 Option<String>,
114) {
115 let (status, _) = termination.to_run_status();
116 match termination {
117 TerminationReason::NaturalEnd => (status, Some("natural".to_string()), None),
118 TerminationReason::BehaviorRequested => {
119 (status, Some("behavior_requested".to_string()), None)
120 }
121 TerminationReason::Suspended => (status, Some("input_required".to_string()), None),
122 TerminationReason::Cancelled => (status, Some("cancelled".to_string()), None),
123 TerminationReason::Error(message) => {
124 (status, Some("error".to_string()), Some(message.clone()))
125 }
126 TerminationReason::Stopped(stopped) => (
127 status,
128 Some(stopped.code.trim().to_ascii_lowercase()),
129 stopped.detail.clone(),
130 ),
131 }
132}
133
134pub(super) struct PendingDeltaCommitContext<'a> {
135 run_identity: &'a RunIdentity,
136 state_committer: Option<&'a Arc<dyn StateCommitter>>,
137}
138
139impl<'a> PendingDeltaCommitContext<'a> {
140 pub(super) fn new(
141 run_identity: &'a RunIdentity,
142 state_committer: Option<&'a Arc<dyn StateCommitter>>,
143 ) -> Self {
144 Self {
145 run_identity,
146 state_committer,
147 }
148 }
149
150 pub(super) async fn commit(
151 &self,
152 run_ctx: &mut RunContext,
153 reason: CheckpointReason,
154 force: bool,
155 ) -> Result<(), AgentLoopError> {
156 commit_pending_delta(
157 run_ctx,
158 reason,
159 force,
160 self.state_committer,
161 self.run_identity,
162 None,
163 )
164 .await
165 }
166
167 pub(super) async fn commit_run_finished(
168 &self,
169 run_ctx: &mut RunContext,
170 termination: &TerminationReason,
171 ) -> Result<(), AgentLoopError> {
172 commit_pending_delta(
173 run_ctx,
174 CheckpointReason::RunFinished,
175 true,
176 self.state_committer,
177 self.run_identity,
178 Some(termination),
179 )
180 .await
181 }
182}