tirea_agent_loop/runtime/loop_runner/
state_commit.rs1use super::{AgentLoopError, StateCommitError, StateCommitter};
2use crate::contracts::storage::VersionPrecondition;
3use crate::contracts::thread::CheckpointReason;
4use crate::contracts::RunContext;
5use crate::contracts::ThreadChangeSet;
6use async_trait::async_trait;
7use std::sync::Arc;
8
9#[derive(Clone)]
10pub struct ChannelStateCommitter {
11 tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>,
12}
13
14impl ChannelStateCommitter {
15 pub fn new(tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>) -> Self {
16 Self { tx }
17 }
18}
19
20#[async_trait]
21impl StateCommitter for ChannelStateCommitter {
22 async fn commit(
23 &self,
24 _thread_id: &str,
25 changeset: ThreadChangeSet,
26 precondition: VersionPrecondition,
27 ) -> Result<u64, StateCommitError> {
28 let next_version = match precondition {
29 VersionPrecondition::Any => 1,
30 VersionPrecondition::Exact(version) => version.saturating_add(1),
31 };
32 self.tx
33 .send(changeset)
34 .map_err(|e| StateCommitError::new(format!("channel state commit failed: {e}")))?;
35 Ok(next_version)
36 }
37}
38
39pub(super) async fn commit_pending_delta(
40 run_ctx: &mut RunContext,
41 reason: CheckpointReason,
42 force: bool,
43 run_id: &str,
44 parent_run_id: Option<&str>,
45 state_committer: Option<&Arc<dyn StateCommitter>>,
46) -> Result<(), AgentLoopError> {
47 let Some(committer) = state_committer else {
48 return Ok(());
49 };
50
51 let delta = run_ctx.take_delta();
52 if !force && delta.is_empty() {
53 return Ok(());
54 }
55
56 let changeset = ThreadChangeSet::from_parts(
57 run_id.to_string(),
58 parent_run_id.map(str::to_string),
59 reason,
60 delta.messages,
61 delta.patches,
62 None,
63 );
64 let precondition = VersionPrecondition::Exact(run_ctx.version());
65 let committed_version = committer
66 .commit(run_ctx.thread_id(), changeset, precondition)
67 .await
68 .map_err(|e| AgentLoopError::StateError(format!("state commit failed: {e}")))?;
69 run_ctx.set_version(committed_version, Some(super::current_unix_millis()));
70 Ok(())
71}
72
73pub(super) struct PendingDeltaCommitContext<'a> {
74 run_id: &'a str,
75 parent_run_id: Option<&'a str>,
76 state_committer: Option<&'a Arc<dyn StateCommitter>>,
77}
78
79impl<'a> PendingDeltaCommitContext<'a> {
80 pub(super) fn new(
81 run_id: &'a str,
82 parent_run_id: Option<&'a str>,
83 state_committer: Option<&'a Arc<dyn StateCommitter>>,
84 ) -> Self {
85 Self {
86 run_id,
87 parent_run_id,
88 state_committer,
89 }
90 }
91
92 pub(super) async fn commit(
93 &self,
94 run_ctx: &mut RunContext,
95 reason: CheckpointReason,
96 force: bool,
97 ) -> Result<(), AgentLoopError> {
98 commit_pending_delta(
99 run_ctx,
100 reason,
101 force,
102 self.run_id,
103 self.parent_run_id,
104 self.state_committer,
105 )
106 .await
107 }
108}