Skip to main content

tirea_agent_loop/runtime/
run_context.rs

1use crate::contracts::storage::VersionPrecondition;
2use crate::contracts::thread::ThreadChangeSet;
3use async_trait::async_trait;
4use futures::future::pending;
5use thiserror::Error;
6use tokio_util::sync::CancellationToken;
7
8pub type RunCancellationToken = CancellationToken;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum CancelAware<T> {
12    Value(T),
13    Cancelled,
14}
15
16pub fn is_cancelled(token: Option<&RunCancellationToken>) -> bool {
17    token.is_some_and(RunCancellationToken::is_cancelled)
18}
19
20pub async fn cancelled(token: Option<&RunCancellationToken>) {
21    if let Some(token) = token {
22        token.cancelled().await;
23    } else {
24        pending::<()>().await;
25    }
26}
27
28pub async fn await_or_cancel<T, F>(token: Option<&RunCancellationToken>, fut: F) -> CancelAware<T>
29where
30    F: std::future::Future<Output = T>,
31{
32    if let Some(token) = token {
33        tokio::select! {
34            _ = token.cancelled() => CancelAware::Cancelled,
35            value = fut => CancelAware::Value(value),
36        }
37    } else {
38        CancelAware::Value(fut.await)
39    }
40}
41
42/// Error returned by state commit sinks.
43#[derive(Debug, Clone, Error)]
44#[error("{message}")]
45pub struct StateCommitError {
46    pub message: String,
47}
48
49impl StateCommitError {
50    pub fn new(message: impl Into<String>) -> Self {
51        Self {
52            message: message.into(),
53        }
54    }
55}
56
57/// Sink for committed thread deltas.
58#[async_trait]
59pub trait StateCommitter: Send + Sync {
60    /// Commit a single change set for a thread.
61    ///
62    /// Returns the committed storage version after the write succeeds.
63    async fn commit(
64        &self,
65        thread_id: &str,
66        changeset: ThreadChangeSet,
67        precondition: VersionPrecondition,
68    ) -> Result<u64, StateCommitError>;
69}
70
71impl std::fmt::Debug for dyn StateCommitter {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.write_str("<StateCommitter>")
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use tokio::time::{timeout, Duration};
81
82    #[tokio::test]
83    async fn await_or_cancel_returns_value_without_token() {
84        let out = await_or_cancel(None, async { 42usize }).await;
85        assert_eq!(out, CancelAware::Value(42));
86    }
87
88    #[tokio::test]
89    async fn await_or_cancel_returns_cancelled_when_token_cancelled() {
90        let token = RunCancellationToken::new();
91        let token_for_task = token.clone();
92        let handle = tokio::spawn(async move {
93            await_or_cancel(Some(&token_for_task), async {
94                tokio::time::sleep(Duration::from_secs(5)).await;
95                7usize
96            })
97            .await
98        });
99
100        token.cancel();
101        let out = timeout(Duration::from_millis(300), handle)
102            .await
103            .expect("await_or_cancel should resolve quickly after cancellation")
104            .expect("task should not panic");
105        assert_eq!(out, CancelAware::Cancelled);
106    }
107
108    #[tokio::test]
109    async fn cancelled_waits_for_token_signal() {
110        let token = RunCancellationToken::new();
111        let token_for_task = token.clone();
112        let handle = tokio::spawn(async move {
113            cancelled(Some(&token_for_task)).await;
114            true
115        });
116
117        token.cancel();
118        let done = timeout(Duration::from_millis(300), handle)
119            .await
120            .expect("cancelled() should return after token cancellation")
121            .expect("task should not panic");
122        assert!(done);
123    }
124}