Skip to main content

tirea_agent_loop/runtime/loop_runner/
state_commit.rs

1use super::{AgentLoopError, RunExecutionContext, StateCommitError, StateCommitter};
2use crate::contracts::storage::{RunOrigin, VersionPrecondition};
3use crate::contracts::thread::CheckpointReason;
4use crate::contracts::{RunContext, RunMeta, TerminationReason, ThreadChangeSet};
5use async_trait::async_trait;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::Arc;
8
9#[derive(Clone)]
10pub struct ChannelStateCommitter {
11    tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>,
12    version: Arc<AtomicU64>,
13}
14
15impl ChannelStateCommitter {
16    pub fn new(tx: tokio::sync::mpsc::UnboundedSender<ThreadChangeSet>) -> Self {
17        Self {
18            tx,
19            version: Arc::new(AtomicU64::new(0)),
20        }
21    }
22}
23
24#[async_trait]
25impl StateCommitter for ChannelStateCommitter {
26    async fn commit(
27        &self,
28        _thread_id: &str,
29        changeset: ThreadChangeSet,
30        _precondition: VersionPrecondition,
31    ) -> Result<u64, StateCommitError> {
32        let next_version = self.version.fetch_add(1, Ordering::SeqCst) + 1;
33        self.tx
34            .send(changeset)
35            .map_err(|e| StateCommitError::new(format!("channel state commit failed: {e}")))?;
36        Ok(next_version)
37    }
38}
39
40pub(super) async fn commit_pending_delta(
41    run_ctx: &mut RunContext,
42    reason: CheckpointReason,
43    force: bool,
44    state_committer: Option<&Arc<dyn StateCommitter>>,
45    execution_ctx: &RunExecutionContext,
46    termination: Option<&TerminationReason>,
47) -> Result<(), AgentLoopError> {
48    let Some(committer) = state_committer else {
49        return Ok(());
50    };
51
52    let delta = run_ctx.take_delta();
53    if !force && delta.is_empty() {
54        return Ok(());
55    }
56
57    // On RunFinished, write a full state snapshot to bound the action/patch
58    // replay window to a single run.
59    let snapshot = if reason == CheckpointReason::RunFinished {
60        match run_ctx.snapshot() {
61            Ok(state) => Some(state),
62            Err(e) => {
63                tracing::warn!(error = %e, "failed to compute RunFinished snapshot; continuing without snapshot");
64                None
65            }
66        }
67    } else {
68        None
69    };
70
71    let mut changeset = ThreadChangeSet::from_parts(
72        execution_ctx.run_id.clone(),
73        execution_ctx.parent_run_id.clone(),
74        reason,
75        delta.messages,
76        delta.patches,
77        delta.state_actions,
78        snapshot,
79    );
80
81    // Loop always emits run-finished RunMeta. Whether this metadata is used to
82    // materialize/maintain durable run mappings is decided by the outer
83    // orchestration layer's StateCommitter policy.
84    if let Some(termination) = termination {
85        let agent_id = execution_ctx.agent_id.clone();
86        let origin: RunOrigin = execution_ctx.origin;
87        let parent_thread_id = None; // Already set on the initial changeset.
88        let (status, termination_code, termination_detail) = map_termination(termination);
89        changeset.run_meta = Some(RunMeta {
90            agent_id,
91            origin,
92            status,
93            parent_thread_id,
94            termination_code,
95            termination_detail,
96        });
97    }
98
99    let precondition = VersionPrecondition::Exact(run_ctx.version());
100    let committed_version = committer
101        .commit(run_ctx.thread_id(), changeset, precondition)
102        .await
103        .map_err(|e| AgentLoopError::StateError(format!("state commit failed: {e}")))?;
104    run_ctx.set_version(committed_version, Some(super::current_unix_millis()));
105    Ok(())
106}
107
108fn map_termination(
109    termination: &TerminationReason,
110) -> (
111    crate::contracts::storage::RunStatus,
112    Option<String>,
113    Option<String>,
114) {
115    let (status, _) = termination.to_run_status();
116    match termination {
117        TerminationReason::NaturalEnd => (status, Some("natural".to_string()), None),
118        TerminationReason::BehaviorRequested => {
119            (status, Some("behavior_requested".to_string()), None)
120        }
121        TerminationReason::Suspended => (status, Some("input_required".to_string()), None),
122        TerminationReason::Cancelled => (status, Some("cancelled".to_string()), None),
123        TerminationReason::Error(message) => {
124            (status, Some("error".to_string()), Some(message.clone()))
125        }
126        TerminationReason::Stopped(stopped) => (
127            status,
128            Some(stopped.code.trim().to_ascii_lowercase()),
129            stopped.detail.clone(),
130        ),
131    }
132}
133
134pub(super) struct PendingDeltaCommitContext<'a> {
135    execution_ctx: &'a RunExecutionContext,
136    state_committer: Option<&'a Arc<dyn StateCommitter>>,
137}
138
139impl<'a> PendingDeltaCommitContext<'a> {
140    pub(super) fn new(
141        execution_ctx: &'a RunExecutionContext,
142        state_committer: Option<&'a Arc<dyn StateCommitter>>,
143    ) -> Self {
144        Self {
145            execution_ctx,
146            state_committer,
147        }
148    }
149
150    pub(super) async fn commit(
151        &self,
152        run_ctx: &mut RunContext,
153        reason: CheckpointReason,
154        force: bool,
155    ) -> Result<(), AgentLoopError> {
156        commit_pending_delta(
157            run_ctx,
158            reason,
159            force,
160            self.state_committer,
161            self.execution_ctx,
162            None,
163        )
164        .await
165    }
166
167    pub(super) async fn commit_run_finished(
168        &self,
169        run_ctx: &mut RunContext,
170        termination: &TerminationReason,
171    ) -> Result<(), AgentLoopError> {
172        commit_pending_delta(
173            run_ctx,
174            CheckpointReason::RunFinished,
175            true,
176            self.state_committer,
177            self.execution_ctx,
178            Some(termination),
179        )
180        .await
181    }
182}