tirea_agent_loop/runtime/
run_context.rs1use crate::contracts::storage::VersionPrecondition;
2use crate::contracts::thread::ThreadChangeSet;
3use async_trait::async_trait;
4use futures::future::pending;
5use thiserror::Error;
6use tokio_util::sync::CancellationToken;
7
8pub type RunCancellationToken = CancellationToken;
9
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum CancelAware<T> {
12 Value(T),
13 Cancelled,
14}
15
16pub fn is_cancelled(token: Option<&RunCancellationToken>) -> bool {
17 token.is_some_and(RunCancellationToken::is_cancelled)
18}
19
20pub async fn cancelled(token: Option<&RunCancellationToken>) {
21 if let Some(token) = token {
22 token.cancelled().await;
23 } else {
24 pending::<()>().await;
25 }
26}
27
28pub async fn await_or_cancel<T, F>(token: Option<&RunCancellationToken>, fut: F) -> CancelAware<T>
29where
30 F: std::future::Future<Output = T>,
31{
32 if let Some(token) = token {
33 tokio::select! {
34 _ = token.cancelled() => CancelAware::Cancelled,
35 value = fut => CancelAware::Value(value),
36 }
37 } else {
38 CancelAware::Value(fut.await)
39 }
40}
41
42#[derive(Debug, Clone, Error)]
44#[error("{message}")]
45pub struct StateCommitError {
46 pub message: String,
47}
48
49impl StateCommitError {
50 pub fn new(message: impl Into<String>) -> Self {
51 Self {
52 message: message.into(),
53 }
54 }
55}
56
57#[async_trait]
59pub trait StateCommitter: Send + Sync {
60 async fn commit(
64 &self,
65 thread_id: &str,
66 changeset: ThreadChangeSet,
67 precondition: VersionPrecondition,
68 ) -> Result<u64, StateCommitError>;
69}
70
71impl std::fmt::Debug for dyn StateCommitter {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 f.write_str("<StateCommitter>")
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use tokio::time::{timeout, Duration};
81
82 #[tokio::test]
83 async fn await_or_cancel_returns_value_without_token() {
84 let out = await_or_cancel(None, async { 42usize }).await;
85 assert_eq!(out, CancelAware::Value(42));
86 }
87
88 #[tokio::test]
89 async fn await_or_cancel_returns_cancelled_when_token_cancelled() {
90 let token = RunCancellationToken::new();
91 let token_for_task = token.clone();
92 let handle = tokio::spawn(async move {
93 await_or_cancel(Some(&token_for_task), async {
94 tokio::time::sleep(Duration::from_secs(5)).await;
95 7usize
96 })
97 .await
98 });
99
100 token.cancel();
101 let out = timeout(Duration::from_millis(300), handle)
102 .await
103 .expect("await_or_cancel should resolve quickly after cancellation")
104 .expect("task should not panic");
105 assert_eq!(out, CancelAware::Cancelled);
106 }
107
108 #[tokio::test]
109 async fn cancelled_waits_for_token_signal() {
110 let token = RunCancellationToken::new();
111 let token_for_task = token.clone();
112 let handle = tokio::spawn(async move {
113 cancelled(Some(&token_for_task)).await;
114 true
115 });
116
117 token.cancel();
118 let done = timeout(Duration::from_millis(300), handle)
119 .await
120 .expect("cancelled() should return after token cancellation")
121 .expect("task should not panic");
122 assert!(done);
123 }
124}
125
126pub const TOOL_SCOPE_CALLER_THREAD_ID_KEY: &str = "__agent_tool_caller_thread_id";
128pub const TOOL_SCOPE_CALLER_AGENT_ID_KEY: &str = "__agent_tool_caller_agent_id";
130pub const TOOL_SCOPE_CALLER_STATE_KEY: &str = "__agent_tool_caller_state";
132pub const TOOL_SCOPE_CALLER_MESSAGES_KEY: &str = "__agent_tool_caller_messages";