Skip to main content

tirea_agent_loop/runtime/loop_runner/
state_commit.rs

1use super::{AgentLoopError, StateCommitError, StateCommitter};
2use crate::contracts::storage::VersionPrecondition;
3use crate::contracts::thread::CheckpointReason;
4use crate::contracts::RunContext;
5use crate::contracts::ThreadChangeSet;
6use async_trait::async_trait;
7use std::sync::Arc;
8
9#[derive(Clone)]
10pub struct ChannelStateCommitter {
11    tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>,
12}
13
14impl ChannelStateCommitter {
15    pub fn new(tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>) -> Self {
16        Self { tx }
17    }
18}
19
20#[async_trait]
21impl StateCommitter for ChannelStateCommitter {
22    async fn commit(
23        &self,
24        _thread_id: &str,
25        changeset: ThreadChangeSet,
26        precondition: VersionPrecondition,
27    ) -> Result<u64, StateCommitError> {
28        let next_version = match precondition {
29            VersionPrecondition::Any => 1,
30            VersionPrecondition::Exact(version) => version.saturating_add(1),
31        };
32        self.tx
33            .send(changeset)
34            .map_err(|e| StateCommitError::new(format!("channel state commit failed: {e}")))?;
35        Ok(next_version)
36    }
37}
38
39pub(super) async fn commit_pending_delta(
40    run_ctx: &mut RunContext,
41    reason: CheckpointReason,
42    force: bool,
43    run_id: &str,
44    parent_run_id: Option<&str>,
45    state_committer: Option<&Arc<dyn StateCommitter>>,
46) -> Result<(), AgentLoopError> {
47    let Some(committer) = state_committer else {
48        return Ok(());
49    };
50
51    let delta = run_ctx.take_delta();
52    if !force && delta.is_empty() {
53        return Ok(());
54    }
55
56    let changeset = ThreadChangeSet::from_parts(
57        run_id.to_string(),
58        parent_run_id.map(str::to_string),
59        reason,
60        delta.messages,
61        delta.patches,
62        None,
63    );
64    let precondition = VersionPrecondition::Exact(run_ctx.version());
65    let committed_version = committer
66        .commit(run_ctx.thread_id(), changeset, precondition)
67        .await
68        .map_err(|e| AgentLoopError::StateError(format!("state commit failed: {e}")))?;
69    run_ctx.set_version(committed_version, Some(super::current_unix_millis()));
70    Ok(())
71}
72
73pub(super) struct PendingDeltaCommitContext<'a> {
74    run_id: &'a str,
75    parent_run_id: Option<&'a str>,
76    state_committer: Option<&'a Arc<dyn StateCommitter>>,
77}
78
79impl<'a> PendingDeltaCommitContext<'a> {
80    pub(super) fn new(
81        run_id: &'a str,
82        parent_run_id: Option<&'a str>,
83        state_committer: Option<&'a Arc<dyn StateCommitter>>,
84    ) -> Self {
85        Self {
86            run_id,
87            parent_run_id,
88            state_committer,
89        }
90    }
91
92    pub(super) async fn commit(
93        &self,
94        run_ctx: &mut RunContext,
95        reason: CheckpointReason,
96        force: bool,
97    ) -> Result<(), AgentLoopError> {
98        commit_pending_delta(
99            run_ctx,
100            reason,
101            force,
102            self.run_id,
103            self.parent_run_id,
104            self.state_committer,
105        )
106        .await
107    }
108}