1use std::sync::Arc;
12
13use crate::task::TaskMetadata;
14
15#[derive(Clone, Debug)]
21pub struct TaskExecutionContext {
22 pub workflow_id: Arc<str>,
24 pub instance_id: Arc<str>,
26 pub task_id: Arc<str>,
28 pub metadata: TaskMetadata,
30 pub workflow_metadata_json: Option<Arc<str>>,
32}
33
34pub struct WorkflowContext<C, M> {
39 pub workflow_id: Arc<str>,
41 pub codec: Arc<C>,
43 pub metadata: Arc<M>,
45 pub metadata_json: Option<Arc<str>>,
47}
48
49impl<C, M> Clone for WorkflowContext<C, M> {
50 fn clone(&self) -> Self {
51 Self {
52 workflow_id: Arc::clone(&self.workflow_id),
53 codec: Arc::clone(&self.codec),
54 metadata: Arc::clone(&self.metadata),
55 metadata_json: self.metadata_json.clone(),
56 }
57 }
58}
59
60impl<C, M> WorkflowContext<C, M> {
61 pub fn new(workflow_id: impl Into<Arc<str>>, codec: Arc<C>, metadata: Arc<M>) -> Self {
63 Self {
64 workflow_id: workflow_id.into(),
65 codec,
66 metadata,
67 metadata_json: None,
68 }
69 }
70
71 #[must_use]
73 pub fn workflow_id(&self) -> &str {
74 &self.workflow_id
75 }
76
77 #[must_use]
79 pub fn codec(&self) -> Arc<C> {
80 self.codec.clone()
81 }
82
83 #[must_use]
85 pub fn metadata(&self) -> Arc<M> {
86 self.metadata.clone()
87 }
88}
89
90use std::cell::RefCell;
91
92std::thread_local! {
93 static THREAD_LOCAL_TASK_CTX: RefCell<Option<TaskExecutionContext>> = const { RefCell::new(None) };
98}
99
100pub fn with_thread_local_task_context<R>(ctx: TaskExecutionContext, f: impl FnOnce() -> R) -> R {
103 THREAD_LOCAL_TASK_CTX.with(|cell| {
104 let prev = cell.borrow_mut().replace(ctx);
105 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
106 *cell.borrow_mut() = prev;
107 match result {
108 Ok(r) => r,
109 Err(e) => std::panic::resume_unwind(e),
110 }
111 })
112}
113
114#[must_use]
116pub fn get_thread_local_task_context() -> Option<TaskExecutionContext> {
117 THREAD_LOCAL_TASK_CTX.with(|cell| cell.borrow().clone())
118}
119
120#[cfg(feature = "tokio")]
123mod task_local_ctx {
124 use super::TaskExecutionContext;
125
126 tokio::task_local! {
127 static TASK_EXEC_CTX: Option<TaskExecutionContext>;
129 }
130
131 pub async fn with_task_context<F: std::future::Future>(
133 ctx: TaskExecutionContext,
134 fut: F,
135 ) -> F::Output {
136 TASK_EXEC_CTX.scope(Some(ctx), fut).await
137 }
138
139 #[must_use]
143 pub fn get_task_context() -> Option<TaskExecutionContext> {
144 TASK_EXEC_CTX
145 .try_with(std::clone::Clone::clone)
146 .ok()
147 .flatten()
148 .or_else(super::get_thread_local_task_context)
149 }
150}
151
152#[cfg(feature = "tokio")]
153pub use task_local_ctx::{get_task_context, with_task_context};
154
155#[cfg(not(feature = "tokio"))]
159#[must_use]
160pub fn get_task_context() -> Option<TaskExecutionContext> {
161 get_thread_local_task_context()
162}
163
164#[macro_export]
176macro_rules! task_context {
177 () => {
178 $crate::context::get_task_context()
179 };
180}
181
182#[cfg(all(test, feature = "tokio"))]
183#[allow(clippy::unwrap_used, clippy::panic)]
184mod tests {
185 use super::*;
186 use crate::task::TaskMetadata;
187
188 fn make_task_ctx() -> TaskExecutionContext {
189 TaskExecutionContext {
190 workflow_id: Arc::from("wf-1"),
191 instance_id: Arc::from("inst-1"),
192 task_id: Arc::from("task-a"),
193 metadata: TaskMetadata::default(),
194 workflow_metadata_json: None,
195 }
196 }
197
198 #[test]
199 fn thread_local_roundtrip() {
200 assert!(get_thread_local_task_context().is_none());
201
202 let ctx = make_task_ctx();
203 let result = with_thread_local_task_context(ctx.clone(), || {
204 let inner = get_thread_local_task_context().unwrap();
205 assert_eq!(&*inner.workflow_id, "wf-1");
206 assert_eq!(&*inner.instance_id, "inst-1");
207 assert_eq!(&*inner.task_id, "task-a");
208 42
209 });
210 assert_eq!(result, 42);
211
212 assert!(get_thread_local_task_context().is_none());
214 }
215
216 #[test]
217 fn thread_local_restores_on_panic() {
218 let ctx = make_task_ctx();
219 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
220 with_thread_local_task_context(ctx, || {
221 panic!("boom");
222 })
223 }));
224 assert!(result.is_err());
225 assert!(get_thread_local_task_context().is_none());
226 }
227
228 #[test]
229 fn task_local_roundtrip() {
230 let rt = tokio::runtime::Builder::new_current_thread()
231 .enable_all()
232 .build()
233 .unwrap();
234 rt.block_on(async {
235 assert!(get_task_context().is_none());
236
237 let ctx = make_task_ctx();
238 let inner = with_task_context(ctx, async {
239 let c = get_task_context().unwrap();
240 assert_eq!(&*c.task_id, "task-a");
241 c
242 })
243 .await;
244
245 assert_eq!(&*inner.workflow_id, "wf-1");
246 });
247 }
248
249 #[test]
250 fn task_local_falls_back_to_thread_local() {
251 let rt = tokio::runtime::Builder::new_current_thread()
252 .enable_all()
253 .build()
254 .unwrap();
255 rt.block_on(async {
256 let ctx = make_task_ctx();
258 let result = with_thread_local_task_context(ctx, || get_task_context());
259 assert!(result.is_some());
260 assert_eq!(&*result.unwrap().instance_id, "inst-1");
261 });
262 }
263
264 #[test]
265 fn macro_works() {
266 let ctx = make_task_ctx();
267 with_thread_local_task_context(ctx, || {
268 let c = task_context!().unwrap();
269 assert_eq!(&*c.task_id, "task-a");
270 });
271 }
272}